diff --git a/python/metatensor-torch/metatensor/torch/atomistic/ase_calculator.py b/python/metatensor-torch/metatensor/torch/atomistic/ase_calculator.py index 8d86e112c..c5f178832 100644 --- a/python/metatensor-torch/metatensor/torch/atomistic/ase_calculator.py +++ b/python/metatensor-torch/metatensor/torch/atomistic/ase_calculator.py @@ -21,6 +21,7 @@ import ase # isort: skip +import ase.md # isort: skip import ase.neighborlist # isort: skip import ase.calculators.calculator # isort: skip from ase.calculators.calculator import ( # isort: skip @@ -73,6 +74,7 @@ def __init__( extensions_directory=None, check_consistency=False, device=None, + properties_to_store: Optional[List[str]] = None, ): """ :param model: model to use for the calculation. This can be a file path, a @@ -84,6 +86,12 @@ def __init__( running, defaults to False. :param device: torch device to use for the calculation. If ``None``, we will try the options in the model's ``supported_device`` in order. + :param properties_to_store: list of model outputs to store as results of the ASE + calculator at every step. This is useful when you want to store properties + that are not used in the propagation of the dynamics and/or are not standard + ASE properties ('energy', 'forces', 'stress', 'stresses', 'dipole', + 'charges', 'magmom', 'magmoms', 'free_energy', 'energies'). These properties + will be available as ``atoms.calc.results['']``. """ super().__init__() @@ -158,6 +166,10 @@ def __init__( # We do our own check to verify if a property is implemented in `calculate()`, # so we pretend to be able to compute all properties ASE knows about. self.implemented_properties = ALL_ASE_PROPERTIES + self.properties_to_store = ( + properties_to_store if properties_to_store is not None else [] + ) + self.additional_properties_to_store = [] def todict(self): if "model_path" not in self.parameters: @@ -253,8 +265,12 @@ def calculate( system_changes=system_changes, ) + properties_to_calculate = ( + properties + self.properties_to_store + self.additional_properties_to_store + ) + with record_function("ASECalculator::prepare_inputs"): - outputs = _ase_properties_to_metatensor_outputs(properties) + outputs = _ase_properties_to_metatensor_outputs(properties_to_calculate) capabilities = self._model.capabilities() for name in outputs.keys(): if name not in capabilities.outputs: @@ -268,11 +284,11 @@ def calculate( ) do_backward = False - if "forces" in properties: + if "forces" in properties_to_calculate: do_backward = True positions.requires_grad_(True) - if "stress" in properties: + if "stress" in properties_to_calculate: do_backward = True strain = torch.eye( @@ -284,7 +300,7 @@ def calculate( cell = cell @ strain - if "stresses" in properties: + if "stresses" in properties_to_calculate: raise NotImplementedError("'stresses' are not implemented yet") run_options = ModelEvaluationOptions( @@ -335,14 +351,14 @@ def calculate( self.results = {} - if "energies" in properties: + if "energies" in properties_to_calculate: energies_values = energies.detach().reshape(-1) energies_values = energies_values.to(device="cpu").to( dtype=torch.float64 ) self.results["energies"] = energies_values.numpy() - if "energy" in properties: + if "energy" in properties_to_calculate: energy_values = energy.detach() energy_values = energy_values.to(device="cpu").to(dtype=torch.float64) self.results["energy"] = energy_values.numpy()[0, 0] @@ -352,18 +368,54 @@ def calculate( energy.backward(-torch.ones_like(energy)) with record_function("ASECalculator::convert_outputs"): - if "forces" in properties: + if "forces" in properties_to_calculate: forces_values = system.positions.grad.reshape(-1, 3) forces_values = forces_values.to(device="cpu").to(dtype=torch.float64) self.results["forces"] = forces_values.numpy() - if "stress" in properties: + if "stress" in properties_to_calculate: stress_values = -strain.grad.reshape(3, 3) / atoms.cell.volume stress_values = stress_values.to(device="cpu").to(dtype=torch.float64) self.results["stress"] = _full_3x3_to_voigt_6_stress( stress_values.numpy() ) + def request_properties_every_n_steps( + self, + dyn: ase.md.md.MolecularDynamics, + properties: List[str], + n: int, + ): + """ + Makes a property available every n steps of the dynamics. + + :param dyn: ASE molecular dynamics object + :param properties: list of properties to be made available + at regular intervals + :param n: number of steps between each property calculation + """ + + # prepare for step 0, where the properties must be available + self.additional_properties_to_store.extend(properties) + + def _request_properties(): + self.additional_properties_to_store.extend(properties) + + def _unrequest_properties(): + for prop in properties: + self.additional_properties_to_store.remove(prop) + + def _manage_additional_properties(): + step_count = dyn.nsteps + if step_count % n == n - 1: + _request_properties() + elif step_count % n == 0: + _unrequest_properties() + else: + pass + + dyn.attach(_manage_additional_properties, interval=1, mode="step") + def _find_best_device(devices: List[str]) -> torch.device: """