-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add OpenMM-torch interface #664
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this using openmm-ml instead of openmm-torch directly? My understanding of the former was that it is supposed to provide specific models, not a whole new set of models?
This will also need tests, documentation and examples. And I guess this required #402 to be done first as well, since OpenMM relies heavily on conda.
python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py
Outdated
Show resolved
Hide resolved
python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py
Outdated
Show resolved
Hide resolved
python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py
Outdated
Show resolved
Hide resolved
python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py
Outdated
Show resolved
Hide resolved
python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py
Outdated
Show resolved
Hide resolved
python/metatensor-torch/metatensor/torch/atomistic/openmm_interface.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could not find how this code handle dtype/device? The model might be on a different device than the positions/cell.
For the dtype, the model might require a specific one and the system should be converted to this dtype
.. autoclass:: metatensor.torch.atomistic.openmm_interface.get_metatensor_force | ||
:show-inheritance: | ||
:members: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.. autoclass:: metatensor.torch.atomistic.openmm_interface.get_metatensor_force | |
:show-inheritance: | |
:members: | |
.. autofunction:: metatensor.torch.atomistic.openmm_interface.get_metatensor_force |
forceGroup: int = 0, | ||
atoms: Optional[Iterable[int]] = None, | ||
check_consistency: bool = False, | ||
) -> openmm.System: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this needs a docstring & doc for all parameters
path: str, | ||
extensions_directory: Optional[str] = None, | ||
forceGroup: int = 0, | ||
atoms: Optional[Iterable[int]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
atoms: Optional[Iterable[int]] = None, | |
selected_atoms: Optional[Iterable[int]] = None, |
def __init__( | ||
self, | ||
model: torch.jit._script.RecursiveScriptModule, | ||
atomic_numbers: List[int], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
atomic_numbers: List[int], | |
atomic_types: List[int], |
cell = system.cell | ||
|
||
# Get the neighbor pairs, shifts and edge indices. | ||
neighbors, interatomic_vectors, _, _ = getNeighborPairs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
neighbors, interatomic_vectors, _, _ = getNeighborPairs( | |
neighbors, interatomic_vectors, _, _ = NNPOps.neighbors.getNeighborPairs( |
for clarity
neighbors, interatomic_vectors, _, _ = getNeighborPairs( | ||
system.positions, | ||
self.requested_neighbor_list.engine_cutoff("nm"), | ||
-1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you give the name of this argument? (and maybe others as well)
-1, | ||
cell, | ||
) | ||
mask = neighbors[0] >= 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will there negative values in neighbors[0]
if n_pairs=-1
above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think that NL returns a huge Tensor with all pairs, where only a certain number are not -1, i.e., masked
energy = ( | ||
self.model( | ||
[system], | ||
self.evaluation_options, | ||
check_consistency=self.check_consistency, | ||
)["energy"] | ||
.block() | ||
.values.reshape(()) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
energy = ( | |
self.model( | |
[system], | |
self.evaluation_options, | |
check_consistency=self.check_consistency, | |
)["energy"] | |
.block() | |
.values.reshape(()) | |
) | |
outputs = self.model( | |
[system], | |
self.evaluation_options, | |
check_consistency=self.check_consistency, | |
) | |
energy = outputs["energy"].block().values.reshape(()) |
) | ||
|
||
model = load_atomistic_model(path, extensions_directory=extensions_directory) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you print model.metadata()
to the log somewhere? This prints information about the model.
Also, is there a mechanism to register papers to cite in OpenMM? If so, we should register the model references.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think they have some MD loggers, but I don't think they're accessible from our function (they're attached to a Simulation
object that is independent). Literally zero clue regarding the citation mechanism
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps the best solution is to print()
and whatever
|
||
def __init__( | ||
self, | ||
model: torch.jit._script.RecursiveScriptModule, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model: torch.jit._script.RecursiveScriptModule, | |
model, |
We don't really need the types in __init__
, and this one is private
Typical code to start a simulation would look like this:
The device is handled by OpenMM torch (it moves the model to the desired device). Let me know what you think the best solution is (leave as is, check at each step, etc) |
Adds a MD interface to OpenMM via OpenMM-torch
Contributor (creator of pull-request) checklist
Reviewer checklist
📚 Download documentation preview for this pull-request
⚙️ Download Python wheels for this pull-request (you can install these with pip)