Skip to content

Commit d18dd2a

Browse files
Merge trainer revamp
2 parents d3a7763 + 145e138 commit d18dd2a

File tree

11 files changed

+270
-77
lines changed

11 files changed

+270
-77
lines changed

CHANGELOG.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77
Most recent change on the bottom.
88

99

10-
## Unreleased
10+
## Unreleased - 0.7.0
11+
### Added
12+
- `--override` now supported as a `nequip-train` flag (similar to its use in `nequip-deploy`)
13+
- add SoftAdapt (https://arxiv.org/abs/2403.18122) callback option
14+
15+
### Changed
16+
- [Breaking] training restart behavior altered: file-wise consistency checks performed between original config and config passed to `nequip-train` on restart (instead of checking the config dicts)
17+
- [Breaking] config format for callbacks changed (see `configs/full.yaml` for an example)
1118

19+
### Fixed
20+
- fixed `wandb_watch` bug
1221

1322
## [0.6.1] - 2024-7-9
1423
### Added

configs/full.yaml

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,10 +256,17 @@ loss_coeffs:
256256
# In the "schedule" key each entry is a two-element list of:
257257
# - the 1-based epoch index at which to start the new loss coefficients
258258
# - the new loss coefficients as a dict
259-
#
260-
# start_of_epoch_callbacks:
261-
# - !!python/object:nequip.train.callbacks.loss_schedule.SimpleLossSchedule {"schedule": [[2, {"forces": 0.0, "total_energy": 1.0}]]}
262-
#
259+
# callbacks:
260+
# start_of_epoch:
261+
# - !!python/object:nequip.train.callbacks.SimpleLossSchedule {"schedule": [[2, {"forces": 0.0, "total_energy": 1.0}]]}
262+
263+
# You can also try using the SoftAdapt strategy for adaptively changing loss coefficients
264+
# (see https://arxiv.org/abs/2403.18122)
265+
#callbacks:
266+
# end_of_batch:
267+
# - !!python/object:nequip.train.callbacks.SoftAdapt {"batches_per_update": 5, "beta": 1.1}
268+
269+
263270

264271
# output metrics
265272
metrics_components:

nequip/scripts/train.py

Lines changed: 73 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import logging
44
import argparse
55
import warnings
6+
import shutil
7+
import difflib
8+
import yaml
69

710
# This is a weird hack to avoid Intel MKL issues on the cluster when this is called as a subprocess of a process that has itself initialized PyTorch.
811
# Since numpy gets imported later anyway for dataset stuff, this shouldn't affect performance.
@@ -29,6 +32,8 @@
2932
root="./",
3033
tensorboard=False,
3134
wandb=False,
35+
wandb_watch=False,
36+
wandb_watch_kwargs={},
3237
model_builders=[
3338
"SimpleIrrepsConfig",
3439
"EnergyModel",
@@ -46,7 +51,7 @@
4651
equivariance_test=False,
4752
grad_anomaly_mode=False,
4853
gpu_oom_offload=False,
49-
append=False,
54+
append=True,
5055
warn_unused=False,
5156
_jit_bailout_depth=2, # avoid 20 iters of pain, see https://github.com/pytorch/pytorch/issues/52286
5257
# Quote from eelison in PyTorch slack:
@@ -68,32 +73,61 @@
6873

6974

7075
def main(args=None, running_as_script: bool = True):
71-
config = parse_command_line(args)
76+
config, path_to_config, override_options = parse_command_line(args)
7277

7378
if running_as_script:
7479
set_up_script_logger(config.get("log", None), config.verbose)
7580

76-
found_restart_file = exists(f"{config.root}/{config.run_name}/trainer.pth")
81+
train_dir = f"{config.root}/{config.run_name}"
82+
found_restart_file = exists(f"{train_dir}/trainer.pth")
7783
if found_restart_file and not config.append:
7884
raise RuntimeError(
79-
f"Training instance exists at {config.root}/{config.run_name}; "
85+
f"Training instance exists at {train_dir}; "
8086
"either set append to True or use a different root or runname"
8187
)
82-
elif not found_restart_file and isdir(f"{config.root}/{config.run_name}"):
88+
elif not found_restart_file and isdir(train_dir):
8389
# output directory exists but no ``trainer.pth`` file, suggesting previous run crash during
8490
# first training epoch (usually due to memory):
8591
warnings.warn(
86-
f"Previous run folder at {config.root}/{config.run_name} exists, but a saved model "
92+
f"Previous run folder at {train_dir} exists, but a saved model "
8793
f"(trainer.pth file) was not found. This folder will be cleared and a fresh training run will "
8894
f"be started."
8995
)
90-
rmtree(f"{config.root}/{config.run_name}")
96+
rmtree(train_dir)
9197

92-
# for fresh new train
93-
if not found_restart_file:
98+
if not found_restart_file: # fresh start
99+
# update config with override parameters for setting up train-dir
100+
config.update(override_options)
94101
trainer = fresh_start(config)
95-
else:
96-
trainer = restart(config)
102+
# copy original config to training directory
103+
shutil.copyfile(path_to_config, f"{train_dir}/original_config.yaml")
104+
else: # restart
105+
# perform string matching for original config and restart config
106+
# throw error if they are different
107+
with (
108+
open(f"{train_dir}/original_config.yaml") as orig_f,
109+
open(path_to_config) as current_f,
110+
):
111+
diffs = [
112+
x
113+
for x in difflib.Differ().compare(
114+
orig_f.readlines(), current_f.readlines()
115+
)
116+
if x[0] in ("+", "-")
117+
]
118+
if diffs:
119+
raise RuntimeError(
120+
f"Config {path_to_config} used for restart differs from original config for training run in {train_dir}.\n"
121+
+ "The following differences were found:\n\n"
122+
+ "".join(diffs)
123+
+ "\n"
124+
+ "If you intend to override the original config parameters, use the --override flag. For example, use\n"
125+
+ f'`nequip-train {path_to_config} --override "max_epochs: 42"`\n'
126+
+ 'on the command line to override the config parameter "max_epochs"\n'
127+
+ "BE WARNED that use of the --override flag is not protected by consistency checks performed by NequIP."
128+
)
129+
else:
130+
trainer = restart(config, override_options)
97131

98132
# Train
99133
trainer.save()
@@ -157,6 +191,12 @@ def parse_command_line(args=None):
157191
help="Warn instead of error when the config contains unused keys",
158192
action="store_true",
159193
)
194+
parser.add_argument(
195+
"--override",
196+
help="Override top-level configuration keys from the `--train-dir`/`--model`'s config YAML file. This should be a valid YAML string. Unless you know why you need to, do not use this option.",
197+
type=str,
198+
default=None,
199+
)
160200
args = parser.parse_args(args=args)
161201

162202
config = Config.from_file(args.config, defaults=default_config)
@@ -169,10 +209,26 @@ def parse_command_line(args=None):
169209
):
170210
config[flag] = getattr(args, flag) or config[flag]
171211

172-
return config
212+
# Set override options before _set_global_options so that things like allow_tf32 are correctly handled
213+
if args.override is not None:
214+
override_options = yaml.load(args.override, Loader=yaml.Loader)
215+
assert isinstance(
216+
override_options, dict
217+
), "--override's YAML string must define a dictionary of top-level options"
218+
overridden_keys = set(config.keys()).intersection(override_options.keys())
219+
set_keys = set(override_options.keys()) - set(overridden_keys)
220+
logging.info(
221+
f"--override: overrode keys {list(overridden_keys)} and set new keys {list(set_keys)}"
222+
)
223+
del overridden_keys, set_keys
224+
else:
225+
override_options = {}
226+
227+
return config, args.config, override_options
173228

174229

175230
def fresh_start(config):
231+
176232
# we use add_to_config cause it's a fresh start and need to record it
177233
check_code_version(config, add_to_config=True)
178234
_set_global_options(config)
@@ -267,7 +323,7 @@ def _unused_check():
267323
return trainer
268324

269325

270-
def restart(config):
326+
def restart(config, override_options):
271327
# load the dictionary
272328
restart_file = f"{config.root}/{config.run_name}/trainer.pth"
273329
dictionary = load_file(
@@ -276,20 +332,6 @@ def restart(config):
276332
enforced_format="torch",
277333
)
278334

279-
# compare dictionary to config and update stop condition related arguments
280-
for k in config.keys():
281-
if config[k] != dictionary.get(k, ""):
282-
if k == "max_epochs":
283-
dictionary[k] = config[k]
284-
logging.info(f'Update "{k}" to {dictionary[k]}')
285-
elif k.startswith("early_stop"):
286-
dictionary[k] = config[k]
287-
logging.info(f'Update "{k}" to {dictionary[k]}')
288-
elif isinstance(config[k], type(dictionary.get(k, ""))):
289-
raise ValueError(
290-
f'Key "{k}" is different in config and the result trainer.pth file. Please double check'
291-
)
292-
293335
# note, "trainer.pth"/dictionary also store code versions,
294336
# which will not be stored in config and thus not checked here
295337
check_code_version(config)
@@ -299,6 +341,10 @@ def restart(config):
299341

300342
config = Config(dictionary, exclude_keys=["state_dict", "progress"])
301343

344+
# override configs loaded from save
345+
dictionary.update(override_options)
346+
config.update(override_options)
347+
302348
# dtype, etc.
303349
_set_global_options(config)
304350

nequip/train/callback_manager.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from nequip.utils import load_callable
2+
import dataclasses
3+
4+
5+
class CallbackManager:
6+
"""Parent callback class
7+
8+
Centralized object to manage various callbacks that can be added-on.
9+
"""
10+
11+
def __init__(
12+
self,
13+
callbacks={},
14+
):
15+
CALLBACK_TYPES = [
16+
"init",
17+
"start_of_epoch",
18+
"end_of_epoch",
19+
"end_of_batch",
20+
"end_of_train",
21+
"final",
22+
]
23+
# load all callbacks
24+
self.callbacks = {callback_type: [] for callback_type in CALLBACK_TYPES}
25+
26+
for callback_type in callbacks:
27+
if callback_type not in CALLBACK_TYPES:
28+
raise ValueError(
29+
f"{callback_type} is not a supported callback type.\nSupported callback types include "
30+
+ str(CALLBACK_TYPES)
31+
)
32+
# make sure callbacks are either dataclasses or functions
33+
for callback in callbacks[callback_type]:
34+
if not (dataclasses.is_dataclass(callback) or callable(callback)):
35+
raise ValueError(
36+
f"Callbacks must be of type dataclass or callable. Error found on the callback {callback} of type {callback_type}"
37+
)
38+
self.callbacks[callback_type].append(load_callable(callback))
39+
40+
def apply(self, trainer, callback_type: str):
41+
42+
for callback in self.callbacks.get(callback_type):
43+
callback(trainer)
44+
45+
def state_dict(self):
46+
return {"callback_manager_obj_callbacks": self.callbacks}
47+
48+
def load_state_dict(self, state_dict):
49+
self.callbacks = state_dict.get("callback_manager_obj_callbacks")

nequip/train/callbacks/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .adaptive_loss_weights import SoftAdapt
2+
from .loss_schedule import SimpleLossSchedule
3+
4+
__all__ = [SoftAdapt, SimpleLossSchedule]
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from dataclasses import dataclass
2+
3+
from nequip.train import Trainer
4+
5+
from nequip.train._key import ABBREV
6+
import torch
7+
8+
# Making this a dataclass takes care of equality operators, handing restart consistency checks
9+
10+
11+
@dataclass
12+
class SoftAdapt:
13+
"""Adaptively modify `loss_coeffs` through a training run using the SoftAdapt scheme (https://arxiv.org/abs/2403.18122)
14+
15+
To use this in a training, set in your YAML file:
16+
17+
end_of_batch_callbacks:
18+
- !!python/object:nequip.train.callbacks.adaptive_loss_weights.SoftAdapt {"batches_per_update": 20, "beta": 1.0}
19+
20+
This funny syntax tells PyYAML to construct an object of this class.
21+
22+
Main hyperparameters are:
23+
- how often the loss weights are updated, `batches_per_update`
24+
- how sensitive the new loss weights are to the change in loss components, `beta`
25+
"""
26+
27+
# user-facing parameters
28+
batches_per_update: int = None
29+
beta: float = None
30+
eps: float = 1e-8 # small epsilon to avoid division by zero
31+
# attributes for internal tracking
32+
batch_counter: int = -1
33+
prev_losses: torch.Tensor = None
34+
cached_weights = None
35+
36+
def __call__(self, trainer: Trainer):
37+
38+
# --- CORRECTNESS CHECKS ---
39+
assert self in trainer.callback_manager.callbacks["end_of_batch"]
40+
assert self.batches_per_update >= 1
41+
42+
# track batch number
43+
self.batch_counter += 1
44+
45+
# empty list of cached weights to store for next cycle
46+
if self.batch_counter % self.batches_per_update == 0:
47+
self.cached_weights = []
48+
49+
# --- MAIN LOGIC THAT RUNS EVERY EPOCH ---
50+
51+
# collect loss for each training target
52+
losses = []
53+
for key in trainer.loss.coeffs.keys():
54+
losses.append(trainer.batch_losses[f"loss_{ABBREV.get(key)}"])
55+
new_losses = torch.tensor(losses)
56+
57+
# compute and cache new loss weights over the update cycle
58+
if self.prev_losses is None:
59+
self.prev_losses = new_losses
60+
return
61+
else:
62+
# compute normalized loss change
63+
loss_change = new_losses - self.prev_losses
64+
loss_change = torch.nn.functional.normalize(
65+
loss_change, dim=0, eps=self.eps
66+
)
67+
self.prev_losses = new_losses
68+
# compute new weights with softmax
69+
exps = torch.exp(self.beta * loss_change)
70+
self.cached_weights.append(exps.div(exps.sum() + self.eps))
71+
72+
# --- average weights over previous cycle and update ---
73+
if self.batch_counter % self.batches_per_update == 1:
74+
softadapt_weights = torch.stack(self.cached_weights, dim=-1).mean(-1)
75+
counter = 0
76+
for key in trainer.loss.coeffs.keys():
77+
trainer.loss.coeffs.update({key: softadapt_weights[counter]})
78+
counter += 1

nequip/train/callbacks/loss_schedule.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ class SimpleLossSchedule:
2525

2626
def __call__(self, trainer: Trainer):
2727
assert (
28-
self in trainer._start_of_epoch_callbacks
28+
self in trainer.callback_manager.callbacks["start_of_epoch"]
2929
), "must be start not end of epoch"
3030
# user-facing 1 based indexing of epochs rather than internal zero based
31+
3132
iepoch: int = trainer.iepoch + 1
3233
if iepoch < 1: # initial validation epoch is 0 in user-facing indexing
3334
return

nequip/train/loss.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,15 @@ def __call__(self, pred: dict, ref: dict):
109109

110110
return loss, contrib
111111

112+
def state_dict(self):
113+
# verbose key names to avoid repetition/clashes
114+
dictionary = {"loss_obj_coeffs": self.coeffs}
115+
return dictionary
116+
117+
def load_state_dict(self, state_dict):
118+
# only need to save/load loss weights (or coefficients)
119+
self.coeffs = state_dict.get("loss_obj_coeffs")
120+
112121

113122
class LossStat:
114123
"""

0 commit comments

Comments
 (0)