diff --git a/README.md b/README.md index c09e5823..5ff3840f 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ pip install vision-agent ``` Ensure you have an OpenAI API key and set it as an environment variable (if you are -using Azure OpenAI please see the additional setup section): +using Azure OpenAI please see the Azure setup section): ```bash export OPENAI_API_KEY="your-api-key" @@ -96,6 +96,31 @@ you. For example: }] ``` +#### Custom Tools +You can also add your own custom tools for your vision agent to use: + +```python +>>> from vision_agent.tools import Tool, register_tool +>>> @register_tool +>>> class NumItems(Tool): +>>> name = "num_items_" +>>> description = "Returns the number of items in a list." +>>> usage = { +>>> "required_parameters": [{"name": "prompt", "type": "list"}], +>>> "examples": [ +>>> { +>>> "scenario": "How many items are in this list? ['a', 'b', 'c']", +>>> "parameters": {"prompt": "['a', 'b', 'c']"}, +>>> } +>>> ], +>>> } +>>> def __call__(self, prompt: list[str]) -> int: +>>> return len(prompt) +``` +This will register it with the list of tools Vision Agent has access to. It will be able +to pick it based on the tool description and use it based on the usage provided. + +#### Tool List | Tool | Description | | --- | --- | | CLIP | CLIP is a tool that can classify or tag any image given a set of input classes or tags. | @@ -111,11 +136,12 @@ you. For example: | ExtractFrames | ExtractFrames extracts frames with motion from a video. | | ZeroShotCounting | ZeroShotCounting returns the total number of objects belonging to a single class in a given image | | VisualPromptCounting | VisualPromptCounting returns the total number of objects belonging to a single class given an image and visual prompt | +| OCR | OCR returns the text detected in an image along with the location. | It also has a basic set of calculate tools such as add, subtract, multiply and divide. -### Additional Setup +### Azure Setup If you want to use Azure OpenAI models, you can set the environment variable: ```bash diff --git a/examples/custom_tools/pid.png b/examples/custom_tools/pid.png new file mode 100644 index 00000000..713b7317 Binary files /dev/null and b/examples/custom_tools/pid.png differ diff --git a/examples/custom_tools/pid_template.png b/examples/custom_tools/pid_template.png new file mode 100644 index 00000000..c736c6cb Binary files /dev/null and b/examples/custom_tools/pid_template.png differ diff --git a/examples/custom_tools/run_custom_tool.py b/examples/custom_tools/run_custom_tool.py new file mode 100644 index 00000000..beaa9eca --- /dev/null +++ b/examples/custom_tools/run_custom_tool.py @@ -0,0 +1,49 @@ +from template_match import template_matching_with_rotation + +import vision_agent as va +from vision_agent.image_utils import get_image_size, normalize_bbox +from vision_agent.tools import Tool, register_tool + + +@register_tool +class TemplateMatch(Tool): + name = "template_match_" + description = "'template_match_' takes a template image and finds all locations where that template appears in the input image." + usage = { + "required_parameters": [ + {"name": "target_image", "type": "str"}, + {"name": "template_image", "type": "str"}, + ], + "examples": [ + { + "scenario": "Can you detect the location of the template in the target image? Image name: target.png Reference image: template.png", + "parameters": { + "target_image": "target.png", + "template_image": "template.png", + }, + }, + ], + } + + def __call__(self, target_image: str, template_image: str) -> dict: + image_size = get_image_size(target_image) + matches = template_matching_with_rotation(target_image, template_image) + matches["bboxes"] = [ + normalize_bbox(box, image_size) for box in matches["bboxes"] + ] + return matches + + +if __name__ == "__main__": + agent = va.agent.VisionAgent(verbose=True) + resp, tools = agent.chat_with_workflow( + [ + { + "role": "user", + "content": "Can you find the locations of the pid_template.png in pid.png and tell me if any are nearby 'NOTE 5'?", + } + ], + image="pid.png", + reference_data={"image": "pid_template.png"}, + visualize_output=True, + ) diff --git a/examples/custom_tools/template_match.py b/examples/custom_tools/template_match.py new file mode 100644 index 00000000..1dd9fbe0 --- /dev/null +++ b/examples/custom_tools/template_match.py @@ -0,0 +1,96 @@ +import cv2 +import numpy as np +import torch +from torchvision.ops import nms + + +def rotate_image(mat, angle): + """ + Rotates an image (angle in degrees) and expands image to avoid cropping + """ + + height, width = mat.shape[:2] # image shape has 3 dimensions + image_center = ( + width / 2, + height / 2, + ) # getRotationMatrix2D needs coordinates in reverse order (width, height) compared to shape + + rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0) + + # rotation calculates the cos and sin, taking absolutes of those. + abs_cos = abs(rotation_mat[0, 0]) + abs_sin = abs(rotation_mat[0, 1]) + + # find the new width and height bounds + bound_w = int(height * abs_sin + width * abs_cos) + bound_h = int(height * abs_cos + width * abs_sin) + + # subtract old image center (bringing image back to origo) and adding the new image center coordinates + rotation_mat[0, 2] += bound_w / 2 - image_center[0] + rotation_mat[1, 2] += bound_h / 2 - image_center[1] + + # rotate image with the new bounds and translated rotation matrix + rotated_mat = cv2.warpAffine(mat, rotation_mat, (bound_w, bound_h)) + return rotated_mat + + +def template_matching_with_rotation( + main_image_path: str, + template_path: str, + max_rotation: int = 360, + step: int = 90, + threshold: float = 0.75, + visualize: bool = False, +) -> dict: + main_image = cv2.imread(main_image_path) + template = cv2.imread(template_path) + template_height, template_width = template.shape[:2] + + # Convert images to grayscale + main_image_gray = cv2.cvtColor(main_image, cv2.COLOR_BGR2GRAY) + template_gray = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY) + + boxes = [] + scores = [] + + for angle in range(0, max_rotation, step): + # Rotate the template + rotated_template = rotate_image(template_gray, angle) + + # Perform template matching + result = cv2.matchTemplate( + main_image_gray, + rotated_template, + cv2.TM_CCOEFF_NORMED, + ) + + y_coords, x_coords = np.where(result >= threshold) + for x, y in zip(x_coords, y_coords): + boxes.append( + (x, y, x + rotated_template.shape[1], y + rotated_template.shape[0]) + ) + scores.append(result[y, x]) + + indices = ( + nms( + torch.tensor(boxes).float(), + torch.tensor(scores).float(), + 0.2, + ) + .numpy() + .tolist() + ) + boxes = [boxes[i] for i in indices] + scores = [scores[i] for i in indices] + + if visualize: + # Draw a rectangle around the best match + for box in boxes: + cv2.rectangle(main_image, (box[0], box[1]), (box[2], box[3]), 255, 2) + + # Display the result + cv2.imshow("Best Match", main_image) + cv2.waitKey(0) + cv2.destroyAllWindows() + + return {"bboxes": boxes, "scores": scores} diff --git a/tests/tools/test_tools.py b/tests/tools/test_tools.py index 12c21347..6de8d6c8 100644 --- a/tests/tools/test_tools.py +++ b/tests/tools/test_tools.py @@ -2,8 +2,10 @@ import tempfile import numpy as np +import pytest from PIL import Image +from vision_agent.tools import TOOLS, Tool, register_tool from vision_agent.tools.tools import BboxIoU, BoxDistance, SegArea, SegIoU @@ -65,3 +67,71 @@ def test_box_distance(): box1 = [0, 0, 2, 2] box2 = [1, 1, 3, 3] assert box_dist(box1, box2) == 0.0 + + +def test_register_tool(): + assert TOOLS[len(TOOLS) - 1]["name"] != "test_tool_" + + @register_tool + class TestTool(Tool): + name = "test_tool_" + description = "Test Tool" + usage = { + "required_parameters": [{"name": "prompt", "type": "str"}], + "examples": [ + { + "scenario": "Test", + "parameters": {"prompt": "Test Prompt"}, + } + ], + } + + def __call__(self, prompt: str) -> str: + return prompt + + assert TOOLS[len(TOOLS) - 1]["name"] == "test_tool_" + + +def test_register_tool_incorrect(): + with pytest.raises(ValueError): + + @register_tool + class NoAttributes(Tool): + pass + + with pytest.raises(ValueError): + + @register_tool + class NoName(Tool): + description = "Test Tool" + usage = { + "required_parameters": [{"name": "prompt", "type": "str"}], + "examples": [ + { + "scenario": "Test", + "parameters": {"prompt": "Test Prompt"}, + } + ], + } + + with pytest.raises(ValueError): + + @register_tool + class NoDescription(Tool): + name = "test_tool_" + usage = { + "required_parameters": [{"name": "prompt", "type": "str"}], + "examples": [ + { + "scenario": "Test", + "parameters": {"prompt": "Test Prompt"}, + } + ], + } + + with pytest.raises(ValueError): + + @register_tool + class NoUsage(Tool): + name = "test_tool_" + description = "Test Tool" diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 8627b06c..76629f6a 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -377,6 +377,7 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]] "dinov_", "zero_shot_counting_", "visual_prompt_counting_", + "ocr_", ]: continue @@ -508,20 +509,20 @@ def chat_with_workflow( if image: question += f" Image name: {image}" if reference_data: - if not ( - "image" in reference_data - and ("mask" in reference_data or "bbox" in reference_data) - ): - raise ValueError( - f"Reference data must contain 'image' and a visual prompt which can be 'mask' or 'bbox'. but got {reference_data}" - ) - visual_prompt_data = ( - f"Reference mask: {reference_data['mask']}" + question += ( + f" Reference image: {reference_data['image']}" + if "image" in reference_data + else "" + ) + question += ( + f" Reference mask: {reference_data['mask']}" if "mask" in reference_data - else f"Reference bbox: {reference_data['bbox']}" + else "" ) question += ( - f" Reference image: {reference_data['image']}, {visual_prompt_data}" + f" Reference bbox: {reference_data['bbox']}" + if "bbox" in reference_data + else "" ) reflections = "" diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 38bb08d4..67248156 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -1,6 +1,7 @@ from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT from .tools import ( # Counter, CLIP, + OCR, TOOLS, BboxArea, BboxIoU, @@ -11,9 +12,10 @@ GroundingDINO, GroundingSAM, ImageCaption, - ZeroShotCounting, - VisualPromptCounting, SegArea, SegIoU, Tool, + VisualPromptCounting, + ZeroShotCounting, + register_tool, ) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index c964a895..a53e3c7f 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -1,8 +1,9 @@ +import io import logging import tempfile from abc import ABC from pathlib import Path -from typing import Any, Dict, List, Tuple, Union, cast +from typing import Any, Dict, List, Tuple, Type, Union, cast import numpy as np import requests @@ -11,10 +12,10 @@ from vision_agent.image_utils import ( convert_to_b64, + denormalize_bbox, get_image_size, - rle_decode, normalize_bbox, - denormalize_bbox, + rle_decode, ) from vision_agent.tools.video import extract_frames_from_video from vision_agent.type_defs import LandingaiAPIKey @@ -29,6 +30,9 @@ class Tool(ABC): description: str usage: Dict + def __call__(self, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError + class NoOp(Tool): name = "noop_" @@ -865,6 +869,57 @@ def __call__(self, video_uri: str) -> List[Tuple[str, float]]: return result +class OCR(Tool): + name = "ocr_" + description = "'ocr_' extracts text from an image." + usage = { + "required_parameters": [ + {"name": "image", "type": "str"}, + ], + "examples": [ + { + "scenario": "Can you extract the text from this image? Image name: image.png", + "parameters": {"image": "image.png"}, + }, + ], + } + _API_KEY = "land_sk_WVYwP00xA3iXely2vuar6YUDZ3MJT9yLX6oW5noUkwICzYLiDV" + _URL = "https://app.landing.ai/ocr/v1/detect-text" + + def __call__(self, image: str) -> dict: + pil_image = Image.open(image).convert("RGB") + image_size = pil_image.size[::-1] + image_buffer = io.BytesIO() + pil_image.save(image_buffer, format="PNG") + buffer_bytes = image_buffer.getvalue() + image_buffer.close() + + res = requests.post( + self._URL, + files={"images": buffer_bytes}, + data={"language": "en"}, + headers={"contentType": "multipart/form-data", "apikey": self._API_KEY}, + ) + if res.status_code != 200: + _LOGGER.error(f"Request failed: {res.text}") + raise ValueError(f"Request failed: {res.text}") + + data = res.json() + output: Dict[str, List] = {"labels": [], "bboxes": [], "scores": []} + for det in data[0]: + output["labels"].append(det["text"]) + box = [ + det["location"][0]["x"], + det["location"][0]["y"], + det["location"][2]["x"], + det["location"][2]["y"], + ] + box = normalize_bbox(box, image_size) + output["bboxes"].append(box) + output["scores"].append(round(det["score"], 2)) + return output + + class Calculator(Tool): r"""Calculator is a tool that can perform basic arithmetic operations.""" @@ -910,6 +965,7 @@ def __call__(self, equation: str) -> float: SegIoU, BboxContains, BoxDistance, + OCR, Calculator, ] ) @@ -917,6 +973,31 @@ def __call__(self, equation: str) -> float: } +def register_tool(tool: Type[Tool]) -> Type[Tool]: + r"""Add a tool to the list of available tools. + + Parameters: + tool: The tool to add. + """ + + if ( + not hasattr(tool, "name") + or not hasattr(tool, "description") + or not hasattr(tool, "usage") + ): + raise ValueError( + "The tool must have 'name', 'description' and 'usage' attributes." + ) + + TOOLS[len(TOOLS)] = { + "name": tool.name, + "description": tool.description, + "usage": tool.usage, + "class": tool, + } + return tool + + def _send_inference_request( payload: Dict[str, Any], endpoint_name: str ) -> Dict[str, Any]: