From 8c403c12e5315114ca8070d7fc8ec84ce2e50ea5 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 5 Jul 2024 08:41:03 +0200 Subject: [PATCH 01/16] Register `ModelMetadataHolder::print()` --- metatensor-torch/src/register.cpp | 1 + python/metatensor-torch/tests/atomistic/metadata.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/metatensor-torch/src/register.cpp b/metatensor-torch/src/register.cpp index 461dbda40..b6e613ef3 100644 --- a/metatensor-torch/src/register.cpp +++ b/metatensor-torch/src/register.cpp @@ -382,6 +382,7 @@ TORCH_LIBRARY(metatensor, m) { ) .def("__repr__", &ModelMetadataHolder::print) .def("__str__", &ModelMetadataHolder::print) + .def("print", &ModelMetadataHolder::print) .def_readwrite("name", &ModelMetadataHolder::name) .def_readwrite("description", &ModelMetadataHolder::description) .def_readwrite("authors", &ModelMetadataHolder::authors) diff --git a/python/metatensor-torch/tests/atomistic/metadata.py b/python/metatensor-torch/tests/atomistic/metadata.py index fefcb817b..dc7bcee6d 100644 --- a/python/metatensor-torch/tests/atomistic/metadata.py +++ b/python/metatensor-torch/tests/atomistic/metadata.py @@ -226,3 +226,5 @@ def forward(self, x: ModelMetadataWrap) -> ModelMetadataWrap: * ref-4 """ assert str(metadata) == expected + assert metadata.__repr__() == expected + assert metadata.print() == expected From 1e3a0a62ef2f14f4ad0abf50b57a7668872fec6c Mon Sep 17 00:00:00 2001 From: Filippo Bigi <98903385+frostedoyster@users.noreply.github.com> Date: Mon, 17 Jun 2024 12:45:16 +0200 Subject: [PATCH 02/16] Add OpenMM interface --- python/metatensor-operations/CHANGELOG.md | 4 - .../torch/atomistic/openmm_interface.py | 138 ++++++++++++++++++ 2 files changed, 138 insertions(+), 4 deletions(-) create mode 100644 python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py diff --git a/python/metatensor-operations/CHANGELOG.md b/python/metatensor-operations/CHANGELOG.md index 771a2602d..aa72a50b6 100644 --- a/python/metatensor-operations/CHANGELOG.md +++ b/python/metatensor-operations/CHANGELOG.md @@ -17,10 +17,6 @@ follows [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ### Removed --> -### Changed - -- We now require Python >= 3.9 - ## [Version 0.2.2](https://github.com/lab-cosmo/metatensor/releases/tag/metatensor-operations-v0.2.2) - 2024-06-19 ### Fixed diff --git a/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py b/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py new file mode 100644 index 000000000..7162309a4 --- /dev/null +++ b/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py @@ -0,0 +1,138 @@ +import torch +from typing import Iterable, Optional +from metatensor.torch.atomistic import load_atomistic_model, System, ModelOutput, ModelEvaluationOptions +from metatensor.torch import Labels +from typing import List + +try: + import openmm + import openmmtorch + from openmmml.mlpotential import MLPotential, MLPotentialImpl, MLPotentialImplFactory + HAS_OPENMM = True +except ImportError as e: + class MLPotential: + pass + class MLPotentialImpl: + pass + class MLPotentialImplFactory: + pass + HAS_OPENMM = False + + +class MetatensorPotentialImplFactory(MLPotentialImplFactory): + + def createImpl( + name: str, **args + ) -> MLPotentialImpl: + # TODO: extensions_directory + return MetatensorPotentialImpl(name, **args) + + +class MetatensorPotentialImpl(MLPotentialImpl): + + def __init__(self, name: str, path: str) -> None: + self.path = path + + def addForces( + self, + topology: openmm.app.Topology, + system: openmm.System, + atoms: Optional[Iterable[int]], + forceGroup: int, + **args, + ) -> None: + + if not HAS_OPENMM: + raise ImportError( + "Could not import openmm. If you want to use metatensor with " + "openmm, please install openmm-ml with conda." + ) + + model = load_atomistic_model( + self.path # TODO: extensions_directory + ) + + # Get the atomic numbers of the ML region. + all_atoms = list(topology.atoms()) + atomic_numbers = [atom.element.atomic_number for atom in all_atoms] + + # TODO: Set up selected_atoms as a Labels object + if atoms is None: + selected_atoms = None + else: + selected_atoms = Labels( + names=["system", "atom"], + values=torch.tensor( + [[0, selected_atom] for selected_atom in atoms], + dtype=torch.int32, + ), + ) + + class MetatensorForce(torch.nn.Module): + + def __init__( + self, + model: torch.jit._script.RecursiveScriptModule, + atomic_numbers: List[int], + selected_atoms: Optional[Labels], + ) -> None: + super(MetatensorForce, self).__init__() + + self.model = model + self.register_buffer("atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int32)) + self.evaluation_options = ModelEvaluationOptions( + length_unit='nm', + outputs={ + "energy": ModelOutput( + quantity="energy", + unit="kJ/mol", + per_atom=False, + ), + }, + selected_atoms=selected_atoms, + ) + + + def forward( + self, positions: torch.Tensor, cell: Optional[torch.Tensor] = None + ) -> torch.Tensor: + # move labels if necessary + selected_atoms = self.evaluation_options.selected_atoms + if selected_atoms is not None: + if selected_atoms.device != positions.device: + self.evaluation_options.selected_atoms = selected_atoms.to(positions.device) + + if cell is None: + cell = torch.zeros((3, 3), dtype=positions.dtype, device=positions.device) + + # create System + system = System( + types=self.atomic_numbers, + positions=positions, + cell=cell, + ) + + energy = self.model([system], self.evaluation_options, check_consistency=True)["energy"].block().values.reshape(()) + return energy + + metatensor_force = MetatensorForce( + model, + atomic_numbers, + selected_atoms, + ) + + # torchscript everything + module = torch.jit.script(metatensor_force) + + # create the OpenMM force + force = openmmtorch.TorchForce(module) + isPeriodic = ( + topology.getPeriodicBoxVectors() is not None + ) or system.usesPeriodicBoundaryConditions() + force.setUsesPeriodicBoundaryConditions(isPeriodic) + force.setForceGroup(forceGroup) + + system.addForce(force) + + +MLPotential.registerImplFactory("metatensor", MetatensorPotentialImplFactory) From 1792629d68911e38703ea1d04e703a82916f0aa6 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 21 Jun 2024 10:51:54 +0200 Subject: [PATCH 03/16] Linter --- .../torch/atomistic/openmm_interface.py | 61 +++++++++++++------ 1 file changed, 42 insertions(+), 19 deletions(-) diff --git a/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py b/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py index 7162309a4..33029da3f 100644 --- a/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py +++ b/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py @@ -1,29 +1,43 @@ +from typing import Iterable, List, Optional + import torch -from typing import Iterable, Optional -from metatensor.torch.atomistic import load_atomistic_model, System, ModelOutput, ModelEvaluationOptions + from metatensor.torch import Labels -from typing import List +from metatensor.torch.atomistic import ( + ModelEvaluationOptions, + ModelOutput, + System, + load_atomistic_model, +) + try: import openmm import openmmtorch - from openmmml.mlpotential import MLPotential, MLPotentialImpl, MLPotentialImplFactory + from openmmml.mlpotential import ( + MLPotential, + MLPotentialImpl, + MLPotentialImplFactory, + ) + HAS_OPENMM = True -except ImportError as e: +except ImportError: + class MLPotential: pass + class MLPotentialImpl: pass + class MLPotentialImplFactory: pass + HAS_OPENMM = False class MetatensorPotentialImplFactory(MLPotentialImplFactory): - def createImpl( - name: str, **args - ) -> MLPotentialImpl: + def createImpl(name: str, **args) -> MLPotentialImpl: # TODO: extensions_directory return MetatensorPotentialImpl(name, **args) @@ -47,10 +61,8 @@ def addForces( "Could not import openmm. If you want to use metatensor with " "openmm, please install openmm-ml with conda." ) - - model = load_atomistic_model( - self.path # TODO: extensions_directory - ) + + model = load_atomistic_model(self.path) # TODO: extensions_directory # Get the atomic numbers of the ML region. all_atoms = list(topology.atoms()) @@ -79,9 +91,11 @@ def __init__( super(MetatensorForce, self).__init__() self.model = model - self.register_buffer("atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int32)) + self.register_buffer( + "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int32) + ) self.evaluation_options = ModelEvaluationOptions( - length_unit='nm', + length_unit="nm", outputs={ "energy": ModelOutput( quantity="energy", @@ -92,7 +106,6 @@ def __init__( selected_atoms=selected_atoms, ) - def forward( self, positions: torch.Tensor, cell: Optional[torch.Tensor] = None ) -> torch.Tensor: @@ -100,10 +113,14 @@ def forward( selected_atoms = self.evaluation_options.selected_atoms if selected_atoms is not None: if selected_atoms.device != positions.device: - self.evaluation_options.selected_atoms = selected_atoms.to(positions.device) - + self.evaluation_options.selected_atoms = selected_atoms.to( + positions.device + ) + if cell is None: - cell = torch.zeros((3, 3), dtype=positions.dtype, device=positions.device) + cell = torch.zeros( + (3, 3), dtype=positions.dtype, device=positions.device + ) # create System system = System( @@ -112,7 +129,13 @@ def forward( cell=cell, ) - energy = self.model([system], self.evaluation_options, check_consistency=True)["energy"].block().values.reshape(()) + energy = ( + self.model( + [system], self.evaluation_options, check_consistency=True + )["energy"] + .block() + .values.reshape(()) + ) return energy metatensor_force = MetatensorForce( From a673e36cabdacf2c98937e4673fc280a9bd64566 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 21 Jun 2024 17:00:57 +0200 Subject: [PATCH 04/16] Change interface to openmm-torch --- .../torch/atomistic/openmm_interface.py | 203 ++++++++---------- 1 file changed, 92 insertions(+), 111 deletions(-) diff --git a/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py b/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py index 33029da3f..b1d7b537a 100644 --- a/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py +++ b/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py @@ -14,11 +14,6 @@ try: import openmm import openmmtorch - from openmmml.mlpotential import ( - MLPotential, - MLPotentialImpl, - MLPotentialImplFactory, - ) HAS_OPENMM = True except ImportError: @@ -35,127 +30,113 @@ class MLPotentialImplFactory: HAS_OPENMM = False -class MetatensorPotentialImplFactory(MLPotentialImplFactory): - - def createImpl(name: str, **args) -> MLPotentialImpl: - # TODO: extensions_directory - return MetatensorPotentialImpl(name, **args) +def attach_metatensor_force( + system: openmm.System, + topology: openmm.app.Topology, + path: str, + extensions_directory: Optional[str] = None, + forceGroup: int = 0, + atoms: Optional[Iterable[int]] = None, +) -> openmm.System: + if not HAS_OPENMM: + raise ImportError( + "Could not import openmm. If you want to use metatensor with " + "openmm, please install openmm-ml with conda." + ) -class MetatensorPotentialImpl(MLPotentialImpl): + model = load_atomistic_model(path, extensions_directory=extensions_directory) + + # Get the atomic numbers of the ML region. + all_atoms = list(topology.atoms()) + atomic_numbers = [atom.element.atomic_number for atom in all_atoms] + + if atoms is None: + selected_atoms = None + else: + selected_atoms = Labels( + names=["system", "atom"], + values=torch.tensor( + [[0, selected_atom] for selected_atom in atoms], + dtype=torch.int32, + ), + ) - def __init__(self, name: str, path: str) -> None: - self.path = path + class MetatensorForce(torch.nn.Module): - def addForces( - self, - topology: openmm.app.Topology, - system: openmm.System, - atoms: Optional[Iterable[int]], - forceGroup: int, - **args, - ) -> None: + def __init__( + self, + model: torch.jit._script.RecursiveScriptModule, + atomic_numbers: List[int], + selected_atoms: Optional[Labels], + ) -> None: + super(MetatensorForce, self).__init__() - if not HAS_OPENMM: - raise ImportError( - "Could not import openmm. If you want to use metatensor with " - "openmm, please install openmm-ml with conda." + self.model = model + self.register_buffer( + "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int32) ) - - model = load_atomistic_model(self.path) # TODO: extensions_directory - - # Get the atomic numbers of the ML region. - all_atoms = list(topology.atoms()) - atomic_numbers = [atom.element.atomic_number for atom in all_atoms] - - # TODO: Set up selected_atoms as a Labels object - if atoms is None: - selected_atoms = None - else: - selected_atoms = Labels( - names=["system", "atom"], - values=torch.tensor( - [[0, selected_atom] for selected_atom in atoms], - dtype=torch.int32, - ), + self.evaluation_options = ModelEvaluationOptions( + length_unit="nm", + outputs={ + "energy": ModelOutput( + quantity="energy", + unit="kJ/mol", + per_atom=False, + ), + }, + selected_atoms=selected_atoms, ) - class MetatensorForce(torch.nn.Module): - - def __init__( - self, - model: torch.jit._script.RecursiveScriptModule, - atomic_numbers: List[int], - selected_atoms: Optional[Labels], - ) -> None: - super(MetatensorForce, self).__init__() - - self.model = model - self.register_buffer( - "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int32) - ) - self.evaluation_options = ModelEvaluationOptions( - length_unit="nm", - outputs={ - "energy": ModelOutput( - quantity="energy", - unit="kJ/mol", - per_atom=False, - ), - }, - selected_atoms=selected_atoms, - ) - - def forward( - self, positions: torch.Tensor, cell: Optional[torch.Tensor] = None - ) -> torch.Tensor: - # move labels if necessary - selected_atoms = self.evaluation_options.selected_atoms - if selected_atoms is not None: - if selected_atoms.device != positions.device: - self.evaluation_options.selected_atoms = selected_atoms.to( - positions.device - ) - - if cell is None: - cell = torch.zeros( - (3, 3), dtype=positions.dtype, device=positions.device + def forward( + self, positions: torch.Tensor, cell: Optional[torch.Tensor] = None + ) -> torch.Tensor: + # move labels if necessary + selected_atoms = self.evaluation_options.selected_atoms + if selected_atoms is not None: + if selected_atoms.device != positions.device: + self.evaluation_options.selected_atoms = selected_atoms.to( + positions.device ) - # create System - system = System( - types=self.atomic_numbers, - positions=positions, - cell=cell, + if cell is None: + cell = torch.zeros( + (3, 3), dtype=positions.dtype, device=positions.device ) - energy = ( - self.model( - [system], self.evaluation_options, check_consistency=True - )["energy"] - .block() - .values.reshape(()) - ) - return energy + # create System + system = System( + types=self.atomic_numbers, + positions=positions, + cell=cell, + ) - metatensor_force = MetatensorForce( - model, - atomic_numbers, - selected_atoms, - ) + energy = ( + self.model([system], self.evaluation_options, check_consistency=True)[ + "energy" + ] + .block() + .values.reshape(()) + ) + return energy - # torchscript everything - module = torch.jit.script(metatensor_force) + metatensor_force = MetatensorForce( + model, + atomic_numbers, + selected_atoms, + ) - # create the OpenMM force - force = openmmtorch.TorchForce(module) - isPeriodic = ( - topology.getPeriodicBoxVectors() is not None - ) or system.usesPeriodicBoundaryConditions() - force.setUsesPeriodicBoundaryConditions(isPeriodic) - force.setForceGroup(forceGroup) + # torchscript everything + module = torch.jit.script(metatensor_force) - system.addForce(force) + # create the OpenMM force + force = openmmtorch.TorchForce(module) + isPeriodic = ( + topology.getPeriodicBoxVectors() is not None + ) or system.usesPeriodicBoundaryConditions() + force.setUsesPeriodicBoundaryConditions(isPeriodic) + force.setForceGroup(forceGroup) + system.addForce(force) -MLPotential.registerImplFactory("metatensor", MetatensorPotentialImplFactory) + return system From cea817b9fb024c519c519cdaf02c8ab5ec17058a Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 21 Jun 2024 19:07:03 +0200 Subject: [PATCH 05/16] Let the user add the force to the system --- .../metatensor/torch/atomistic/openmm_interface.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py b/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py index b1d7b537a..96f7e23c5 100644 --- a/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py +++ b/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py @@ -30,7 +30,7 @@ class MLPotentialImplFactory: HAS_OPENMM = False -def attach_metatensor_force( +def get_metatensor_force( system: openmm.System, topology: openmm.app.Topology, path: str, @@ -137,6 +137,4 @@ def forward( force.setUsesPeriodicBoundaryConditions(isPeriodic) force.setForceGroup(forceGroup) - system.addForce(force) - - return system + return force From f2e3477d6187f308024a43945ffdceb68eb36e49 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 21 Jun 2024 19:28:38 +0200 Subject: [PATCH 06/16] Add basic documentation --- docs/src/atomistic/reference/index.rst | 1 + docs/src/atomistic/reference/openmm.rst | 41 +++++++++++++++++++++++++ 2 files changed, 42 insertions(+) create mode 100644 docs/src/atomistic/reference/openmm.rst diff --git a/docs/src/atomistic/reference/index.rst b/docs/src/atomistic/reference/index.rst index b54bb7fd8..671e3dc24 100644 --- a/docs/src/atomistic/reference/index.rst +++ b/docs/src/atomistic/reference/index.rst @@ -38,6 +38,7 @@ atomistic systems as input of a machine learning model with systems models/index ase + openmm C++ API reference diff --git a/docs/src/atomistic/reference/openmm.rst b/docs/src/atomistic/reference/openmm.rst new file mode 100644 index 000000000..87739d9ea --- /dev/null +++ b/docs/src/atomistic/reference/openmm.rst @@ -0,0 +1,41 @@ +OpenMM integration +================== + +.. py:currentmodule:: metatensor.torch.atomistic + +:py:mod:`openmm_interface` contains the ``get_metatensor_force`` function, +which can be used to load a :py:class:`MetatensorAtomisticModel` into an +``openmm.Force`` object able to calculate forces on an ``openmm.System``. + +.. autoclass:: metatensor.torch.atomistic.openmm_interface.get_metatensor_force + :show-inheritance: + :members: + +In order to run simulations with ``metatensor.torch.atomistic`` and ``OpenMM``, +we recommend installing ``OpenMM`` from conda, using +``conda install -c conda-forge openmm-torch nnpops``. Subsequently, +metatensor can be installed with ``pip install metatensor[torch]``, and a minimal +script demonstrating how to run a simple simulation is illustrated below: + +.. code-block:: python + + import openmm + from metatensor.torch.atomistic.openmm_interface import get_metatensor_force + + # load an input geometry file + topology = openmm.app.PDBFile('input.pdb').getTopology() + system = openmm.System() + for atom in topology.atoms(): + system.addParticle(atom.element.mass) + + # get the force object from an exported model saved as 'model.pt' + force = get_metatensor_force(system, topology, 'model.pt') + system.addForce(force) + + integrator = openmm.VerletIntegrator(0.001) + platform = openmm.Platform.getPlatformByName('CUDA') + simulation = openmm.app.Simulation(topology, system, integrator, platform) + simulation.context.setPositions(openmm.app.PDBFile('input.pdb').getPositions()) + simulation.reporters.append(openmm.app.PDBReporter('output.pdb', 100)) + + simulation.step(1000) From 21b26296e743ce2314d313053bd143a5504c6101 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 27 Jun 2024 20:36:49 +0200 Subject: [PATCH 07/16] Add neighbor list --- .../torch/atomistic/openmm_interface.py | 103 +++++++++++++++++- 1 file changed, 97 insertions(+), 6 deletions(-) diff --git a/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py b/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py index 96f7e23c5..df6692481 100644 --- a/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py +++ b/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py @@ -2,7 +2,7 @@ import torch -from metatensor.torch import Labels +from metatensor.torch import Labels, TensorBlock from metatensor.torch.atomistic import ( ModelEvaluationOptions, ModelOutput, @@ -12,6 +12,7 @@ try: + import NNPOps import openmm import openmmtorch @@ -37,12 +38,13 @@ def get_metatensor_force( extensions_directory: Optional[str] = None, forceGroup: int = 0, atoms: Optional[Iterable[int]] = None, + check_consistency: bool = False, ) -> openmm.System: if not HAS_OPENMM: raise ImportError( - "Could not import openmm. If you want to use metatensor with " - "openmm, please install openmm-ml with conda." + "Could not import openmm and/or nnpops. If you want to use metatensor with " + "openmm, please install openmm-torch and nnpops with conda." ) model = load_atomistic_model(path, extensions_directory=extensions_directory) @@ -69,6 +71,7 @@ def __init__( model: torch.jit._script.RecursiveScriptModule, atomic_numbers: List[int], selected_atoms: Optional[Labels], + check_consistency: bool, ) -> None: super(MetatensorForce, self).__init__() @@ -88,6 +91,91 @@ def __init__( selected_atoms=selected_atoms, ) + requested_nls = self.model.requested_neighbor_lists() + if len(requested_nls) > 1: + raise ValueError( + "The model requested more than one neighbor list. " + "Currently, only models with a single neighbor list are supported " + "by the OpenMM interface." + ) + elif len(requested_nls) == 1: + self.requested_neighbor_list = requested_nls[0] + else: + # no neighbor list requested + self.requested_neighbor_list = None + + self.check_consistency = check_consistency + + def _attach_neighbors(self, system: System) -> System: + + if self.requested_neighbor_list is None: + return system + + cell: Optional[torch.Tensor] = None + if not torch.all(system.cell == 0.0): + cell = system.cell + + # Get the neighbor pairs, shifts and edge indices. + neighbors, interatomic_vectors, _, _ = NNPOps.getNeighborPairs( + system.positions, self.requested_neighbor_list.engine_cutoff(), -1, cell + ) + mask = neighbors[0] >= 0 + neighbors = neighbors[:, mask] + interatomic_vectors = interatomic_vectors[mask, :] + + if self.requested_neighbor_list.full_list: + neighbors = torch.stack((neighbors, neighbors.flip(0)), dim=1) + interatomic_vectors = torch.stack( + (interatomic_vectors, -interatomic_vectors) + ) + + if cell is not None: + interatomic_vectors_unit_cell = ( + system.positions[interatomic_vectors[0]] + - system.positions[interatomic_vectors[1]] + ) + cell_shifts = ( + interatomic_vectors_unit_cell - interatomic_vectors + ) @ torch.linalg.inv(cell) + cell_shifts = torch.round(cell_shifts).to(torch.int32) + else: + cell_shifts = torch.zeros( + (neighbors.shape[1], 3), + dtype=self.dtype, + device=system.positions.device, + ) + + neighbor_list = TensorBlock( + values=interatomic_vectors.reshape(-1, 3, 1), + samples=Labels( + names=[ + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ], + values=torch.stack([neighbors.T, cell_shifts], dim=-1), + ), + components=[ + Labels( + names=["xyz"], + values=torch.arange( + 3, dtype=torch.int32, device=system.positions.device + ), + ) + ], + properties=Labels( + names=["distance"], + values=torch.tensor( + [[0]], dtype=torch.int32, device=system.positions.device + ), + ), + ) + + system.add_neighbor_list(self.requested_neighbor_list, neighbor_list) + return system + def forward( self, positions: torch.Tensor, cell: Optional[torch.Tensor] = None ) -> torch.Tensor: @@ -112,9 +200,11 @@ def forward( ) energy = ( - self.model([system], self.evaluation_options, check_consistency=True)[ - "energy" - ] + self.model( + [system], + self.evaluation_options, + check_consistency=self.check_consistency, + )["energy"] .block() .values.reshape(()) ) @@ -124,6 +214,7 @@ def forward( model, atomic_numbers, selected_atoms, + check_consistency, ) # torchscript everything From 7e2f643503ddb1fdd34631188b79e45cafdbc0e9 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 27 Jun 2024 21:05:53 +0200 Subject: [PATCH 08/16] Fix bugs --- .../torch/atomistic/openmm_interface.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py b/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py index df6692481..d8bdaed3d 100644 --- a/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py +++ b/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py @@ -12,9 +12,9 @@ try: - import NNPOps import openmm import openmmtorch + from NNPOps.neighbors import getNeighborPairs HAS_OPENMM = True except ImportError: @@ -116,16 +116,19 @@ def _attach_neighbors(self, system: System) -> System: cell = system.cell # Get the neighbor pairs, shifts and edge indices. - neighbors, interatomic_vectors, _, _ = NNPOps.getNeighborPairs( - system.positions, self.requested_neighbor_list.engine_cutoff(), -1, cell + neighbors, interatomic_vectors, _, _ = getNeighborPairs( + system.positions, + self.requested_neighbor_list.engine_cutoff("nm"), + -1, + cell, ) mask = neighbors[0] >= 0 neighbors = neighbors[:, mask] interatomic_vectors = interatomic_vectors[mask, :] if self.requested_neighbor_list.full_list: - neighbors = torch.stack((neighbors, neighbors.flip(0)), dim=1) - interatomic_vectors = torch.stack( + neighbors = torch.concatenate((neighbors, neighbors.flip(0)), dim=1) + interatomic_vectors = torch.concatenate( (interatomic_vectors, -interatomic_vectors) ) @@ -141,7 +144,7 @@ def _attach_neighbors(self, system: System) -> System: else: cell_shifts = torch.zeros( (neighbors.shape[1], 3), - dtype=self.dtype, + dtype=torch.int32, device=system.positions.device, ) @@ -155,14 +158,14 @@ def _attach_neighbors(self, system: System) -> System: "cell_shift_b", "cell_shift_c", ], - values=torch.stack([neighbors.T, cell_shifts], dim=-1), + values=torch.concatenate([neighbors.T, cell_shifts], dim=-1), ), components=[ Labels( names=["xyz"], values=torch.arange( 3, dtype=torch.int32, device=system.positions.device - ), + ).reshape(-1, 1), ) ], properties=Labels( @@ -198,6 +201,7 @@ def forward( positions=positions, cell=cell, ) + system = self._attach_neighbors(system) energy = ( self.model( From 5e94523ce22845847ed969a637d4c7dae4351785 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 28 Jun 2024 14:35:40 +0200 Subject: [PATCH 09/16] Extract neighbor list --- .../{openmm_interface.py => openmm_force.py} | 152 +++++++++--------- 1 file changed, 78 insertions(+), 74 deletions(-) rename python/metatensor-torch/metatensor/torch/atomistic/{openmm_interface.py => openmm_force.py} (62%) diff --git a/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py b/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py similarity index 62% rename from python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py rename to python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py index d8bdaed3d..810a5d398 100644 --- a/python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py +++ b/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py @@ -8,6 +8,7 @@ ModelOutput, System, load_atomistic_model, + NeighborListOptions ) @@ -106,79 +107,6 @@ def __init__( self.check_consistency = check_consistency - def _attach_neighbors(self, system: System) -> System: - - if self.requested_neighbor_list is None: - return system - - cell: Optional[torch.Tensor] = None - if not torch.all(system.cell == 0.0): - cell = system.cell - - # Get the neighbor pairs, shifts and edge indices. - neighbors, interatomic_vectors, _, _ = getNeighborPairs( - system.positions, - self.requested_neighbor_list.engine_cutoff("nm"), - -1, - cell, - ) - mask = neighbors[0] >= 0 - neighbors = neighbors[:, mask] - interatomic_vectors = interatomic_vectors[mask, :] - - if self.requested_neighbor_list.full_list: - neighbors = torch.concatenate((neighbors, neighbors.flip(0)), dim=1) - interatomic_vectors = torch.concatenate( - (interatomic_vectors, -interatomic_vectors) - ) - - if cell is not None: - interatomic_vectors_unit_cell = ( - system.positions[interatomic_vectors[0]] - - system.positions[interatomic_vectors[1]] - ) - cell_shifts = ( - interatomic_vectors_unit_cell - interatomic_vectors - ) @ torch.linalg.inv(cell) - cell_shifts = torch.round(cell_shifts).to(torch.int32) - else: - cell_shifts = torch.zeros( - (neighbors.shape[1], 3), - dtype=torch.int32, - device=system.positions.device, - ) - - neighbor_list = TensorBlock( - values=interatomic_vectors.reshape(-1, 3, 1), - samples=Labels( - names=[ - "first_atom", - "second_atom", - "cell_shift_a", - "cell_shift_b", - "cell_shift_c", - ], - values=torch.concatenate([neighbors.T, cell_shifts], dim=-1), - ), - components=[ - Labels( - names=["xyz"], - values=torch.arange( - 3, dtype=torch.int32, device=system.positions.device - ).reshape(-1, 1), - ) - ], - properties=Labels( - names=["distance"], - values=torch.tensor( - [[0]], dtype=torch.int32, device=system.positions.device - ), - ), - ) - - system.add_neighbor_list(self.requested_neighbor_list, neighbor_list) - return system - def forward( self, positions: torch.Tensor, cell: Optional[torch.Tensor] = None ) -> torch.Tensor: @@ -201,7 +129,7 @@ def forward( positions=positions, cell=cell, ) - system = self._attach_neighbors(system) + system = _attach_neighbors(system, self.requested_neighbor_list) energy = ( self.model( @@ -233,3 +161,79 @@ def forward( force.setForceGroup(forceGroup) return force + + +def _attach_neighbors(system: System, requested_nl_options: NeighborListOptions) -> System: + + if requested_nl_options is None: + return system + + cell: Optional[torch.Tensor] = None + if not torch.all(system.cell == 0.0): + cell = system.cell + + # Get the neighbor pairs, shifts and edge indices. + neighbors, interatomic_vectors, _, _ = getNeighborPairs( + system.positions, + requested_nl_options.engine_cutoff("nm"), + -1, + cell, + ) + mask = neighbors[0] >= 0 + neighbors = neighbors[:, mask] + neighbors = neighbors.flip(0) # [neighbor, center] -> [center, neighbor] + interatomic_vectors = interatomic_vectors[mask, :] + + if requested_nl_options.full_list: + neighbors = torch.concatenate((neighbors, neighbors.flip(0)), dim=1) + interatomic_vectors = torch.concatenate( + (interatomic_vectors, -interatomic_vectors) + ) + + if cell is not None: + interatomic_vectors_unit_cell = ( + system.positions[neighbors[1]] + - system.positions[neighbors[0]] + ) + cell_shifts = ( + interatomic_vectors_unit_cell - interatomic_vectors + ) @ torch.linalg.inv(cell) + print(cell_shifts) + cell_shifts = torch.round(cell_shifts).to(torch.int32) + else: + cell_shifts = torch.zeros( + (neighbors.shape[1], 3), + dtype=torch.int32, + device=system.positions.device, + ) + + neighbor_list = TensorBlock( + values=interatomic_vectors.reshape(-1, 3, 1), + samples=Labels( + names=[ + "first_atom", + "second_atom", + "cell_shift_a", + "cell_shift_b", + "cell_shift_c", + ], + values=torch.concatenate([neighbors.T, cell_shifts], dim=-1), + ), + components=[ + Labels( + names=["xyz"], + values=torch.arange( + 3, dtype=torch.int32, device=system.positions.device + ).reshape(-1, 1), + ) + ], + properties=Labels( + names=["distance"], + values=torch.tensor( + [[0]], dtype=torch.int32, device=system.positions.device + ), + ), + ) + + system.add_neighbor_list(requested_nl_options, neighbor_list) + return system From e74d0fd2d499facb9eac811357a178951c73d28c Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 28 Jun 2024 15:41:34 +0200 Subject: [PATCH 10/16] Add tests --- .../torch/atomistic/openmm_force.py | 10 +- .../tests/atomistic/openmm_force.py | 188 ++++++++++++++++++ 2 files changed, 193 insertions(+), 5 deletions(-) create mode 100644 python/metatensor-torch/tests/atomistic/openmm_force.py diff --git a/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py b/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py index 810a5d398..77dd0828f 100644 --- a/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py +++ b/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py @@ -6,9 +6,9 @@ from metatensor.torch.atomistic import ( ModelEvaluationOptions, ModelOutput, + NeighborListOptions, System, load_atomistic_model, - NeighborListOptions ) @@ -163,7 +163,9 @@ def forward( return force -def _attach_neighbors(system: System, requested_nl_options: NeighborListOptions) -> System: +def _attach_neighbors( + system: System, requested_nl_options: NeighborListOptions +) -> System: if requested_nl_options is None: return system @@ -192,13 +194,11 @@ def _attach_neighbors(system: System, requested_nl_options: NeighborListOptions) if cell is not None: interatomic_vectors_unit_cell = ( - system.positions[neighbors[1]] - - system.positions[neighbors[0]] + system.positions[neighbors[1]] - system.positions[neighbors[0]] ) cell_shifts = ( interatomic_vectors_unit_cell - interatomic_vectors ) @ torch.linalg.inv(cell) - print(cell_shifts) cell_shifts = torch.round(cell_shifts).to(torch.int32) else: cell_shifts = torch.zeros( diff --git a/python/metatensor-torch/tests/atomistic/openmm_force.py b/python/metatensor-torch/tests/atomistic/openmm_force.py new file mode 100644 index 000000000..2376380d8 --- /dev/null +++ b/python/metatensor-torch/tests/atomistic/openmm_force.py @@ -0,0 +1,188 @@ +import os + +import ase.io +import metatensor_lj_test +import numpy as np +import pytest +from ase.build import bulk + +from metatensor.torch.atomistic.ase_calculator import MetatensorCalculator +from metatensor.torch.atomistic.openmm_force import get_metatensor_force + + +try: + import NNPOps # noqa: F401 + import openmm + import openmmtorch # noqa: F401 + + HAS_OPENMM = True +except ImportError: + HAS_OPENMM = False + +CUTOFF = 5.0 +SIGMA = 1.5808 +EPSILON = 0.1729 + + +def model(): + return metatensor_lj_test.lennard_jones_model( + atomic_type=29, + cutoff=CUTOFF, + sigma=SIGMA, + epsilon=EPSILON, + length_unit="Angstrom", + energy_unit="eV", + with_extension=False, + ) + + +def model_different_units(): + return metatensor_lj_test.lennard_jones_model( + atomic_type=29, + cutoff=CUTOFF / ase.units.Bohr, + sigma=SIGMA / ase.units.Bohr, + epsilon=EPSILON / ase.units.kJ * ase.units.mol, + length_unit="Bohr", + energy_unit="kJ/mol", + with_extension=False, + ) + + +def _modify_pdb(path): + # makes the ASE pdb file compatible with OpenMM + with open(path, "r") as f: + lines = f.readlines() + count = 0 + new_lines = [] + for line in lines: + if "Cu" in line: + count += 1 + line = line.replace(" Cu", f"Cu{count}") + new_lines.append(line) + with open(path, "w") as f: + f.writelines(new_lines) + + +def _check_against_ase(tmpdir, atoms): + model_path = os.path.join(tmpdir, "model.pt") + structure_path = os.path.join(tmpdir, "structure.pdb") + + ase.io.write(structure_path, atoms) + _modify_pdb(structure_path) + + topology = openmm.app.PDBFile(structure_path).getTopology() + system = openmm.System() + for atom in topology.atoms(): + system.addParticle(atom.element.mass) + if atoms.pbc.any(): + system.setDefaultPeriodicBoxVectors( + atoms.cell[0] * openmm.unit.angstrom, + atoms.cell[1] * openmm.unit.angstrom, + atoms.cell[2] * openmm.unit.angstrom, + ) + force = get_metatensor_force(system, topology, model_path, check_consistency=True) + system.addForce(force) + integrator = openmm.VerletIntegrator(0.001 * openmm.unit.picoseconds) + platform = openmm.Platform.getPlatformByName("CPU") + properties = {} + simulation = openmm.app.Simulation( + topology, system, integrator, platform, properties + ) + simulation.context.setPositions(openmm.app.PDBFile(structure_path).getPositions()) + state = simulation.context.getState(getForces=True) + openmm_forces = ( + state.getForces(asNumpy=True).value_in_unit( + openmm.unit.ev / (openmm.unit.angstrom * openmm.unit.mole) + ) + / 6.0221367e23 + ) + + atoms = ase.io.read(structure_path) + calculator = MetatensorCalculator(model_path, check_consistency=True) + atoms.set_calculator(calculator) + ase_forces = atoms.get_forces() + + print(openmm_forces) + print(ase_forces) + assert np.allclose(openmm_forces, ase_forces) + + +@pytest.mark.skipif(not HAS_OPENMM, reason="OpenMM not available") +def test_diagonal_cell(tmpdir): + cell = np.array( + [ + [10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0], + ] + ) + atoms = bulk("Cu", cubic=False) + atoms *= (2, 2, 2) + atoms.positions += np.random.rand(*atoms.positions.shape) * 0.1 + atoms.cell = cell + atoms.pbc = True + atoms.wrap() + + m = model() + m.save(os.path.join(tmpdir, "model.pt")) + + _check_against_ase(tmpdir, atoms) + + +@pytest.mark.skipif(not HAS_OPENMM, reason="OpenMM not available") +def test_non_diagonal_cell(tmpdir): + cell = np.array( + [ + [10.0, 0.0, 0.0], + [3.0, 10.0, 0.0], + [3.0, -3.0, 10.0], + ] + ) + atoms = bulk("Cu", cubic=False) + atoms *= (2, 2, 2) + atoms.positions += np.random.rand(*atoms.positions.shape) * 0.1 + atoms.cell = cell + atoms.pbc = True + atoms.wrap() + + m = model() + m.save(os.path.join(tmpdir, "model.pt")) + + _check_against_ase(tmpdir, atoms) + + +@pytest.mark.skipif(not HAS_OPENMM, reason="OpenMM not available") +def test_non_diagonal_cell_different_units(tmpdir): + cell = np.array( + [ + [100.0, 0.0, 0.0], + [3.0, 100.0, 0.0], + [3.0, -3.0, 100.0], + ] + ) + atoms = bulk("Cu", cubic=False) + atoms *= (2, 2, 2) + atoms.positions += np.random.rand(*atoms.positions.shape) * 0.1 + atoms.cell = cell + atoms.pbc = True + atoms.wrap() + + m = model_different_units() + m.save(os.path.join(tmpdir, "model.pt")) + + _check_against_ase(tmpdir, atoms) + + +@pytest.mark.skipif(not HAS_OPENMM, reason="OpenMM not available") +def test_no_cell(tmpdir): + atoms = bulk("Cu", cubic=False) + atoms *= (2, 2, 2) + atoms.positions += np.random.rand(*atoms.positions.shape) * 0.1 + atoms.cell = None + atoms.pbc = False + atoms.wrap() + + m = model() + m.save(os.path.join(tmpdir, "model.pt")) + + _check_against_ase(tmpdir, atoms) From c660e07035718278bf4627c66a34ee84b7d22dc9 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 28 Jun 2024 15:58:09 +0200 Subject: [PATCH 11/16] Implement suggestions --- docs/src/atomistic/reference/openmm.rst | 4 +- .../torch/atomistic/openmm_force.py | 61 +++++++++++-------- 2 files changed, 38 insertions(+), 27 deletions(-) diff --git a/docs/src/atomistic/reference/openmm.rst b/docs/src/atomistic/reference/openmm.rst index 87739d9ea..aec8f71cd 100644 --- a/docs/src/atomistic/reference/openmm.rst +++ b/docs/src/atomistic/reference/openmm.rst @@ -7,9 +7,7 @@ OpenMM integration which can be used to load a :py:class:`MetatensorAtomisticModel` into an ``openmm.Force`` object able to calculate forces on an ``openmm.System``. -.. autoclass:: metatensor.torch.atomistic.openmm_interface.get_metatensor_force - :show-inheritance: - :members: +.. autofunction:: metatensor.torch.atomistic.openmm_force.get_metatensor_force In order to run simulations with ``metatensor.torch.atomistic`` and ``OpenMM``, we recommend installing ``OpenMM`` from conda, using diff --git a/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py b/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py index 77dd0828f..b99204a44 100644 --- a/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py +++ b/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py @@ -13,9 +13,9 @@ try: + import NNPOps.neighbors import openmm import openmmtorch - from NNPOps.neighbors import getNeighborPairs HAS_OPENMM = True except ImportError: @@ -38,9 +38,25 @@ def get_metatensor_force( path: str, extensions_directory: Optional[str] = None, forceGroup: int = 0, - atoms: Optional[Iterable[int]] = None, + selected_atoms: Optional[Iterable[int]] = None, check_consistency: bool = False, -) -> openmm.System: +) -> openmm.Force: + """ + Create an OpenMM Force from a metatensor atomistic model. + + :param system: The OpenMM System object. + :param topology: The OpenMM Topology object. + :param path: The path to the exported metatensor model. + :param extensions_directory: The path to the extensions for the model. + :param forceGroup: The force group to which the force should be added. + :param selected_atoms: The indices of the atoms on which non-zero forces should + be computed (e.g., the ML region). If None, forces will be computed for all + atoms. + :param check_consistency: Whether to check various consistency conditions inside + the model. + + return: The OpenMM Force object. + """ if not HAS_OPENMM: raise ImportError( @@ -52,15 +68,15 @@ def get_metatensor_force( # Get the atomic numbers of the ML region. all_atoms = list(topology.atoms()) - atomic_numbers = [atom.element.atomic_number for atom in all_atoms] + atomic_types = [atom.element.atomic_number for atom in all_atoms] - if atoms is None: + if selected_atoms is None: selected_atoms = None else: selected_atoms = Labels( names=["system", "atom"], values=torch.tensor( - [[0, selected_atom] for selected_atom in atoms], + [[0, selected_atom] for selected_atom in selected_atoms], dtype=torch.int32, ), ) @@ -69,8 +85,8 @@ class MetatensorForce(torch.nn.Module): def __init__( self, - model: torch.jit._script.RecursiveScriptModule, - atomic_numbers: List[int], + model, + atomic_types: List[int], selected_atoms: Optional[Labels], check_consistency: bool, ) -> None: @@ -78,7 +94,7 @@ def __init__( self.model = model self.register_buffer( - "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int32) + "atomic_types", torch.tensor(atomic_types, dtype=torch.int32) ) self.evaluation_options = ModelEvaluationOptions( length_unit="nm", @@ -125,26 +141,23 @@ def forward( # create System system = System( - types=self.atomic_numbers, + types=self.atomic_types, positions=positions, cell=cell, ) system = _attach_neighbors(system, self.requested_neighbor_list) - energy = ( - self.model( - [system], - self.evaluation_options, - check_consistency=self.check_consistency, - )["energy"] - .block() - .values.reshape(()) + outputs = self.model( + [system], + self.evaluation_options, + check_consistency=self.check_consistency, ) + energy = outputs["energy"].block().values.reshape(()) return energy metatensor_force = MetatensorForce( model, - atomic_numbers, + atomic_types, selected_atoms, check_consistency, ) @@ -175,11 +188,11 @@ def _attach_neighbors( cell = system.cell # Get the neighbor pairs, shifts and edge indices. - neighbors, interatomic_vectors, _, _ = getNeighborPairs( - system.positions, - requested_nl_options.engine_cutoff("nm"), - -1, - cell, + neighbors, interatomic_vectors, _, _ = NNPOps.neighbors.getNeighborPairs( + positions=system.positions, + cutoff=requested_nl_options.engine_cutoff("nm"), + max_num_pairs=-1, + box_vectors=cell, ) mask = neighbors[0] >= 0 neighbors = neighbors[:, mask] From 94e0973d8acee5f31a4d252833e5f62837d005e3 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 28 Jun 2024 16:10:46 +0200 Subject: [PATCH 12/16] Update fake classes for type hints --- .../metatensor/torch/atomistic/openmm_force.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py b/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py index b99204a44..802d96739 100644 --- a/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py +++ b/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py @@ -20,14 +20,16 @@ HAS_OPENMM = True except ImportError: - class MLPotential: - pass + class openmm: + class System: + pass - class MLPotentialImpl: - pass + class Force: + pass - class MLPotentialImplFactory: - pass + class app: + class Topology: + pass HAS_OPENMM = False From 3ea577a7b475be776982b55815a1bfca2845891c Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 5 Jul 2024 08:33:50 +0200 Subject: [PATCH 13/16] Print metadata --- .../metatensor/torch/atomistic/openmm_force.py | 3 +++ python/metatensor-torch/tests/atomistic/openmm_force.py | 2 -- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py b/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py index 802d96739..977289db7 100644 --- a/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py +++ b/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py @@ -68,6 +68,9 @@ def get_metatensor_force( model = load_atomistic_model(path, extensions_directory=extensions_directory) + # Print the model's metadata + print(model.metadata().print()) + # Get the atomic numbers of the ML region. all_atoms = list(topology.atoms()) atomic_types = [atom.element.atomic_number for atom in all_atoms] diff --git a/python/metatensor-torch/tests/atomistic/openmm_force.py b/python/metatensor-torch/tests/atomistic/openmm_force.py index 2376380d8..27ce46a2f 100644 --- a/python/metatensor-torch/tests/atomistic/openmm_force.py +++ b/python/metatensor-torch/tests/atomistic/openmm_force.py @@ -102,8 +102,6 @@ def _check_against_ase(tmpdir, atoms): atoms.set_calculator(calculator) ase_forces = atoms.get_forces() - print(openmm_forces) - print(ase_forces) assert np.allclose(openmm_forces, ase_forces) From aed0dcf66fa5ceb1e358c1c920330c5e28a1edd9 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 5 Jul 2024 11:38:43 +0200 Subject: [PATCH 14/16] Add warning for dtype conversion --- .../torch/atomistic/openmm_force.py | 33 +++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py b/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py index 977289db7..e7a790a6c 100644 --- a/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py +++ b/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py @@ -1,3 +1,4 @@ +import warnings from typing import Iterable, List, Optional import torch @@ -11,6 +12,8 @@ load_atomistic_model, ) +from .ase_calculator import STR_TO_DTYPE + try: import NNPOps.neighbors @@ -127,6 +130,8 @@ def __init__( self.requested_neighbor_list = None self.check_consistency = check_consistency + self.dtype = STR_TO_DTYPE[self.model.capabilities().dtype] + self.already_warned = False def forward( self, positions: torch.Tensor, cell: Optional[torch.Tensor] = None @@ -152,13 +157,28 @@ def forward( ) system = _attach_neighbors(system, self.requested_neighbor_list) + original_dtype = system.dtype + if self.dtype != system.dtype: + if not self.already_warned: + model_dtype_string = dtype_to_str(self.dtype) + system_dtype_string = dtype_to_str(system.dtype) + warnings.warn( + f"Model dtype {model_dtype_string} does not match the dtype " + f"of the system {system_dtype_string}. The system will be " + f"converted to {model_dtype_string} temporarily to allow the " + "model to run.", + stacklevel=2, + ) + self.already_warned = True + system = system.to(dtype=self.dtype) + outputs = self.model( [system], self.evaluation_options, check_consistency=self.check_consistency, ) energy = outputs["energy"].block().values.reshape(()) - return energy + return energy.to(original_dtype) metatensor_force = MetatensorForce( model, @@ -182,7 +202,7 @@ def forward( def _attach_neighbors( - system: System, requested_nl_options: NeighborListOptions + system: System, requested_nl_options: Optional[NeighborListOptions] ) -> System: if requested_nl_options is None: @@ -255,3 +275,12 @@ def _attach_neighbors( system.add_neighbor_list(requested_nl_options, neighbor_list) return system + + +def dtype_to_str(dtype: torch.dtype) -> str: + if dtype == torch.float32: + return "float32" + elif dtype == torch.float64: + return "float64" + else: + raise ValueError(f"Unsupported dtype {dtype}.") From 46244835844219513113e50db5217099bb32314f Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 5 Jul 2024 11:50:10 +0200 Subject: [PATCH 15/16] Revert "Register `ModelMetadataHolder::print()`" This reverts commit 8c403c12e5315114ca8070d7fc8ec84ce2e50ea5. --- metatensor-torch/src/register.cpp | 1 - python/metatensor-torch/tests/atomistic/metadata.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/metatensor-torch/src/register.cpp b/metatensor-torch/src/register.cpp index b6e613ef3..461dbda40 100644 --- a/metatensor-torch/src/register.cpp +++ b/metatensor-torch/src/register.cpp @@ -382,7 +382,6 @@ TORCH_LIBRARY(metatensor, m) { ) .def("__repr__", &ModelMetadataHolder::print) .def("__str__", &ModelMetadataHolder::print) - .def("print", &ModelMetadataHolder::print) .def_readwrite("name", &ModelMetadataHolder::name) .def_readwrite("description", &ModelMetadataHolder::description) .def_readwrite("authors", &ModelMetadataHolder::authors) diff --git a/python/metatensor-torch/tests/atomistic/metadata.py b/python/metatensor-torch/tests/atomistic/metadata.py index dc7bcee6d..fefcb817b 100644 --- a/python/metatensor-torch/tests/atomistic/metadata.py +++ b/python/metatensor-torch/tests/atomistic/metadata.py @@ -226,5 +226,3 @@ def forward(self, x: ModelMetadataWrap) -> ModelMetadataWrap: * ref-4 """ assert str(metadata) == expected - assert metadata.__repr__() == expected - assert metadata.print() == expected From fc19b4103ab103ab423926ef9c99cfc4af4506fc Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Fri, 5 Jul 2024 11:51:00 +0200 Subject: [PATCH 16/16] Use `print(metadata)` to print the metadata --- .../metatensor-torch/metatensor/torch/atomistic/openmm_force.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py b/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py index e7a790a6c..d0312eefb 100644 --- a/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py +++ b/python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py @@ -72,7 +72,7 @@ def get_metatensor_force( model = load_atomistic_model(path, extensions_directory=extensions_directory) # Print the model's metadata - print(model.metadata().print()) + print(model.metadata()) # Get the atomic numbers of the ML region. all_atoms = list(topology.atoms())