diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index bca1f6ea..ec45f7c9 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -10,11 +10,11 @@ detr_segmentation, dpt_hybrid_midas, florence2_image_caption, - florence2_phrase_grounding, florence2_ocr, + florence2_phrase_grounding, florence2_roberta_vqa, florence2_sam2_image, - florence2_sam2_video, + florence2_sam2_video_tracking, generate_pose_image, generate_soft_edge_image, git_vqa_v2, @@ -25,7 +25,8 @@ loca_visual_prompt_counting, loca_zero_shot_counting, ocr, - owl_v2, + owl_v2_image, + owl_v2_video, template_match, vit_image_classification, vit_nsfw_classification, @@ -53,9 +54,9 @@ def test_grounding_dino_tiny(): assert [res["label"] for res in result] == ["coin"] * 24 -def test_owl(): +def test_owl_v2_image(): img = ski.data.coins() - result = owl_v2( + result = owl_v2_image( prompt="coin", image=img, ) @@ -63,6 +64,19 @@ def test_owl(): assert [res["label"] for res in result] == ["coin"] * 25 +def test_owl_v2_video(): + frames = [ + np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10) + ] + result = owl_v2_video( + prompt="coin", + frames=frames, + ) + + assert len(result) == 10 + assert len([res["label"] for res in result[0]]) == 25 + + def test_object_detection(): img = ski.data.coins() result = florence2_phrase_grounding( @@ -108,7 +122,7 @@ def test_florence2_sam2_video(): frames = [ np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10) ] - result = florence2_sam2_video( + result = florence2_sam2_video_tracking( prompt="coin", frames=frames, ) diff --git a/tests/integration_dev/test_tools.py b/tests/integration_dev/test_tools.py index 29262245..246c5642 100644 --- a/tests/integration_dev/test_tools.py +++ b/tests/integration_dev/test_tools.py @@ -1,9 +1,6 @@ import skimage as ski -from vision_agent.tools import ( - countgd_counting, - countgd_example_based_counting, -) +from vision_agent.tools import countgd_counting, countgd_example_based_counting def test_countgd_counting() -> None: