-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
114 lines (94 loc) · 4.13 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
import torch
import torch.nn as nn
from torch.nn import init
def mkdir(paths):
if not isinstance(paths, (list, tuple)):
paths = [paths]
for path in paths:
if not os.path.isdir(path):
os.makedirs(path)
def cuda_devices(gpu_ids):
gpu_ids = [str(i) for i in gpu_ids]
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(gpu_ids)
def cuda(xs):
if torch.cuda.is_available():
if not isinstance(xs, (list, tuple)):
return xs.cuda()
else:
return [x.cuda() for x in xs]
def save_checkpoint(state, save_path, is_best=False, max_keep=None):
# save checkpoint
torch.save(state, save_path)
# deal with max_keep
save_dir = os.path.dirname(save_path)
list_path = os.path.join(save_dir, 'latest_checkpoint')
save_path = os.path.basename(save_path)
if os.path.exists(list_path):
with open(list_path) as f:
ckpt_list = f.readlines()
ckpt_list = [save_path + '\n'] + ckpt_list
else:
ckpt_list = [save_path + '\n']
if max_keep is not None:
for ckpt in ckpt_list[max_keep:]:
ckpt = os.path.join(save_dir, ckpt[:-1])
if os.path.exists(ckpt):
os.remove(ckpt)
ckpt_list[max_keep:] = []
with open(list_path, 'w') as f:
f.writelines(ckpt_list)
# copy best
if is_best:
shutil.copyfile(save_path, os.path.join(save_dir, 'best_model.ckpt'))
def load_checkpoint(ckpt_dir_or_file, map_location=None, load_best=False):
if os.path.isdir(ckpt_dir_or_file):
if load_best:
ckpt_path = os.path.join(ckpt_dir_or_file, 'best_model.ckpt')
else:
with open(os.path.join(ckpt_dir_or_file, 'latest_checkpoint')) as f:
ckpt_path = os.path.join(ckpt_dir_or_file, f.readline()[:-1])
else:
ckpt_path = ckpt_dir_or_file
ckpt = torch.load(ckpt_path, map_location=map_location)
print(' [*] Loading checkpoint from %s succeed!' % ckpt_path)
return ckpt
def init_weights(net, init_type='normal', init_gain=0.02):
"""Initialize network weights.
Parameters:
net (network) -- network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
work better for some applications. Feel free to try yourself.
"""
def init_func(m): # define the initialization function
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
#init.normal_(m.weight.data, 0.0, init_gain)
m.weight.data.normal_(0.0, init_gain)
elif init_type == 'xavier':
#init.xavier_normal_(m.weight.data, gain=init_gain)
m.weight.data = init.xavier_norm(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
#init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
m.weight.data = init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif init_type == 'constant':
#init.constant_(m.weight.data, 1.0)
m.weight.data.fill_(1.0)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
#init.constant_(m.bias.data, 0.0)
m.bias.data.fill_(0.0)
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
#init.normal_(m.weight.data, 1.0, init_gain)
#init.constant_(m.bias.data, 0.0)
m.weight.data.normal_(1.0, init_gain)
m.bias.data.fill_(0.0)
net.apply(init_func) # apply the initialization function <init_func>