1
1
import copy
2
2
import math
3
- from typing import Optional , Dict , Any
3
+ from typing import Optional , Dict , Any , Tuple
4
4
5
5
import cv2
6
6
import numpy as np
7
7
from skimage import measure
8
-
8
+ import matplotlib . pyplot as plt
9
9
from wired_table_rec .utils import OrtInferSession , resize_img
10
10
from wired_table_rec .utils_table_line_rec import (
11
11
get_table_line ,
@@ -31,22 +31,31 @@ def __init__(self, model_path: Optional[str] = None):
31
31
32
32
self .session = OrtInferSession (model_path )
33
33
34
- def __call__ (self , img : np .ndarray , ** kwargs ) -> Optional [np .ndarray ]:
34
+ def __call__ (
35
+ self , img : np .ndarray , ** kwargs
36
+ ) -> Tuple [Optional [np .ndarray ], Optional [np .ndarray ]]:
35
37
img_info = self .preprocess (img )
36
38
pred = self .infer (img_info )
37
- polygons = self .postprocess (img , pred , ** kwargs )
39
+ polygons , rotated_polygons = self .postprocess (img , pred , ** kwargs )
38
40
if polygons .size == 0 :
39
- return None
41
+ return None , None
40
42
polygons = polygons .reshape (polygons .shape [0 ], 4 , 2 )
41
43
polygons [:, 3 , :], polygons [:, 1 , :] = (
42
44
polygons [:, 1 , :].copy (),
43
45
polygons [:, 3 , :].copy (),
44
46
)
47
+ rotated_polygons = rotated_polygons .reshape (rotated_polygons .shape [0 ], 4 , 2 )
48
+ rotated_polygons [:, 3 , :], rotated_polygons [:, 1 , :] = (
49
+ rotated_polygons [:, 1 , :].copy (),
50
+ rotated_polygons [:, 3 , :].copy (),
51
+ )
45
52
_ , idx = sorted_ocr_boxes (
46
- [box_4_2_poly_to_box_4_1 (poly_box ) for poly_box in polygons ], threhold = 0.4
53
+ [box_4_2_poly_to_box_4_1 (poly_box ) for poly_box in rotated_polygons ],
54
+ threhold = 0.4 ,
47
55
)
48
56
polygons = polygons [idx ]
49
- return polygons
57
+ rotated_polygons = rotated_polygons [idx ]
58
+ return polygons , rotated_polygons
50
59
51
60
def preprocess (self , img ) -> Dict [str , Any ]:
52
61
scale = (self .inp_height , self .inp_width )
@@ -86,7 +95,8 @@ def postprocess(self, img, pred, **kwargs):
86
95
extend_line = (
87
96
kwargs .get ("extend_line" , enhance_box_line ) if kwargs else enhance_box_line
88
97
) # 是否进行线段延长使得端点连接
89
-
98
+ # 是否进行旋转修正
99
+ rotated_fix = kwargs .get ("rotated_fix" ) if kwargs else True
90
100
ori_shape = img .shape
91
101
pred = np .uint8 (pred )
92
102
hpred = copy .deepcopy (pred ) # 横线
@@ -120,8 +130,109 @@ def postprocess(self, img, pred, **kwargs):
120
130
colboxes += rboxes_col_
121
131
if extend_line :
122
132
rowboxes , colboxes = final_adjust_lines (rowboxes , colboxes )
123
- tmp = np .zeros (img .shape [:2 ], dtype = "uint8" )
124
- tmp = draw_lines (tmp , rowboxes + colboxes , color = 255 , lineW = 2 )
133
+ line_img = np .zeros (img .shape [:2 ], dtype = "uint8" )
134
+ line_img = draw_lines (line_img , rowboxes + colboxes , color = 255 , lineW = 2 )
135
+ rotated_angle = self .cal_rotate_angle (line_img )
136
+ if rotated_fix and abs (rotated_angle ) > 0.3 :
137
+ rotated_line_img = self .rotate_image (line_img , rotated_angle )
138
+ rotated_polygons = self .cal_region_boxes (rotated_line_img )
139
+ polygons = self .unrotate_polygons (
140
+ rotated_polygons , rotated_angle , line_img .shape
141
+ )
142
+ else :
143
+ polygons = self .cal_region_boxes (line_img )
144
+ rotated_polygons = polygons .copy ()
145
+ return polygons , rotated_polygons
146
+
147
+ def find_max_corners (self , line_img ):
148
+ # 找到所有轮廓
149
+ contours , _ = cv2 .findContours (
150
+ line_img , cv2 .RETR_EXTERNAL , cv2 .CHAIN_APPROX_SIMPLE
151
+ )
152
+
153
+ # 如果没有找到轮廓,返回空列表
154
+ if not contours :
155
+ return []
156
+
157
+ # 找到面积最大的轮廓
158
+ max_contour = max (contours , key = cv2 .contourArea )
159
+ # 计算最大轮廓的最小外接矩形
160
+ rect = cv2 .minAreaRect (max_contour )
161
+
162
+ # 获取最小外接矩形的四个角点
163
+ box = cv2 .boxPoints (rect )
164
+ box = np .int0 (box )
165
+ #
166
+ # 对角点进行排序
167
+ # 计算中心点
168
+ center = np .mean (box , axis = 0 )
169
+
170
+ # 计算每个点与中心点的角度
171
+ angles = np .arctan2 (box [:, 1 ] - center [1 ], box [:, 0 ] - center [0 ])
172
+
173
+ # 按角度排序
174
+ sorted_indices = np .argsort (angles )
175
+ sorted_box = box [sorted_indices ]
176
+
177
+ # 确保顺序为左上、右上、右下、左下
178
+ top_left = sorted_box [0 ]
179
+ top_right = sorted_box [1 ]
180
+ bottom_right = sorted_box [2 ]
181
+ bottom_left = sorted_box [3 ]
182
+
183
+ # 创建一个纯黑色背景图像
184
+ black_img = np .zeros_like (line_img )
185
+
186
+ # 可视化最大轮廓和四个角点
187
+ plt .figure (figsize = (10 , 10 ))
188
+ plt .imshow (black_img , cmap = "gray" )
189
+ plt .title ("Max Contour and Corners on Black Background" )
190
+
191
+ # 绘制最大轮廓
192
+ max_contour = max_contour .reshape (- 1 , 2 )
193
+ plt .plot (max_contour [:, 0 ], max_contour [:, 1 ], "b-" , linewidth = 2 )
194
+
195
+ # 绘制四个角点
196
+ plt .scatter (
197
+ [top_left [0 ], top_right [0 ], bottom_right [0 ], bottom_left [0 ]],
198
+ [top_left [1 ], top_right [1 ], bottom_right [1 ], bottom_left [1 ]],
199
+ c = "g" ,
200
+ s = 100 ,
201
+ marker = "o" ,
202
+ )
203
+
204
+ plt .axis ("off" )
205
+ plt .show ()
206
+
207
+ return [top_left , top_right , bottom_right , bottom_left ]
208
+
209
+ def extend_image_and_adjust_coordinates (self , img , corners , polygons ):
210
+ # 计算扩展边界
211
+ min_x = min (point [0 ] for point in corners )
212
+ min_y = min (point [1 ] for point in corners )
213
+ max_x = max (point [0 ] for point in corners )
214
+ max_y = max (point [1 ] for point in corners )
215
+
216
+ # 计算扩展的宽度和高度
217
+ left = - min_x if min_x < 0 else 0
218
+ top = - min_y if min_y < 0 else 0
219
+ right = max_x - img .shape [1 ] if max_x > img .shape [1 ] else 0
220
+ bottom = max_y - img .shape [0 ] if max_y > img .shape [0 ] else 0
221
+
222
+ # 扩展图像
223
+ new_width = img .shape [1 ] + left + right
224
+ new_height = img .shape [0 ] + top + bottom
225
+ extended_img = np .zeros ((new_height , new_width ), dtype = img .dtype )
226
+ extended_img [top : top + img .shape [0 ], left : left + img .shape [1 ]] = img
227
+
228
+ # 调整角点和多边形坐标
229
+ adjusted_corners = [(point [0 ] + left , point [1 ] + top ) for point in corners ]
230
+ adjusted_polygons = polygons .copy ()
231
+ adjusted_polygons [:, 0 ::2 ] += left
232
+ adjusted_polygons [:, 1 ::2 ] += top
233
+ return extended_img , adjusted_corners , adjusted_polygons
234
+
235
+ def cal_region_boxes (self , tmp ):
125
236
labels = measure .label (tmp < 255 , connectivity = 2 ) # 8连通区域标记
126
237
regions = measure .regionprops (labels )
127
238
ceilboxes = min_area_rect_box (
@@ -133,3 +244,52 @@ def postprocess(self, img, pred, **kwargs):
133
244
adjust_box = False ,
134
245
) # 最后一个参数改为False
135
246
return np .array (ceilboxes )
247
+
248
+ def cal_rotate_angle (self , tmp ):
249
+ # 计算最外侧的旋转框
250
+ contours , _ = cv2 .findContours (tmp , cv2 .RETR_EXTERNAL , cv2 .CHAIN_APPROX_SIMPLE )
251
+ if not contours :
252
+ return 0
253
+ largest_contour = max (contours , key = cv2 .contourArea )
254
+ rect = cv2 .minAreaRect (largest_contour )
255
+ # 计算旋转角度
256
+ angle = rect [2 ]
257
+ if angle < - 45 :
258
+ angle += 90
259
+ elif angle > 45 :
260
+ angle -= 90
261
+ return angle
262
+
263
+ def rotate_image (self , image , angle ):
264
+ # 获取图像的中心点
265
+ (h , w ) = image .shape [:2 ]
266
+ center = (w // 2 , h // 2 )
267
+
268
+ # 计算旋转矩阵
269
+ M = cv2 .getRotationMatrix2D (center , angle , 1.0 )
270
+
271
+ # 进行旋转
272
+ rotated_image = cv2 .warpAffine (
273
+ image , M , (w , h ), flags = cv2 .INTER_NEAREST , borderMode = cv2 .BORDER_REPLICATE
274
+ )
275
+
276
+ return rotated_image
277
+
278
+ def unrotate_polygons (
279
+ self , polygons : np .ndarray , angle : float , img_shape : tuple
280
+ ) -> np .ndarray :
281
+ # 将多边形旋转回原始位置
282
+ (h , w ) = img_shape
283
+ center = (w // 2 , h // 2 )
284
+ M_inv = cv2 .getRotationMatrix2D (center , - angle , 1.0 )
285
+
286
+ # 将 (N, 8) 转换为 (N, 4, 2)
287
+ polygons_reshaped = polygons .reshape (- 1 , 4 , 2 )
288
+
289
+ # 批量逆旋转
290
+ unrotated_polygons = cv2 .transform (polygons_reshaped , M_inv )
291
+
292
+ # 将 (N, 4, 2) 转换回 (N, 8)
293
+ unrotated_polygons = unrotated_polygons .reshape (- 1 , 8 )
294
+
295
+ return unrotated_polygons
0 commit comments