Skip to content

Commit 725777a

Browse files
authored
Merge pull request #87 from RapidAI/optim_wired_table_rotated
feat: optim rotated wired table rec
2 parents 639f6f7 + 2eed579 commit 725777a

File tree

6 files changed

+206
-31
lines changed

6 files changed

+206
-31
lines changed

README.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
[English](README_en.md) | 简体中文
1515
</div>
1616

17-
### 最近更新
18-
- **2024.11.12**
19-
- 抽离模型识别和处理过程核心阈值,方便大家进行微调适配自己的场景[输入参数](#核心参数)
17+
### 最近更新
2018
- **2024.11.16**
2119
- 补充文档扭曲矫正方案,可作为前置处理 [RapidUnwrap](https://github.com/Joker1212/RapidUnWrap)
2220
- **2024.11.22**
2321
- 支持单字符匹配方案,需要RapidOCR>=1.4.0
22+
- **2024.11.28**
23+
- wiredV2模型提升对轻度旋转表格识别准确率,参见[输入参数](#核心参数)
2424

2525
### 简介
2626
💖该仓库是用来对文档中表格做结构化识别的推理库,包括来自阿里读光有线和无线表格识别模型,llaipython(微信)贡献的有线表格模型,网易Qanything内置表格分类模型等。\
@@ -132,6 +132,7 @@ ocr_res = trans_char_ocr_res(ocr_res)
132132

133133
#### 表格旋转及透视修正
134134
##### 1.简单背景,小角度场景
135+
最新wiredV2模型自适应小角度旋转
135136
```python
136137
import cv2
137138

@@ -178,6 +179,9 @@ html, elasp, polygons, logic_points, ocr_res = wired_table_rec(
178179
ocr_result, # 输入rapidOCR识别结果,不传默认使用内部rapidocr模型
179180
version="v2", #默认使用v2线框模型,切换阿里读光模型可改为v1
180181
enhance_box_line=True, # 识别框切割增强(关闭避免多余切割,开启减少漏切割),默认为True
182+
col_threshold=15, # 识别框左边界x坐标差值小于col_threshold的默认同列
183+
row_threshold=10, # 识别框上边界y坐标差值小于row_threshold的默认同行
184+
rotated_fix=True, # wiredV2支持,轻度旋转(-45°~45°)矫正,默认为True
181185
need_ocr=True, # 是否进行OCR识别, 默认为True
182186
rec_again=True,# 是否针对未识别到文字的表格框,进行单独截取再识别,默认为True
183187
)

tests/test_wired_table_rec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_squeeze_bug():
4343
ocr_result, _ = ocr_engine(img_path)
4444
table_str, *_ = table_recog(str(img_path), ocr_result)
4545
td_nums = get_td_nums(table_str)
46-
assert td_nums >= 192
46+
assert td_nums >= 160
4747

4848

4949
@pytest.mark.parametrize(

wired_table_rec/main.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,23 @@ def __call__(
6060
s = time.perf_counter()
6161
rec_again = True
6262
need_ocr = True
63+
col_threshold = 15
64+
row_threshold = 10
6365
if kwargs:
6466
rec_again = kwargs.get("rec_again", True)
6567
need_ocr = kwargs.get("need_ocr", True)
68+
col_threshold = kwargs.get("col_threshold", 15)
69+
row_threshold = kwargs.get("row_threshold", 10)
6670
img = self.load_img(img)
67-
polygons = self.table_line_rec(img, **kwargs)
71+
polygons, rotated_polygons = self.table_line_rec(img, **kwargs)
6872
if polygons is None:
6973
logging.warning("polygons is None.")
7074
return "", 0.0, None, None, None
7175

7276
try:
73-
table_res, logi_points = self.table_recover(polygons)
77+
table_res, logi_points = self.table_recover(
78+
rotated_polygons, row_threshold, col_threshold
79+
)
7480
# 将坐标由逆时针转为顺时针方向,后续处理与无线表格对齐
7581
polygons[:, 1, :], polygons[:, 3, :] = (
7682
polygons[:, 3, :].copy(),

wired_table_rec/table_line_rec.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- encoding: utf-8 -*-
22
# @Author: SWHL
33
# @Contact: [email protected]
4-
from typing import Any, Dict, Optional
4+
from typing import Any, Dict, Optional, Tuple
55

66
import cv2
77
import numpy as np
@@ -36,12 +36,14 @@ def __init__(self, model_path: Optional[str] = None):
3636

3737
self.session = OrtInferSession(model_path)
3838

39-
def __call__(self, img: np.ndarray, **kwargs) -> Optional[np.ndarray]:
39+
def __call__(
40+
self, img: np.ndarray, **kwargs
41+
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
4042
img_info = self.preprocess(img)
4143
pred = self.infer(img_info)
4244
polygons = self.postprocess(pred)
4345
if polygons.size == 0:
44-
return None
46+
return None, None
4547

4648
polygons = polygons.reshape(polygons.shape[0], 4, 2)
4749
del_idxs = filter_duplicated_box(
@@ -53,7 +55,7 @@ def __call__(self, img: np.ndarray, **kwargs) -> Optional[np.ndarray]:
5355
)
5456
polygons = polygons[idx]
5557
polygons = merge_adjacent_polys(polygons)
56-
return polygons
58+
return polygons, polygons
5759

5860
def preprocess(self, img) -> Dict[str, Any]:
5961
height, width = img.shape[:2]

wired_table_rec/table_line_rec_plus.py

Lines changed: 170 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import copy
22
import math
3-
from typing import Optional, Dict, Any
3+
from typing import Optional, Dict, Any, Tuple
44

55
import cv2
66
import numpy as np
77
from skimage import measure
8-
8+
import matplotlib.pyplot as plt
99
from wired_table_rec.utils import OrtInferSession, resize_img
1010
from wired_table_rec.utils_table_line_rec import (
1111
get_table_line,
@@ -31,22 +31,31 @@ def __init__(self, model_path: Optional[str] = None):
3131

3232
self.session = OrtInferSession(model_path)
3333

34-
def __call__(self, img: np.ndarray, **kwargs) -> Optional[np.ndarray]:
34+
def __call__(
35+
self, img: np.ndarray, **kwargs
36+
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
3537
img_info = self.preprocess(img)
3638
pred = self.infer(img_info)
37-
polygons = self.postprocess(img, pred, **kwargs)
39+
polygons, rotated_polygons = self.postprocess(img, pred, **kwargs)
3840
if polygons.size == 0:
39-
return None
41+
return None, None
4042
polygons = polygons.reshape(polygons.shape[0], 4, 2)
4143
polygons[:, 3, :], polygons[:, 1, :] = (
4244
polygons[:, 1, :].copy(),
4345
polygons[:, 3, :].copy(),
4446
)
47+
rotated_polygons = rotated_polygons.reshape(rotated_polygons.shape[0], 4, 2)
48+
rotated_polygons[:, 3, :], rotated_polygons[:, 1, :] = (
49+
rotated_polygons[:, 1, :].copy(),
50+
rotated_polygons[:, 3, :].copy(),
51+
)
4552
_, idx = sorted_ocr_boxes(
46-
[box_4_2_poly_to_box_4_1(poly_box) for poly_box in polygons], threhold=0.4
53+
[box_4_2_poly_to_box_4_1(poly_box) for poly_box in rotated_polygons],
54+
threhold=0.4,
4755
)
4856
polygons = polygons[idx]
49-
return polygons
57+
rotated_polygons = rotated_polygons[idx]
58+
return polygons, rotated_polygons
5059

5160
def preprocess(self, img) -> Dict[str, Any]:
5261
scale = (self.inp_height, self.inp_width)
@@ -86,7 +95,8 @@ def postprocess(self, img, pred, **kwargs):
8695
extend_line = (
8796
kwargs.get("extend_line", enhance_box_line) if kwargs else enhance_box_line
8897
) # 是否进行线段延长使得端点连接
89-
98+
# 是否进行旋转修正
99+
rotated_fix = kwargs.get("rotated_fix") if kwargs else True
90100
ori_shape = img.shape
91101
pred = np.uint8(pred)
92102
hpred = copy.deepcopy(pred) # 横线
@@ -120,8 +130,109 @@ def postprocess(self, img, pred, **kwargs):
120130
colboxes += rboxes_col_
121131
if extend_line:
122132
rowboxes, colboxes = final_adjust_lines(rowboxes, colboxes)
123-
tmp = np.zeros(img.shape[:2], dtype="uint8")
124-
tmp = draw_lines(tmp, rowboxes + colboxes, color=255, lineW=2)
133+
line_img = np.zeros(img.shape[:2], dtype="uint8")
134+
line_img = draw_lines(line_img, rowboxes + colboxes, color=255, lineW=2)
135+
rotated_angle = self.cal_rotate_angle(line_img)
136+
if rotated_fix and abs(rotated_angle) > 0.3:
137+
rotated_line_img = self.rotate_image(line_img, rotated_angle)
138+
rotated_polygons = self.cal_region_boxes(rotated_line_img)
139+
polygons = self.unrotate_polygons(
140+
rotated_polygons, rotated_angle, line_img.shape
141+
)
142+
else:
143+
polygons = self.cal_region_boxes(line_img)
144+
rotated_polygons = polygons.copy()
145+
return polygons, rotated_polygons
146+
147+
def find_max_corners(self, line_img):
148+
# 找到所有轮廓
149+
contours, _ = cv2.findContours(
150+
line_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
151+
)
152+
153+
# 如果没有找到轮廓,返回空列表
154+
if not contours:
155+
return []
156+
157+
# 找到面积最大的轮廓
158+
max_contour = max(contours, key=cv2.contourArea)
159+
# 计算最大轮廓的最小外接矩形
160+
rect = cv2.minAreaRect(max_contour)
161+
162+
# 获取最小外接矩形的四个角点
163+
box = cv2.boxPoints(rect)
164+
box = np.int0(box)
165+
#
166+
# 对角点进行排序
167+
# 计算中心点
168+
center = np.mean(box, axis=0)
169+
170+
# 计算每个点与中心点的角度
171+
angles = np.arctan2(box[:, 1] - center[1], box[:, 0] - center[0])
172+
173+
# 按角度排序
174+
sorted_indices = np.argsort(angles)
175+
sorted_box = box[sorted_indices]
176+
177+
# 确保顺序为左上、右上、右下、左下
178+
top_left = sorted_box[0]
179+
top_right = sorted_box[1]
180+
bottom_right = sorted_box[2]
181+
bottom_left = sorted_box[3]
182+
183+
# 创建一个纯黑色背景图像
184+
black_img = np.zeros_like(line_img)
185+
186+
# 可视化最大轮廓和四个角点
187+
plt.figure(figsize=(10, 10))
188+
plt.imshow(black_img, cmap="gray")
189+
plt.title("Max Contour and Corners on Black Background")
190+
191+
# 绘制最大轮廓
192+
max_contour = max_contour.reshape(-1, 2)
193+
plt.plot(max_contour[:, 0], max_contour[:, 1], "b-", linewidth=2)
194+
195+
# 绘制四个角点
196+
plt.scatter(
197+
[top_left[0], top_right[0], bottom_right[0], bottom_left[0]],
198+
[top_left[1], top_right[1], bottom_right[1], bottom_left[1]],
199+
c="g",
200+
s=100,
201+
marker="o",
202+
)
203+
204+
plt.axis("off")
205+
plt.show()
206+
207+
return [top_left, top_right, bottom_right, bottom_left]
208+
209+
def extend_image_and_adjust_coordinates(self, img, corners, polygons):
210+
# 计算扩展边界
211+
min_x = min(point[0] for point in corners)
212+
min_y = min(point[1] for point in corners)
213+
max_x = max(point[0] for point in corners)
214+
max_y = max(point[1] for point in corners)
215+
216+
# 计算扩展的宽度和高度
217+
left = -min_x if min_x < 0 else 0
218+
top = -min_y if min_y < 0 else 0
219+
right = max_x - img.shape[1] if max_x > img.shape[1] else 0
220+
bottom = max_y - img.shape[0] if max_y > img.shape[0] else 0
221+
222+
# 扩展图像
223+
new_width = img.shape[1] + left + right
224+
new_height = img.shape[0] + top + bottom
225+
extended_img = np.zeros((new_height, new_width), dtype=img.dtype)
226+
extended_img[top : top + img.shape[0], left : left + img.shape[1]] = img
227+
228+
# 调整角点和多边形坐标
229+
adjusted_corners = [(point[0] + left, point[1] + top) for point in corners]
230+
adjusted_polygons = polygons.copy()
231+
adjusted_polygons[:, 0::2] += left
232+
adjusted_polygons[:, 1::2] += top
233+
return extended_img, adjusted_corners, adjusted_polygons
234+
235+
def cal_region_boxes(self, tmp):
125236
labels = measure.label(tmp < 255, connectivity=2) # 8连通区域标记
126237
regions = measure.regionprops(labels)
127238
ceilboxes = min_area_rect_box(
@@ -133,3 +244,52 @@ def postprocess(self, img, pred, **kwargs):
133244
adjust_box=False,
134245
) # 最后一个参数改为False
135246
return np.array(ceilboxes)
247+
248+
def cal_rotate_angle(self, tmp):
249+
# 计算最外侧的旋转框
250+
contours, _ = cv2.findContours(tmp, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
251+
if not contours:
252+
return 0
253+
largest_contour = max(contours, key=cv2.contourArea)
254+
rect = cv2.minAreaRect(largest_contour)
255+
# 计算旋转角度
256+
angle = rect[2]
257+
if angle < -45:
258+
angle += 90
259+
elif angle > 45:
260+
angle -= 90
261+
return angle
262+
263+
def rotate_image(self, image, angle):
264+
# 获取图像的中心点
265+
(h, w) = image.shape[:2]
266+
center = (w // 2, h // 2)
267+
268+
# 计算旋转矩阵
269+
M = cv2.getRotationMatrix2D(center, angle, 1.0)
270+
271+
# 进行旋转
272+
rotated_image = cv2.warpAffine(
273+
image, M, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_REPLICATE
274+
)
275+
276+
return rotated_image
277+
278+
def unrotate_polygons(
279+
self, polygons: np.ndarray, angle: float, img_shape: tuple
280+
) -> np.ndarray:
281+
# 将多边形旋转回原始位置
282+
(h, w) = img_shape
283+
center = (w // 2, h // 2)
284+
M_inv = cv2.getRotationMatrix2D(center, -angle, 1.0)
285+
286+
# 将 (N, 8) 转换为 (N, 4, 2)
287+
polygons_reshaped = polygons.reshape(-1, 4, 2)
288+
289+
# 批量逆旋转
290+
unrotated_polygons = cv2.transform(polygons_reshaped, M_inv)
291+
292+
# 将 (N, 4, 2) 转换回 (N, 8)
293+
unrotated_polygons = unrotated_polygons.reshape(-1, 8)
294+
295+
return unrotated_polygons

0 commit comments

Comments
 (0)