@@ -383,6 +383,36 @@ class EnergyModelTestsMixin(BasicModelTestsMixin):
383383 - Isolated atom energy shifts
384384 """
385385
386+ def total_energy_keys (self ):
387+ """Return list of total energy keys to test.
388+
389+ Subclasses can override to test multiple energy keys (e.g., energy_0, energy_1).
390+
391+ Returns:
392+ List of total energy key strings
393+ """
394+ return [AtomicDataDict .TOTAL_ENERGY_KEY ]
395+
396+ def per_atom_energy_keys (self ):
397+ """Return list of per-atom energy keys to test.
398+
399+ Subclasses can override to test multiple energy keys (e.g., per_atom_energy_0, per_atom_energy_1).
400+
401+ Returns:
402+ List of per-atom energy key strings
403+ """
404+ return [AtomicDataDict .PER_ATOM_ENERGY_KEY ]
405+
406+ def force_keys (self ):
407+ """Return list of force keys to test.
408+
409+ Subclasses can override to test multiple force keys (e.g., force_0, force_1).
410+
411+ Returns:
412+ List of force key strings
413+ """
414+ return [AtomicDataDict .FORCE_KEY ]
415+
386416 def test_large_separation (self , model , molecules , device ):
387417 instance , config , _ = model
388418 atol = {torch .float32 : 1e-4 , torch .float64 : 1e-10 }[instance .model_dtype ]
@@ -422,22 +452,27 @@ def test_large_separation(self, model, molecules, device):
422452 out2 = instance (from_dict (data2 ))
423453 out_both = instance (from_dict (data_both ))
424454
425- assert torch .allclose (
426- out1 [AtomicDataDict .TOTAL_ENERGY_KEY ]
427- + out2 [AtomicDataDict .TOTAL_ENERGY_KEY ],
428- out_both [AtomicDataDict .TOTAL_ENERGY_KEY ],
429- atol = atol ,
430- )
431- if AtomicDataDict .FORCE_KEY in out1 :
432- # check forces if it's a force model
433- assert torch .allclose (
434- torch .cat (
435- (out1 [AtomicDataDict .FORCE_KEY ], out2 [AtomicDataDict .FORCE_KEY ]),
436- dim = 0 ,
437- ),
438- out_both [AtomicDataDict .FORCE_KEY ],
439- atol = atol ,
440- )
455+ # test all total energy keys
456+ for energy_key in self .total_energy_keys ():
457+ if energy_key in out1 and energy_key in out2 and energy_key in out_both :
458+ assert torch .allclose (
459+ out1 [energy_key ] + out2 [energy_key ],
460+ out_both [energy_key ],
461+ atol = atol ,
462+ ), f"Large separation test failed for { energy_key } "
463+
464+ # test all force keys
465+ for force_key in self .force_keys ():
466+ if force_key in out1 :
467+ # check forces if it's a force model
468+ assert torch .allclose (
469+ torch .cat (
470+ (out1 [force_key ], out2 [force_key ]),
471+ dim = 0 ,
472+ ),
473+ out_both [force_key ],
474+ atol = atol ,
475+ ), f"Large separation test failed for { force_key } "
441476
442477 atoms_both2 = atoms1 .copy ()
443478 atoms3 = atoms2 .copy ()
@@ -450,16 +485,24 @@ def test_large_separation(self, model, molecules, device):
450485 )
451486
452487 out_both2 = instance (data_both2 )
453- assert torch .allclose (
454- out_both2 [AtomicDataDict .TOTAL_ENERGY_KEY ],
455- out_both [AtomicDataDict .TOTAL_ENERGY_KEY ],
456- atol = atol ,
457- )
458- assert torch .allclose (
459- out_both2 [AtomicDataDict .PER_ATOM_ENERGY_KEY ],
460- out_both [AtomicDataDict .PER_ATOM_ENERGY_KEY ],
461- atol = atol ,
462- )
488+
489+ # test total energy invariance to rigid translation
490+ for energy_key in self .total_energy_keys ():
491+ if energy_key in out_both and energy_key in out_both2 :
492+ assert torch .allclose (
493+ out_both2 [energy_key ],
494+ out_both [energy_key ],
495+ atol = atol ,
496+ ), f"Translation invariance test failed for { energy_key } "
497+
498+ # test per-atom energy invariance to rigid translation
499+ for per_atom_energy_key in self .per_atom_energy_keys ():
500+ if per_atom_energy_key in out_both and per_atom_energy_key in out_both2 :
501+ assert torch .allclose (
502+ out_both2 [per_atom_energy_key ],
503+ out_both [per_atom_energy_key ],
504+ atol = atol ,
505+ ), f"Translation invariance test failed for { per_atom_energy_key } "
463506
464507 def test_cross_frame_grad (self , model , device , nequip_dataset ):
465508 batch = AtomicDataDict .batched_from_list (
@@ -470,67 +513,71 @@ def test_cross_frame_grad(self, model, device, nequip_dataset):
470513 data [AtomicDataDict .POSITIONS_KEY ].requires_grad = True
471514
472515 output = energy_model (data )
473- grads = torch .autograd .grad (
474- outputs = output [AtomicDataDict .TOTAL_ENERGY_KEY ][- 1 ],
475- inputs = data [AtomicDataDict .POSITIONS_KEY ],
476- allow_unused = True ,
477- )[0 ]
478516
479- last_frame_n_atom = batch [AtomicDataDict .NUM_NODES_KEY ][- 1 ]
517+ # test cross-frame gradient isolation for all total energy keys
518+ for energy_key in self .total_energy_keys ():
519+ if energy_key in output :
520+ grads = torch .autograd .grad (
521+ outputs = output [energy_key ][- 1 ],
522+ inputs = data [AtomicDataDict .POSITIONS_KEY ],
523+ allow_unused = True ,
524+ retain_graph = True ,
525+ )[0 ]
526+
527+ last_frame_n_atom = batch [AtomicDataDict .NUM_NODES_KEY ][- 1 ]
480528
481- in_frame_grad = grads [- last_frame_n_atom :]
482- cross_frame_grad = grads [:- last_frame_n_atom ]
529+ in_frame_grad = grads [- last_frame_n_atom :]
530+ cross_frame_grad = grads [:- last_frame_n_atom ]
483531
484- assert cross_frame_grad .abs ().max ().item () == 0
485- assert in_frame_grad .abs ().max ().item () > 0
532+ assert cross_frame_grad .abs ().max ().item () == 0 , (
533+ f"Cross-frame gradient test failed for { energy_key } "
534+ )
535+ assert in_frame_grad .abs ().max ().item () > 0 , (
536+ f"In-frame gradient test failed for { energy_key } "
537+ )
486538
487539 def test_numeric_gradient (self , model , atomic_batch , device ):
488540 """
489541 Tests the ForceStressOutput model by comparing numerical gradients of the forces to the analytical gradients.
490542 """
491543 model , _ , out_fields = model
492- # proceed with tests only if forces are available
493- if AtomicDataDict .FORCE_KEY in out_fields :
494- # physical predictions (energy, forces, etc) will be converted to default_dtype (float64) before comparing
495- data = AtomicDataDict .to_ (atomic_batch , device )
496- output = model (data )
497- forces = output [AtomicDataDict .FORCE_KEY ]
498- epsilon = 1e-3
499544
500- # Compute numerical gradients for each atom and direction and compare to analytical gradients
501- for iatom in range (len (data [AtomicDataDict .POSITIONS_KEY ])):
502- for idir in range (3 ):
503- # Shift `iatom` an `epsilon` in the `idir` direction
504- pos = data [AtomicDataDict .POSITIONS_KEY ][iatom , idir ]
505- data [AtomicDataDict .POSITIONS_KEY ][iatom , idir ] = pos + epsilon
506- output = model (data )
507- e_plus = (
508- output [AtomicDataDict .TOTAL_ENERGY_KEY ]
509- .sum ()
510- .to (torch .get_default_dtype ())
511- )
512-
513- # Shift `iatom` an `epsilon` in the negative `idir` direction
514- data [AtomicDataDict .POSITIONS_KEY ][iatom , idir ] -= epsilon * 2
515- output = model (data )
516- e_minus = (
517- output [AtomicDataDict .TOTAL_ENERGY_KEY ]
518- .sum ()
519- .to (torch .get_default_dtype ())
520- )
521-
522- # Symmetric difference to get the partial forces to all the atoms
523- numeric = - (e_plus - e_minus ) / (epsilon * 2 )
524- analytical = forces [iatom , idir ].to (torch .get_default_dtype ())
525-
526- assert torch .isclose (
527- numeric , analytical , atol = 2e-2
528- ) or torch .isclose (numeric , analytical , rtol = 5e-3 ), (
529- f"numeric: { numeric .item ()} , analytical: { analytical .item ()} "
530- )
531-
532- # Reset the position
533- data [AtomicDataDict .POSITIONS_KEY ][iatom , idir ] += epsilon
545+ # test numeric gradients for each pair of (total_energy_key, force_key)
546+ for energy_key , force_key in zip (self .total_energy_keys (), self .force_keys ()):
547+ # proceed with tests only if forces are available
548+ if force_key in out_fields :
549+ # physical predictions (energy, forces, etc) will be converted to default_dtype (float64) before comparing
550+ data = AtomicDataDict .to_ (atomic_batch , device )
551+ output = model (data )
552+ forces = output [force_key ]
553+ epsilon = 1e-3
554+
555+ # compute numerical gradients for each atom and direction and compare to analytical gradients
556+ for iatom in range (len (data [AtomicDataDict .POSITIONS_KEY ])):
557+ for idir in range (3 ):
558+ # shift `iatom` an `epsilon` in the `idir` direction
559+ pos = data [AtomicDataDict .POSITIONS_KEY ][iatom , idir ]
560+ data [AtomicDataDict .POSITIONS_KEY ][iatom , idir ] = pos + epsilon
561+ output = model (data )
562+ e_plus = output [energy_key ].sum ().to (torch .get_default_dtype ())
563+
564+ # shift `iatom` an `epsilon` in the negative `idir` direction
565+ data [AtomicDataDict .POSITIONS_KEY ][iatom , idir ] -= epsilon * 2
566+ output = model (data )
567+ e_minus = output [energy_key ].sum ().to (torch .get_default_dtype ())
568+
569+ # symmetric difference to get the partial forces to all the atoms
570+ numeric = - (e_plus - e_minus ) / (epsilon * 2 )
571+ analytical = forces [iatom , idir ].to (torch .get_default_dtype ())
572+
573+ assert torch .isclose (
574+ numeric , analytical , atol = 2e-2
575+ ) or torch .isclose (numeric , analytical , rtol = 5e-3 ), (
576+ f"numeric: { numeric .item ()} , analytical: { analytical .item ()} for { energy_key } /{ force_key } "
577+ )
578+
579+ # reset the position
580+ data [AtomicDataDict .POSITIONS_KEY ][iatom , idir ] += epsilon
534581
535582 def test_partial_forces (
536583 self , model , partial_model , atomic_batch , device , strict_locality
@@ -802,10 +849,15 @@ def test_isolated_atom_energies(self, model, device):
802849 ),
803850 device ,
804851 )
805- energies = instance (data )[AtomicDataDict .TOTAL_ENERGY_KEY ]
806- assert torch .allclose (
807- energies , scale_shift_module .shifts .reshape (energies .shape )
808- )
852+ output = instance (data )
853+
854+ # test isolated atom energies for all total energy keys
855+ for energy_key in self .total_energy_keys ():
856+ if energy_key in output :
857+ energies = output [energy_key ]
858+ assert torch .allclose (
859+ energies , scale_shift_module .shifts .reshape (energies .shape )
860+ ), f"Isolated atom energy test failed for { energy_key } "
809861
810862 def test_embedding_cutoff (self , model , device ):
811863 """Test that edge embeddings/features go to zero at cutoff and gradients are correct."""
@@ -861,17 +913,20 @@ def test_embedding_cutoff(self, model, device):
861913 torch .zeros_like (grads ),
862914 )
863915
864- if AtomicDataDict .PER_ATOM_ENERGY_KEY in out :
865- # are the first two atom's energies unaffected by atom at the cutoff?
866- grads = torch .autograd .grad (
867- outputs = out [AtomicDataDict .PER_ATOM_ENERGY_KEY ][:2 ].sum (),
868- inputs = in_dict [AtomicDataDict .POSITIONS_KEY ],
869- )[0 ]
870- # only care about gradient wrt moved atom
871- assert grads .shape == (3 , 3 ), (
872- f"Expected gradient shape (3, 3) for 3 atoms in 3D, got { grads .shape } "
873- )
874- torch .testing .assert_close (
875- grads [2 ],
876- torch .zeros_like (grads [2 ]),
877- )
916+ # test that first two atoms' energies are unaffected by atom at cutoff
917+ for per_atom_energy_key in self .per_atom_energy_keys ():
918+ if per_atom_energy_key in out :
919+ # are the first two atom's energies unaffected by atom at the cutoff?
920+ grads = torch .autograd .grad (
921+ outputs = out [per_atom_energy_key ][:2 ].sum (),
922+ inputs = in_dict [AtomicDataDict .POSITIONS_KEY ],
923+ retain_graph = True ,
924+ )[0 ]
925+ # only care about gradient wrt moved atom
926+ assert grads .shape == (3 , 3 ), (
927+ f"Expected gradient shape (3, 3) for 3 atoms in 3D, got { grads .shape } "
928+ )
929+ torch .testing .assert_close (
930+ grads [2 ],
931+ torch .zeros_like (grads [2 ]),
932+ )
0 commit comments