-
Notifications
You must be signed in to change notification settings - Fork 20
/
train.py
198 lines (149 loc) · 6.71 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import numpy as np
import torch
from tensorboardX import SummaryWriter
from torch.optim import Adam
from dataset import gen_dataloaders
from nets.MobileNetV2_unet import MobileNetV2_unet
# count number of model parameters
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def train(data_loader, model, optimizer, criterion):
model.train()
running_loss = 0.0
count = 0
for batch_idx, (inputs, labels) in enumerate(data_loader):
inputs = inputs.to(args.device)
labels = labels.to(args.device)
optimizer.zero_grad()
# with torch.set_grad_enabled(True):
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
count += inputs.size(0)
if batch_idx % args.log_interval != 0:
continue
print('[{}/{} ({:0.0f}%)]\tLoss: {:0.3f}'.format(
batch_idx * len(inputs),
len(data_loader.dataset),
100. * batch_idx / len(data_loader),
loss.item()))
epoch_loss = running_loss / count
print('[End of train epoch]\tLoss: {:0.5f}'.format(epoch_loss))
return epoch_loss
def test(data_loader, model, criterion):
model.eval()
running_loss = 0.0
count = 0
with torch.no_grad():
for inputs, labels in data_loader:
inputs = inputs.to(args.device)
labels = labels.to(args.device)
outputs = model(inputs)
loss = criterion(outputs, labels)
running_loss += loss.item() * inputs.size(0)
count += inputs.size(0)
epoch_loss = running_loss / count
print('[End of test epoch]\tLoss: {:0.5f}'.format(epoch_loss))
return epoch_loss
def main():
# Tensorboard writer
writer = SummaryWriter(args.log_dir)
save_filename = args.model_dir
# Data
train_loader, valid_loader = gen_dataloaders(args.data_folder,
val_split=0.05, shuffle=True,
batch_size=args.batch_size,
seed=args.seed,
img_size=224,
cuda=args.cuda
)
# Model
model = MobileNetV2_unet(pre_trained=args.pre_trained).to(args.device)
optimizer = Adam(model.parameters(), lr=args.lr)
criterion = torch.nn.CrossEntropyLoss()
test_losses = []
best_loss = np.inf
for epoch in range(1, args.num_epochs + 1):
print("================== Epoch: {} ==================".format(epoch))
train_loss = train(train_loader, model, optimizer, criterion)
val_loss = test(valid_loader, model, criterion)
# Logs
writer.add_scalar('loss/train/loss', train_loss, epoch)
writer.add_scalar('loss/test/loss', val_loss, epoch)
test_losses.append(val_loss)
if val_loss < best_loss:
best_loss = val_loss
# Save model
with open('{0}/best.pt'.format(save_filename), 'wb') as f:
torch.save(model.state_dict(), f)
# Early stopping
if args.patience is not None and epoch > args.patience + 1:
loss_array = np.array(test_losses)
if all(loss_array[-args.patience:] - best_loss >
args.early_stopping_eps):
break
print("Model saved at: {0}/best.pt".format(save_filename))
print("# Parameters: {}".format(count_parameters(model)))
return
if __name__ == '__main__':
import argparse
import os
parser = argparse.ArgumentParser(description='Patchy VAE')
# Dataset
parser.add_argument('--dataset', type=str, default='lfw',
help='name of the dataset (default: lfw)')
parser.add_argument('--data-folder', type=str, default='./data',
help='name of the data folder (default: ./data)')
parser.add_argument('--workers', type=int, default=1,
help='number of image preprocessing (default: 1)')
# Model
parser.add_argument('--arch', type=str, default='patchy',
help='model architecture (default: patchy)')
# Optimization
parser.add_argument('--batch-size', type=int, default=8,
help='batch size (default: 8)')
parser.add_argument('--num-epochs', type=int, default=10,
help='number of epochs (default: 10)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='enables CUDA training (default: False)')
parser.add_argument('--lr', type=float, default=3e-4,
help='learning rate for Adam optimizer (default: 3e-4)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10,
help='how many batches to wait before logging training status')
# Early Stopping
parser.add_argument('--patience', type=int, default=None,
help='patience for early stopping (default: None)')
parser.add_argument('--early-stopping-eps', type=int, default=1e-5,
help='patience for early stopping (default: 1e-5)')
# Miscellaneous
parser.add_argument('--pre-trained', type=str, default=None,
help='path of pre-trained weights (default: None)')
parser.add_argument('--output-folder', type=str, default='./scratch',
help='name of the output folder (default: ./scratch)')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
args.device = torch.device("cuda" if args.cuda else "cpu")
# Slurm
if 'SLURM_JOB_NAME' in os.environ and 'SLURM_JOB_ID' in os.environ:
# running with sbatch and not srun
if os.environ['SLURM_JOB_NAME'] != 'bash':
args.output_folder = os.path.join(args.output_folder,
os.environ['SLURM_JOB_ID'])
else:
args.output_folder = os.path.join(args.output_folder, str(os.getpid()))
# Create logs and models folder if they don't exist
if not os.path.exists(args.output_folder):
os.makedirs(args.output_folder)
log_dir = os.path.join(args.output_folder, 'logs')
if not os.path.exists(log_dir):
os.makedirs(log_dir)
model_dir = os.path.join(args.output_folder, 'models')
if not os.path.exists(model_dir):
os.makedirs(model_dir)
args.log_dir = log_dir
args.model_dir = model_dir
main()