Skip to content

Commit 103602a

Browse files
authored
Merge pull request #9 from RapidAI/develop
fix: fixed issue #3 #7 #8
2 parents c2721b0 + 3cd44e1 commit 103602a

File tree

8 files changed

+171
-40
lines changed

8 files changed

+171
-40
lines changed

1.jpg

51.2 KB
Loading

README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,18 @@
2929
| `yolov8n_layout_report`| 研报 | `yolov8n_layout_report.onnx` | `['Text', 'Title', 'Header', 'Footer', 'Figure', 'Table', 'Toc', 'Figure caption', 'Table caption']` |
3030
| `yolov8n_layout_publaynet`| 英文 | `yolov8n_layout_publaynet.onnx` | `["Text", "Title", "List", "Table", "Figure"]` |
3131
| `yolov8n_layout_general6`| 通用 | `yolov8n_layout_general6.onnx` | `["Text", "Title", "Figure", "Table", "Caption", "Equation"]` |
32-
| 🔥`doclayout_yolo`| 通用 | `doclayout_yolo_docstructbench_imgsz1024.onnx` | `['title', 'text', 'abandon', 'figure', 'figure_caption', 'table', 'table_caption', 'table_footnote', 'isolate_formula', 'formula_caption']` |
32+
| 🔥`doclayout_docstructbench`| 通用 | `doclayout_yolo_docstructbench_imgsz1024.onnx` | `['title', 'plain text', 'abandon', 'figure', 'figure_caption', 'table', 'table_caption', 'table_footnote', 'isolate_formula', 'formula_caption']` |
33+
| 🔥`doclayout_d4la`| 通用 | `doclayout_yolo_d4la_imgsz1600_docsynth_pretrain.onnx` | `['DocTitle', 'ParaTitle', 'ParaText', 'ListText', 'RegionTitle', 'Date', 'LetterHead', 'LetterDear', 'LetterSign', 'Question', 'OtherText', 'RegionKV', 'RegionList', 'Abstract', 'Author', 'TableName', 'Table', 'Figure', 'FigureName', 'Equation', 'Reference', 'Footer', 'PageHeader', 'PageFooter', 'Number', 'Catalog', 'PageNumber']` |
34+
| 🔥`doclayout_docsynth`| 通用 | `doclayout_yolo_doclaynet_imgsz1120_docsynth_pretrain.onnx` | `['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']` |
3335

3436
PP模型来源:[PaddleOCR 版面分析](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/layout/README_ch.md)
3537

3638
yolov8n系列来源:[360LayoutAnalysis](https://github.com/360AILAB-NLP/360LayoutAnalysis)
3739

38-
doclayout版本暂时有问题,不推荐使用。正在更新中....
39-
~~(推荐使用)🔥doclayout_yolo模型来源:[DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO),该模型是目前最为优秀的开源模型,支持学术论文、Textbook、Financial、Exam Paper、Fuzzy Scans、PPT和Poster 7种文档类型的版面检测。值得一提的是,该模型支持的类别中存在`abandon`一类,主要是文档页面的页眉页脚部分,便于后续快速舍弃。~~
4040

41-
模型下载地址为:[link](https://github.com/RapidAI/RapidLayout/releases/tag/v0.0.0)
41+
(推荐使用)🔥doclayout_yolo模型来源:[DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO),该模型是目前最为优秀的开源模型,挑选了3个基于不同训练集训练得到的模型。其中`doclayout_docstructbench`来自[link](https://huggingface.co/juliozhao/DocLayout-YOLO-DocStructBench/tree/main)`doclayout_d4la`来自[link](https://huggingface.co/juliozhao/DocLayout-YOLO-D4LA-Docsynth300K_pretrained/blob/main/doclayout_yolo_d4la_imgsz1600_docsynth_pretrain.pt)`doclayout_docsynth`来自[link](https://huggingface.co/juliozhao/DocLayout-YOLO-DocLayNet-Docsynth300K_pretrained/tree/main)
42+
43+
DocLayout模型下载地址为:[link](https://github.com/RapidAI/RapidLayout/releases/tag/v0.0.0)
4244

4345
### 安装
4446

demo.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55

66
from rapid_layout import RapidLayout, VisLayout
77

8-
layout_engine = RapidLayout(model_type="doclayout_yolo", conf_thres=0.1)
8+
layout_engine = RapidLayout(model_type="doclayout_docsynth")
99

1010
img_path = "tests/test_files/PMC3576793_00004.jpg"
1111
img = cv2.imread(img_path)
1212

13-
boxes, scores, class_names, elapse = layout_engine(img)
13+
boxes, scores, class_names, elapse = layout_engine(img_path)
14+
print(boxes.shape)
1415
ploted_img = VisLayout.draw_detections(img, boxes, scores, class_names)
1516
if ploted_img is not None:
1617
cv2.imwrite("layout_res.png", ploted_img)

rapid_layout/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
"yolov8n_layout_report": f"{ROOT_URL}/yolov8n_layout_report.onnx",
3636
"yolov8n_layout_publaynet": f"{ROOT_URL}/yolov8n_layout_publaynet.onnx",
3737
"yolov8n_layout_general6": f"{ROOT_URL}/yolov8n_layout_general6.onnx",
38-
"doclayout_yolo": f"{ROOT_URL}/doclayout_yolo_docstructbench_imgsz1024.onnx",
38+
"doclayout_docstructbench": f"{ROOT_URL}/doclayout_yolo_docstructbench_imgsz1024.onnx",
39+
"doclayout_d4la": f"{ROOT_URL}/doclayout_yolo_d4la_imgsz1600_docsynth_pretrain.onnx",
40+
"doclayout_docsynth": f"{ROOT_URL}/doclayout_yolo_doclaynet_imgsz1120_docsynth_pretrain.onnx",
3941
}
4042
DEFAULT_MODEL_PATH = str(ROOT_DIR / "models" / "layout_cdla.onnx")
4143

rapid_layout/utils/augment.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# -*- encoding: utf-8 -*-
2+
# @Author: SWHL
3+
# @Contact: [email protected]
4+
import cv2
5+
import numpy as np
6+
7+
8+
class LetterBox:
9+
"""Resize image and padding for detection, instance segmentation, pose."""
10+
11+
def __init__(
12+
self,
13+
new_shape=(640, 640),
14+
auto=False,
15+
scaleFill=False,
16+
scaleup=True,
17+
center=True,
18+
stride=32,
19+
):
20+
"""Initialize LetterBox object with specific parameters."""
21+
self.new_shape = new_shape
22+
self.auto = auto
23+
self.scaleFill = scaleFill
24+
self.scaleup = scaleup
25+
self.stride = stride
26+
self.center = center # Put the image in the middle or top-left
27+
28+
def __call__(self, labels=None, image=None):
29+
"""Return updated labels and image with added border."""
30+
if labels is None:
31+
labels = {}
32+
img = labels.get("img") if image is None else image
33+
shape = img.shape[:2] # current shape [height, width]
34+
new_shape = labels.pop("rect_shape", self.new_shape)
35+
if isinstance(new_shape, int):
36+
new_shape = (new_shape, new_shape)
37+
38+
# Scale ratio (new / old)
39+
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
40+
if not self.scaleup: # only scale down, do not scale up (for better val mAP)
41+
r = min(r, 1.0)
42+
43+
# Compute padding
44+
ratio = r, r # width, height ratios
45+
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
46+
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
47+
if self.auto: # minimum rectangle
48+
dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) # wh padding
49+
elif self.scaleFill: # stretch
50+
dw, dh = 0.0, 0.0
51+
new_unpad = (new_shape[1], new_shape[0])
52+
ratio = (
53+
new_shape[1] / shape[1],
54+
new_shape[0] / shape[0],
55+
) # width, height ratios
56+
57+
if self.center:
58+
dw /= 2 # divide padding into 2 sides
59+
dh /= 2
60+
61+
if shape[::-1] != new_unpad: # resize
62+
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
63+
top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
64+
left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
65+
img = cv2.copyMakeBorder(
66+
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
67+
) # add border
68+
if labels.get("ratio_pad"):
69+
labels["ratio_pad"] = (labels["ratio_pad"], (left, top)) # for evaluation
70+
71+
if len(labels):
72+
labels = self._update_labels(labels, ratio, dw, dh)
73+
labels["img"] = img
74+
labels["resized_shape"] = new_shape
75+
return labels
76+
else:
77+
return img
78+
79+
def _update_labels(self, labels, ratio, padw, padh):
80+
"""Update labels."""
81+
labels["instances"].convert_bbox(format="xyxy")
82+
labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
83+
labels["instances"].scale(*ratio)
84+
labels["instances"].add_padding(padw, padh)
85+
return labels

rapid_layout/utils/post_prepross.py

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def extract_boxes(self, predictions):
299299

300300

301301
class DocLayoutPostProcess:
302-
def __init__(self, labels: List[str], conf_thres=0.7, iou_thres=0.5):
302+
def __init__(self, labels: List[str], conf_thres=0.2, iou_thres=0.5):
303303
self.labels = labels
304304
self.conf_threshold = conf_thres
305305
self.iou_threshold = iou_thres
@@ -308,31 +308,18 @@ def __init__(self, labels: List[str], conf_thres=0.7, iou_thres=0.5):
308308

309309
def __call__(
310310
self,
311-
output,
311+
preds,
312312
ori_img_shape: Tuple[int, int],
313313
img_shape: Tuple[int, int] = (1024, 1024),
314314
):
315-
self.img_height, self.img_width = ori_img_shape
316-
self.input_height, self.input_width = img_shape
317-
318-
output = output[0].squeeze()
319-
boxes = output[:, :-2]
320-
confidences = output[:, -2]
321-
class_ids = output[:, -1].astype(int)
322-
323-
mask = confidences > self.conf_threshold
324-
boxes = boxes[mask, :]
325-
confidences = confidences[mask]
326-
class_ids = class_ids[mask]
327-
328-
# Rescale boxes to original image dimensions
329-
boxes = rescale_boxes(
330-
boxes,
331-
self.input_width,
332-
self.input_height,
333-
self.img_width,
334-
self.img_height,
335-
)
315+
preds = preds[0]
316+
mask = preds[..., 4] > self.conf_threshold
317+
preds = [p[mask[idx]] for idx, p in enumerate(preds)][0]
318+
preds[:, :4] = scale_boxes(list(img_shape), preds[:, :4], list(ori_img_shape))
319+
320+
boxes = preds[:, :4]
321+
confidences = preds[:, 4]
322+
class_ids = preds[:, 5].astype(int)
336323
labels = [self.labels[i] for i in class_ids]
337324
return boxes, confidences, labels
338325

@@ -345,6 +332,54 @@ def rescale_boxes(boxes, input_width, input_height, img_width, img_height):
345332
return boxes
346333

347334

335+
def scale_boxes(
336+
img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False
337+
):
338+
"""
339+
Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally
340+
specified in (img1_shape) to the shape of a different image (img0_shape).
341+
342+
Args:
343+
img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
344+
boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
345+
img0_shape (tuple): the shape of the target image, in the format of (height, width).
346+
ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
347+
calculated based on the size difference between the two images.
348+
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
349+
rescaling.
350+
xywh (bool): The box format is xywh or not, default=False.
351+
352+
Returns:
353+
boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
354+
"""
355+
if ratio_pad is None: # calculate from img0_shape
356+
gain = min(
357+
img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]
358+
) # gain = old / new
359+
pad = (
360+
round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1),
361+
round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1),
362+
) # wh padding
363+
else:
364+
gain = ratio_pad[0][0]
365+
pad = ratio_pad[1]
366+
367+
if padding:
368+
boxes[..., 0] -= pad[0] # x padding
369+
boxes[..., 1] -= pad[1] # y padding
370+
if not xywh:
371+
boxes[..., 2] -= pad[0] # x padding
372+
boxes[..., 3] -= pad[1] # y padding
373+
boxes[..., :4] /= gain
374+
return clip_boxes(boxes, img0_shape)
375+
376+
377+
def clip_boxes(boxes, shape):
378+
boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
379+
boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
380+
return boxes
381+
382+
348383
def nms(boxes, scores, iou_threshold):
349384
# Sort by score
350385
sorted_indices = np.argsort(scores)[::-1]

rapid_layout/utils/pre_procss.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
import cv2
88
import numpy as np
99

10+
from .augment import LetterBox
11+
1012
InputType = Union[str, np.ndarray, bytes, Path]
1113

1214

1315
class PPPreProcess:
14-
1516
def __init__(self, img_size: Tuple[int, int]):
1617
self.size = img_size
1718
self.mean = np.array([0.485, 0.456, 0.406])
@@ -41,7 +42,6 @@ def permute(self, img: np.ndarray) -> np.ndarray:
4142

4243

4344
class YOLOv8PreProcess:
44-
4545
def __init__(self, img_size: Tuple[int, int]):
4646
self.img_size = img_size
4747

@@ -54,14 +54,15 @@ def __call__(self, image: np.ndarray) -> np.ndarray:
5454

5555

5656
class DocLayoutPreProcess:
57-
5857
def __init__(self, img_size: Tuple[int, int]):
5958
self.img_size = img_size
59+
self.letterbox = LetterBox(new_shape=img_size, auto=False, stride=32)
6060

6161
def __call__(self, image: np.ndarray) -> np.ndarray:
62-
input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
63-
input_img = cv2.resize(image, self.img_size)
64-
input_img = input_img / 255.0
65-
input_img = input_img.transpose(2, 0, 1)
66-
input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32)
62+
input_img = self.letterbox(image=image)
63+
input_img = input_img[None, ...]
64+
input_img = input_img[..., ::-1].transpose(0, 3, 1, 2)
65+
input_img = np.ascontiguousarray(input_img)
66+
input_img = input_img / 255
67+
input_tensor = input_img.astype(np.float32)
6768
return input_tensor

tests/test_layout.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,15 @@
2626
[
2727
("yolov8n_layout_publaynet", 12),
2828
("yolov8n_layout_general6", 13),
29-
("doclayout_yolo", 14),
29+
(
30+
"doclayout_docstructbench",
31+
14,
32+
),
33+
("doclayout_d4la", 11),
34+
("doclayout_docsynth", 14),
3035
],
3136
)
32-
def test_yolov8n_layout(model_type, gt):
37+
def test_layout(model_type, gt):
3338
img_path = test_file_dir / "PMC3576793_00004.jpg"
3439
engine = RapidLayout(model_type=model_type)
3540
boxes, scores, class_names, *elapse = engine(img_path)

0 commit comments

Comments
 (0)