Skip to content

Commit

Permalink
Add DINOv as a new tool
Browse files Browse the repository at this point in the history
  • Loading branch information
AsiaCao committed Apr 9, 2024
1 parent c5de3b8 commit 44feb20
Showing 1 changed file with 90 additions and 0 deletions.
90 changes: 90 additions & 0 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,96 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict:
return ret_pred


class DINOv(Tool):
r"""DINOv is a tool that can detect and segment similar objects with the given input masks.
Example
-------
>>> import vision_agent as va
>>> t = va.tools.DINOv()
>>> t(["red line", "yellow dot"], ct_scan1.jpg"])
[{'scores': [0.512, 0.212],
'masks': [array([[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8)},
array([[0, 0, 0, ..., 0, 0, 0],
...,
[1, 1, 1, ..., 1, 1, 1]], dtype=uint8)]}]
"""

_ENDPOINT = "https://rkgkjvqgh7vbzjdb23tr7eay4a0vczdo.lambda-url.us-east-2.on.aws"

name = "dinov_"
description = "'dinov_' is a tool that can detect and segment similar objects with the given input segmentation masks."
usage = {
"required_parameters": [
{"name": "prompt", "type": "List[Dict[str, str]]"},
{"name": "image", "type": "str"},
],
"examples": [
{
"scenario": "Can you find all the balloons in this image that is similar to the provided masked area?",
"parameters": {
"prompt": [
{"mask": "balloon_mask1.jpg", "image": "balloon.jpg"},
{"mask": "balloon_mask2.jpg", "image": "balloon.jpg"}
],
"image": "input.jpg"
},
},
{
"scenario": "Count all the objects in this image that is similar to the provided masked area? image: input.jpg, mask: mask.jpg, mask_image: background.jpg",
"parameters": {
"prompt": [
{"mask": "obj_mask1.jpg", "image": "background.jpg"},
],
"image": "input.jpg"
},
},
],
}

def __call__(self, prompt: List[Dict[str, str]], image: Union[str, ImageType]) -> Dict:
"""Invoke the DINOv model.
Parameters:
prompt: a list of visual prompts in the form of {'mask': 'MASK_FILE_PATH', 'image': 'IMAGE_FILE_PATH'}.
image: the input image to segment.
Returns:
A dictionary of the below keys: 'scores', 'masks' and 'mask_shape', which stores a list of detected segmentation masks and its scores.
"""
image_b64 = convert_to_b64(image)
for p in prompt:
p["mask"] = convert_to_b64(p["mask"])
p["image"] = convert_to_b64(p["image"])
data = {
"prompt": prompt,
"image": image_b64,
}
res = requests.post(
self._ENDPOINT,
headers={"Content-Type": "application/json"},
json=data,
)
resp_json: Dict[str, Any] = res.json()
if (
"statusCode" in resp_json and resp_json["statusCode"] != 200
) or "statusCode" not in resp_json:
_LOGGER.error(f"Request failed: {resp_json}")
raise ValueError(f"Request failed: {resp_json}")
rets = resp_json["data"]
shape = rets.pop("mask_shape")
mask_files = []
for encoded_mask in rets["masks"]:
mask = rle_decode(mask_rle=encoded_mask, shape=shape)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
Image.fromarray(mask * 255).save(tmp)
mask_files.append(tmp.name)
rets["masks"] = mask_files
return rets


class AgentGroundingSAM(GroundingSAM):
r"""AgentGroundingSAM is the same as GroundingSAM but it saves the masks as files
returns the file name. This makes it easier for agents to use.
Expand Down

0 comments on commit 44feb20

Please sign in to comment.