Skip to content

Commit

Permalink
Deep Supervision, optimized network structure, support for hdf5
Browse files Browse the repository at this point in the history
  • Loading branch information
fgdfgfthgr-fox committed Oct 16, 2024
1 parent 3327924 commit 887ff5f
Show file tree
Hide file tree
Showing 16 changed files with 1,154 additions and 728 deletions.
10 changes: 5 additions & 5 deletions Augmentation Parameters.csv
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
Augmentation,Probability,Low Bound,High Bound,Value
Rotate xy,0.25,0,360,
Rotate xz,0,0,360,
Rotate yz,0,0,360,
Rotate xy,0.5,0,360,
Rotate xz,0.0,0,360,
Rotate yz,0.0,0,360,
Rescaling,0.25,0.75,1.25,
Edge Replicate Pad,0,0,0,0.075
Edge Replicate Pad,0.0,0,0,0.05
Vertical Flip,0.5,,,
Horizontal Flip,0.5,,,
Depth Flip,0.5,,,
Expand All @@ -14,6 +14,6 @@ Gradient Contrast,0.25,0.75,1.5,
Adjust Contrast,0.75,0.5,1.5,
Adjust Gamma,0.75,0.5,1.5,
Adjust Brightness,0.75,0.75,1.4,
Salt And Pepper,0.25,0.005,0.01,
Salt And Pepper,0.0,0.005,0.01,
Label Blur,0,1,1,5
Contour Blur,0,0.5,0.5,7
11 changes: 6 additions & 5 deletions Components/Augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from scipy.ndimage import distance_transform_edt
import torch.multiprocessing as mp


device = "cuda" if torch.cuda.is_available() else "cpu"
# Various customize image augmentation implementations specialised in 4 dimensional tensors.
# (Channel, Depth, Height, Width).
# The expected range should be between 0 and 1.
Expand Down Expand Up @@ -320,7 +322,6 @@ def random_gradient(tensor, range=(0.5, 1.5), gamma=True):
gradient = torch.linspace(range[0], range[1], tensor.shape[1])
gradient = gradient.view(1, -1, 1, 1)
gradient = gradient.flip(dims=(1,))

if gamma:
return tensor ** gradient
else:
Expand All @@ -339,14 +340,14 @@ def salt_and_pepper_noise(tensor, prob=0.01):
Returns:
torch.Tensor: Tensor with salt and pepper noise added.
"""
noisy_tensor = tensor.clone()
noisy_tensor = tensor.detach()

# Add salt noise
salt_mask = torch.rand_like(tensor) < prob
salt_mask = torch.rand_like(tensor, device=device) < prob
noisy_tensor[salt_mask] = 1.0

# Add pepper noise
pepper_mask = torch.rand_like(tensor) < prob
pepper_mask = torch.rand_like(tensor, device=device) < prob
noisy_tensor[pepper_mask] = 0.0

return noisy_tensor
Expand Down Expand Up @@ -557,6 +558,6 @@ def edge_replicate_pad(input_tensors, padding_percentile=0.1):
output_tensors = []
for input_tensor in input_tensors:
output_tensor = input_tensor[:, D_crop:-D_crop, H_crop:-H_crop, W_crop:-W_crop]
output_tensor = T_F.pad(output_tensor, [W_crop, W_crop, H_crop, H_crop, D_crop, D_crop], padding_mode='replicate')
output_tensor = F.pad(output_tensor, [W_crop, W_crop, H_crop, H_crop, D_crop, D_crop], mode='replicate')
output_tensors.append(output_tensor)
return output_tensors
156 changes: 91 additions & 65 deletions Components/DataComponents.py

Large diffs are not rendered by default.

137 changes: 119 additions & 18 deletions Components/Metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,13 @@ def calculate_other_metrices(inputs: torch.Tensor, targets: torch.Tensor):
false_positives = (inputs*(1-targets)).sum().detach()
return true_positives, false_negatives, true_negatives, false_positives

def forward(self, inputs: torch.Tensor, targets: torch.Tensor, sparse_label=False):
def forward(self, predict: torch.Tensor, target: torch.Tensor, sparse_label=False):
"""
Calculate binary classification metrics and loss based on the provided inputs and targets.
Args:
inputs (torch.Tensor): The predicted binary classification values (B, 1, D, H, W).
targets (torch.Tensor): The target labels (B, 1, D, H, W).
predict (torch.Tensor): The predicted binary classification values (B, 1, D, H, W).
target (torch.Tensor): The target labels (B, 1, D, H, W).
When `sparse_label` is True: 0 for unlabeled, 1 for foreground, 2 for background
When `sparse_label` is False: 0.0 for background, 1.0 for foreground
sparse_label (bool): A flag indicating whether the target labels are sparse (default is False).
Expand All @@ -213,35 +213,136 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor, sparse_label=Fals
# Input: 1 = Foreground, 0 = Background. Can be any number in between.
# Target: 1 = Foreground, 0 = Background. Can be any number in between.
if sparse_label:
inputs, targets = self.sparse_label_transform(inputs, targets)
return self.dice_loss(inputs, targets)
predict, target = self.sparse_preprocessing(predict, target)
#return self.dice_loss(inputs, targets)

if self.loss_mode == "focal":
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
inputs = torch.sigmoid(inputs)
BCE_loss = F.binary_cross_entropy_with_logits(predict, target, reduction='none')
predict = torch.sigmoid(predict)
pt = torch.exp(-BCE_loss)
F_loss = (1-pt) ** 1.333 * BCE_loss
with torch.no_grad():
intersection, union, _ = self.calculate_iou_loss(inputs, targets)
tp, fn, tn, fp = self.calculate_other_metrices(inputs, targets)
intersection, union, _ = self.calculate_iou_loss(predict, target)
tp, fn, tn, fp = self.calculate_other_metrices(predict, target)
return F_loss.mean(), intersection, union, tp, fn, tn, fp
elif self.loss_mode == "bce_no_dice":
# Scale down to 20% since it's used for unsupervised learning and is often much higher than supervised
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') * 0.2
BCE_loss = F.binary_cross_entropy_with_logits(predict, target, reduction='none') * 0.2
return BCE_loss.mean(), torch.nan, torch.nan, torch.nan, torch.nan, torch.nan, torch.nan
elif self.loss_mode == "dice":
inputs = torch.sigmoid(inputs)
intersection, union, loss = self.calculate_iou_loss(inputs, targets)
predict = torch.sigmoid(predict)
intersection, union, loss = self.calculate_iou_loss(predict, target)
with torch.no_grad():
tp, fn, tn, fp = self.calculate_other_metrices(inputs, targets)
tp, fn, tn, fp = self.calculate_other_metrices(predict, target)
return loss, intersection, union, tp, fn, tn, fp
elif self.loss_mode == "dice+bce":
bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none').mean()
inputs = torch.sigmoid(inputs)
intersection, union, dice_loss = self.calculate_iou_loss(inputs, targets)
bce_loss = F.binary_cross_entropy_with_logits(predict, target, reduction='none').mean()
predict = torch.sigmoid(predict)
intersection, union, dice_loss = self.calculate_iou_loss(predict, target)
with torch.no_grad():
tp, fn, tn, fp = self.calculate_other_metrices(inputs, targets)
total_loss = (dice_loss*0.1+bce_loss*1.9)/2
tp, fn, tn, fp = self.calculate_other_metrices(predict, target)
total_loss = 0.1*dice_loss+1.9*bce_loss
return total_loss, intersection, union, tp, fn, tn, fp
else:
raise ValueError("Invalid loss. Use 'boundary' or 'focal' or 'dice' or 'dice+bce'.")


class BinaryMetricsForList(nn.Module):
def __init__(self, loss_mode: str, smooth=1024):
"""
A loss module designed to calculate evaluation related metrics as well as loss. Dealing with a list of predicted and a single target.
Args:
loss_mode (str): A string indicating whether to use focal loss ("focal") or dice+bce ("dice+bce").
smooth (float): A smoothing factor for numerical stability (default is 1024, very large, explained in the code)
"""
super(BinaryMetricsForList, self).__init__()
self.loss_mode = loss_mode
self.smooth = smooth

@staticmethod
def sparse_preprocessing(predicts, target):
# In sparse label cases:
# Predict: 1 = Foreground, 0 = Background. Can be any number in between.
# Target: 0 = Unlabelled, 1 = Foreground, 2 = Background
target = target.flatten()
inputs = [predict.flatten() for predict in predicts]
target = torch.where(target == 0, math.nan, target)
target = 1 - (target - 1)
known_label = ~torch.isnan(target)
inputs = [input[known_label] for input in inputs]
targets = target[known_label]
return inputs, targets

def calculate_iou_loss(self, predict: torch.Tensor, target: torch.Tensor):
intersection_s = 2 * torch.sum(target * predict) + self.smooth
# Huge smooth to prevent loss jump when the ground truth foreground is very low or 0
union_s = torch.sum(predict) + torch.sum(target) + self.smooth
loss = 1 - (intersection_s / union_s)
with torch.no_grad():
predict_map = torch.where(predict >= 0.5, 1, 0).to(torch.int8)
intersection = 2 * torch.sum(target * predict_map)
union = torch.sum(predict_map) + torch.sum(target)
return intersection, union, loss

@staticmethod
def calculate_other_metrices(inputs: torch.Tensor, targets: torch.Tensor):
inputs = torch.where(inputs >= 0.5, 1, 0).to(torch.int8)
true_positives = (inputs*targets).sum().detach()
false_negatives = ((1-inputs)*targets).sum().detach()
true_negatives = ((1-inputs)*(1-targets)).sum().detach()
false_positives = (inputs*(1-targets)).sum().detach()
return true_positives, false_negatives, true_negatives, false_positives

def forward(self, predicts: list[torch.Tensor], target: torch.Tensor, sparse_label=False):
"""
Calculate binary classification metrics and loss based on the provided inputs and targets.
Args:
predicts (List of torch.Tensor): The predicted binary classification values (B, 1, D, H, W).
target (torch.Tensor): The target labels (B, 1, D, H, W).
When `sparse_label` is True: 0 for unlabeled, 1 for foreground, 2 for background
When `sparse_label` is False: 0.0 for background, 1.0 for foreground
sparse_label (bool): A flag indicating whether the target labels are sparse (default is False).
If true, will force to dice loss.
Returns:
loss (torch.Tensor): The calculated loss value based on the chosen loss function.
intersection (torch.Tensor)
union (torch.Tensor)
true_positives (torch.Tensor)
false_negatives (torch.Tensor)
true_negatives (torch.Tensor)
false_positives (torch.Tensor)
"""
# In Non-sparse label cases:
# Input: 1 = Foreground, 0 = Background. Can be any number in between.
# Target: 1 = Foreground, 0 = Background. Can be any number in between.
if sparse_label:
predicts, target = self.sparse_preprocessing(predicts, target)

if self.loss_mode == "focal":
bce_losses = [F.binary_cross_entropy_with_logits(predict, target, reduction='none') for predict in predicts]
pts = [torch.exp(-bce_loss) for bce_loss in bce_losses]
f_losses = [((1-pt) ** 1.333 * bce_loss).mean() for pt, bce_loss in zip(pts, bce_losses)]
f_loss = sum(f_losses) / len(f_losses)

with torch.no_grad():
iou_outs = [self.calculate_iou_loss(F.sigmoid(predict), target) for predict in predicts]
other_metrics = [self.calculate_other_metrices(F.sigmoid(predict), target) for predict in predicts]
intersection, union = sum([scale[0] for scale in iou_outs]), sum([scale[1] for scale in iou_outs])
tp, fn, tn, fp = [sum([scale[i] for scale in other_metrics]) for i in range(4)]
return f_loss, intersection, union, tp, fn, tn, fp
elif self.loss_mode == "dice+bce":
bce_losses = [F.binary_cross_entropy_with_logits(predict, target, reduction='none').mean() for predict in predicts]
bce_loss = sum(bce_losses) / len(bce_losses)
iou_outs = [self.calculate_iou_loss(F.sigmoid(predict), target) for predict in predicts]
intersection, union, dice_losses = [sum([scale[i] for scale in iou_outs]) for i in range(3)]
dice_loss = dice_losses / len(predicts)
with torch.no_grad():
other_metrics = [self.calculate_other_metrices(F.sigmoid(predict), target) for predict in predicts]
tp, fn, tn, fp = [sum([scale[i] for scale in other_metrics]) for i in range(4)]
total_loss = 0.1*dice_loss+1.9*bce_loss
return total_loss, intersection, union, tp, fn, tn, fp
else:
raise ValueError("Invalid loss. Use 'boundary' or 'focal' or 'dice' or 'dice+bce'.")
Expand Down
Loading

0 comments on commit 887ff5f

Please sign in to comment.