diff --git a/tests/test_training.py b/tests/test_training.py index f506324a..41ed5415 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -81,3 +81,7 @@ def test_training(self): for epoch in tqdm(range(n_epochs)): batch_losses = train_epoch(model, train_tasks, batch_size=batch_size) epoch_losses.append(np.mean(batch_losses)) + + # Check for NaNs in the loss + loss = np.mean(epoch_losses) + self.assertFalse(np.isnan(loss))