From 62c49826ccf76056aa30c6cc297abef3aa44dbab Mon Sep 17 00:00:00 2001 From: Asia <2736300+humpydonkey@users.noreply.github.com> Date: Sat, 16 Mar 2024 22:43:56 -0700 Subject: [PATCH] Add two new models: CLIP and Grounded SAM (#18) Add two new tools/models: CLIP and Grounded SAM Co-authored-by: Yazhou Cao --- vision_agent/tools/tools.py | 104 +++++++++++++++++++++++++++++++++--- 1 file changed, 97 insertions(+), 7 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index de8960dd..b892e2c9 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Any, Dict, List, Tuple, Union, cast +import numpy as np import requests from PIL.Image import Image as ImageType @@ -23,6 +24,22 @@ def normalize_bbox( return [x1, y1, x2, y2] +def rle_decode(mask_rle: str, shape: Tuple[int, int]) -> np.ndarray: + """ + mask_rle: run-length as string formated (start length) + shape: (height,width) of array to return + Returns numpy array, 1 - mask, 0 - background + """ + s = mask_rle.split() + starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])] + starts -= 1 + ends = starts + lengths + img = np.zeros(shape[0] * shape[1], dtype=np.uint8) + for lo, hi in zip(starts, ends): + img[lo:hi] = 1 + return img.reshape(shape) + + class ImageTool(ABC): @abstractmethod def __call__(self, image: Union[str, ImageType]) -> List[Dict]: @@ -30,6 +47,17 @@ def __call__(self, image: Union[str, ImageType]) -> List[Dict]: class CLIP(ImageTool): + """ + Example usage: + > from vision_agent.tools import tools + > t = tools.CLIP(["red line", "yellow dot", "none"]) + > t("examples/img/ct_scan1.jpg")) + + [[0.02567436918616295, 0.9534115791320801, 0.020914122462272644]] + """ + + _ENDPOINT = "https://rb4ii6dfacmwqfxivi4aedyyfm0endsv.lambda-url.us-east-2.on.aws" + doc = ( "CLIP is a tool that can classify or tag any image given a set if input classes or tags." "Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n" @@ -38,11 +66,27 @@ class CLIP(ImageTool): 'Exmaple 3: User Question: "Can you build me a classifier taht classifies red shirts, green shirts and other?" {{"Parameters":{{"prompt": ["red shirt", "green shirt", "other"]}}}}\n' ) - def __init__(self, prompt: str): + def __init__(self, prompt: list[str]): self.prompt = prompt def __call__(self, image: Union[str, ImageType]) -> List[Dict]: - raise NotImplementedError + image_b64 = convert_to_b64(image) + data = { + "classes": self.prompt, + "images": [image_b64], + } + res = requests.post( + self._ENDPOINT, + headers={"Content-Type": "application/json"}, + json=data, + ) + resp_json: Dict[str, Any] = res.json() + if ( + "statusCode" in resp_json and resp_json["statusCode"] != 200 + ) or "statusCode" not in resp_json: + _LOGGER.error(f"Request failed: {resp_json}") + raise ValueError(f"Request failed: {resp_json}") + return cast(List[Dict], resp_json["data"]) class GroundingDINO(ImageTool): @@ -92,16 +136,62 @@ def __call__(self, image: Union[str, Path, ImageType]) -> List[Dict]: class GroundingSAM(ImageTool): + """ + Example usage: + > from vision_agent.tools import tools + > t = tools.GroundingSAM(["red line", "yellow dot", "none"]) + > t("examples/img/ct_scan1.jpg") + + [{'label': 'none', 'mask': array([[0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0], + ..., + [0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)}, {'label': 'red line', 'mask': array([[0, 0, 0, ..., 0, 0, 0], + [0, 0, 0, ..., 0, 0, 0], + ..., + [1, 1, 1, ..., 1, 1, 1], + [1, 1, 1, ..., 1, 1, 1]], dtype=uint8)}] + """ + + _ENDPOINT = "https://cou5lfmus33jbddl6hoqdfbw7e0qidrw.lambda-url.us-east-2.on.aws" + doc = ( "Grounding SAM is a tool that can detect and segment arbitrary objects with inputs such as category names or referring expressions." "Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n" - 'Example 1: User Question: "Can you build me a car segmentor?" {{"Parameters":{{"prompt": "car"}}}}\n' - 'Example 2: User Question: "Can you segment the person on the left?" {{"Parameters":{{"prompt": "person on the left"}}\n' - 'Exmaple 3: User Question: "Can you build me a tool that segments red shirts and green shirts?" {{"Parameters":{{"prompt": "red shirt. green shirt"}}}}\n' + 'Example 1: User Question: "Can you build me a car segmentor?" {{"Parameters":{{"prompt": ["car"]}}}}\n' + 'Example 2: User Question: "Can you segment the person on the left?" {{"Parameters":{{"prompt": ["person on the left"]}}\n' + 'Exmaple 3: User Question: "Can you build me a tool that segments red shirts and green shirts?" {{"Parameters":{{"prompt": ["red shirt", "green shirt"]}}}}\n' ) - def __init__(self, prompt: str): + def __init__(self, prompt: list[str]): self.prompt = prompt def __call__(self, image: Union[str, ImageType]) -> List[Dict]: - raise NotImplementedError + image_b64 = convert_to_b64(image) + data = { + "classes": self.prompt, + "image": image_b64, + } + res = requests.post( + self._ENDPOINT, + headers={"Content-Type": "application/json"}, + json=data, + ) + resp_json: Dict[str, Any] = res.json() + if ( + "statusCode" in resp_json and resp_json["statusCode"] != 200 + ) or "statusCode" not in resp_json: + _LOGGER.error(f"Request failed: {resp_json}") + raise ValueError(f"Request failed: {resp_json}") + resp_data = resp_json["data"] + preds = [] + for pred in resp_data["preds"]: + encoded_mask = pred["encoded_mask"] + mask = rle_decode(mask_rle=encoded_mask, shape=pred["mask_shape"]) + preds.append( + { + "label": pred["label_name"], + "mask": mask, + } + ) + return preds