Skip to content

Commit

Permalink
use best model if val/test/predict after train
Browse files Browse the repository at this point in the history
  • Loading branch information
cw-tan committed Oct 9, 2024
1 parent 85f02ce commit 9bae4dd
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions nequip/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def main(config: DictConfig) -> None:
runs = [config.run]
assert all([run_type in ["train", "val", "test", "predict"] for run_type in runs])

# ensure only single train at most, to protect restart and checkpointing logic later
assert (
sum([run_type == "train" for run_type in runs]) == 1
), "only a single `train` instance can be present in `run`"

if "train" in runs:
assert (
"loss" in config.train
Expand Down Expand Up @@ -164,14 +169,17 @@ def training_data_stats(stat_name: str):
trainer = instantiate(trainer_cfg, inference_mode=False)

# === loop of run types ===
ckpt_path = config.get("ckpt_path", None)
# restart behavior is such that
# - train from ckpt uses the correct ckpt file to restore training state (so it is given a specific `ckpt_path`)
# - val/test/predict from ckpt would use the `nequip_module` from the ckpt (and so uses `ckpt_path=None`)
# - if we train, then val/test/predict, we set `ckpt_path="best"` after training so val/test/predict tasks after that will use the "best" model
ckpt_path = None
for run_type in runs:
if run_type == "train":
ckpt_path = config.get("ckpt_path", None)
logger.info("TRAIN RUN START")
trainer.fit(nequip_module, datamodule=datamodule, ckpt_path=ckpt_path)
# `fit` is the only task that changes the model, so we remove the ckpt_path if train ever gets called
# this means that the latest model information is used for other tasks (val, test, predict)
ckpt_path = None
ckpt_path = "best"
logger.info("TRAIN RUN END")
elif run_type == "val":
logger.info("VAL RUN START")
Expand Down

0 comments on commit 9bae4dd

Please sign in to comment.