forked from JDAI-CV/fast-reid
-
Notifications
You must be signed in to change notification settings - Fork 0
/
caffe_export.py
87 lines (70 loc) · 2.26 KB
/
caffe_export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# encoding: utf-8
"""
@author: xingyu liao
@contact: [email protected]
"""
import argparse
import logging
import sys
import torch
sys.path.append('../../')
import pytorch_to_caffe
from fastreid.config import get_cfg
from fastreid.modeling.meta_arch import build_model
from fastreid.utils.file_io import PathManager
from fastreid.utils.checkpoint import Checkpointer
from fastreid.utils.logger import setup_logger
# import some modules added in project like this below
# sys.path.append('../projects/FastCls')
# from fastcls import *
setup_logger(name='fastreid')
logger = logging.getLogger("fastreid.caffe_export")
def setup_cfg(args):
cfg = get_cfg()
# add_cls_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
return cfg
def get_parser():
parser = argparse.ArgumentParser(description="Convert Pytorch to Caffe model")
parser.add_argument(
"--config-file",
metavar="FILE",
help="path to config file",
)
parser.add_argument(
"--name",
default="baseline",
help="name for converted model"
)
parser.add_argument(
"--output",
default='caffe_model',
help='path to save converted caffe model'
)
parser.add_argument(
"--opts",
help="Modify config options using the command-line 'KEY VALUE' pairs",
default=[],
nargs=argparse.REMAINDER,
)
return parser
if __name__ == '__main__':
args = get_parser().parse_args()
cfg = setup_cfg(args)
cfg.defrost()
cfg.MODEL.BACKBONE.PRETRAIN = False
if cfg.MODEL.HEADS.POOL_LAYER == 'fastavgpool':
cfg.MODEL.HEADS.POOL_LAYER = 'avgpool'
cfg.MODEL.BACKBONE.WITH_NL = False
model = build_model(cfg)
Checkpointer(model).load(cfg.MODEL.WEIGHTS)
model.eval()
logger.info(model)
inputs = torch.randn(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1]).to(torch.device(cfg.MODEL.DEVICE))
PathManager.mkdirs(args.output)
pytorch_to_caffe.trans_net(model, inputs, args.name)
pytorch_to_caffe.save_prototxt(f"{args.output}/{args.name}.prototxt")
pytorch_to_caffe.save_caffemodel(f"{args.output}/{args.name}.caffemodel")
logger.info(f"Export caffe model in {args.output} sucessfully!")