Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpenMM-torch interface #664

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions docs/src/atomistic/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ atomistic systems as input of a machine learning model with
systems
models/index
ase
openmm


C++ API reference
Expand Down
41 changes: 41 additions & 0 deletions docs/src/atomistic/reference/openmm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
OpenMM integration
==================

.. py:currentmodule:: metatensor.torch.atomistic

:py:mod:`openmm_interface` contains the ``get_metatensor_force`` function,
which can be used to load a :py:class:`MetatensorAtomisticModel` into an
``openmm.Force`` object able to calculate forces on an ``openmm.System``.

.. autoclass:: metatensor.torch.atomistic.openmm_interface.get_metatensor_force
:show-inheritance:
:members:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.. autoclass:: metatensor.torch.atomistic.openmm_interface.get_metatensor_force
:show-inheritance:
:members:
.. autofunction:: metatensor.torch.atomistic.openmm_interface.get_metatensor_force


In order to run simulations with ``metatensor.torch.atomistic`` and ``OpenMM``,
we recommend installing ``OpenMM`` from conda, using
``conda install -c conda-forge openmm-torch nnpops``. Subsequently,
metatensor can be installed with ``pip install metatensor[torch]``, and a minimal
script demonstrating how to run a simple simulation is illustrated below:

.. code-block:: python

import openmm
from metatensor.torch.atomistic.openmm_interface import get_metatensor_force

# load an input geometry file
topology = openmm.app.PDBFile('input.pdb').getTopology()
system = openmm.System()
for atom in topology.atoms():
system.addParticle(atom.element.mass)

# get the force object from an exported model saved as 'model.pt'
force = get_metatensor_force(system, topology, 'model.pt')
system.addForce(force)

integrator = openmm.VerletIntegrator(0.001)
platform = openmm.Platform.getPlatformByName('CUDA')
simulation = openmm.app.Simulation(topology, system, integrator, platform)
simulation.context.setPositions(openmm.app.PDBFile('input.pdb').getPositions())
simulation.reporters.append(openmm.app.PDBReporter('output.pdb', 100))

simulation.step(1000)
235 changes: 235 additions & 0 deletions python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
from typing import Iterable, List, Optional

import torch

from metatensor.torch import Labels, TensorBlock
from metatensor.torch.atomistic import (
ModelEvaluationOptions,
ModelOutput,
System,
load_atomistic_model,
)


try:
import openmm
import openmmtorch
from NNPOps.neighbors import getNeighborPairs

HAS_OPENMM = True
except ImportError:

class MLPotential:
pass

class MLPotentialImpl:
pass

class MLPotentialImplFactory:
pass

HAS_OPENMM = False


def get_metatensor_force(
system: openmm.System,
topology: openmm.app.Topology,
path: str,
extensions_directory: Optional[str] = None,
forceGroup: int = 0,
atoms: Optional[Iterable[int]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
atoms: Optional[Iterable[int]] = None,
selected_atoms: Optional[Iterable[int]] = None,

check_consistency: bool = False,
) -> openmm.System:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs a docstring & doc for all parameters


if not HAS_OPENMM:
raise ImportError(
"Could not import openmm and/or nnpops. If you want to use metatensor with "
"openmm, please install openmm-torch and nnpops with conda."
)

model = load_atomistic_model(path, extensions_directory=extensions_directory)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you print model.metadata() to the log somewhere? This prints information about the model.

Also, is there a mechanism to register papers to cite in OpenMM? If so, we should register the model references.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they have some MD loggers, but I don't think they're accessible from our function (they're attached to a Simulation object that is independent). Literally zero clue regarding the citation mechanism

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps the best solution is to print() and whatever

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

printing everything is a good starting point, and enough for this PR!

# Get the atomic numbers of the ML region.
all_atoms = list(topology.atoms())
atomic_numbers = [atom.element.atomic_number for atom in all_atoms]

if atoms is None:
selected_atoms = None
else:
selected_atoms = Labels(
names=["system", "atom"],
values=torch.tensor(
[[0, selected_atom] for selected_atom in atoms],
dtype=torch.int32,
),
)

class MetatensorForce(torch.nn.Module):

def __init__(
self,
model: torch.jit._script.RecursiveScriptModule,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model: torch.jit._script.RecursiveScriptModule,
model,

We don't really need the types in __init__, and this one is private

atomic_numbers: List[int],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
atomic_numbers: List[int],
atomic_types: List[int],

selected_atoms: Optional[Labels],
check_consistency: bool,
) -> None:
super(MetatensorForce, self).__init__()

self.model = model
self.register_buffer(
"atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int32)
)
self.evaluation_options = ModelEvaluationOptions(
length_unit="nm",
outputs={
"energy": ModelOutput(
quantity="energy",
unit="kJ/mol",
per_atom=False,
),
},
selected_atoms=selected_atoms,
)

requested_nls = self.model.requested_neighbor_lists()
if len(requested_nls) > 1:
raise ValueError(
"The model requested more than one neighbor list. "
"Currently, only models with a single neighbor list are supported "
"by the OpenMM interface."
)
elif len(requested_nls) == 1:
self.requested_neighbor_list = requested_nls[0]
else:
# no neighbor list requested
self.requested_neighbor_list = None

self.check_consistency = check_consistency

def _attach_neighbors(self, system: System) -> System:

if self.requested_neighbor_list is None:
return system

cell: Optional[torch.Tensor] = None
if not torch.all(system.cell == 0.0):
cell = system.cell

# Get the neighbor pairs, shifts and edge indices.
neighbors, interatomic_vectors, _, _ = getNeighborPairs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
neighbors, interatomic_vectors, _, _ = getNeighborPairs(
neighbors, interatomic_vectors, _, _ = NNPOps.neighbors.getNeighborPairs(

for clarity

system.positions,
self.requested_neighbor_list.engine_cutoff("nm"),
-1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you give the name of this argument? (and maybe others as well)

cell,
)
mask = neighbors[0] >= 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will there negative values in neighbors[0] if n_pairs=-1 above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think that NL returns a huge Tensor with all pairs, where only a certain number are not -1, i.e., masked

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand this behavior if we are setting max_n_pairs=300: if there are fewer pairs, the other ones are filled with -1. You are saying this is also the case when giving max_n_pairs=-1?

neighbors = neighbors[:, mask]
interatomic_vectors = interatomic_vectors[mask, :]

if self.requested_neighbor_list.full_list:
neighbors = torch.concatenate((neighbors, neighbors.flip(0)), dim=1)
interatomic_vectors = torch.concatenate(
(interatomic_vectors, -interatomic_vectors)
)

if cell is not None:
interatomic_vectors_unit_cell = (
system.positions[interatomic_vectors[0]]
- system.positions[interatomic_vectors[1]]
)
cell_shifts = (
interatomic_vectors_unit_cell - interatomic_vectors
) @ torch.linalg.inv(cell)
cell_shifts = torch.round(cell_shifts).to(torch.int32)
else:
cell_shifts = torch.zeros(
(neighbors.shape[1], 3),
dtype=torch.int32,
device=system.positions.device,
)

neighbor_list = TensorBlock(
values=interatomic_vectors.reshape(-1, 3, 1),
samples=Labels(
names=[
"first_atom",
"second_atom",
"cell_shift_a",
"cell_shift_b",
"cell_shift_c",
],
values=torch.concatenate([neighbors.T, cell_shifts], dim=-1),
),
components=[
Labels(
names=["xyz"],
values=torch.arange(
3, dtype=torch.int32, device=system.positions.device
).reshape(-1, 1),
)
],
properties=Labels(
names=["distance"],
values=torch.tensor(
[[0]], dtype=torch.int32, device=system.positions.device
),
),
)

system.add_neighbor_list(self.requested_neighbor_list, neighbor_list)
return system

def forward(
self, positions: torch.Tensor, cell: Optional[torch.Tensor] = None
) -> torch.Tensor:
# move labels if necessary
selected_atoms = self.evaluation_options.selected_atoms
if selected_atoms is not None:
if selected_atoms.device != positions.device:
self.evaluation_options.selected_atoms = selected_atoms.to(
positions.device
)

if cell is None:
cell = torch.zeros(
(3, 3), dtype=positions.dtype, device=positions.device
)

# create System
system = System(
types=self.atomic_numbers,
positions=positions,
cell=cell,
)
system = self._attach_neighbors(system)

energy = (
self.model(
[system],
self.evaluation_options,
check_consistency=self.check_consistency,
)["energy"]
.block()
.values.reshape(())
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
energy = (
self.model(
[system],
self.evaluation_options,
check_consistency=self.check_consistency,
)["energy"]
.block()
.values.reshape(())
)
outputs = self.model(
[system],
self.evaluation_options,
check_consistency=self.check_consistency,
)
energy = outputs["energy"].block().values.reshape(())

return energy

metatensor_force = MetatensorForce(
model,
atomic_numbers,
selected_atoms,
check_consistency,
)

# torchscript everything
module = torch.jit.script(metatensor_force)

# create the OpenMM force
force = openmmtorch.TorchForce(module)
isPeriodic = (
topology.getPeriodicBoxVectors() is not None
) or system.usesPeriodicBoundaryConditions()
force.setUsesPeriodicBoundaryConditions(isPeriodic)
force.setForceGroup(forceGroup)

return force
Loading