Skip to content

Commit c49e2f0

Browse files
committed
support Parallel
1 parent 4ba0191 commit c49e2f0

File tree

4 files changed

+105
-25
lines changed

4 files changed

+105
-25
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ data
22
model
33
logs
44
.idea/
5-
tests/
5+
tests/*
6+
asset/*
67
*.egg-info/
78
dist/

dlocr/__main__.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import argparse
2-
from datetime import datetime
2+
import time
33

44
import keras.backend as K
55

6-
from dlocr import default_dict_path, default_densenet_config_path, default_densenet_weight_path, default_ctpn_config_path, \
6+
from dlocr import default_dict_path, default_densenet_config_path, default_densenet_weight_path, \
7+
default_ctpn_config_path, \
78
default_ctpn_weight_path, get_session, TextDetectionApp
89

910
if __name__ == '__main__':
@@ -30,7 +31,7 @@
3031
dict_path=args.dict_file_path,
3132
ctpn_config_path=args.ctpn_config_path,
3233
densenet_config_path=args.densenet_config_path)
33-
start_time = datetime.now()
34-
for rect, line in app.detect(args.image_path, args.adjust):
35-
print(line)
36-
print(f"cost {(datetime.now() - start_time).microseconds / 1000}ms")
34+
start_time = time.time()
35+
_, texts = app.detect(args.image_path, args.adjust)
36+
print('\n'.join(texts))
37+
print(f"cost {(time.time() - start_time) * 1000}ms")

dlocr/densenet/core.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
import os
2+
from concurrent.futures import ThreadPoolExecutor
33

44
import keras.backend as K
55
import numpy as np
@@ -57,6 +57,40 @@ def _ctc_loss(args):
5757
return K.ctc_batch_cost(labels, y_pred, input_length, label_length)
5858

5959

60+
def single_img_process(img):
61+
im = img.convert('L')
62+
scale = im.size[1] * 1.0 / 32
63+
w = im.size[0] / scale
64+
w = int(w)
65+
66+
im = im.resize((w, 32), Image.ANTIALIAS)
67+
img = np.array(im).astype(np.float32) / 255.0 - 0.5
68+
img = img.reshape((32, w, 1))
69+
return img
70+
71+
72+
def pad_img(img, len, value):
73+
out = np.ones(shape=(32, len, 1)) * value
74+
out[:, :img.shape[1], :] = img
75+
return out
76+
77+
78+
def process_imgs(imgs):
79+
tmp = []
80+
with ThreadPoolExecutor() as executor:
81+
for img in executor.map(single_img_process, imgs):
82+
tmp.append(img)
83+
84+
max_len = max([img.shape[1] for img in tmp])
85+
86+
output = []
87+
with ThreadPoolExecutor() as executor:
88+
for img in executor.map(lambda img: pad_img(img, max_len, 0.5), tmp):
89+
output.append(img)
90+
91+
return np.array(output)
92+
93+
6094
class DenseNetOCR:
6195

6296
def __init__(self,
@@ -164,14 +198,28 @@ def predict(self, image, id_to_char):
164198
X = np.array([X])
165199

166200
y_pred = self.base_model.predict(X)
167-
argmax = np.argmax(y_pred, axis=2)[0]
168201

169202
y_pred = y_pred[:, :, :]
170203
out = K.get_value(K.ctc_decode(y_pred, input_length=np.ones(y_pred.shape[0]) * y_pred.shape[1], )[0][0])[:, :]
171204
out = u''.join([id_to_char[x] for x in out[0]])
172205

173206
return out, im
174207

208+
def predict_multi(self, images, id_to_char):
209+
210+
def single_text(out):
211+
return u''.join(['' if x == -1 else id_to_char[x] for x in out])
212+
213+
X = process_imgs(images)
214+
y_pred = self.base_model.predict_on_batch(X)
215+
outs = K.get_value(K.ctc_decode(y_pred, input_length=np.ones(y_pred.shape[0]) * y_pred.shape[1], )[0][0])[:, :]
216+
texts = []
217+
with ThreadPoolExecutor() as executor:
218+
for text in executor.map(single_text, outs):
219+
texts.append(text)
220+
221+
return texts
222+
175223
@staticmethod
176224
def save_config(obj, config_path: str):
177225
with open(config_path, 'w+') as outfile:

dlocr/text_detection_app.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
from concurrent.futures import ThreadPoolExecutor
13
from math import *
24

35
import cv2
@@ -7,7 +9,6 @@
79
from dlocr.ctpn import CTPN
810
from dlocr.densenet import DenseNetOCR
911
from dlocr.densenet import load_dict
10-
import os
1112

1213

1314
def dumpRotateImage(img, degree, pt1, pt2, pt3, pt4):
@@ -65,21 +66,36 @@ def single_text_detect(rec, ocr, id_to_char, img, adjust):
6566
return image, text
6667

6768

68-
def model(ctpn, ocr, id_to_char, image_path, adjust):
69-
text_recs, img = ctpn.predict(image_path, mode=2) # 得到所有的检测框
70-
text_recs = sort_box(text_recs)
71-
results = []
69+
def clip_single_img(bbox, img, xDim, yDim, adjust):
70+
xlength = int((bbox[2] - bbox[0]) * 0.1)
71+
ylength = int((bbox[3] - bbox[1]) * 0.2)
72+
if adjust:
73+
pt1 = (max(1, bbox[0] - xlength), max(1, bbox[1] - ylength))
74+
pt2 = (bbox[2], bbox[3])
75+
pt3 = (min(bbox[6] + xlength, xDim - 2), min(yDim - 2, bbox[7] + ylength))
76+
pt4 = (bbox[4], bbox[5])
77+
else:
78+
pt1 = (max(1, bbox[0]), max(1, bbox[1]))
79+
pt2 = (bbox[2], bbox[3])
80+
pt3 = (min(bbox[6], xDim - 2), min(yDim - 2, bbox[7]))
81+
pt4 = (bbox[4], bbox[5])
7282

73-
for index, rec in enumerate(text_recs):
74-
image, text = single_text_detect(rec, ocr, id_to_char, img, adjust) # 识别文字
75-
# plt.subplot(len(text_recs), 1, index + 1)
76-
# plt.imshow(image)
77-
if text is not None and len(text) > 0:
78-
results.append((rec, text))
83+
degree = degrees(atan2(pt2[1] - pt1[1], pt2[0] - pt1[0])) # 图像倾斜角度
7984

80-
# plt.show()
85+
partImg = dumpRotateImage(img, degree, pt1, pt2, pt3, pt4)
86+
image = Image.fromarray(partImg)
87+
return image
8188

82-
return results
89+
90+
def clip_imgs_with_bboxes(bboxes, img, adjust):
91+
xDim, yDim = img.shape[1], img.shape[0]
92+
93+
imgs = []
94+
with ThreadPoolExecutor() as executor:
95+
for img in executor.map(lambda t: clip_single_img(t[0], t[1], xDim, yDim, adjust),
96+
map(lambda bbox: (bbox, img), bboxes)):
97+
imgs.append(img)
98+
return imgs
8399

84100

85101
class TextDetectionApp:
@@ -117,16 +133,30 @@ def __init__(self,
117133
else:
118134
self.ocr = DenseNetOCR(num_classes=len(self.id_to_char))
119135

120-
def detect(self, image_path, adjust=True):
136+
def detect(self, image_path, adjust=True, parallel=True):
121137
"""
122138
139+
:param parallel: 是否并行处理
123140
:param image_path: 图像路径
124141
:param adjust: 是否调整检测框
125142
:return:
126143
"""
127144
if not os.path.exists(image_path):
128145
raise ValueError(f"The image path: {image_path} not exists!")
129-
return model(self.ctpn, self.ocr, self.id_to_char, image_path, adjust)
130-
146+
text_recs, img = self.ctpn.predict(image_path, mode=2) # 得到所有的检测框
147+
text_recs = sort_box(text_recs)
131148

149+
if parallel:
150+
imgs = clip_imgs_with_bboxes(text_recs, img, adjust)
132151

152+
texts = self.ocr.predict_multi(imgs, id_to_char=self.id_to_char)
153+
else:
154+
texts = []
155+
for index, rec in enumerate(text_recs):
156+
image, text = single_text_detect(rec, self.ocr, self.id_to_char, img, adjust) # 识别文字
157+
# plt.subplot(len(text_recs), 1, index + 1)
158+
# plt.imshow(image)
159+
if text is not None and len(text) > 0:
160+
texts.append(text)
161+
162+
return text_recs, texts

0 commit comments

Comments
 (0)