Skip to content

Commit

Permalink
Add Custom Tools (#61)
Browse files Browse the repository at this point in the history
* added custom tools

* updated readme

* register tool returns tool'

* Add a new tool: determine if a bbox is contained within another bbox (#59)

* Add a new bounding box contains tool

* Fix format

* [skip ci] chore(release): vision-agent 0.1.5

* Add Count tools (#56)

* Adding counting tools to vision agent

* fixed heatmap overlay and addressesessed PR comments

* adding the counting tool to take both absolute coordinate and normalized coordinates, refactoring code, adding llm generate counter tool

* fix linting

* Remove torch and cuda dependencies (#60)

Resolve merge conflicts

* [skip ci] chore(release): vision-agent 0.2.1

* make it easier to use custom tools

* ran isort

* fix linting error

* added OCR

* added example template matching use case

* formatting and typing fix

* round scores

* fix readme typo

---------

Co-authored-by: Asia <[email protected]>
Co-authored-by: GitHub Actions Bot <[email protected]>
Co-authored-by: Shankar <[email protected]>
  • Loading branch information
4 people authored Apr 24, 2024
1 parent e46cff5 commit c0dde36
Show file tree
Hide file tree
Showing 9 changed files with 343 additions and 18 deletions.
30 changes: 28 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pip install vision-agent
```

Ensure you have an OpenAI API key and set it as an environment variable (if you are
using Azure OpenAI please see the additional setup section):
using Azure OpenAI please see the Azure setup section):

```bash
export OPENAI_API_KEY="your-api-key"
Expand Down Expand Up @@ -96,6 +96,31 @@ you. For example:
}]
```

#### Custom Tools
You can also add your own custom tools for your vision agent to use:

```python
>>> from vision_agent.tools import Tool, register_tool
>>> @register_tool
>>> class NumItems(Tool):
>>> name = "num_items_"
>>> description = "Returns the number of items in a list."
>>> usage = {
>>> "required_parameters": [{"name": "prompt", "type": "list"}],
>>> "examples": [
>>> {
>>> "scenario": "How many items are in this list? ['a', 'b', 'c']",
>>> "parameters": {"prompt": "['a', 'b', 'c']"},
>>> }
>>> ],
>>> }
>>> def __call__(self, prompt: list[str]) -> int:
>>> return len(prompt)
```
This will register it with the list of tools Vision Agent has access to. It will be able
to pick it based on the tool description and use it based on the usage provided.

#### Tool List
| Tool | Description |
| --- | --- |
| CLIP | CLIP is a tool that can classify or tag any image given a set of input classes or tags. |
Expand All @@ -114,11 +139,12 @@ you. For example:
| ExtractFrames | ExtractFrames extracts frames with motion from a video. |
| ZeroShotCounting | ZeroShotCounting returns the total number of objects belonging to a single class in a given image |
| VisualPromptCounting | VisualPromptCounting returns the total number of objects belonging to a single class given an image and visual prompt |
| OCR | OCR returns the text detected in an image along with the location. |


It also has a basic set of calculate tools such as add, subtract, multiply and divide.

### Additional Setup
### Azure Setup
If you want to use Azure OpenAI models, you can set the environment variable:

```bash
Expand Down
Binary file added examples/custom_tools/pid.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/custom_tools/pid_template.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
49 changes: 49 additions & 0 deletions examples/custom_tools/run_custom_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from template_match import template_matching_with_rotation

import vision_agent as va
from vision_agent.image_utils import get_image_size, normalize_bbox
from vision_agent.tools import Tool, register_tool


@register_tool
class TemplateMatch(Tool):
name = "template_match_"
description = "'template_match_' takes a template image and finds all locations where that template appears in the input image."
usage = {
"required_parameters": [
{"name": "target_image", "type": "str"},
{"name": "template_image", "type": "str"},
],
"examples": [
{
"scenario": "Can you detect the location of the template in the target image? Image name: target.png Reference image: template.png",
"parameters": {
"target_image": "target.png",
"template_image": "template.png",
},
},
],
}

def __call__(self, target_image: str, template_image: str) -> dict:
image_size = get_image_size(target_image)
matches = template_matching_with_rotation(target_image, template_image)
matches["bboxes"] = [
normalize_bbox(box, image_size) for box in matches["bboxes"]
]
return matches


if __name__ == "__main__":
agent = va.agent.VisionAgent(verbose=True)
resp, tools = agent.chat_with_workflow(
[
{
"role": "user",
"content": "Can you find the locations of the pid_template.png in pid.png and tell me if any are nearby 'NOTE 5'?",
}
],
image="pid.png",
reference_data={"image": "pid_template.png"},
visualize_output=True,
)
96 changes: 96 additions & 0 deletions examples/custom_tools/template_match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import cv2
import numpy as np
import torch
from torchvision.ops import nms


def rotate_image(mat, angle):
"""
Rotates an image (angle in degrees) and expands image to avoid cropping
"""

height, width = mat.shape[:2] # image shape has 3 dimensions
image_center = (
width / 2,
height / 2,
) # getRotationMatrix2D needs coordinates in reverse order (width, height) compared to shape

rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)

# rotation calculates the cos and sin, taking absolutes of those.
abs_cos = abs(rotation_mat[0, 0])
abs_sin = abs(rotation_mat[0, 1])

# find the new width and height bounds
bound_w = int(height * abs_sin + width * abs_cos)
bound_h = int(height * abs_cos + width * abs_sin)

# subtract old image center (bringing image back to origo) and adding the new image center coordinates
rotation_mat[0, 2] += bound_w / 2 - image_center[0]
rotation_mat[1, 2] += bound_h / 2 - image_center[1]

# rotate image with the new bounds and translated rotation matrix
rotated_mat = cv2.warpAffine(mat, rotation_mat, (bound_w, bound_h))
return rotated_mat


def template_matching_with_rotation(
main_image_path: str,
template_path: str,
max_rotation: int = 360,
step: int = 90,
threshold: float = 0.75,
visualize: bool = False,
) -> dict:
main_image = cv2.imread(main_image_path)
template = cv2.imread(template_path)
template_height, template_width = template.shape[:2]

# Convert images to grayscale
main_image_gray = cv2.cvtColor(main_image, cv2.COLOR_BGR2GRAY)
template_gray = cv2.cvtColor(template, cv2.COLOR_BGR2GRAY)

boxes = []
scores = []

for angle in range(0, max_rotation, step):
# Rotate the template
rotated_template = rotate_image(template_gray, angle)

# Perform template matching
result = cv2.matchTemplate(
main_image_gray,
rotated_template,
cv2.TM_CCOEFF_NORMED,
)

y_coords, x_coords = np.where(result >= threshold)
for x, y in zip(x_coords, y_coords):
boxes.append(
(x, y, x + rotated_template.shape[1], y + rotated_template.shape[0])
)
scores.append(result[y, x])

indices = (
nms(
torch.tensor(boxes).float(),
torch.tensor(scores).float(),
0.2,
)
.numpy()
.tolist()
)
boxes = [boxes[i] for i in indices]
scores = [scores[i] for i in indices]

if visualize:
# Draw a rectangle around the best match
for box in boxes:
cv2.rectangle(main_image, (box[0], box[1]), (box[2], box[3]), 255, 2)

# Display the result
cv2.imshow("Best Match", main_image)
cv2.waitKey(0)
cv2.destroyAllWindows()

return {"bboxes": boxes, "scores": scores}
70 changes: 70 additions & 0 deletions tests/tools/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import tempfile

import numpy as np
import pytest
from PIL import Image

from vision_agent.tools import TOOLS, Tool, register_tool
from vision_agent.tools.tools import BboxIoU, BoxDistance, SegArea, SegIoU


Expand Down Expand Up @@ -65,3 +67,71 @@ def test_box_distance():
box1 = [0, 0, 2, 2]
box2 = [1, 1, 3, 3]
assert box_dist(box1, box2) == 0.0


def test_register_tool():
assert TOOLS[len(TOOLS) - 1]["name"] != "test_tool_"

@register_tool
class TestTool(Tool):
name = "test_tool_"
description = "Test Tool"
usage = {
"required_parameters": [{"name": "prompt", "type": "str"}],
"examples": [
{
"scenario": "Test",
"parameters": {"prompt": "Test Prompt"},
}
],
}

def __call__(self, prompt: str) -> str:
return prompt

assert TOOLS[len(TOOLS) - 1]["name"] == "test_tool_"


def test_register_tool_incorrect():
with pytest.raises(ValueError):

@register_tool
class NoAttributes(Tool):
pass

with pytest.raises(ValueError):

@register_tool
class NoName(Tool):
description = "Test Tool"
usage = {
"required_parameters": [{"name": "prompt", "type": "str"}],
"examples": [
{
"scenario": "Test",
"parameters": {"prompt": "Test Prompt"},
}
],
}

with pytest.raises(ValueError):

@register_tool
class NoDescription(Tool):
name = "test_tool_"
usage = {
"required_parameters": [{"name": "prompt", "type": "str"}],
"examples": [
{
"scenario": "Test",
"parameters": {"prompt": "Test Prompt"},
}
],
}

with pytest.raises(ValueError):

@register_tool
class NoUsage(Tool):
name = "test_tool_"
description = "Test Tool"
23 changes: 12 additions & 11 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]
"dinov_",
"zero_shot_counting_",
"visual_prompt_counting_",
"ocr_",
]:
continue

Expand Down Expand Up @@ -523,20 +524,20 @@ def chat_with_workflow(
if image:
question += f" Image name: {image}"
if reference_data:
if not (
"image" in reference_data
and ("mask" in reference_data or "bbox" in reference_data)
):
raise ValueError(
f"Reference data must contain 'image' and a visual prompt which can be 'mask' or 'bbox'. but got {reference_data}"
)
visual_prompt_data = (
f"Reference mask: {reference_data['mask']}"
question += (
f" Reference image: {reference_data['image']}"
if "image" in reference_data
else ""
)
question += (
f" Reference mask: {reference_data['mask']}"
if "mask" in reference_data
else f"Reference bbox: {reference_data['bbox']}"
else ""
)
question += (
f" Reference image: {reference_data['image']}, {visual_prompt_data}"
f" Reference bbox: {reference_data['bbox']}"
if "bbox" in reference_data
else ""
)

reflections = ""
Expand Down
6 changes: 4 additions & 2 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
from .tools import ( # Counter,
CLIP,
OCR,
TOOLS,
BboxArea,
BboxIoU,
Expand All @@ -11,9 +12,10 @@
GroundingDINO,
GroundingSAM,
ImageCaption,
ZeroShotCounting,
VisualPromptCounting,
SegArea,
SegIoU,
Tool,
VisualPromptCounting,
ZeroShotCounting,
register_tool,
)
Loading

0 comments on commit c0dde36

Please sign in to comment.