Skip to content

Commit

Permalink
Adding Tools - Captioning, Image Processing, Generic OD & Seg (#157)
Browse files Browse the repository at this point in the history
* Adding the new list of tools, adding capability to load youtube videos in extract frames and increases testing footprint to accomodate for new tools

* updating tools to latest names

* fix mypy errors

* fix mypy errors

* fixing integration test

* fixing rle decode for segmentation

* fix documentation

* Image processing tools return grayscale image, so need to convert it to RGB

* adding image question answering with context

* Fixing tool names
  • Loading branch information
shankar-vision-eng authored Jul 1, 2024
1 parent 3f98d56 commit 1305146
Show file tree
Hide file tree
Showing 5 changed files with 488 additions and 14 deletions.
15 changes: 13 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ e2b = "^0.17.1"
e2b-code-interpreter = "^0.0.9"
tenacity = "^8.3.0"
pillow-heif = "^0.16.0"
pytube = "15.0.0"

[tool.poetry.group.dev.dependencies]
autoflake = "1.*"
Expand Down
92 changes: 92 additions & 0 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,22 @@
blip_image_caption,
clip,
closest_mask_distance,
florencev2_image_caption,
depth_anything_v2,
dpt_hybrid_midas,
generate_pose_image,
generate_soft_edge_image,
florencev2_object_detection,
detr_segmentation,
git_vqa_v2,
grounding_dino,
grounding_sam,
florencev2_roberta_vqa,
loca_visual_prompt_counting,
loca_zero_shot_counting,
ocr,
owl_v2,
template_match,
vit_image_classification,
vit_nsfw_classification,
)
Expand Down Expand Up @@ -48,6 +57,24 @@ def test_owl():
assert [res["label"] for res in result] == ["coin"] * 25


def test_object_detection():
img = ski.data.coins()
result = florencev2_object_detection(
image=img,
)
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24


def test_template_match():
img = ski.data.coins()
result = template_match(
image=img,
template_image=img[32:76, 20:68],
)
assert len(result) == 2


def test_grounding_sam():
img = ski.data.coins()
result = grounding_sam(
Expand All @@ -59,6 +86,16 @@ def test_grounding_sam():
assert len([res["mask"] for res in result]) == 24


def test_segmentation():
img = ski.data.coins()
result = detr_segmentation(
image=img,
)
assert len(result) == 1
assert [res["label"] for res in result] == ["pizza"]
assert len([res["mask"] for res in result]) == 1


def test_clip():
img = ski.data.coins()
result = clip(
Expand Down Expand Up @@ -92,6 +129,14 @@ def test_image_caption() -> None:
assert result.strip() == "a rocket on a stand"


def test_florence_image_caption() -> None:
img = ski.data.rocket()
result = florencev2_image_caption(
image=img,
)
assert "The image shows a rocket on a launch pad at night" in result.strip()


def test_loca_zero_shot_counting() -> None:
img = ski.data.coins()

Expand Down Expand Up @@ -119,6 +164,15 @@ def test_git_vqa_v2() -> None:
assert result.strip() == "night"


def test_image_qa_with_context() -> None:
img = ski.data.rocket()
result = florencev2_roberta_vqa(
prompt="Is the scene captured during day or night ?",
image=img,
)
assert "night" in result.strip()


def test_ocr() -> None:
img = ski.data.page()
result = ocr(
Expand All @@ -144,3 +198,41 @@ def test_mask_distance():
np.sqrt(2) * 81,
atol=1e-2,
), f"Expected {np.sqrt(2) * 81}, got {distance}"


def test_generate_depth():
img = ski.data.coins()
result = depth_anything_v2(
image=img,
)

assert result.shape == img.shape


def test_generate_pose():
img = ski.data.coins()
result = generate_pose_image(
image=img,
)
import cv2

cv2.imwrite("imag.png", result)
assert result.shape == img.shape + (3,)


def test_generate_normal():
img = ski.data.coins()
result = dpt_hybrid_midas(
image=img,
)

assert result.shape == img.shape + (3,)


def test_generate_hed():
img = ski.data.coins()
result = generate_soft_edge_image(
image=img,
)

assert result.shape == img.shape
9 changes: 9 additions & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,18 @@
closest_box_distance,
closest_mask_distance,
extract_frames,
florencev2_image_caption,
get_tool_documentation,
florencev2_object_detection,
detr_segmentation,
depth_anything_v2,
generate_soft_edge_image,
dpt_hybrid_midas,
generate_pose_image,
git_vqa_v2,
grounding_dino,
grounding_sam,
florencev2_roberta_vqa,
load_image,
loca_visual_prompt_counting,
loca_zero_shot_counting,
Expand All @@ -27,6 +35,7 @@
save_image,
save_json,
save_video,
template_match,
vit_image_classification,
vit_nsfw_classification,
)
Expand Down
Loading

0 comments on commit 1305146

Please sign in to comment.