Skip to content

Commit 9bae4dd

Browse files
committed
use best model if val/test/predict after train
1 parent 85f02ce commit 9bae4dd

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

nequip/scripts/train.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ def main(config: DictConfig) -> None:
3333
runs = [config.run]
3434
assert all([run_type in ["train", "val", "test", "predict"] for run_type in runs])
3535

36+
# ensure only single train at most, to protect restart and checkpointing logic later
37+
assert (
38+
sum([run_type == "train" for run_type in runs]) == 1
39+
), "only a single `train` instance can be present in `run`"
40+
3641
if "train" in runs:
3742
assert (
3843
"loss" in config.train
@@ -164,14 +169,17 @@ def training_data_stats(stat_name: str):
164169
trainer = instantiate(trainer_cfg, inference_mode=False)
165170

166171
# === loop of run types ===
167-
ckpt_path = config.get("ckpt_path", None)
172+
# restart behavior is such that
173+
# - train from ckpt uses the correct ckpt file to restore training state (so it is given a specific `ckpt_path`)
174+
# - val/test/predict from ckpt would use the `nequip_module` from the ckpt (and so uses `ckpt_path=None`)
175+
# - if we train, then val/test/predict, we set `ckpt_path="best"` after training so val/test/predict tasks after that will use the "best" model
176+
ckpt_path = None
168177
for run_type in runs:
169178
if run_type == "train":
179+
ckpt_path = config.get("ckpt_path", None)
170180
logger.info("TRAIN RUN START")
171181
trainer.fit(nequip_module, datamodule=datamodule, ckpt_path=ckpt_path)
172-
# `fit` is the only task that changes the model, so we remove the ckpt_path if train ever gets called
173-
# this means that the latest model information is used for other tasks (val, test, predict)
174-
ckpt_path = None
182+
ckpt_path = "best"
175183
logger.info("TRAIN RUN END")
176184
elif run_type == "val":
177185
logger.info("VAL RUN START")

0 commit comments

Comments
 (0)