Skip to content

Commit

Permalink
add dinov with updated endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Apr 18, 2024
1 parent 81f3cf3 commit 2a69082
Showing 1 changed file with 25 additions and 21 deletions.
46 changes: 25 additions & 21 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,6 @@ class DINOv(Tool):
[1, 1, 1, ..., 1, 1, 1]], dtype=uint8)]}]
"""

_ENDPOINT = "https://rkgkjvqgh7vbzjdb23tr7eay4a0vczdo.lambda-url.us-east-2.on.aws"

name = "dinov_"
description = "'dinov_' is a tool that can detect and segment similar objects given a reference segmentation mask."
usage = {
Expand Down Expand Up @@ -437,31 +435,36 @@ def __call__(
for p in prompt:
p["mask"] = convert_to_b64(p["mask"])
p["image"] = convert_to_b64(p["image"])
data = {
request_data = {
"prompt": prompt,
"image": image_b64,
"tool": "dinov",
}
res = requests.post(
self._ENDPOINT,
headers={"Content-Type": "application/json"},
json=data,
)
resp_json: Dict[str, Any] = res.json()
if (
"statusCode" in resp_json and resp_json["statusCode"] != 200
) or "statusCode" not in resp_json:
_LOGGER.error(f"Request failed: {resp_json}")
raise ValueError(f"Request failed: {resp_json}")
rets = resp_json["data"]
shape = rets.pop("mask_shape")
data: Dict[str, Any] = _send_inference_request(request_data, "dinov")
if "bboxes" in data:
data["bboxes"] = [normalize_bbox(box, data["mask_shape"]) for box in data["bboxes"]]
if "masks" in data:
data["masks"] = [
rle_decode(mask_rle=mask, shape=data["mask_shape"]) for mask in data["masks"]
]
data["labels"] = ["visual prompt" for _ in range(len(data["masks"]))]
return data


class AgentDINOv(DINOv):
def __call__(
self,
prompt: List[Dict[str, str]],
image: Union[str, ImageType],
) -> Dict:
rets = super().__call__(prompt, image)
mask_files = []
for encoded_mask in rets["masks"]:
mask = rle_decode(mask_rle=encoded_mask, shape=shape)
for mask in rets["masks"]:
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
Image.fromarray(mask * 255).save(tmp)
mask_files.append(tmp.name)
file_name = Path(tmp.name).with_suffix(".mask.png")
Image.fromarray(mask * 255).save(file_name)
mask_files.append(str(file_name))
rets["masks"] = mask_files
rets["labels"] = ["visual prompt" for _ in range(len(mask_files))]
return rets


Expand Down Expand Up @@ -745,6 +748,7 @@ def __call__(self, equation: str) -> float:
ImageCaption,
GroundingDINO,
AgentGroundingSAM,
AgentDINOv,
ExtractFrames,
Crop,
BboxArea,
Expand Down

0 comments on commit 2a69082

Please sign in to comment.