-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add OpenMM-torch interface #664
base: master
Are you sure you want to change the base?
Changes from 10 commits
8fc6b55
c178e44
af86f14
156d617
7b00594
ea8985e
263df69
d86e324
1b05329
490c351
10171ba
00c8257
af58dc9
f89f6f4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
OpenMM integration | ||
================== | ||
|
||
.. py:currentmodule:: metatensor.torch.atomistic | ||
|
||
:py:mod:`openmm_interface` contains the ``get_metatensor_force`` function, | ||
which can be used to load a :py:class:`MetatensorAtomisticModel` into an | ||
``openmm.Force`` object able to calculate forces on an ``openmm.System``. | ||
|
||
.. autoclass:: metatensor.torch.atomistic.openmm_interface.get_metatensor_force | ||
:show-inheritance: | ||
:members: | ||
|
||
In order to run simulations with ``metatensor.torch.atomistic`` and ``OpenMM``, | ||
we recommend installing ``OpenMM`` from conda, using | ||
``conda install -c conda-forge openmm-torch nnpops``. Subsequently, | ||
metatensor can be installed with ``pip install metatensor[torch]``, and a minimal | ||
script demonstrating how to run a simple simulation is illustrated below: | ||
|
||
.. code-block:: python | ||
|
||
import openmm | ||
from metatensor.torch.atomistic.openmm_interface import get_metatensor_force | ||
|
||
# load an input geometry file | ||
topology = openmm.app.PDBFile('input.pdb').getTopology() | ||
system = openmm.System() | ||
for atom in topology.atoms(): | ||
system.addParticle(atom.element.mass) | ||
|
||
# get the force object from an exported model saved as 'model.pt' | ||
force = get_metatensor_force(system, topology, 'model.pt') | ||
system.addForce(force) | ||
|
||
integrator = openmm.VerletIntegrator(0.001) | ||
platform = openmm.Platform.getPlatformByName('CUDA') | ||
simulation = openmm.app.Simulation(topology, system, integrator, platform) | ||
simulation.context.setPositions(openmm.app.PDBFile('input.pdb').getPositions()) | ||
simulation.reporters.append(openmm.app.PDBReporter('output.pdb', 100)) | ||
|
||
simulation.step(1000) |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,235 @@ | ||||||||||||||||||||||||||||||||
from typing import Iterable, List, Optional | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
from metatensor.torch import Labels, TensorBlock | ||||||||||||||||||||||||||||||||
from metatensor.torch.atomistic import ( | ||||||||||||||||||||||||||||||||
ModelEvaluationOptions, | ||||||||||||||||||||||||||||||||
ModelOutput, | ||||||||||||||||||||||||||||||||
System, | ||||||||||||||||||||||||||||||||
load_atomistic_model, | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||
import openmm | ||||||||||||||||||||||||||||||||
import openmmtorch | ||||||||||||||||||||||||||||||||
from NNPOps.neighbors import getNeighborPairs | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
HAS_OPENMM = True | ||||||||||||||||||||||||||||||||
except ImportError: | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
class MLPotential: | ||||||||||||||||||||||||||||||||
pass | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
class MLPotentialImpl: | ||||||||||||||||||||||||||||||||
pass | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
class MLPotentialImplFactory: | ||||||||||||||||||||||||||||||||
pass | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
HAS_OPENMM = False | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def get_metatensor_force( | ||||||||||||||||||||||||||||||||
system: openmm.System, | ||||||||||||||||||||||||||||||||
topology: openmm.app.Topology, | ||||||||||||||||||||||||||||||||
path: str, | ||||||||||||||||||||||||||||||||
extensions_directory: Optional[str] = None, | ||||||||||||||||||||||||||||||||
forceGroup: int = 0, | ||||||||||||||||||||||||||||||||
atoms: Optional[Iterable[int]] = None, | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||
check_consistency: bool = False, | ||||||||||||||||||||||||||||||||
) -> openmm.System: | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this needs a docstring & doc for all parameters |
||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if not HAS_OPENMM: | ||||||||||||||||||||||||||||||||
raise ImportError( | ||||||||||||||||||||||||||||||||
"Could not import openmm and/or nnpops. If you want to use metatensor with " | ||||||||||||||||||||||||||||||||
"openmm, please install openmm-torch and nnpops with conda." | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
model = load_atomistic_model(path, extensions_directory=extensions_directory) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you print Also, is there a mechanism to register papers to cite in OpenMM? If so, we should register the model references. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think they have some MD loggers, but I don't think they're accessible from our function (they're attached to a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps the best solution is to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. printing everything is a good starting point, and enough for this PR! |
||||||||||||||||||||||||||||||||
# Get the atomic numbers of the ML region. | ||||||||||||||||||||||||||||||||
all_atoms = list(topology.atoms()) | ||||||||||||||||||||||||||||||||
atomic_numbers = [atom.element.atomic_number for atom in all_atoms] | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if atoms is None: | ||||||||||||||||||||||||||||||||
selected_atoms = None | ||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||
selected_atoms = Labels( | ||||||||||||||||||||||||||||||||
names=["system", "atom"], | ||||||||||||||||||||||||||||||||
values=torch.tensor( | ||||||||||||||||||||||||||||||||
[[0, selected_atom] for selected_atom in atoms], | ||||||||||||||||||||||||||||||||
dtype=torch.int32, | ||||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
class MetatensorForce(torch.nn.Module): | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||
model: torch.jit._script.RecursiveScriptModule, | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
We don't really need the types in |
||||||||||||||||||||||||||||||||
atomic_numbers: List[int], | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||
selected_atoms: Optional[Labels], | ||||||||||||||||||||||||||||||||
check_consistency: bool, | ||||||||||||||||||||||||||||||||
) -> None: | ||||||||||||||||||||||||||||||||
super(MetatensorForce, self).__init__() | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
self.model = model | ||||||||||||||||||||||||||||||||
self.register_buffer( | ||||||||||||||||||||||||||||||||
"atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int32) | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
self.evaluation_options = ModelEvaluationOptions( | ||||||||||||||||||||||||||||||||
length_unit="nm", | ||||||||||||||||||||||||||||||||
outputs={ | ||||||||||||||||||||||||||||||||
"energy": ModelOutput( | ||||||||||||||||||||||||||||||||
quantity="energy", | ||||||||||||||||||||||||||||||||
unit="kJ/mol", | ||||||||||||||||||||||||||||||||
per_atom=False, | ||||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||||
selected_atoms=selected_atoms, | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
requested_nls = self.model.requested_neighbor_lists() | ||||||||||||||||||||||||||||||||
if len(requested_nls) > 1: | ||||||||||||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||||||||||||
"The model requested more than one neighbor list. " | ||||||||||||||||||||||||||||||||
"Currently, only models with a single neighbor list are supported " | ||||||||||||||||||||||||||||||||
"by the OpenMM interface." | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
elif len(requested_nls) == 1: | ||||||||||||||||||||||||||||||||
self.requested_neighbor_list = requested_nls[0] | ||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||
# no neighbor list requested | ||||||||||||||||||||||||||||||||
self.requested_neighbor_list = None | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
self.check_consistency = check_consistency | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def _attach_neighbors(self, system: System) -> System: | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if self.requested_neighbor_list is None: | ||||||||||||||||||||||||||||||||
return system | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
cell: Optional[torch.Tensor] = None | ||||||||||||||||||||||||||||||||
if not torch.all(system.cell == 0.0): | ||||||||||||||||||||||||||||||||
cell = system.cell | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
# Get the neighbor pairs, shifts and edge indices. | ||||||||||||||||||||||||||||||||
neighbors, interatomic_vectors, _, _ = getNeighborPairs( | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
for clarity |
||||||||||||||||||||||||||||||||
system.positions, | ||||||||||||||||||||||||||||||||
self.requested_neighbor_list.engine_cutoff("nm"), | ||||||||||||||||||||||||||||||||
-1, | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you give the name of this argument? (and maybe others as well) |
||||||||||||||||||||||||||||||||
cell, | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
mask = neighbors[0] >= 0 | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will there negative values in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think that NL returns a huge Tensor with all pairs, where only a certain number are not -1, i.e., masked There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand this behavior if we are setting |
||||||||||||||||||||||||||||||||
neighbors = neighbors[:, mask] | ||||||||||||||||||||||||||||||||
interatomic_vectors = interatomic_vectors[mask, :] | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if self.requested_neighbor_list.full_list: | ||||||||||||||||||||||||||||||||
neighbors = torch.concatenate((neighbors, neighbors.flip(0)), dim=1) | ||||||||||||||||||||||||||||||||
interatomic_vectors = torch.concatenate( | ||||||||||||||||||||||||||||||||
(interatomic_vectors, -interatomic_vectors) | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if cell is not None: | ||||||||||||||||||||||||||||||||
interatomic_vectors_unit_cell = ( | ||||||||||||||||||||||||||||||||
system.positions[interatomic_vectors[0]] | ||||||||||||||||||||||||||||||||
- system.positions[interatomic_vectors[1]] | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
cell_shifts = ( | ||||||||||||||||||||||||||||||||
interatomic_vectors_unit_cell - interatomic_vectors | ||||||||||||||||||||||||||||||||
) @ torch.linalg.inv(cell) | ||||||||||||||||||||||||||||||||
cell_shifts = torch.round(cell_shifts).to(torch.int32) | ||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||
cell_shifts = torch.zeros( | ||||||||||||||||||||||||||||||||
(neighbors.shape[1], 3), | ||||||||||||||||||||||||||||||||
dtype=torch.int32, | ||||||||||||||||||||||||||||||||
device=system.positions.device, | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
neighbor_list = TensorBlock( | ||||||||||||||||||||||||||||||||
values=interatomic_vectors.reshape(-1, 3, 1), | ||||||||||||||||||||||||||||||||
samples=Labels( | ||||||||||||||||||||||||||||||||
names=[ | ||||||||||||||||||||||||||||||||
"first_atom", | ||||||||||||||||||||||||||||||||
"second_atom", | ||||||||||||||||||||||||||||||||
"cell_shift_a", | ||||||||||||||||||||||||||||||||
"cell_shift_b", | ||||||||||||||||||||||||||||||||
"cell_shift_c", | ||||||||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||||||||
values=torch.concatenate([neighbors.T, cell_shifts], dim=-1), | ||||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||||
components=[ | ||||||||||||||||||||||||||||||||
Labels( | ||||||||||||||||||||||||||||||||
names=["xyz"], | ||||||||||||||||||||||||||||||||
values=torch.arange( | ||||||||||||||||||||||||||||||||
3, dtype=torch.int32, device=system.positions.device | ||||||||||||||||||||||||||||||||
).reshape(-1, 1), | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||||||||
properties=Labels( | ||||||||||||||||||||||||||||||||
names=["distance"], | ||||||||||||||||||||||||||||||||
values=torch.tensor( | ||||||||||||||||||||||||||||||||
[[0]], dtype=torch.int32, device=system.positions.device | ||||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
system.add_neighbor_list(self.requested_neighbor_list, neighbor_list) | ||||||||||||||||||||||||||||||||
return system | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
def forward( | ||||||||||||||||||||||||||||||||
self, positions: torch.Tensor, cell: Optional[torch.Tensor] = None | ||||||||||||||||||||||||||||||||
) -> torch.Tensor: | ||||||||||||||||||||||||||||||||
# move labels if necessary | ||||||||||||||||||||||||||||||||
selected_atoms = self.evaluation_options.selected_atoms | ||||||||||||||||||||||||||||||||
if selected_atoms is not None: | ||||||||||||||||||||||||||||||||
if selected_atoms.device != positions.device: | ||||||||||||||||||||||||||||||||
self.evaluation_options.selected_atoms = selected_atoms.to( | ||||||||||||||||||||||||||||||||
positions.device | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if cell is None: | ||||||||||||||||||||||||||||||||
cell = torch.zeros( | ||||||||||||||||||||||||||||||||
(3, 3), dtype=positions.dtype, device=positions.device | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
# create System | ||||||||||||||||||||||||||||||||
system = System( | ||||||||||||||||||||||||||||||||
types=self.atomic_numbers, | ||||||||||||||||||||||||||||||||
positions=positions, | ||||||||||||||||||||||||||||||||
cell=cell, | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
system = self._attach_neighbors(system) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
energy = ( | ||||||||||||||||||||||||||||||||
self.model( | ||||||||||||||||||||||||||||||||
[system], | ||||||||||||||||||||||||||||||||
self.evaluation_options, | ||||||||||||||||||||||||||||||||
check_consistency=self.check_consistency, | ||||||||||||||||||||||||||||||||
)["energy"] | ||||||||||||||||||||||||||||||||
.block() | ||||||||||||||||||||||||||||||||
.values.reshape(()) | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||
return energy | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
metatensor_force = MetatensorForce( | ||||||||||||||||||||||||||||||||
model, | ||||||||||||||||||||||||||||||||
atomic_numbers, | ||||||||||||||||||||||||||||||||
selected_atoms, | ||||||||||||||||||||||||||||||||
check_consistency, | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
# torchscript everything | ||||||||||||||||||||||||||||||||
module = torch.jit.script(metatensor_force) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
# create the OpenMM force | ||||||||||||||||||||||||||||||||
force = openmmtorch.TorchForce(module) | ||||||||||||||||||||||||||||||||
isPeriodic = ( | ||||||||||||||||||||||||||||||||
topology.getPeriodicBoxVectors() is not None | ||||||||||||||||||||||||||||||||
) or system.usesPeriodicBoundaryConditions() | ||||||||||||||||||||||||||||||||
force.setUsesPeriodicBoundaryConditions(isPeriodic) | ||||||||||||||||||||||||||||||||
force.setForceGroup(forceGroup) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
return force |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.