@@ -33,6 +33,11 @@ def main(config: DictConfig) -> None:
3333 runs = [config .run ]
3434 assert all ([run_type in ["train" , "val" , "test" , "predict" ] for run_type in runs ])
3535
36+ # ensure only single train at most, to protect restart and checkpointing logic later
37+ assert (
38+ sum ([run_type == "train" for run_type in runs ]) == 1
39+ ), "only a single `train` instance can be present in `run`"
40+
3641 if "train" in runs :
3742 assert (
3843 "loss" in config .train
@@ -164,14 +169,17 @@ def training_data_stats(stat_name: str):
164169 trainer = instantiate (trainer_cfg , inference_mode = False )
165170
166171 # === loop of run types ===
167- ckpt_path = config .get ("ckpt_path" , None )
172+ # restart behavior is such that
173+ # - train from ckpt uses the correct ckpt file to restore training state (so it is given a specific `ckpt_path`)
174+ # - val/test/predict from ckpt would use the `nequip_module` from the ckpt (and so uses `ckpt_path=None`)
175+ # - if we train, then val/test/predict, we set `ckpt_path="best"` after training so val/test/predict tasks after that will use the "best" model
176+ ckpt_path = None
168177 for run_type in runs :
169178 if run_type == "train" :
179+ ckpt_path = config .get ("ckpt_path" , None )
170180 logger .info ("TRAIN RUN START" )
171181 trainer .fit (nequip_module , datamodule = datamodule , ckpt_path = ckpt_path )
172- # `fit` is the only task that changes the model, so we remove the ckpt_path if train ever gets called
173- # this means that the latest model information is used for other tasks (val, test, predict)
174- ckpt_path = None
182+ ckpt_path = "best"
175183 logger .info ("TRAIN RUN END" )
176184 elif run_type == "val" :
177185 logger .info ("VAL RUN START" )
0 commit comments