From d5286d136fbc1e6c226b2ed1bfd7caf298352324 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Fri, 22 Mar 2024 10:52:43 -0700 Subject: [PATCH] updated docs --- vision_agent/tools/__init__.py | 10 ++++- vision_agent/tools/tools.py | 72 ++++++++++++---------------------- 2 files changed, 33 insertions(+), 49 deletions(-) diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 7e5636c2..67921aae 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -1,2 +1,10 @@ from .prompts import CHOOSE_PARAMS, SYSTEM_PROMPT -from .tools import CLIP, TOOLS, Counter, Crop, GroundingDINO, GroundingSAM, Tool +from .tools import ( + CLIP, + TOOLS, + Counter, + Crop, + GroundingDINO, + GroundingSAM, + Tool, +) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 120d45a0..9f437ec8 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -58,8 +58,8 @@ class CLIP(Tool): Examples:: >>> import vision_agent as va >>> clip = va.tools.CLIP() - >>> clip(["red line", "yellow dot"], "examples/img/ct_scan1.jpg")) - >>> [[0.02567436918616295, 0.9534115791320801, 0.020914122462272644]] + >>> clip(["red line", "yellow dot"], "ct_scan1.jpg")) + >>> [{"labels": ["red line", "yellow dot"], "scores": [0.98, 0.02]}] """ _ENDPOINT = "https://rb4ii6dfacmwqfxivi4aedyyfm0endsv.lambda-url.us-east-2.on.aws" @@ -115,6 +115,18 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> List[Dict class GroundingDINO(Tool): + r"""Grounding DINO is a tool that can detect arbitrary objects with inputs such as + category names or referring expressions. + + Examples:: + >>> import vision_agent as va + >>> t = va.tools.GroundingDINO() + >>> t("red line. yellow dot", "ct_scan1.jpg") + >>> [{'labels': ['red line', 'yellow dot'], + >>> 'bboxes': [[0.38, 0.15, 0.59, 0.7], [0.48, 0.25, 0.69, 0.71]], + >>> 'scores': [0.98, 0.02]}] + """ + _ENDPOINT = "https://chnicr4kes5ku77niv2zoytggq0qyqlp.lambda-url.us-east-2.on.aws" name = "grounding_dino_" @@ -177,18 +189,21 @@ class GroundingSAM(Tool): inputs such as category names or referring expressions. Examples:: - >>> from vision_agent.tools import tools - >>> t = tools.GroundingSAM(["red line", "yellow dot", "none"]) - >>> t("examples/img/ct_scan1.jpg") - >>> [{'label': 'yellow dot', 'mask': array([[0, 0, 0, ..., 0, 0, 0], + >>> import vision_agent as va + >>> t = va.tools.GroundingSAM() + >>> t(["red line", "yellow dot"], ct_scan1.jpg"]) + >>> [{'labels': ['yellow dot', 'red line'], + >>> 'bboxes': [[0.38, 0.15, 0.59, 0.7], [0.48, 0.25, 0.69, 0.71]], + >>> 'masks': [array([[0, 0, 0, ..., 0, 0, 0], >>> [0, 0, 0, ..., 0, 0, 0], >>> ..., >>> [0, 0, 0, ..., 0, 0, 0], - >>> [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)}, {'label': 'red line', 'mask': array([[0, 0, 0, ..., 0, 0, 0], + >>> [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)}, + >>> array([[0, 0, 0, ..., 0, 0, 0], >>> [0, 0, 0, ..., 0, 0, 0], >>> ..., >>> [1, 1, 1, ..., 1, 1, 1], - >>> [1, 1, 1, ..., 1, 1, 1]], dtype=uint8)}] + >>> [1, 1, 1, ..., 1, 1, 1]], dtype=uint8)]}] """ _ENDPOINT = "https://cou5lfmus33jbddl6hoqdfbw7e0qidrw.lambda-url.us-east-2.on.aws" @@ -287,7 +302,6 @@ class Counter(Tool): def __call__(self, prompt: str, image: Union[str, ImageType]) -> Dict: resp = GroundingDINO()(prompt, image) - __import__("ipdb").set_trace() return dict(CounterClass(resp[0]["labels"])) @@ -320,51 +334,13 @@ def __call__(self, bbox: List[float], image: Union[str, Path]) -> str: int(bbox[2] * width), int(bbox[3] * height), ] - cropped_image = pil_image.crop(bbox) + cropped_image = pil_image.crop(bbox) # type: ignore with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: cropped_image.save(tmp.name) return tmp.name -class ImageSearch(Tool): - name = "image_search_" - description = "'image_search_' searches for images similar to the input image." - usage = { - "required_parameters": [{"name": "image", "type": "str"}], - "examples": [ - { - "scenario": "Can you find images similar to the image? Image name: image.jpg", - "parameters": {"image": "image.jpg"}, - } - ], - } - - def __call__(self, image: Union[str, Path]) -> List[str]: - assert isinstance(image, str), "The input image must be a string url." - image = "https://popmenucloud.com/cdn-cgi/image/width%3D1920%2Cheight%3D1920%2Cfit%3Dscale-down%2Cformat%3Dauto%2Cquality%3D60/vpylarnm/a6ad1671-8938-457f-b4cd-3215caa122cb.png" - url = "https://www.googleapis.com/customsearch/v1" - api_key = os.getenv("GOOGLE_API_KEY") - search_engine_id = os.getenv("GOOGLE_SEARCH_ENGINE_ID") - assert api_key and search_engine_id, "Please set the GOOGLE_API_KEY and GOOGLE_SEARCH_ENGINE_ID environment variable. See https://developers.google.com/custom-search/v1/using_rest for more information." - params = { - "key": api_key, - "cx": search_engine_id, - "q": image, - "num": 10, - "searchType":"image", - } - response = requests.get(url, params=params) - - # Check if the request was successful - if response.status_code != 200: - raise RuntimeError(f"Failed to fetch data: {response.status_code} {response.reason}") - - resp = response.json() - items = resp.get("items", []) - return [item["link"] for item in items] - - class Add(Tool): name = "add_" description = "'add_' returns the sum of all the arguments passed to it, normalized to 2 decimal places."