-
Notifications
You must be signed in to change notification settings - Fork 128
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding more cv tools to coder agent (#88)
* adding more cv tools to coder agent * Bug fix: missing kernel python3 (#87) * Bug fix: missing kernel python3 * Update API key * [skip ci] chore(release): vision-agent 0.2.24 * adding more cv tools to coder agent * Added test cases for the tools added * added test case for every tool * fix linting * fixing tests * fix linting * fixing test cases for grounding tools as the output format changed --------- Co-authored-by: Asia <[email protected]> Co-authored-by: GitHub Actions Bot <[email protected]>
- Loading branch information
1 parent
92fd2c8
commit 21849f4
Showing
4 changed files
with
241 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,84 @@ | ||
import skimage as ski | ||
from PIL import Image | ||
|
||
from vision_agent.tools.tools import CLIP, GroundingDINO, GroundingSAM, ImageCaption | ||
from vision_agent.tools.tools_v2 import ( | ||
clip, | ||
zero_shot_counting, | ||
visual_prompt_counting, | ||
image_question_answering, | ||
ocr, | ||
grounding_dino, | ||
grounding_sam, | ||
image_caption, | ||
) | ||
|
||
|
||
def test_grounding_dino(): | ||
img = Image.fromarray(ski.data.coins()) | ||
result = GroundingDINO()( | ||
img = ski.data.coins() | ||
result = grounding_dino( | ||
prompt="coin", | ||
image=img, | ||
) | ||
assert result["labels"] == ["coin"] * 24 | ||
assert len(result["bboxes"]) == 24 | ||
assert len(result["scores"]) == 24 | ||
assert len(result) == 24 | ||
assert [res["label"] for res in result] == ["coin"] * 24 | ||
|
||
|
||
def test_grounding_sam(): | ||
img = Image.fromarray(ski.data.coins()) | ||
result = GroundingSAM()( | ||
img = ski.data.coins() | ||
result = grounding_sam( | ||
prompt="coin", | ||
image=img, | ||
) | ||
assert result["labels"] == ["coin"] * 24 | ||
assert len(result["bboxes"]) == 24 | ||
assert len(result["scores"]) == 24 | ||
assert len(result["masks"]) == 24 | ||
assert len(result) == 24 | ||
assert [res["label"] for res in result] == ["coin"] * 24 | ||
assert len([res["mask"] for res in result]) == 24 | ||
|
||
|
||
def test_clip(): | ||
img = Image.fromarray(ski.data.coins()) | ||
result = CLIP()( | ||
prompt="coins", | ||
img = ski.data.coins() | ||
result = clip( | ||
classes=["coins", "notes"], | ||
image=img, | ||
) | ||
assert result["scores"] == [1.0] | ||
assert result["scores"] == [0.9999, 0.0001] | ||
|
||
|
||
def test_image_caption() -> None: | ||
img = Image.fromarray(ski.data.coins()) | ||
result = ImageCaption()(image=img) | ||
assert result["text"] | ||
img = ski.data.rocket() | ||
result = image_caption( | ||
image=img, | ||
) | ||
assert result.strip() == "a rocket on a stand" | ||
|
||
|
||
def test_zero_shot_counting() -> None: | ||
img = ski.data.coins() | ||
result = zero_shot_counting( | ||
image=img, | ||
) | ||
assert result["count"] == 21 | ||
|
||
|
||
def test_visual_prompt_counting() -> None: | ||
img = ski.data.coins() | ||
result = visual_prompt_counting( | ||
visual_prompt={"bbox": [85, 106, 122, 145]}, | ||
image=img, | ||
) | ||
assert result["count"] == 25 | ||
|
||
|
||
def test_image_question_answering() -> None: | ||
img = ski.data.rocket() | ||
result = image_question_answering( | ||
prompt="Is the scene captured during day or night ?", | ||
image=img, | ||
) | ||
assert result.strip() == "night" | ||
|
||
|
||
def test_ocr() -> None: | ||
img = ski.data.page() | ||
result = ocr( | ||
image=img, | ||
) | ||
assert any("Region-based segmentation" in res["label"] for res in result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters