From 6af8d9eeabb0e1fbf5a30ccd6dad7992cd3a0ee0 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Tue, 23 Apr 2024 15:42:41 -0700 Subject: [PATCH] added OCR --- README.md | 1 + vision_agent/agent/vision_agent.py | 1 + vision_agent/tools/__init__.py | 1 + vision_agent/tools/tools.py | 53 ++++++++++++++++++++++++++++++ 4 files changed, 56 insertions(+) diff --git a/README.md b/README.md index 7383af76..87db92ad 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,7 @@ to pick it based on the tool description and use it based on the usage provided. | ExtractFrames | ExtractFrames extracts frames with motion from a video. | | ZeroShotCounting | ZeroShotCounting returns the total number of objects belonging to a single class in a given image | | VisualPromptCounting | VisualPromptCounting returns the total number of objects belonging to a single class given an image and visual prompt | +| OCR | OCR returns the text detected in an image along with the location. | It also has a basic set of calculate tools such as add, subtract, multiply and divide. diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 9b02d4fd..76629f6a 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -377,6 +377,7 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]] "dinov_", "zero_shot_counting_", "visual_prompt_counting_", + "ocr_", ]: continue diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 450c0b26..67248156 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -1,6 +1,7 @@ from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT from .tools import ( # Counter, CLIP, + OCR, TOOLS, BboxArea, BboxIoU, diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 06727fc2..64f97467 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -1,3 +1,4 @@ +import io import logging import tempfile from abc import ABC @@ -868,6 +869,57 @@ def __call__(self, video_uri: str) -> List[Tuple[str, float]]: return result +class OCR(Tool): + name = "ocr_" + description = "'ocr_' extracts text from an image." + usage = { + "required_parameters": [ + {"name": "image", "type": "str"}, + ], + "examples": [ + { + "scenario": "Can you extract the text from this image? Image name: image.png", + "parameters": {"image": "image.png"}, + }, + ], + } + _API_KEY = "land_sk_WVYwP00xA3iXely2vuar6YUDZ3MJT9yLX6oW5noUkwICzYLiDV" + _URL = "https://app.landing.ai/ocr/v1/detect-text" + + def __call__(self, image: str) -> dict: + pil_image = Image.open(image).convert("RGB") + image_size = pil_image.size[::-1] + image_buffer = io.BytesIO() + pil_image.save(image_buffer, format="PNG") + buffer_bytes = image_buffer.getvalue() + image_buffer.close() + + res = requests.post( + self._URL, + files={"images": buffer_bytes}, + data={"language": "en"}, + headers={"contentType": "multipart/form-data", "apikey": self._API_KEY}, + ) + if res.status_code != 200: + _LOGGER.error(f"Request failed: {res.text}") + raise ValueError(f"Request failed: {res.text}") + + data = res.json() + output = {"labels": [], "bboxes": [], "scores": []} + for det in data[0]: + output["labels"].append(det["text"]) + box = [ + det["location"][0]["x"], + det["location"][0]["y"], + det["location"][2]["x"], + det["location"][2]["y"], + ] + box = normalize_bbox(box, image_size) + output["bboxes"].append(box) + output["scores"].append(det["score"]) + return output + + class Calculator(Tool): r"""Calculator is a tool that can perform basic arithmetic operations.""" @@ -913,6 +965,7 @@ def __call__(self, equation: str) -> float: SegIoU, BboxContains, BoxDistance, + OCR, Calculator, ] )