diff --git a/vision_agent/__init__.py b/vision_agent/__init__.py index 3704c363..360b8a78 100644 --- a/vision_agent/__init__.py +++ b/vision_agent/__init__.py @@ -1,5 +1,5 @@ +from .agent import Agent from .data import DataStore, build_data_store from .emb import Embedder, OpenAIEmb, SentenceTransformerEmb, get_embedder from .llm import LLM, OpenAILLM from .lmm import LMM, LLaVALMM, OpenAILMM, get_lmm -from .agent import Agent diff --git a/vision_agent/agent/__init__.py b/vision_agent/agent/__init__.py index aec05098..f0954251 100644 --- a/vision_agent/agent/__init__.py +++ b/vision_agent/agent/__init__.py @@ -1,3 +1,3 @@ from .agent import Agent -from .reflexion import Reflexion from .easytool import EasyTool +from .reflexion import Reflexion diff --git a/vision_agent/agent/easytool.py b/vision_agent/agent/easytool.py index 929689ad..6a3c6c77 100644 --- a/vision_agent/agent/easytool.py +++ b/vision_agent/agent/easytool.py @@ -4,7 +4,8 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from vision_agent import LLM, LMM, OpenAILLM +from vision_agent.llm import LLM, OpenAILLM +from vision_agent.lmm import LMM from vision_agent.tools import TOOLS from .agent import Agent @@ -42,10 +43,10 @@ def change_name(name: str) -> str: def format_tools(tools: Dict[int, Any]) -> str: # Format this way so it's clear what the ID's are - tool_list = [] + tool_str = "" for key in tools: - tool_list.append(f"ID: {key}, {tools[key]}\\n") - return str(tool_list) + tool_str += f"ID: {key}, {tools[key]}\n" + return tool_str def task_decompose( @@ -151,7 +152,11 @@ def answer_summarize( def function_call(tool: Callable, parameters: Dict[str, Any]) -> Any: - return tool()(**parameters) + try: + return tool()(**parameters) + except Exception as e: + _LOGGER.error(f"Failed function_call on: {e}") + return None def retrieval( @@ -160,7 +165,6 @@ def retrieval( tools: Dict[int, Any], previous_log: str, ) -> Tuple[List[Dict], str]: - # TODO: remove tools_used? tool_id = choose_tool( model, question, {k: v["description"] for k, v in tools.items()} ) @@ -200,7 +204,7 @@ def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any: call_results.extend(parse_tool_results(result)) tool_results[i]["call_results"] = call_results - call_results_str = "\n\n".join([str(e) for e in call_results]) + call_results_str = "\n\n".join([str(e) for e in call_results if e is not None]) _LOGGER.info(f"\tCall Results: {call_results_str}") return tool_results, call_results_str diff --git a/vision_agent/agent/reflexion.py b/vision_agent/agent/reflexion.py index 4d0d05c3..f13984b1 100644 --- a/vision_agent/agent/reflexion.py +++ b/vision_agent/agent/reflexion.py @@ -4,7 +4,8 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple, Union -from vision_agent import LLM, LMM, OpenAILLM +from vision_agent.llm import LLM, OpenAILLM +from vision_agent.lmm import LMM from .agent import Agent from .reflexion_prompts import ( @@ -114,7 +115,7 @@ def __init__( self.reflect_prompt = reflect_prompt self.finsh_prompt = finsh_prompt self.cot_examples = cot_examples - self.refelct_examples = reflect_examples + self.reflect_examples = reflect_examples self.reflections: List[str] = [] if verbose: _LOGGER.setLevel(logging.INFO) @@ -273,7 +274,7 @@ def _build_reflect_prompt( self, question: str, context: str = "", scratchpad: str = "" ) -> str: return self.reflect_prompt.format( - examples=self.refelct_examples, + examples=self.reflect_examples, context=context, question=question, scratchpad=scratchpad, diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index dae14d66..7e5636c2 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -1,2 +1,2 @@ from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT -from .tools import CLIP, TOOLS, GroundingDINO, GroundingSAM, Tool +from .tools import CLIP, TOOLS, Counter, Crop, GroundingDINO, GroundingSAM, Tool diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index c1b8fe2d..fdcece18 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -1,10 +1,13 @@ import logging +import tempfile from abc import ABC +from collections import Counter as CounterClass from pathlib import Path from typing import Any, Dict, List, Tuple, Union, cast import numpy as np import requests +from PIL import Image from PIL.Image import Image as ImageType from vision_agent.image_utils import convert_to_b64, get_image_size @@ -52,19 +55,16 @@ class CLIP(Tool): or tags. Examples:: - >>> from vision_agent.tools import tools - >>> t = tools.CLIP(["red line", "yellow dot", "none"]) - >>> t("examples/img/ct_scan1.jpg")) - >>> [[0.02567436918616295, 0.9534115791320801, 0.020914122462272644]] + >>> import vision_agent as va + >>> clip = va.tools.CLIP() + >>> clip(["red line", "yellow dot"], "ct_scan1.jpg")) + >>> [{"labels": ["red line", "yellow dot"], "scores": [0.98, 0.02]}] """ _ENDPOINT = "https://rb4ii6dfacmwqfxivi4aedyyfm0endsv.lambda-url.us-east-2.on.aws" name = "clip_" - description = ( - "'clip_' is a tool that can classify or tag any image given a set if input classes or tags." - "Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n" - ) + description = "'clip_' is a tool that can classify or tag any image given a set if input classes or tags." usage = { "required_parameters": [ {"name": "prompt", "type": "List[str]"}, @@ -106,22 +106,30 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict ) or "statusCode" not in resp_json: _LOGGER.error(f"Request failed: {resp_json}") raise ValueError(f"Request failed: {resp_json}") - return cast(List[Dict], resp_json["data"]) + + rets = [] + for elt in resp_json["data"]: + rets.append({"labels": prompt, "scores": [round(prob, 2) for prob in elt]}) + return cast(List[Dict], rets) class GroundingDINO(Tool): + r"""Grounding DINO is a tool that can detect arbitrary objects with inputs such as + category names or referring expressions. + + Examples:: + >>> import vision_agent as va + >>> t = va.tools.GroundingDINO() + >>> t("red line. yellow dot", "ct_scan1.jpg") + >>> [{'labels': ['red line', 'yellow dot'], + >>> 'bboxes': [[0.38, 0.15, 0.59, 0.7], [0.48, 0.25, 0.69, 0.71]], + >>> 'scores': [0.98, 0.02]}] + """ + _ENDPOINT = "https://chnicr4kes5ku77niv2zoytggq0qyqlp.lambda-url.us-east-2.on.aws" name = "grounding_dino_" - description = ( - "'grounding_dino_' is a tool that can detect arbitrary objects with inputs such as category names or referring expressions." - "Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n" - "The tool returns a list of dictionaries, each containing the following keys:\n" - ' - "label": The label of the detected object.\n' - ' - "score": The confidence score of the detection.\n' - ' - "bbox": The bounding box of the detected object. The box coordinates are normalize to [0, 1]\n' - 'An example output would be: [{"label": ["car"], "score": [0.99], "bbox": [[0.1, 0.2, 0.3, 0.4]]}]\n' - ) + description = "'grounding_dino_' is a tool that can detect arbitrary objects with inputs such as category names or referring expressions." usage = { "required_parameters": [ {"name": "prompt", "type": "str"}, @@ -180,27 +188,27 @@ class GroundingSAM(Tool): inputs such as category names or referring expressions. Examples:: - >>> from vision_agent.tools import tools - >>> t = tools.GroundingSAM(["red line", "yellow dot", "none"]) - >>> t("examples/img/ct_scan1.jpg") - >>> [{'label': 'none', 'mask': array([[0, 0, 0, ..., 0, 0, 0], + >>> import vision_agent as va + >>> t = va.tools.GroundingSAM() + >>> t(["red line", "yellow dot"], ct_scan1.jpg"]) + >>> [{'labels': ['yellow dot', 'red line'], + >>> 'bboxes': [[0.38, 0.15, 0.59, 0.7], [0.48, 0.25, 0.69, 0.71]], + >>> 'masks': [array([[0, 0, 0, ..., 0, 0, 0], >>> [0, 0, 0, ..., 0, 0, 0], >>> ..., >>> [0, 0, 0, ..., 0, 0, 0], - >>> [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)}, {'label': 'red line', 'mask': array([[0, 0, 0, ..., 0, 0, 0], + >>> [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)}, + >>> array([[0, 0, 0, ..., 0, 0, 0], >>> [0, 0, 0, ..., 0, 0, 0], >>> ..., >>> [1, 1, 1, ..., 1, 1, 1], - >>> [1, 1, 1, ..., 1, 1, 1]], dtype=uint8)}] + >>> [1, 1, 1, ..., 1, 1, 1]], dtype=uint8)]}] """ _ENDPOINT = "https://cou5lfmus33jbddl6hoqdfbw7e0qidrw.lambda-url.us-east-2.on.aws" name = "grounding_sam_" - description = ( - "'grounding_sam_' is a tool that can detect and segment arbitrary objects with inputs such as category names or referring expressions." - "Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n" - ) + description = "'grounding_sam_' is a tool that can detect and segment arbitrary objects with inputs such as category names or referring expressions." usage = { "required_parameters": [ {"name": "prompt", "type": "List[str]"}, @@ -226,6 +234,7 @@ class GroundingSAM(Tool): } def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]: + image_size = get_image_size(image) image_b64 = convert_to_b64(image) data = { "classes": prompt, @@ -243,24 +252,100 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict _LOGGER.error(f"Request failed: {resp_json}") raise ValueError(f"Request failed: {resp_json}") resp_data = resp_json["data"] - preds = [] + ret_pred: Dict[str, List] = {"labels": [], "bboxes": [], "masks": []} for pred in resp_data["preds"]: encoded_mask = pred["encoded_mask"] mask = rle_decode(mask_rle=encoded_mask, shape=pred["mask_shape"]) - preds.append( - { - "label": pred["label_name"], - "mask": mask, - } - ) - return preds + ret_pred["labels"].append(pred["label_name"]) + ret_pred["bboxes"].append(normalize_bbox(pred["bbox"], image_size)) + ret_pred["masks"].append(mask) + ret_preds = [ret_pred] + return ret_preds + + +class AgentGroundingSAM(GroundingSAM): + r"""AgentGroundingSAM is the same as GroundingSAM but it saves the masks as files + returns the file name. This makes it easier for agents to use. + """ + + def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]: + rets = super().__call__(prompt, image) + for ret in rets: + mask_files = [] + for mask in ret["masks"]: + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: + Image.fromarray(mask * 255).save(tmp) + mask_files.append(tmp.name) + ret["masks"] = mask_files + return rets + + +class Counter(Tool): + name = "counter_" + description = "'counter_' detects and counts the number of objects in an image given an input such as a category name or referring expression." + usage = { + "required_parameters": [ + {"name": "prompt", "type": "str"}, + {"name": "image", "type": "str"}, + ], + "examples": [ + { + "scenario": "Can you count the number of cars in this image? Image name image.jpg", + "parameters": {"prompt": "car", "image": "image.jpg"}, + }, + { + "scenario": "Can you count the number of people? Image name: people.png", + "parameters": {"prompt": "person", "image": "people.png"}, + }, + ], + } + + def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict: + resp = GroundingDINO()(prompt, image) + return dict(CounterClass(resp[0]["labels"])) + + +class Crop(Tool): + name = "crop_" + description = "'crop_' crops an image given a bounding box and returns a file name of the cropped image." + usage = { + "required_parameters": [ + {"name": "bbox", "type": "List[float]"}, + {"name": "image", "type": "str"}, + ], + "examples": [ + { + "scenario": "Can you crop the image to the bounding box [0.1, 0.1, 0.9, 0.9]? Image name: image.jpg", + "parameters": {"bbox": [0.1, 0.1, 0.9, 0.9], "image": "image.jpg"}, + }, + { + "scenario": "Cut out the image to the bounding box [0.2, 0.2, 0.8, 0.8]. Image name: car.jpg", + "parameters": {"bbox": [0.2, 0.2, 0.8, 0.8], "image": "car.jpg"}, + }, + ], + } + + def __call__(self, bbox: List[float], image: Union[str, Path]) -> str: + pil_image = Image.open(image) + width, height = pil_image.size + bbox = [ + int(bbox[0] * width), + int(bbox[1] * height), + int(bbox[2] * width), + int(bbox[3] * height), + ] + cropped_image = pil_image.crop(bbox) # type: ignore + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: + cropped_image.save(tmp.name) + + return tmp.name class Add(Tool): name = "add_" description = "'add_' returns the sum of all the arguments passed to it, normalized to 2 decimal places." usage = { - "required_parameters": {"name": "input", "type": "List[int]"}, + "required_parameters": [{"name": "input", "type": "List[int]"}], "examples": [ { "scenario": "If you want to calculate 2 + 4", @@ -277,7 +362,7 @@ class Subtract(Tool): name = "subtract_" description = "'subtract_' returns the difference of all the arguments passed to it, normalized to 2 decimal places." usage = { - "required_parameters": {"name": "input", "type": "List[int]"}, + "required_parameters": [{"name": "input", "type": "List[int]"}], "examples": [ { "scenario": "If you want to calculate 4 - 2", @@ -294,7 +379,7 @@ class Multiply(Tool): name = "multiply_" description = "'multiply_' returns the product of all the arguments passed to it, normalized to 2 decimal places." usage = { - "required_parameters": {"name": "input", "type": "List[int]"}, + "required_parameters": [{"name": "input", "type": "List[int]"}], "examples": [ { "scenario": "If you want to calculate 2 * 4", @@ -311,7 +396,7 @@ class Divide(Tool): name = "divide_" description = "'divide_' returns the division of all the arguments passed to it, normalized to 2 decimal places." usage = { - "required_parameters": {"name": "input", "type": "List[int]"}, + "required_parameters": [{"name": "input", "type": "List[int]"}], "examples": [ { "scenario": "If you want to calculate 4 / 2", @@ -327,7 +412,17 @@ def __call__(self, input: List[int]) -> float: TOOLS = { i: {"name": c.name, "description": c.description, "usage": c.usage, "class": c} for i, c in enumerate( - [CLIP, GroundingDINO, GroundingSAM, Add, Subtract, Multiply, Divide] + [ + CLIP, + GroundingDINO, + AgentGroundingSAM, + Counter, + Crop, + Add, + Subtract, + Multiply, + Divide, + ] ) if (hasattr(c, "name") and hasattr(c, "description") and hasattr(c, "usage")) }