Skip to content

Commit

Permalink
added OCR
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Apr 23, 2024
1 parent 3aca8c7 commit 6af8d9e
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ to pick it based on the tool description and use it based on the usage provided.
| ExtractFrames | ExtractFrames extracts frames with motion from a video. |
| ZeroShotCounting | ZeroShotCounting returns the total number of objects belonging to a single class in a given image |
| VisualPromptCounting | VisualPromptCounting returns the total number of objects belonging to a single class given an image and visual prompt |
| OCR | OCR returns the text detected in an image along with the location. |


It also has a basic set of calculate tools such as add, subtract, multiply and divide.
Expand Down
1 change: 1 addition & 0 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]
"dinov_",
"zero_shot_counting_",
"visual_prompt_counting_",
"ocr_",
]:
continue

Expand Down
1 change: 1 addition & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
from .tools import ( # Counter,
CLIP,
OCR,
TOOLS,
BboxArea,
BboxIoU,
Expand Down
53 changes: 53 additions & 0 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import io
import logging
import tempfile
from abc import ABC
Expand Down Expand Up @@ -868,6 +869,57 @@ def __call__(self, video_uri: str) -> List[Tuple[str, float]]:
return result


class OCR(Tool):
name = "ocr_"
description = "'ocr_' extracts text from an image."
usage = {
"required_parameters": [
{"name": "image", "type": "str"},
],
"examples": [
{
"scenario": "Can you extract the text from this image? Image name: image.png",
"parameters": {"image": "image.png"},
},
],
}
_API_KEY = "land_sk_WVYwP00xA3iXely2vuar6YUDZ3MJT9yLX6oW5noUkwICzYLiDV"
_URL = "https://app.landing.ai/ocr/v1/detect-text"

def __call__(self, image: str) -> dict:
pil_image = Image.open(image).convert("RGB")
image_size = pil_image.size[::-1]
image_buffer = io.BytesIO()
pil_image.save(image_buffer, format="PNG")
buffer_bytes = image_buffer.getvalue()
image_buffer.close()

res = requests.post(
self._URL,
files={"images": buffer_bytes},
data={"language": "en"},
headers={"contentType": "multipart/form-data", "apikey": self._API_KEY},
)
if res.status_code != 200:
_LOGGER.error(f"Request failed: {res.text}")
raise ValueError(f"Request failed: {res.text}")

data = res.json()
output = {"labels": [], "bboxes": [], "scores": []}
for det in data[0]:
output["labels"].append(det["text"])
box = [
det["location"][0]["x"],
det["location"][0]["y"],
det["location"][2]["x"],
det["location"][2]["y"],
]
box = normalize_bbox(box, image_size)
output["bboxes"].append(box)
output["scores"].append(det["score"])
return output


class Calculator(Tool):
r"""Calculator is a tool that can perform basic arithmetic operations."""

Expand Down Expand Up @@ -913,6 +965,7 @@ def __call__(self, equation: str) -> float:
SegIoU,
BboxContains,
BoxDistance,
OCR,
Calculator,
]
)
Expand Down

0 comments on commit 6af8d9e

Please sign in to comment.