-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathloss.py
109 lines (97 loc) · 3.88 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
from PIL import Image
from os import listdir
from os.path import isfile, join
import numpy as np
import os
from glob import glob
from PIL import Image
np.random.seed(123)
import warnings
import numpy as np
from keras import backend as K
from keras.optimizers import Adam, Nadam
import tensorflow as tf
def el(y_true, y_pred):
l = keras.losses.BinaryCrossentropy(y_true,y_pred)
return l
def get_optimizer():
adam = Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
return adam
def single_dice_coef(y_true, y_pred_bin):
# shape of y_true and y_pred_bin: (height, width)
intersection = np.sum(y_true * y_pred_bin)
if (np.sum(y_true)==0) and (np.sum(y_pred_bin)==0):
return 1
return (2*intersection) / (np.sum(y_true) + np.sum(y_pred_bin))
def dice_coef(y_true, y_pred):
y_true_f = K.flatten(y_true)
y_pred = K.cast(y_pred, 'float32')
y_pred_f = K.cast(K.greater(K.flatten(y_pred), 0.5), 'float32')
intersection = y_true_f * y_pred_f
score = 2. * K.sum(intersection) / (K.sum(y_true_f) + K.sum(y_pred_f))
return score
def mean_dice_coef(y_true, y_pred_bin):
# shape of y_true and y_pred_bin: (n_samples, height, width, n_channels)
batch_size = y_true.shape[0]
channel_num = y_true.shape[-1]
mean_dice_channel = 0.
for i in range(batch_size):
for j in range(channel_num):
channel_dice = single_dice_coef(y_true[i, :, :, j], y_pred_bin[i, :, :, j])
mean_dice_channel += channel_dice/(channel_num*batch_size)
return mean_dice_channel
def seg_loss(y_true, y_pred):
dice_s = dice_coefficient_loss(y_true,y_pred)
#ce_loss = BinaryCrossentropy(y_true,y_pred)
ce_loss =tf.keras.backend.binary_crossentropy(y_true,y_pred)
return ce_loss +dice_s
def dice_loss(y_true, y_pred):
smooth = 1.
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = y_true_f * y_pred_f
score = (2. * K.sum(intersection) + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
return 1. - score
def dice_coefficient_loss(y_true, y_pred):
return 1.-dice_coefficient(y_true, y_pred)
def bce_logdice_loss(y_true, y_pred):
return binary_crossentropy(y_true, y_pred) - K.log(1. - dice_loss(y_true, y_pred))
def weighted_bce_loss(y_true, y_pred, weight):
epsilon = 1e-7
y_pred = K.clip(y_pred, epsilon, 1. - epsilon)
logit_y_pred = K.log(y_pred / (1. - y_pred))
loss = weight * (logit_y_pred * (1. - y_true) +
K.log(1. + K.exp(-K.abs(logit_y_pred))) + K.maximum(-logit_y_pred, 0.))
return K.sum(loss) / K.sum(weight)
def weighted_dice(y_true, y_pred):
smooth = 1.
w, m1, m2 = 0.7, y_true, y_pred
intersection = (m1 * m2)
score = (2. * K.sum(w * intersection) + smooth) / (K.sum(w * m1) + K.sum(w * m2) + smooth)
return K.sum(score)
def weighted_dice_loss(y_true, y_pred):
smooth = 1.
w, m1, m2 = 0.7, y_true, y_pred
intersection = (m1 * m2)
score = (2. * K.sum(w * intersection) + smooth) / (K.sum(w * m1) + K.sum(w * m2) + smooth)
loss = 1. - K.sum(score)
return loss
def weighted_bce_dice_loss(y_true, y_pred):
y_true = K.cast(y_true, 'float32')
y_pred = K.cast(y_pred, 'float32')
# if we want to get same size of output, kernel size must be odd
averaged_mask = K.pool2d(y_true, pool_size=(50, 50), strides=(1, 1), padding='same', pool_mode='avg')
weight = K.ones_like(averaged_mask)
w0 = K.sum(weight)
weight = 5. * K.exp(-5. * K.abs(averaged_mask - 0.5))
w1 = K.sum(weight)
weight *= (w0 / w1)
loss = weighted_bce_loss(y_true, y_pred, weight) + dice_loss(y_true, y_pred)
return loss
def dice_coefficient(y_true, y_pred, smooth=1):
y_true_f = K.flatten(y_true[:,:,:,0])
y_pred_f = K.flatten(y_pred[:,:,:,0])
intersection = K.sum(y_true_f * y_pred_f)
d1 = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
return d1