diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index fffd379d..4f5c674f 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -68,6 +68,7 @@ def test_owl_v2_image(): ) assert 24 <= len(result) <= 26 assert [res["label"] for res in result] == ["coin"] * len(result) + assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result]) def test_owl_v2_fine_tune_id(): @@ -80,6 +81,7 @@ def test_owl_v2_fine_tune_id(): # this calls a fine-tuned florence2 model which is going to be worse at this task assert 14 <= len(result) <= 26 assert [res["label"] for res in result] == ["coin"] * len(result) + assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result]) def test_owl_v2_video(): @@ -93,6 +95,7 @@ def test_owl_v2_video(): assert len(result) == 10 assert 24 <= len([res["label"] for res in result[0]]) <= 26 + assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result[0]]) def test_florence2_phrase_grounding(): @@ -101,8 +104,10 @@ def test_florence2_phrase_grounding(): image=img, prompt="coin", ) + assert len(result) == 25 assert [res["label"] for res in result] == ["coin"] * 25 + assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result]) def test_florence2_phrase_grounding_fine_tune_id(): @@ -115,6 +120,7 @@ def test_florence2_phrase_grounding_fine_tune_id(): # this calls a fine-tuned florence2 model which is going to be worse at this task assert 14 <= len(result) <= 26 assert [res["label"] for res in result] == ["coin"] * len(result) + assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result]) def test_florence2_phrase_grounding_video(): @@ -127,6 +133,7 @@ def test_florence2_phrase_grounding_video(): ) assert len(result) == 10 assert 2 <= len([res["label"] for res in result[0]]) <= 26 + assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result[0]]) def test_florence2_phrase_grounding_video_fine_tune_id(): @@ -141,6 +148,7 @@ def test_florence2_phrase_grounding_video_fine_tune_id(): ) assert len(result) == 10 assert 16 <= len([res["label"] for res in result[0]]) <= 26 + assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result[0]]) def test_template_match(): @@ -395,6 +403,7 @@ def test_countgd_counting() -> None: img = ski.data.coins() result = countgd_counting(image=img, prompt="coin") assert len(result) == 24 + assert [res["label"] for res in result] == ["coin"] * 24 def test_countgd_example_based_counting() -> None: @@ -404,3 +413,4 @@ def test_countgd_example_based_counting() -> None: image=img, ) assert len(result) == 24 + assert [res["label"] for res in result] == ["coin"] * 24 diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 71646c45..bf4da892 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -700,6 +700,7 @@ def countgd_counting( {'score': 0.98, 'label': 'flower', 'bbox': [0.44, 0.24, 0.49, 0.58}, ] """ + image_size = image.shape[:2] buffer_bytes = numpy_to_bytes(image) files = [("image", buffer_bytes)] prompt = prompt.replace(", ", " .") @@ -712,7 +713,7 @@ def countgd_counting( bboxes_formatted = [ ODResponseData( label=bbox["label"], - bbox=list(map(lambda x: round(x, 2), bbox["bounding_box"])), + bbox=normalize_bbox(bbox["bounding_box"], image_size), score=round(bbox["score"], 2), ) for bbox in bboxes_per_frame @@ -757,6 +758,7 @@ def countgd_example_based_counting( {'score': 0.98, 'label': 'object', 'bounding_box': [0.44, 0.24, 0.49, 0.58}, ] """ + image_size = image.shape[:2] buffer_bytes = numpy_to_bytes(image) files = [("image", buffer_bytes)] visual_prompts = [ @@ -771,7 +773,7 @@ def countgd_example_based_counting( bboxes_formatted = [ ODResponseData( label=bbox["label"], - bbox=list(map(lambda x: round(x, 2), bbox["bounding_box"])), + bbox=normalize_bbox(bbox["bounding_box"], image_size), score=round(bbox["score"], 2), ) for bbox in bboxes_per_frame