From 7c9b059ee8bee03cf1e99fdc23e684ec68fd6e96 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Mon, 12 Aug 2024 17:50:51 -0700 Subject: [PATCH] Add Florence2 OCR (#194) --- tests/integ/test_tools.py | 17 ++++++++--- vision_agent/tools/__init__.py | 1 + vision_agent/tools/tools.py | 47 +++++++++++++++++++++++++++++++ vision_agent/utils/image_utils.py | 17 +++++++++++ 4 files changed, 78 insertions(+), 4 deletions(-) diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index 68f5d396..1d99ff69 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -5,17 +5,18 @@ blip_image_caption, clip, closest_mask_distance, - florencev2_image_caption, depth_anything_v2, + detr_segmentation, dpt_hybrid_midas, + florencev2_image_caption, + florencev2_object_detection, + florencev2_roberta_vqa, + florencev2_ocr, generate_pose_image, generate_soft_edge_image, - florencev2_object_detection, - detr_segmentation, git_vqa_v2, grounding_dino, grounding_sam, - florencev2_roberta_vqa, loca_visual_prompt_counting, loca_zero_shot_counting, ocr, @@ -182,6 +183,14 @@ def test_ocr() -> None: assert any("Region-based segmentation" in res["label"] for res in result) +def test_florencev2_ocr() -> None: + img = ski.data.page() + result = florencev2_ocr( + image=img, + ) + assert any("Region-based segmentation" in res["label"] for res in result) + + def test_mask_distance(): # Create two binary masks mask1 = np.zeros((100, 100), dtype=np.uint8) diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 52681274..f9879626 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -19,6 +19,7 @@ florencev2_image_caption, florencev2_object_detection, florencev2_roberta_vqa, + florencev2_ocr, generate_pose_image, generate_soft_edge_image, get_tool_documentation, diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index fc4a59f4..5d91a8ff 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -28,6 +28,7 @@ denormalize_bbox, get_image_size, normalize_bbox, + convert_quad_box_to_bbox, rle_decode, ) @@ -652,6 +653,51 @@ def florencev2_object_detection(image: np.ndarray, prompt: str) -> List[Dict[str return return_data +def florencev2_ocr(image: np.ndarray) -> List[Dict[str, Any]]: + """'florencev2_ocr' is a tool that can detect text and text regions in an image. + Each text region contains one line of text. It returns a list of detected text, + the text region as a bounding box with normalized coordinates, and confidence + scores. The results are sorted from top-left to bottom right. + + Parameters: + image (np.ndarray): The image to extract text from. + + Returns: + List[Dict[str, Any]]: A list of dictionaries containing the detected text, bbox + with nornmalized coordinates, and confidence score. + + Example + ------- + >>> florencev2_ocr(image) + [ + {'label': 'hello world', 'bbox': [0.1, 0.11, 0.35, 0.4], 'score': 0.99}, + ] + """ + + image_size = image.shape[:2] + image_b64 = convert_to_b64(image) + data = { + "image": image_b64, + "task": "", + "function_name": "florencev2_ocr", + } + + detections = send_inference_request(data, "florence2", v2=True) + detections = detections[""] + return_data = [] + for i in range(len(detections["quad_boxes"])): + return_data.append( + { + "label": detections["labels"][i], + "bbox": normalize_bbox( + convert_quad_box_to_bbox(detections["quad_boxes"][i]), image_size + ), + "score": 1.0, + } + ) + return return_data + + def detr_segmentation(image: np.ndarray) -> List[Dict[str, Any]]: """'detr_segmentation' is a tool that can segment common objects in an image without any text prompt. It returns a list of detected objects @@ -1248,6 +1294,7 @@ def overlay_heat_map( loca_visual_prompt_counting, florencev2_roberta_vqa, florencev2_image_caption, + florencev2_ocr, detr_segmentation, depth_anything_v2, generate_soft_edge_image, diff --git a/vision_agent/utils/image_utils.py b/vision_agent/utils/image_utils.py index 217c9fa2..ddbd14b3 100644 --- a/vision_agent/utils/image_utils.py +++ b/vision_agent/utils/image_utils.py @@ -140,6 +140,23 @@ def denormalize_bbox( return bbox +def convert_quad_box_to_bbox(quad_box: List[Union[int, float]]) -> List[float]: + r"""Convert a quadrilateral bounding box to a rectangular bounding box. + + Parameters: + quad_box: the quadrilateral bounding box + + Returns: + The rectangular bounding box + """ + x1, y1, x2, y2, x3, y3, x4, y4 = quad_box + x_min = min(x1, x2, x3, x4) + x_max = max(x1, x2, x3, x4) + y_min = min(y1, y2, y3, y4) + y_max = max(y1, y2, y3, y4) + return [x_min, y_min, x_max, y_max] + + def overlay_bboxes( image: Union[str, Path, np.ndarray, ImageType], bboxes: Dict ) -> ImageType: