-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_conv_cifar.py
59 lines (50 loc) · 1.76 KB
/
train_conv_cifar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""
Train conv nets on the CIFAR.
"""
from data_loader import load_cifar
from keras.callbacks import ReduceLROnPlateau, CSVLogger, ModelCheckpoint
from models import get_cifar_convnet
import os
import sys
import numpy as np
# import resnet
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# log_root = 'log'
log_root = '/tigress/qlu/logs/keras-resnet/log'
print('Train: ', sys.argv[1])
data_name = sys.argv[1]
# data_name = 'cifar10'
model_name = 'conv'
batch_size = 32
n_epochs = 50
n_subjs = 10
# train
for subj_id in range(n_subjs):
# create various callbacks
log_dir = os.path.join(log_root, data_name, model_name, 'subj%.2d' % (subj_id))
print(log_dir)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
# callbacks
lr_reducer = ReduceLROnPlateau(
factor=np.sqrt(0.1), cooldown=0, patience=5, min_lr=1e-5)
csv_logger = CSVLogger(os.path.join(log_dir, 'history.csv'))
checkpointer = ModelCheckpoint(
filepath=os.path.join(log_dir, 'weights.{epoch:03d}.hdf5'),
verbose=1, save_best_only=False, period=1)
# load data
X_train, X_test, Y_train, Y_test, y_train, y_test, data_info = load_cifar(data_name)
[n_classes, img_rows, img_cols, img_channels] = data_info
input_shape = (img_rows, img_cols, img_channels)
# get the model
model = get_cifar_convnet(input_shape, n_classes)
# save the random weights
model.save_weights(os.path.join(log_dir, 'weights.%.3d.hdf5'%(0)))
# Fit the model on the batches generated by datagen.flow().
model.fit(X_train, Y_train,
shuffle=True,
batch_size = batch_size,
validation_data=(X_test, Y_test),
epochs=n_epochs, verbose=2,
callbacks=[lr_reducer, csv_logger, checkpointer]
)