Skip to content

Commit

Permalink
Merge branch 'main' into add-dinov
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird authored Apr 17, 2024
2 parents da818ad + 85a6170 commit 33cdf31
Showing 1 changed file with 16 additions and 21 deletions.
37 changes: 16 additions & 21 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,28 +440,23 @@ def __call__(
data = {
"prompt": prompt,
"image": image_b64,
"tool": "visual_grounding_segment",
"kwargs": {"box_threshold": box_threshold, "iou_threshold": iou_threshold},
}
res = requests.post(
self._ENDPOINT,
headers={"Content-Type": "application/json"},
json=data,
)
resp_json: Dict[str, Any] = res.json()
if (
"statusCode" in resp_json and resp_json["statusCode"] != 200
) or "statusCode" not in resp_json:
_LOGGER.error(f"Request failed: {resp_json}")
raise ValueError(f"Request failed: {resp_json}")
rets: Dict[str, Any] = resp_json["data"]
shape = rets.pop("mask_shape")
mask_files = []
for encoded_mask in rets["masks"]:
mask = rle_decode(mask_rle=encoded_mask, shape=shape)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
Image.fromarray(mask * 255).save(tmp)
mask_files.append(tmp.name)
rets["masks"] = mask_files
return rets
data: Dict[str, Any] = _send_inference_request(request_data, "tools")
ret_pred: Dict[str, List] = {"labels": [], "bboxes": [], "masks": []}
if "bboxes" in data:
ret_pred["bboxes"] = [
normalize_bbox(box, image_size) for box in data["bboxes"]
]
if "masks" in data:
ret_pred["masks"] = [
rle_decode(mask_rle=mask, shape=data["mask_shape"])
for mask in data["masks"]
]
ret_pred["labels"] = data["labels"]
ret_pred["scores"] = data["scores"]
return ret_pred


class AgentGroundingSAM(GroundingSAM):
Expand Down

0 comments on commit 33cdf31

Please sign in to comment.