@@ -156,29 +156,37 @@ def test_batch_invariance(self, fake_model_training_session):
156156 print (orig_train_loss )
157157 print (new_train_loss )
158158 assert len (orig_train_loss ) == len (new_train_loss )
159- assert all (
160- [
161- math .isclose (a , b , rel_tol = tol )
162- for a , b in zip (orig_train_loss .values (), new_train_loss .values ())
163- ]
164- )
159+ for name , orig_val , new_val in zip (
160+ orig_train_loss .keys (),
161+ orig_train_loss .values (),
162+ new_train_loss .values (),
163+ ):
164+ if not math .isclose (orig_val , new_val , rel_tol = tol ):
165+ raise AssertionError (
166+ f"Training loss mismatch for '{ name } ': "
167+ f"original={ orig_val } , new={ new_val } , "
168+ f"diff={ abs (orig_val - new_val )} , rel_tol={ tol } "
169+ )
165170
166171 # == test val metrics invariance to batch size ==
167172 batchsize1_val_metrics = nequip_module .val_metrics [0 ].metrics_values_epoch
168173 print (batchsize5_val_metrics )
169174 print (batchsize1_val_metrics )
170175 assert len (batchsize5_val_metrics ) == len (batchsize1_val_metrics )
171- assert all (
172- [
173- math .isclose (a , b , rel_tol = tol ) if "maxabserr" not in name else True
174- # ^ do not include maxabserr in testing
175- for name , a , b in zip (
176- batchsize5_val_metrics .keys (),
177- batchsize5_val_metrics .values (),
178- batchsize1_val_metrics .values (),
176+ for name , batch5_val , batch1_val in zip (
177+ batchsize5_val_metrics .keys (),
178+ batchsize5_val_metrics .values (),
179+ batchsize1_val_metrics .values (),
180+ ):
181+ # do not include maxabserr or total energy in testing (per atom energy tested)
182+ if ("maxabserr" in name ) or ("total_energy" in name ):
183+ continue
184+ if not math .isclose (batch5_val , batch1_val , rel_tol = tol ):
185+ raise AssertionError (
186+ f"Validation metric mismatch for '{ name } ': "
187+ f"batch_size=5 value={ batch5_val } , batch_size=1 value={ batch1_val } , "
188+ f"diff={ abs (batch5_val - batch1_val )} , rel_tol={ tol } "
179189 )
180- ]
181- )
182190
183191 # TODO: will fail if train dataloader has shuffle=True
184192 def test_restarts (self , fake_model_training_session ):
0 commit comments