Skip to content

Commit 6dd4ce5

Browse files
authored
Merge pull request #2 from CAAI/raphael-dev
merge new UserConfig class to main
2 parents 822d3a7 + 1cd8b6f commit 6dd4ce5

File tree

5 files changed

+141
-79
lines changed

5 files changed

+141
-79
lines changed

rhtorch/callbacks/plotting.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66

77

8-
def plot_inline(d1, d2, d3, color_channel_axis=0):
8+
def plot_inline(d1, d2, d3, color_channel_axis=0, vmin=None, vmax=None):
99
"""
1010
Parameters
1111
----------
@@ -19,28 +19,28 @@ def plot_inline(d1, d2, d3, color_channel_axis=0):
1919
Axis for color channel in the numpy array .
2020
Default is 0 for Pytorch models (cc, dimx, dimy, dimz)
2121
Use 3 for TF models (dimx, dimy, dimz, cc)
22+
vmin : Lower bound for color channel. Default (None) used to plot full range
23+
vmax : Upper bound for color channel. Default (None) used to plot full range
2224
2325
"""
2426
# If input has more than 1 color channel, use only the first
2527
if d1.shape[color_channel_axis] > 1:
2628
d1 = d1[0,...] if color_channel_axis == 0 else d1[...,0]
27-
d1 = torch.unsqueeze(d1,color_channel_axis)
28-
d_arr = d_arr = np.concatenate((d1, d2, d3), color_channel_axis)
29+
d1 = torch.unsqueeze(d1, color_channel_axis)
30+
d_arr = np.concatenate((d1, d2, d3), color_channel_axis)
2931
num_dat = d_arr.shape[color_channel_axis]
3032

3133
fig, ax = plt.subplots(1, num_dat, gridspec_kw={'wspace': 0, 'hspace': 0})
3234
slice_i = int(d1.size(1) / 2)
3335
orient = 0
3436
text_pos = d1.size(2) * 0.98
3537

36-
# make a list of subplot titles - may need several input subtitles
37-
titles = [f"Input{i+1}" for i in range(d1.size(color_channel_axis))]
38-
titles.extend(['Target', 'Prediction'])
38+
titles = ['Input', 'Target', 'Prediction']
3939

4040
for idx in range(num_dat):
4141
single_data = d_arr.take(indices=idx, axis=color_channel_axis)
4242
ax[idx].imshow(single_data.take(indices=slice_i, axis=orient),
43-
cmap='gray', vmin=0, vmax=1)
43+
cmap='gray', vmin=vmin, vmax=vmax)
4444
ax[idx].axis('off')
4545
ax[idx].text(3, text_pos, titles[idx], color='white', fontsize=12)
4646

@@ -50,10 +50,12 @@ def plot_inline(d1, d2, d3, color_channel_axis=0):
5050
return wandb_im
5151

5252
class ImagePredictionLogger(Callback):
53-
def __init__(self, val_dataloader):
53+
def __init__(self, val_dataloader, config=None):
5454
super().__init__()
5555
self.X, self.y = next(iter(val_dataloader))
5656

57+
# TODO: Read config file to parse vmin, vmax or e.g. custom titles
58+
5759
def on_validation_epoch_end(self, trainer, pl_module):
5860
# Dataloader loads on CPU --> pass to GPU
5961
X = self.X.to(device=pl_module.device)

rhtorch/config_utils.py

+66-58
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,78 @@
1-
#!/usr/bin/env python3
2-
# -*- coding: utf-8 -*-
31
import ruamel.yaml as yaml
42
from datetime import datetime
53
from pathlib import Path
64
import torch
75
from rhtorch.version import __version__
86
import socket
97

10-
loss_map = {'MeanAbsoluteError': 'mae',
11-
'MeanSquaredError': 'mse',
12-
'huber_loss': 'huber',
13-
'BCEWithLogitsLoss': 'BCE'}
8+
class UserConfig:
9+
def __init__(self, rootdir, arguments=None):
10+
self.rootdir = rootdir
11+
self.config_file = self.is_path(arguments.config)
12+
self.args = arguments
13+
14+
# load default configs
15+
default_config_file = Path(__file__).parent.joinpath('default_config.yaml')
16+
with open(default_config_file) as dcf:
17+
self.default_params = yaml.load(dcf, Loader=yaml.Loader)
18+
19+
# load user config file
20+
with open(self.config_file) as cf:
21+
self.hparams = yaml.load(cf, Loader=yaml.RoundTripLoader)
22+
23+
# merge the two dicts
24+
self.merge_dicts()
25+
26+
# sanity check on data_folder provided by user
27+
self.data_path = self.is_path(self.hparams['data_folder'])
28+
29+
# make model name
30+
self.fill_additional_info()
31+
self.create_model_name()
32+
33+
def is_path(self, path):
34+
# check for path - assuming absolute path was given
35+
filepath = Path(path)
36+
if not filepath.exists():
37+
# assuming path was given relative to rootdir
38+
filepath = self.rootdir.joinpath(filepath)
39+
if not filepath.exists():
40+
raise FileNotFoundError(f"{path} not found. Define relative to project directory or as absolute path in config file/argument passing.")
41+
42+
return filepath
1443

44+
def merge_dicts(self):
45+
""" adds to the user_params dictionnary any missing key from the default params """
46+
47+
for k, v in self.default_params.items():
48+
if k not in self.hparams:
49+
self.hparams[k] = v
50+
### TO DO - ENSURE NOT COPYING IRRELEVANT DATA e.g. GAN parameters if model is AE
1551

16-
def load_model_config(rootdir, arguments):
52+
def fill_additional_info(self):
53+
# additional info from args and miscellaneous to save in config
54+
self.hparams['build_date'] = datetime.now().strftime("%Y%m%d-%H%M%S")
55+
self.hparams['project_dir'] = str(self.rootdir)
56+
self.hparams['data_folder'] = str(self.data_path)
57+
self.hparams['config_file'] = str(self.config_file)
58+
self.hparams['k_fold'] = self.args.kfold
59+
self.hparams['GPUs'] = torch.cuda.device_count()
60+
self.hparams['global_batch_size'] = self.hparams['batch_size'] * self.hparams['GPUs']
61+
self.hparams['rhtorch_version'] = __version__
62+
self.hparams['hostname'] = socket.gethostname()
1763

18-
# check for config_file
19-
config_file = Path(arguments.config)
20-
if not config_file.exists():
21-
config_file = rootdir.joinpath(config_file)
22-
if not config_file.exists():
23-
raise FileNotFoundError("Config file not found. Define relative to project directory or as absolute path in config file")
24-
25-
# read the config file
26-
with open(config_file) as file:
27-
config = yaml.load(file, Loader=yaml.RoundTripLoader)
64+
def create_model_name(self):
2865

29-
data_shape = 'x'.join(map(str, config['data_shape']))
30-
base_name = f"{config['module']}_{config['version_name']}_{config['data_generator']}"
31-
dat_name = f"bz{config['batch_size']}_{data_shape}"
32-
full_name = f"{base_name}_{dat_name}_k{arguments.kfold}_e{config['epoch']}"
33-
34-
# check for data folder
35-
data_folder = Path(config['data_folder'])
36-
if not data_folder.exists():
37-
# try relative to project dir - in this case overwrite config
38-
data_folder = rootdir.joinpath(config['data_folder'])
39-
if not data_folder.exists():
40-
raise FileNotFoundError("Data path not found. Define relative to the project directory or as absolute path in config file")
66+
data_shape = 'x'.join(map(str, self.hparams['data_shape']))
67+
base_name = f"{self.hparams['module']}_{self.hparams['version_name']}_{self.hparams['data_generator']}"
68+
dat_name = f"bz{self.hparams['batch_size']}_{data_shape}"
69+
self.hparams['model_name'] = f"{base_name}_{dat_name}_k{self.args.kfold}_e{self.hparams['epoch']}"
4170

42-
# additional info from args and miscellaneous to save in config
43-
config['build_date'] = datetime.now().strftime("%Y-%m-%d %H.%M.%S")
44-
config['model_name'] = full_name
45-
config['project_dir'] = str(rootdir)
46-
config['data_folder'] = str(data_folder)
47-
config['config_file'] = str(config_file)
48-
config['k_fold'] = arguments.kfold
49-
if 'precision' not in config:
50-
config['precision'] = 32
51-
config['GPUs'] = torch.cuda.device_count()
52-
config['global_batch_size'] = config['batch_size'] * config['GPUs']
53-
config['rhtorch_version'] = __version__
54-
config['hostname'] = socket.gethostname()
55-
if 'acc_grad_batches' not in config:
56-
config['acc_grad_batches'] = 1
57-
58-
return config
59-
60-
61-
def copy_model_config(path, config, append_timestamp=False):
62-
model_name = config['model_name']
63-
if append_timestamp:
64-
timestamp = config['build_date'].replace(' ','_')
65-
config_file = path.joinpath(f"config_{model_name}_{timestamp}.yaml")
66-
else:
67-
config_file = path.joinpath(f"config_{model_name}.yaml")
68-
config.yaml_set_start_comment(f'Config file for {model_name}')
69-
with open(config_file, 'w') as file:
70-
yaml.dump(config, file, Dumper=yaml.RoundTripDumper)
71+
def save_copy(self, output_dir, append_timestamp=False):
72+
model_name = self.hparams['model_name']
73+
timestamp = f"_{self.hparams['build_date']}" if append_timestamp else ""
74+
save_config_file_name = f"config_{model_name}{timestamp}"
75+
config_file = output_dir.joinpath(save_config_file_name + ".yaml")
76+
self.hparams.yaml_set_start_comment(f'Config file for {model_name}')
77+
with open(config_file, 'w') as file:
78+
yaml.dump(self.hparams, file, Dumper=yaml.RoundTripDumper)

rhtorch/default_config.yaml

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# main config to train models
2+
project_name: YOUR_PROJECT_NAME # Used for WANDB
3+
version_name: v0 # Make run unique by changing this counter
4+
5+
# main model:
6+
precision: 32
7+
epoch: 100
8+
batch_size: 1
9+
acc_grad_batches: 1
10+
module: LightningAE
11+
12+
# generator
13+
generator: UNet3DFullConv
14+
# depth: 4
15+
# initial_num_filters: 64
16+
g_activation: ReLU
17+
g_optimizer: Adam
18+
g_lr: 1e-4
19+
#lr_scheduler: 'exponential_decay_0.01'
20+
g_loss: MeanAbsoluteError
21+
22+
# transfer learning
23+
pretrained_generator: null # absolute path to .pt or .ckpt
24+
freeze_encoder: False
25+
26+
# discriminator - will be used if model is GAN
27+
discriminator: ConvNetDiscriminator
28+
d_optimizer: Adam
29+
d_lr: 2e-4
30+
d_loss: BCEWithLogitsLoss
31+
32+
# data:
33+
data_split_pkl: data_split.pickle # inside data folder .json or .pickle
34+
data_generator: DefaultDataLoader
35+
data_folder: Data/data_noblur_25_64x64x64 # inside project dir
36+
pet_normalization_constant: 32676
37+
augment: True
38+
data_shape: [64,64,64]
39+
color_channels_in: 1
40+
repeat_patient_list: 1
41+
42+
# to implement
43+
# full_data_shape: [64,64,64]
44+
# loading_style: 'volume', 'patch'
45+
# input_data_shape: [8,8,8] for patch or [16,64,64] for slice
46+
# for plotting during training
47+
callback_image2image: ImagePredictionLogger
48+
49+
# model-specific info (self-generated) - do not write anything beyond here

rhtorch/torch_training.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# library package imports
1313
from rhtorch.models import modules
1414
from rhtorch.callbacks import plotting
15-
from rhtorch.config_utils import load_model_config, copy_model_config
15+
from rhtorch.config_utils import UserConfig
1616

1717
def main():
1818
import argparse
@@ -30,8 +30,9 @@ def main():
3030
is_test = args.test
3131

3232
# load configs from file + additional info from args
33-
configs = load_model_config(project_dir, args)
34-
33+
user_configs = UserConfig(project_dir, args)
34+
configs = user_configs.hparams ### WARNING TO CHECK IF THIS is 2 names for the same memory address or 2 distinct memory addresses (matters when saving copy in the end)
35+
3536
# Set local data_generator
3637
sys.path.insert(1, args.input)
3738
import data_generator
@@ -61,19 +62,21 @@ def main():
6162
model = module(configs, shape_in)
6263

6364
# transfer learning setup
64-
if 'pretrained_generator' in configs:
65+
if configs['pretrained_generator']:
6566
print("Setting up transfer learning")
6667
pretrained_model_path = Path(configs['pretrained_generator'])
6768
if pretrained_model_path.exists():
6869
if pretrained_model_path.name.endswith(".ckpt"):
6970
# important to pass in new configs here as we want to load the weights but config may differ from pretrained model
7071
model = module.load_from_checkpoint(pretrained_model_path, hparams=configs, in_shape=shape_in, strict=False)
71-
elif pretrained_model_path.endswith(".pt"):
72+
elif pretrained_model_path.name.endswith(".pt"):
7273
# this works for both .pt and .ckpt actually
7374
# WARNING I don't know which of the above or below method is the correct way to load ckpt
7475
# this below method only load the weights. Above also load state of optimizer, etc...
7576
ckpt = torch.load(pretrained_model_path)
76-
pretrained_model = ckpt['state_dict']
77+
# OBS, the 'state_dict' is not set during save?
78+
# What if we are to save multiple models used later for pretrain? (e.g. a GAN with 3 networks?)
79+
pretrained_model = ckpt['state_dict'] if 'state_dict' in ckpt.keys() else ckpt
7780
model.load_state_dict(pretrained_model, strict=False)
7881
else:
7982
raise ValueError("Expected model format: '.pt' or '.ckpt'.")
@@ -117,15 +120,16 @@ def main():
117120

118121
# Save the config prior to training the model - one for each time the script is started
119122
if not is_test:
120-
copy_model_config(model_path, configs, append_timestamp=True)
123+
user_configs.save_copy(model_path, append_timestamp=True)
121124
print("Saved config prior to model training")
122125

123126
# set the trainer and fit
127+
accelerator = 'ddp' if configs['GPUs'] > 1 else None
124128
trainer = pl.Trainer(max_epochs=configs['epoch'],
125129
logger=wandb_logger,
126130
callbacks=callbacks,
127131
gpus=-1,
128-
accelerator='ddp',
132+
accelerator=accelerator,
129133
resume_from_checkpoint=existing_checkpoint,
130134
auto_select_gpus=True,
131135
accumulate_grad_batches=configs['acc_grad_batches'],
@@ -136,12 +140,12 @@ def main():
136140
trainer.fit(model, train_dataloader, valid_dataloader)
137141

138142
# add useful info to saved configs
139-
configs['best_model'] = checkpoint_callback.best_model_path
143+
user_configs.hparams['best_model'] = checkpoint_callback.best_model_path
140144

141145
# save the model
142146
output_file = model_path.joinpath(f"{configs['model_name']}.pt")
143147
torch.save(model.state_dict(), output_file)
144-
copy_model_config(model_path, configs)
148+
user_configs.save_copy(model_path)
145149
print("Saved model and config file to disk")
146150

147151

rhtorch/version.py

100644100755
+2-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,5 @@
1313
VERSIONING (UPDATED WHEN PR ARE MERGED INTO MASTER BRANCH)
1414
0.0.1 # Added repository (CL 18-05-2021)
1515
0.0.2 # Cleaned up main, moved to torchmetrics in modules (RD 20-05-2021)
16-
0.0.3 # Added version control to config-logfiles
17-
18-
"""
16+
0.0.3 # Added version control to config-logfiles, and default config yaml settings (CL, RD 23-05-2021)
17+
"""

0 commit comments

Comments
 (0)