-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathtrain_search.py
355 lines (293 loc) · 14.7 KB
/
train_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
import os
import sys
import time
import glob
import math
import numpy as np
import torch
import utils
import logging
import argparse
import torch.nn as nn
import torch.utils
import torch.nn.functional as F
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import torch.distributions.categorical as cate
import torchvision.utils as vutils
from torch.autograd import Variable
from model_search import Network
from architect import Architect
from tensorboardX import SummaryWriter
parser = argparse.ArgumentParser("cifar")
parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--batch_increase', default=8, type=int, help='how much does the batch size increase after making a decision')
parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
parser.add_argument('--layers', type=int, default=8, help='total number of layers')
parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
parser.add_argument('--drop_path_prob', type=float, default=0.3, help='drop path probability')
parser.add_argument('--save', type=str, default='EXP', help='experiment name')
parser.add_argument('--seed', type=int, default=2, help='random seed')
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
parser.add_argument('--unrolled', action='store_true', default=False, help='use one-step unrolled validation loss')
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
parser.add_argument('--warmup_dec_epoch', type=int, default=9, help='warmup decision epoch')
parser.add_argument('--decision_freq', type=int, default=5, help='decision freq epoch')
parser.add_argument('--use_history', action='store_true', help='use history for decision')
parser.add_argument('--history_size', type=int, default=4, help='number of stored epoch scores')
parser.add_argument('--post_val', action='store_true', default=False, help='validate after each decision')
args = parser.parse_args()
args.save = 'search-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"))
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
writer = SummaryWriter(log_dir=args.save, max_queue=50)
CIFAR_CLASSES = 10
def histogram_average(history, probs):
histogram_inter = torch.zeros(probs.shape[0], dtype=torch.float).cuda()
if not history:
return histogram_inter
for hist in history:
histogram_inter += utils.histogram_intersection(hist, probs)
histogram_inter /= len(history)
return histogram_inter
def score_image(type, score, epoch):
score_img = vutils.make_grid(
torch.unsqueeze(torch.unsqueeze(torch.unsqueeze(score, 1), 2), 3),
nrow=7,
normalize=True,
pad_value=0.5)
writer.add_image(type + '_score', score_img, epoch)
def edge_decision(type, alphas, selected_idxs, candidate_flags, probs_history, epoch, model, args):
mat = F.softmax(torch.stack(alphas, dim=0), dim=-1).detach()
print(mat)
importance = torch.sum(mat[:, 1:], dim=-1)
# logging.info(type + " importance {}".format(importance))
probs = mat[:, 1:] / importance[:, None]
# print(type + " probs", probs)
entropy = cate.Categorical(probs=probs).entropy() / math.log(probs.size()[1])
# logging.info(type + " entropy {}".format(entropy))
if args.use_history: # SGAS Cri.2
# logging.info(type + " probs history {}".format(probs_history))
histogram_inter = histogram_average(probs_history, probs)
# logging.info(type + " histogram intersection average {}".format(histogram_inter))
probs_history.append(probs)
if (len(probs_history) > args.history_size):
probs_history.pop(0)
score = utils.normalize(importance) * utils.normalize(
1 - entropy) * utils.normalize(histogram_inter)
# logging.info(type + " score {}".format(score))
else: # SGAS Cri.1
score = utils.normalize(importance) * utils.normalize(1 - entropy)
# logging.info(type + " score {}".format(score))
if torch.sum(candidate_flags.int()) > 0 and \
epoch >= args.warmup_dec_epoch and \
(epoch - args.warmup_dec_epoch) % args.decision_freq == 0:
masked_score = torch.min(score,
(2 * candidate_flags.float() - 1) * np.inf)
selected_edge_idx = torch.argmax(masked_score)
selected_op_idx = torch.argmax(probs[selected_edge_idx]) + 1 # add 1 since none op
selected_idxs[selected_edge_idx] = selected_op_idx
candidate_flags[selected_edge_idx] = False
alphas[selected_edge_idx].requires_grad = False
if type == 'normal':
reduction = False
elif type == 'reduce':
reduction = True
else:
raise Exception('Unknown Cell Type')
candidate_flags, selected_idxs = model.check_edges(candidate_flags,
selected_idxs,
reduction=reduction)
logging.info("#" * 30 + " Decision Epoch " + "#" * 30)
logging.info("epoch {}, {}_selected_idxs {}, added edge {} with op idx {}".format(epoch,
type,
selected_idxs,
selected_edge_idx,
selected_op_idx))
print(type + "_candidate_flags {}".format(candidate_flags))
score_image(type, score, epoch)
return True, selected_idxs, candidate_flags
else:
logging.info("#" * 30 + " Not a Decision Epoch " + "#" * 30)
logging.info("epoch {}, {}_selected_idxs {}".format(epoch,
type,
selected_idxs))
print(type + "_candidate_flags {}".format(candidate_flags))
score_image(type, score, epoch)
return False, selected_idxs, candidate_flags
def main():
if not torch.cuda.is_available():
logging.info('no gpu device available')
sys.exit(1)
np.random.seed(args.seed)
torch.cuda.set_device(args.gpu)
cudnn.benchmark = True
torch.manual_seed(args.seed)
cudnn.enabled = True
torch.cuda.manual_seed(args.seed)
logging.info('gpu device = %d' % args.gpu)
logging.info("args = %s", args)
criterion = nn.CrossEntropyLoss()
criterion = criterion.cuda()
model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion)
model = model.cuda()
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
optimizer = torch.optim.SGD(
model.parameters(),
args.learning_rate,
momentum=args.momentum,
weight_decay=args.weight_decay)
train_transform, valid_transform = utils._data_transforms_cifar10(args)
train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
num_train = len(train_data)
indices = list(range(num_train))
split = int(np.floor(args.train_portion * num_train))
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
pin_memory=True)
valid_queue = torch.utils.data.DataLoader(
train_data, batch_size=args.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
pin_memory=True)
num_edges = model._steps * 2
post_train = 5
epochs = args.warmup_dec_epoch + args.decision_freq * (num_edges - 1) + post_train + 1
logging.info("total epochs: %d", epochs)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, float(epochs), eta_min=args.learning_rate_min)
architect = Architect(model, args)
normal_selected_idxs = torch.tensor(len(model.alphas_normal) * [-1], requires_grad=False, dtype=torch.int).cuda()
reduce_selected_idxs = torch.tensor(len(model.alphas_reduce) * [-1], requires_grad=False, dtype=torch.int).cuda()
normal_candidate_flags = torch.tensor(len(model.alphas_normal) * [True], requires_grad=False, dtype=torch.bool).cuda()
reduce_candidate_flags = torch.tensor(len(model.alphas_reduce) * [True], requires_grad=False, dtype=torch.bool).cuda()
logging.info('normal_selected_idxs: {}'.format(normal_selected_idxs))
logging.info('reduce_selected_idxs: {}'.format(reduce_selected_idxs))
logging.info('normal_candidate_flags: {}'.format(normal_candidate_flags))
logging.info('reduce_candidate_flags: {}'.format(reduce_candidate_flags))
model.normal_selected_idxs = normal_selected_idxs
model.reduce_selected_idxs = reduce_selected_idxs
model.normal_candidate_flags = normal_candidate_flags
model.reduce_candidate_flags = reduce_candidate_flags
print(F.softmax(torch.stack(model.alphas_normal, dim=0), dim=-1).detach())
print(F.softmax(torch.stack(model.alphas_reduce, dim=0), dim=-1).detach())
count = 0
normal_probs_history = []
reduce_probs_history = []
for epoch in range(epochs):
scheduler.step()
lr = scheduler.get_lr()[0]
logging.info('epoch %d lr %e', epoch, lr)
# training
train_acc, train_obj = train(train_queue, valid_queue, model, architect, criterion, optimizer, lr)
logging.info('train_acc %f', train_acc)
# validation
with torch.no_grad():
valid_acc, valid_obj = infer(valid_queue, model, criterion)
logging.info('valid_acc %f', valid_acc)
saved_memory_normal, model.normal_selected_idxs, \
model.normal_candidate_flags = edge_decision('normal',
model.alphas_normal,
model.normal_selected_idxs,
model.normal_candidate_flags,
normal_probs_history,
epoch,
model,
args)
saved_memory_reduce, model.reduce_selected_idxs, \
model.reduce_candidate_flags = edge_decision('reduce',
model.alphas_reduce,
model.reduce_selected_idxs,
model.reduce_candidate_flags,
reduce_probs_history,
epoch,
model,
args)
if saved_memory_normal or saved_memory_reduce:
del train_queue, valid_queue
torch.cuda.empty_cache()
count += 1
new_batch_size = args.batch_size + args.batch_increase * count
logging.info("new_batch_size = {}".format(new_batch_size))
train_queue = torch.utils.data.DataLoader(
train_data, batch_size=new_batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
pin_memory=True, num_workers=2)
valid_queue = torch.utils.data.DataLoader(
train_data, batch_size=new_batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
pin_memory=True, num_workers=2)
# post validation
if args.post_val:
with torch.no_grad():
post_valid_acc, valid_obj = infer(valid_queue, model, criterion)
logging.info('post_valid_acc %f', post_valid_acc)
logging.info('genotype = %s', model.get_genotype(force=True))
writer.add_scalar('stats/train_acc', train_acc, epoch)
writer.add_scalar('stats/valid_acc', valid_acc, epoch)
utils.save(model, os.path.join(args.save, 'weights.pt'))
logging.info("#" * 30 + " Done " + "#" * 30)
logging.info('genotype = %s', model.get_genotype())
def train(train_queue, valid_queue, model, architect, criterion, optimizer, lr):
objs = utils.AverageMeter()
top1 = utils.AverageMeter()
top5 = utils.AverageMeter()
for step, (input, target) in enumerate(train_queue):
model.train()
n = input.size(0)
input = Variable(input, requires_grad=False).cuda()
target = Variable(target, requires_grad=False).cuda(async=True)
# get a random minibatch from the search queue with replacement
input_search, target_search = next(iter(valid_queue))
input_search = Variable(input_search, requires_grad=False).cuda()
target_search = Variable(target_search, requires_grad=False).cuda(async=True)
architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled)
optimizer.zero_grad()
logits = model(input)
loss = criterion(logits, target)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
objs.update(loss.item(), n)
top1.update(prec1.item(), n)
top5.update(prec5.item(), n)
if step % args.report_freq == 0:
logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
return top1.avg, objs.avg
def infer(valid_queue, model, criterion):
objs = utils.AverageMeter()
top1 = utils.AverageMeter()
top5 = utils.AverageMeter()
model.eval()
for step, (input, target) in enumerate(valid_queue):
input = Variable(input).cuda()
target = Variable(target).cuda(async=True)
logits = model(input)
loss = criterion(logits, target)
prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
n = input.size(0)
objs.update(loss.item(), n)
top1.update(prec1.item(), n)
top5.update(prec5.item(), n)
if step % args.report_freq == 0:
logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
return top1.avg, objs.avg
if __name__ == '__main__':
main()