-
Notifications
You must be signed in to change notification settings - Fork 53
/
from_pytorch_cv.py
59 lines (46 loc) · 1.7 KB
/
from_pytorch_cv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
import torch
from pytorchcv.model_provider import get_model as ptcv_get_model
out_dir = './onnx_models'
os.makedirs(out_dir, exist_ok=True)
def export(model, rgb, onnx_name, opset=10):
onnx_file_path = os.path.join(out_dir, onnx_name)
torch.onnx.export(model,
rgb,
onnx_file_path,
export_params=True,
input_names=['rgb'],
output_names=['output'],
do_constant_folding=True,
verbose=False,
opset_version=opset
)
print(f'exported {onnx_name}')
if __name__ == '__main__':
#%% DeepLabv3 with output stride 8
H = 1024
W = 2048
rgb = torch.rand(size=(1, 3, H, W), dtype=torch.float32)
deeplabv3 = ptcv_get_model('deeplabv3_resnetd101b_cityscapes',
in_size=(H, W),
aux=False)
deeplabv3.eval()
export(deeplabv3, rgb, 'deeplabv3.onnx')
#%% PSPNet with output stride 8
H = 1008 # height and width need to be dividable by 16 and 6
W = 2016
rgb = torch.rand(size=(1, 3, H, W), dtype=torch.float32)
pspnet = ptcv_get_model('pspnet_resnetd101b_cityscapes',
in_size=(H, W),
aux=False)
pspnet.eval()
export(pspnet, rgb, 'pspnet.onnx')
#%% BiSENet
H = 1024
W = 2048
rgb = torch.rand(size=(1, 3, H, W), dtype=torch.float32)
bisenet = ptcv_get_model('bisenet_resnet18_celebamaskhq',
in_size=(H, W),
aux=False)
bisenet.eval()
export(bisenet, rgb, 'bisenet.onnx')