Skip to content

Commit 1e60902

Browse files
committed
remove redundant code to simplify torchsim tests
1 parent fa8107b commit 1e60902

1 file changed

Lines changed: 42 additions & 61 deletions

File tree

nequip/utils/unittests/model_tests_torchsim.py

Lines changed: 42 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,9 @@
1515
import uuid
1616
import numpy as np
1717

18-
from nequip.data import to_ase
1918
from nequip.utils.versions import _TORCH_GE_2_6, _TORCH_GE_2_10
2019
from nequip.ase import NequIPCalculator
2120

22-
from hydra.utils import instantiate
2321
from .utils import _check_and_print
2422
from .model_tests_compilation import CompilationTestsMixin
2523

@@ -285,7 +283,9 @@ def test_torchsim_calculator_consistency(
285283
as reference, while TorchSim loads from batch-compiled model.
286284
"""
287285
compiled_path, mode = torchsim_compiled_model
288-
config, tmpdir, env, model_dtype, model_source, _ = fake_model_training_session
286+
config, tmpdir, env, model_dtype, model_source, structures = (
287+
fake_model_training_session
288+
)
289289

290290
# get model path for ASE calculator
291291
if model_source == "checkpoint":
@@ -308,56 +308,48 @@ def test_torchsim_calculator_consistency(
308308
chemical_species_to_atom_type_map=True,
309309
)
310310

311-
# get validation data
312-
datamodule = instantiate(config.data, _recursive_=False)
313-
datamodule.prepare_data()
314-
datamodule.setup("validate")
315-
dloader = datamodule.val_dataloader()[0]
316-
317311
# test on validation structures
318-
for data in dloader:
319-
atoms_list = to_ase(data.copy())
320-
for atoms in atoms_list:
321-
# NequIP calculator results
322-
nequip_atoms = atoms.copy()
323-
nequip_atoms.calc = nequip_calc
324-
nequip_E = nequip_atoms.get_potential_energy()
325-
nequip_F = nequip_atoms.get_forces()
326-
nequip_S = nequip_atoms.get_stress(voigt=False)
327-
328-
# TorchSim calculator results
329-
# convert atoms to SimState
330-
sim_state = ts.io.atoms_to_state(
331-
[atoms], device=device, dtype=torch.float64
332-
)
333-
ts_results = torchsim_calc(sim_state)
312+
for atoms in structures:
313+
# NequIP calculator results
314+
nequip_atoms = atoms.copy()
315+
nequip_atoms.calc = nequip_calc
316+
nequip_E = nequip_atoms.get_potential_energy()
317+
nequip_F = nequip_atoms.get_forces()
318+
nequip_S = nequip_atoms.get_stress(voigt=False)
319+
320+
# TorchSim calculator results
321+
# convert atoms to SimState
322+
sim_state = ts.io.atoms_to_state(
323+
[atoms], device=device, dtype=torch.float64
324+
)
325+
ts_results = torchsim_calc(sim_state)
334326

335-
# compare energies
336-
np.testing.assert_allclose(
337-
nequip_E,
338-
ts_results["energy"].cpu().numpy()[0],
339-
rtol=torchsim_tol,
340-
atol=torchsim_tol,
341-
)
327+
# compare energies
328+
np.testing.assert_allclose(
329+
nequip_E,
330+
ts_results["energy"].cpu().numpy()[0],
331+
rtol=torchsim_tol,
332+
atol=torchsim_tol,
333+
)
334+
335+
# compare forces
336+
np.testing.assert_allclose(
337+
nequip_F,
338+
ts_results["forces"].cpu().numpy(),
339+
rtol=torchsim_tol,
340+
atol=torchsim_tol,
341+
)
342342

343-
# compare forces
343+
# compare stress (if available)
344+
if "stress" in ts_results:
344345
np.testing.assert_allclose(
345-
nequip_F,
346-
ts_results["forces"].cpu().numpy(),
346+
nequip_S,
347+
ts_results["stress"].cpu().numpy()[0],
347348
rtol=torchsim_tol,
348349
atol=torchsim_tol,
349350
)
350351

351-
# compare stress (if available)
352-
if "stress" in ts_results:
353-
np.testing.assert_allclose(
354-
nequip_S,
355-
ts_results["stress"].cpu().numpy()[0],
356-
rtol=torchsim_tol,
357-
atol=torchsim_tol,
358-
)
359-
360-
del nequip_atoms, sim_state, ts_results
352+
del nequip_atoms, sim_state, ts_results
361353

362354
@pytest.mark.skipif(not _TORCHSIM_INSTALLED, reason="torch-sim not installed")
363355
@pytest.mark.parametrize("batch_size", [2, 3])
@@ -377,7 +369,7 @@ def test_torchsim_batched_evaluation(
377369
This is a key feature of torch-sim for efficient MD simulations.
378370
"""
379371
compiled_path, _ = torchsim_compiled_model
380-
config, _, _, _, _, _ = fake_model_training_session
372+
config, _, _, _, _, structures = fake_model_training_session
381373

382374
# load calculator
383375
torchsim_calc = NequIPTorchSimCalc.from_compiled_model(
@@ -386,22 +378,11 @@ def test_torchsim_batched_evaluation(
386378
chemical_species_to_atom_type_map=True,
387379
)
388380

389-
# get test structures
390-
datamodule = instantiate(config.data, _recursive_=False)
391-
datamodule.prepare_data()
392-
datamodule.setup("validate")
393-
dloader = datamodule.val_dataloader()[0]
394-
395-
structures = []
396-
for data in dloader:
397-
structures += to_ase(data.copy())
398-
if len(structures) >= batch_size:
399-
break
400-
381+
# check if we have enough structures for batched evaluation
401382
if len(structures) < batch_size:
402383
pytest.skip(f"Not enough structures for batch_size={batch_size}")
403384

404-
structures = structures[:batch_size]
385+
test_structures = structures[:batch_size]
405386

406387
# === test batched vs individual evaluation ===
407388

@@ -410,7 +391,7 @@ def test_torchsim_batched_evaluation(
410391
individual_forces = []
411392
individual_stresses = []
412393

413-
for atoms in structures:
394+
for atoms in test_structures:
414395
sim_state = ts.io.atoms_to_state(
415396
[atoms], device=device, dtype=torch.float64
416397
)
@@ -422,7 +403,7 @@ def test_torchsim_batched_evaluation(
422403

423404
# batched evaluation
424405
batched_sim_state = ts.io.atoms_to_state(
425-
structures, device=device, dtype=torch.float64
406+
test_structures, device=device, dtype=torch.float64
426407
)
427408
batched_result = torchsim_calc(batched_sim_state)
428409

0 commit comments

Comments
 (0)