@@ -33,6 +33,11 @@ def main(config: DictConfig) -> None:
33
33
runs = [config .run ]
34
34
assert all ([run_type in ["train" , "val" , "test" , "predict" ] for run_type in runs ])
35
35
36
+ # ensure only single train at most, to protect restart and checkpointing logic later
37
+ assert (
38
+ sum ([run_type == "train" for run_type in runs ]) == 1
39
+ ), "only a single `train` instance can be present in `run`"
40
+
36
41
if "train" in runs :
37
42
assert (
38
43
"loss" in config .train
@@ -164,14 +169,17 @@ def training_data_stats(stat_name: str):
164
169
trainer = instantiate (trainer_cfg , inference_mode = False )
165
170
166
171
# === loop of run types ===
167
- ckpt_path = config .get ("ckpt_path" , None )
172
+ # restart behavior is such that
173
+ # - train from ckpt uses the correct ckpt file to restore training state (so it is given a specific `ckpt_path`)
174
+ # - val/test/predict from ckpt would use the `nequip_module` from the ckpt (and so uses `ckpt_path=None`)
175
+ # - if we train, then val/test/predict, we set `ckpt_path="best"` after training so val/test/predict tasks after that will use the "best" model
176
+ ckpt_path = None
168
177
for run_type in runs :
169
178
if run_type == "train" :
179
+ ckpt_path = config .get ("ckpt_path" , None )
170
180
logger .info ("TRAIN RUN START" )
171
181
trainer .fit (nequip_module , datamodule = datamodule , ckpt_path = ckpt_path )
172
- # `fit` is the only task that changes the model, so we remove the ckpt_path if train ever gets called
173
- # this means that the latest model information is used for other tasks (val, test, predict)
174
- ckpt_path = None
182
+ ckpt_path = "best"
175
183
logger .info ("TRAIN RUN END" )
176
184
elif run_type == "val" :
177
185
logger .info ("VAL RUN START" )
0 commit comments