Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store additional properties in ASE calculator #658

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@


import ase # isort: skip
import ase.md # isort: skip
import ase.neighborlist # isort: skip
import ase.calculators.calculator # isort: skip
from ase.calculators.calculator import ( # isort: skip
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
extensions_directory=None,
check_consistency=False,
device=None,
properties_to_store: Optional[List[str]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
properties_to_store: Optional[List[str]] = None,
extra_metatensor_outputs: Optional[Dict[str, ModelOutput]] = None,

I would rename this and handle it a bit differently from "normal" properties to make it clearer that this mechanism is only intended for things which are not already known ASE properties.

IMO the data should also be stored as TensorMap in calc.results["extra_metatensor_outputs"]["..."], again to make it clear we are operating outside of the ASE interface.

):
"""
:param model: model to use for the calculation. This can be a file path, a
Expand All @@ -84,6 +86,12 @@ def __init__(
running, defaults to False.
:param device: torch device to use for the calculation. If ``None``, we will try
the options in the model's ``supported_device`` in order.
:param properties_to_store: list of model outputs to store as results of the ASE
calculator at every step. This is useful when you want to store properties
that are not used in the propagation of the dynamics and/or are not standard
ASE properties ('energy', 'forces', 'stress', 'stresses', 'dipole',
'charges', 'magmom', 'magmoms', 'free_energy', 'energies'). These properties
will be available as ``atoms.calc.results['<stored_property>']``.
"""
super().__init__()

Expand Down Expand Up @@ -158,6 +166,10 @@ def __init__(
# We do our own check to verify if a property is implemented in `calculate()`,
# so we pretend to be able to compute all properties ASE knows about.
self.implemented_properties = ALL_ASE_PROPERTIES
self.properties_to_store = (
properties_to_store if properties_to_store is not None else []
)
self.additional_properties_to_store = []

def todict(self):
if "model_path" not in self.parameters:
Expand Down Expand Up @@ -253,8 +265,12 @@ def calculate(
system_changes=system_changes,
)

properties_to_calculate = (
properties + self.properties_to_store + self.additional_properties_to_store
)

with record_function("ASECalculator::prepare_inputs"):
outputs = _ase_properties_to_metatensor_outputs(properties)
outputs = _ase_properties_to_metatensor_outputs(properties_to_calculate)
capabilities = self._model.capabilities()
for name in outputs.keys():
if name not in capabilities.outputs:
Expand All @@ -268,11 +284,11 @@ def calculate(
)

do_backward = False
if "forces" in properties:
if "forces" in properties_to_calculate:
do_backward = True
positions.requires_grad_(True)

if "stress" in properties:
if "stress" in properties_to_calculate:
do_backward = True

strain = torch.eye(
Expand All @@ -284,7 +300,7 @@ def calculate(

cell = cell @ strain

if "stresses" in properties:
if "stresses" in properties_to_calculate:
raise NotImplementedError("'stresses' are not implemented yet")

run_options = ModelEvaluationOptions(
Expand Down Expand Up @@ -335,14 +351,14 @@ def calculate(

self.results = {}

if "energies" in properties:
if "energies" in properties_to_calculate:
energies_values = energies.detach().reshape(-1)
energies_values = energies_values.to(device="cpu").to(
dtype=torch.float64
)
self.results["energies"] = energies_values.numpy()

if "energy" in properties:
if "energy" in properties_to_calculate:
energy_values = energy.detach()
energy_values = energy_values.to(device="cpu").to(dtype=torch.float64)
self.results["energy"] = energy_values.numpy()[0, 0]
Expand All @@ -352,18 +368,54 @@ def calculate(
energy.backward(-torch.ones_like(energy))

with record_function("ASECalculator::convert_outputs"):
if "forces" in properties:
if "forces" in properties_to_calculate:
forces_values = system.positions.grad.reshape(-1, 3)
forces_values = forces_values.to(device="cpu").to(dtype=torch.float64)
self.results["forces"] = forces_values.numpy()

if "stress" in properties:
if "stress" in properties_to_calculate:
stress_values = -strain.grad.reshape(3, 3) / atoms.cell.volume
stress_values = stress_values.to(device="cpu").to(dtype=torch.float64)
self.results["stress"] = _full_3x3_to_voigt_6_stress(
stress_values.numpy()
)

def request_properties_every_n_steps(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't think this should be part of the same object. The calculator should follow the ASE API for calculators, and we should provide a separate tool (class, function, …) to handle calculations that have to run at other times.

self,
dyn: ase.md.md.MolecularDynamics,
properties: List[str],
n: int,
):
"""
Makes a property available every n steps of the dynamics.

:param dyn: ASE molecular dynamics object
:param properties: list of properties to be made available
at regular intervals
:param n: number of steps between each property calculation
"""

# prepare for step 0, where the properties must be available
self.additional_properties_to_store.extend(properties)

def _request_properties():
self.additional_properties_to_store.extend(properties)

def _unrequest_properties():
for prop in properties:
self.additional_properties_to_store.remove(prop)

def _manage_additional_properties():
step_count = dyn.nsteps
if step_count % n == n - 1:
_request_properties()
elif step_count % n == 0:
_unrequest_properties()
else:
pass

dyn.attach(_manage_additional_properties, interval=1, mode="step")


def _find_best_device(devices: List[str]) -> torch.device:
"""
Expand Down