From f661a93ce5d23fe4f25c87a89d5d9d04fbc9bea7 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Thu, 21 Mar 2024 17:21:27 -0700 Subject: [PATCH 01/13] fixing prompting and failure cases --- vision_agent/agent/easytool.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/vision_agent/agent/easytool.py b/vision_agent/agent/easytool.py index 929689ad..b86400d9 100644 --- a/vision_agent/agent/easytool.py +++ b/vision_agent/agent/easytool.py @@ -42,10 +42,10 @@ def change_name(name: str) -> str: def format_tools(tools: Dict[int, Any]) -> str: # Format this way so it's clear what the ID's are - tool_list = [] + tool_str = "" for key in tools: - tool_list.append(f"ID: {key}, {tools[key]}\\n") - return str(tool_list) + tool_str += f"ID: {key}, {tools[key]}\n" + return tool_str def task_decompose( @@ -151,7 +151,11 @@ def answer_summarize( def function_call(tool: Callable, parameters: Dict[str, Any]) -> Any: - return tool()(**parameters) + try: + return tool()(**parameters) + except Exception as e: + _LOGGER.error(f"Failed function_call on: {e}") + return None def retrieval( @@ -160,7 +164,6 @@ def retrieval( tools: Dict[int, Any], previous_log: str, ) -> Tuple[List[Dict], str]: - # TODO: remove tools_used? tool_id = choose_tool( model, question, {k: v["description"] for k, v in tools.items()} ) @@ -200,7 +203,7 @@ def parse_tool_results(result: Dict[str, Union[Dict, List]]) -> Any: call_results.extend(parse_tool_results(result)) tool_results[i]["call_results"] = call_results - call_results_str = "\n\n".join([str(e) for e in call_results]) + call_results_str = "\n\n".join([str(e) for e in call_results if e is not None]) _LOGGER.info(f"\tCall Results: {call_results_str}") return tool_results, call_results_str From 4a9ef0d02c724af00dc75c4b3e49fe88a4e029da Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Thu, 21 Mar 2024 17:21:36 -0700 Subject: [PATCH 02/13] fix typo --- vision_agent/agent/reflexion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vision_agent/agent/reflexion.py b/vision_agent/agent/reflexion.py index 4d0d05c3..b28e0a34 100644 --- a/vision_agent/agent/reflexion.py +++ b/vision_agent/agent/reflexion.py @@ -114,7 +114,7 @@ def __init__( self.reflect_prompt = reflect_prompt self.finsh_prompt = finsh_prompt self.cot_examples = cot_examples - self.refelct_examples = reflect_examples + self.reflect_examples = reflect_examples self.reflections: List[str] = [] if verbose: _LOGGER.setLevel(logging.INFO) @@ -273,7 +273,7 @@ def _build_reflect_prompt( self, question: str, context: str = "", scratchpad: str = "" ) -> str: return self.reflect_prompt.format( - examples=self.refelct_examples, + examples=self.reflect_examples, context=context, question=question, scratchpad=scratchpad, From 31cdc13d7e18358fd26a1ccd8cd3e38fe7c25cb9 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Thu, 21 Mar 2024 17:21:54 -0700 Subject: [PATCH 03/13] minimize description of tools, add test tools --- vision_agent/tools/tools.py | 85 +++++++++++++++++++++++++++---------- 1 file changed, 63 insertions(+), 22 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index c1b8fe2d..0a2584d0 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -1,10 +1,12 @@ import logging +import tempfile from abc import ABC from pathlib import Path from typing import Any, Dict, List, Tuple, Union, cast import numpy as np import requests +from PIL import Image from PIL.Image import Image as ImageType from vision_agent.image_utils import convert_to_b64, get_image_size @@ -61,10 +63,7 @@ class CLIP(Tool): _ENDPOINT = "https://rb4ii6dfacmwqfxivi4aedyyfm0endsv.lambda-url.us-east-2.on.aws" name = "clip_" - description = ( - "'clip_' is a tool that can classify or tag any image given a set if input classes or tags." - "Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n" - ) + description = "'clip_' is a tool that can classify or tag any image given a set if input classes or tags." usage = { "required_parameters": [ {"name": "prompt", "type": "List[str]"}, @@ -113,15 +112,7 @@ class GroundingDINO(Tool): _ENDPOINT = "https://chnicr4kes5ku77niv2zoytggq0qyqlp.lambda-url.us-east-2.on.aws" name = "grounding_dino_" - description = ( - "'grounding_dino_' is a tool that can detect arbitrary objects with inputs such as category names or referring expressions." - "Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n" - "The tool returns a list of dictionaries, each containing the following keys:\n" - ' - "label": The label of the detected object.\n' - ' - "score": The confidence score of the detection.\n' - ' - "bbox": The bounding box of the detected object. The box coordinates are normalize to [0, 1]\n' - 'An example output would be: [{"label": ["car"], "score": [0.99], "bbox": [[0.1, 0.2, 0.3, 0.4]]}]\n' - ) + description = "'grounding_dino_' is a tool that can detect arbitrary objects with inputs such as category names or referring expressions." usage = { "required_parameters": [ {"name": "prompt", "type": "str"}, @@ -197,10 +188,7 @@ class GroundingSAM(Tool): _ENDPOINT = "https://cou5lfmus33jbddl6hoqdfbw7e0qidrw.lambda-url.us-east-2.on.aws" name = "grounding_sam_" - description = ( - "'grounding_sam_' is a tool that can detect and segment arbitrary objects with inputs such as category names or referring expressions." - "Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n" - ) + description = "'grounding_sam_' is a tool that can detect and segment arbitrary objects with inputs such as category names or referring expressions." usage = { "required_parameters": [ {"name": "prompt", "type": "List[str]"}, @@ -256,11 +244,64 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict return preds +class Crop(Tool): + name = "crop_" + description = "'crop_' crops an image given a bounding box and returns a file name of the cropped image." + usage = { + "required_parameters": [ + {"name": "bbox", "type": "List[float]"}, + {"name": "image", "type": "str"}, + ], + "examples": [ + { + "scenario": "Can you crop the image to the bounding box [0.1, 0.1, 0.9, 0.9]? Image name: image.jpg", + "parameters": {"bbox": [0.1, 0.1, 0.9, 0.9], "image": "image.jpg"}, + }, + { + "scenario": "Cut out the image to the bounding box [0.2, 0.2, 0.8, 0.8]. Image name: car.jpg", + "parameters": {"bbox": [0.2, 0.2, 0.8, 0.8], "image": "car.jpg"}, + }, + ], + } + + def __call__(self, bbox: List[float], image: Union[str, Path]) -> str: + pil_image = Image.open(image) + width, height = pil_image.size + bbox = [ + int(bbox[0] * width), + int(bbox[1] * height), + int(bbox[2] * width), + int(bbox[3] * height), + ] + cropped_image = pil_image.crop(bbox) + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: + cropped_image.save(tmp.name) + + return tmp.name + + +class ImageSearch(Tool): + name = "image_search_" + description = "'image_search_' searches for images similar to the input image." + usage = { + "required_parameters": [{"name": "image", "type": "str"}], + "examples": [ + { + "scenario": "Can you find images similar to the image? Image name: image.jpg", + "parameters": {"image": "image.jpg"}, + } + ], + } + + def __call__(self, image: Union[str, Path]) -> List[str]: + return ["image1.png", "image2.png", "image3.png"] + + class Add(Tool): name = "add_" description = "'add_' returns the sum of all the arguments passed to it, normalized to 2 decimal places." usage = { - "required_parameters": {"name": "input", "type": "List[int]"}, + "required_parameters": [{"name": "input", "type": "List[int]"}], "examples": [ { "scenario": "If you want to calculate 2 + 4", @@ -277,7 +318,7 @@ class Subtract(Tool): name = "subtract_" description = "'subtract_' returns the difference of all the arguments passed to it, normalized to 2 decimal places." usage = { - "required_parameters": {"name": "input", "type": "List[int]"}, + "required_parameters": [{"name": "input", "type": "List[int]"}], "examples": [ { "scenario": "If you want to calculate 4 - 2", @@ -294,7 +335,7 @@ class Multiply(Tool): name = "multiply_" description = "'multiply_' returns the product of all the arguments passed to it, normalized to 2 decimal places." usage = { - "required_parameters": {"name": "input", "type": "List[int]"}, + "required_parameters": [{"name": "input", "type": "List[int]"}], "examples": [ { "scenario": "If you want to calculate 2 * 4", @@ -311,7 +352,7 @@ class Divide(Tool): name = "divide_" description = "'divide_' returns the division of all the arguments passed to it, normalized to 2 decimal places." usage = { - "required_parameters": {"name": "input", "type": "List[int]"}, + "required_parameters": [{"name": "input", "type": "List[int]"}], "examples": [ { "scenario": "If you want to calculate 4 / 2", @@ -327,7 +368,7 @@ def __call__(self, input: List[int]) -> float: TOOLS = { i: {"name": c.name, "description": c.description, "usage": c.usage, "class": c} for i, c in enumerate( - [CLIP, GroundingDINO, GroundingSAM, Add, Subtract, Multiply, Divide] + [CLIP, GroundingDINO, GroundingSAM, Crop, ImageSearch, Add, Subtract, Multiply, Divide] ) if (hasattr(c, "name") and hasattr(c, "description") and hasattr(c, "usage")) } From 9d70eb0efe4e95de3f83f71a8792a7b9cf10f83d Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Thu, 21 Mar 2024 20:35:10 -0700 Subject: [PATCH 04/13] added counter tool --- vision_agent/tools/tools.py | 83 ++++++++++++++++++++++++++++++------- 1 file changed, 69 insertions(+), 14 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 0a2584d0..625e4865 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -54,9 +54,9 @@ class CLIP(Tool): or tags. Examples:: - >>> from vision_agent.tools import tools - >>> t = tools.CLIP(["red line", "yellow dot", "none"]) - >>> t("examples/img/ct_scan1.jpg")) + >>> import vision_agent as va + >>> clip = va.tools.CLIP() + >>> clip(["red line", "yellow dot"], "examples/img/ct_scan1.jpg")) >>> [[0.02567436918616295, 0.9534115791320801, 0.020914122462272644]] """ @@ -105,7 +105,11 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict ) or "statusCode" not in resp_json: _LOGGER.error(f"Request failed: {resp_json}") raise ValueError(f"Request failed: {resp_json}") - return cast(List[Dict], resp_json["data"]) + + rets = [] + for elt in resp_json["data"]: + rets.append({"labels": prompt, "scores": [round(prob, 2) for prob in elt]}) + return cast(List[Dict], rets) class GroundingDINO(Tool): @@ -174,7 +178,7 @@ class GroundingSAM(Tool): >>> from vision_agent.tools import tools >>> t = tools.GroundingSAM(["red line", "yellow dot", "none"]) >>> t("examples/img/ct_scan1.jpg") - >>> [{'label': 'none', 'mask': array([[0, 0, 0, ..., 0, 0, 0], + >>> [{'label': 'yellow dot', 'mask': array([[0, 0, 0, ..., 0, 0, 0], >>> [0, 0, 0, ..., 0, 0, 0], >>> ..., >>> [0, 0, 0, ..., 0, 0, 0], @@ -214,6 +218,7 @@ class GroundingSAM(Tool): } def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]: + image_size = get_image_size(image) image_b64 = convert_to_b64(image) data = { "classes": prompt, @@ -231,17 +236,56 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict _LOGGER.error(f"Request failed: {resp_json}") raise ValueError(f"Request failed: {resp_json}") resp_data = resp_json["data"] - preds = [] + ret_pred = {"labels": [], "bboxes": [], "masks": []} for pred in resp_data["preds"]: encoded_mask = pred["encoded_mask"] mask = rle_decode(mask_rle=encoded_mask, shape=pred["mask_shape"]) - preds.append( - { - "label": pred["label_name"], - "mask": mask, - } - ) - return preds + ret_pred["labels"].append(pred["label_name"]) + ret_pred["bboxes"].append(normalize_bbox(pred["bbox"], image_size)) + ret_pred["masks"].append(mask) + return [ret_pred] + + +class AgentGroundingSAM(GroundingSAM): + r"""AgentGroundingSAM is the same as GroundingSAM but it saves the masks as files + returns the file name. This makes it easier for agents to use. + """ + + def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict]: + rets = super().__call__(prompt, image) + for ret in rets: + mask_files = [] + for mask in ret["masks"]: + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: + Image.fromarray(mask * 255).save(tmp) + mask_files.append(tmp.name) + ret["masks"] = mask_files + return rets + + +class Counter(Tool): + name = "counter_" + description = "'counter_' detects and counts the number of objects in an image given an input such as a category name or referring expression." + usage = { + "required_parameters": [ + {"name": "prompt", "type": "str"}, + {"name": "image", "type": "str"}, + ], + "examples": [ + { + "scenario": "Can you count the number of cars in this image? Image name image.jpg", + "parameters": {"prompt": "car", "image": "image.jpg"}, + }, + { + "scenario": "Can you count the number of people? Image name: people.png", + "parameters": {"prompt": "person", "image": "people.png"}, + }, + ], + } + + def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict: + resp = GroundingDINO()(prompt, image) + return dict(Counter=resp[0]["labels"]) class Crop(Tool): @@ -368,7 +412,18 @@ def __call__(self, input: List[int]) -> float: TOOLS = { i: {"name": c.name, "description": c.description, "usage": c.usage, "class": c} for i, c in enumerate( - [CLIP, GroundingDINO, GroundingSAM, Crop, ImageSearch, Add, Subtract, Multiply, Divide] + [ + CLIP, + GroundingDINO, + AgentGroundingSAM, + Counter, + Crop, + ImageSearch, + Add, + Subtract, + Multiply, + Divide, + ] ) if (hasattr(c, "name") and hasattr(c, "description") and hasattr(c, "usage")) } From 688431c09b36814ab3a4c1df2f439a412e8535dc Mon Sep 17 00:00:00 2001 From: Yazhou Cao Date: Thu, 21 Mar 2024 21:53:08 -0700 Subject: [PATCH 05/13] Finish ImageSearch --- vision_agent/tools/tools.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 625e4865..5ff1c7ce 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -338,7 +338,25 @@ class ImageSearch(Tool): } def __call__(self, image: Union[str, Path]) -> List[str]: - return ["image1.png", "image2.png", "image3.png"] + assert isinstance(image, str), "The input image must be a string url." + url = "https://www.googleapis.com/customsearch/v1" + params = { + "key": "AIzaSyDy3UMHL1E3nFLTLdIQb3nyIU5-zhSfzPo", + "cx": "831f248aa2e1d4daf", + "q": image, + "num": 10, + "searchType":"image", + } + response = requests.get(url, params=params) + + # Check if the request was successful + if response.status_code != 200: + raise RuntimeError(f"Failed to fetch data: {response.status_code} {response.reason}") + + resp = response.json() + items = resp.get("items", []) + print(f"Found {len(items)} results.") + return [item["link"] for item in items] class Add(Tool): From 11288ae48778a4f00673931d08f71617170f5b27 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Fri, 22 Mar 2024 08:12:05 -0700 Subject: [PATCH 06/13] fix counter class --- vision_agent/tools/__init__.py | 2 +- vision_agent/tools/tools.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index dae14d66..7e5636c2 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -1,2 +1,2 @@ from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT -from .tools import CLIP, TOOLS, GroundingDINO, GroundingSAM, Tool +from .tools import CLIP, TOOLS, Counter, Crop, GroundingDINO, GroundingSAM, Tool diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 5ff1c7ce..2d2e2a0b 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -3,6 +3,7 @@ from abc import ABC from pathlib import Path from typing import Any, Dict, List, Tuple, Union, cast +from collections import Counter as CounterClass import numpy as np import requests @@ -285,7 +286,8 @@ class Counter(Tool): def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict: resp = GroundingDINO()(prompt, image) - return dict(Counter=resp[0]["labels"]) + __import__("ipdb").set_trace() + return dict(CounterClass(resp[0]["labels"])) class Crop(Tool): @@ -339,6 +341,7 @@ class ImageSearch(Tool): def __call__(self, image: Union[str, Path]) -> List[str]: assert isinstance(image, str), "The input image must be a string url." + image = "https://popmenucloud.com/cdn-cgi/image/width%3D1920%2Cheight%3D1920%2Cfit%3Dscale-down%2Cformat%3Dauto%2Cquality%3D60/vpylarnm/a6ad1671-8938-457f-b4cd-3215caa122cb.png" url = "https://www.googleapis.com/customsearch/v1" params = { "key": "AIzaSyDy3UMHL1E3nFLTLdIQb3nyIU5-zhSfzPo", @@ -355,7 +358,6 @@ def __call__(self, image: Union[str, Path]) -> List[str]: resp = response.json() items = resp.get("items", []) - print(f"Found {len(items)} results.") return [item["link"] for item in items] From 37f0e75268dfcc8aaba92e0839ef0d12fecd2277 Mon Sep 17 00:00:00 2001 From: Yazhou Cao Date: Fri, 22 Mar 2024 10:42:10 -0700 Subject: [PATCH 07/13] Remove keys --- vision_agent/tools/tools.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 2d2e2a0b..120d45a0 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -1,9 +1,10 @@ import logging +import os import tempfile from abc import ABC +from collections import Counter as CounterClass from pathlib import Path from typing import Any, Dict, List, Tuple, Union, cast -from collections import Counter as CounterClass import numpy as np import requests @@ -343,9 +344,12 @@ def __call__(self, image: Union[str, Path]) -> List[str]: assert isinstance(image, str), "The input image must be a string url." image = "https://popmenucloud.com/cdn-cgi/image/width%3D1920%2Cheight%3D1920%2Cfit%3Dscale-down%2Cformat%3Dauto%2Cquality%3D60/vpylarnm/a6ad1671-8938-457f-b4cd-3215caa122cb.png" url = "https://www.googleapis.com/customsearch/v1" + api_key = os.getenv("GOOGLE_API_KEY") + search_engine_id = os.getenv("GOOGLE_SEARCH_ENGINE_ID") + assert api_key and search_engine_id, "Please set the GOOGLE_API_KEY and GOOGLE_SEARCH_ENGINE_ID environment variable. See https://developers.google.com/custom-search/v1/using_rest for more information." params = { - "key": "AIzaSyDy3UMHL1E3nFLTLdIQb3nyIU5-zhSfzPo", - "cx": "831f248aa2e1d4daf", + "key": api_key, + "cx": search_engine_id, "q": image, "num": 10, "searchType":"image", From d5286d136fbc1e6c226b2ed1bfd7caf298352324 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Fri, 22 Mar 2024 10:52:43 -0700 Subject: [PATCH 08/13] updated docs --- vision_agent/tools/__init__.py | 10 ++++- vision_agent/tools/tools.py | 72 ++++++++++++---------------------- 2 files changed, 33 insertions(+), 49 deletions(-) diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 7e5636c2..67921aae 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -1,2 +1,10 @@ from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT -from .tools import CLIP, TOOLS, Counter, Crop, GroundingDINO, GroundingSAM, Tool +from .tools import ( + CLIP, + TOOLS, + Counter, + Crop, + GroundingDINO, + GroundingSAM, + Tool, +) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 120d45a0..9f437ec8 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -58,8 +58,8 @@ class CLIP(Tool): Examples:: >>> import vision_agent as va >>> clip = va.tools.CLIP() - >>> clip(["red line", "yellow dot"], "examples/img/ct_scan1.jpg")) - >>> [[0.02567436918616295, 0.9534115791320801, 0.020914122462272644]] + >>> clip(["red line", "yellow dot"], "ct_scan1.jpg")) + >>> [{"labels": ["red line", "yellow dot"], "scores": [0.98, 0.02]}] """ _ENDPOINT = "https://rb4ii6dfacmwqfxivi4aedyyfm0endsv.lambda-url.us-east-2.on.aws" @@ -115,6 +115,18 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict class GroundingDINO(Tool): + r"""Grounding DINO is a tool that can detect arbitrary objects with inputs such as + category names or referring expressions. + + Examples:: + >>> import vision_agent as va + >>> t = va.tools.GroundingDINO() + >>> t("red line. yellow dot", "ct_scan1.jpg") + >>> [{'labels': ['red line', 'yellow dot'], + >>> 'bboxes': [[0.38, 0.15, 0.59, 0.7], [0.48, 0.25, 0.69, 0.71]], + >>> 'scores': [0.98, 0.02]}] + """ + _ENDPOINT = "https://chnicr4kes5ku77niv2zoytggq0qyqlp.lambda-url.us-east-2.on.aws" name = "grounding_dino_" @@ -177,18 +189,21 @@ class GroundingSAM(Tool): inputs such as category names or referring expressions. Examples:: - >>> from vision_agent.tools import tools - >>> t = tools.GroundingSAM(["red line", "yellow dot", "none"]) - >>> t("examples/img/ct_scan1.jpg") - >>> [{'label': 'yellow dot', 'mask': array([[0, 0, 0, ..., 0, 0, 0], + >>> import vision_agent as va + >>> t = va.tools.GroundingSAM() + >>> t(["red line", "yellow dot"], ct_scan1.jpg"]) + >>> [{'labels': ['yellow dot', 'red line'], + >>> 'bboxes': [[0.38, 0.15, 0.59, 0.7], [0.48, 0.25, 0.69, 0.71]], + >>> 'masks': [array([[0, 0, 0, ..., 0, 0, 0], >>> [0, 0, 0, ..., 0, 0, 0], >>> ..., >>> [0, 0, 0, ..., 0, 0, 0], - >>> [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)}, {'label': 'red line', 'mask': array([[0, 0, 0, ..., 0, 0, 0], + >>> [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)}, + >>> array([[0, 0, 0, ..., 0, 0, 0], >>> [0, 0, 0, ..., 0, 0, 0], >>> ..., >>> [1, 1, 1, ..., 1, 1, 1], - >>> [1, 1, 1, ..., 1, 1, 1]], dtype=uint8)}] + >>> [1, 1, 1, ..., 1, 1, 1]], dtype=uint8)]}] """ _ENDPOINT = "https://cou5lfmus33jbddl6hoqdfbw7e0qidrw.lambda-url.us-east-2.on.aws" @@ -287,7 +302,6 @@ class Counter(Tool): def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict: resp = GroundingDINO()(prompt, image) - __import__("ipdb").set_trace() return dict(CounterClass(resp[0]["labels"])) @@ -320,51 +334,13 @@ def __call__(self, bbox: List[float], image: Union[str, Path]) -> str: int(bbox[2] * width), int(bbox[3] * height), ] - cropped_image = pil_image.crop(bbox) + cropped_image = pil_image.crop(bbox) # type: ignore with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: cropped_image.save(tmp.name) return tmp.name -class ImageSearch(Tool): - name = "image_search_" - description = "'image_search_' searches for images similar to the input image." - usage = { - "required_parameters": [{"name": "image", "type": "str"}], - "examples": [ - { - "scenario": "Can you find images similar to the image? Image name: image.jpg", - "parameters": {"image": "image.jpg"}, - } - ], - } - - def __call__(self, image: Union[str, Path]) -> List[str]: - assert isinstance(image, str), "The input image must be a string url." - image = "https://popmenucloud.com/cdn-cgi/image/width%3D1920%2Cheight%3D1920%2Cfit%3Dscale-down%2Cformat%3Dauto%2Cquality%3D60/vpylarnm/a6ad1671-8938-457f-b4cd-3215caa122cb.png" - url = "https://www.googleapis.com/customsearch/v1" - api_key = os.getenv("GOOGLE_API_KEY") - search_engine_id = os.getenv("GOOGLE_SEARCH_ENGINE_ID") - assert api_key and search_engine_id, "Please set the GOOGLE_API_KEY and GOOGLE_SEARCH_ENGINE_ID environment variable. See https://developers.google.com/custom-search/v1/using_rest for more information." - params = { - "key": api_key, - "cx": search_engine_id, - "q": image, - "num": 10, - "searchType":"image", - } - response = requests.get(url, params=params) - - # Check if the request was successful - if response.status_code != 200: - raise RuntimeError(f"Failed to fetch data: {response.status_code} {response.reason}") - - resp = response.json() - items = resp.get("items", []) - return [item["link"] for item in items] - - class Add(Tool): name = "add_" description = "'add_' returns the sum of all the arguments passed to it, normalized to 2 decimal places." From 16313fb2dda2ff905a7a8ff1f08c72636c6f7cb0 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Fri, 22 Mar 2024 10:53:46 -0700 Subject: [PATCH 09/13] remove image search --- vision_agent/tools/tools.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 9f437ec8..f66f576c 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -418,7 +418,6 @@ def __call__(self, input: List[int]) -> float: AgentGroundingSAM, Counter, Crop, - ImageSearch, Add, Subtract, Multiply, From b2dcd028ab739545cf18f82a352553f288cd8464 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Fri, 22 Mar 2024 10:55:08 -0700 Subject: [PATCH 10/13] fixed typign issue --- vision_agent/tools/tools.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index f66f576c..403c1c15 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -253,14 +253,15 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict _LOGGER.error(f"Request failed: {resp_json}") raise ValueError(f"Request failed: {resp_json}") resp_data = resp_json["data"] - ret_pred = {"labels": [], "bboxes": [], "masks": []} + ret_pred: Dict[str, List] = {"labels": [], "bboxes": [], "masks": []} for pred in resp_data["preds"]: encoded_mask = pred["encoded_mask"] mask = rle_decode(mask_rle=encoded_mask, shape=pred["mask_shape"]) ret_pred["labels"].append(pred["label_name"]) ret_pred["bboxes"].append(normalize_bbox(pred["bbox"], image_size)) ret_pred["masks"].append(mask) - return [ret_pred] + ret_preds = [ret_pred] + return ret_preds class AgentGroundingSAM(GroundingSAM): From 4293a1cdf9de729fc6308e7de886b7212d626607 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Fri, 22 Mar 2024 10:57:38 -0700 Subject: [PATCH 11/13] ran isort --- vision_agent/__init__.py | 2 +- vision_agent/agent/__init__.py | 2 +- vision_agent/tools/__init__.py | 10 +--------- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/vision_agent/__init__.py b/vision_agent/__init__.py index 3704c363..360b8a78 100644 --- a/vision_agent/__init__.py +++ b/vision_agent/__init__.py @@ -1,5 +1,5 @@ +from .agent import Agent from .data import DataStore, build_data_store from .emb import Embedder, OpenAIEmb, SentenceTransformerEmb, get_embedder from .llm import LLM, OpenAILLM from .lmm import LMM, LLaVALMM, OpenAILMM, get_lmm -from .agent import Agent diff --git a/vision_agent/agent/__init__.py b/vision_agent/agent/__init__.py index aec05098..f0954251 100644 --- a/vision_agent/agent/__init__.py +++ b/vision_agent/agent/__init__.py @@ -1,3 +1,3 @@ from .agent import Agent -from .reflexion import Reflexion from .easytool import EasyTool +from .reflexion import Reflexion diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 67921aae..7e5636c2 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -1,10 +1,2 @@ from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT -from .tools import ( - CLIP, - TOOLS, - Counter, - Crop, - GroundingDINO, - GroundingSAM, - Tool, -) +from .tools import CLIP, TOOLS, Counter, Crop, GroundingDINO, GroundingSAM, Tool From 7b2754984cf8e64ac1add833f5a493b9f2824f24 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Fri, 22 Mar 2024 11:03:01 -0700 Subject: [PATCH 12/13] remove extra import --- vision_agent/tools/tools.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 403c1c15..fdcece18 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -1,5 +1,4 @@ import logging -import os import tempfile from abc import ABC from collections import Counter as CounterClass From 546c858a8c8a1845d44f4a2cb17e4adc3e1d00a8 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Fri, 22 Mar 2024 11:18:12 -0700 Subject: [PATCH 13/13] fix imports --- vision_agent/agent/easytool.py | 3 ++- vision_agent/agent/reflexion.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/vision_agent/agent/easytool.py b/vision_agent/agent/easytool.py index b86400d9..6a3c6c77 100644 --- a/vision_agent/agent/easytool.py +++ b/vision_agent/agent/easytool.py @@ -4,7 +4,8 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from vision_agent import LLM, LMM, OpenAILLM +from vision_agent.llm import LLM, OpenAILLM +from vision_agent.lmm import LMM from vision_agent.tools import TOOLS from .agent import Agent diff --git a/vision_agent/agent/reflexion.py b/vision_agent/agent/reflexion.py index b28e0a34..f13984b1 100644 --- a/vision_agent/agent/reflexion.py +++ b/vision_agent/agent/reflexion.py @@ -4,7 +4,8 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple, Union -from vision_agent import LLM, LMM, OpenAILLM +from vision_agent.llm import LLM, OpenAILLM +from vision_agent.lmm import LMM from .agent import Agent from .reflexion_prompts import (