Skip to content

Commit c8f56f3

Browse files
committed
make test_embedding_cutoff more verbose
1 parent 05561c2 commit c8f56f3

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

nequip/utils/unittests/model_tests_basic.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -833,10 +833,14 @@ def test_embedding_cutoff(self, model, device):
833833
if key == AtomicDataDict.EDGE_EMBEDDING_KEY:
834834
# we can only check that other edges are unaffected if we know it's an embedding
835835
# For example, an Allegro edge feature is many body so will be affected
836-
assert torch.allclose(edge_embed[:2], edge_embed2[:2])
837-
assert edge_embed[2:].abs().sum() > 1e-6 # some nonzero terms
838-
assert torch.allclose(
839-
edge_embed2[2:], torch.zeros(1, device=device, dtype=edge_embed2.dtype)
836+
torch.testing.assert_close(edge_embed[:2], edge_embed2[:2])
837+
assert edge_embed[2:].abs().sum() > 1e-6, (
838+
f"Edge embeddings at cutoff should have some nonzero terms before moving atom to cutoff, "
839+
f"but got sum of absolute values: {edge_embed[2:].abs().sum()}"
840+
)
841+
torch.testing.assert_close(
842+
edge_embed2[2:],
843+
torch.zeros_like(edge_embed2[2:]),
840844
)
841845

842846
# test gradients
@@ -852,8 +856,9 @@ def test_embedding_cutoff(self, model, device):
852856
inputs=in_dict[AtomicDataDict.POSITIONS_KEY],
853857
retain_graph=True,
854858
)[0]
855-
assert torch.allclose(
856-
grads, torch.zeros(1, device=device, dtype=grads.dtype)
859+
torch.testing.assert_close(
860+
grads,
861+
torch.zeros_like(grads),
857862
)
858863

859864
if AtomicDataDict.PER_ATOM_ENERGY_KEY in out:
@@ -862,7 +867,11 @@ def test_embedding_cutoff(self, model, device):
862867
outputs=out[AtomicDataDict.PER_ATOM_ENERGY_KEY][:2].sum(),
863868
inputs=in_dict[AtomicDataDict.POSITIONS_KEY],
864869
)[0]
865-
print(grads)
866870
# only care about gradient wrt moved atom
867-
assert grads.shape == (3, 3)
868-
assert torch.allclose(grads[2], torch.zeros(1, device=device))
871+
assert grads.shape == (3, 3), (
872+
f"Expected gradient shape (3, 3) for 3 atoms in 3D, got {grads.shape}"
873+
)
874+
torch.testing.assert_close(
875+
grads[2],
876+
torch.zeros_like(grads[2]),
877+
)

0 commit comments

Comments
 (0)