From c33bdc2d2ac9594402e19991a28974c425c50003 Mon Sep 17 00:00:00 2001 From: Vedanuj Goswami Date: Tue, 7 Sep 2021 21:52:49 -0700 Subject: [PATCH] skip optimizer update when nan loss --- .pre-commit-config.yaml | 2 +- mmf/trainers/core/training_loop.py | 19 +++++++++++++++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3fd1da29d..79699ae26 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -47,4 +47,4 @@ repos: rev: 19.3b0 hooks: - id: black - language_version: python3.7 + language_version: python3 diff --git a/mmf/trainers/core/training_loop.py b/mmf/trainers/core/training_loop.py index a03dbccbd..9ad1a5981 100644 --- a/mmf/trainers/core/training_loop.py +++ b/mmf/trainers/core/training_loop.py @@ -23,6 +23,7 @@ class TrainerTrainingLoopMixin(ABC): current_iteration: int = 0 num_updates: int = 0 meter: Meter = Meter() + skip_optim_update: bool = False def training_loop(self) -> None: self.max_updates = self._calculate_max_updates() @@ -114,7 +115,10 @@ def run_training_epoch(self) -> None: should_start_update = True should_log = False - if self.num_updates % self.logistics_callback.log_interval == 0: + if ( + self.num_updates % self.logistics_callback.log_interval == 0 + and not self.skip_optim_update + ): should_log = True # Calculate metrics every log interval for debugging if self.training_config.evaluate_metrics: @@ -123,6 +127,9 @@ def run_training_epoch(self) -> None: ) self.meter.update_from_report(combined_report) + if self.skip_optim_update: + self.skip_optim_update = False + self.on_update_end( report=combined_report, meter=self.meter, should_log=should_log ) @@ -179,16 +186,17 @@ def _check_nan_losses(self, report): loss_dict = report.losses nan_loss_keys = [] for key, value in loss_dict.items(): - if torch.any(torch.isnan(value)).item(): + if torch.isnan(value).any() or torch.isinf(value).any(): nan_loss_keys.append(key) if len(nan_loss_keys) > 0: keys_str = ", ".join(nan_loss_keys) error_msg = ( f"NaN occurred in the following loss(es): {keys_str}; " - f"exiting the training" + f"skipping optimizer update" ) logger.info(error_msg) - raise RuntimeError(error_msg) + # raise RuntimeError(error_msg) + self.skip_optim_update = True def _forward(self, batch: Dict[str, Tensor]) -> Dict[str, Any]: # Move the sample list to device if it isn't as of now. @@ -213,6 +221,9 @@ def _backward(self, loss: Tensor) -> None: self.profile("Backward time") def _finish_update(self): + if self.skip_optim_update: + self.scaler.unscale_(self.optimizer) + self.optimizer.zero_grad(set_to_none=True) if self.training_config.clip_gradients: clip_gradients( self.model,