Skip to content

Commit

Permalink
Merge pull request #7 from kntkb/fix/use_u_qm_to_reshape_graphs
Browse files Browse the repository at this point in the history
Fix/add option to use u_qm to reshape graphs
  • Loading branch information
kntkb authored Mar 16, 2024
2 parents bf4f819 + abe1c3a commit 0cd025f
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 16 deletions.
10 changes: 4 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,18 @@ Infrastruture to train espaloma with experimental observables
### Installation
>mamba create -n espfit python=3.11
>mamba install -c conda-forge espaloma=0.3.2
>#uninstall openff-toolkit and install a customized version to support dgl graphs created using openff-toolkit=0.10.6
>conda uninstall --force openff-toolkit
>pip install git+https://github.com/kntkb/openff-toolkit.git@7e9d0225782ef723083407a1cbf1f4f70631f934
>#install openeye-toolkit
>mamba install openeye-toolkits -c openeye
>#uninstall openmmforcefields if < 0.12.0
>conda uninstall --force openmmforcefields
>#use pip instead of mamba to avoid dependency issues with ambertools and python
>pip install git+https://github.com/openmm/[email protected]
>#install openmmtools
>mamba install openmmtools
>#install barnaba
>mamba install barnaba
#### Notes
- `openff-toolkit` is re-installed with a customized version to support dgl graphs created using `openff-toolkit=0.10.6`
- `openmmforcefields` is reinstalled if the version is `<0.12.0` using pip to avoid dependency issues with `ambertools` and `python`. espaloma functionalities are better supported after `>=0.12.0`.


### Quick Usage
```python
Expand Down
56 changes: 52 additions & 4 deletions espfit/app/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def report_loss(self, epoch, loss_dict):

log_file_path = os.path.join(self.output_directory_path, 'reporter.log')
df_new = pd.DataFrame.from_dict(loss_dict, orient='index').T
df_new = df_new.mul(100) # Multiple each loss component by 100
df_new = df_new.mul(100) # Multiple each loss component by 100. Is this large enough?
df_new.insert(0, 'epoch', epoch)

if os.path.exists(log_file_path):
Expand Down Expand Up @@ -455,7 +455,14 @@ def train_sampler(self, sampler_patience=800, neff_threshold=0.2, sampler_weight
with torch.autograd.set_detect_anomaly(True):
for i in range(self.restart_epoch, self.epochs):
epoch = i + 1 # Start from 1 (not zero-indexing)


"""
# torch.cuda.OutOfMemoryError: CUDA out of memory.
# Tried to allocate 80.00 MiB (GPU 0; 10.75 GiB total capacity;
# 9.76 GiB already allocated; 7.62 MiB free; 10.40 GiB reserved in total by PyTorch)
# If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.
# See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
loss = torch.tensor(0.0)
if torch.cuda.is_available():
loss = loss.cuda("cuda:0")
Expand Down Expand Up @@ -496,7 +503,48 @@ def train_sampler(self, sampler_patience=800, neff_threshold=0.2, sampler_weight
# Back propagate
loss.backward()
optimizer.step()

"""

# Gradient accumulation
accumulation_steps = len(ds_tr_loader)
for g in ds_tr_loader:
optimizer.zero_grad()
if torch.cuda.is_available():
g = g.to("cuda:0")
g.nodes["n1"].data["xyz"].requires_grad = True

loss, loss_dict = self.net(g)
loss = loss/accumulation_steps
loss.backward()

if epoch > self.sampler_patience:
# Save checkpoint as local model (net.pt)
# `neff_min` is -1 if SamplerReweight.samplers is None
samplers = self._setup_local_samplers(epoch, net_copy, debug)
neff_min = SamplerReweight.get_effective_sample_size(temporary_samplers=samplers)

# If effective sample size is below threshold, update SamplerReweight.samplers and re-run simulaton
if neff_min < self.neff_threshold:
_logger.info(f'Minimum effective sample size ({neff_min:.3f}) below threshold ({self.neff_threshold})')
SamplerReweight.samplers = samplers
SamplerReweight.run()
del samplers

# Compute sampler loss
loss_list = SamplerReweight.compute_loss() # list of torch.tensor
for sampler_index, sampler_loss in enumerate(loss_list):
sampler = SamplerReweight.samplers[sampler_index]
loss += sampler_loss * sampler_weight
loss_dict[f'{sampler.target_name}'] = sampler_loss.item()
loss.backward()
loss_dict['neff'] = neff_min

loss_dict['loss'] = loss.item()
self.report_loss(epoch, loss_dict)

# Update
optimizer.step()

if epoch % self.checkpoint_frequency == 0:
# Note: returned loss is a joint loss of different units.
#_loss = HARTREE_TO_KCALPERMOL * loss.pow(0.5).item()
Expand Down Expand Up @@ -577,7 +625,7 @@ def _save_local_model(self, epoch, net_copy):
_logger.info(f'Save ckpt{epoch}.pt as temporary espaloma model (net.pt)')
self._save_checkpoint(epoch)
local_model = os.path.join(self.output_directory_path, f"ckpt{epoch}.pt")
self.save_model(net=net_copy, best_model=local_model, model_name=f"net.pt", output_directory_path=self.output_directory_path)
self.save_model(net=net_copy, checkpoint_file=local_model, output_model=f"net.pt", output_directory_path=self.output_directory_path)


def _setup_local_samplers(self, epoch, net_copy, debug):
Expand Down
28 changes: 22 additions & 6 deletions espfit/utils/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class CustomGraphDataset(GraphDataset):
compute_baseline_energy_force(forcefield_list=['openff-2.1.0']):
Compute energies and forces using other force fields.
reshape_conformation_size(n_confs=50, include_min_energy_conf=False):
reshape_conformation_size(n_confs=50, include_min_energy_conf=False, keyname='u_ref'):
Reshape conformation size.
compute_relative_energy():
Expand Down Expand Up @@ -514,7 +514,7 @@ def compute_relative_energy(self):
del new_graphs


def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):
def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False, keyname='u_ref'):
"""Reshape conformation size.
This is a work around to handle different graph size (shape). DGL requires at least one dimension with same size.
Expand All @@ -539,6 +539,11 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):
include_min_energy_conf : boolean, default=False
If True, then minimum energy conformer will be included for all split graphs.
keyname : str, default='u_ref'
Key name to be used to define the energy minima. This is usually `u_ref` or `u_qm`.
Note that depending on how the dataset was prepared, nonbonded energies could be subtracted from `u_ref`,
whereas `u_qm` could be the raw QM energies.
Returns
-------
None
Expand All @@ -553,8 +558,9 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):
import copy
import torch

# Remove node features that are not used during training
self._remove_node_features()
# Check if keyname is specified
if include_min_energy_conf == True and keyname not in ['u_ref', 'u_qm']:
raise Exception(f'Key name {keyname} not supported. Supported keynames are u_ref and u_qm')

new_graphs = []
n_confs_cache = n_confs
Expand Down Expand Up @@ -584,7 +590,14 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):

# Get index for minimum energy conformer
if include_min_energy_conf:
index_min = [g.nodes['g'].data['u_ref'].argmin().item()]
index_min = [g.nodes['g'].data[keyname].argmin().item()]

# DEBUG PURPOSE
#_index_min_uref = [g.nodes['g'].data['u_ref'].argmin().item()]
#_index_min_uqm = [g.nodes['g'].data['u_qm'].argmin().item()]
#_logger.info(f'(u_ref:{_index_min_uref[0]} and u_qm:{_index_min_uqm[0]})')
#_logger.info(f'Index for minima energy conformer {keyname}: {index_min[0]}')

n_confs = n_confs - 1
_logger.info(f"Mol #{i} ({n} conformers): Shuffle and split into {n_confs} conformers and add minimum energy conformer (index #{index_min[0]})")
else:
Expand All @@ -603,7 +616,7 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):
_logger.debug(f"Iteration {j}: Randomly select {len(index_random)} conformers and add minimum energy conformer")
else:
_logger.debug(f"Iteration {j}: Randomly select {len(index_random)} conformers")

_g.nodes["g"].data["u_ref"] = torch.cat((_g.nodes['g'].data['u_ref'][:, index], _g.nodes['g'].data['u_ref'][:, index_random]), dim=-1)
_g.nodes["g"].data["u_ref_relative"] = torch.cat((_g.nodes['g'].data['u_ref_relative'][:, index], _g.nodes['g'].data['u_ref_relative'][:, index_random]), dim=-1)
_g.nodes["n1"].data["xyz"] = torch.cat((_g.nodes['n1'].data['xyz'][:, index, :], _g.nodes['n1'].data['xyz'][:, index_random, :]), dim=1)
Expand All @@ -628,6 +641,9 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):

# Update in place
self.graphs = new_graphs
# Remove node features that are not used during training
self._remove_node_features()

del new_graphs


Expand Down

0 comments on commit 0cd025f

Please sign in to comment.