From 4bcc013c06db2d3de60b9cd46b9432f45885a5c7 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Thu, 19 Sep 2024 10:19:49 -0700 Subject: [PATCH] update tests --- tests/integ/test_tools.py | 43 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index ba5b989e..4954738c 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -21,8 +21,8 @@ grounding_dino, grounding_sam, ixc25_image_vqa, - ixc25_video_vqa, ixc25_temporal_localization, + ixc25_video_vqa, loca_visual_prompt_counting, loca_zero_shot_counting, ocr, @@ -33,6 +33,8 @@ vit_nsfw_classification, ) +FINE_TUNE_ID = "65ebba4a-88b7-419f-9046-0750e30250da" + def test_grounding_dino(): img = ski.data.coins() @@ -65,6 +67,18 @@ def test_owl_v2_image(): assert [res["label"] for res in result] == ["coin"] * len(result) +def test_owl_v2_fine_tune_id(): + img = ski.data.coins() + result = owl_v2_image( + prompt="coin", + image=img, + fine_tune_id=FINE_TUNE_ID, + ) + # this calls a fine-tuned florence2 model which is going to be worse at this task + assert 14 <= len(result) <= 26 + assert [res["label"] for res in result] == ["coin"] * len(result) + + def test_owl_v2_video(): frames = [ np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10) @@ -78,7 +92,7 @@ def test_owl_v2_video(): assert 24 <= len([res["label"] for res in result[0]]) <= 26 -def test_object_detection(): +def test_florence2_phrase_grounding(): img = ski.data.coins() result = florence2_phrase_grounding( image=img, @@ -88,6 +102,18 @@ def test_object_detection(): assert [res["label"] for res in result] == ["coin"] * 25 +def test_florence2_phrase_grounding_fine_tune_id(): + img = ski.data.coins() + result = florence2_phrase_grounding( + prompt="coin", + image=img, + fine_tune_id=FINE_TUNE_ID, + ) + # this calls a fine-tuned florence2 model which is going to be worse at this task + assert 14 <= len(result) <= 26 + assert [res["label"] for res in result] == ["coin"] * len(result) + + def test_template_match(): img = ski.data.coins() result = template_match( @@ -119,6 +145,19 @@ def test_florence2_sam2_image(): assert len([res["mask"] for res in result]) == 25 +def test_florence2_sam2_image_fine_tune_id(): + img = ski.data.coins() + result = florence2_sam2_image( + prompt="coin", + image=img, + fine_tune_id=FINE_TUNE_ID, + ) + # this calls a fine-tuned florence2 model which is going to be worse at this task + assert 14 <= len(result) <= 26 + assert [res["label"] for res in result] == ["coin"] * len(result) + assert len([res["mask"] for res in result]) == len(result) + + def test_florence2_sam2_video(): frames = [ np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)