diff --git a/docs/workflows/blocks.md b/docs/workflows/blocks.md index 06cea3d4c6..56e2748daa 100644 --- a/docs/workflows/blocks.md +++ b/docs/workflows/blocks.md @@ -13,24 +13,6 @@ hide:
-

-

-

-

-

-

-

-

-

-

-

-

-

-

-

-

-

-

@@ -49,59 +31,77 @@ hide:

+

+

+

+

+

+

+

+

+

+

+

+

+

+

+

+

+

+

+

+

+

+

+

+

+

+

+

+

+

-

-

-

+

+

+

-

-

-

-

-

+

+

+

+

-

-

-

-

-

-

+

+

+

+

-

+

+

+

+

-

-

-

-

-

-

-

-

-

-

+

-

-

-

+

diff --git a/docs/workflows/kinds.md b/docs/workflows/kinds.md index d910b85277..6be93488fe 100644 --- a/docs/workflows/kinds.md +++ b/docs/workflows/kinds.md @@ -37,36 +37,37 @@ for the presence of a mask in the input. ## Kinds declared in Roboflow plugins -* [`image_metadata`](/workflows/kinds/image_metadata): Dictionary with image metadata required by supervision +* [`integer`](/workflows/kinds/integer): Integer value +* [`roboflow_model_id`](/workflows/kinds/roboflow_model_id): Roboflow model id +* [`object_detection_prediction`](/workflows/kinds/object_detection_prediction): Prediction with detected bounding boxes in form of sv.Detections(...) object +* [`video_metadata`](/workflows/kinds/video_metadata): Video image metadata * [`string`](/workflows/kinds/string): String value -* [`numpy_array`](/workflows/kinds/numpy_array): Numpy array -* [`parent_id`](/workflows/kinds/parent_id): Identifier of parent for step output -* [`qr_code_detection`](/workflows/kinds/qr_code_detection): Prediction with QR code detection -* [`float`](/workflows/kinds/float): Float value -* [`dictionary`](/workflows/kinds/dictionary): Dictionary +* [`roboflow_api_key`](/workflows/kinds/roboflow_api_key): Roboflow API key +* [`detection`](/workflows/kinds/detection): Single element of detections-based prediction (like `object_detection_prediction`) +* [`list_of_values`](/workflows/kinds/list_of_values): List of values of any type +* [`instance_segmentation_prediction`](/workflows/kinds/instance_segmentation_prediction): Prediction with detected bounding boxes and segmentation masks in form of sv.Detections(...) object * [`float_zero_to_one`](/workflows/kinds/float_zero_to_one): `float` value in range `[0.0, 1.0]` -* [`object_detection_prediction`](/workflows/kinds/object_detection_prediction): Prediction with detected bounding boxes in form of sv.Detections(...) object -* [`*`](/workflows/kinds/*): Equivalent of any element +* [`image`](/workflows/kinds/image): Image in workflows +* [`image_metadata`](/workflows/kinds/image_metadata): Dictionary with image metadata required by supervision +* [`image_keypoints`](/workflows/kinds/image_keypoints): Image keypoints detected by classical Computer Vision method * [`bar_code_detection`](/workflows/kinds/bar_code_detection): Prediction with barcode detection -* [`roboflow_model_id`](/workflows/kinds/roboflow_model_id): Roboflow model id +* [`bytes`](/workflows/kinds/bytes): This kind represent bytes +* [`roboflow_project`](/workflows/kinds/roboflow_project): Roboflow project name +* [`dictionary`](/workflows/kinds/dictionary): Dictionary +* [`numpy_array`](/workflows/kinds/numpy_array): Numpy array +* [`qr_code_detection`](/workflows/kinds/qr_code_detection): Prediction with QR code detection +* [`classification_prediction`](/workflows/kinds/classification_prediction): Predictions from classifier * [`contours`](/workflows/kinds/contours): List of numpy arrays where each array represents contour points * [`serialised_payloads`](/workflows/kinds/serialised_payloads): Serialised element that is usually accepted by sink -* [`video_metadata`](/workflows/kinds/video_metadata): Video image metadata +* [`prediction_type`](/workflows/kinds/prediction_type): String value with type of prediction +* [`zone`](/workflows/kinds/zone): Definition of polygon zone +* [`keypoint_detection_prediction`](/workflows/kinds/keypoint_detection_prediction): Prediction with detected bounding boxes and detected keypoints in form of sv.Detections(...) object +* [`boolean`](/workflows/kinds/boolean): Boolean flag +* [`float`](/workflows/kinds/float): Float value +* [`point`](/workflows/kinds/point): Single point in 2D * [`top_class`](/workflows/kinds/top_class): String value representing top class predicted by classification model * [`language_model_output`](/workflows/kinds/language_model_output): LLM / VLM output -* [`image`](/workflows/kinds/image): Image in workflows -* [`roboflow_api_key`](/workflows/kinds/roboflow_api_key): Roboflow API key +* [`parent_id`](/workflows/kinds/parent_id): Identifier of parent for step output +* [`*`](/workflows/kinds/*): Equivalent of any element * [`rgb_color`](/workflows/kinds/rgb_color): RGB color -* [`boolean`](/workflows/kinds/boolean): Boolean flag -* [`roboflow_project`](/workflows/kinds/roboflow_project): Roboflow project name -* [`image_keypoints`](/workflows/kinds/image_keypoints): Image keypoints detected by classical Computer Vision method -* [`list_of_values`](/workflows/kinds/list_of_values): List of values of any type -* [`zone`](/workflows/kinds/zone): Definition of polygon zone -* [`point`](/workflows/kinds/point): Single point in 2D -* [`prediction_type`](/workflows/kinds/prediction_type): String value with type of prediction -* [`instance_segmentation_prediction`](/workflows/kinds/instance_segmentation_prediction): Prediction with detected bounding boxes and segmentation masks in form of sv.Detections(...) object -* [`integer`](/workflows/kinds/integer): Integer value -* [`keypoint_detection_prediction`](/workflows/kinds/keypoint_detection_prediction): Prediction with detected bounding boxes and detected keypoints in form of sv.Detections(...) object -* [`classification_prediction`](/workflows/kinds/classification_prediction): Predictions from classifier -* [`detection`](/workflows/kinds/detection): Single element of detections-based prediction (like `object_detection_prediction`) diff --git a/inference/core/version.py b/inference/core/version.py index 26bce6ff9b..a98e64b54c 100644 --- a/inference/core/version.py +++ b/inference/core/version.py @@ -1,4 +1,4 @@ -__version__ = "0.24.0" +__version__ = "0.25.0" if __name__ == "__main__": diff --git a/inference/core/workflows/core_steps/loader.py b/inference/core/workflows/core_steps/loader.py index d672b7a005..efc67723f2 100644 --- a/inference/core/workflows/core_steps/loader.py +++ b/inference/core/workflows/core_steps/loader.py @@ -227,6 +227,9 @@ from inference.core.workflows.core_steps.transformations.stitch_images.v1 import ( StitchImagesBlockV1, ) +from inference.core.workflows.core_steps.transformations.stitch_ocr_detections.v1 import ( + StitchOCRDetectionsBlockV1, +) # Visualizers from inference.core.workflows.core_steps.visualizations.background_color.v1 import ( @@ -425,6 +428,7 @@ def load_blocks() -> List[Type[WorkflowBlock]]: StabilityAIInpaintingBlockV1, StabilizeTrackedDetectionsBlockV1, StitchImagesBlockV1, + StitchOCRDetectionsBlockV1, TemplateMatchingBlockV1, TimeInZoneBlockV1, TimeInZoneBlockV2, diff --git a/inference/core/workflows/core_steps/transformations/stitch_ocr_detections/__init__.py b/inference/core/workflows/core_steps/transformations/stitch_ocr_detections/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/inference/core/workflows/core_steps/transformations/stitch_ocr_detections/v1.py b/inference/core/workflows/core_steps/transformations/stitch_ocr_detections/v1.py new file mode 100644 index 0000000000..4141f8de03 --- /dev/null +++ b/inference/core/workflows/core_steps/transformations/stitch_ocr_detections/v1.py @@ -0,0 +1,294 @@ +from enum import Enum +from typing import Dict, List, Literal, Optional, Tuple, Type, Union + +import numpy as np +import supervision as sv +from pydantic import AliasChoices, ConfigDict, Field, field_validator + +from inference.core.workflows.execution_engine.entities.base import ( + Batch, + OutputDefinition, +) +from inference.core.workflows.execution_engine.entities.types import ( + INTEGER_KIND, + OBJECT_DETECTION_PREDICTION_KIND, + STRING_KIND, + StepOutputSelector, + WorkflowParameterSelector, +) +from inference.core.workflows.prototypes.block import ( + BlockResult, + WorkflowBlock, + WorkflowBlockManifest, +) + +LONG_DESCRIPTION = """ +Combines OCR detection results into a coherent text string by organizing detections spatially. +This transformation is perfect for turning individual OCR results into structured, readable text! + +#### How It Works + +This transformation reconstructs the original text from OCR detection results by: + +1. 📐 **Grouping** text detections into rows based on their vertical (`y`) positions + +2. 📏 **Sorting** detections within each row by horizontal (`x`) position + +3. 📜 **Concatenating** the text in reading order (left-to-right, top-to-bottom) + +#### Parameters + +- **`tolerance`**: Controls how close detections need to be vertically to be considered part of the same line of text. +A higher tolerance will group detections that are further apart vertically. + +- **`reading_direction`**: Determines the order in which text is read. Available options: + + * **"left_to_right"**: Standard left-to-right reading (e.g., English) ➡️ + + * **"right_to_left"**: Right-to-left reading (e.g., Arabic) ⬅️ + + * **"vertical_top_to_bottom"**: Vertical reading from top to bottom ⬇️ + + * **"vertical_bottom_to_top"**: Vertical reading from bottom to top ⬆️ + +#### Why Use This Transformation? + +This is especially useful for: + +- 📖 Converting individual character/word detections into a readable text block + +- 📝 Reconstructing multi-line text from OCR results + +- 🔀 Maintaining proper reading order for detected text elements + +- 🌏 Supporting different writing systems and text orientations + +#### Example Usage + +Use this transformation after an OCR model that outputs individual words or characters, so you can reconstruct the +original text layout in its intended format. +""" + +SHORT_DESCRIPTION = "Combines OCR detection results into a coherent text string by organizing detections spatially." + + +class ReadingDirection(str, Enum): + LEFT_TO_RIGHT = "left_to_right" + RIGHT_TO_LEFT = "right_to_left" + VERTICAL_TOP_TO_BOTTOM = "vertical_top_to_bottom" + VERTICAL_BOTTOM_TO_TOP = "vertical_bottom_to_top" + + +class BlockManifest(WorkflowBlockManifest): + model_config = ConfigDict( + json_schema_extra={ + "name": "Stitch OCR Detections", + "version": "v1", + "short_description": SHORT_DESCRIPTION, + "long_description": LONG_DESCRIPTION, + "license": "Apache-2.0", + "block_type": "transformation", + "ui_manifest": { + "section": "advanced", + "icon": "fal fa-reel", + "blockPriority": 2, + }, + } + ) + type: Literal["roboflow_core/stitch_ocr_detections@v1"] + predictions: StepOutputSelector( + kind=[ + OBJECT_DETECTION_PREDICTION_KIND, + ] + ) = Field( + title="OCR Detections", + description="The output of an OCR detection model.", + examples=["$steps.my_ocr_detection_model.predictions"], + ) + reading_direction: Literal[ + "left_to_right", + "right_to_left", + "vertical_top_to_bottom", + "vertical_bottom_to_top", + ] = Field( + title="Reading Direction", + description="The direction of the text in the image.", + examples=["right_to_left"], + json_schema_extra={ + "values_metadata": { + "left_to_right": { + "name": "Left To Right", + "description": "Standard left-to-right reading (e.g., English language)", + }, + "right_to_left": { + "name": "Right To Left", + "description": "Right-to-left reading (e.g., Arabic)", + }, + "vertical_top_to_bottom": { + "name": "Top To Bottom (Vertical)", + "description": "Vertical reading from top to bottom", + }, + "vertical_bottom_to_top": { + "name": "Bottom To Top (Vertical)", + "description": "Vertical reading from bottom to top", + }, + } + }, + ) + tolerance: Union[int, WorkflowParameterSelector(kind=[INTEGER_KIND])] = Field( + title="Tolerance", + description="The tolerance for grouping detections into the same line of text.", + default=10, + examples=[10, "$inputs.tolerance"], + ) + + @field_validator("tolerance") + @classmethod + def ensure_tolerance_greater_than_zero( + cls, value: Union[int, str] + ) -> Union[int, str]: + if isinstance(value, int) and value <= 0: + raise ValueError( + "Stitch OCR detections block expects `tollerance` to be greater than zero." + ) + return value + + @classmethod + def accepts_batch_input(cls) -> bool: + return True + + @classmethod + def describe_outputs(cls) -> List[OutputDefinition]: + return [ + OutputDefinition(name="ocr_text", kind=[STRING_KIND]), + ] + + @classmethod + def get_execution_engine_compatibility(cls) -> Optional[str]: + return ">=1.0.0,<2.0.0" + + +class StitchOCRDetectionsBlockV1(WorkflowBlock): + @classmethod + def get_manifest(cls) -> Type[WorkflowBlockManifest]: + return BlockManifest + + def run( + self, + predictions: Batch[sv.Detections], + reading_direction: str, + tolerance: int, + ) -> BlockResult: + return [ + stitch_ocr_detections( + detections=detections, + reading_direction=reading_direction, + tolerance=tolerance, + ) + for detections in predictions + ] + + +def stitch_ocr_detections( + detections: sv.Detections, + reading_direction: str = "left_to_right", + tolerance: int = 10, +) -> Dict[str, str]: + """ + Stitch OCR detections into coherent text based on spatial arrangement. + + Args: + detections: Supervision Detections object containing OCR results + reading_direction: Direction to read text ("left_to_right", "right_to_left", + "vertical_top_to_bottom", "vertical_bottom_to_top") + tolerance: Vertical tolerance for grouping text into lines + + Returns: + Dict containing stitched OCR text under 'ocr_text' key + """ + if len(detections) == 0: + return {"ocr_text": ""} + + xyxy = detections.xyxy.round().astype(dtype=int) + class_names = detections.data["class_name"] + + # Prepare coordinates based on reading direction + xyxy = prepare_coordinates(xyxy, reading_direction) + + # Group detections into lines + boxes_by_line = group_detections_by_line(xyxy, reading_direction, tolerance) + # Sort lines based on reading direction + lines = sorted( + boxes_by_line.keys(), reverse=reading_direction in ["vertical_bottom_to_top"] + ) + + # Build final text + ordered_class_names = [] + for i, key in enumerate(lines): + line_data = boxes_by_line[key] + line_xyxy = np.array(line_data["xyxy"]) + line_idx = np.array(line_data["idx"]) + + # Sort detections within line + sort_idx = sort_line_detections(line_xyxy, reading_direction) + + # Add sorted class names for this line + ordered_class_names.extend(class_names[line_idx[sort_idx]]) + + # Add line separator if not last line + if i < len(lines) - 1: + ordered_class_names.append(get_line_separator(reading_direction)) + + return {"ocr_text": "".join(ordered_class_names)} + + +def prepare_coordinates( + xyxy: np.ndarray, + reading_direction: str, +) -> np.ndarray: + """Prepare coordinates based on reading direction.""" + if reading_direction in ["vertical_top_to_bottom", "vertical_bottom_to_top"]: + # Swap x and y coordinates: [x1,y1,x2,y2] -> [y1,x1,y2,x2] + return xyxy[:, [1, 0, 3, 2]] + return xyxy + + +def group_detections_by_line( + xyxy: np.ndarray, + reading_direction: str, + tolerance: int, +) -> Dict[float, Dict[str, List]]: + """Group detections into lines based on primary coordinate.""" + # After prepare_coordinates swap, we always group by y ([:, 1]) + primary_coord = xyxy[:, 1] # This is y for horizontal, swapped x for vertical + + # Round primary coordinate to group into lines + rounded_primary = np.round(primary_coord / tolerance) * tolerance + + boxes_by_line = {} + # Group bounding boxes and associated indices by line + for i, (bbox, line_pos) in enumerate(zip(xyxy, rounded_primary)): + if line_pos not in boxes_by_line: + boxes_by_line[line_pos] = {"xyxy": [bbox], "idx": [i]} + else: + boxes_by_line[line_pos]["xyxy"].append(bbox) + boxes_by_line[line_pos]["idx"].append(i) + + return boxes_by_line + + +def sort_line_detections( + line_xyxy: np.ndarray, + reading_direction: str, +) -> np.ndarray: + """Sort detections within a line based on reading direction.""" + # After prepare_coordinates swap, we always sort by x ([:, 0]) + if reading_direction in ["left_to_right", "vertical_top_to_bottom"]: + return line_xyxy[:, 0].argsort() # Sort by x1 (original x or swapped y) + else: # right_to_left or vertical_bottom_to_top + return (-line_xyxy[:, 0]).argsort() # Sort by -x1 (original -x or swapped -y) + + +def get_line_separator(reading_direction: str) -> str: + """Get the appropriate separator based on reading direction.""" + return "\n" if reading_direction in ["left_to_right", "right_to_left"] else " " diff --git a/tests/inference/models_predictions_tests/test_owlv2.py b/tests/inference/models_predictions_tests/test_owlv2.py index 3ad5913abb..6bbcbcfd5b 100644 --- a/tests/inference/models_predictions_tests/test_owlv2.py +++ b/tests/inference/models_predictions_tests/test_owlv2.py @@ -15,7 +15,14 @@ def test_owlv2(): { "image": image, "boxes": [ - {"x": 223, "y": 306, "w": 40, "h": 226, "cls": "post", "negative": False}, + { + "x": 223, + "y": 306, + "w": 40, + "h": 226, + "cls": "post", + "negative": False, + }, ], } ], @@ -42,7 +49,6 @@ def test_owlv2(): assert abs(532 - posts[3].x) < 1.5 assert abs(572 - posts[4].x) < 1.5 - # test we can handle multiple (positive and negative) prompts for the same image request = OwlV2InferenceRequest( image=image, @@ -50,9 +56,30 @@ def test_owlv2(): { "image": image, "boxes": [ - {"x": 223, "y": 306, "w": 40, "h": 226, "cls": "post", "negative": False}, - {"x": 247, "y": 294, "w": 25, "h": 165, "cls": "post", "negative": True}, - {"x": 264, "y": 327, "w": 21, "h": 74, "cls": "post", "negative": False}, + { + "x": 223, + "y": 306, + "w": 40, + "h": 226, + "cls": "post", + "negative": False, + }, + { + "x": 247, + "y": 294, + "w": 25, + "h": 165, + "cls": "post", + "negative": True, + }, + { + "x": 264, + "y": 327, + "w": 21, + "h": 74, + "cls": "post", + "negative": False, + }, ], } ], @@ -76,7 +103,14 @@ def test_owlv2(): { "image": image, "boxes": [ - {"x": 223, "y": 306, "w": 40, "h": 226, "cls": "post", "negative": False} + { + "x": 223, + "y": 306, + "w": 40, + "h": 226, + "cls": "post", + "negative": False, + } ], }, { @@ -89,4 +123,4 @@ def test_owlv2(): ) response = OwlV2().infer_from_request(request) - assert len(response.predictions) == 5 \ No newline at end of file + assert len(response.predictions) == 5 diff --git a/tests/workflows/integration_tests/execution/assets/image_credits.txt b/tests/workflows/integration_tests/execution/assets/image_credits.txt index 22e35dca4f..8c24852f81 100644 --- a/tests/workflows/integration_tests/execution/assets/image_credits.txt +++ b/tests/workflows/integration_tests/execution/assets/image_credits.txt @@ -2,3 +2,4 @@ crowd.jpg: https://pixabay.com/users/wal_172619-12138562 license_plate.jpg: https://www.pexels.com/photo/kia-niros-driving-on-the-road-11320632/ dogs.jpg: https://www.pexels.com/photo/brown-and-white-dogs-sitting-on-field-3568134/ multi-fruit.jpg: https://www.freepik.com/free-photo/front-close-view-organic-nutrition-source-fresh-bananas-bundle-red-apples-orange-with-stem-dark-background_17119128.htm +multi_line_text.jpg: https://www.pexels.com/photo/illuminated-qoute-board-2255441/ \ No newline at end of file diff --git a/tests/workflows/integration_tests/execution/assets/multi_line_text.jpg b/tests/workflows/integration_tests/execution/assets/multi_line_text.jpg new file mode 100644 index 0000000000..5b932fd22b Binary files /dev/null and b/tests/workflows/integration_tests/execution/assets/multi_line_text.jpg differ diff --git a/tests/workflows/integration_tests/execution/conftest.py b/tests/workflows/integration_tests/execution/conftest.py index dbac6b1a73..bf10c136be 100644 --- a/tests/workflows/integration_tests/execution/conftest.py +++ b/tests/workflows/integration_tests/execution/conftest.py @@ -35,6 +35,11 @@ def fruit_image() -> np.ndarray: return cv2.imread(os.path.join(ASSETS_DIR, "multi-fruit.jpg")) +@pytest.fixture(scope="function") +def multi_line_text_image() -> np.ndarray: + return cv2.imread(os.path.join(ASSETS_DIR, "multi_line_text.jpg")) + + @pytest.fixture(scope="function") def stitch_left_image() -> np.ndarray: return cv2.imread(os.path.join(ASSETS_DIR, "stitch", "v_left.jpeg")) diff --git a/tests/workflows/integration_tests/execution/test_workflow_with_keypoint_visualization.py b/tests/workflows/integration_tests/execution/test_workflow_with_keypoint_visualization.py index c3ce3ff949..268bdc1243 100644 --- a/tests/workflows/integration_tests/execution/test_workflow_with_keypoint_visualization.py +++ b/tests/workflows/integration_tests/execution/test_workflow_with_keypoint_visualization.py @@ -1,12 +1,11 @@ -import numpy as np import cv2 +import numpy as np from inference.core.env import WORKFLOWS_MAX_CONCURRENT_STEPS from inference.core.managers.base import ModelManager from inference.core.workflows.core_steps.common.entities import StepExecutionMode from inference.core.workflows.execution_engine.core import ExecutionEngine - WORKFLOW_KEYPOINT_VISUALIZATION = { "version": "1.1", "inputs": [ diff --git a/tests/workflows/integration_tests/execution/test_workflow_with_ocr_detections_stitching.py b/tests/workflows/integration_tests/execution/test_workflow_with_ocr_detections_stitching.py new file mode 100644 index 0000000000..b370602b37 --- /dev/null +++ b/tests/workflows/integration_tests/execution/test_workflow_with_ocr_detections_stitching.py @@ -0,0 +1,96 @@ +import numpy as np + +from inference.core.env import WORKFLOWS_MAX_CONCURRENT_STEPS +from inference.core.managers.base import ModelManager +from inference.core.workflows.core_steps.common.entities import StepExecutionMode +from inference.core.workflows.execution_engine.core import ExecutionEngine +from tests.workflows.integration_tests.execution.workflows_gallery_collector.decorators import ( + add_to_workflows_gallery, +) + +WORKFLOW_STITCHING_OCR_DETECTIONS = { + "version": "1.0", + "inputs": [ + {"type": "WorkflowImage", "name": "image"}, + { + "type": "WorkflowParameter", + "name": "model_id", + "default_value": "ocr-oy9a7/1", + }, + {"type": "WorkflowParameter", "name": "tolerance", "default_value": 10}, + {"type": "WorkflowParameter", "name": "confidence", "default_value": 0.4}, + ], + "steps": [ + { + "type": "roboflow_core/roboflow_object_detection_model@v1", + "name": "ocr_detection", + "image": "$inputs.image", + "model_id": "$inputs.model_id", + "confidence": "$inputs.confidence", + }, + { + "type": "roboflow_core/stitch_ocr_detections@v1", + "name": "detections_stitch", + "predictions": "$steps.ocr_detection.predictions", + "reading_direction": "left_to_right", + "tolerance": "$inputs.tolerance", + }, + ], + "outputs": [ + { + "type": "JsonField", + "name": "ocr_text", + "selector": "$steps.detections_stitch.ocr_text", + }, + ], +} + + +@add_to_workflows_gallery( + category="Workflows for OCR", + use_case_title="Workflow with model detecting individual characters and text stitching", + use_case_description=""" +This workflow extracts and organizes text from an image using OCR. It begins by analyzing the image with detection +model to detect individual characters or words and their positions. + +Then, it groups nearby text into lines based on a specified `tolerance` for spacing and arranges them in +reading order (`left-to-right`). + +The final output is a JSON field containing the structured text in readable, logical order, accurately reflecting +the layout of the original image. + """, + workflow_definition=WORKFLOW_STITCHING_OCR_DETECTIONS, + workflow_name_in_app="ocr-detections-stitch", +) +def test_detection_plus_classification_workflow_when_minimal_valid_input_provided( + model_manager: ModelManager, + multi_line_text_image: np.ndarray, + roboflow_api_key: str, +) -> None: + # given + workflow_init_parameters = { + "workflows_core.model_manager": model_manager, + "workflows_core.api_key": roboflow_api_key, + "workflows_core.step_execution_mode": StepExecutionMode.LOCAL, + } + execution_engine = ExecutionEngine.init( + workflow_definition=WORKFLOW_STITCHING_OCR_DETECTIONS, + init_parameters=workflow_init_parameters, + max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS, + ) + + # when + result = execution_engine.run( + runtime_parameters={ + "image": multi_line_text_image, + "tolerance": 20, + "confidence": 0.6, + } + ) + + assert isinstance(result, list), "Expected list to be delivered" + assert len(result) == 1, "Expected 1 element in the output for one input image" + assert set(result[0].keys()) == { + "ocr_text", + }, "Expected all declared outputs to be delivered" + assert result[0]["ocr_text"] == "MAKE\nTHISDAY\nGREAT" diff --git a/tests/workflows/unit_tests/core_steps/transformations/test_stitch_ocr_detections.py b/tests/workflows/unit_tests/core_steps/transformations/test_stitch_ocr_detections.py new file mode 100644 index 0000000000..94ff58ab8a --- /dev/null +++ b/tests/workflows/unit_tests/core_steps/transformations/test_stitch_ocr_detections.py @@ -0,0 +1,195 @@ +import numpy as np +import pytest +import supervision as sv +from pydantic import ValidationError + +from inference.core.workflows.core_steps.transformations.stitch_ocr_detections.v1 import ( + BlockManifest, + stitch_ocr_detections, +) + + +def test_stitch_ocr_detections_when_valid_manifest_is_given() -> None: + # given + data = { + "type": "roboflow_core/stitch_ocr_detections@v1", + "name": "some", + "predictions": "$steps.detection.predictions", + "reading_direction": "left_to_right", + "tolerance": "$inputs.tolerance", + } + + # when + result = BlockManifest.model_validate(data) + + # then + assert result == BlockManifest( + type="roboflow_core/stitch_ocr_detections@v1", + name="some", + predictions="$steps.detection.predictions", + reading_direction="left_to_right", + tolerance="$inputs.tolerance", + ) + + +def test_stitch_ocr_detections_when_invalid_tolerance_is_given() -> None: + # given + data = { + "type": "roboflow_core/stitch_ocr_detections@v1", + "name": "some", + "predictions": "$steps.detection.predictions", + "reading_direction": "left_to_right", + "tolerance": 0, + } + + # when + with pytest.raises(ValidationError): + _ = BlockManifest.model_validate(data) + + +def create_test_detections(xyxy: np.ndarray, class_names: list) -> sv.Detections: + """Helper function to create test detection objects.""" + return sv.Detections( + xyxy=np.array(xyxy), data={"class_name": np.array(class_names)} + ) + + +def test_empty_detections(): + """Test handling of empty detections.""" + detections = create_test_detections(xyxy=np.array([]).reshape(0, 4), class_names=[]) + result = stitch_ocr_detections(detections) + assert result == {"ocr_text": ""} + + +def test_left_to_right_single_line(): + """Test basic left-to-right reading of a single line.""" + detections = create_test_detections( + xyxy=np.array( + [ + [10, 0, 20, 10], # "H" + [30, 0, 40, 10], # "E" + [50, 0, 60, 10], # "L" + [70, 0, 80, 10], # "L" + [90, 0, 100, 10], # "O" + ] + ), + class_names=["H", "E", "L", "L", "O"], + ) + result = stitch_ocr_detections(detections, reading_direction="left_to_right") + assert result == {"ocr_text": "HELLO"} + + +def test_left_to_right_multiple_lines(): + """Test left-to-right reading with multiple lines.""" + detections = create_test_detections( + xyxy=np.array( + [ + [10, 0, 20, 10], # "H" + [30, 0, 40, 10], # "I" + [10, 20, 20, 30], # "B" + [30, 20, 40, 30], # "Y" + [50, 20, 60, 30], # "E" + ] + ), + class_names=["H", "I", "B", "Y", "E"], + ) + result = stitch_ocr_detections(detections, reading_direction="left_to_right") + assert result == {"ocr_text": "HI\nBYE"} + + +def test_right_to_left_single_line(): + """Test right-to-left reading of a single line.""" + detections = create_test_detections( + xyxy=np.array( + [ + [90, 0, 100, 10], # "م" + [70, 0, 80, 10], # "ر" + [50, 0, 60, 10], # "ح" + [30, 0, 40, 10], # "ب" + [10, 0, 20, 10], # "ا" + ] + ), + class_names=["م", "ر", "ح", "ب", "ا"], + ) + result = stitch_ocr_detections(detections, reading_direction="right_to_left") + assert result == {"ocr_text": "مرحبا"} + + +def test_vertical_top_to_bottom(): + """Test vertical reading from top to bottom.""" + detections = create_test_detections( + xyxy=np.array( + [ + # First column (rightmost) + [20, 10, 30, 20], # "上" + [20, 30, 30, 40], # "下" + # Second column (leftmost) + [0, 10, 10, 20], # "左" + [0, 30, 10, 40], # "右" + ] + ), + class_names=["上", "下", "左", "右"], + ) + # With current logic, we'll group by original x-coord and sort by y + result = stitch_ocr_detections( + detections, reading_direction="vertical_top_to_bottom" + ) + assert result == {"ocr_text": "左右 上下"} + + +def test_tolerance_grouping(): + """Test that tolerance parameter correctly groups lines.""" + detections = create_test_detections( + xyxy=np.array( + [ + [10, 0, 20, 10], # "A" + [30, 2, 40, 12], # "B" (slightly offset) + [10, 20, 20, 30], # "C" (closer to D) + [30, 22, 40, 32], # "D" (slightly offset from C) + ] + ), + class_names=["A", "B", "C", "D"], + ) + + # With small tolerance, should treat as 4 separate lines + result_small = stitch_ocr_detections(detections, tolerance=1) + assert result_small == {"ocr_text": "A\nB\nC\nD"} + + # With larger tolerance, should group into 2 lines + result_large = stitch_ocr_detections(detections, tolerance=5) + assert result_large == {"ocr_text": "AB\nCD"} + + +def test_unordered_input(): + """Test that detections are correctly ordered regardless of input order.""" + detections = create_test_detections( + xyxy=np.array( + [ + [50, 0, 60, 10], # "O" + [10, 0, 20, 10], # "H" + [70, 0, 80, 10], # "W" + [30, 0, 40, 10], # "L" + ] + ), + class_names=["O", "H", "W", "L"], + ) + result = stitch_ocr_detections(detections, reading_direction="left_to_right") + assert result == {"ocr_text": "HLOW"} + + +@pytest.mark.parametrize( + "reading_direction", + [ + "left_to_right", + "right_to_left", + "vertical_top_to_bottom", + "vertical_bottom_to_top", + ], +) +def test_reading_directions(reading_direction): + """Test that all reading directions are supported.""" + detections = create_test_detections( + xyxy=np.array([[0, 0, 10, 10]]), class_names=["A"] # Single detection + ) + result = stitch_ocr_detections(detections, reading_direction=reading_direction) + assert result == {"ocr_text": "A"} # Should work with any direction