forked from opencv/opencv_zoo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add tools for quantization and quantized models (opencv#36)
* add scripts for quantization * update path to pp-resnet50 * add quantized models * rename dict to models * add requirements and readme * fix typos
- Loading branch information
1 parent
1147b6c
commit cdfd1de
Showing
11 changed files
with
210 additions
and
0 deletions.
There are no files selected for viewing
3 changes: 3 additions & 0 deletions
3
models/face_detection_yunet/face_detection_yunet_2021dec-act_int8-wt_int8-quantized.onnx
Git LFS file not shown
3 changes: 3 additions & 0 deletions
3
models/face_recognition_sface/face_recognition_sface_2021dec-act_int8-wt_int8-quantized.onnx
Git LFS file not shown
3 changes: 3 additions & 0 deletions
3
...entation_pphumanseg/human_segmentation_pphumanseg_2021oct-act_int8-wt_int8-quantized.onnx
Git LFS file not shown
3 changes: 3 additions & 0 deletions
3
...fication_ppresnet/image_classification_ppresnet50_2022jan-act_int8-wt_int8-quantized.onnx
Git LFS file not shown
3 changes: 3 additions & 0 deletions
3
models/person_reid_youtureid/person_reid_youtu_2021nov-act_int8-wt_int8-quantized.onnx
Git LFS file not shown
3 changes: 3 additions & 0 deletions
3
...ls/text_recognition_crnn/text_recognition_CRNN_CN_2021nov-act_int8-wt_int8-quantized.onnx
Git LFS file not shown
3 changes: 3 additions & 0 deletions
3
...ls/text_recognition_crnn/text_recognition_CRNN_EN_2021sep-act_int8-wt_int8-quantized.onnx
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Quantization with ONNXRUNTIME | ||
|
||
ONNXRUNTIME is used for quantization in the Zoo. | ||
|
||
Install dependencies before trying quantization: | ||
```shell | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Usage | ||
|
||
Quantize all models in the Zoo: | ||
```shell | ||
python quantize.py | ||
``` | ||
|
||
Quantize one of the models in the Zoo: | ||
```shell | ||
# python quantize.py <key_in_models> | ||
python quantize.py yunet | ||
``` | ||
|
||
Customizing quantization configs: | ||
```python | ||
# add model into `models` dict in quantize.py | ||
models = dict( | ||
# ... | ||
model1=Quantize(model_path='/path/to/model1.onnx' | ||
calibration_image_dir='/path/to/images', | ||
transforms=Compose([''' transforms ''']), # transforms can be found in transforms.py | ||
per_channel=False, # set False to quantize in per-tensor style | ||
act_type='int8', # available types: 'int8', 'uint8' | ||
wt_type='int8' # available types: 'int8', 'uint8' | ||
) | ||
) | ||
# quantize the added models | ||
python quantize.py model1 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# This file is part of OpenCV Zoo project. | ||
# It is subject to the license terms in the LICENSE file found in the same directory. | ||
# | ||
# Copyright (C) 2021, Shenzhen Institute of Artificial Intelligence and Robotics for Society, all rights reserved. | ||
# Third party copyrights are property of their respective owners. | ||
|
||
import os | ||
import sys | ||
import numpy as ny | ||
import cv2 as cv | ||
|
||
import onnx | ||
from onnx import version_converter | ||
import onnxruntime | ||
from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType | ||
|
||
from transform import Compose, Resize, ColorConvert | ||
|
||
class DataReader(CalibrationDataReader): | ||
def __init__(self, model_path, image_dir, transforms): | ||
model = onnx.load(model_path) | ||
self.input_name = model.graph.input[0].name | ||
self.transforms = transforms | ||
self.data = self.get_calibration_data(image_dir) | ||
self.enum_data_dicts = iter([{self.input_name: x} for x in self.data]) | ||
|
||
def get_next(self): | ||
return next(self.enum_data_dicts, None) | ||
|
||
def get_calibration_data(self, image_dir): | ||
blobs = [] | ||
for image_name in os.listdir(image_dir): | ||
if not image_name.endswith('jpg'): | ||
continue | ||
img = cv.imread(os.path.join(image_dir, image_name)) | ||
img = self.transforms(img) | ||
blob = cv.dnn.blobFromImage(img) | ||
blobs.append(blob) | ||
return blobs | ||
|
||
class Quantize: | ||
def __init__(self, model_path, calibration_image_dir, transforms=Compose(), per_channel=False, act_type='int8', wt_type='int8'): | ||
self.type_dict = {"uint8" : QuantType.QUInt8, "int8" : QuantType.QInt8} | ||
|
||
self.model_path = model_path | ||
self.calibration_image_dir = calibration_image_dir | ||
self.transforms = transforms | ||
self.per_channel = per_channel | ||
self.act_type = act_type | ||
self.wt_type = wt_type | ||
|
||
# data reader | ||
self.dr = DataReader(self.model_path, self.calibration_image_dir, self.transforms) | ||
|
||
def check_opset(self, convert=True): | ||
model = onnx.load(self.model_path) | ||
if model.opset_import[0].version != 11: | ||
print('\tmodel opset version: {}. Converting to opset 11'.format(model.opset_import[0].version)) | ||
# convert opset version to 11 | ||
model_opset11 = version_converter.convert_version(model, 11) | ||
# save converted model | ||
output_name = '{}-opset11.onnx'.format(self.model_path[:-5]) | ||
onnx.save_model(model_opset11, output_name) | ||
# update model_path for quantization | ||
self.model_path = output_name | ||
|
||
def run(self): | ||
print('Quantizing {}: act_type {}, wt_type {}'.format(self.model_path, self.act_type, self.wt_type)) | ||
self.check_opset() | ||
output_name = '{}-act_{}-wt_{}-quantized.onnx'.format(self.model_path[:-5], self.act_type, self.wt_type) | ||
quantize_static(self.model_path, output_name, self.dr, | ||
per_channel=self.per_channel, | ||
weight_type=self.type_dict[self.wt_type], | ||
activation_type=self.type_dict[self.act_type]) | ||
os.remove('augmented_model.onnx') | ||
os.remove('{}-opt.onnx'.format(self.model_path[:-5])) | ||
print('\tQuantized model saved to {}'.format(output_name)) | ||
|
||
|
||
models=dict( | ||
yunet=Quantize(model_path='../../models/face_detection_yunet/face_detection_yunet_2021dec.onnx', | ||
calibration_image_dir='../../benchmark/data/face_detection'), | ||
sface=Quantize(model_path='../../models/face_recognition_sface/face_recognition_sface_2021dec.onnx', | ||
calibration_image_dir='../../benchmark/data/face_recognition', | ||
transforms=Compose([Resize(size=(112, 112))])), | ||
pphumenseg=Quantize(model_path='../../models/human_segmentation_pphumanseg/human_segmentation_pphumanseg_2021oct.onnx', | ||
calibration_image_dir='../../benchmark/data/human_segmentation', | ||
transforms=Compose([Resize(size=(192, 192))])), | ||
ppresnet50=Quantize(model_path='../../models/image_classification_ppresnet/image_classification_ppresnet50_2022jan.onnx', | ||
calibration_image_dir='../../benchmark/data/image_classification', | ||
transforms=Compose([Resize(size=(224, 224))])), | ||
# TBD: DaSiamRPN | ||
youtureid=Quantize(model_path='../../models/person_reid_youtureid/person_reid_youtu_2021nov.onnx', | ||
calibration_image_dir='../../benchmark/data/person_reid', | ||
transforms=Compose([Resize(size=(128, 256))])), | ||
# TBD: DB-EN & DB-CN | ||
crnn_en=Quantize(model_path='../../models/text_recognition_crnn/text_recognition_CRNN_EN_2021sep.onnx', | ||
calibration_image_dir='../../benchmark/data/text', | ||
transforms=Compose([Resize(size=(100, 32)), ColorConvert(ctype=cv.COLOR_BGR2GRAY)])), | ||
crnn_cn=Quantize(model_path='../../models/text_recognition_crnn/text_recognition_CRNN_CN_2021nov.onnx', | ||
calibration_image_dir='../../benchmark/data/text', | ||
transforms=Compose([Resize(size=(100, 32))])) | ||
) | ||
|
||
if __name__ == '__main__': | ||
selected_models = [] | ||
for i in range(1, len(sys.argv)): | ||
selected_models.append(sys.argv[i]) | ||
if not selected_models: | ||
selected_models = list(models.keys()) | ||
print('Models to be quantized: {}'.format(str(selected_models))) | ||
|
||
for selected_model_name in selected_models: | ||
q = models[selected_model_name] | ||
q.run() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
opencv-python>=4.5.4.58 | ||
onnx | ||
onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# This file is part of OpenCV Zoo project. | ||
# It is subject to the license terms in the LICENSE file found in the same directory. | ||
# | ||
# Copyright (C) 2021, Shenzhen Institute of Artificial Intelligence and Robotics for Society, all rights reserved. | ||
# Third party copyrights are property of their respective owners. | ||
|
||
import numpy as numpy | ||
import cv2 as cv | ||
|
||
class Compose: | ||
def __init__(self, transforms=[]): | ||
self.transforms = transforms | ||
|
||
def __call__(self, img): | ||
for t in self.transforms: | ||
img = t(img) | ||
return img | ||
|
||
class Resize: | ||
def __init__(self, size, interpolation=cv.INTER_LINEAR): | ||
self.size = size | ||
self.interpolation = interpolation | ||
|
||
def __call__(self, img): | ||
return cv.resize(img, self.size) | ||
|
||
class ColorConvert: | ||
def __init__(self, ctype): | ||
self.ctype = ctype | ||
|
||
def __call__(self, img): | ||
return cv.cvtColor(img, self.ctype) |