Skip to content

Commit

Permalink
Add Custom Tools & Overlay Heatmap (#100)
Browse files Browse the repository at this point in the history
* added overlay heatmap

* added overlays to utilities

* added register tool

* removed old examples

* fix readme for mask app

* fix issues for custom tools for example case

* fix issues for custom tool imports

* update readme

* fix azure doc

* fix type error
  • Loading branch information
dillonalaird authored May 30, 2024
1 parent 503b3e0 commit 91206b0
Show file tree
Hide file tree
Showing 16 changed files with 340 additions and 751 deletions.
35 changes: 32 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,34 @@ you. For example:
}]
```

You can also add custom tools to the agent:

```python
import vision_agent as va

@va.tools.register_tool(imports=["import numpy as np"])
def custom_tool(image_path: str) -> str:
"""My custom tool documentation.
Parameters:
image_path (str): The path to the image.
Returns:
str: The result of the tool.
Example
-------
>>> custom_tool("image.jpg")
"""

import numpy as np
return np.zeros((10, 10))
```

You need to ensure you call `@va.tools.register_tool` with any imports it might use and
ensure the documentation is in the same format above with description, `Parameters:`,
`Returns:`, and `Example\n-------`. You can find an example use case [here](examples/custom_tools/).

### Azure Setup
If you want to use Azure OpenAI models, you can set the environment variable:

Expand All @@ -133,9 +161,10 @@ You can then run Vision Agent using the Azure OpenAI models:
```python
>>> import vision_agent as va
>>> agent = va.agent.VisionAgent(
>>> task_model=va.llm.AzureOpenAILLM(),
>>> answer_model=va.lmm.AzureOpenAILMM(),
>>> reflection_model=va.lmm.AzureOpenAILMM(),
>>> planner=va.llm.AzureOpenAILLM(),
>>> coder=va.lmm.AzureOpenAILMM(),
>>> tester=va.lmm.AzureOpenAILMM(),
>>> debugger=va.lmm.AzureOpenAILMM(),
>>> )
```

178 changes: 115 additions & 63 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# 🔍🤖 Vision Agent
Vision Agent is a library that helps you utilize agent frameworks to generate code to
solve your vision task. Many current vision problems can easily take hours or days to
solve, you need to find the right model, figure out how to use it and program it to
accomplish the task you want. Vision Agent aims to provide an in-seconds experience by
allowing users to describe their problem in text and have the agent framework generate
code to solve the task for them. Check out our discord for updates and roadmaps!

## Documentation

- [Vision Agent Library Docs](https://landing-ai.github.io/vision-agent/)

Vision Agent is a library that helps you utilize agent frameworks for your vision tasks.
Many current vision problems can easily take hours or days to solve, you need to find the
right model, figure out how to use it, possibly write programming logic around it to
accomplish the task you want or even more expensive, train your own model. Vision Agent
aims to provide an in-seconds experience by allowing users to describe their problem in
text and utilizing agent frameworks to solve the task for them. Check out our discord
for updates and roadmaps!

## Getting Started
### Installation
Expand All @@ -16,52 +19,79 @@ To get started, you can install the library using pip:
pip install vision-agent
```

Ensure you have an OpenAI API key and set it as an environment variable:
Ensure you have an OpenAI API key and set it as an environment variable (if you are
using Azure OpenAI please see the Azure setup section):

```bash
export OPENAI_API_KEY="your-api-key"
```

### Vision Agents
You can interact with the agents as you would with any LLM or LMM model:
### Vision Agent
You can interact with the agent as you would with any LLM or LMM model:

```python
>>> import vision_agent as va
>>> from vision_agent.agent import VisionAgent
>>> agent = VisionAgent()
>>> agent("How many apples are in this image?", image="apples.jpg")
"There are 2 apples in the image."
>>> code = agent("What percentage of the area of the jar is filled with coffee beans?", media="jar.jpg")
```

Which produces the following code:
```python
from vision_agent.tools import load_image, grounding_sam

def calculate_filled_percentage(image_path: str) -> float:
# Step 1: Load the image
image = load_image(image_path)

# Step 2: Segment the jar
jar_segments = grounding_sam(prompt="jar", image=image)

# Step 3: Segment the coffee beans
coffee_beans_segments = grounding_sam(prompt="coffee beans", image=image)

# Step 4: Calculate the area of the segmented jar
jar_area = 0
for segment in jar_segments:
jar_area += segment['mask'].sum()

# Step 5: Calculate the area of the segmented coffee beans
coffee_beans_area = 0
for segment in coffee_beans_segments:
coffee_beans_area += segment['mask'].sum()

# Step 6: Compute the percentage of the jar area that is filled with coffee beans
if jar_area == 0:
return 0.0 # To avoid division by zero
filled_percentage = (coffee_beans_area / jar_area) * 100

# Step 7: Return the computed percentage
return filled_percentage
```

To better understand how the model came up with it's answer, you can also run it in
debug mode by passing in the verbose argument:
To better understand how the model came up with it's answer, you can run it in debug
mode by passing in the verbose argument:

```python
>>> agent = VisionAgent(verbose=True)
>>> agent = VisionAgent(verbose=2)
```

You can also have it return the workflow it used to complete the task along with all
the individual steps and tools to get the answer:
You can also have it return more information by calling `chat_with_workflow`:

```python
>>> resp, workflow = agent.chat_with_workflow([{"role": "user", "content": "How many apples are in this image?"}], image="apples.jpg")
>>> print(workflow)
[{"task": "Count the number of apples using 'grounding_dino_'.",
"tool": "grounding_dino_",
"parameters": {"prompt": "apple", "image": "apples.jpg"},
"call_results": [[
{
"labels": ["apple", "apple"],
"scores": [0.99, 0.95],
"bboxes": [
[0.58, 0.2, 0.72, 0.45],
[0.94, 0.57, 0.98, 0.66],
]
}
]],
"answer": "There are 2 apples in the image.",
}]
>>> results = agent.chat_with_workflow([{"role": "user", "content": "What percentage of the area of the jar is filled with coffee beans?"}], media="jar.jpg")
>>> print(results)
{
"code": "from vision_agent.tools import ..."
"test": "calculate_filled_percentage('jar.jpg')",
"test_result": "...",
"plan": [{"code": "...", "test": "...", "plan": "..."}, ...],
"working_memory": ...,
}
```

With this you can examine more detailed information such as the etesting code, testing
results, plan or working memory it used to complete the task.

### Tools
There are a variety of tools for the model or the user to use. Some are executed locally
while others are hosted for you. You can also ask an LLM directly to build a tool for
Expand All @@ -70,37 +100,59 @@ you. For example:
```python
>>> import vision_agent as va
>>> llm = va.llm.OpenAILLM()
>>> detector = llm.generate_detector("Can you build an apple detector for me?")
>>> detector("apples.jpg")
[{"labels": ["apple", "apple"],
"scores": [0.99, 0.95],
>>> detector = llm.generate_detector("Can you build a jar detector for me?")
>>> detector("jar.jpg")
[{"labels": ["jar",],
"scores": [0.99],
"bboxes": [
[0.58, 0.2, 0.72, 0.45],
[0.94, 0.57, 0.98, 0.66],
]
}]
```

| Tool | Description |
| --- | --- |
| CLIP | CLIP is a tool that can classify or tag any image given a set of input classes or tags. |
| ImageCaption| ImageCaption is a tool that can generate a caption for an image. |
| GroundingDINO | GroundingDINO is a tool that can detect arbitrary objects with inputs such as category names or referring expressions. |
| GroundingSAM | GroundingSAM is a tool that can detect and segment arbitrary objects with inputs such as category names or referring expressions. |
| DINOv | DINOv is a tool that can detect arbitrary objects with using a referring mask. |
| Crop | Crop crops an image given a bounding box and returns a file name of the cropped image. |
| BboxArea | BboxArea returns the area of the bounding box in pixels normalized to 2 decimal places. |
| SegArea | SegArea returns the area of the segmentation mask in pixels normalized to 2 decimal places. |
| BboxIoU | BboxIoU returns the intersection over union of two bounding boxes normalized to 2 decimal places. |
| SegIoU | SegIoU returns the intersection over union of two segmentation masks normalized to 2 decimal places. |
| BoxDistance | BoxDistance returns the minimum distance between two bounding boxes normalized to 2 decimal places. |
| MaskDistance | MaskDistance returns the minimum distance between two segmentation masks in pixel units |
| BboxContains | BboxContains returns the intersection of two boxes over the target box area. It is good for check if one box is contained within another box. |
| ExtractFrames | ExtractFrames extracts frames with motion from a video. |
| ZeroShotCounting | ZeroShotCounting returns the total number of objects belonging to a single class in a given image. |
| VisualPromptCounting | VisualPromptCounting returns the total number of objects belonging to a single class given an image and visual prompt. |
| VisualQuestionAnswering | VisualQuestionAnswering is a tool that can explain the contents of an image and answer questions about the image. |
| ImageQuestionAnswering | ImageQuestionAnswering is similar to VisualQuestionAnswering but does not rely on OpenAI and instead uses a dedicated model for the task. |
| OCR | OCR returns the text detected in an image along with the location. |

It also has a basic set of calculate tools such as add, subtract, multiply and divide.
You can also add custom tools to the agent:

```python
import vision_agent as va

@va.tools.register_tool(imports=["import numpy as np"])
def custom_tool(image_path: str) -> str:
"""My custom tool documentation.
Parameters:
image_path (str): The path to the image.
Returns:
str: The result of the tool.
Example
-------
>>> custom_tool("image.jpg")
"""

import numpy as np
return np.zeros((10, 10))
```

You need to ensure you call `@va.tools.register_tool` with any imports it might use and
ensure the documentation is in the same format above with description, `Parameters:`,
`Returns:`, and `Example\n-------`. You can find an example use case [here](examples/custom_tools/).

### Azure Setup
If you want to use Azure OpenAI models, you can set the environment variable:

```bash
export AZURE_OPENAI_API_KEY="your-api-key"
export AZURE_OPENAI_ENDPOINT="your-endpoint"
```

You can then run Vision Agent using the Azure OpenAI models:

```python
>>> import vision_agent as va
>>> agent = va.agent.VisionAgent(
>>> planner=va.llm.AzureOpenAILLM(),
>>> coder=va.lmm.AzureOpenAILMM(),
>>> tester=va.lmm.AzureOpenAILMM(),
>>> debugger=va.lmm.AzureOpenAILMM(),
>>> )
27 changes: 24 additions & 3 deletions examples/custom_tools/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Template Matching Custom Tool

## Example

This demo shows you how to create a custom tool for template matching that your Vision
Agent can then use to help you answer questions. To get started, you can install the
requirements by running:
Expand All @@ -20,10 +22,29 @@ call out which tool you want to use. For example:
```python
import vision_agent as va

agent = va.agent.VisionAgent(verbose=True)
agent = va.agent.VisionAgent(verbosity=2)
agent(
"Can you use the 'template_match_' tool to find the location of pid_template.png in pid.png?",
image="pid.png",
reference_data={"image": "pid_template.png"},
media="pid.png",
)
```

## Details
Because we execute code on a separate process, we need to re-register the tools inside
the new process. To do this, `register_tools` copies the source code and prepends it to
the code that is executed in the new process. But there's a catch, it cannot copy the
imports needed to run the tool code. To solve this, you must pass in the necessary
imports in the `register_tool` like so:

```python
import vision_agent as va

@va.register_tool(
imports=["import cv2"],
)
def custom_tool(*args):
# Your tool code here
pass
```

This way the code executed in the new process will have the necessary imports to run.
71 changes: 36 additions & 35 deletions examples/custom_tools/run_custom_tool.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,50 @@
import numpy as np
from template_match import template_matching_with_rotation

import vision_agent as va
from vision_agent.image_utils import get_image_size, normalize_bbox
from vision_agent.tools import Tool, register_tool


@register_tool
class TemplateMatch(Tool):
name = "template_match_"
description = "'template_match_' takes a template image and finds all locations where that template appears in the input image."
usage = {
"required_parameters": [
{"name": "target_image", "type": "str"},
{"name": "template_image", "type": "str"},
],
"examples": [
{
"scenario": "Can you detect the location of the template in the target image? Image name: target.png Reference image: template.png",
"parameters": {
"target_image": "target.png",
"template_image": "template.png",
},
},
],
}
from vision_agent.utils.image_utils import get_image_size, normalize_bbox


@va.tools.register_tool(
imports=[
"import numpy as np",
"from vision_agent.utils.image_utils import get_image_size, normalize_bbox",
"from template_match import template_matching_with_rotation",
]
)
def template_match(target_image: np.ndarray, template_image: np.ndarray) -> dict:
"""'template_match' tool that finds the locations of the template image in the
target image.
Parameters:
target_image (np.ndarray): The target image.
template_image (np.ndarray): The template image.
Returns:
dict: A dictionary containing the bounding boxes of the matches.
Example
-------
>>> import cv2
>>> target_image = cv2.imread("pid.png")
>>> template_image = cv2.imread("pid_template.png")
>>> matches = template_match(target_image, template_image)
"""

def __call__(self, target_image: str, template_image: str) -> dict:
image_size = get_image_size(target_image)
matches = template_matching_with_rotation(target_image, template_image)
matches["bboxes"] = [
normalize_bbox(box, image_size) for box in matches["bboxes"]
]
return matches
image_size = get_image_size(target_image)
matches = template_matching_with_rotation(target_image, template_image)
matches["bboxes"] = [normalize_bbox(box, image_size) for box in matches["bboxes"]]
return matches


if __name__ == "__main__":
agent = va.agent.VisionAgent(verbose=True)
resp, tools = agent.chat_with_workflow(
agent = va.agent.VisionAgent(verbosity=2)
result = agent.chat_with_workflow(
[
{
"role": "user",
"content": "Can you find the locations of the pid_template.png in pid.png and tell me if any are nearby 'NOTE 5'?",
}
],
image="pid.png",
reference_data={"image": "pid_template.png"},
visualize_output=True,
media="pid.png",
)
Loading

0 comments on commit 91206b0

Please sign in to comment.