Skip to content

Commit

Permalink
Add two new models: CLIP and Grounded SAM (#18)
Browse files Browse the repository at this point in the history
Add two new tools/models: CLIP and Grounded SAM

Co-authored-by: Yazhou Cao <[email protected]>
  • Loading branch information
humpydonkey and AsiaCao authored Mar 17, 2024
1 parent eda075c commit 62c4982
Showing 1 changed file with 97 additions and 7 deletions.
104 changes: 97 additions & 7 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union, cast

import numpy as np
import requests
from PIL.Image import Image as ImageType

Expand All @@ -23,13 +24,40 @@ def normalize_bbox(
return [x1, y1, x2, y2]


def rle_decode(mask_rle: str, shape: Tuple[int, int]) -> np.ndarray:
"""
mask_rle: run-length as string formated (start length)
shape: (height,width) of array to return
Returns numpy array, 1 - mask, 0 - background
"""
s = mask_rle.split()
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
starts -= 1
ends = starts + lengths
img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
for lo, hi in zip(starts, ends):
img[lo:hi] = 1
return img.reshape(shape)


class ImageTool(ABC):
@abstractmethod
def __call__(self, image: Union[str, ImageType]) -> List[Dict]:
pass


class CLIP(ImageTool):
"""
Example usage:
> from vision_agent.tools import tools
> t = tools.CLIP(["red line", "yellow dot", "none"])
> t("examples/img/ct_scan1.jpg"))
[[0.02567436918616295, 0.9534115791320801, 0.020914122462272644]]
"""

_ENDPOINT = "https://rb4ii6dfacmwqfxivi4aedyyfm0endsv.lambda-url.us-east-2.on.aws"

doc = (
"CLIP is a tool that can classify or tag any image given a set if input classes or tags."
"Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n"
Expand All @@ -38,11 +66,27 @@ class CLIP(ImageTool):
'Exmaple 3: User Question: "Can you build me a classifier taht classifies red shirts, green shirts and other?" {{"Parameters":{{"prompt": ["red shirt", "green shirt", "other"]}}}}\n'
)

def __init__(self, prompt: str):
def __init__(self, prompt: list[str]):
self.prompt = prompt

def __call__(self, image: Union[str, ImageType]) -> List[Dict]:
raise NotImplementedError
image_b64 = convert_to_b64(image)
data = {
"classes": self.prompt,
"images": [image_b64],
}
res = requests.post(
self._ENDPOINT,
headers={"Content-Type": "application/json"},
json=data,
)
resp_json: Dict[str, Any] = res.json()
if (
"statusCode" in resp_json and resp_json["statusCode"] != 200
) or "statusCode" not in resp_json:
_LOGGER.error(f"Request failed: {resp_json}")
raise ValueError(f"Request failed: {resp_json}")
return cast(List[Dict], resp_json["data"])


class GroundingDINO(ImageTool):
Expand Down Expand Up @@ -92,16 +136,62 @@ def __call__(self, image: Union[str, Path, ImageType]) -> List[Dict]:


class GroundingSAM(ImageTool):
"""
Example usage:
> from vision_agent.tools import tools
> t = tools.GroundingSAM(["red line", "yellow dot", "none"])
> t("examples/img/ct_scan1.jpg")
[{'label': 'none', 'mask': array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8)}, {'label': 'red line', 'mask': array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]], dtype=uint8)}]
"""

_ENDPOINT = "https://cou5lfmus33jbddl6hoqdfbw7e0qidrw.lambda-url.us-east-2.on.aws"

doc = (
"Grounding SAM is a tool that can detect and segment arbitrary objects with inputs such as category names or referring expressions."
"Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n"
'Example 1: User Question: "Can you build me a car segmentor?" {{"Parameters":{{"prompt": "car"}}}}\n'
'Example 2: User Question: "Can you segment the person on the left?" {{"Parameters":{{"prompt": "person on the left"}}\n'
'Exmaple 3: User Question: "Can you build me a tool that segments red shirts and green shirts?" {{"Parameters":{{"prompt": "red shirt. green shirt"}}}}\n'
'Example 1: User Question: "Can you build me a car segmentor?" {{"Parameters":{{"prompt": ["car"]}}}}\n'
'Example 2: User Question: "Can you segment the person on the left?" {{"Parameters":{{"prompt": ["person on the left"]}}\n'
'Exmaple 3: User Question: "Can you build me a tool that segments red shirts and green shirts?" {{"Parameters":{{"prompt": ["red shirt", "green shirt"]}}}}\n'
)

def __init__(self, prompt: str):
def __init__(self, prompt: list[str]):
self.prompt = prompt

def __call__(self, image: Union[str, ImageType]) -> List[Dict]:
raise NotImplementedError
image_b64 = convert_to_b64(image)
data = {
"classes": self.prompt,
"image": image_b64,
}
res = requests.post(
self._ENDPOINT,
headers={"Content-Type": "application/json"},
json=data,
)
resp_json: Dict[str, Any] = res.json()
if (
"statusCode" in resp_json and resp_json["statusCode"] != 200
) or "statusCode" not in resp_json:
_LOGGER.error(f"Request failed: {resp_json}")
raise ValueError(f"Request failed: {resp_json}")
resp_data = resp_json["data"]
preds = []
for pred in resp_data["preds"]:
encoded_mask = pred["encoded_mask"]
mask = rle_decode(mask_rle=encoded_mask, shape=pred["mask_shape"])
preds.append(
{
"label": pred["label_name"],
"mask": mask,
}
)
return preds

0 comments on commit 62c4982

Please sign in to comment.