Skip to content

Commit d9ee003

Browse files
committed
make energy model tests more extensible
1 parent 3481090 commit d9ee003

File tree

1 file changed

+150
-95
lines changed

1 file changed

+150
-95
lines changed

nequip/utils/unittests/model_tests_basic.py

Lines changed: 150 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,36 @@ class EnergyModelTestsMixin(BasicModelTestsMixin):
383383
- Isolated atom energy shifts
384384
"""
385385

386+
def total_energy_keys(self):
387+
"""Return list of total energy keys to test.
388+
389+
Subclasses can override to test multiple energy keys (e.g., energy_0, energy_1).
390+
391+
Returns:
392+
List of total energy key strings
393+
"""
394+
return [AtomicDataDict.TOTAL_ENERGY_KEY]
395+
396+
def per_atom_energy_keys(self):
397+
"""Return list of per-atom energy keys to test.
398+
399+
Subclasses can override to test multiple energy keys (e.g., per_atom_energy_0, per_atom_energy_1).
400+
401+
Returns:
402+
List of per-atom energy key strings
403+
"""
404+
return [AtomicDataDict.PER_ATOM_ENERGY_KEY]
405+
406+
def force_keys(self):
407+
"""Return list of force keys to test.
408+
409+
Subclasses can override to test multiple force keys (e.g., force_0, force_1).
410+
411+
Returns:
412+
List of force key strings
413+
"""
414+
return [AtomicDataDict.FORCE_KEY]
415+
386416
def test_large_separation(self, model, molecules, device):
387417
instance, config, _ = model
388418
atol = {torch.float32: 1e-4, torch.float64: 1e-10}[instance.model_dtype]
@@ -422,22 +452,27 @@ def test_large_separation(self, model, molecules, device):
422452
out2 = instance(from_dict(data2))
423453
out_both = instance(from_dict(data_both))
424454

425-
assert torch.allclose(
426-
out1[AtomicDataDict.TOTAL_ENERGY_KEY]
427-
+ out2[AtomicDataDict.TOTAL_ENERGY_KEY],
428-
out_both[AtomicDataDict.TOTAL_ENERGY_KEY],
429-
atol=atol,
430-
)
431-
if AtomicDataDict.FORCE_KEY in out1:
432-
# check forces if it's a force model
433-
assert torch.allclose(
434-
torch.cat(
435-
(out1[AtomicDataDict.FORCE_KEY], out2[AtomicDataDict.FORCE_KEY]),
436-
dim=0,
437-
),
438-
out_both[AtomicDataDict.FORCE_KEY],
439-
atol=atol,
440-
)
455+
# test all total energy keys
456+
for energy_key in self.total_energy_keys():
457+
if energy_key in out1 and energy_key in out2 and energy_key in out_both:
458+
assert torch.allclose(
459+
out1[energy_key] + out2[energy_key],
460+
out_both[energy_key],
461+
atol=atol,
462+
), f"Large separation test failed for {energy_key}"
463+
464+
# test all force keys
465+
for force_key in self.force_keys():
466+
if force_key in out1:
467+
# check forces if it's a force model
468+
assert torch.allclose(
469+
torch.cat(
470+
(out1[force_key], out2[force_key]),
471+
dim=0,
472+
),
473+
out_both[force_key],
474+
atol=atol,
475+
), f"Large separation test failed for {force_key}"
441476

442477
atoms_both2 = atoms1.copy()
443478
atoms3 = atoms2.copy()
@@ -450,16 +485,24 @@ def test_large_separation(self, model, molecules, device):
450485
)
451486

452487
out_both2 = instance(data_both2)
453-
assert torch.allclose(
454-
out_both2[AtomicDataDict.TOTAL_ENERGY_KEY],
455-
out_both[AtomicDataDict.TOTAL_ENERGY_KEY],
456-
atol=atol,
457-
)
458-
assert torch.allclose(
459-
out_both2[AtomicDataDict.PER_ATOM_ENERGY_KEY],
460-
out_both[AtomicDataDict.PER_ATOM_ENERGY_KEY],
461-
atol=atol,
462-
)
488+
489+
# test total energy invariance to rigid translation
490+
for energy_key in self.total_energy_keys():
491+
if energy_key in out_both and energy_key in out_both2:
492+
assert torch.allclose(
493+
out_both2[energy_key],
494+
out_both[energy_key],
495+
atol=atol,
496+
), f"Translation invariance test failed for {energy_key}"
497+
498+
# test per-atom energy invariance to rigid translation
499+
for per_atom_energy_key in self.per_atom_energy_keys():
500+
if per_atom_energy_key in out_both and per_atom_energy_key in out_both2:
501+
assert torch.allclose(
502+
out_both2[per_atom_energy_key],
503+
out_both[per_atom_energy_key],
504+
atol=atol,
505+
), f"Translation invariance test failed for {per_atom_energy_key}"
463506

464507
def test_cross_frame_grad(self, model, device, nequip_dataset):
465508
batch = AtomicDataDict.batched_from_list(
@@ -470,67 +513,71 @@ def test_cross_frame_grad(self, model, device, nequip_dataset):
470513
data[AtomicDataDict.POSITIONS_KEY].requires_grad = True
471514

472515
output = energy_model(data)
473-
grads = torch.autograd.grad(
474-
outputs=output[AtomicDataDict.TOTAL_ENERGY_KEY][-1],
475-
inputs=data[AtomicDataDict.POSITIONS_KEY],
476-
allow_unused=True,
477-
)[0]
478516

479-
last_frame_n_atom = batch[AtomicDataDict.NUM_NODES_KEY][-1]
517+
# test cross-frame gradient isolation for all total energy keys
518+
for energy_key in self.total_energy_keys():
519+
if energy_key in output:
520+
grads = torch.autograd.grad(
521+
outputs=output[energy_key][-1],
522+
inputs=data[AtomicDataDict.POSITIONS_KEY],
523+
allow_unused=True,
524+
retain_graph=True,
525+
)[0]
526+
527+
last_frame_n_atom = batch[AtomicDataDict.NUM_NODES_KEY][-1]
480528

481-
in_frame_grad = grads[-last_frame_n_atom:]
482-
cross_frame_grad = grads[:-last_frame_n_atom]
529+
in_frame_grad = grads[-last_frame_n_atom:]
530+
cross_frame_grad = grads[:-last_frame_n_atom]
483531

484-
assert cross_frame_grad.abs().max().item() == 0
485-
assert in_frame_grad.abs().max().item() > 0
532+
assert cross_frame_grad.abs().max().item() == 0, (
533+
f"Cross-frame gradient test failed for {energy_key}"
534+
)
535+
assert in_frame_grad.abs().max().item() > 0, (
536+
f"In-frame gradient test failed for {energy_key}"
537+
)
486538

487539
def test_numeric_gradient(self, model, atomic_batch, device):
488540
"""
489541
Tests the ForceStressOutput model by comparing numerical gradients of the forces to the analytical gradients.
490542
"""
491543
model, _, out_fields = model
492-
# proceed with tests only if forces are available
493-
if AtomicDataDict.FORCE_KEY in out_fields:
494-
# physical predictions (energy, forces, etc) will be converted to default_dtype (float64) before comparing
495-
data = AtomicDataDict.to_(atomic_batch, device)
496-
output = model(data)
497-
forces = output[AtomicDataDict.FORCE_KEY]
498-
epsilon = 1e-3
499544

500-
# Compute numerical gradients for each atom and direction and compare to analytical gradients
501-
for iatom in range(len(data[AtomicDataDict.POSITIONS_KEY])):
502-
for idir in range(3):
503-
# Shift `iatom` an `epsilon` in the `idir` direction
504-
pos = data[AtomicDataDict.POSITIONS_KEY][iatom, idir]
505-
data[AtomicDataDict.POSITIONS_KEY][iatom, idir] = pos + epsilon
506-
output = model(data)
507-
e_plus = (
508-
output[AtomicDataDict.TOTAL_ENERGY_KEY]
509-
.sum()
510-
.to(torch.get_default_dtype())
511-
)
512-
513-
# Shift `iatom` an `epsilon` in the negative `idir` direction
514-
data[AtomicDataDict.POSITIONS_KEY][iatom, idir] -= epsilon * 2
515-
output = model(data)
516-
e_minus = (
517-
output[AtomicDataDict.TOTAL_ENERGY_KEY]
518-
.sum()
519-
.to(torch.get_default_dtype())
520-
)
521-
522-
# Symmetric difference to get the partial forces to all the atoms
523-
numeric = -(e_plus - e_minus) / (epsilon * 2)
524-
analytical = forces[iatom, idir].to(torch.get_default_dtype())
525-
526-
assert torch.isclose(
527-
numeric, analytical, atol=2e-2
528-
) or torch.isclose(numeric, analytical, rtol=5e-3), (
529-
f"numeric: {numeric.item()}, analytical: {analytical.item()}"
530-
)
531-
532-
# Reset the position
533-
data[AtomicDataDict.POSITIONS_KEY][iatom, idir] += epsilon
545+
# test numeric gradients for each pair of (total_energy_key, force_key)
546+
for energy_key, force_key in zip(self.total_energy_keys(), self.force_keys()):
547+
# proceed with tests only if forces are available
548+
if force_key in out_fields:
549+
# physical predictions (energy, forces, etc) will be converted to default_dtype (float64) before comparing
550+
data = AtomicDataDict.to_(atomic_batch, device)
551+
output = model(data)
552+
forces = output[force_key]
553+
epsilon = 1e-3
554+
555+
# compute numerical gradients for each atom and direction and compare to analytical gradients
556+
for iatom in range(len(data[AtomicDataDict.POSITIONS_KEY])):
557+
for idir in range(3):
558+
# shift `iatom` an `epsilon` in the `idir` direction
559+
pos = data[AtomicDataDict.POSITIONS_KEY][iatom, idir]
560+
data[AtomicDataDict.POSITIONS_KEY][iatom, idir] = pos + epsilon
561+
output = model(data)
562+
e_plus = output[energy_key].sum().to(torch.get_default_dtype())
563+
564+
# shift `iatom` an `epsilon` in the negative `idir` direction
565+
data[AtomicDataDict.POSITIONS_KEY][iatom, idir] -= epsilon * 2
566+
output = model(data)
567+
e_minus = output[energy_key].sum().to(torch.get_default_dtype())
568+
569+
# symmetric difference to get the partial forces to all the atoms
570+
numeric = -(e_plus - e_minus) / (epsilon * 2)
571+
analytical = forces[iatom, idir].to(torch.get_default_dtype())
572+
573+
assert torch.isclose(
574+
numeric, analytical, atol=2e-2
575+
) or torch.isclose(numeric, analytical, rtol=5e-3), (
576+
f"numeric: {numeric.item()}, analytical: {analytical.item()} for {energy_key}/{force_key}"
577+
)
578+
579+
# reset the position
580+
data[AtomicDataDict.POSITIONS_KEY][iatom, idir] += epsilon
534581

535582
def test_partial_forces(
536583
self, model, partial_model, atomic_batch, device, strict_locality
@@ -802,10 +849,15 @@ def test_isolated_atom_energies(self, model, device):
802849
),
803850
device,
804851
)
805-
energies = instance(data)[AtomicDataDict.TOTAL_ENERGY_KEY]
806-
assert torch.allclose(
807-
energies, scale_shift_module.shifts.reshape(energies.shape)
808-
)
852+
output = instance(data)
853+
854+
# test isolated atom energies for all total energy keys
855+
for energy_key in self.total_energy_keys():
856+
if energy_key in output:
857+
energies = output[energy_key]
858+
assert torch.allclose(
859+
energies, scale_shift_module.shifts.reshape(energies.shape)
860+
), f"Isolated atom energy test failed for {energy_key}"
809861

810862
def test_embedding_cutoff(self, model, device):
811863
"""Test that edge embeddings/features go to zero at cutoff and gradients are correct."""
@@ -861,17 +913,20 @@ def test_embedding_cutoff(self, model, device):
861913
torch.zeros_like(grads),
862914
)
863915

864-
if AtomicDataDict.PER_ATOM_ENERGY_KEY in out:
865-
# are the first two atom's energies unaffected by atom at the cutoff?
866-
grads = torch.autograd.grad(
867-
outputs=out[AtomicDataDict.PER_ATOM_ENERGY_KEY][:2].sum(),
868-
inputs=in_dict[AtomicDataDict.POSITIONS_KEY],
869-
)[0]
870-
# only care about gradient wrt moved atom
871-
assert grads.shape == (3, 3), (
872-
f"Expected gradient shape (3, 3) for 3 atoms in 3D, got {grads.shape}"
873-
)
874-
torch.testing.assert_close(
875-
grads[2],
876-
torch.zeros_like(grads[2]),
877-
)
916+
# test that first two atoms' energies are unaffected by atom at cutoff
917+
for per_atom_energy_key in self.per_atom_energy_keys():
918+
if per_atom_energy_key in out:
919+
# are the first two atom's energies unaffected by atom at the cutoff?
920+
grads = torch.autograd.grad(
921+
outputs=out[per_atom_energy_key][:2].sum(),
922+
inputs=in_dict[AtomicDataDict.POSITIONS_KEY],
923+
retain_graph=True,
924+
)[0]
925+
# only care about gradient wrt moved atom
926+
assert grads.shape == (3, 3), (
927+
f"Expected gradient shape (3, 3) for 3 atoms in 3D, got {grads.shape}"
928+
)
929+
torch.testing.assert_close(
930+
grads[2],
931+
torch.zeros_like(grads[2]),
932+
)

0 commit comments

Comments
 (0)