1515import uuid
1616import numpy as np
1717
18- from nequip .data import to_ase
1918from nequip .utils .versions import _TORCH_GE_2_6 , _TORCH_GE_2_10
2019from nequip .ase import NequIPCalculator
2120
22- from hydra .utils import instantiate
2321from .utils import _check_and_print
2422from .model_tests_compilation import CompilationTestsMixin
2523
@@ -285,7 +283,9 @@ def test_torchsim_calculator_consistency(
285283 as reference, while TorchSim loads from batch-compiled model.
286284 """
287285 compiled_path , mode = torchsim_compiled_model
288- config , tmpdir , env , model_dtype , model_source , _ = fake_model_training_session
286+ config , tmpdir , env , model_dtype , model_source , structures = (
287+ fake_model_training_session
288+ )
289289
290290 # get model path for ASE calculator
291291 if model_source == "checkpoint" :
@@ -308,56 +308,48 @@ def test_torchsim_calculator_consistency(
308308 chemical_species_to_atom_type_map = True ,
309309 )
310310
311- # get validation data
312- datamodule = instantiate (config .data , _recursive_ = False )
313- datamodule .prepare_data ()
314- datamodule .setup ("validate" )
315- dloader = datamodule .val_dataloader ()[0 ]
316-
317311 # test on validation structures
318- for data in dloader :
319- atoms_list = to_ase (data .copy ())
320- for atoms in atoms_list :
321- # NequIP calculator results
322- nequip_atoms = atoms .copy ()
323- nequip_atoms .calc = nequip_calc
324- nequip_E = nequip_atoms .get_potential_energy ()
325- nequip_F = nequip_atoms .get_forces ()
326- nequip_S = nequip_atoms .get_stress (voigt = False )
327-
328- # TorchSim calculator results
329- # convert atoms to SimState
330- sim_state = ts .io .atoms_to_state (
331- [atoms ], device = device , dtype = torch .float64
332- )
333- ts_results = torchsim_calc (sim_state )
312+ for atoms in structures :
313+ # NequIP calculator results
314+ nequip_atoms = atoms .copy ()
315+ nequip_atoms .calc = nequip_calc
316+ nequip_E = nequip_atoms .get_potential_energy ()
317+ nequip_F = nequip_atoms .get_forces ()
318+ nequip_S = nequip_atoms .get_stress (voigt = False )
319+
320+ # TorchSim calculator results
321+ # convert atoms to SimState
322+ sim_state = ts .io .atoms_to_state (
323+ [atoms ], device = device , dtype = torch .float64
324+ )
325+ ts_results = torchsim_calc (sim_state )
334326
335- # compare energies
336- np .testing .assert_allclose (
337- nequip_E ,
338- ts_results ["energy" ].cpu ().numpy ()[0 ],
339- rtol = torchsim_tol ,
340- atol = torchsim_tol ,
341- )
327+ # compare energies
328+ np .testing .assert_allclose (
329+ nequip_E ,
330+ ts_results ["energy" ].cpu ().numpy ()[0 ],
331+ rtol = torchsim_tol ,
332+ atol = torchsim_tol ,
333+ )
334+
335+ # compare forces
336+ np .testing .assert_allclose (
337+ nequip_F ,
338+ ts_results ["forces" ].cpu ().numpy (),
339+ rtol = torchsim_tol ,
340+ atol = torchsim_tol ,
341+ )
342342
343- # compare forces
343+ # compare stress (if available)
344+ if "stress" in ts_results :
344345 np .testing .assert_allclose (
345- nequip_F ,
346- ts_results ["forces " ].cpu ().numpy (),
346+ nequip_S ,
347+ ts_results ["stress " ].cpu ().numpy ()[ 0 ] ,
347348 rtol = torchsim_tol ,
348349 atol = torchsim_tol ,
349350 )
350351
351- # compare stress (if available)
352- if "stress" in ts_results :
353- np .testing .assert_allclose (
354- nequip_S ,
355- ts_results ["stress" ].cpu ().numpy ()[0 ],
356- rtol = torchsim_tol ,
357- atol = torchsim_tol ,
358- )
359-
360- del nequip_atoms , sim_state , ts_results
352+ del nequip_atoms , sim_state , ts_results
361353
362354 @pytest .mark .skipif (not _TORCHSIM_INSTALLED , reason = "torch-sim not installed" )
363355 @pytest .mark .parametrize ("batch_size" , [2 , 3 ])
@@ -377,7 +369,7 @@ def test_torchsim_batched_evaluation(
377369 This is a key feature of torch-sim for efficient MD simulations.
378370 """
379371 compiled_path , _ = torchsim_compiled_model
380- config , _ , _ , _ , _ , _ = fake_model_training_session
372+ config , _ , _ , _ , _ , structures = fake_model_training_session
381373
382374 # load calculator
383375 torchsim_calc = NequIPTorchSimCalc .from_compiled_model (
@@ -386,22 +378,11 @@ def test_torchsim_batched_evaluation(
386378 chemical_species_to_atom_type_map = True ,
387379 )
388380
389- # get test structures
390- datamodule = instantiate (config .data , _recursive_ = False )
391- datamodule .prepare_data ()
392- datamodule .setup ("validate" )
393- dloader = datamodule .val_dataloader ()[0 ]
394-
395- structures = []
396- for data in dloader :
397- structures += to_ase (data .copy ())
398- if len (structures ) >= batch_size :
399- break
400-
381+ # check if we have enough structures for batched evaluation
401382 if len (structures ) < batch_size :
402383 pytest .skip (f"Not enough structures for batch_size={ batch_size } " )
403384
404- structures = structures [:batch_size ]
385+ test_structures = structures [:batch_size ]
405386
406387 # === test batched vs individual evaluation ===
407388
@@ -410,7 +391,7 @@ def test_torchsim_batched_evaluation(
410391 individual_forces = []
411392 individual_stresses = []
412393
413- for atoms in structures :
394+ for atoms in test_structures :
414395 sim_state = ts .io .atoms_to_state (
415396 [atoms ], device = device , dtype = torch .float64
416397 )
@@ -422,7 +403,7 @@ def test_torchsim_batched_evaluation(
422403
423404 # batched evaluation
424405 batched_sim_state = ts .io .atoms_to_state (
425- structures , device = device , dtype = torch .float64
406+ test_structures , device = device , dtype = torch .float64
426407 )
427408 batched_result = torchsim_calc (batched_sim_state )
428409
0 commit comments