Skip to content

Commit

Permalink
Benchmark framework implementation and 3 models added:
Browse files Browse the repository at this point in the history
* benchmark framework: benchmarks based on configs

* added impl and benchmark for YuNet (face detection)

* added impl and benchmark for DB (text detection)

* added impl and benchmark for CRNN (text recognition)
  • Loading branch information
fengyuentau committed Sep 17, 2021
1 parent af1afb3 commit bfac311
Show file tree
Hide file tree
Showing 26 changed files with 1,649 additions and 4 deletions.
5 changes: 2 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
*.pyc

benchmark/data
benchmark/data/**
**/__pycache__
**/__pycache__/**

.vscode
34 changes: 33 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,38 @@

A zoo for models tuned for OpenCV DNN with benchmarks on different platforms.

Guidelines:
- To clone this repo, please install [git-lfs](https://git-lfs.github.com/), run `git lfs install` and use `git lfs clone https://github.com/opencv/opencv_zoo`.
- To run benchmark on your hardware settings, please refer to [benchmark/README](./benchmark/README.md).

## Models & Benchmarks

Hardware Setup:
- `CPU x86_64`: INTEL CPU i7-5930K @ 3.50GHz, 6 cores, 12 threads.
- `CPU ARM`: Raspberry 4B, BCM2711B0 @ 1.5GHz (Cortex A-72), 4 cores, 4 threads.
<!--
- `GPU CUDA`: NVIDIA Jetson Nano B01, 128-core Maxwell, Quad-core ARM A57 @ 1.43 GHz.
-->

***Important Notes***:
- The time data that shown on the following tables presents the time elapsed from preprocess (resize is excluded), to a forward pass of a network, and postprocess to get final results.
- The time data that shown on the following tables is averaged from a 100-time run.
- View [benchmark/config](./benchmark/config) for more details on benchmarking different models.

<!--
| Model | Input Size | CPU x86_64 (ms) | CPU ARM (ms) | GPU CUDA (ms) |
|-------|------------|-----------------|--------------|---------------|
| [YuNet](./models/face_detection_yunet) | 160x120 | 2.17 | 8.87 | 14.95 |
| [DB](./models/text_detection_db) | 640x480 | 148.65 | 2759.88 | 218.25 |
| [CRNN](./models/text_recognition_crnn) | 100x32 | 23.23 | 235.87 | 195.20 |
-->
| Model | Input Size | CPU x86_64 (ms) | CPU ARM (ms) |
|-------|------------|-----------------|--------------|
| [YuNet](./models/face_detection_yunet) | 160x120 | 2.17 | 8.87 |
| [DB](./models/text_detection_db) | 640x480 | 148.65 | 2759.88 |
| [CRNN](./models/text_recognition_crnn) | 100x32 | 23.23 | 235.87 |


## License

OpenCV Zoo is licensed under the [Apache 2.0 license](./LICENCE). Please refer to the licenses of different models for model weights.
OpenCV Zoo is licensed under the [Apache 2.0 license](./LICENSE). Please refer to licenses of different models.
32 changes: 32 additions & 0 deletions benchmark/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# OpenCV Zoo Benchmark

Benchmarking different models in the zoo.

Data for benchmarking will be downloaded and loaded in [data](./data) based on given config.

Time is measured from data preprocess (resize is excluded), to a forward pass of a network, and postprocess to get final results. The final time data presented is averaged from a 100-time run.

## Preparation

1. Install `python >= 3.6`.
2. Install dependencies: `pip install -r requirements.txt`.

## Benchmarking

Run the following command to benchmark on a given config:

```shell
PYTHONPATH=.. python benchmark.py --cfg ./config/face_detection_yunet.yaml
```

If you are a Windows user and wants to run in CMD/PowerShell, use this command instead:
```shell
set PYTHONPATH=..
python benchmark.py --cfg ./config/face_detection_yunet.yaml
```
<!--
Omit `--cfg` if you want to benchmark all included models:
```shell
PYTHONPATH=.. python benchmark.py
```
-->
182 changes: 182 additions & 0 deletions benchmark/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import os
import argparse

import yaml
import tqdm
import numpy as np
import cv2 as cv

from models import MODELS
from download import Downloader

parser = argparse.ArgumentParser("Benchmarks for OpenCV Zoo.")
parser.add_argument('--cfg', '-c', type=str,
help='Benchmarking on the given config.')
args = parser.parse_args()

class Timer:
def __init__(self):
self._tm = cv.TickMeter()

self._time_record = []
self._average_time = 0
self._calls = 0

def start(self):
self._tm.start()

def stop(self):
self._tm.stop()
self._calls += 1
self._time_record.append(self._tm.getTimeMilli())
self._average_time = sum(self._time_record) / self._calls
self._tm.reset()

def reset(self):
self._time_record = []
self._average_time = 0
self._calls = 0

def getAverageTime(self):
return self._average_time


class Benchmark:
def __init__(self, **kwargs):
self._fileList = kwargs.pop('fileList', None)
assert self._fileList, 'fileList cannot be empty'

backend_id = kwargs.pop('backend', 'default')
available_backends = dict(
default=cv.dnn.DNN_BACKEND_DEFAULT,
# halide=cv.dnn.DNN_BACKEND_HALIDE,
# inference_engine=cv.dnn.DNN_BACKEND_INFERENCE_ENGINE,
opencv=cv.dnn.DNN_BACKEND_OPENCV,
# vkcom=cv.dnn.DNN_BACKEND_VKCOM,
cuda=cv.dnn.DNN_BACKEND_CUDA
)
self._backend = available_backends[backend_id]

target_id = kwargs.pop('target', 'cpu')
available_targets = dict(
cpu=cv.dnn.DNN_TARGET_CPU,
# opencl=cv.dnn.DNN_TARGET_OPENCL,
# opencl_fp16=cv.dnn.DNN_TARGET_OPENCL_FP16,
# myriad=cv.dnn.DNN_TARGET_MYRIAD,
# vulkan=cv.dnn.DNN_TARGET_VULKAN,
# fpga=cv.dnn.DNN_TARGET_FPGA,
cuda=cv.dnn.DNN_TARGET_CUDA,
cuda_fp16=cv.dnn.DNN_TARGET_CUDA_FP16,
# hddl=cv.dnn.DNN_TARGET_HDDL
)
self._target = available_targets[target_id]

self._sizes = kwargs.pop('sizes', None)
self._repeat = kwargs.pop('repeat', 100)
self._parentPath = kwargs.pop('parentPath', 'benchmark/data')
self._useGroundTruth = kwargs.pop('useDetectionLabel', False) # If it is enable, 'sizes' will not work
assert (self._sizes and not self._useGroundTruth) or (not self._sizes and self._useGroundTruth), 'If \'useDetectionLabel\' is True, \'sizes\' should not exist.'

self._timer = Timer()
self._benchmark_results = dict.fromkeys(self._fileList, dict())

if self._useGroundTruth:
self.loadLabel()

def loadLabel(self):
self._labels = dict.fromkeys(self._fileList, None)
for imgName in self._fileList:
self._labels[imgName] = np.loadtxt(os.path.join(self._parentPath, '{}.txt'.format(imgName[:-4])))

def run(self, model):
model.setBackend(self._backend)
model.setTarget(self._target)

for imgName in self._fileList:
img = cv.imread(os.path.join(self._parentPath, imgName))
if self._useGroundTruth:
for idx, gt in enumerate(self._labels[imgName]):
self._benchmark_results[imgName]['gt{}'.format(idx)] = self._run(
model,
img,
gt,
pbar_msg=' {}, gt{}'.format(imgName, idx)
)
else:
if self._sizes is None:
h, w, _ = img.shape
model.setInputSize([w, h])
self._benchmark_results[imgName][str([w, h])] = self._run(
model,
img,
pbar_msg=' {}, original size {}'.format(imgName, str([w, h]))
)
else:
for size in self._sizes:
imgResized = cv.resize(img, size)
model.setInputSize(size)
self._benchmark_results[imgName][str(size)] = self._run(
model,
imgResized,
pbar_msg=' {}, size {}'.format(imgName, str(size))
)

def printResults(self):
print(' Results:')
for imgName, results in self._benchmark_results.items():
print(' image: {}'.format(imgName))
total_latency = 0
for key, latency in results.items():
total_latency += latency
print(' {}, latency: {:.4f} ms'.format(key, latency))
print(' Average latency: {:.4f} ms'.format(total_latency / len(results)))

def _run(self, model, *args, **kwargs):
self._timer.reset()
pbar = tqdm.tqdm(range(self._repeat))
for _ in pbar:
pbar.set_description(kwargs.get('pbar_msg', None))

self._timer.start()
results = model.infer(*args)
self._timer.stop()
return self._timer.getAverageTime()


def build_from_cfg(cfg, registery):
obj_name = cfg.pop('name')
obj = registery.get(obj_name)
return obj(**cfg)

def prepend_pythonpath(cfg, key1, key2):
pythonpath = os.environ['PYTHONPATH']
if cfg[key1][key2].startswith('/'):
return
cfg[key1][key2] = os.path.join(pythonpath, cfg[key1][key2])

if __name__ == '__main__':
assert args.cfg.endswith('yaml'), 'Currently support configs of yaml format only.'
with open(args.cfg, 'r') as f:
cfg = yaml.safe_load(f)

# prepend PYTHONPATH to each path
prepend_pythonpath(cfg, key1='Data', key2='parentPath')
prepend_pythonpath(cfg, key1='Benchmark', key2='parentPath')
prepend_pythonpath(cfg, key1='Model', key2='modelPath')


# Download data if not exist
print('Loading data:')
downloader = Downloader(**cfg['Data'])
downloader.get()

# Instantiate benchmarking
benchmark = Benchmark(**cfg['Benchmark'])

# Instantiate model
model = build_from_cfg(cfg=cfg['Model'], registery=MODELS)

# Run benchmarking
print('Benchmarking {}:'.format(model.name))
benchmark.run(model)
benchmark.printResults()
28 changes: 28 additions & 0 deletions benchmark/config/face_detection_yunet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
Data:
name: "Images for Face Detection"
url: "https://drive.google.com/u/0/uc?id=1lOAliAIeOv4olM65YDzE55kn6XjiX2l6&export=download"
sha: "0ba67a9cfd60f7fdb65cdb7c55a1ce76c1193df1"
filename: "face_detection.zip"
parentPath: "benchmark/data"

Benchmark:
name: "Face Detection Benchmark"
parentPath: "benchmark/data/face_detection"
fileList:
- "group.jpg"
- "concerts.jpg"
- "dance.jpg"
backend: "default"
target: "cpu"
sizes: # [w, h], Omit to run at original scale
- [160, 120]
- [640, 480]
repeat: 100 # default 100

Model:
name: "YuNet"
modelPath: "models/face_detection_yunet/face_detection_yunet.onnx"
confThreshold: 0.6
nmsThreshold: 0.3
topK: 5000
keepTopK: 750
27 changes: 27 additions & 0 deletions benchmark/config/text_detection_db.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
Data:
name: "Images for Text Detection"
url: "https://drive.google.com/u/0/uc?id=1lTQdZUau7ujHBqp0P6M1kccnnJgO-dRj&export=download"
sha: "a40cf095ceb77159ddd2a5902f3b4329696dd866"
filename: "text.zip"
parentPath: "benchmark/data"

Benchmark:
name: "Text Detection Benchmark"
parentPath: "benchmark/data/text"
fileList:
- "1.jpg"
- "2.jpg"
- "3.jpg"
backend: "default"
target: "cpu"
sizes: # [w, h], default original scale
- [640, 480]
repeat: 100

Model:
name: "DB"
modelPath: "models/text_detection_db/text_detection_db.onnx"
binaryThreshold: 0.3
polygonThreshold: 0.5
maxCandidates: 200
unclipRatio: 2.0
22 changes: 22 additions & 0 deletions benchmark/config/text_recognition_crnn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
Data:
name: "Images for Text Detection"
url: "https://drive.google.com/u/0/uc?id=1lTQdZUau7ujHBqp0P6M1kccnnJgO-dRj&export=download"
sha: "a40cf095ceb77159ddd2a5902f3b4329696dd866"
filename: "text.zip"
parentPath: "benchmark/data"

Benchmark:
name: "Text Recognition Benchmark"
parentPath: "benchmark/data/text"
fileList:
- "1.jpg"
- "2.jpg"
- "3.jpg"
backend: "default"
target: "cpu"
useDetectionLabel: True
repeat: 100

Model:
name: "CRNN"
modelPath: "models/text_recognition_crnn/text_recognition_crnn.onnx"
2 changes: 2 additions & 0 deletions benchmark/data/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
!.gitignore
Loading

0 comments on commit bfac311

Please sign in to comment.