-
Notifications
You must be signed in to change notification settings - Fork 217
/
test.py
138 lines (118 loc) · 5.02 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""Test a model and generate submission CSV.
Usage:
> python test.py --split SPLIT --load_path PATH --name NAME
where
> SPLIT is either "dev" or "test"
> PATH is a path to a checkpoint (e.g., save/train/model-01/best.pth.tar)
> NAME is a name to identify the test run
Author:
Chris Chute ([email protected])
"""
import csv
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import util
from args import get_test_args
from collections import OrderedDict
from json import dumps
from models import BiDAF
from os.path import join
from tensorboardX import SummaryWriter
from tqdm import tqdm
from ujson import load as json_load
from util import collate_fn, SQuAD
def main(args):
# Set up logging
args.save_dir = util.get_save_dir(args.save_dir, args.name, training=False)
log = util.get_logger(args.save_dir, args.name)
log.info(f'Args: {dumps(vars(args), indent=4, sort_keys=True)}')
device, gpu_ids = util.get_available_devices()
args.batch_size *= max(1, len(gpu_ids))
# Get embeddings
log.info('Loading embeddings...')
word_vectors = util.torch_from_json(args.word_emb_file)
# Get model
log.info('Building model...')
model = BiDAF(word_vectors=word_vectors,
hidden_size=args.hidden_size)
model = nn.DataParallel(model, gpu_ids)
log.info(f'Loading checkpoint from {args.load_path}...')
model = util.load_model(model, args.load_path, gpu_ids, return_step=False)
model = model.to(device)
model.eval()
# Get data loader
log.info('Building dataset...')
record_file = vars(args)[f'{args.split}_record_file']
dataset = SQuAD(record_file, args.use_squad_v2)
data_loader = data.DataLoader(dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=collate_fn)
# Evaluate
log.info(f'Evaluating on {args.split} split...')
nll_meter = util.AverageMeter()
pred_dict = {} # Predictions for TensorBoard
sub_dict = {} # Predictions for submission
eval_file = vars(args)[f'{args.split}_eval_file']
with open(eval_file, 'r') as fh:
gold_dict = json_load(fh)
with torch.no_grad(), \
tqdm(total=len(dataset)) as progress_bar:
for cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids in data_loader:
# Setup for forward
cw_idxs = cw_idxs.to(device)
qw_idxs = qw_idxs.to(device)
batch_size = cw_idxs.size(0)
# Forward
log_p1, log_p2 = model(cw_idxs, qw_idxs)
y1, y2 = y1.to(device), y2.to(device)
loss = F.nll_loss(log_p1, y1) + F.nll_loss(log_p2, y2)
nll_meter.update(loss.item(), batch_size)
# Get F1 and EM scores
p1, p2 = log_p1.exp(), log_p2.exp()
starts, ends = util.discretize(p1, p2, args.max_ans_len, args.use_squad_v2)
# Log info
progress_bar.update(batch_size)
if args.split != 'test':
# No labels for the test set, so NLL would be invalid
progress_bar.set_postfix(NLL=nll_meter.avg)
idx2pred, uuid2pred = util.convert_tokens(gold_dict,
ids.tolist(),
starts.tolist(),
ends.tolist(),
args.use_squad_v2)
pred_dict.update(idx2pred)
sub_dict.update(uuid2pred)
# Log results (except for test set, since it does not come with labels)
if args.split != 'test':
results = util.eval_dicts(gold_dict, pred_dict, args.use_squad_v2)
results_list = [('NLL', nll_meter.avg),
('F1', results['F1']),
('EM', results['EM'])]
if args.use_squad_v2:
results_list.append(('AvNA', results['AvNA']))
results = OrderedDict(results_list)
# Log to console
results_str = ', '.join(f'{k}: {v:05.2f}' for k, v in results.items())
log.info(f'{args.split.title()} {results_str}')
# Log to TensorBoard
tbx = SummaryWriter(args.save_dir)
util.visualize(tbx,
pred_dict=pred_dict,
eval_path=eval_file,
step=0,
split=args.split,
num_visuals=args.num_visuals)
# Write submission file
sub_path = join(args.save_dir, args.split + '_' + args.sub_file)
log.info(f'Writing submission file to {sub_path}...')
with open(sub_path, 'w', newline='', encoding='utf-8') as csv_fh:
csv_writer = csv.writer(csv_fh, delimiter=',')
csv_writer.writerow(['Id', 'Predicted'])
for uuid in sorted(sub_dict):
csv_writer.writerow([uuid, sub_dict[uuid]])
if __name__ == '__main__':
main(get_test_args())