Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add two new models: CLIP and Grounded SAM #18

Merged
merged 1 commit into from
Mar 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 97 additions & 7 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union, cast

import numpy as np
import requests
from PIL.Image import Image as ImageType

Expand All @@ -23,13 +24,40 @@ def normalize_bbox(
return [x1, y1, x2, y2]


def rle_decode(mask_rle: str, shape: Tuple[int, int]) -> np.ndarray:
"""
mask_rle: run-length as string formated (start length)
shape: (height,width) of array to return
Returns numpy array, 1 - mask, 0 - background
"""
s = mask_rle.split()
starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
starts -= 1
ends = starts + lengths
img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
for lo, hi in zip(starts, ends):
img[lo:hi] = 1
return img.reshape(shape)


class ImageTool(ABC):
@abstractmethod
def __call__(self, image: Union[str, ImageType]) -> List[Dict]:
pass


class CLIP(ImageTool):
"""
Example usage:
> from vision_agent.tools import tools
> t = tools.CLIP(["red line", "yellow dot", "none"])
> t("examples/img/ct_scan1.jpg"))

[[0.02567436918616295, 0.9534115791320801, 0.020914122462272644]]
"""

_ENDPOINT = "https://rb4ii6dfacmwqfxivi4aedyyfm0endsv.lambda-url.us-east-2.on.aws"

doc = (
"CLIP is a tool that can classify or tag any image given a set if input classes or tags."
"Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n"
Expand All @@ -38,11 +66,27 @@ class CLIP(ImageTool):
'Exmaple 3: User Question: "Can you build me a classifier taht classifies red shirts, green shirts and other?" {{"Parameters":{{"prompt": ["red shirt", "green shirt", "other"]}}}}\n'
)

def __init__(self, prompt: str):
def __init__(self, prompt: list[str]):
self.prompt = prompt

def __call__(self, image: Union[str, ImageType]) -> List[Dict]:
raise NotImplementedError
image_b64 = convert_to_b64(image)
data = {
"classes": self.prompt,
"images": [image_b64],
}
res = requests.post(
self._ENDPOINT,
headers={"Content-Type": "application/json"},
json=data,
)
resp_json: Dict[str, Any] = res.json()
if (
"statusCode" in resp_json and resp_json["statusCode"] != 200
) or "statusCode" not in resp_json:
_LOGGER.error(f"Request failed: {resp_json}")
raise ValueError(f"Request failed: {resp_json}")
return cast(List[Dict], resp_json["data"])


class GroundingDINO(ImageTool):
Expand Down Expand Up @@ -92,16 +136,62 @@ def __call__(self, image: Union[str, Path, ImageType]) -> List[Dict]:


class GroundingSAM(ImageTool):
"""
Example usage:
> from vision_agent.tools import tools
> t = tools.GroundingSAM(["red line", "yellow dot", "none"])
> t("examples/img/ct_scan1.jpg")

[{'label': 'none', 'mask': array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8)}, {'label': 'red line', 'mask': array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]], dtype=uint8)}]
"""

_ENDPOINT = "https://cou5lfmus33jbddl6hoqdfbw7e0qidrw.lambda-url.us-east-2.on.aws"

doc = (
"Grounding SAM is a tool that can detect and segment arbitrary objects with inputs such as category names or referring expressions."
"Here are some exmaples of how to use the tool, the examples are in the format of User Question: which will have the user's question in quotes followed by the parameters in JSON format, which is the parameters you need to output to call the API to solve the user's question.\n"
'Example 1: User Question: "Can you build me a car segmentor?" {{"Parameters":{{"prompt": "car"}}}}\n'
'Example 2: User Question: "Can you segment the person on the left?" {{"Parameters":{{"prompt": "person on the left"}}\n'
'Exmaple 3: User Question: "Can you build me a tool that segments red shirts and green shirts?" {{"Parameters":{{"prompt": "red shirt. green shirt"}}}}\n'
'Example 1: User Question: "Can you build me a car segmentor?" {{"Parameters":{{"prompt": ["car"]}}}}\n'
'Example 2: User Question: "Can you segment the person on the left?" {{"Parameters":{{"prompt": ["person on the left"]}}\n'
'Exmaple 3: User Question: "Can you build me a tool that segments red shirts and green shirts?" {{"Parameters":{{"prompt": ["red shirt", "green shirt"]}}}}\n'
)

def __init__(self, prompt: str):
def __init__(self, prompt: list[str]):
self.prompt = prompt

def __call__(self, image: Union[str, ImageType]) -> List[Dict]:
raise NotImplementedError
image_b64 = convert_to_b64(image)
data = {
"classes": self.prompt,
"image": image_b64,
}
res = requests.post(
self._ENDPOINT,
headers={"Content-Type": "application/json"},
json=data,
)
resp_json: Dict[str, Any] = res.json()
if (
"statusCode" in resp_json and resp_json["statusCode"] != 200
) or "statusCode" not in resp_json:
_LOGGER.error(f"Request failed: {resp_json}")
raise ValueError(f"Request failed: {resp_json}")
resp_data = resp_json["data"]
preds = []
for pred in resp_data["preds"]:
encoded_mask = pred["encoded_mask"]
mask = rle_decode(mask_rle=encoded_mask, shape=pred["mask_shape"])
preds.append(
{
"label": pred["label_name"],
"mask": mask,
}
)
return preds
Loading