diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 65d43943..192b2a1f 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -440,23 +440,29 @@ def __call__( data = { "prompt": prompt, "image": image_b64, - "tool": "visual_grounding_segment", - "kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold}, } - data: Dict[str, Any] = _send_inference_request(request_data, "tools") - ret_pred: Dict[str, List] = {"labels": [], "bboxes": [], "masks": []} - if "bboxes" in data: - ret_pred["bboxes"] = [ - normalize_bbox(box, image_size) for box in data["bboxes"] - ] - if "masks" in data: - ret_pred["masks"] = [ - rle_decode(mask_rle=mask, shape=data["mask_shape"]) - for mask in data["masks"] - ] - ret_pred["labels"] = data["labels"] - ret_pred["scores"] = data["scores"] - return ret_pred + res = requests.post( + self._ENDPOINT, + headers={"Content-Type": "application/json"}, + json=data, + ) + resp_json: Dict[str, Any] = res.json() + if ( + "statusCode" in resp_json and resp_json["statusCode"] != 200 + ) or "statusCode" not in resp_json: + _LOGGER.error(f"Request failed: {resp_json}") + raise ValueError(f"Request failed: {resp_json}") + rets = resp_json["data"] + shape = rets.pop("mask_shape") + mask_files = [] + for encoded_mask in rets["masks"]: + mask = rle_decode(mask_rle=encoded_mask, shape=shape) + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: + Image.fromarray(mask * 255).save(tmp) + mask_files.append(tmp.name) + rets["masks"] = mask_files + rets["labels"] = ["visual prompt" for _ in range(len(mask_files))] + return rets class AgentGroundingSAM(GroundingSAM):