diff --git a/pts/trainer.py b/pts/trainer.py index 852a04a..601fdd3 100644 --- a/pts/trainer.py +++ b/pts/trainer.py @@ -41,6 +41,9 @@ def __call__( train_iter: DataLoader, validation_iter: Optional[DataLoader] = None, ) -> None: + + self.training_loss = [] + self.validation_loss = [] optimizer = Adam( net.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay ) @@ -80,7 +83,6 @@ def __call__( }, refresh=False, ) - loss.backward() if self.clip_gradient is not None: nn.utils.clip_grad_norm_(net.parameters(), self.clip_gradient) @@ -92,6 +94,9 @@ def __call__( break it.close() + # Append the average loss for the epoch to the list of loss values + self.training_loss.append(avg_epoch_loss) + # validation loop if validation_iter is not None: cumm_epoch_loss_val = 0.0 @@ -121,6 +126,8 @@ def __call__( break it.close() + self.validation_loss.append(avg_epoch_loss_val) # mark epoch end time and log time cost of current epoch toc = time.time() +