From 44feb201d579e087cf6436d5e4609ac4619cc49c Mon Sep 17 00:00:00 2001 From: Yazhou Cao Date: Tue, 9 Apr 2024 09:42:42 -0700 Subject: [PATCH] Add DINOv as a new tool --- vision_agent/tools/tools.py | 90 +++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 74d56c0a..6e1f1bcd 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -316,6 +316,96 @@ def __call__(self, prompt: List[str], image: Union[str, ImageType]) -> Dict: return ret_pred +class DINOv(Tool): + r"""DINOv is a tool that can detect and segment similar objects with the given input masks. + + Example + ------- + >>> import vision_agent as va + >>> t = va.tools.DINOv() + >>> t(["red line", "yellow dot"], ct_scan1.jpg"]) + [{'scores': [0.512, 0.212], + 'masks': [array([[0, 0, 0, ..., 0, 0, 0], + ..., + [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)}, + array([[0, 0, 0, ..., 0, 0, 0], + ..., + [1, 1, 1, ..., 1, 1, 1]], dtype=uint8)]}] + """ + + _ENDPOINT = "https://rkgkjvqgh7vbzjdb23tr7eay4a0vczdo.lambda-url.us-east-2.on.aws" + + name = "dinov_" + description = "'dinov_' is a tool that can detect and segment similar objects with the given input segmentation masks." + usage = { + "required_parameters": [ + {"name": "prompt", "type": "List[Dict[str, str]]"}, + {"name": "image", "type": "str"}, + ], + "examples": [ + { + "scenario": "Can you find all the balloons in this image that is similar to the provided masked area?", + "parameters": { + "prompt": [ + {"mask": "balloon_mask1.jpg", "image": "balloon.jpg"}, + {"mask": "balloon_mask2.jpg", "image": "balloon.jpg"} + ], + "image": "input.jpg" + }, + }, + { + "scenario": "Count all the objects in this image that is similar to the provided masked area? image: input.jpg, mask: mask.jpg, mask_image: background.jpg", + "parameters": { + "prompt": [ + {"mask": "obj_mask1.jpg", "image": "background.jpg"}, + ], + "image": "input.jpg" + }, + }, + ], + } + + def __call__(self, prompt: List[Dict[str, str]], image: Union[str, ImageType]) -> Dict: + """Invoke the DINOv model. + + Parameters: + prompt: a list of visual prompts in the form of {'mask': 'MASK_FILE_PATH', 'image': 'IMAGE_FILE_PATH'}. + image: the input image to segment. + + Returns: + A dictionary of the below keys: 'scores', 'masks' and 'mask_shape', which stores a list of detected segmentation masks and its scores. + """ + image_b64 = convert_to_b64(image) + for p in prompt: + p["mask"] = convert_to_b64(p["mask"]) + p["image"] = convert_to_b64(p["image"]) + data = { + "prompt": prompt, + "image": image_b64, + } + res = requests.post( + self._ENDPOINT, + headers={"Content-Type": "application/json"}, + json=data, + ) + resp_json: Dict[str, Any] = res.json() + if ( + "statusCode" in resp_json and resp_json["statusCode"] != 200 + ) or "statusCode" not in resp_json: + _LOGGER.error(f"Request failed: {resp_json}") + raise ValueError(f"Request failed: {resp_json}") + rets = resp_json["data"] + shape = rets.pop("mask_shape") + mask_files = [] + for encoded_mask in rets["masks"]: + mask = rle_decode(mask_rle=encoded_mask, shape=shape) + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: + Image.fromarray(mask * 255).save(tmp) + mask_files.append(tmp.name) + rets["masks"] = mask_files + return rets + + class AgentGroundingSAM(GroundingSAM): r"""AgentGroundingSAM is the same as GroundingSAM but it saves the masks as files returns the file name. This makes it easier for agents to use.