Skip to content

Commit

Permalink
Initial code release
Browse files Browse the repository at this point in the history
  • Loading branch information
Max Seitzer committed Sep 14, 2018
1 parent 0471d96 commit 428a03f
Show file tree
Hide file tree
Showing 72 changed files with 10,855 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

__pycache__/
48 changes: 48 additions & 0 deletions configs/1-recnet.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
{
"seed": 0,
"run_name": "1-recnet",
"description": "Train the RecNet baseline on ScarSeg512",

"train_dataset": "ScarSeg",
"validation_dataset": "ScarSeg",
"split_ratio": [4, 1, 1],
"input_mode": "2d",

"undersampling": {
"sampling_scheme": "varden",
"acceleration_factor": 8,
"variable_acceleration": false
},

"runner_type": "standard",
"application": "reconstruction",

"model": {
"name": "RecNet",
"num_blocks": 3,
"num_convs": 3,
"num_filters": 32
},

"loss_name": "MSE",

"optimizer": {
"name": "Adam",
"learning_rate": 0.0002
},

"num_epochs": 1500,
"batch_size": 20,
"epochs_per_validation": 1,
"epochs_per_checkpoint": 1,
"steps_per_train_summary": 20,
"num_periodic_checkpoints": 2,
"num_data_workers": 4,

"train_metrics": ["psnr"],
"validation_metrics": ["psnr", "ssim"],
"validation_checkpoint_metrics": ["loss_MSE"],

"use_tensorboard": true,
"num_image_summaries": 8
}
118 changes: 118 additions & 0 deletions configs/2-refinement.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
{
"seed": 1,
"run_name": "2-refinement",
"description": "Train the refinement network on top of pretrained RecNet",

"train_dataset": "ScarSeg",
"validation_dataset": "ScarSeg",
"split_ratio": [4, 1, 1],
"input_mode": "2d",

"undersampling": {
"sampling_scheme": "varden",
"acceleration_factor": 8,
"variable_acceleration": false
},

"runner_type": "adversarial",
"application": "reconstruction",

"generator_model": {
"name": "RefinementWrapper",
"mode": "real-penalty-add",
"input_mode": "output",
"pretrained_model": {
"name": "RecNet",
"num_blocks": 3,
"num_convs": 3,
"num_filters": 32,
"pretrained_weights": ["../resources/models/INSERT_CHECKPOINT_PATH_HERE", "model"]
},
"learnable_model": {
"name": "UNET",
"num_inputs": 2,
"num_outputs": 1,
"num_layers_per_scale": 2,
"encode_filters": [32, 64, 128],
"decode_filters": [64, 32],
"kernel_size": 4,
"upsampling_mode": "nn-resize-conv",
"output_activation": "none",
"padding": "reflection",
"decoder_act_upsampling_only": true,
"weight_init": {
"conv_weight": ["orthogonal", "relu"],
"batchnorm_weight": ["constant", 1.0]
}
}
},

"discriminator_model": {
"spatial_shape": [512, 512],
"num_inputs": 1,
"input_method": "simple-magnitude",

"num_filters_per_layer": [64, 128, 256, 512, 1024, 1024],
"strides": [2, 2, 2, 2, 2, 1],
"kernel_sizes": 4,
"final_conv_kernel_size": 4,
"padding": "reflection",

"act_fn": "lrelu",
"relu_leakiness": 0.2,

"dropout_after": [3, 4, 5],
"dropout_prob": 0.5,

"use_norm_layers": "not-first",
"norm_layer": "batch",

"compute_features": true,
"use_image_pool": true,
"image_pool_size": 80
},

"generator_adversarial_losses": ["gan", "FeatureMatching"],
"generator_losses": ["VGG19", "FeaturePenalty"],
"generator_loss_weights": {
"gan": 0.5,
"VGG19": 10,
"FeaturePenalty": 2
},
"discriminator_losses": ["gan"],
"discriminator_label_smoothing": 0.1,

"feature_penalty": {
"criterion": "L1",
"input_key": "prescaled_refinement"
},

"generator_optimizer": {
"name": "Adam",
"learning_rate": 0.0002,
"beta1": 0.5
},

"discriminator_optimizer": {
"name": "Adam",
"learning_rate": 0.0002,
"beta1": 0.5
},

"num_epochs": 200,
"batch_size": 5,
"validation_batch_size": 5,
"epochs_per_validation": 1,
"epochs_per_checkpoint": 1,
"steps_per_train_summary": 20,
"num_periodic_checkpoints": 20,
"num_data_workers": 4,

"train_generator_metrics": ["psnr"],
"train_discriminator_metrics": ["binary_accuracy"],
"validation_metrics": ["psnr", "ssim"],
"best_checkpoint_metrics": ["gen_psnr"],

"use_tensorboard": true,
"num_image_summaries": 5
}
56 changes: 56 additions & 0 deletions configs/3-train-segmentation-unet.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
{
"seed": 0,
"run_name": "3-train-segmentation-unet",
"description": "Train a UNET for segmentation on ground truth reconstructions",

"train_dataset": "ScarSeg",
"validation_dataset": "ScarSeg",
"split_ratio": [4, 1, 1],
"input_mode": "2d",
"dataset_mode": "segmentation",

"undersampling": {
"comment": "Undersampling is not active in dataset_mode segmentation",
"sampling_scheme": "varden",
"acceleration_factor": 8,
"variable_acceleration": false
},

"runner_type": "standard",
"application": "segmentation",

"model": {
"name": "UNET",
"num_inputs": 2,
"num_outputs": 2,
"num_layers_per_scale": 2,
"encode_filters": [32, 64, 128, 256, 512],
"decode_filters": [256, 128, 64, 32],
"use_bn": true,
"upsampling_mode": "pixelshuffle",
"padding": "reflection",
"output_activation": "none"
},

"loss_name": "CrossEntropy",

"optimizer": {
"name": "Adam",
"learning_rate": 0.0002
},

"num_epochs": 200,
"batch_size": 8,
"epochs_per_validation": 1,
"epochs_per_checkpoint": 1,
"steps_per_train_summary": 20,
"num_periodic_checkpoints": 1,
"num_data_workers": 4,

"train_metrics": ["dice_class_0", "dice_class_1"],
"validation_metrics": ["dice_class_0", "dice_class_1"],
"validation_checkpoint_metrics": ["dice_class_1"],

"use_tensorboard": true,
"num_image_summaries": 1
}
13 changes: 13 additions & 0 deletions configs/segscore_unet.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"name": "UNET",
"num_inputs": 2,
"num_outputs": 2,
"num_layers_per_scale": 2,
"encode_filters": [32, 64, 128, 256, 512],
"decode_filters": [256, 128, 64, 32],
"use_bn": true,
"upsampling_mode": "pixelshuffle",
"padding": "reflection",
"output_activation": "none",
"pretrained_weights": ["../resources/models/INSERT_SEGMENTATION_MODEL_CHECKPOINT_HERE", "model"]
}
51 changes: 51 additions & 0 deletions data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import importlib

_DATASET_MODULES = {
'ScarSeg': 'data.reconstruction.scar_seg',
}


def is_dataset(dataset_name):
return dataset_name in _DATASET_MODULES


def maybe_get_subset_sampler(num_samples, dataset):
if num_samples is None or num_samples == 0:
return None

if num_samples > len(dataset):
raise ValueError(('Requesting subset of {} samples, but '
'dataset has only {}').format(num_samples, len(dataset)))

from torch.utils.data.sampler import SubsetRandomSampler
return SubsetRandomSampler(range(num_samples))


def load_dataset(conf, data_dir, dataset_name, fold):
"""Load dataset
Parameters
----------
conf : Configuration
Configuration to pass to the dataset loader
data_dir : string
Path to top level data folder
dataset_name : string
Dataset name
fold : string
Either `train`, `val`, or `test` fold
"""
assert fold in ('train', 'val', 'test')
assert dataset_name in _DATASET_MODULES, \
'Unknown dataset {}'.format(dataset_name)

module = importlib.import_module(_DATASET_MODULES[dataset_name])

if fold == 'train':
return module.get_train_set(conf, data_dir)
elif fold == 'val':
return module.get_val_set(conf, data_dir)
elif fold == 'test':
return module.get_test_set(conf, data_dir)

return None
5 changes: 5 additions & 0 deletions data/reconstruction/deep_med_lib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
MRI reconstruction transform library
Code from Jo Schlemper, with permission
"""
Empty file.
24 changes: 24 additions & 0 deletions data/reconstruction/deep_med_lib/my_pytorch/custom_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch
import torch.nn as nn

def get_objective_loss(objective_loss):
if objective_loss == 'weighted_cross_entropy':
weight = torch.Tensor([.1, 1., 1., 1.])
criterion = CrossEntropyLoss2d(weight)
elif objective_loss == 'bce': # binary cross entropy
criterion = nn.BCELoss()
else: # default: L2
criterion = nn.MSELoss()
return criterion


# Recommend
class CrossEntropyLoss2d(nn.Module):
def __init__(self, weight=None, size_average=True):
super(CrossEntropyLoss2d, self).__init__()
self.nll_loss = nn.NLLLoss2d(weight, size_average)

def forward(self, inputs, targets):
return self.nll_loss(inputs, targets)


11 changes: 11 additions & 0 deletions data/reconstruction/deep_med_lib/my_pytorch/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Hacky utils"""

def get_to_cuda(cuda):
def to_cuda(tensor):
return tensor.cuda() if cuda else tensor
return to_cuda

def get_params(model):
return [w for w in model.parameters()]


Loading

0 comments on commit 428a03f

Please sign in to comment.