From 2f4bc7c29a8bf4e4010e510593d1dae832dc7ea0 Mon Sep 17 00:00:00 2001 From: Pulin Agrawal Date: Tue, 28 Jun 2022 13:21:13 +0530 Subject: [PATCH 1/4] enable resume training from checkpoint --- run.py | 1 + 1 file changed, 1 insertion(+) diff --git a/run.py b/run.py index 160ed762..993e1384 100644 --- a/run.py +++ b/run.py @@ -3,6 +3,7 @@ import argparse import numpy as np from pathlib import Path +from collections import OrderedDict from models import * from experiment import VAEXperiment import torch.backends.cudnn as cudnn From 85cd5fdbb0298004bd94b4c23973f0a350dea110 Mon Sep 17 00:00:00 2001 From: Pulin Agrawal Date: Tue, 28 Jun 2022 13:22:52 +0530 Subject: [PATCH 2/4] enable backward compatibility --- run.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/run.py b/run.py index 993e1384..0e5973e9 100644 --- a/run.py +++ b/run.py @@ -37,6 +37,15 @@ seed_everything(config['exp_params']['manual_seed'], True) model = vae_models[config['model_params']['name']](**config['model_params']) +if 'custom_params' in config: + if config['custom_params']['resume_training']: + checkpoint = torch.load(config['custom_params']['resume_chkpt_path']) + state_dict = checkpoint['state_dict'] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + new_state_dict[k.replace("model.", "")] = v + model.load_state_dict(new_state_dict) + experiment = VAEXperiment(model, config['exp_params']) From 85403274c685d69176a2aad0bc2b3565a4664af5 Mon Sep 17 00:00:00 2001 From: Pulin Agrawal Date: Tue, 28 Jun 2022 13:26:21 +0530 Subject: [PATCH 3/4] example config setup for resume training --- configs/vae.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/configs/vae.yaml b/configs/vae.yaml index abd336b8..71dbb043 100644 --- a/configs/vae.yaml +++ b/configs/vae.yaml @@ -27,3 +27,6 @@ logging_params: save_dir: "logs/" name: "VanillaVAE" +custom_params: + resume_training: false + resume_chkpt_path: '' \ No newline at end of file From 070c2746988a7179916d626786123650b5e106fd Mon Sep 17 00:00:00 2001 From: Pulin Agrawal Date: Tue, 28 Jun 2022 15:55:39 +0530 Subject: [PATCH 4/4] fix quotes in yaml --- configs/vae.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/vae.yaml b/configs/vae.yaml index 71dbb043..7683f403 100644 --- a/configs/vae.yaml +++ b/configs/vae.yaml @@ -29,4 +29,4 @@ logging_params: custom_params: resume_training: false - resume_chkpt_path: '' \ No newline at end of file + resume_chkpt_path: ""