Skip to content

Commit

Permalink
EasyTool demo (#22)
Browse files Browse the repository at this point in the history
* fixing prompting and failure cases

* fix typo

* minimize description of tools, add test tools

* added counter tool

* Finish ImageSearch

* fix counter class

* Remove keys

* updated docs

* remove image search

* fixed typign issue

* ran isort

* remove extra import

* fix imports

---------

Co-authored-by: Yazhou Cao <[email protected]>
  • Loading branch information
dillonalaird and AsiaCao authored Mar 22, 2024
1 parent b3fde00 commit 4d5c6fa
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 54 deletions.
2 changes: 1 addition & 1 deletion vision_agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .agent import Agent
from .data import DataStore, build_data_store
from .emb import Embedder, OpenAIEmb, SentenceTransformerEmb, get_embedder
from .llm import LLM, OpenAILLM
from .lmm import LMM, LLaVALMM, OpenAILMM, get_lmm
from .agent import Agent
2 changes: 1 addition & 1 deletion vision_agent/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .agent import Agent
from .reflexion import Reflexion
from .easytool import EasyTool
from .reflexion import Reflexion
18 changes: 11 additions & 7 deletions vision_agent/agent/easytool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from vision_agent import LLM, LMM, OpenAILLM
from vision_agent.llm import LLM, OpenAILLM
from vision_agent.lmm import LMM
from vision_agent.tools import TOOLS

from .agent import Agent
Expand Down Expand Up @@ -42,10 +43,10 @@ def change_name(name: str) -> str:

def format_tools(tools: Dict[int, Any]) -> str:
# Format this way so it's clear what the ID's are
tool_list = []
tool_str = ""
for key in tools:
tool_list.append(f"ID: {key}, {tools[key]}\\n")
return str(tool_list)
tool_str += f"ID: {key}, {tools[key]}\n"
return tool_str


def task_decompose(
Expand Down Expand Up @@ -151,7 +152,11 @@ def answer_summarize(


def function_call(tool: Callable, parameters: Dict[str, Any]) -> Any:
return tool()(**parameters)
try:
return tool()(**parameters)
except Exception as e:
_LOGGER.error(f"Failed function_call on: {e}")
return None


def retrieval(
Expand All @@ -160,7 +165,6 @@ def retrieval(
tools: Dict[int, Any],
previous_log: str,
) -> Tuple[List[Dict], str]:
# TODO: remove tools_used?
tool_id = choose_tool(
model, question, {k: v["description"] for k, v in tools.items()}
)
Expand Down Expand Up @@ -200,7 +204,7 @@ def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any:
call_results.extend(parse_tool_results(result))
tool_results[i]["call_results"] = call_results

call_results_str = "\n\n".join([str(e) for e in call_results])
call_results_str = "\n\n".join([str(e) for e in call_results if e is not None])
_LOGGER.info(f"\tCall Results: {call_results_str}")
return tool_results, call_results_str

Expand Down
7 changes: 4 additions & 3 deletions vision_agent/agent/reflexion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

from vision_agent import LLM, LMM, OpenAILLM
from vision_agent.llm import LLM, OpenAILLM
from vision_agent.lmm import LMM

from .agent import Agent
from .reflexion_prompts import (
Expand Down Expand Up @@ -114,7 +115,7 @@ def __init__(
self.reflect_prompt = reflect_prompt
self.finsh_prompt = finsh_prompt
self.cot_examples = cot_examples
self.refelct_examples = reflect_examples
self.reflect_examples = reflect_examples
self.reflections: List[str] = []
if verbose:
_LOGGER.setLevel(logging.INFO)
Expand Down Expand Up @@ -273,7 +274,7 @@ def _build_reflect_prompt(
self, question: str, context: str = "", scratchpad: str = ""
) -> str:
return self.reflect_prompt.format(
examples=self.refelct_examples,
examples=self.reflect_examples,
context=context,
question=question,
scratchpad=scratchpad,
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
from .tools import CLIP, TOOLS, GroundingDINO, GroundingSAM, Tool
from .tools import CLIP, TOOLS, Counter, Crop, GroundingDINO, GroundingSAM, Tool
177 changes: 136 additions & 41 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
import tempfile
from abc import ABC
from collections import Counter as CounterClass
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union, cast

import numpy as np
import requests
from PIL import Image
from PIL.Image import Image as ImageType

from vision_agent.image_utils import convert_to_b64, get_image_size
Expand Down Expand Up @@ -52,19 +55,16 @@ class CLIP(Tool):
or tags.
Examples::
>>> from vision_agent.tools import tools
>>> t = tools.CLIP(["red line", "yellow dot", "none"])
>>> t("examples/img/ct_scan1.jpg"))
>>> [[0.02567436918616295, 0.9534115791320801, 0.020914122462272644]]
>>> import vision_agent as va
>>> clip = va.tools.CLIP()
>>> clip(["red line", "yellow dot"], "ct_scan1.jpg"))
>>> [{"labels": ["red line", "yellow dot"], "scores": [0.98, 0.02]}]
"""

_ENDPOINT = "https://rb4ii6dfacmwqfxivi4aedyyfm0endsv.lambda-url.us-east-2.on.aws"

name = "clip_"
description = (
"'clip_' is a tool that can classify or tag any image given a set if input classes or tags."
"Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n"
)
description = "'clip_' is a tool that can classify or tag any image given a set if input classes or tags."
usage = {
"required_parameters": [
{"name": "prompt", "type": "List[str]"},
Expand Down Expand Up @@ -106,22 +106,30 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict
) or "statusCode" not in resp_json:
_LOGGER.error(f"Request failed: {resp_json}")
raise ValueError(f"Request failed: {resp_json}")
return cast(List[Dict], resp_json["data"])

rets = []
for elt in resp_json["data"]:
rets.append({"labels": prompt, "scores": [round(prob, 2) for prob in elt]})
return cast(List[Dict], rets)


class GroundingDINO(Tool):
r"""Grounding DINO is a tool that can detect arbitrary objects with inputs such as
category names or referring expressions.
Examples::
>>> import vision_agent as va
>>> t = va.tools.GroundingDINO()
>>> t("red line. yellow dot", "ct_scan1.jpg")
>>> [{'labels': ['red line', 'yellow dot'],
>>> 'bboxes': [[0.38, 0.15, 0.59, 0.7], [0.48, 0.25, 0.69, 0.71]],
>>> 'scores': [0.98, 0.02]}]
"""

_ENDPOINT = "https://chnicr4kes5ku77niv2zoytggq0qyqlp.lambda-url.us-east-2.on.aws"

name = "grounding_dino_"
description = (
"'grounding_dino_' is a tool that can detect arbitrary objects with inputs such as category names or referring expressions."
"Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n"
"The tool returns a list of dictionaries, each containing the following keys:\n"
' - "label": The label of the detected object.\n'
' - "score": The confidence score of the detection.\n'
' - "bbox": The bounding box of the detected object. The box coordinates are normalize to [0, 1]\n'
'An example output would be: [{"label": ["car"], "score": [0.99], "bbox": [[0.1, 0.2, 0.3, 0.4]]}]\n'
)
description = "'grounding_dino_' is a tool that can detect arbitrary objects with inputs such as category names or referring expressions."
usage = {
"required_parameters": [
{"name": "prompt", "type": "str"},
Expand Down Expand Up @@ -180,27 +188,27 @@ class GroundingSAM(Tool):
inputs such as category names or referring expressions.
Examples::
>>> from vision_agent.tools import tools
>>> t = tools.GroundingSAM(["red line", "yellow dot", "none"])
>>> t("examples/img/ct_scan1.jpg")
>>> [{'label': 'none', 'mask': array([[0, 0, 0, ..., 0, 0, 0],
>>> import vision_agent as va
>>> t = va.tools.GroundingSAM()
>>> t(["red line", "yellow dot"], ct_scan1.jpg"])
>>> [{'labels': ['yellow dot', 'red line'],
>>> 'bboxes': [[0.38, 0.15, 0.59, 0.7], [0.48, 0.25, 0.69, 0.71]],
>>> 'masks': [array([[0, 0, 0, ..., 0, 0, 0],
>>> [0, 0, 0, ..., 0, 0, 0],
>>> ...,
>>> [0, 0, 0, ..., 0, 0, 0],
>>> [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)}, {'label': 'red line', 'mask': array([[0, 0, 0, ..., 0, 0, 0],
>>> [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)},
>>> array([[0, 0, 0, ..., 0, 0, 0],
>>> [0, 0, 0, ..., 0, 0, 0],
>>> ...,
>>> [1, 1, 1, ..., 1, 1, 1],
>>> [1, 1, 1, ..., 1, 1, 1]], dtype=uint8)}]
>>> [1, 1, 1, ..., 1, 1, 1]], dtype=uint8)]}]
"""

_ENDPOINT = "https://cou5lfmus33jbddl6hoqdfbw7e0qidrw.lambda-url.us-east-2.on.aws"

name = "grounding_sam_"
description = (
"'grounding_sam_' is a tool that can detect and segment arbitrary objects with inputs such as category names or referring expressions."
"Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n"
)
description = "'grounding_sam_' is a tool that can detect and segment arbitrary objects with inputs such as category names or referring expressions."
usage = {
"required_parameters": [
{"name": "prompt", "type": "List[str]"},
Expand All @@ -226,6 +234,7 @@ class GroundingSAM(Tool):
}

def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]:
image_size = get_image_size(image)
image_b64 = convert_to_b64(image)
data = {
"classes": prompt,
Expand All @@ -243,24 +252,100 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict
_LOGGER.error(f"Request failed: {resp_json}")
raise ValueError(f"Request failed: {resp_json}")
resp_data = resp_json["data"]
preds = []
ret_pred: Dict[str, List] = {"labels": [], "bboxes": [], "masks": []}
for pred in resp_data["preds"]:
encoded_mask = pred["encoded_mask"]
mask = rle_decode(mask_rle=encoded_mask, shape=pred["mask_shape"])
preds.append(
{
"label": pred["label_name"],
"mask": mask,
}
)
return preds
ret_pred["labels"].append(pred["label_name"])
ret_pred["bboxes"].append(normalize_bbox(pred["bbox"], image_size))
ret_pred["masks"].append(mask)
ret_preds = [ret_pred]
return ret_preds


class AgentGroundingSAM(GroundingSAM):
r"""AgentGroundingSAM is the same as GroundingSAM but it saves the masks as files
returns the file name. This makes it easier for agents to use.
"""

def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]:
rets = super().__call__(prompt, image)
for ret in rets:
mask_files = []
for mask in ret["masks"]:
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
Image.fromarray(mask * 255).save(tmp)
mask_files.append(tmp.name)
ret["masks"] = mask_files
return rets


class Counter(Tool):
name = "counter_"
description = "'counter_' detects and counts the number of objects in an image given an input such as a category name or referring expression."
usage = {
"required_parameters": [
{"name": "prompt", "type": "str"},
{"name": "image", "type": "str"},
],
"examples": [
{
"scenario": "Can you count the number of cars in this image? Image name image.jpg",
"parameters": {"prompt": "car", "image": "image.jpg"},
},
{
"scenario": "Can you count the number of people? Image name: people.png",
"parameters": {"prompt": "person", "image": "people.png"},
},
],
}

def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict:
resp = GroundingDINO()(prompt, image)
return dict(CounterClass(resp[0]["labels"]))


class Crop(Tool):
name = "crop_"
description = "'crop_' crops an image given a bounding box and returns a file name of the cropped image."
usage = {
"required_parameters": [
{"name": "bbox", "type": "List[float]"},
{"name": "image", "type": "str"},
],
"examples": [
{
"scenario": "Can you crop the image to the bounding box [0.1, 0.1, 0.9, 0.9]? Image name: image.jpg",
"parameters": {"bbox": [0.1, 0.1, 0.9, 0.9], "image": "image.jpg"},
},
{
"scenario": "Cut out the image to the bounding box [0.2, 0.2, 0.8, 0.8]. Image name: car.jpg",
"parameters": {"bbox": [0.2, 0.2, 0.8, 0.8], "image": "car.jpg"},
},
],
}

def __call__(self, bbox: List[float], image: Union[str, Path]) -> str:
pil_image = Image.open(image)
width, height = pil_image.size
bbox = [
int(bbox[0] * width),
int(bbox[1] * height),
int(bbox[2] * width),
int(bbox[3] * height),
]
cropped_image = pil_image.crop(bbox) # type: ignore
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
cropped_image.save(tmp.name)

return tmp.name


class Add(Tool):
name = "add_"
description = "'add_' returns the sum of all the arguments passed to it, normalized to 2 decimal places."
usage = {
"required_parameters": {"name": "input", "type": "List[int]"},
"required_parameters": [{"name": "input", "type": "List[int]"}],
"examples": [
{
"scenario": "If you want to calculate 2 + 4",
Expand All @@ -277,7 +362,7 @@ class Subtract(Tool):
name = "subtract_"
description = "'subtract_' returns the difference of all the arguments passed to it, normalized to 2 decimal places."
usage = {
"required_parameters": {"name": "input", "type": "List[int]"},
"required_parameters": [{"name": "input", "type": "List[int]"}],
"examples": [
{
"scenario": "If you want to calculate 4 - 2",
Expand All @@ -294,7 +379,7 @@ class Multiply(Tool):
name = "multiply_"
description = "'multiply_' returns the product of all the arguments passed to it, normalized to 2 decimal places."
usage = {
"required_parameters": {"name": "input", "type": "List[int]"},
"required_parameters": [{"name": "input", "type": "List[int]"}],
"examples": [
{
"scenario": "If you want to calculate 2 * 4",
Expand All @@ -311,7 +396,7 @@ class Divide(Tool):
name = "divide_"
description = "'divide_' returns the division of all the arguments passed to it, normalized to 2 decimal places."
usage = {
"required_parameters": {"name": "input", "type": "List[int]"},
"required_parameters": [{"name": "input", "type": "List[int]"}],
"examples": [
{
"scenario": "If you want to calculate 4 / 2",
Expand All @@ -327,7 +412,17 @@ def __call__(self, input: List[int]) -> float:
TOOLS = {
i: {"name": c.name, "description": c.description, "usage": c.usage, "class": c}
for i, c in enumerate(
[CLIP, GroundingDINO, GroundingSAM, Add, Subtract, Multiply, Divide]
[
CLIP,
GroundingDINO,
AgentGroundingSAM,
Counter,
Crop,
Add,
Subtract,
Multiply,
Divide,
]
)
if (hasattr(c, "name") and hasattr(c, "description") and hasattr(c, "usage"))
}

0 comments on commit 4d5c6fa

Please sign in to comment.