-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_resnet.py
77 lines (67 loc) · 3.08 KB
/
train_resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
Adapted from keras example cifar10_cnn.py
Train ResNet-18 on the CIFAR10 small images dataset.
GPU run command with Theano backend (with TensorFlow, the GPU is automatically used):
THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python cifar10.py
"""
from keras.preprocessing.image import ImageDataGenerator
from data_loader import load_cifar
from keras.callbacks import ReduceLROnPlateau, CSVLogger, ModelCheckpoint
import os
import sys
import numpy as np
import resnet
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# log_root = 'log'
log_root = '/tigress/qlu/logs/keras-resnet/log'
print('Train subject: ', sys.argv[1])
subj_id = int(sys.argv[1])
data_name = sys.argv[2]
# data_name = 'cifar10'
model_name = 'resnet18'
batch_size = 32
n_epochs = 100
# create various callbacks
log_dir = os.path.join(log_root, data_name, model_name, 'subj%.2d' % (subj_id))
if not os.path.exists(log_dir):
os.makedirs(log_dir)
lr_reducer = ReduceLROnPlateau(
factor=np.sqrt(0.1), cooldown=0, patience=5, min_lr=0.5e-6)
# early_stopper = EarlyStopping(min_delta=0.001, patience=10)
csv_logger = CSVLogger(os.path.join(log_dir, 'history.csv'))
checkpointer = ModelCheckpoint(
filepath=os.path.join(log_dir, 'weights.{epoch:03d}.hdf5'),
verbose=1, save_best_only=False, period=1)
# load data
X_train, X_test, Y_train, Y_test, y_train, y_test, data_info = load_cifar(data_name)
[n_classes, img_rows, img_cols, img_channels] = data_info
# build the model
model = resnet.ResnetBuilder.build_resnet_18(
(img_channels, img_rows, img_cols), n_classes)
model.compile(loss='categorical_crossentropy',
optimizer='adam', metrics=['accuracy'])
# save the random weights
model.save_weights(os.path.join(log_dir, 'weights.%.3d.hdf5'%(0)))
# This will do preprocessing and realtime data augmentation:
datagen = ImageDataGenerator(
featurewise_center=False, # set input mean to 0 over the dataset
samplewise_center=False, # set each sample mean to 0
featurewise_std_normalization=False, # divide inputs by std of the dataset
samplewise_std_normalization=False, # divide each input by its std
zca_whitening=False, # apply ZCA whitening
rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)
width_shift_range=0.1, # randomly shift images horizontally (fraction of total width)
height_shift_range=0.1, # randomly shift images vertically (fraction of total height)
horizontal_flip=True, # randomly flip images
vertical_flip=False) # randomly flip images
# Compute quantities required for featurewise normalization
# (std, mean, and principal components if ZCA whitening is applied).
datagen.fit(X_train)
# Fit the model on the batches generated by datagen.flow().
model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size),
shuffle=True,
steps_per_epoch=X_train.shape[0] // batch_size,
# steps_per_epoch=8,
validation_data=(X_test, Y_test),
epochs=n_epochs, verbose=2, max_q_size=100,
callbacks=[lr_reducer, csv_logger, checkpointer])