Skip to content

Commit

Permalink
added more test cases for correct format, fixed normalize bboxes for …
Browse files Browse the repository at this point in the history
…countgd
  • Loading branch information
dillonalaird committed Oct 9, 2024
1 parent 82c4d20 commit 760bae5
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
10 changes: 10 additions & 0 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def test_owl_v2_image():
)
assert 24 <= len(result) <= 26
assert [res["label"] for res in result] == ["coin"] * len(result)
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result])


def test_owl_v2_fine_tune_id():
Expand All @@ -80,6 +81,7 @@ def test_owl_v2_fine_tune_id():
# this calls a fine-tuned florence2 model which is going to be worse at this task
assert 14 <= len(result) <= 26
assert [res["label"] for res in result] == ["coin"] * len(result)
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result])


def test_owl_v2_video():
Expand All @@ -93,6 +95,7 @@ def test_owl_v2_video():

assert len(result) == 10
assert 24 <= len([res["label"] for res in result[0]]) <= 26
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result[0]])


def test_florence2_phrase_grounding():
Expand All @@ -101,8 +104,10 @@ def test_florence2_phrase_grounding():
image=img,
prompt="coin",
)

assert len(result) == 25
assert [res["label"] for res in result] == ["coin"] * 25
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result])


def test_florence2_phrase_grounding_fine_tune_id():
Expand All @@ -115,6 +120,7 @@ def test_florence2_phrase_grounding_fine_tune_id():
# this calls a fine-tuned florence2 model which is going to be worse at this task
assert 14 <= len(result) <= 26
assert [res["label"] for res in result] == ["coin"] * len(result)
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result])


def test_florence2_phrase_grounding_video():
Expand All @@ -127,6 +133,7 @@ def test_florence2_phrase_grounding_video():
)
assert len(result) == 10
assert 2 <= len([res["label"] for res in result[0]]) <= 26
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result[0]])


def test_florence2_phrase_grounding_video_fine_tune_id():
Expand All @@ -141,6 +148,7 @@ def test_florence2_phrase_grounding_video_fine_tune_id():
)
assert len(result) == 10
assert 16 <= len([res["label"] for res in result[0]]) <= 26
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result[0]])


def test_template_match():
Expand Down Expand Up @@ -395,6 +403,7 @@ def test_countgd_counting() -> None:
img = ski.data.coins()
result = countgd_counting(image=img, prompt="coin")
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24


def test_countgd_example_based_counting() -> None:
Expand All @@ -404,3 +413,4 @@ def test_countgd_example_based_counting() -> None:
image=img,
)
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24
6 changes: 4 additions & 2 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,7 @@ def countgd_counting(
{'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58},
]
"""
image_size = image.shape[:2]
buffer_bytes = numpy_to_bytes(image)
files = [("image", buffer_bytes)]
prompt = prompt.replace(", ", " .")
Expand All @@ -712,7 +713,7 @@ def countgd_counting(
bboxes_formatted = [
ODResponseData(
label=bbox["label"],
bbox=list(map(lambda x: round(x, 2), bbox["bounding_box"])),
bbox=normalize_bbox(bbox["bounding_box"], image_size),
score=round(bbox["score"], 2),
)
for bbox in bboxes_per_frame
Expand Down Expand Up @@ -757,6 +758,7 @@ def countgd_example_based_counting(
{'score': 0.98, 'label': 'object', 'bounding_box': [0.44, 0.24, 0.49, 0.58},
]
"""
image_size = image.shape[:2]
buffer_bytes = numpy_to_bytes(image)
files = [("image", buffer_bytes)]
visual_prompts = [
Expand All @@ -771,7 +773,7 @@ def countgd_example_based_counting(
bboxes_formatted = [
ODResponseData(
label=bbox["label"],
bbox=list(map(lambda x: round(x, 2), bbox["bounding_box"])),
bbox=normalize_bbox(bbox["bounding_box"], image_size),
score=round(bbox["score"], 2),
)
for bbox in bboxes_per_frame
Expand Down

0 comments on commit 760bae5

Please sign in to comment.