Skip to content

Commit

Permalink
fix: fix preprocess for yolo obj det
Browse files Browse the repository at this point in the history
  • Loading branch information
Joker1212 committed Nov 2, 2024
1 parent 71e14c4 commit 097e2df
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 29 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
🔍 使用在线体验找到适合你场景的模型组合

### 在线体验

[modelscope](https://www.modelscope.cn/studios/jockerK/RapidTableDetDemo)
### 效果展示

![res_show.jpg](readme_resource/res_show.jpg)![res_show2.jpg](readme_resource/res_show2.jpg)
Expand Down Expand Up @@ -102,6 +102,7 @@ print(
# from rapid_table_det.utils.visuallize import img_loader, visuallize, extract_table_img
#
# img = img_loader(img_path)
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# file_name_with_ext = os.path.basename(img_path)
# file_name, file_ext = os.path.splitext(file_name_with_ext)
# out_dir = "rapid_table_det/outputs"
Expand Down
45 changes: 23 additions & 22 deletions demo_onnx.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from rapid_table_det.inference import TableDetector

img_path = f"images/weixin.png"
img_path = f"images/WechatIMG149.jpeg"
table_det = TableDetector(
obj_model_type="paddle_obj_det_s", edge_model_type="paddle_edge_det_s"
edge_model_type="yolo_edge_det", obj_model_type="yolo_obj_det"
)

result, elapse = table_det(img_path)
Expand All @@ -11,23 +11,24 @@
f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}"
)
# 输出可视化
# import os
# import cv2
# from rapid_table_det.utils.visuallize import img_loader, visuallize, extract_table_img
#
# img = img_loader(img_path)
# file_name_with_ext = os.path.basename(img_path)
# file_name, file_ext = os.path.splitext(file_name_with_ext)
# out_dir = "rapid_table_det/outputs"
# if not os.path.exists(out_dir):
# os.makedirs(out_dir)
# extract_img = img.copy()
# for i, res in enumerate(result):
# box = res["box"]
# lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"]
# # 带识别框和左上角方向位置
# img = visuallize(img, box, lt, rt, rb, lb)
# # 透视变换提取表格图片
# wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb)
# cv2.imwrite(f"{out_dir}/{file_name}-extract-{i}.jpg", wrapped_img)
# cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img)
import os
import cv2
from rapid_table_det.utils.visuallize import img_loader, visuallize, extract_table_img

img = img_loader(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
file_name_with_ext = os.path.basename(img_path)
file_name, file_ext = os.path.splitext(file_name_with_ext)
out_dir = "rapid_table_det/outputs"
if not os.path.exists(out_dir):
os.makedirs(out_dir)
extract_img = img.copy()
for i, res in enumerate(result):
box = res["box"]
lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"]
# 带识别框和左上角方向位置
img = visuallize(img, box, lt, rt, rb, lb)
# 透视变换提取表格图片
wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb)
cv2.imwrite(f"{out_dir}/{file_name}-extract-{i}.jpg", wrapped_img)
cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img)
9 changes: 3 additions & 6 deletions rapid_table_det/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,7 @@ def __call__(self, img, **kwargs):
return result, time.time() - start

def img_preprocess(self, img, resize_shape=[928, 928]):
# im, new_w, new_h, left, top = ResizePad(img, resize_shape[0])
new_w, new_h = resize_shape
left, top = 0, 0
im = cv2.resize(img, resize_shape, cv2.INTER_LINEAR)
im, new_w, new_h, left, top = ResizePad(img, resize_shape[0])
im = im / 255.0
im = im.transpose((2, 0, 1)).copy()
im = im[None, :].astype("float32")
Expand All @@ -118,8 +115,8 @@ def img_postprocess(self, predict_maps, x_factor, y_factor, left, top, score):
# 从当前行提取边界框坐标
x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3]
# 计算边界框的缩放坐标
xmin = max(int((x - w / 2) * x_factor) - left, 0)
ymin = max(int((y - h / 2) * y_factor) - top, 0)
xmin = max(int((x - w / 2 - left) * x_factor), 0)
ymin = max(int((y - h / 2 - top) * y_factor), 0)
xmax = xmin + int(w * x_factor)
ymax = ymin + int(h * y_factor)
# 将类别ID、得分和框坐标添加到各自的列表中
Expand Down

0 comments on commit 097e2df

Please sign in to comment.