generated from cavalleria/pytorch-template
-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathbase_model.py
117 lines (103 loc) · 3.92 KB
/
base_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import torch.nn as nn
import torchsummary
import os, warnings, sys
from utils import add_flops_counting_methods, flops_to_string
class BaseModel(nn.Module):
def __init__(self):
super(BaseModel, self).__init__()
def summary(self, input_shape, batch_size=1, device='cpu', print_flops=False):
print("[%s] Network summary..." % (self.__class__.__name__))
torchsummary.summary(self, input_size=input_shape, batch_size=batch_size, device=device)
if print_flops:
input = torch.randn([1, *input_shape], dtype=torch.float)
counter = add_flops_counting_methods(self)
counter.eval().start_flops_count()
counter(input)
print('Flops: {}'.format(flops_to_string(counter.compute_average_flops_cost())))
print('----------------------------------------------------------------')
def init_weights(self):
print("[%s] Initialize weights..." % (self.__class__.__name__))
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
def load_pretrained_model(self, pretrained):
if isinstance(pretrained, str):
print("[%s] Load pretrained model from %s" % (self.__class__.__name__, pretrained))
pretrain_dict = torch.load(pretrained, map_location='cpu')
if 'state_dict' in pretrain_dict:
pretrain_dict = pretrain_dict['state_dict']
elif isinstance(pretrained, dict):
print("[%s] Load pretrained model" % (self.__class__.__name__))
pretrain_dict = pretrained
model_dict = {}
state_dict = self.state_dict()
for k, v in pretrain_dict.items():
if k in state_dict:
if state_dict[k].shape==v.shape:
model_dict[k] = v
else:
print("[%s]"%(self.__class__.__name__), k, "is ignored due to not matching shape")
else:
print("[%s]"%(self.__class__.__name__), k, "is ignored due to not matching key")
state_dict.update(model_dict)
self.load_state_dict(state_dict)
class BaseBackbone(BaseModel):
def __init__(self):
super(BaseBackbone, self).__init__()
def load_pretrained_model_extended(self, pretrained):
"""
This function is specifically designed for loading pretrain with different in_channels
"""
if isinstance(pretrained, str):
print("[%s] Load pretrained model from %s" % (self.__class__.__name__, pretrained))
pretrain_dict = torch.load(pretrained, map_location='cpu')
if 'state_dict' in pretrain_dict:
pretrain_dict = pretrain_dict['state_dict']
elif isinstance(pretrained, dict):
print("[%s] Load pretrained model" % (self.__class__.__name__))
pretrain_dict = pretrained
model_dict = {}
state_dict = self.state_dict()
for k, v in pretrain_dict.items():
if k in state_dict:
if state_dict[k].shape!=v.shape:
model_dict[k] = state_dict[k]
model_dict[k][:,:3,...] = v
else:
model_dict[k] = v
else:
print("[%s]"%(self.__class__.__name__), k, "is ignored")
state_dict.update(model_dict)
self.load_state_dict(state_dict)
class BaseBackboneWrapper(BaseBackbone):
def __init__(self):
super(BaseBackboneWrapper, self).__init__()
def train(self, mode=True):
if mode:
print("[%s] Switch to train mode" % (self.__class__.__name__))
else:
print("[%s] Switch to eval mode" % (self.__class__.__name__))
super(BaseBackboneWrapper, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for module in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(module, nn.BatchNorm2d):
module.eval()
elif isinstance(module, nn.Sequential):
for m in module:
if isinstance(module, (nn.BatchNorm2d, nn.GroupNorm)):
m.eval()
def init_from_imagenet(self, archname):
pass
def _freeze_stages(self):
pass