-
Notifications
You must be signed in to change notification settings - Fork 53
/
lednet.patch
108 lines (95 loc) · 3.87 KB
/
lednet.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
diff -x .git -x .gitignore -x '*.png' -r -N -r -u ./train/lednet.py esanet/src/models/external_code/LEDNet/train/lednet.py
--- ./train/lednet.py 2020-12-06 10:02:54.000000000 +0100
+++ esanet/src/models/external_code/LEDNet/train/lednet.py 2020-12-04 17:56:52.000000000 +0100
@@ -58,11 +58,11 @@
x1 = self.pool(input)
x2 = self.conv(input)
- diffY = x2.size()[2] - x1.size()[2]
- diffX = x2.size()[3] - x1.size()[3]
-
- x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
- diffY // 2, diffY - diffY // 2])
+ # diffY = x2.size()[2] - x1.size()[2]
+ # diffX = x2.size()[3] - x1.size()[3]
+ #
+ # x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
+ # diffY // 2, diffY - diffY // 2])
output = torch.cat([x2, x1], 1)
output = self.bn(output)
@@ -205,7 +205,7 @@
self.size = size
self.mode = mode
def forward(self,x):
- x = self.interp(x,size=self.size,mode=self.mode,align_corners=True)
+ x = self.interp(x,size=self.size,mode=self.mode,align_corners=False)
return x
@@ -235,12 +235,12 @@
def forward(self, x):
- h = x.size()[2]
- w = x.size()[3]
+ h = int(x.size()[2])
+ w = int(x.size()[3])
b1 = self.branch1(x)
# b1 = Interpolate(size=(h, w), mode="bilinear")(b1)
- b1= interpolate(b1, size=(h, w), mode="bilinear", align_corners=True)
+ b1= interpolate(b1, size=(h, w), mode="bilinear", align_corners=False)
mid = self.mid(x)
@@ -248,16 +248,18 @@
x2 = self.down2(x1)
x3 = self.down3(x2)
# x3 = Interpolate(size=(h // 4, w // 4), mode="bilinear")(x3)
- x3= interpolate(x3, size=(h // 4, w // 4), mode="bilinear", align_corners=True)
+ x3= interpolate(x3, size=(h // 4, w // 4), mode="bilinear",
+ align_corners=False)
x2 = self.conv2(x2)
x = x2 + x3
# x = Interpolate(size=(h // 2, w // 2), mode="bilinear")(x)
- x= interpolate(x, size=(h // 2, w // 2), mode="bilinear", align_corners=True)
+ x= interpolate(x, size=(h // 2, w // 2), mode="bilinear",
+ align_corners=False)
x1 = self.conv1(x1)
x = x + x1
# x = Interpolate(size=(h, w), mode="bilinear")(x)
- x= interpolate(x, size=(h, w), mode="bilinear", align_corners=True)
+ x= interpolate(x, size=(h, w), mode="bilinear", align_corners=False)
x = torch.mul(x, mid)
@@ -282,7 +284,8 @@
def forward(self, input):
output = self.apn(input)
- out = interpolate(output, size=(512, 1024), mode="bilinear", align_corners=True)
+ out = interpolate(output, size=(512, 1024), mode="bilinear",
+ align_corners=False)
# out = self.upsample(output)
# print(out.shape)
return out
diff -x .git -x .gitignore -x '*.png' -r -N -r -u ./train/to_onnx.py esanet/src/models/external_code/LEDNet/train/to_onnx.py
--- ./train/to_onnx.py 1970-01-01 01:00:00.000000000 +0100
+++ esanet/src/models/external_code/LEDNet/train/to_onnx.py 2020-12-04 17:56:52.000000000 +0100
@@ -0,0 +1,27 @@
+import os
+import torch
+from lednet import Net
+
+N_CLASSES = 19
+H = 512
+W = 1024
+
+model = Net(N_CLASSES)
+model.eval()
+
+rgb = torch.rand(size=(1, 3, H, W), dtype=torch.float32)
+
+out_dir = '../../onnx_models'
+os.makedirs(out_dir, exist_ok=True)
+onnx_file_path = os.path.join(out_dir, 'lednet.onnx')
+
+torch.onnx.export(model,
+ rgb,
+ onnx_file_path,
+ export_params=True,
+ input_names=['rgb'],
+ output_names=['output'],
+ do_constant_folding=True,
+ verbose=False,
+ opset_version=10
+ )