Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpenMM-torch interface #664

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open

Add OpenMM-torch interface #664

wants to merge 14 commits into from

Conversation

frostedoyster
Copy link
Contributor

@frostedoyster frostedoyster commented Jun 21, 2024

Adds a MD interface to OpenMM via OpenMM-torch

Contributor (creator of pull-request) checklist

  • Tests updated (for new features and bugfixes)?
  • Documentation updated (for new features)?
  • Issue referenced (for PRs that solve an issue)?

Reviewer checklist

  • CHANGELOG updated with public API or any other important changes?

📚 Download documentation preview for this pull-request

⚙️ Download Python wheels for this pull-request (you can install these with pip)

Copy link
Contributor

@Luthaf Luthaf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this using openmm-ml instead of openmm-torch directly? My understanding of the former was that it is supposed to provide specific models, not a whole new set of models?


This will also need tests, documentation and examples. And I guess this required #402 to be done first as well, since OpenMM relies heavily on conda.

@frostedoyster frostedoyster changed the title Add OpenMM-ML interface Add OpenMM-torch interface Jun 24, 2024
@frostedoyster frostedoyster marked this pull request as ready for review June 27, 2024 18:37
Copy link
Contributor

@Luthaf Luthaf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could not find how this code handle dtype/device? The model might be on a different device than the positions/cell.

For the dtype, the model might require a specific one and the system should be converted to this dtype

Comment on lines 10 to 12
.. autoclass:: metatensor.torch.atomistic.openmm_interface.get_metatensor_force
:show-inheritance:
:members:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.. autoclass:: metatensor.torch.atomistic.openmm_interface.get_metatensor_force
:show-inheritance:
:members:
.. autofunction:: metatensor.torch.atomistic.openmm_interface.get_metatensor_force

forceGroup: int = 0,
atoms: Optional[Iterable[int]] = None,
check_consistency: bool = False,
) -> openmm.System:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs a docstring & doc for all parameters

path: str,
extensions_directory: Optional[str] = None,
forceGroup: int = 0,
atoms: Optional[Iterable[int]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
atoms: Optional[Iterable[int]] = None,
selected_atoms: Optional[Iterable[int]] = None,

def __init__(
self,
model: torch.jit._script.RecursiveScriptModule,
atomic_numbers: List[int],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
atomic_numbers: List[int],
atomic_types: List[int],

cell = system.cell

# Get the neighbor pairs, shifts and edge indices.
neighbors, interatomic_vectors, _, _ = getNeighborPairs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
neighbors, interatomic_vectors, _, _ = getNeighborPairs(
neighbors, interatomic_vectors, _, _ = NNPOps.neighbors.getNeighborPairs(

for clarity

neighbors, interatomic_vectors, _, _ = getNeighborPairs(
system.positions,
self.requested_neighbor_list.engine_cutoff("nm"),
-1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you give the name of this argument? (and maybe others as well)

-1,
cell,
)
mask = neighbors[0] >= 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will there negative values in neighbors[0] if n_pairs=-1 above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think that NL returns a huge Tensor with all pairs, where only a certain number are not -1, i.e., masked

Comment on lines 206 to 214
energy = (
self.model(
[system],
self.evaluation_options,
check_consistency=self.check_consistency,
)["energy"]
.block()
.values.reshape(())
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
energy = (
self.model(
[system],
self.evaluation_options,
check_consistency=self.check_consistency,
)["energy"]
.block()
.values.reshape(())
)
outputs = self.model(
[system],
self.evaluation_options,
check_consistency=self.check_consistency,
)
energy = outputs["energy"].block().values.reshape(())

)

model = load_atomistic_model(path, extensions_directory=extensions_directory)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you print model.metadata() to the log somewhere? This prints information about the model.

Also, is there a mechanism to register papers to cite in OpenMM? If so, we should register the model references.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they have some MD loggers, but I don't think they're accessible from our function (they're attached to a Simulation object that is independent). Literally zero clue regarding the citation mechanism

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps the best solution is to print() and whatever


def __init__(
self,
model: torch.jit._script.RecursiveScriptModule,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model: torch.jit._script.RecursiveScriptModule,
model,

We don't really need the types in __init__, and this one is private

@frostedoyster
Copy link
Contributor Author

Typical code to start a simulation would look like this:

platform = openmm.Platform.getPlatformByName('CUDA')
properties = {'Precision': 'double'}
simulation = openmm.app.Simulation(topology, system, integrator, platform, properties)

The device is handled by OpenMM torch (it moves the model to the desired device).
The dtype is more tricky, as it is not handled by OpenMM (I think they're trying to allow model developers to do things in mixed precision, so they won't call to(dtype)). At the moment, if the precision is wrong, the model will crash at runtime, and if the user matches the dtypes correctly, the model will run.

Let me know what you think the best solution is (leave as is, check at each step, etc)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants