forked from PhilippeNguyen/nested_dropout
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pretrain_model.py
120 lines (103 loc) · 3.97 KB
/
pretrain_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import mnist
import argparse
import tensorflow as tf
import tensorflow.keras as keras
from special import (tanh_crossentropy,
build_repeat_block,FixedModelCheckpoint,
build_dropout_block,build_latent_params,
build_lin_to_tanh_converter)
from tensorflow.keras.callbacks import ModelCheckpoint
if __name__ == '__main__':
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--latent_size", action="store", dest="latent_size",
default=None,type=int,
help="number of neurons in the latent"
)
parser.add_argument(
"--dataset", action="store", dest="dataset",
default='mnist',
help="dataset (mnist or imagenet?)"
)
parser.add_argument(
"--batch_size", action="store", dest="batch_size",
default=8,type=int,
help="batch_size (total batch size = batch_size*batch_repeats)"
)
parser.add_argument(
"--save_model", action="store", dest="save_model",
required=True,
help="filename to save the model as (.hdf5)"
)
parser.add_argument(
"--epochs", action="store", dest="epochs",
default=100,type=int,
help="number of epochs to train"
)
parser.add_argument(
"--patience", action="store", dest="patience",
default=5,type=int,
help="Early stopping patience"
)
args = parser.parse_args()
dataset_name = args.dataset
latent_size = args.latent_size
batch_size = args.batch_size
patience= args.patience
epochs = args.epochs
out = args.save_model if args.save_model.endswith('.hdf5') else args.save_model + '.hdf5'
if dataset_name == 'mnist':
dataset = mnist
if latent_size is None:
latent_size = 100
else:
pass
###Config Stuff
#overwrite the parser args
batch_size = 64
###
#Set up data
x_train,x_test,_,_ = dataset.get_data()
data_shape = x_test.shape[1:]
train_samples,test_samples = x_train.shape[0],x_test.shape[0]
if (train_samples%batch_size) !=0:
train_samples = train_samples-(train_samples%batch_size)
x_train = x_train[:train_samples]
if (test_samples%batch_size) !=0:
test_samples = test_samples-(test_samples%batch_size)
x_test = x_test[:test_samples]
#Set up model
input_layer = keras.layers.Input(shape=data_shape)
encoder = dataset.build_encoder(data_shape)
encoder_out = encoder(input_layer)
encoder_out_shape = encoder_out.shape.as_list()[1:]
latent_param_block = build_latent_params(encoder_out_shape,
latent_size=latent_size,
activation='linear')
latent_params_out = latent_param_block(encoder_out)
latent_params_shape = latent_params_out.shape.as_list()[1:]
tanh_converter = build_lin_to_tanh_converter(latent_params_shape)
tanh_out = tanh_converter(latent_params_out)
drop_block = build_dropout_block(latent_params_shape)
drop_out = drop_block(tanh_out)
latent_shape = drop_out.shape.as_list()[1:]
decoder = dataset.build_decoder(latent_shape)
decoder_out = decoder(drop_out)
pretrain_model = keras.models.Model([input_layer],
[decoder_out])
saving_model = keras.models.Model([input_layer],
[decoder_out])
pretrain_model.compile(
optimizer=keras.optimizers.Adam(lr=0.001),
loss=tanh_crossentropy,
)
# # start training
early_stopping = keras.callbacks.EarlyStopping(patience=patience)
model_check = ModelCheckpoint(out,save_best_only=True)
pretrain_model.fit(x_train,x_train,validation_data=(x_test,x_test),
batch_size=batch_size,
epochs=epochs,
callbacks=[early_stopping,model_check],
)