-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathestimate_model.py
74 lines (67 loc) · 2.61 KB
/
estimate_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from datasets import prepare_valid_loaders
from utils import *
from models import *
import torch
from torch.nn import functional as F
import gc
import matplotlib.pyplot as plt
from fastprogress import progress_bar
class Model_pred:
def __init__(self, model, dl, tta: bool = True, half: bool = False, config=None):
self.model = model
self.dl = dl
self.tta = tta
self.half = half
self.config = config
def __iter__(self):
self.model.eval()
name_list = self.dl.dataset.graph_list
count = 0
with torch.no_grad():
for x, y in iter(self.dl):
if self.config.device != "cpu":
x = x.to(self.config.device)
if self.half:
x = x.half()
x = x.type(torch.float)
p = self.model(x)
py = torch.sigmoid(p).detach()
if self.tta:
# x,y,xy flips as TTA
flips = [[-1], [-2], [-2, -1]]
for f in flips:
p = self.model(torch.flip(x, f))
p = torch.flip(p, f)
py += torch.sigmoid(p).detach()
py /= (1 + len(flips))
if y is not None and len(y.shape) == 4 and py.shape != y.shape:
py = F.upsample(py, size=(y.shape[-2], y.shape[-1]), mode="bilinear")
py = py.permute(0, 2, 3, 1).float().cpu()
batch_size = len(py)
for i in range(batch_size):
taget = y[i].detach().cpu() if y is not None else None
yield py[i], taget, name_list[count]
count += 1
def __len__(self):
return len(self.dl.dataset)
def predict_model(model, cfg):
dice_loader = prepare_valid_loaders(cfg)
mp = Model_pred(model, dice_loader, config=cfg)
dice = Dice_th_pred(np.arange(0.2, 0.7, 0.01))
for p in progress_bar(mp):
dice.accumulate(p[0], p[1])
# save_img(p[0], p[2], out)
gc.collect()
dices = dice.value
noise_ths = dice.ths
best_dice = dices.max()
best_thr = noise_ths[dices.argmax()]
plt.figure(figsize=(8, 4))
plt.plot(noise_ths, dices, color='blue')
plt.vlines(x=best_thr, ymin=dices.min(), ymax=dices.max(), colors='black')
d = dices.max() - dices.min()
plt.text(noise_ths[-1] - 0.1, best_dice - 0.1 * d, f'DICE = {best_dice:.3f}', fontsize=12)
plt.text(noise_ths[-1] - 0.1, best_dice - 0.2 * d, f'TH = {best_thr:.3f}', fontsize=12)
plt.savefig(f'{cfg.save_weights_dir}/save.jpg')
plt.close()
return None