From 40d994af980f562b2df0638eefddcddf6960fa03 Mon Sep 17 00:00:00 2001 From: Tom Andersson Date: Mon, 17 Jul 2023 17:36:13 +0100 Subject: [PATCH] Check for NaNs in loss --- tests/test_training.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_training.py b/tests/test_training.py index f506324a..41ed5415 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -81,3 +81,7 @@ def test_training(self): for epoch in tqdm(range(n_epochs)): batch_losses = train_epoch(model, train_tasks, batch_size=batch_size) epoch_losses.append(np.mean(batch_losses)) + + # Check for NaNs in the loss + loss = np.mean(epoch_losses) + self.assertFalse(np.isnan(loss))