Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Tools - Captioning, Image Processing, Generic OD & Seg #157

Merged
merged 10 commits into from
Jul 1, 2024
15 changes: 13 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ e2b = "^0.17.1"
e2b-code-interpreter = "^0.0.9"
tenacity = "^8.3.0"
pillow-heif = "^0.16.0"
pytube = "15.0.0"

[tool.poetry.group.dev.dependencies]
autoflake = "1.*"
Expand Down
92 changes: 92 additions & 0 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,22 @@
blip_image_caption,
clip,
closest_mask_distance,
florencev2_image_caption,
depth_anything_v2,
dpt_hybrid_midas,
generate_pose_image,
generate_soft_edge_image,
florencev2_object_detection,
detr_segmentation,
git_vqa_v2,
grounding_dino,
grounding_sam,
florencev2_roberta_vqa,
loca_visual_prompt_counting,
loca_zero_shot_counting,
ocr,
owl_v2,
template_match,
vit_image_classification,
vit_nsfw_classification,
)
Expand Down Expand Up @@ -48,6 +57,24 @@ def test_owl():
assert [res["label"] for res in result] == ["coin"] * 25


def test_object_detection():
img = ski.data.coins()
result = florencev2_object_detection(
image=img,
)
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24


def test_template_match():
img = ski.data.coins()
result = template_match(
image=img,
template_image=img[32:76, 20:68],
)
assert len(result) == 2


def test_grounding_sam():
img = ski.data.coins()
result = grounding_sam(
Expand All @@ -59,6 +86,16 @@ def test_grounding_sam():
assert len([res["mask"] for res in result]) == 24


def test_segmentation():
img = ski.data.coins()
result = detr_segmentation(
image=img,
)
assert len(result) == 1
assert [res["label"] for res in result] == ["pizza"]
assert len([res["mask"] for res in result]) == 1


def test_clip():
img = ski.data.coins()
result = clip(
Expand Down Expand Up @@ -92,6 +129,14 @@ def test_image_caption() -> None:
assert result.strip() == "a rocket on a stand"


def test_florence_image_caption() -> None:
img = ski.data.rocket()
result = florencev2_image_caption(
image=img,
)
assert "The image shows a rocket on a launch pad at night" in result.strip()


def test_loca_zero_shot_counting() -> None:
img = ski.data.coins()

Expand Down Expand Up @@ -119,6 +164,15 @@ def test_git_vqa_v2() -> None:
assert result.strip() == "night"


def test_image_qa_with_context() -> None:
img = ski.data.rocket()
result = florencev2_roberta_vqa(
prompt="Is the scene captured during day or night ?",
image=img,
)
assert "night" in result.strip()


def test_ocr() -> None:
img = ski.data.page()
result = ocr(
Expand All @@ -144,3 +198,41 @@ def test_mask_distance():
np.sqrt(2) * 81,
atol=1e-2,
), f"Expected {np.sqrt(2) * 81}, got {distance}"


def test_generate_depth():
img = ski.data.coins()
result = depth_anything_v2(
image=img,
)

assert result.shape == img.shape


def test_generate_pose():
img = ski.data.coins()
result = generate_pose_image(
image=img,
)
import cv2

cv2.imwrite("imag.png", result)
assert result.shape == img.shape + (3,)


def test_generate_normal():
img = ski.data.coins()
result = dpt_hybrid_midas(
image=img,
)

assert result.shape == img.shape + (3,)


def test_generate_hed():
img = ski.data.coins()
result = generate_soft_edge_image(
image=img,
)

assert result.shape == img.shape
9 changes: 9 additions & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,18 @@
closest_box_distance,
closest_mask_distance,
extract_frames,
florencev2_image_caption,
get_tool_documentation,
florencev2_object_detection,
detr_segmentation,
depth_anything_v2,
generate_soft_edge_image,
dpt_hybrid_midas,
generate_pose_image,
git_vqa_v2,
grounding_dino,
grounding_sam,
florencev2_roberta_vqa,
load_image,
loca_visual_prompt_counting,
loca_zero_shot_counting,
Expand All @@ -27,6 +35,7 @@
save_image,
save_json,
save_video,
template_match,
vit_image_classification,
vit_nsfw_classification,
)
Expand Down
Loading
Loading