-
Notifications
You must be signed in to change notification settings - Fork 53
/
ssma.patch
135 lines (121 loc) · 5.42 KB
/
ssma.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
diff -x .git -x .gitignore -x '*.pyc' -r -N -u ./src/adapnet.py esanet/src/models/external_code/ssma_pytorch/src/adapnet.py
--- ./src/adapnet.py 2020-12-05 23:38:10.000000000 +0100
+++ esanet/src/models/external_code/ssma_pytorch/src/adapnet.py 2020-12-04 17:56:53.000000000 +0100
@@ -9,7 +9,7 @@
class AdapNet(nn.Module):
"""PyTorch module for 'AdapNet++' and 'AdapNet++ with fusion architecture' """
- def __init__(self, C, encoders=[]):
+ def __init__(self, C, encoders=1):
"""Constructor
:param C: number of categories
@@ -23,12 +23,12 @@
self.num_categories = C
self.fusion = False
- if len(encoders) > 0:
+ if encoders == 2:
self.encoder_mod1 = Encoder()
- self.encoder_mod1.load_state_dict(encoders[0].state_dict())
+ # self.encoder_mod1.load_state_dict(encoders[0].state_dict())
self.encoder_mod1.res_n50_enc.layer3[2].dropout = False
self.encoder_mod2 = Encoder()
- self.encoder_mod2.load_state_dict(encoders[1].state_dict())
+ # self.encoder_mod2.load_state_dict(encoders[1].state_dict())
self.encoder_mod2.res_n50_enc.layer3[2].dropout = False
self.ssma_s1 = SSMA(24, 6)
self.ssma_s2 = SSMA(24, 6)
@@ -60,6 +60,6 @@
m1_x = self.eASPP(m1_x)
- aux1, aux2, res = self.decoder(m1_x, skip1, skip2)
+ res = self.decoder(m1_x, skip1, skip2)
- return aux1, aux2, res
+ return res
diff -x .git -x .gitignore -x '*.pyc' -r -N -u ./src/components/decoder.py esanet/src/models/external_code/ssma_pytorch/src/components/decoder.py
--- ./src/components/decoder.py 2020-12-05 23:38:10.000000000 +0100
+++ esanet/src/models/external_code/ssma_pytorch/src/components/decoder.py 2020-12-04 17:56:53.000000000 +0100
@@ -82,7 +82,7 @@
"""
# stage 1
x = torch.relu(self.deconv1_bn(self.deconv1(x)))
- y1 = self.aux(x, self.aux_conv1, self.aux_conv1_bn, 8)
+ # y1 = self.aux(x, self.aux_conv1, self.aux_conv1_bn, 8)
if self.fusion:
# integrate fusion skip
int_fuse_skip = self.integrate_fuse_skip(x, skip1, self.fuse_conv1, self.fuse_conv1_bn)
@@ -92,7 +92,7 @@
# stage 2
x = self.stage2(x)
- y2 = self.aux(x, self.aux_conv2, self.aux_conv2_bn, 4)
+ # y2 = self.aux(x, self.aux_conv2, self.aux_conv2_bn, 4)
if self.fusion:
# integrate fusion skip
int_fuse_skip = self.integrate_fuse_skip(x, skip2, self.fuse_conv2, self.fuse_conv2_bn)
@@ -103,7 +103,7 @@
# stage 3
y3 = self.stage3(x)
- return y1, y2, y3
+ return y3
def aux(self, x, conv, bn, scale):
"""Compute auxiliary output"""
@@ -113,6 +113,7 @@
def integrate_fuse_skip(self, x, fuse_skip, conv, bn):
"""Integrate fuse skip connection with decoder"""
+ x = nn.AdaptiveAvgPool2d((2, 2))(x)
x = nn.AdaptiveAvgPool2d((1, 1))(x)
x = torch.relu(bn(conv(x)))
diff -x .git -x .gitignore -x '*.pyc' -r -N -u ./src/components/easpp.py esanet/src/models/external_code/ssma_pytorch/src/components/easpp.py
--- ./src/components/easpp.py 2020-12-05 23:38:10.000000000 +0100
+++ esanet/src/models/external_code/ssma_pytorch/src/components/easpp.py 2020-12-04 17:56:53.000000000 +0100
@@ -57,6 +57,7 @@
:param x: input from encoder (in stage 1) or from fused encoders (in stage 2 and 3)
:return: feature maps to be forwarded to decoder
"""
+ h, w = x.size()[2:]
# branch 1: 1x1 convolution
out = torch.relu(self.branch1_bn(self.branch1_conv(x)))
@@ -69,9 +70,11 @@
out = torch.cat((out, y), 1)
# branch 5: image pooling
+ x = nn.AdaptiveAvgPool2d((2, 2))(x)
x = nn.AdaptiveAvgPool2d((1, 1))(x)
x = torch.relu(self.branch5_bn(self.branch5_conv(x)))
- x = nn.Upsample((24, 48), mode="bilinear")(x)
+ x = nn.Upsample((int(h), int(w)), mode="bilinear",
+ align_corners=False)(x)
out = torch.cat((out, x), 1)
return torch.relu(self.eASPP_fin_bn(self.eASPP_fin_conv(out)))
diff -x .git -x .gitignore -x '*.pyc' -r -N -u ./src/to_onnx.py esanet/src/models/external_code/ssma_pytorch/src/to_onnx.py
--- ./src/to_onnx.py 1970-01-01 01:00:00.000000000 +0100
+++ esanet/src/models/external_code/ssma_pytorch/src/to_onnx.py 2020-12-04 17:56:53.000000000 +0100
@@ -0,0 +1,32 @@
+from torch import nn
+import os
+import torch
+
+from adapnet import AdapNet
+
+
+def to_onnx(n_classes, h, w, outname):
+ model = AdapNet(C=n_classes, encoders=2)
+ model.eval()
+
+ rgb = torch.rand(size=(1, 3, h, w), dtype=torch.float32)
+ depth = torch.rand(size=(1, 3, h, w), dtype=torch.float32)
+
+ out_dir = '../../onnx_models'
+ os.makedirs(out_dir, exist_ok=True)
+ onnx_file_path = os.path.join(out_dir, outname)
+ torch.onnx.export(model,
+ (rgb, depth),
+ onnx_file_path,
+ export_params=True,
+ input_names=['rgb', 'hha'],
+ output_names=['output'],
+ do_constant_folding=True,
+ verbose=False,
+ opset_version=10
+ )
+ print(f'exported {onnx_file_path}')
+
+
+to_onnx(n_classes=40, h=480, w=640, outname='ssma_nyuv2.onnx')
+to_onnx(n_classes=19, h=1024, w=2048, outname='ssma_cityscapes.onnx')