diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 8c2f81dd..c8c9d52e 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -440,28 +440,23 @@ def __call__( data = { "prompt": prompt, "image": image_b64, + "tool": "visual_grounding_segment", + "kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold}, } - res = requests.post( - self._ENDPOINT, - headers={"Content-Type": "application/json"}, - json=data, - ) - resp_json: Dict[str, Any] = res.json() - if ( - "statusCode" in resp_json and resp_json["statusCode"] != 200 - ) or "statusCode" not in resp_json: - _LOGGER.error(f"Request failed: {resp_json}") - raise ValueError(f"Request failed: {resp_json}") - rets: Dict[str, Any] = resp_json["data"] - shape = rets.pop("mask_shape") - mask_files = [] - for encoded_mask in rets["masks"]: - mask = rle_decode(mask_rle=encoded_mask, shape=shape) - with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: - Image.fromarray(mask * 255).save(tmp) - mask_files.append(tmp.name) - rets["masks"] = mask_files - return rets + data: Dict[str, Any] = _send_inference_request(request_data, "tools") + ret_pred: Dict[str, List] = {"labels": [], "bboxes": [], "masks": []} + if "bboxes" in data: + ret_pred["bboxes"] = [ + normalize_bbox(box, image_size) for box in data["bboxes"] + ] + if "masks" in data: + ret_pred["masks"] = [ + rle_decode(mask_rle=mask, shape=data["mask_shape"]) + for mask in data["masks"] + ] + ret_pred["labels"] = data["labels"] + ret_pred["scores"] = data["scores"] + return ret_pred class AgentGroundingSAM(GroundingSAM):