Skip to content

Commit

Permalink
removed old projects, formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
ibro45 committed Jan 8, 2021
1 parent 7657165 commit e1d7218
Show file tree
Hide file tree
Showing 61 changed files with 208 additions and 1,317 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.ipynb linguist-vendored
10 changes: 5 additions & 5 deletions midaGAN/configs/utils/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ def init_config(conf, config_class=configs.training.TrainConfig):
# Allows the framework to find user-defined, project-specific, dataset classes and their configs
if conf.project_dir:
IMPORT_LOCATIONS["dataset"].append(conf.project_dir)
logger.info(
f"Project directory {conf.project_dir} added to path to allow imports of modules from it."
)
logger.info(f"Project directory {conf.project_dir} added to the"
" path to allow imports of modules from it.")

# Make yaml mergeable by instantiating the dataclasses
conf = instantiate_dataclasses_from_yaml(conf)
Expand All @@ -37,7 +36,7 @@ def init_config(conf, config_class=configs.training.TrainConfig):


def instantiate_dataclasses_from_yaml(conf):
"""Goes through a config and instantiates the fields that are dataclasses.
"""Goes through a config and instantiates the fields that are dataclasses.
A field is a dataclass if its key can be found in the keys of the IMPORT_LOCATIONS.
Each such dataclass should have an entry "name" which is used to import its dataclass
class using that "name" + "Config" as class name.
Expand Down Expand Up @@ -78,4 +77,5 @@ def get_all_conf_keys(conf):
"""Get all keys from a conf and order from them the deepest to the shallowest."""
conf = OmegaConf.to_container(conf)
keys = list(utils.iterate_nested_dict_keys(conf))
return keys[::-1] # order by depth
# Order deeper to shallower
return keys[::-1]
4 changes: 2 additions & 2 deletions midaGAN/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def build_loader(conf):
dataset,
shuffle=False,
num_replicas=communication.get_world_size(),
rank=communication.get_local_rank(
)) # TODO: verify that this indeed should be local rank and not rank
# TODO: should it be rank instead?
rank=communication.get_local_rank())

loader = DataLoader(dataset,
batch_size=conf.batch_size,
Expand Down
4 changes: 3 additions & 1 deletion midaGAN/data/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
class ImageDatasetConfig(configs.base.BaseDatasetConfig):
name: str = "ImageDataset"
image_channels: int = 3
preprocess: str = "resize_and_crop" # scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]'
# Scaling and cropping of images at load time:
# [resize_and_crop | crop | scale_width | scale_width_and_crop | none]'
preprocess: str = "resize_and_crop"
load_size: int = 286
crop_size: int = 256
flip: bool = True
Expand Down
2 changes: 1 addition & 1 deletion midaGAN/data/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
# - Docstring to match the rest of the library
# - Calls to other subroutines which do not exist in DIRECT.

import torch
import itertools
import torch

from torch.utils.data.sampler import Sampler
from midaGAN.utils import communication
Expand Down
3 changes: 2 additions & 1 deletion midaGAN/data/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

def pad(volume, target_shape):
assert len(target_shape) == len(volume.shape)
pad_width = [(0, 0) for _ in range(len(target_shape))] # by default no padding
# By default no padding
pad_width = [(0, 0) for _ in range(len(target_shape))]

for dim in range(len(target_shape)):
if target_shape[dim] > volume.shape[dim]:
Expand Down
8 changes: 5 additions & 3 deletions midaGAN/data/utils/body_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,11 @@ def get_body_mask_and_bound(image: np.ndarray, hu_threshold: int) -> np.ndarray:
return body_mask, bound



def apply_body_mask_and_bound(array: np.ndarray, masking_value: int =-1024, \
apply_mask: bool =False, apply_bound: bool=False, hu_threshold: int =-300) -> np.ndarray:
def apply_body_mask_and_bound(array: np.ndarray,
masking_value: int = -1024,
apply_mask: bool = False,
apply_bound: bool = False,
hu_threshold: int = -300) -> np.ndarray:
"""
Function to apply mask based filtering and bound the array
Expand Down
18 changes: 12 additions & 6 deletions midaGAN/data/utils/image_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def __init__(self, pool_size):
pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
"""
self.pool_size = pool_size
if self.pool_size > 0: # create an empty pool
# Create an empty pool
if self.pool_size > 0:
self.num_imgs = 0
self.images = []

Expand All @@ -37,18 +38,23 @@ def query(self, images):
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
# If the buffer is not full, keep inserting current images to the buffer
if self.num_imgs < self.pool_size:
self.num_imgs = self.num_imgs + 1
self.images.append(image)
return_images.append(image)
else:
p = random.uniform(0, 1)
if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
# By 50% chance, the buffer will return a previously stored image,
# and insert the current image into the buffer
if p > 0.5:
random_id = random.randint(0, self.pool_size - 1)
tmp = self.images[random_id].clone()
self.images[random_id] = image
return_images.append(tmp)
else: # by another 50% chance, the buffer will return the current image
# By another 50% chance, the buffer will return the current image
else:
return_images.append(image)
return_images = torch.cat(return_images, 0) # collect all the images and return
# Collect all the images and return
return_images = torch.cat(return_images, 0)
return return_images
21 changes: 12 additions & 9 deletions midaGAN/data/utils/registration_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ def truncate_CT_to_scope_of_CBCT(CT, CBCT):
end_slice = int(round(mean(z_corners[4:])))
# When the registration fails, just return the original CT. Happens infrequently.
if start_slice < 0:
logger.info(
"Registration failed as the at least one corner is below 0 in one of the axes. Passing the whole CT volume."
)
logger.info("Registration failed as the at least one corner is below 0 in one of the axes."
" Passing the whole CT volume.")
return CT
return CT[:, :, start_slice:end_slice]

Expand All @@ -80,11 +79,12 @@ def register_CT_to_CBCT(CT, CBCT, registration_type="Rigid"):
def get_registration_transform(fixed_image, moving_image, registration_type="Rigid"):
"""Performs the registration and returns a SimpleITK's `Transform` class which can be
used to resample an image so that it is registered to another one. However, in our code
should be truncated so that it contains only the part of the body that is found in the `fixed_image`.
Registration parameters are hardcoded and picked for the specific task of CBCT to CT translation.
should be truncated so that it contains only the part of the body that is
found in the `fixed_image`.
Registration parameters are hardcoded and picked for the specific
task of CBCT to CT translation.
TODO: consider making the adjustable in config.
Parameters:
------------------------
fixed_image:
Expand All @@ -95,8 +95,10 @@ def get_registration_transform(fixed_image, moving_image, registration_type="Rig
"""

# Get seed from environment variable if set for registration randomness
seed = int(
os.environ.get('PYTHONHASHSEED')) if 'PYTHONHASHSEED' in os.environ else sitk.sitkWallClock
if 'PYTHONHASHSEED' in os.environ:
seed = int(os.environ.get('PYTHONHASHSEED'))
else:
seed = sitk.sitkWallClock

# SimpleITK registration's supported pixel types are sitkFloat32 and sitkFloat64
fixed_image = sitk.Cast(fixed_image, sitk.sitkFloat32)
Expand Down Expand Up @@ -130,7 +132,8 @@ def get_registration_transform(fixed_image, moving_image, registration_type="Rig
logger.warning("Unsupported transform provided, falling back to Rigid transformation")
registration_transform = REGISTRATION_MAP["Rigid"]

# Align the centers of the two volumes and set the center of rotation to the center of the fixed image
# Align the centers of the two volumes and set the
# center of rotation to the center of the fixed image
initial_transform = sitk.CenteredTransformInitializer(
fixed_image, moving_image, registration_transform,
sitk.CenteredTransformInitializerFilter.GEOMETRY)
Expand Down
29 changes: 16 additions & 13 deletions midaGAN/data/utils/slice_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
from midaGAN.data.utils import pad


# TODO: Differentiate from Stochasting Focal Patching
class SliceSampler:
""" Stochasting Focal Patching technique achieves spatial correspondance of patches extracted from a pair
""" Stochasting Focal Patching technique achieves spatial correspondance of patches extracted from a pair
of volumes by:
(1) Randomly selecting a slice from volume_A (slice_A)
(2) Calculating the relative start position of the slice_A
(2) Calculating the relative start position of the slice_A
(3) Translating the slice_A's relative position in volume_B
(4) Placing a focal region (a proportion of volume shape) around the focal point
(5) Randomly selecting a start point in the focal region and extracting the slice_B
The added stochasticity in steps (4) and (5) aims to account for possible differences in positioning
The added stochasticity in steps (4) and (5) aims to account for possible differences in positioning
of the object between volumes.
"""

Expand All @@ -24,7 +25,7 @@ def __init__(self,
self.patch_size = np.array(patch_size)
try:
assert len(self.patch_size) == 2
except AssertionError as error:
except AssertionError:
print("Patch size needs to be 2D, use StochasticFocalPatchSampler for 3D!")
exit()

Expand All @@ -40,9 +41,8 @@ def get_slice_pair(self, volume_A, volume_B):
def slice_and_focal_point_from_A(self, volume):
"""Return random patch from volume A and its relative start position."""
z, x, y = self.pick_random_start(volume)

x_end, y_end = [sum(pair) for pair in zip((x, y), self.patch_size)
] # start + patch size for each coord
# start + patch size for each coord
x_end, y_end = [sum(pair) for pair in zip((x, y), self.patch_size)]

volume = pad(volume, (0, x_end, y_end))

Expand All @@ -53,8 +53,8 @@ def slice_and_focal_point_from_A(self, volume):
def slice_from_B(self, volume, relative_focal_point):
"""Return random patch from volume B that is in relative neighborhood of patch_A."""
z, x, y = self.pick_stochastic_focal_start(volume, relative_focal_point)
x_end, y_end = [sum(pair) for pair in zip((x, y), self.patch_size)
] # start + patch size for each coord
# start + patch size for each coord
x_end, y_end = [sum(pair) for pair in zip((x, y), self.patch_size)]

volume = pad(volume, (0, x_end, y_end))

Expand All @@ -76,7 +76,8 @@ def pick_stochastic_focal_start(self, volume, relative_focal_point):
focal_region = self.focal_region_proportion * volume_size
focal_region = focal_region.astype(np.int64)

focal_point = relative_focal_point * volume_size # map relative point to corresponding point in this volume
# Map relative point to corresponding point in this volume
focal_point = relative_focal_point * volume_size
valid_start_region = self.calculate_valid_start_region(volume)

start_coordinates = self.apply_stochastic_focal_method(focal_point, focal_region,
Expand All @@ -99,10 +100,11 @@ def apply_stochastic_focal_method(self, focal_point, focal_region, valid_start_r
min_position = max(0, min_position)
max_position = min(max_position, valid_start_region[axis])

if min_position > max_position: # edge cases # TODO: is it because there's no min(min_position, valid_start_region[axis])
# Edge cases # TODO: is it because there's no min(min_position, valid_start_region[axis])
if min_position > max_position:
start_point.append(max_position)
else:
start_point.append(random.randint(min_position, max_position)) # regular case
start_point.append(random.randint(min_position, max_position))

return start_point

Expand All @@ -123,4 +125,5 @@ def calculate_valid_start_region(self, volume):
return valid_start_region

def get_size(self, volume):
return np.array(volume.shape[-3:]) # last three dimension (Z,X,Y)
# last three dimension (Z,X,Y)
return np.array(volume.shape[-3:])
27 changes: 15 additions & 12 deletions midaGAN/data/utils/stochastic_focal_patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@


class StochasticFocalPatchSampler:
""" Stochasting Focal Patching technique achieves spatial correspondance of patches extracted from a pair
""" Stochasting Focal Patching technique achieves spatial correspondance of patches extracted from a pair
of volumes by:
(1) Randomly selecting a patch from volume_A (patch_A)
(2) Calculating the relative start position of the patch_A
(2) Calculating the relative start position of the patch_A
(3) Translating the patch_A's relative position in volume_B
(4) Placing a focal region (a proportion of volume shape) around the focal point
(5) Randomly selecting a start point in the focal region and extracting the patch_B
The added stochasticity in steps (4) and (5) aims to account for possible differences in positioning
The added stochasticity in steps (4) and (5) aims to account for possible differences in positioning
of the object between volumes.
"""

Expand All @@ -28,8 +28,8 @@ def get_patch_pair(self, volume_A, volume_B):
def patch_and_focal_point_from_A(self, volume):
"""Return random patch from volume A and its relative start position."""
z, x, y = self.pick_random_start(volume)
z_end, x_end, y_end = [sum(pair) for pair in zip((z, x, y), self.patch_size)
] # start + patch size for each coord
# start + patch size for each coord
z_end, x_end, y_end = [sum(pair) for pair in zip((z, x, y), self.patch_size)]

patch = volume[z:z_end, x:x_end, y:y_end]
relative_focal_point = self.calculate_relative_focal_point(z, x, y, volume)
Expand All @@ -38,8 +38,8 @@ def patch_and_focal_point_from_A(self, volume):
def patch_from_B(self, volume, relative_focal_point):
"""Return random patch from volume B that is in relative neighborhood of patch_A."""
z, x, y = self.pick_stochastic_focal_start(volume, relative_focal_point)
z_end, x_end, y_end = [sum(pair) for pair in zip((z, x, y), self.patch_size)
] # start + patch size for each coord
# start + patch size for each coord
z_end, x_end, y_end = [sum(pair) for pair in zip((z, x, y), self.patch_size)]

patch = volume[z:z_end, x:x_end, y:y_end]
return patch
Expand All @@ -56,7 +56,8 @@ def pick_stochastic_focal_start(self, volume, relative_focal_point):
focal_region = self.focal_region_proportion * volume_size
focal_region = focal_region.astype(np.int64)

focal_point = relative_focal_point * volume_size # map relative point to corresponding point in this volume
# Map relative point to corresponding point in this volume
focal_point = relative_focal_point * volume_size
valid_start_region = self.calculate_valid_start_region(volume)

z, x, y = self.apply_stochastic_focal_method(focal_point, focal_region, valid_start_region)
Expand All @@ -71,14 +72,15 @@ def apply_stochastic_focal_method(self, focal_point, focal_region, valid_start_r
min_position = int(focal_point[axis] - focal_region[axis] / 2)
max_position = int(focal_point[axis] + focal_region[axis] / 2)

# if one of the boundaries of the focus is outside of the possible area to sample from, cap it
# If one of the boundaries of the focus is outside of the possible area to sample from, cap it
min_position = max(0, min_position)
max_position = min(max_position, valid_start_region[axis])

if min_position > max_position: # edge cases # TODO: is it because there's no min(min_position, valid_start_region[axis])
# Edge cases # TODO: is it because there's no min(min_position, valid_start_region[axis])
if min_position > max_position:
start_point.append(max_position)
else:
start_point.append(random.randint(min_position, max_position)) # regular case
start_point.append(random.randint(min_position, max_position))

return start_point

Expand All @@ -97,4 +99,5 @@ def calculate_valid_start_region(self, volume):
return valid_start_region

def get_size(self, volume):
return np.array(volume.shape[-3:]) # last three dimension (Z,X,Y)
# last three dimension (Z,X,Y)
return np.array(volume.shape[-3:])
3 changes: 2 additions & 1 deletion midaGAN/data/utils/transforms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from PIL import Image
import torchvision.transforms as transforms

Expand All @@ -12,7 +13,7 @@ def get_transform(conf, method=Image.BICUBIC):
transform_list = []

if 'resize' in preprocess:
osize = [load_size, load_size] # TODO: make it a tuple from config
osize = [load_size, load_size]
transform_list.append(transforms.Resize(osize, method))

elif 'scale_width' in preprocess:
Expand Down
19 changes: 9 additions & 10 deletions midaGAN/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,17 @@ def infer(self, data):
# Sliding window (i.e. patch-wise) inference
if self.sliding_window_inferer:
return self.sliding_window_inferer(data, self.model.infer)
else:
return self.model.infer(data)
return self.model.infer(data)

def _init_sliding_window_inferer(self):
if self.conf.sliding_window:
return SlidingWindowInferer(roi_size=self.conf.sliding_window.window_size,
sw_batch_size=self.conf.sliding_window.batch_size,
overlap=self.conf.sliding_window.overlap,
mode=self.conf.sliding_window.mode,
cval=-1)
else:
return None
if not self.conf.sliding_window:
return

return SlidingWindowInferer(roi_size=self.conf.sliding_window.window_size,
sw_batch_size=self.conf.sliding_window.batch_size,
overlap=self.conf.sliding_window.overlap,
mode=self.conf.sliding_window.mode,
cval=-1)

def calculate_metrics(self, pred, target):
# Check if dataset has scale_to_hu method defined,
Expand Down
Loading

0 comments on commit e1d7218

Please sign in to comment.