-
Notifications
You must be signed in to change notification settings - Fork 0
/
NAFRSSR_M_Net.py
360 lines (270 loc) · 11.5 KB
/
NAFRSSR_M_Net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
'''
@misc{chen2024nafrssr,
title={NAFRSSR: a Lightweight Recursive Network for Efficient Stereo Image Super-Resolution},
author={Yihong Chen and Zhen Fan and Shuai Dong and Zhiwei Chen and Wenjie Li and Minghui Qin and Min Zeng and Xubing Lu and Guofu Zhou and Xingsen Gao and Jun-Ming Liu},
year={2024},
eprint={2405.08423},
archivePrefix={arXiv},
primaryClass={eess.IV}
}
'''
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from basicsr.models.archs.arch_util import MySequential
from basicsr.models.archs.arch_util import LayerNorm2d
from basicsr.models.archs.local_arch import Local_Base
class SimpleGate(nn.Module):
def forward(self, x):
x1, x2 = x.chunk(2, dim=1)
return x1 * x2
class GaussianBlur(nn.Module):
def __init__(self):
super(GaussianBlur, self).__init__()
# kernel_1和kernel2均为边缘检测
# kernel_1 = [[-1, -1, -1],
# [-1, 8, -1],
# [-1, -1, -1]]
kernel_2 = [[0, -1, 0],
[-1, 4, -1],
[0, -1, 0]]
kernel = torch.Tensor(kernel_2).unsqueeze(0).unsqueeze(0)
self.weight = nn.Parameter(data=kernel, requires_grad=True)
def forward(self, x):
x1 = x[:, 0]
x2 = x[:, 1]
x3 = x[:, 2]
x4 = x[:, 3]
x5 = x[:, 4]
x6 = x[:, 5]
x1 = F.conv2d(x1.unsqueeze(1), self.weight, padding=1)
x2 = F.conv2d(x2.unsqueeze(1), self.weight, padding=1)
x3 = F.conv2d(x3.unsqueeze(1), self.weight, padding=1)
x4 = F.conv2d(x4.unsqueeze(1), self.weight, padding=1)
x5 = F.conv2d(x5.unsqueeze(1), self.weight, padding=1)
x6 = F.conv2d(x6.unsqueeze(1), self.weight, padding=1)
x = torch.cat([x1, x2, x3, x4, x5, x6], dim=1)
return x
# my module
class NAFBlock(nn.Module):
def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
super().__init__()
dw_channel = c * DW_Expand
self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1,
bias=True)
self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1,
groups=dw_channel, bias=True)
self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=3, padding=1, stride=1,
groups=c // 4, bias=True)
self.sg = SimpleGate()
ffn_channel = FFN_Expand * c
self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1,
groups=1, bias=True)
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1,
groups=1, bias=True)
self.norm1 = LayerNorm2d(c)
self.norm2 = LayerNorm2d(c)
self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
def forward(self, inp):
x = inp
x = self.norm1(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.sg(x)
x = self.conv3(x)
x = self.dropout1(x)
y = inp + x * self.beta
x = self.norm2(y)
x = self.conv4(x)
x = self.sg(x)
x = self.conv5(x)
x = self.dropout2(x)
return y + x * self.gamma
class SCAM(nn.Module):
'''
Stereo Cross Attention Module (SCAM)
'''
def __init__(self, c):
super().__init__()
self.scale = c ** -0.5
self.norm_l = LayerNorm2d(c)
# self.norm_r = LayerNorm2d(c)
self.l_proj1 = nn.Conv2d(c, c, kernel_size=3, stride=1, padding=1, groups=c)
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
def forward(self, x_l, x_r):
Q_l = self.l_proj1(self.norm_l(x_l)).permute(0, 2, 3, 1) # B, H, W, c
Q_r_T = self.l_proj1(self.norm_l(x_r)).permute(0, 2, 1, 3) # B, H, c, W (transposed)
V_l = x_l.permute(0, 2, 3, 1) # B, H, W, c
V_r = x_r.permute(0, 2, 3, 1) # B, H, W, c
# (B, H, W, c) x (B, H, c, W) -> (B, H, W, W)
attention = torch.matmul(Q_l, Q_r_T) * self.scale
F_r2l = torch.matmul(torch.softmax(attention, dim=-1), V_r) # B, H, W, c
F_l2r = torch.matmul(torch.softmax(attention.permute(0, 1, 3, 2), dim=-1), V_l) # B, H, W, c
# scale
F_r2l = F_r2l.permute(0, 3, 1, 2) * self.beta
F_l2r = F_l2r.permute(0, 3, 1, 2) * self.gamma
return x_l + F_r2l, x_r + F_l2r
class DropPath(nn.Module):
def __init__(self, drop_rate, module):
super().__init__()
self.drop_rate = drop_rate
self.module = module
def forward(self, *feats):
if self.training and np.random.rand() < self.drop_rate:
return feats
new_feats = self.module(*feats)
factor = 1. / (1 - self.drop_rate) if self.training else 1.
if self.training and factor != 1.:
new_feats = tuple([x + factor * (new_x - x) for x, new_x in zip(feats, new_feats)])
return new_feats
class NAFBlockSR(nn.Module):
'''
NAFBlock for Super-Resolution
'''
def __init__(self, c, fusion=False, drop_out_rate=0.):
super().__init__()
self.blk = NAFBlock(c, drop_out_rate=drop_out_rate)
self.fusion = SCAM(c) if fusion else None
def forward(self, *feats):
feats = tuple([self.blk(x) for x in feats])
if self.fusion:
feats = self.fusion(*feats)
return feats
class NAFBlock_mid(nn.Module):
def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
super().__init__()
dw_channel = c * DW_Expand
self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1,
bias=True)
self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1,
groups=dw_channel, bias=True)
# self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1,
# groups=1, bias=True)
self.sg = SimpleGate()
self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0,
stride=1, groups=1, bias=True)
self.conv4 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=3, padding=1, stride=1,
groups=4, bias=True)
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
self.norm1 = LayerNorm2d(c)
def forward(self, inp):
x = inp
x = self.norm1(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.sg(x)
x = self.conv4(x)
x = self.conv3(x)
return inp + x * self.gamma
class NAFBlockSR_mid(nn.Module):
'''
NAFBlock for Super-Resolution
'''
def __init__(self, c, fusion=False, drop_out_rate=0.):
super().__init__()
self.blk = NAFBlock_mid(c, drop_out_rate=drop_out_rate)
self.fusion = SCAM(c) if fusion else None
def forward(self, *feats):
feats = tuple([self.blk(x) for x in feats])
if self.fusion:
feats = self.fusion(*feats)
return feats
class NAFNetSR(nn.Module):
'''
NAFNet for Super-Resolution
'''
def __init__(self, up_scale=4, width=64, num_blks=16, img_channel=3, drop_path_rate=0., drop_out_rate=0.,
fusion_from=-1, fusion_to=-1, dual=False):
super().__init__()
self.dual = dual # dual input for stereo SR (left view, right view)
self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1,
groups=1, bias=True)
start_block = 2
mid_block = 4
end_block = 4
self.body1 = MySequential(
*[DropPath(
drop_path_rate,
NAFBlockSR(
width,
fusion=(fusion_from <= i and i <= fusion_to),
drop_out_rate=drop_out_rate
)) for i in range(start_block)]
)
self.body2 = MySequential(
*[DropPath(
drop_path_rate,
NAFBlockSR(
width,
fusion=(fusion_from <= i and i <= fusion_to),
drop_out_rate=drop_out_rate
)) for i in range(end_block)]
)
self.mid = MySequential(
*[DropPath(
drop_path_rate,
NAFBlockSR_mid(
width,
fusion=(fusion_from <= i and i <= fusion_to),
# fusion=False,
drop_out_rate=drop_out_rate
)) for i in range(mid_block)]
)
self.up = nn.Sequential(
nn.Conv2d(in_channels=width, out_channels=img_channel * up_scale ** 2, kernel_size=3, padding=1, stride=1,
groups=1, bias=True),
nn.PixelShuffle(up_scale)
)
self.up_scale = up_scale
self.blur = GaussianBlur()
self.weight = nn.Parameter(torch.zeros((1, 6, 1, 1)), requires_grad=True)
def forward(self, inp):
inp_hr = F.interpolate(inp, scale_factor=self.up_scale, mode='bicubic')
if self.dual:
inp = inp.chunk(2, dim=1)
else:
inp = (inp,)
feats = [self.intro(x) for x in inp]
feats = self.body1(*feats)
feats = self.body1(*feats)
feats = self.mid(*feats)
feats = self.mid(*feats)
feats = self.mid(*feats)
feats = self.body2(*feats)
out = torch.cat([self.up(x) for x in feats], dim=1)
out = out + self.blur(out) * self.weight
out = out + inp_hr
return out
class NAFSSR(Local_Base, NAFNetSR):
def __init__(self, *args, train_size=(1, 6, 30, 90), fast_imp=False, fusion_from=-1, fusion_to=1000, **kwargs):
Local_Base.__init__(self)
NAFNetSR.__init__(self, *args, img_channel=3, fusion_from=fusion_from, fusion_to=fusion_to, dual=True, **kwargs)
N, C, H, W = train_size
base_size = (int(H * 1.5), int(W * 1.5))
self.eval()
with torch.no_grad():
self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp)
if __name__ == '__main__':
num_blks = 16
width = 64
droppath = 0.
train_size = (1, 6, 30, 90)
net = NAFSSR(up_scale=4, train_size=train_size, fast_imp=True, width=width, num_blks=num_blks,
drop_path_rate=droppath)
inp_shape = (6, 64, 64)
from ptflops import get_model_complexity_info
FLOPS = 0
macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=True)
# params = float(params[:-4])
print(params)
macs = float(macs[:-4]) + FLOPS / 10 ** 9
print('mac', macs, params)
from basicsr.models.archs.arch_util import measure_inference_speed
net = net.cuda()
data = torch.randn((1, 6, 128, 128)).cuda()
measure_inference_speed(net, (data,))