-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathengine.py
113 lines (85 loc) · 4.23 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from tqdm import tqdm
import torch
import gc, sys, math
import torch.nn as nn
import torch.nn.functional as F
from utils.losses import DiceScore, DiceBCELoss, build_target, dice_loss
from utils.metrics import Metrics
from utils import utils
def criterion(inputs, target, loss_weight=None, num_classes: int = 2, dice: bool = True, ignore_index: int = -100):
loss = nn.functional.cross_entropy(inputs, target, ignore_index=ignore_index, weight=loss_weight)
if dice is True:
dice_target = build_target(target, num_classes, ignore_index)
loss += dice_loss(inputs, dice_target, multiclass=True, ignore_index=ignore_index)
return loss
# def criterion(inputs, target):
# losses = [F.binary_cross_entropy_with_logits(inputs[i], target) for i in range(len(inputs))]
# total_loss = sum(losses)
#
# return total_loss
def train_one_epoch(model, optimizer, dataloader,
epoch, device, print_freq, clip_grad, clip_mode, loss_scaler, writer=None, args=None):
model.train()
num_steps = len(dataloader)
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
if args.nb_classes == 2:
# TODO set CrossEntropy loss-weights for object & background according to your datasets
loss_weight = torch.as_tensor([1.0, 2.0], device=device)
else:
loss_weight = None
for idx, (img, lbl) in enumerate(metric_logger.log_every(dataloader, print_freq, header)):
img = img.to(device, dtype=torch.float16, non_blocking=True)
lbl = lbl.to(device, non_blocking=True)
optimizer.zero_grad()
with torch.cuda.amp.autocast():
logits = model(img)
loss = criterion(logits, lbl, loss_weight, num_classes=args.nb_classes, dice=args.dice,
ignore_index=args.ignore_index)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
# this attribute is added by timm on one optimizer (adahessian)
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
with torch.cuda.amp.autocast():
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
parameters=model.parameters(), create_graph=is_second_order)
torch.cuda.synchronize()
lr = optimizer.param_groups[0]["lr"]
metric_logger.update(loss=loss_value, lr=lr)
if idx % print_freq == 0:
if args.local_rank == 0:
iter_all_count = epoch * num_steps + idx
writer.add_scalar('train_loss', loss, iter_all_count)
writer.add_scalar('train_lr', lr, iter_all_count)
metric_logger.synchronize_between_processes()
torch.cuda.empty_cache()
gc.collect()
return metric_logger.meters["loss"].global_avg, lr
@torch.inference_mode()
def valid_one_epoch(args, model, dataloader, device, print_freq, writer=None):
model.eval()
metric = Metrics(args.nb_classes, args.ignore_label, args.device)
confmat = utils.ConfusionMatrix(args.nb_classes)
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
for idx, (images, labels) in enumerate(metric_logger.log_every(dataloader, print_freq, header)):
images = images.to(device, dtype=torch.float32, non_blocking=True)
labels = labels.to(device, non_blocking=True)
# compute output
# with torch.cuda.amp.autocast(): # TODO: ConfusionMatrix not implemented for 'Half' data
outputs = model(images)
confmat.update(labels.flatten(), outputs.argmax(1).flatten())
metric.update(outputs, labels.flatten())
if writer:
if idx % print_freq == 0 & args.local_rank == 0:
writer.add_scalar('valid_mf1', metric.compute_f1()[1])
writer.add_scalar('valid_acc', metric.compute_pixel_acc()[1])
writer.add_scalar('valid_mIOU', metric.compute_iou()[1])
confmat.reduce_from_all_processes()
metric.reduce_from_all_processes()
torch.cuda.empty_cache()
gc.collect()
return confmat, metric