Skip to content

Commit 1b3b46a

Browse files
authored
Merge pull request #6 from CAAI/new_unet3d
New unet3d
2 parents 060453e + 06a98b9 commit 1b3b46a

12 files changed

+343
-357
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ __pycache__
33
wandb/
44
rhtorch.egg-info
55
.vscode/
6-
rhtorch/models/dev
6+
rhtorch/models/dev/*
7+
!rhtorch/models/dev/__init__.py

example_project/LowdosePETv1_config.yaml

+9-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
# main config to train models
1+
# main config to train your model
2+
# check default.config for all available options
3+
24
project_name: YOUR_PROJECT_NAME # Used for WANDB
3-
version_name: RESUNETv1 # Make run unique by changing this counter
5+
version_name: UNETv1 # Make run unique by changing this counter
46

57
# main model:
68
precision: 32
@@ -10,29 +12,24 @@ acc_grad_batches: 4
1012
module: LightningAE
1113

1214
# generator
13-
generator: Res3DUnet
15+
generator: UNet3D
16+
g_pooling_type: full_conv
17+
g_filters: [64, 128, 256, 512, 1024]
1418
# g_activation: ReLU
1519
g_optimizer: Adam
1620
g_lr: 1e-4
1721
#lr_scheduler: 'exponential_decay_0.01'
1822
g_loss: MeanAbsoluteError
19-
# pretrained_generator: absolute path to .pt or .ckpt
20-
21-
# discriminator - will be used if model is GAN
22-
# discriminator: ConvNetDiscriminator
23-
# d_optimizer: Adam
24-
# d_lr: 2e-4
25-
# d_loss: BCEWithLogitsLoss
2623

2724
# data:
2825
data_split_pkl: Data/data_5fold.pickle # inside data folder .json or .pickle
2926
data_generator: CustomDataLoader
3027
data_folder: Data/data_noblur_25_64x64x64 # inside project dir
3128
pet_normalization_constant: 32676
3229
augment: True
33-
data_shape: [64,64,64]
30+
data_shape: [128, 128, 128]
3431
color_channels_in: 1
35-
repeat_patient_list: 500
32+
repeat_patient_list: 100
3633

3734
# for plotting during training
3835
callback_image2image: 'ImagePredictionLogger'

rhtorch/callbacks/__init__.py

Whitespace-only changes.

rhtorch/config_utils.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -5,50 +5,51 @@
55
from rhtorch.version import __version__
66
import socket
77

8+
89
class UserConfig:
910
def __init__(self, rootdir, arguments=None):
1011
self.rootdir = rootdir
1112
self.config_file = self.is_path(arguments.config)
1213
self.args = arguments
13-
14+
1415
# load default configs
15-
default_config_file = Path(__file__).parent.joinpath('default_config.yaml')
16+
default_config_file = Path(__file__).parent.joinpath('default.config')
1617
with open(default_config_file) as dcf:
1718
self.default_params = yaml.load(dcf, Loader=yaml.Loader)
18-
19+
1920
# load user config file
2021
with open(self.config_file) as cf:
2122
self.hparams = yaml.load(cf, Loader=yaml.RoundTripLoader)
22-
23+
2324
# merge the two dicts
2425
self.merge_dicts()
25-
26+
2627
# sanity check on data_folder provided by user
2728
self.data_path = self.is_path(self.hparams['data_folder'])
28-
29+
2930
# make model name
3031
self.fill_additional_info()
3132
self.create_model_name()
32-
33+
3334
def is_path(self, path):
3435
# check for path - assuming absolute path was given
3536
filepath = Path(path)
3637
if not filepath.exists():
3738
# assuming path was given relative to rootdir
3839
filepath = self.rootdir.joinpath(filepath)
3940
if not filepath.exists():
40-
raise FileNotFoundError(f"{path} not found. Define relative to project directory or as absolute path in config file/argument passing.")
41-
41+
raise FileNotFoundError(
42+
f"{path} not found. Define relative to project directory or as absolute path in config file/argument passing.")
43+
4244
return filepath
4345

4446
def merge_dicts(self):
4547
""" adds to the user_params dictionnary any missing key from the default params """
46-
48+
4749
for key, value in self.default_params.items():
4850
# copy from default if value is not None/0/False and key not already in user config
4951
if value and key not in self.hparams:
5052
self.hparams[key] = value
51-
5253

5354
def fill_additional_info(self):
5455
# additional info from args and miscellaneous to save in config
@@ -58,21 +59,22 @@ def fill_additional_info(self):
5859
self.hparams['config_file'] = str(self.config_file)
5960
self.hparams['k_fold'] = self.args.kfold
6061
self.hparams['GPUs'] = torch.cuda.device_count()
61-
self.hparams['global_batch_size'] = self.hparams['batch_size'] * self.hparams['GPUs']
62+
self.hparams['global_batch_size'] = self.hparams['batch_size'] * \
63+
self.hparams['GPUs']
6264
self.hparams['rhtorch_version'] = __version__
6365
self.hparams['hostname'] = socket.gethostname()
64-
66+
6567
def create_model_name(self):
66-
68+
6769
data_shape = 'x'.join(map(str, self.hparams['data_shape']))
6870
base_name = f"{self.hparams['module']}_{self.hparams['version_name']}_{self.hparams['data_generator']}"
6971
dat_name = f"bz{self.hparams['batch_size']}_{data_shape}"
7072
self.hparams['model_name'] = f"{base_name}_{dat_name}_k{self.args.kfold}_e{self.hparams['epoch']}"
71-
73+
7274
def save_copy(self, output_dir, append_timestamp=False):
7375
model_name = self.hparams['model_name']
7476
timestamp = f"_{self.hparams['build_date']}" if append_timestamp else ""
75-
save_config_file_name = f"config_{model_name}{timestamp}"
77+
save_config_file_name = f"config_{model_name}{timestamp}"
7678
config_file = output_dir.joinpath(save_config_file_name + ".yaml")
7779
self.hparams.yaml_set_start_comment(f'Config file for {model_name}')
7880
with open(config_file, 'w') as file:

rhtorch/default_config.yaml rhtorch/default.config

+5-6
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ acc_grad_batches: 1
1010
module: LightningAE
1111

1212
# generator
13-
generator: UNet3DFullConv
14-
# depth: 4
15-
# initial_num_filters: 64
16-
g_activation: ReLU
13+
generator: UNet3D
14+
#g_pooling_type: full_conv # full_conv or max_pool
15+
#g_filters: [64, 128, 256, 512, 1024]
16+
#g_activation: ReLU
1717
g_optimizer: Adam
1818
g_lr: 1e-4
1919
#lr_scheduler: 'exponential_decay_0.01'
@@ -33,9 +33,8 @@ d_loss: null
3333
data_split_pkl: data_split_file_inside_data_folder.pickle # inside data folder .json or .pickle
3434
data_generator: DefaultDataLoader
3535
data_folder: data_folder_inside_project_dir # inside project dir
36-
pet_normalization_constant: 32676
3736
augment: True
38-
data_shape: [64,64,64]
37+
data_shape: [128, 128, 128]
3938
color_channels_in: 1
4039
repeat_patient_list: 1
4140

0 commit comments

Comments
 (0)