-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
251 lines (183 loc) · 8.67 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import math
import typing
from pathlib import Path
from tqdm import tqdm
import torch
from torch.nn import Module
from torch.utils.data import DataLoader, random_split, RandomSampler, SequentialSampler
from torch.utils.tensorboard import SummaryWriter
from utils import import_data, save_on_master
from configs import Config
from dataset import NLUDataset, Collator
class Trainer(object):
args:Config
device:torch.device
output_dir: Path
testing:bool
model:Module
optimizer: torch.optim.Optimizer
data_loader_train: DataLoader
data_loader_val: DataLoader
_max_validation_loss:float
_writer:SummaryWriter
def __init__(self, args:Config, model:Module, optimizer: typing.Optional[torch.optim.Optimizer], dataset:NLUDataset, collator:Collator, testing=False):
self.args = args
self.device = torch.device(args.device)
self.output_dir = Path(args.output_dir)
self.testing = testing
self.model = model
self.optimizer = optimizer
self._max_validation_loss = 6.0
self._init_data_loaders(args, dataset, collator)
if args.version != "":
self.output_dir = self.output_dir / args.version
summary_dir = self.output_dir / 'summary'
self._writer = SummaryWriter(str(summary_dir))
def _init_data_loaders(self, args:Config, dataset: NLUDataset, collator:Collator):
total = len(dataset)
train_size = math.ceil(total * 0.8)
test_size = total - train_size
dataset_train, dataset_val = random_split(dataset, [train_size, test_size])
sampler_train = RandomSampler(dataset_train)
sampler_val = SequentialSampler(dataset_val)
batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True)
batch_sampler_val = torch.utils.data.BatchSampler(sampler_val, args.batch_size, drop_last=False)
self.data_loader_train = DataLoader(dataset_train, num_workers=args.num_workers,
batch_sampler=batch_sampler_train, collate_fn=collator)
self.data_loader_val = DataLoader(dataset_val, num_workers=args.num_workers,
batch_sampler=batch_sampler_val, collate_fn=collator)
def _before_train(self):
self.model.to(self.device)
n_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
# train model
print("Start training")
# save config before start training
self.args.save_config()
def train(self):
self._before_train()
for epoch in range(self.args.start_epoch, self.args.epochs):
train_stats = self.train_one_epoch(epoch)
validation_stats = self.evaluate()
self._writer.add_scalars('epoch_loss', {
"training": train_stats['avg_loss'],
"validation": validation_stats['avg_loss'],
}, epoch)
self._writer.add_scalars('intent_accuracy', {
"training": train_stats['intent_acc'],
"validation": validation_stats['intent_acc'],
}, epoch)
self._writer.add_scalars('slot_accuracy', {
"training": train_stats['slot_acc'],
"validation": validation_stats['slot_acc'],
}, epoch)
self._after_evaluation(validation_stats, epoch)
if self.testing:
break
def _after_train(self):
return
def _after_evaluation(self, stats:dict, epoch:int):
if self.args.output_dir:
checkpoint_path = self.output_dir / 'checkpoint.pth'
if stats['avg_loss'] < self._max_validation_loss:
self._max_validation_loss = stats['avg_loss']
save_on_master({'model': self.model.state_dict(),
'optimizer': self.optimizer.state_dict(),
'epoch': epoch,
'args': self.args}, checkpoint_path)
print('model saved on {} after {} epoch with validation loss {:.4f}'.format(checkpoint_path, epoch, self._max_validation_loss))
def train_one_epoch(self, epoch:int):
self.model.train()
print_freq = 10
loader_desc = 'Epoch [{:d}]: loss = {:.4f}, accuracy (intent = {:.4f}, slots = {:.4f})'
train_iterator = tqdm(self.data_loader_train, desc=loader_desc.format(epoch, 0.0, 0.0, 0.0))
n_batch = len(self.data_loader_train)
intent_total = 0
intent_correct = 0
slot_total = 0
slot_correct = 0
total_loss = 0
completed_batch = 0
for idx, samples in enumerate(train_iterator, 1):
inputs, targets = samples
# data & target
inputs = inputs.to(self.device)
targets = {k: v.to(self.device) if type(v) is not str else v for k, v in targets.items()}
inputs.update({
"intent_label_ids": targets['intents'],
"slot_labels_ids": targets['tags'],
})
outputs = self.model(**inputs)
losses, (intent_logits, slot_logits) = outputs
loss_val = losses.item()
total_loss += loss_val
intent_preds = torch.argmax(intent_logits.detach(), dim=1)
intent_acc = torch.sum(intent_preds == targets['intents'])
intent_total += targets['intents'].shape[0]
intent_correct += intent_acc.item()
slot_preds = torch.argmax(slot_logits.detach(), dim=2)
slot_acc = torch.sum(slot_preds == targets['tags'])
slot_total += targets['tags'].shape[0] * targets['tags'].shape[1]
slot_correct += slot_acc.item()
# loss backward & optimzer step
self.optimizer.zero_grad()
losses.backward()
self.optimizer.step()
completed_batch += 1
if idx % print_freq == 0:
train_iterator.set_description(
loader_desc.format(epoch, loss_val, intent_correct / intent_total, slot_correct / slot_total))
self._writer.add_scalar('training_loss', total_loss / completed_batch, epoch*n_batch + (idx+1))
if self.testing and idx == 10:
break
print('Total {}/{} correct intents, {}/{} correct slots trained with a total loss of {:.4f} from {} items'
.format(intent_correct, intent_total, slot_correct, slot_total, total_loss, completed_batch))
return {
'avg_loss': total_loss / completed_batch,
'intent_acc': intent_correct / intent_total,
'slot_acc': slot_correct / slot_total,
}
@torch.no_grad()
def evaluate(self):
self.model.eval()
print_freq = 2
loader_desc = 'Validation: loss = {:.4f}, accuracy (intent = {:.4f}, slots = {:.4f})'
evaluation_iterator = tqdm(self.data_loader_val, desc=loader_desc.format(0.0, 0.0, 0.0))
intent_total = 0
intent_correct = 0
slot_total = 0
slot_correct = 0
total_loss = 0
completed_batch = 0
for idx, samples in enumerate(evaluation_iterator, 1):
inputs, targets = samples
# data & target
inputs = inputs.to(self.device)
targets = {k: v.to(self.device) if type(v) is not str else v for k, v in targets.items()}
inputs.update({
"intent_label_ids": targets['intents'],
"slot_labels_ids": targets['tags'],
})
outputs = self.model(**inputs)
losses, (intent_logits, slot_logits) = outputs
loss_val = losses.item()
total_loss += loss_val
intent_preds = torch.argmax(intent_logits.detach(), dim=1)
intent_acc = torch.sum(intent_preds == targets['intents'])
intent_total += targets['intents'].shape[0]
intent_correct += intent_acc.item()
slot_preds = torch.argmax(slot_logits.detach(), dim=2)
slot_acc = torch.sum(slot_preds == targets['tags'])
slot_total += targets['tags'].shape[0] * targets['tags'].shape[1]
slot_correct += slot_acc.item()
if idx % print_freq == 0:
evaluation_iterator.set_description(
loader_desc.format(loss_val, intent_correct / intent_total, slot_correct / slot_total))
completed_batch += 1
if self.testing and idx == 10:
break
return {
'avg_loss': total_loss / completed_batch,
'intent_acc': intent_correct / intent_total,
'slot_acc': slot_correct / slot_total,
}