diff --git a/configs/vae.yaml b/configs/vae.yaml index abd336b8..7683f403 100644 --- a/configs/vae.yaml +++ b/configs/vae.yaml @@ -27,3 +27,6 @@ logging_params: save_dir: "logs/" name: "VanillaVAE" +custom_params: + resume_training: false + resume_chkpt_path: "" diff --git a/run.py b/run.py index 160ed762..0e5973e9 100644 --- a/run.py +++ b/run.py @@ -3,6 +3,7 @@ import argparse import numpy as np from pathlib import Path +from collections import OrderedDict from models import * from experiment import VAEXperiment import torch.backends.cudnn as cudnn @@ -36,6 +37,15 @@ seed_everything(config['exp_params']['manual_seed'], True) model = vae_models[config['model_params']['name']](**config['model_params']) +if 'custom_params' in config: + if config['custom_params']['resume_training']: + checkpoint = torch.load(config['custom_params']['resume_chkpt_path']) + state_dict = checkpoint['state_dict'] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + new_state_dict[k.replace("model.", "")] = v + model.load_state_dict(new_state_dict) + experiment = VAEXperiment(model, config['exp_params'])