From f5e122018b9671df575ea11c3a7386f75ec3ad74 Mon Sep 17 00:00:00 2001 From: SimonBoothroyd Date: Mon, 16 Dec 2024 08:50:58 -0500 Subject: [PATCH] [HREMD] Truncate extra trajectory frames when loading checkpoint (#32) * [HREMD] Truncate extra trajectory frames when loading checkpoint --- docs/migration.md | 5 +++-- femto/md/hremd.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index a5b12f8..b1ab727 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -22,8 +22,9 @@ field in OpenFF, Amber, and OpenMM FFXML formats. This re-write also introduced - Ligand force field parameters no longer need to be provided. The [femto.md.config.Prepare][] configuration now exposes a `default_ligand_ff` field that can be used to automatically parameterize ligands with an OpenFF based force field. -- HREMD now correctly stores coordinates as ``coords[i] = replica_i_coords`` rather than ``coords[i] = state_i_coords``. - Checkpoints from previous versions will likely be incorrect. +- HREMD now correctly stores coordinates in checkpoint files as ``coords[i] = replica_i_coords`` rather than + ``coords[i] = state_i_coords``. Checkpoints from previous versions will likely be incorrect. +- Trajectories and sample files are now correctly truncated when restarting from a checkpoint file. #### FE diff --git a/femto/md/hremd.py b/femto/md/hremd.py index 79d1585..7d82fc2 100644 --- a/femto/md/hremd.py +++ b/femto/md/hremd.py @@ -7,6 +7,7 @@ import pickle import typing +import mdtraj import numpy import openmm.app import openmm.unit @@ -103,11 +104,34 @@ def _create_storage( yield _HREMDStorage(file, writer, schema) +def _truncate_trajectories( + simulation: openmm.app.Simulation, + max_cycles: int, + trajectory_interval: int, + trajectory_paths: list[pathlib.Path], +): + """Truncate the trajectories to the maximum number of cycles.""" + + max_frames = (max_cycles - 1) // trajectory_interval + 1 if max_cycles > 0 else 0 + + topology_mdtraj = mdtraj.Topology.from_openmm(simulation.topology) + + for path in trajectory_paths: + if not path.exists(): + continue + + trajectory = mdtraj.load(str(path), top=topology_mdtraj) + trajectory = trajectory[:max_frames] + + trajectory.save_dcd(str(path)) + + def _create_trajectory_storage( simulation: openmm.app.Simulation, n_replicas: int, replica_idx_offset: int, n_steps_per_cycle: int, + max_cycles: int, trajectory_interval: int | None, output_dir: pathlib.Path | None, exit_stack: contextlib.ExitStack, @@ -119,6 +143,7 @@ def _create_trajectory_storage( n_replicas: The number of replicas being sampled on this process. replica_idx_offset: The index of the first replica being sampled on this process n_steps_per_cycle: The number of steps per cycle. + max_cycles: The maximum number of cycles to retain if the file already exists. trajectory_interval: The interval with which to write the trajectory. output_dir: The root output directory. Any trajectories will be written to `output_dir/trajectories/r{replica_idx}.dcd`. @@ -138,6 +163,10 @@ def _create_trajectory_storage( ] should_append = [path.exists() for path in trajectory_paths] + _truncate_trajectories( + simulation, max_cycles, trajectory_interval, trajectory_paths + ) + return [ openmm.app.DCDFile( exit_stack.enter_context(path.open("wb" if not append else "r+b")), @@ -667,6 +696,7 @@ def run_hremd( n_replicas, replica_idx_offset, config.n_steps_per_cycle, + start_cycle, config.trajectory_interval, output_dir, exit_stack, @@ -675,6 +705,7 @@ def run_hremd( for cycle in tqdm.tqdm( range(start_cycle, config.n_cycles), total=config.n_cycles - start_cycle, + initial=start_cycle, disable=mpi_comm.rank != 0, ): reduced_potentials = _propagate_replicas( @@ -703,6 +734,7 @@ def run_hremd( ) if should_save_trajectory: + mpi_comm.barrier() _store_trajectory(coords, trajectory_storage) should_analyze = (