Skip to content

Commit

Permalink
correct output format for cgd
Browse files Browse the repository at this point in the history
  • Loading branch information
shankar-vision-eng committed Sep 2, 2024
1 parent 18cce4f commit 7dd6a37
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
5 changes: 2 additions & 3 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,8 @@ def test_loca_visual_prompt_counting() -> None:

def test_countgd_counting() -> None:
img = ski.data.coins()

result = countgd_counting(image=img, prompt="coin")
assert result["count"] == 24
assert len(result) == 24


def test_countgd_example_based_counting() -> None:
Expand All @@ -201,7 +200,7 @@ def test_countgd_example_based_counting() -> None:
visual_prompts=[[85, 106, 122, 145]],
image=img,
)
assert result["count"] == 24
assert len(result) == 24


def test_git_vqa_v2() -> None:
Expand Down
13 changes: 9 additions & 4 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,10 +540,11 @@ def countgd_counting(
"box_threshold": box_threshold,
}
metadata_payload = {"function_name": "countgd_counting"}
resp: List[Dict[str, Any]] = send_inference_request(
resp_data: List[Dict[str, Any]] = send_inference_request(
payload, "countgd", v2=True, metadata_payload=metadata_payload
) # type: ignore
return resp["data"]

return resp_data


def countgd_example_based_counting(
Expand Down Expand Up @@ -583,16 +584,20 @@ def countgd_example_based_counting(
]
"""
image_b64 = convert_to_b64(image)
visual_prompts = [
denormalize_bbox(bbox, image.shape[:2]) for bbox in visual_prompts
]
payload = {
"image": image_b64,
"visual_prompts": visual_prompts,
"box_threshold": box_threshold,
}
metadata_payload = {"function_name": "countgd_example_based_counting"}
resp: List[Dict[str, Any]] = send_inference_request(
resp_data: List[Dict[str, Any]] = send_inference_request(
payload, "countgd", v2=True, metadata_payload=metadata_payload
) # type: ignore
return resp["data"]

return resp_data


def florence2_roberta_vqa(prompt: str, image: np.ndarray) -> str:
Expand Down

0 comments on commit 7dd6a37

Please sign in to comment.