Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpenMM-torch interface #664

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open

Add OpenMM-torch interface #664

wants to merge 16 commits into from

Conversation

frostedoyster
Copy link
Contributor

@frostedoyster frostedoyster commented Jun 21, 2024

Adds a MD interface to OpenMM via OpenMM-torch

Contributor (creator of pull-request) checklist

  • Tests updated (for new features and bugfixes)?
  • Documentation updated (for new features)?
  • Issue referenced (for PRs that solve an issue)?

Reviewer checklist

  • CHANGELOG updated with public API or any other important changes?

📚 Download documentation preview for this pull-request

⚙️ Download Python wheels for this pull-request (you can install these with pip)

Copy link
Member

@Luthaf Luthaf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this using openmm-ml instead of openmm-torch directly? My understanding of the former was that it is supposed to provide specific models, not a whole new set of models?


This will also need tests, documentation and examples. And I guess this required #402 to be done first as well, since OpenMM relies heavily on conda.

@frostedoyster frostedoyster changed the title Add OpenMM-ML interface Add OpenMM-torch interface Jun 24, 2024
@frostedoyster frostedoyster marked this pull request as ready for review June 27, 2024 18:37
Copy link
Member

@Luthaf Luthaf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could not find how this code handle dtype/device? The model might be on a different device than the positions/cell.

For the dtype, the model might require a specific one and the system should be converted to this dtype

Comment on lines 10 to 12
.. autoclass:: metatensor.torch.atomistic.openmm_interface.get_metatensor_force
:show-inheritance:
:members:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.. autoclass:: metatensor.torch.atomistic.openmm_interface.get_metatensor_force
:show-inheritance:
:members:
.. autofunction:: metatensor.torch.atomistic.openmm_interface.get_metatensor_force

forceGroup: int = 0,
atoms: Optional[Iterable[int]] = None,
check_consistency: bool = False,
) -> openmm.System:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs a docstring & doc for all parameters

path: str,
extensions_directory: Optional[str] = None,
forceGroup: int = 0,
atoms: Optional[Iterable[int]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
atoms: Optional[Iterable[int]] = None,
selected_atoms: Optional[Iterable[int]] = None,

def __init__(
self,
model: torch.jit._script.RecursiveScriptModule,
atomic_numbers: List[int],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
atomic_numbers: List[int],
atomic_types: List[int],

cell = system.cell

# Get the neighbor pairs, shifts and edge indices.
neighbors, interatomic_vectors, _, _ = getNeighborPairs(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
neighbors, interatomic_vectors, _, _ = getNeighborPairs(
neighbors, interatomic_vectors, _, _ = NNPOps.neighbors.getNeighborPairs(

for clarity

neighbors, interatomic_vectors, _, _ = getNeighborPairs(
system.positions,
self.requested_neighbor_list.engine_cutoff("nm"),
-1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you give the name of this argument? (and maybe others as well)

-1,
cell,
)
mask = neighbors[0] >= 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will there negative values in neighbors[0] if n_pairs=-1 above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think that NL returns a huge Tensor with all pairs, where only a certain number are not -1, i.e., masked

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand this behavior if we are setting max_n_pairs=300: if there are fewer pairs, the other ones are filled with -1. You are saying this is also the case when giving max_n_pairs=-1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I tried a few tests (with max_n_pairs=-1) and the mask is still necessary. I think they're always returning a number of pairs that is a multiple of 4. E.g. if you have 27 pairs, 28 will be returned and one of them is masked with center=neighbor=-1

Comment on lines 206 to 214
energy = (
self.model(
[system],
self.evaluation_options,
check_consistency=self.check_consistency,
)["energy"]
.block()
.values.reshape(())
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
energy = (
self.model(
[system],
self.evaluation_options,
check_consistency=self.check_consistency,
)["energy"]
.block()
.values.reshape(())
)
outputs = self.model(
[system],
self.evaluation_options,
check_consistency=self.check_consistency,
)
energy = outputs["energy"].block().values.reshape(())

)

model = load_atomistic_model(path, extensions_directory=extensions_directory)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you print model.metadata() to the log somewhere? This prints information about the model.

Also, is there a mechanism to register papers to cite in OpenMM? If so, we should register the model references.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they have some MD loggers, but I don't think they're accessible from our function (they're attached to a Simulation object that is independent). Literally zero clue regarding the citation mechanism

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps the best solution is to print() and whatever

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

printing everything is a good starting point, and enough for this PR!


def __init__(
self,
model: torch.jit._script.RecursiveScriptModule,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model: torch.jit._script.RecursiveScriptModule,
model,

We don't really need the types in __init__, and this one is private

@frostedoyster
Copy link
Contributor Author

Typical code to start a simulation would look like this:

platform = openmm.Platform.getPlatformByName('CUDA')
properties = {'Precision': 'double'}
simulation = openmm.app.Simulation(topology, system, integrator, platform, properties)

The device is handled by OpenMM torch (it moves the model to the desired device).
The dtype is more tricky, as it is not handled by OpenMM (I think they're trying to allow model developers to do things in mixed precision, so they won't call to(dtype)). At the moment, if the precision is wrong, the model will crash at runtime, and if the user matches the dtypes correctly, the model will run.

Let me know what you think the best solution is (leave as is, check at each step, etc)

@Luthaf
Copy link
Member

Luthaf commented Jul 2, 2024

Let me know what you think the best solution is (leave as is, check at each step, etc)

The way we are doing this in other engines is to always convert from the engine dtype to the model dtype for inputs, and the other way around for outputs. We could add a warning about mismatched dtype if we do the conversion, but I would rather allow the code to run instead of crashing.

@frostedoyster
Copy link
Contributor Author

Printing metadata doesn't work because ::print() is not registered it seems, I'll see if I can fix it

@frostedoyster frostedoyster changed the base branch from master to metadata-print July 5, 2024 06:44
@frostedoyster frostedoyster changed the base branch from metadata-print to master July 5, 2024 18:46
Copy link
Member

@Luthaf Luthaf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! We should find a way to run the tests for this. I guess the main issue is that openmm requires conda for installation, right?

Either we keep the code in this repo, and we find a way to use conda from tox; or we move the code to another repo, with its own CI setup.

Maybe adding a section in tox.ini for "tests that require conda", and in there creating a conda env by calling conda explicitly could work.

Comment on lines +14 to +15
``conda install -c conda-forge openmm-torch nnpops``. Subsequently,
metatensor can be installed with ``pip install metatensor[torch]``, and a minimal
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there is #665, should we recommend building metatensor-torch from sources for now? Something like pip install --no-binary=metatensor-torch metatensor[torch]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good

@@ -17,10 +17,6 @@ follows [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
### Removed
-->

### Changed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this removal looks like a merge issue

:param topology: The OpenMM Topology object.
:param path: The path to the exported metatensor model.
:param extensions_directory: The path to the extensions for the model.
:param forceGroup: The force group to which the force should be added.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we link to OpenMM docs for force group here? It might not be clear to everyone.

.. autofunction:: metatensor.torch.atomistic.openmm_force.get_metatensor_force

In order to run simulations with ``metatensor.torch.atomistic`` and ``OpenMM``,
we recommend installing ``OpenMM`` from conda, using
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should mention than NNPOps is a required dependency as well in here

module = torch.jit.script(metatensor_force)

# create the OpenMM force
force = openmmtorch.TorchForce(module)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so model.to(device) is called by TorchForce? Could you leave a comment if this is the case, so future readers don't wonder what's happening?

)


def model_different_units():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this might not be so interesting to run here, I did it in ASE to ensure our code did unit conversions right, but we don't have to do it for all interfaces.

force = get_metatensor_force(system, topology, model_path, check_consistency=True)
system.addForce(force)
integrator = openmm.VerletIntegrator(0.001 * openmm.unit.picoseconds)
platform = openmm.Platform.getPlatformByName("CPU")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we try to run this test on CUDA if it is available?

state.getForces(asNumpy=True).value_in_unit(
openmm.unit.ev / (openmm.unit.angstrom * openmm.unit.mole)
)
/ 6.0221367e23
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe put this in a constant on top?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, why are you doing * openmm.unit.mole above, just to divide again here?

Comment on lines +13 to +20
try:
import NNPOps # noqa: F401
import openmm
import openmmtorch # noqa: F401

HAS_OPENMM = True
except ImportError:
HAS_OPENMM = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we want to skip the whole file if one of the dependency is missing, we should use importskip:

Suggested change
try:
import NNPOps # noqa: F401
import openmm
import openmmtorch # noqa: F401
HAS_OPENMM = True
except ImportError:
HAS_OPENMM = False
openmm = pytest.importskip("openmm")
pytest.importskip("NNPOps")
pytest.importskip("openmmtorch")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants