Skip to content

Commit

Permalink
fix tests'
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Sep 6, 2024
1 parent 0fcfe04 commit 45a160a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
26 changes: 20 additions & 6 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
detr_segmentation,
dpt_hybrid_midas,
florence2_image_caption,
florence2_phrase_grounding,
florence2_ocr,
florence2_phrase_grounding,
florence2_roberta_vqa,
florence2_sam2_image,
florence2_sam2_video,
florence2_sam2_video_tracking,
generate_pose_image,
generate_soft_edge_image,
git_vqa_v2,
Expand All @@ -25,7 +25,8 @@
loca_visual_prompt_counting,
loca_zero_shot_counting,
ocr,
owl_v2,
owl_v2_image,
owl_v2_video,
template_match,
vit_image_classification,
vit_nsfw_classification,
Expand Down Expand Up @@ -53,16 +54,29 @@ def test_grounding_dino_tiny():
assert [res["label"] for res in result] == ["coin"] * 24


def test_owl():
def test_owl_v2_image():
img = ski.data.coins()
result = owl_v2(
result = owl_v2_image(
prompt="coin",
image=img,
)
assert len(result) == 25
assert [res["label"] for res in result] == ["coin"] * 25


def test_owl_v2_video():
frames = [
np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)
]
result = owl_v2_video(
prompt="coin",
frames=frames,
)

assert len(result) == 10
assert len([res["label"] for res in result[0]]) == 25


def test_object_detection():
img = ski.data.coins()
result = florence2_phrase_grounding(
Expand Down Expand Up @@ -108,7 +122,7 @@ def test_florence2_sam2_video():
frames = [
np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)
]
result = florence2_sam2_video(
result = florence2_sam2_video_tracking(
prompt="coin",
frames=frames,
)
Expand Down
5 changes: 1 addition & 4 deletions tests/integration_dev/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import skimage as ski

from vision_agent.tools import (
countgd_counting,
countgd_example_based_counting,
)
from vision_agent.tools import countgd_counting, countgd_example_based_counting


def test_countgd_counting() -> None:
Expand Down

0 comments on commit 45a160a

Please sign in to comment.