Skip to content

Add OpenMM-torch interface #664

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/atomistic/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ atomistic systems as input of a machine learning model with
systems
models/index
ase
openmm


C++ API reference
Expand Down
39 changes: 39 additions & 0 deletions docs/src/atomistic/reference/openmm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
OpenMM integration
==================

.. py:currentmodule:: metatensor.torch.atomistic

:py:mod:`openmm_interface` contains the ``get_metatensor_force`` function,
which can be used to load a :py:class:`MetatensorAtomisticModel` into an
``openmm.Force`` object able to calculate forces on an ``openmm.System``.

.. autofunction:: metatensor.torch.atomistic.openmm_force.get_metatensor_force

In order to run simulations with ``metatensor.torch.atomistic`` and ``OpenMM``,
we recommend installing ``OpenMM`` from conda, using
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should mention than NNPOps is a required dependency as well in here

``conda install -c conda-forge openmm-torch nnpops``. Subsequently,
metatensor can be installed with ``pip install metatensor[torch]``, and a minimal
Comment on lines +14 to +15
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there is #665, should we recommend building metatensor-torch from sources for now? Something like pip install --no-binary=metatensor-torch metatensor[torch]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good

script demonstrating how to run a simple simulation is illustrated below:

.. code-block:: python

import openmm
from metatensor.torch.atomistic.openmm_interface import get_metatensor_force

# load an input geometry file
topology = openmm.app.PDBFile('input.pdb').getTopology()
system = openmm.System()
for atom in topology.atoms():
system.addParticle(atom.element.mass)

# get the force object from an exported model saved as 'model.pt'
force = get_metatensor_force(system, topology, 'model.pt')
system.addForce(force)

integrator = openmm.VerletIntegrator(0.001)
platform = openmm.Platform.getPlatformByName('CUDA')
simulation = openmm.app.Simulation(topology, system, integrator, platform)
simulation.context.setPositions(openmm.app.PDBFile('input.pdb').getPositions())
simulation.reporters.append(openmm.app.PDBReporter('output.pdb', 100))

simulation.step(1000)
4 changes: 0 additions & 4 deletions python/metatensor-operations/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ follows [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
### Removed
-->

### Changed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this removal looks like a merge issue


- We now require Python >= 3.9

## [Version 0.2.2](https://github.com/lab-cosmo/metatensor/releases/tag/metatensor-operations-v0.2.2) - 2024-06-19

### Fixed
Expand Down
286 changes: 286 additions & 0 deletions python/metatensor-torch/metatensor/torch/atomistic/openmm_force.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
import warnings
from typing import Iterable, List, Optional

import torch

from metatensor.torch import Labels, TensorBlock
from metatensor.torch.atomistic import (
ModelEvaluationOptions,
ModelOutput,
NeighborListOptions,
System,
load_atomistic_model,
)

from .ase_calculator import STR_TO_DTYPE


try:
import NNPOps.neighbors
import openmm
import openmmtorch

HAS_OPENMM = True
except ImportError:

class openmm:
class System:
pass

class Force:
pass

class app:
class Topology:
pass

HAS_OPENMM = False


def get_metatensor_force(
system: openmm.System,
topology: openmm.app.Topology,
path: str,
extensions_directory: Optional[str] = None,
forceGroup: int = 0,
selected_atoms: Optional[Iterable[int]] = None,
check_consistency: bool = False,
) -> openmm.Force:
"""
Create an OpenMM Force from a metatensor atomistic model.

:param system: The OpenMM System object.
:param topology: The OpenMM Topology object.
:param path: The path to the exported metatensor model.
:param extensions_directory: The path to the extensions for the model.
:param forceGroup: The force group to which the force should be added.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we link to OpenMM docs for force group here? It might not be clear to everyone.

:param selected_atoms: The indices of the atoms on which non-zero forces should
be computed (e.g., the ML region). If None, forces will be computed for all
atoms.
:param check_consistency: Whether to check various consistency conditions inside
the model.

return: The OpenMM Force object.
"""

if not HAS_OPENMM:
raise ImportError(
"Could not import openmm and/or nnpops. If you want to use metatensor with "
"openmm, please install openmm-torch and nnpops with conda."
)

model = load_atomistic_model(path, extensions_directory=extensions_directory)

# Print the model's metadata
print(model.metadata())

# Get the atomic numbers of the ML region.
all_atoms = list(topology.atoms())
atomic_types = [atom.element.atomic_number for atom in all_atoms]

if selected_atoms is None:
selected_atoms = None
else:
selected_atoms = Labels(
names=["system", "atom"],
values=torch.tensor(
[[0, selected_atom] for selected_atom in selected_atoms],
dtype=torch.int32,
),
)

class MetatensorForce(torch.nn.Module):

def __init__(
self,
model,
atomic_types: List[int],
selected_atoms: Optional[Labels],
check_consistency: bool,
) -> None:
super(MetatensorForce, self).__init__()

self.model = model
self.register_buffer(
"atomic_types", torch.tensor(atomic_types, dtype=torch.int32)
)
self.evaluation_options = ModelEvaluationOptions(
length_unit="nm",
outputs={
"energy": ModelOutput(
quantity="energy",
unit="kJ/mol",
per_atom=False,
),
},
selected_atoms=selected_atoms,
)

requested_nls = self.model.requested_neighbor_lists()
if len(requested_nls) > 1:
raise ValueError(
"The model requested more than one neighbor list. "
"Currently, only models with a single neighbor list are supported "
"by the OpenMM interface."
)
elif len(requested_nls) == 1:
self.requested_neighbor_list = requested_nls[0]
else:
# no neighbor list requested
self.requested_neighbor_list = None

self.check_consistency = check_consistency
self.dtype = STR_TO_DTYPE[self.model.capabilities().dtype]
self.already_warned = False

def forward(
self, positions: torch.Tensor, cell: Optional[torch.Tensor] = None
) -> torch.Tensor:
# move labels if necessary
selected_atoms = self.evaluation_options.selected_atoms
if selected_atoms is not None:
if selected_atoms.device != positions.device:
self.evaluation_options.selected_atoms = selected_atoms.to(
positions.device
)

if cell is None:
cell = torch.zeros(
(3, 3), dtype=positions.dtype, device=positions.device
)

# create System
system = System(
types=self.atomic_types,
positions=positions,
cell=cell,
)
system = _attach_neighbors(system, self.requested_neighbor_list)

original_dtype = system.dtype
if self.dtype != system.dtype:
if not self.already_warned:
model_dtype_string = dtype_to_str(self.dtype)
system_dtype_string = dtype_to_str(system.dtype)
warnings.warn(
f"Model dtype {model_dtype_string} does not match the dtype "
f"of the system {system_dtype_string}. The system will be "
f"converted to {model_dtype_string} temporarily to allow the "
"model to run.",
stacklevel=2,
)
self.already_warned = True
system = system.to(dtype=self.dtype)

outputs = self.model(
[system],
self.evaluation_options,
check_consistency=self.check_consistency,
)
energy = outputs["energy"].block().values.reshape(())
return energy.to(original_dtype)

metatensor_force = MetatensorForce(
model,
atomic_types,
selected_atoms,
check_consistency,
)

# torchscript everything
module = torch.jit.script(metatensor_force)

# create the OpenMM force
force = openmmtorch.TorchForce(module)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so model.to(device) is called by TorchForce? Could you leave a comment if this is the case, so future readers don't wonder what's happening?

isPeriodic = (
topology.getPeriodicBoxVectors() is not None
) or system.usesPeriodicBoundaryConditions()
force.setUsesPeriodicBoundaryConditions(isPeriodic)
force.setForceGroup(forceGroup)

return force


def _attach_neighbors(
system: System, requested_nl_options: Optional[NeighborListOptions]
) -> System:

if requested_nl_options is None:
return system

cell: Optional[torch.Tensor] = None
if not torch.all(system.cell == 0.0):
cell = system.cell

# Get the neighbor pairs, shifts and edge indices.
neighbors, interatomic_vectors, _, _ = NNPOps.neighbors.getNeighborPairs(
positions=system.positions,
cutoff=requested_nl_options.engine_cutoff("nm"),
max_num_pairs=-1,
box_vectors=cell,
)
mask = neighbors[0] >= 0
neighbors = neighbors[:, mask]
neighbors = neighbors.flip(0) # [neighbor, center] -> [center, neighbor]
interatomic_vectors = interatomic_vectors[mask, :]

if requested_nl_options.full_list:
neighbors = torch.concatenate((neighbors, neighbors.flip(0)), dim=1)
interatomic_vectors = torch.concatenate(
(interatomic_vectors, -interatomic_vectors)
)

if cell is not None:
interatomic_vectors_unit_cell = (
system.positions[neighbors[1]] - system.positions[neighbors[0]]
)
cell_shifts = (
interatomic_vectors_unit_cell - interatomic_vectors
) @ torch.linalg.inv(cell)
cell_shifts = torch.round(cell_shifts).to(torch.int32)
else:
cell_shifts = torch.zeros(
(neighbors.shape[1], 3),
dtype=torch.int32,
device=system.positions.device,
)

neighbor_list = TensorBlock(
values=interatomic_vectors.reshape(-1, 3, 1),
samples=Labels(
names=[
"first_atom",
"second_atom",
"cell_shift_a",
"cell_shift_b",
"cell_shift_c",
],
values=torch.concatenate([neighbors.T, cell_shifts], dim=-1),
),
components=[
Labels(
names=["xyz"],
values=torch.arange(
3, dtype=torch.int32, device=system.positions.device
).reshape(-1, 1),
)
],
properties=Labels(
names=["distance"],
values=torch.tensor(
[[0]], dtype=torch.int32, device=system.positions.device
),
),
)

system.add_neighbor_list(requested_nl_options, neighbor_list)
return system


def dtype_to_str(dtype: torch.dtype) -> str:
if dtype == torch.float32:
return "float32"
elif dtype == torch.float64:
return "float64"
else:
raise ValueError(f"Unsupported dtype {dtype}.")
Loading
Loading