Skip to content

Commit

Permalink
[HREMD] Truncate extra trajectory frames when loading checkpoint (#32)
Browse files Browse the repository at this point in the history
* [HREMD] Truncate extra trajectory frames when loading checkpoint
  • Loading branch information
SimonBoothroyd authored Dec 16, 2024
1 parent 0f1ae8f commit f5e1220
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
5 changes: 3 additions & 2 deletions docs/migration.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ field in OpenFF, Amber, and OpenMM FFXML formats. This re-write also introduced
- Ligand force field parameters no longer need to be provided. The [femto.md.config.Prepare][] configuration now
exposes a `default_ligand_ff` field that can be used to automatically parameterize ligands with an OpenFF based
force field.
- HREMD now correctly stores coordinates as ``coords[i] = replica_i_coords`` rather than ``coords[i] = state_i_coords``.
Checkpoints from previous versions will likely be incorrect.
- HREMD now correctly stores coordinates in checkpoint files as ``coords[i] = replica_i_coords`` rather than
``coords[i] = state_i_coords``. Checkpoints from previous versions will likely be incorrect.
- Trajectories and sample files are now correctly truncated when restarting from a checkpoint file.

#### FE

Expand Down
32 changes: 32 additions & 0 deletions femto/md/hremd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pickle
import typing

import mdtraj
import numpy
import openmm.app
import openmm.unit
Expand Down Expand Up @@ -103,11 +104,34 @@ def _create_storage(
yield _HREMDStorage(file, writer, schema)


def _truncate_trajectories(
simulation: openmm.app.Simulation,
max_cycles: int,
trajectory_interval: int,
trajectory_paths: list[pathlib.Path],
):
"""Truncate the trajectories to the maximum number of cycles."""

max_frames = (max_cycles - 1) // trajectory_interval + 1 if max_cycles > 0 else 0

topology_mdtraj = mdtraj.Topology.from_openmm(simulation.topology)

for path in trajectory_paths:
if not path.exists():
continue

trajectory = mdtraj.load(str(path), top=topology_mdtraj)
trajectory = trajectory[:max_frames]

trajectory.save_dcd(str(path))


def _create_trajectory_storage(
simulation: openmm.app.Simulation,
n_replicas: int,
replica_idx_offset: int,
n_steps_per_cycle: int,
max_cycles: int,
trajectory_interval: int | None,
output_dir: pathlib.Path | None,
exit_stack: contextlib.ExitStack,
Expand All @@ -119,6 +143,7 @@ def _create_trajectory_storage(
n_replicas: The number of replicas being sampled on this process.
replica_idx_offset: The index of the first replica being sampled on this process
n_steps_per_cycle: The number of steps per cycle.
max_cycles: The maximum number of cycles to retain if the file already exists.
trajectory_interval: The interval with which to write the trajectory.
output_dir: The root output directory. Any trajectories will be written to
`output_dir/trajectories/r{replica_idx}.dcd`.
Expand All @@ -138,6 +163,10 @@ def _create_trajectory_storage(
]
should_append = [path.exists() for path in trajectory_paths]

_truncate_trajectories(
simulation, max_cycles, trajectory_interval, trajectory_paths
)

return [
openmm.app.DCDFile(
exit_stack.enter_context(path.open("wb" if not append else "r+b")),
Expand Down Expand Up @@ -667,6 +696,7 @@ def run_hremd(
n_replicas,
replica_idx_offset,
config.n_steps_per_cycle,
start_cycle,
config.trajectory_interval,
output_dir,
exit_stack,
Expand All @@ -675,6 +705,7 @@ def run_hremd(
for cycle in tqdm.tqdm(
range(start_cycle, config.n_cycles),
total=config.n_cycles - start_cycle,
initial=start_cycle,
disable=mpi_comm.rank != 0,
):
reduced_potentials = _propagate_replicas(
Expand Down Expand Up @@ -703,6 +734,7 @@ def run_hremd(
)

if should_save_trajectory:
mpi_comm.barrier()
_store_trajectory(coords, trajectory_storage)

should_analyze = (
Expand Down

0 comments on commit f5e1220

Please sign in to comment.