-
Notifications
You must be signed in to change notification settings - Fork 4
/
mobilenetv4.py
269 lines (232 loc) · 9.15 KB
/
mobilenetv4.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
"""
Creates a MobileNetV4 Model as defined in:
Danfeng Qin, Chas Leichner, Manolis Delakis, Marco Fornoni, Shixin Luo, Fan Yang, Weijun Wang, Colby Banbury, Chengxi Ye, Berkin Akin, Vaibhav Aggarwal, Tenghui Zhu, Daniele Moro, Andrew Howard. (2024).
MobileNetV4 - Universal Models for the Mobile Ecosystem
arXiv preprint arXiv:2404.10518.
"""
import torch
import torch.nn as nn
import math
__all__ = ['mobilenetv4_conv_small', 'mobilenetv4_conv_medium', 'mobilenetv4_conv_large',
'mobilenetv4_hybrid_medium', 'mobilenetv4_hybrid_large']
def make_divisible(value, divisor, min_value=None, round_down_protect=True):
if min_value is None:
min_value = divisor
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if round_down_protect and new_value < 0.9 * value:
new_value += divisor
return new_value
class ConvBN(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
super(ConvBN, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, (kernel_size - 1)//2, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.block(x)
class UniversalInvertedBottleneck(nn.Module):
def __init__(self,
in_channels,
out_channels,
expand_ratio,
start_dw_kernel_size,
middle_dw_kernel_size,
stride,
middle_dw_downsample: bool = True,
use_layer_scale: bool = False,
layer_scale_init_value: float = 1e-5):
super(UniversalInvertedBottleneck, self).__init__()
self.start_dw_kernel_size = start_dw_kernel_size
self.middle_dw_kernel_size = middle_dw_kernel_size
if start_dw_kernel_size:
self.start_dw_conv = nn.Conv2d(in_channels, in_channels, start_dw_kernel_size,
stride if not middle_dw_downsample else 1,
(start_dw_kernel_size - 1) // 2,
groups=in_channels, bias=False)
self.start_dw_norm = nn.BatchNorm2d(in_channels)
expand_channels = make_divisible(in_channels * expand_ratio, 8)
self.expand_conv = nn.Conv2d(in_channels, expand_channels, 1, 1, bias=False)
self.expand_norm = nn.BatchNorm2d(expand_channels)
self.expand_act = nn.ReLU(inplace=True)
if middle_dw_kernel_size:
self.middle_dw_conv = nn.Conv2d(expand_channels, expand_channels, middle_dw_kernel_size,
stride if middle_dw_downsample else 1,
(middle_dw_kernel_size - 1) // 2,
groups=expand_channels, bias=False)
self.middle_dw_norm = nn.BatchNorm2d(expand_channels)
self.middle_dw_act = nn.ReLU(inplace=True)
self.proj_conv = nn.Conv2d(expand_channels, out_channels, 1, 1, bias=False)
self.proj_norm = nn.BatchNorm2d(out_channels)
if use_layer_scale:
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((out_channels)), requires_grad=True)
self.use_layer_scale = use_layer_scale
self.identity = stride == 1 and in_channels == out_channels
def forward(self, x):
shortcut = x
if self.start_dw_kernel_size:
x = self.start_dw_conv(x)
x = self.start_dw_norm(x)
x = self.expand_conv(x)
x = self.expand_norm(x)
x = self.expand_act(x)
if self.middle_dw_kernel_size:
x = self.middle_dw_conv(x)
x = self.middle_dw_norm(x)
x = self.middle_dw_act(x)
x = self.proj_conv(x)
x = self.proj_norm(x)
if self.use_layer_scale:
x = self.gamma * x
return x + shortcut if self.identity else x
class MobileNetV4(nn.Module):
def __init__(self, block_specs, num_classes=1000):
super(MobileNetV4, self).__init__()
c = 3
layers = []
for block_type, *block_cfg in block_specs:
if block_type == 'conv_bn':
block = ConvBN
k, s, f = block_cfg
layers.append(block(c, f, k, s))
elif block_type == 'uib':
block = UniversalInvertedBottleneck
start_k, middle_k, s, f, e = block_cfg
layers.append(block(c, f, e, start_k, middle_k, s))
else:
raise NotImplementedError
c = f
self.features = nn.Sequential(*layers)
# building last several layers
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
hidden_channels = 1280
self.conv = ConvBN(c, hidden_channels, 1)
self.classifier = nn.Linear(hidden_channels, num_classes)
self._initialize_weights()
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = self.conv(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
def mobilenetv4_conv_small(**kwargs):
"""
Constructs a MobileNetV4-Conv-Small model
"""
block_specs = [
# conv_bn, kernel_size, stride, out_channels
# uib, start_dw_kernel_size, middle_dw_kernel_size, stride, out_channels, expand_ratio
# 112px
('conv_bn', 3, 2, 32),
# 56px
('conv_bn', 3, 2, 32),
('conv_bn', 1, 1, 32),
# 28px
('conv_bn', 3, 2, 96),
('conv_bn', 1, 1, 64),
# 14px
('uib', 5, 5, 2, 96, 3.0), # ExtraDW
('uib', 0, 3, 1, 96, 2.0), # IB
('uib', 0, 3, 1, 96, 2.0), # IB
('uib', 0, 3, 1, 96, 2.0), # IB
('uib', 0, 3, 1, 96, 2.0), # IB
('uib', 3, 0, 1, 96, 4.0), # ConvNext
# 7px
('uib', 3, 3, 2, 128, 6.0), # ExtraDW
('uib', 5, 5, 1, 128, 4.0), # ExtraDW
('uib', 0, 5, 1, 128, 4.0), # IB
('uib', 0, 5, 1, 128, 3.0), # IB
('uib', 0, 3, 1, 128, 4.0), # IB
('uib', 0, 3, 1, 128, 4.0), # IB
('conv_bn', 1, 1, 960), # Conv
]
return MobileNetV4(block_specs, **kwargs)
def mobilenetv4_conv_medium(**kwargs):
"""
Constructs a MobileNetV4-Conv-Medium model
"""
block_specs = [
('conv_bn', 3, 2, 32),
('conv_bn', 3, 2, 128),
('conv_bn', 1, 1, 48),
# 3rd stage
('uib', 3, 5, 2, 80, 4.0),
('uib', 3, 3, 1, 80, 2.0),
# 4th stage
('uib', 3, 5, 2, 160, 6.0),
('uib', 3, 3, 1, 160, 4.0),
('uib', 3, 3, 1, 160, 4.0),
('uib', 3, 5, 1, 160, 4.0),
('uib', 3, 3, 1, 160, 4.0),
('uib', 3, 0, 1, 160, 4.0),
('uib', 0, 0, 1, 160, 2.0),
('uib', 3, 0, 1, 160, 4.0),
# 5th stage
('uib', 5, 5, 2, 256, 6.0),
('uib', 5, 5, 1, 256, 4.0),
('uib', 3, 5, 1, 256, 4.0),
('uib', 3, 5, 1, 256, 4.0),
('uib', 0, 0, 1, 256, 4.0),
('uib', 3, 0, 1, 256, 4.0),
('uib', 3, 5, 1, 256, 2.0),
('uib', 5, 5, 1, 256, 4.0),
('uib', 0, 0, 1, 256, 4.0),
('uib', 0, 0, 1, 256, 4.0),
('uib', 5, 0, 1, 256, 2.0),
# FC layers
('conv_bn', 1, 1, 960),
]
return MobileNetV4(block_specs, **kwargs)
def mobilenetv4_conv_large(**kwargs):
"""
Constructs a MobileNetV4-Conv-Large model
"""
block_specs = [
('conv_bn', 3, 2, 24),
('conv_bn', 3, 2, 96),
('conv_bn', 1, 1, 48),
('uib', 3, 5, 2, 96, 4.0),
('uib', 3, 3, 1, 96, 4.0),
('uib', 3, 5, 2, 192, 4.0),
('uib', 3, 3, 1, 192, 4.0),
('uib', 3, 3, 1, 192, 4.0),
('uib', 3, 3, 1, 192, 4.0),
('uib', 3, 5, 1, 192, 4.0),
('uib', 5, 3, 1, 192, 4.0),
('uib', 5, 3, 1, 192, 4.0),
('uib', 5, 3, 1, 192, 4.0),
('uib', 5, 3, 1, 192, 4.0),
('uib', 5, 3, 1, 192, 4.0),
('uib', 3, 0, 1, 192, 4.0),
('uib', 5, 5, 2, 512, 4.0),
('uib', 5, 5, 1, 512, 4.0),
('uib', 5, 5, 1, 512, 4.0),
('uib', 5, 5, 1, 512, 4.0),
('uib', 5, 0, 1, 512, 4.0),
('uib', 5, 3, 1, 512, 4.0),
('uib', 5, 0, 1, 512, 4.0),
('uib', 5, 0, 1, 512, 4.0),
('uib', 5, 3, 1, 512, 4.0),
('uib', 5, 5, 1, 512, 4.0),
('uib', 5, 0, 1, 512, 4.0),
('uib', 5, 0, 1, 512, 4.0),
('uib', 5, 0, 1, 512, 4.0),
('conv_bn', 1, 1, 960),
]
return MobileNetV4(block_specs, **kwargs)