Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] skip optimizer update when nan loss #1078

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,4 @@ repos:
rev: 19.3b0
hooks:
- id: black
language_version: python3.7
language_version: python3
19 changes: 15 additions & 4 deletions mmf/trainers/core/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class TrainerTrainingLoopMixin(ABC):
current_iteration: int = 0
num_updates: int = 0
meter: Meter = Meter()
skip_optim_update: bool = False

def training_loop(self) -> None:
self.max_updates = self._calculate_max_updates()
Expand Down Expand Up @@ -114,7 +115,10 @@ def run_training_epoch(self) -> None:
should_start_update = True

should_log = False
if self.num_updates % self.logistics_callback.log_interval == 0:
if (
self.num_updates % self.logistics_callback.log_interval == 0
and not self.skip_optim_update
):
should_log = True
# Calculate metrics every log interval for debugging
if self.training_config.evaluate_metrics:
Expand All @@ -123,6 +127,9 @@ def run_training_epoch(self) -> None:
)
self.meter.update_from_report(combined_report)

if self.skip_optim_update:
self.skip_optim_update = False

self.on_update_end(
report=combined_report, meter=self.meter, should_log=should_log
)
Expand Down Expand Up @@ -179,16 +186,17 @@ def _check_nan_losses(self, report):
loss_dict = report.losses
nan_loss_keys = []
for key, value in loss_dict.items():
if torch.any(torch.isnan(value)).item():
if torch.isnan(value).any() or torch.isinf(value).any():
nan_loss_keys.append(key)
if len(nan_loss_keys) > 0:
keys_str = ", ".join(nan_loss_keys)
error_msg = (
f"NaN occurred in the following loss(es): {keys_str}; "
f"exiting the training"
f"skipping optimizer update"
)
logger.info(error_msg)
raise RuntimeError(error_msg)
# raise RuntimeError(error_msg)
self.skip_optim_update = True

def _forward(self, batch: Dict[str, Tensor]) -> Dict[str, Any]:
# Move the sample list to device if it isn't as of now.
Expand All @@ -213,6 +221,9 @@ def _backward(self, loss: Tensor) -> None:
self.profile("Backward time")

def _finish_update(self):
if self.skip_optim_update:
self.scaler.unscale_(self.optimizer)
self.optimizer.zero_grad(set_to_none=True)
if self.training_config.clip_gradients:
clip_gradients(
self.model,
Expand Down