diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 1f46273d..77a44069 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -269,7 +269,10 @@ def main(): trainer.fit(model, data, ckpt_path=None if args.reset_trainer else args.load_model) # run test set after completing the fit - model = LNNP.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) + model = LNNP.load_from_checkpoint( + trainer.checkpoint_callback.best_model_path, + hparams_file=f"{args.log_dir}/input.yaml", + ) trainer = pl.Trainer( logger=_logger, inference_mode=False,