-
Notifications
You must be signed in to change notification settings - Fork 24
Add OpenMM-torch interface #664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8c403c1
1e3a0a6
1792629
a673e36
cea817b
f2e3477
21b2629
7e2f643
5e94523
e74d0fd
c660e07
94e0973
3ea577a
aed0dcf
4624483
fc19b41
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
OpenMM integration | ||
================== | ||
|
||
.. py:currentmodule:: metatensor.torch.atomistic | ||
|
||
:py:mod:`openmm_interface` contains the ``get_metatensor_force`` function, | ||
which can be used to load a :py:class:`MetatensorAtomisticModel` into an | ||
``openmm.Force`` object able to calculate forces on an ``openmm.System``. | ||
|
||
.. autofunction:: metatensor.torch.atomistic.openmm_force.get_metatensor_force | ||
|
||
In order to run simulations with ``metatensor.torch.atomistic`` and ``OpenMM``, | ||
we recommend installing ``OpenMM`` from conda, using | ||
``conda install -c conda-forge openmm-torch nnpops``. Subsequently, | ||
metatensor can be installed with ``pip install metatensor[torch]``, and a minimal | ||
Comment on lines
+14
to
+15
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since there is #665, should we recommend building metatensor-torch from sources for now? Something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good |
||
script demonstrating how to run a simple simulation is illustrated below: | ||
|
||
.. code-block:: python | ||
|
||
import openmm | ||
from metatensor.torch.atomistic.openmm_interface import get_metatensor_force | ||
|
||
# load an input geometry file | ||
topology = openmm.app.PDBFile('input.pdb').getTopology() | ||
system = openmm.System() | ||
for atom in topology.atoms(): | ||
system.addParticle(atom.element.mass) | ||
|
||
# get the force object from an exported model saved as 'model.pt' | ||
force = get_metatensor_force(system, topology, 'model.pt') | ||
system.addForce(force) | ||
|
||
integrator = openmm.VerletIntegrator(0.001) | ||
platform = openmm.Platform.getPlatformByName('CUDA') | ||
simulation = openmm.app.Simulation(topology, system, integrator, platform) | ||
simulation.context.setPositions(openmm.app.PDBFile('input.pdb').getPositions()) | ||
simulation.reporters.append(openmm.app.PDBReporter('output.pdb', 100)) | ||
|
||
simulation.step(1000) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,10 +17,6 @@ follows [Semantic Versioning](https://semver.org/spec/v2.0.0.html). | |
### Removed | ||
--> | ||
|
||
### Changed | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this removal looks like a merge issue |
||
|
||
- We now require Python >= 3.9 | ||
|
||
## [Version 0.2.2](https://github.com/lab-cosmo/metatensor/releases/tag/metatensor-operations-v0.2.2) - 2024-06-19 | ||
|
||
### Fixed | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,286 @@ | ||
import warnings | ||
from typing import Iterable, List, Optional | ||
|
||
import torch | ||
|
||
from metatensor.torch import Labels, TensorBlock | ||
from metatensor.torch.atomistic import ( | ||
ModelEvaluationOptions, | ||
ModelOutput, | ||
NeighborListOptions, | ||
System, | ||
load_atomistic_model, | ||
) | ||
|
||
from .ase_calculator import STR_TO_DTYPE | ||
|
||
|
||
try: | ||
import NNPOps.neighbors | ||
import openmm | ||
import openmmtorch | ||
|
||
HAS_OPENMM = True | ||
except ImportError: | ||
|
||
class openmm: | ||
class System: | ||
pass | ||
|
||
class Force: | ||
pass | ||
|
||
class app: | ||
class Topology: | ||
pass | ||
|
||
HAS_OPENMM = False | ||
|
||
|
||
def get_metatensor_force( | ||
system: openmm.System, | ||
topology: openmm.app.Topology, | ||
path: str, | ||
extensions_directory: Optional[str] = None, | ||
forceGroup: int = 0, | ||
selected_atoms: Optional[Iterable[int]] = None, | ||
check_consistency: bool = False, | ||
) -> openmm.Force: | ||
""" | ||
Create an OpenMM Force from a metatensor atomistic model. | ||
|
||
:param system: The OpenMM System object. | ||
:param topology: The OpenMM Topology object. | ||
:param path: The path to the exported metatensor model. | ||
:param extensions_directory: The path to the extensions for the model. | ||
:param forceGroup: The force group to which the force should be added. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we link to OpenMM docs for force group here? It might not be clear to everyone. |
||
:param selected_atoms: The indices of the atoms on which non-zero forces should | ||
be computed (e.g., the ML region). If None, forces will be computed for all | ||
atoms. | ||
:param check_consistency: Whether to check various consistency conditions inside | ||
the model. | ||
|
||
return: The OpenMM Force object. | ||
""" | ||
|
||
if not HAS_OPENMM: | ||
raise ImportError( | ||
"Could not import openmm and/or nnpops. If you want to use metatensor with " | ||
"openmm, please install openmm-torch and nnpops with conda." | ||
) | ||
|
||
model = load_atomistic_model(path, extensions_directory=extensions_directory) | ||
|
||
# Print the model's metadata | ||
print(model.metadata()) | ||
|
||
# Get the atomic numbers of the ML region. | ||
all_atoms = list(topology.atoms()) | ||
atomic_types = [atom.element.atomic_number for atom in all_atoms] | ||
|
||
if selected_atoms is None: | ||
selected_atoms = None | ||
else: | ||
selected_atoms = Labels( | ||
names=["system", "atom"], | ||
values=torch.tensor( | ||
[[0, selected_atom] for selected_atom in selected_atoms], | ||
dtype=torch.int32, | ||
), | ||
) | ||
|
||
class MetatensorForce(torch.nn.Module): | ||
|
||
def __init__( | ||
self, | ||
model, | ||
atomic_types: List[int], | ||
selected_atoms: Optional[Labels], | ||
check_consistency: bool, | ||
) -> None: | ||
super(MetatensorForce, self).__init__() | ||
|
||
self.model = model | ||
self.register_buffer( | ||
"atomic_types", torch.tensor(atomic_types, dtype=torch.int32) | ||
) | ||
self.evaluation_options = ModelEvaluationOptions( | ||
length_unit="nm", | ||
outputs={ | ||
"energy": ModelOutput( | ||
quantity="energy", | ||
unit="kJ/mol", | ||
per_atom=False, | ||
), | ||
}, | ||
selected_atoms=selected_atoms, | ||
) | ||
|
||
requested_nls = self.model.requested_neighbor_lists() | ||
if len(requested_nls) > 1: | ||
raise ValueError( | ||
"The model requested more than one neighbor list. " | ||
"Currently, only models with a single neighbor list are supported " | ||
"by the OpenMM interface." | ||
) | ||
elif len(requested_nls) == 1: | ||
self.requested_neighbor_list = requested_nls[0] | ||
else: | ||
# no neighbor list requested | ||
self.requested_neighbor_list = None | ||
|
||
self.check_consistency = check_consistency | ||
self.dtype = STR_TO_DTYPE[self.model.capabilities().dtype] | ||
self.already_warned = False | ||
|
||
def forward( | ||
self, positions: torch.Tensor, cell: Optional[torch.Tensor] = None | ||
) -> torch.Tensor: | ||
# move labels if necessary | ||
selected_atoms = self.evaluation_options.selected_atoms | ||
if selected_atoms is not None: | ||
if selected_atoms.device != positions.device: | ||
self.evaluation_options.selected_atoms = selected_atoms.to( | ||
positions.device | ||
) | ||
|
||
if cell is None: | ||
cell = torch.zeros( | ||
(3, 3), dtype=positions.dtype, device=positions.device | ||
) | ||
|
||
# create System | ||
system = System( | ||
types=self.atomic_types, | ||
positions=positions, | ||
cell=cell, | ||
) | ||
system = _attach_neighbors(system, self.requested_neighbor_list) | ||
|
||
original_dtype = system.dtype | ||
if self.dtype != system.dtype: | ||
if not self.already_warned: | ||
model_dtype_string = dtype_to_str(self.dtype) | ||
system_dtype_string = dtype_to_str(system.dtype) | ||
warnings.warn( | ||
f"Model dtype {model_dtype_string} does not match the dtype " | ||
f"of the system {system_dtype_string}. The system will be " | ||
f"converted to {model_dtype_string} temporarily to allow the " | ||
"model to run.", | ||
stacklevel=2, | ||
) | ||
self.already_warned = True | ||
system = system.to(dtype=self.dtype) | ||
|
||
outputs = self.model( | ||
[system], | ||
self.evaluation_options, | ||
check_consistency=self.check_consistency, | ||
) | ||
energy = outputs["energy"].block().values.reshape(()) | ||
return energy.to(original_dtype) | ||
|
||
metatensor_force = MetatensorForce( | ||
model, | ||
atomic_types, | ||
selected_atoms, | ||
check_consistency, | ||
) | ||
|
||
# torchscript everything | ||
module = torch.jit.script(metatensor_force) | ||
|
||
# create the OpenMM force | ||
force = openmmtorch.TorchForce(module) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so |
||
isPeriodic = ( | ||
topology.getPeriodicBoxVectors() is not None | ||
) or system.usesPeriodicBoundaryConditions() | ||
force.setUsesPeriodicBoundaryConditions(isPeriodic) | ||
force.setForceGroup(forceGroup) | ||
|
||
return force | ||
|
||
|
||
def _attach_neighbors( | ||
system: System, requested_nl_options: Optional[NeighborListOptions] | ||
) -> System: | ||
|
||
if requested_nl_options is None: | ||
return system | ||
|
||
cell: Optional[torch.Tensor] = None | ||
if not torch.all(system.cell == 0.0): | ||
cell = system.cell | ||
|
||
# Get the neighbor pairs, shifts and edge indices. | ||
neighbors, interatomic_vectors, _, _ = NNPOps.neighbors.getNeighborPairs( | ||
positions=system.positions, | ||
cutoff=requested_nl_options.engine_cutoff("nm"), | ||
max_num_pairs=-1, | ||
box_vectors=cell, | ||
) | ||
mask = neighbors[0] >= 0 | ||
neighbors = neighbors[:, mask] | ||
neighbors = neighbors.flip(0) # [neighbor, center] -> [center, neighbor] | ||
interatomic_vectors = interatomic_vectors[mask, :] | ||
|
||
if requested_nl_options.full_list: | ||
neighbors = torch.concatenate((neighbors, neighbors.flip(0)), dim=1) | ||
interatomic_vectors = torch.concatenate( | ||
(interatomic_vectors, -interatomic_vectors) | ||
) | ||
|
||
if cell is not None: | ||
interatomic_vectors_unit_cell = ( | ||
system.positions[neighbors[1]] - system.positions[neighbors[0]] | ||
) | ||
cell_shifts = ( | ||
interatomic_vectors_unit_cell - interatomic_vectors | ||
) @ torch.linalg.inv(cell) | ||
cell_shifts = torch.round(cell_shifts).to(torch.int32) | ||
else: | ||
cell_shifts = torch.zeros( | ||
(neighbors.shape[1], 3), | ||
dtype=torch.int32, | ||
device=system.positions.device, | ||
) | ||
|
||
neighbor_list = TensorBlock( | ||
values=interatomic_vectors.reshape(-1, 3, 1), | ||
samples=Labels( | ||
names=[ | ||
"first_atom", | ||
"second_atom", | ||
"cell_shift_a", | ||
"cell_shift_b", | ||
"cell_shift_c", | ||
], | ||
values=torch.concatenate([neighbors.T, cell_shifts], dim=-1), | ||
), | ||
components=[ | ||
Labels( | ||
names=["xyz"], | ||
values=torch.arange( | ||
3, dtype=torch.int32, device=system.positions.device | ||
).reshape(-1, 1), | ||
) | ||
], | ||
properties=Labels( | ||
names=["distance"], | ||
values=torch.tensor( | ||
[[0]], dtype=torch.int32, device=system.positions.device | ||
), | ||
), | ||
) | ||
|
||
system.add_neighbor_list(requested_nl_options, neighbor_list) | ||
return system | ||
|
||
|
||
def dtype_to_str(dtype: torch.dtype) -> str: | ||
if dtype == torch.float32: | ||
return "float32" | ||
elif dtype == torch.float64: | ||
return "float64" | ||
else: | ||
raise ValueError(f"Unsupported dtype {dtype}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should mention than NNPOps is a required dependency as well in here