From 2a6908201471d1af760314bcf6b527d3a393bc7f Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Thu, 18 Apr 2024 11:32:21 -0700 Subject: [PATCH] add dinov with updated endpoint --- vision_agent/tools/tools.py | 46 ++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 192b2a1f..5461999b 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -389,8 +389,6 @@ class DINOv(Tool): [1, 1, 1, ..., 1, 1, 1]], dtype=uint8)]}] """ - _ENDPOINT = "https://rkgkjvqgh7vbzjdb23tr7eay4a0vczdo.lambda-url.us-east-2.on.aws" - name = "dinov_" description = "'dinov_' is a tool that can detect and segment similar objects given a reference segmentation mask." usage = { @@ -437,31 +435,36 @@ def __call__( for p in prompt: p["mask"] = convert_to_b64(p["mask"]) p["image"] = convert_to_b64(p["image"]) - data = { + request_data = { "prompt": prompt, "image": image_b64, + "tool": "dinov", } - res = requests.post( - self._ENDPOINT, - headers={"Content-Type": "application/json"}, - json=data, - ) - resp_json: Dict[str, Any] = res.json() - if ( - "statusCode" in resp_json and resp_json["statusCode"] != 200 - ) or "statusCode" not in resp_json: - _LOGGER.error(f"Request failed: {resp_json}") - raise ValueError(f"Request failed: {resp_json}") - rets = resp_json["data"] - shape = rets.pop("mask_shape") + data: Dict[str, Any] = _send_inference_request(request_data, "dinov") + if "bboxes" in data: + data["bboxes"] = [normalize_bbox(box, data["mask_shape"]) for box in data["bboxes"]] + if "masks" in data: + data["masks"] = [ + rle_decode(mask_rle=mask, shape=data["mask_shape"]) for mask in data["masks"] + ] + data["labels"] = ["visual prompt" for _ in range(len(data["masks"]))] + return data + + +class AgentDINOv(DINOv): + def __call__( + self, + prompt: List[Dict[str, str]], + image: Union[str, ImageType], + ) -> Dict: + rets = super().__call__(prompt, image) mask_files = [] - for encoded_mask in rets["masks"]: - mask = rle_decode(mask_rle=encoded_mask, shape=shape) + for mask in rets["masks"]: with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: - Image.fromarray(mask * 255).save(tmp) - mask_files.append(tmp.name) + file_name = Path(tmp.name).with_suffix(".mask.png") + Image.fromarray(mask * 255).save(file_name) + mask_files.append(str(file_name)) rets["masks"] = mask_files - rets["labels"] = ["visual prompt" for _ in range(len(mask_files))] return rets @@ -745,6 +748,7 @@ def __call__(self, equation: str) -> float: ImageCaption, GroundingDINO, AgentGroundingSAM, + AgentDINOv, ExtractFrames, Crop, BboxArea,