Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Sep 19, 2024
1 parent 478eddf commit 4bcc013
Showing 1 changed file with 41 additions and 2 deletions.
43 changes: 41 additions & 2 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
grounding_dino,
grounding_sam,
ixc25_image_vqa,
ixc25_video_vqa,
ixc25_temporal_localization,
ixc25_video_vqa,
loca_visual_prompt_counting,
loca_zero_shot_counting,
ocr,
Expand All @@ -33,6 +33,8 @@
vit_nsfw_classification,
)

FINE_TUNE_ID = "65ebba4a-88b7-419f-9046-0750e30250da"


def test_grounding_dino():
img = ski.data.coins()
Expand Down Expand Up @@ -65,6 +67,18 @@ def test_owl_v2_image():
assert [res["label"] for res in result] == ["coin"] * len(result)


def test_owl_v2_fine_tune_id():
img = ski.data.coins()
result = owl_v2_image(
prompt="coin",
image=img,
fine_tune_id=FINE_TUNE_ID,
)
# this calls a fine-tuned florence2 model which is going to be worse at this task
assert 14 <= len(result) <= 26
assert [res["label"] for res in result] == ["coin"] * len(result)


def test_owl_v2_video():
frames = [
np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)
Expand All @@ -78,7 +92,7 @@ def test_owl_v2_video():
assert 24 <= len([res["label"] for res in result[0]]) <= 26


def test_object_detection():
def test_florence2_phrase_grounding():
img = ski.data.coins()
result = florence2_phrase_grounding(
image=img,
Expand All @@ -88,6 +102,18 @@ def test_object_detection():
assert [res["label"] for res in result] == ["coin"] * 25


def test_florence2_phrase_grounding_fine_tune_id():
img = ski.data.coins()
result = florence2_phrase_grounding(
prompt="coin",
image=img,
fine_tune_id=FINE_TUNE_ID,
)
# this calls a fine-tuned florence2 model which is going to be worse at this task
assert 14 <= len(result) <= 26
assert [res["label"] for res in result] == ["coin"] * len(result)


def test_template_match():
img = ski.data.coins()
result = template_match(
Expand Down Expand Up @@ -119,6 +145,19 @@ def test_florence2_sam2_image():
assert len([res["mask"] for res in result]) == 25


def test_florence2_sam2_image_fine_tune_id():
img = ski.data.coins()
result = florence2_sam2_image(
prompt="coin",
image=img,
fine_tune_id=FINE_TUNE_ID,
)
# this calls a fine-tuned florence2 model which is going to be worse at this task
assert 14 <= len(result) <= 26
assert [res["label"] for res in result] == ["coin"] * len(result)
assert len([res["mask"] for res in result]) == len(result)


def test_florence2_sam2_video():
frames = [
np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)
Expand Down

0 comments on commit 4bcc013

Please sign in to comment.