-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add OpenMM-torch interface #664
base: master
Are you sure you want to change the base?
Conversation
2818ef7
to
8fc6b55
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this using openmm-ml instead of openmm-torch directly? My understanding of the former was that it is supposed to provide specific models, not a whole new set of models?
This will also need tests, documentation and examples. And I guess this required #402 to be done first as well, since OpenMM relies heavily on conda.
python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py
Outdated
Show resolved
Hide resolved
python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py
Outdated
Show resolved
Hide resolved
python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py
Outdated
Show resolved
Hide resolved
python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py
Outdated
Show resolved
Hide resolved
python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py
Outdated
Show resolved
Hide resolved
python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could not find how this code handle dtype/device? The model might be on a different device than the positions/cell.
For the dtype, the model might require a specific one and the system should be converted to this dtype
.. autoclass:: metatensor.torch.atomistic.openmm_interface.get_metatensor_force | ||
:show-inheritance: | ||
:members: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.. autoclass:: metatensor.torch.atomistic.openmm_interface.get_metatensor_force | |
:show-inheritance: | |
:members: | |
.. autofunction:: metatensor.torch.atomistic.openmm_interface.get_metatensor_force |
forceGroup: int = 0, | ||
atoms: Optional[Iterable[int]] = None, | ||
check_consistency: bool = False, | ||
) -> openmm.System: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this needs a docstring & doc for all parameters
path: str, | ||
extensions_directory: Optional[str] = None, | ||
forceGroup: int = 0, | ||
atoms: Optional[Iterable[int]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
atoms: Optional[Iterable[int]] = None, | |
selected_atoms: Optional[Iterable[int]] = None, |
def __init__( | ||
self, | ||
model: torch.jit._script.RecursiveScriptModule, | ||
atomic_numbers: List[int], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
atomic_numbers: List[int], | |
atomic_types: List[int], |
cell = system.cell | ||
|
||
# Get the neighbor pairs, shifts and edge indices. | ||
neighbors, interatomic_vectors, _, _ = getNeighborPairs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
neighbors, interatomic_vectors, _, _ = getNeighborPairs( | |
neighbors, interatomic_vectors, _, _ = NNPOps.neighbors.getNeighborPairs( |
for clarity
neighbors, interatomic_vectors, _, _ = getNeighborPairs( | ||
system.positions, | ||
self.requested_neighbor_list.engine_cutoff("nm"), | ||
-1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you give the name of this argument? (and maybe others as well)
-1, | ||
cell, | ||
) | ||
mask = neighbors[0] >= 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will there negative values in neighbors[0]
if n_pairs=-1
above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think that NL returns a huge Tensor with all pairs, where only a certain number are not -1, i.e., masked
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand this behavior if we are setting max_n_pairs=300
: if there are fewer pairs, the other ones are filled with -1. You are saying this is also the case when giving max_n_pairs=-1
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I tried a few tests (with max_n_pairs=-1
) and the mask is still necessary. I think they're always returning a number of pairs that is a multiple of 4. E.g. if you have 27 pairs, 28 will be returned and one of them is masked with center=neighbor=-1
energy = ( | ||
self.model( | ||
[system], | ||
self.evaluation_options, | ||
check_consistency=self.check_consistency, | ||
)["energy"] | ||
.block() | ||
.values.reshape(()) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
energy = ( | |
self.model( | |
[system], | |
self.evaluation_options, | |
check_consistency=self.check_consistency, | |
)["energy"] | |
.block() | |
.values.reshape(()) | |
) | |
outputs = self.model( | |
[system], | |
self.evaluation_options, | |
check_consistency=self.check_consistency, | |
) | |
energy = outputs["energy"].block().values.reshape(()) |
) | ||
|
||
model = load_atomistic_model(path, extensions_directory=extensions_directory) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you print model.metadata()
to the log somewhere? This prints information about the model.
Also, is there a mechanism to register papers to cite in OpenMM? If so, we should register the model references.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think they have some MD loggers, but I don't think they're accessible from our function (they're attached to a Simulation
object that is independent). Literally zero clue regarding the citation mechanism
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps the best solution is to print()
and whatever
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
printing everything is a good starting point, and enough for this PR!
|
||
def __init__( | ||
self, | ||
model: torch.jit._script.RecursiveScriptModule, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model: torch.jit._script.RecursiveScriptModule, | |
model, |
We don't really need the types in __init__
, and this one is private
Typical code to start a simulation would look like this:
The device is handled by OpenMM torch (it moves the model to the desired device). Let me know what you think the best solution is (leave as is, check at each step, etc) |
The way we are doing this in other engines is to always convert from the engine dtype to the model dtype for inputs, and the other way around for outputs. We could add a warning about mismatched dtype if we do the conversion, but I would rather allow the code to run instead of crashing. |
Printing metadata doesn't work because |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! We should find a way to run the tests for this. I guess the main issue is that openmm requires conda for installation, right?
Either we keep the code in this repo, and we find a way to use conda from tox; or we move the code to another repo, with its own CI setup.
Maybe adding a section in tox.ini
for "tests that require conda", and in there creating a conda env by calling conda explicitly could work.
``conda install -c conda-forge openmm-torch nnpops``. Subsequently, | ||
metatensor can be installed with ``pip install metatensor[torch]``, and a minimal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since there is #665, should we recommend building metatensor-torch from sources for now? Something like pip install --no-binary=metatensor-torch metatensor[torch]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good
@@ -17,10 +17,6 @@ follows [Semantic Versioning](https://semver.org/spec/v2.0.0.html). | |||
### Removed | |||
--> | |||
|
|||
### Changed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this removal looks like a merge issue
:param topology: The OpenMM Topology object. | ||
:param path: The path to the exported metatensor model. | ||
:param extensions_directory: The path to the extensions for the model. | ||
:param forceGroup: The force group to which the force should be added. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we link to OpenMM docs for force group here? It might not be clear to everyone.
.. autofunction:: metatensor.torch.atomistic.openmm_force.get_metatensor_force | ||
|
||
In order to run simulations with ``metatensor.torch.atomistic`` and ``OpenMM``, | ||
we recommend installing ``OpenMM`` from conda, using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should mention than NNPOps is a required dependency as well in here
module = torch.jit.script(metatensor_force) | ||
|
||
# create the OpenMM force | ||
force = openmmtorch.TorchForce(module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so model.to(device)
is called by TorchForce? Could you leave a comment if this is the case, so future readers don't wonder what's happening?
) | ||
|
||
|
||
def model_different_units(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this might not be so interesting to run here, I did it in ASE to ensure our code did unit conversions right, but we don't have to do it for all interfaces.
force = get_metatensor_force(system, topology, model_path, check_consistency=True) | ||
system.addForce(force) | ||
integrator = openmm.VerletIntegrator(0.001 * openmm.unit.picoseconds) | ||
platform = openmm.Platform.getPlatformByName("CPU") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we try to run this test on CUDA if it is available?
state.getForces(asNumpy=True).value_in_unit( | ||
openmm.unit.ev / (openmm.unit.angstrom * openmm.unit.mole) | ||
) | ||
/ 6.0221367e23 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe put this in a constant on top?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, why are you doing * openmm.unit.mole
above, just to divide again here?
try: | ||
import NNPOps # noqa: F401 | ||
import openmm | ||
import openmmtorch # noqa: F401 | ||
|
||
HAS_OPENMM = True | ||
except ImportError: | ||
HAS_OPENMM = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we want to skip the whole file if one of the dependency is missing, we should use importskip:
try: | |
import NNPOps # noqa: F401 | |
import openmm | |
import openmmtorch # noqa: F401 | |
HAS_OPENMM = True | |
except ImportError: | |
HAS_OPENMM = False | |
openmm = pytest.importskip("openmm") | |
pytest.importskip("NNPOps") | |
pytest.importskip("openmmtorch") |
Adds a MD interface to OpenMM via OpenMM-torch
Contributor (creator of pull-request) checklist
Reviewer checklist
📚 Download documentation preview for this pull-request
⚙️ Download Python wheels for this pull-request (you can install these with pip)