Skip to content

Commit

Permalink
tools can deal with empty images
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Oct 12, 2024
1 parent db8ee75 commit f34be2d
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 14 deletions.
106 changes: 92 additions & 14 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ def test_owl_v2_image():
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result])


def test_owl_v2_image_empty():
result = owl_v2_image(
prompt="coin",
image=np.zeros((0, 0, 3)).astype(np.uint8),
)
assert result == []


def test_owl_v2_fine_tune_id():
img = ski.data.coins()
result = owl_v2_image(
Expand Down Expand Up @@ -110,6 +118,14 @@ def test_florence2_phrase_grounding():
assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result])


def test_florence2_phrase_grounding_empty():
result = florence2_phrase_grounding(
image=np.zeros((0, 0, 3)).astype(np.uint8),
prompt="coin",
)
assert result == []


def test_florence2_phrase_grounding_fine_tune_id():
img = ski.data.coins()
result = florence2_phrase_grounding(
Expand Down Expand Up @@ -195,6 +211,14 @@ def test_florence2_sam2_image_fine_tune_id():
assert len([res["mask"] for res in result]) == len(result)


def test_florence2_sam2_image_empty():
result = florence2_sam2_image(
prompt="coin",
image=np.zeros((0, 0, 3)).astype(np.uint8),
)
assert result == []


def test_florence2_sam2_video():
frames = [
np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)
Expand All @@ -208,7 +232,7 @@ def test_florence2_sam2_video():
assert len([res["mask"] for res in result[0]]) == 25


def test_segmentation():
def test_detr_segmentation():
img = ski.data.coins()
result = detr_segmentation(
image=img,
Expand All @@ -218,6 +242,13 @@ def test_segmentation():
assert len([res["mask"] for res in result]) == 1


def test_detr_segmentation_empty():
result = detr_segmentation(
image=np.zeros((0, 0, 3)).astype(np.uint8),
)
assert result == []


def test_clip():
img = ski.data.coins()
result = clip(
Expand All @@ -227,6 +258,15 @@ def test_clip():
assert result["scores"] == [0.9999, 0.0001]


def test_clip_empty():
result = clip(
classes=["coins", "notes"],
image=np.zeros((0, 0, 3)).astype(np.uint8),
)
assert result["scores"] == []
assert result["labels"] == []


def test_vit_classification():
img = ski.data.coins()
result = vit_image_classification(
Expand All @@ -235,6 +275,14 @@ def test_vit_classification():
assert "typewriter keyboard" in result["labels"]


def test_vit_classification_empty():
result = vit_image_classification(
image=np.zeros((0, 0, 3)).astype(np.uint8),
)
assert result["labels"] == []
assert result["scores"] == []


def test_nsfw_classification():
img = ski.data.coins()
result = vit_nsfw_classification(
Expand All @@ -243,23 +291,23 @@ def test_nsfw_classification():
assert result["label"] == "normal"


def test_image_caption() -> None:
def test_image_caption():
img = ski.data.rocket()
result = blip_image_caption(
image=img,
)
assert result.strip() == "a rocket on a stand"


def test_florence_image_caption() -> None:
def test_florence_image_caption():
img = ski.data.rocket()
result = florence2_image_caption(
image=img,
)
assert "The image shows a rocket on a launch pad at night" in result.strip()


def test_loca_zero_shot_counting() -> None:
def test_loca_zero_shot_counting():
img = ski.data.coins()

result = loca_zero_shot_counting(
Expand All @@ -268,7 +316,7 @@ def test_loca_zero_shot_counting() -> None:
assert result["count"] == 21


def test_loca_visual_prompt_counting() -> None:
def test_loca_visual_prompt_counting():
img = ski.data.coins()
result = loca_visual_prompt_counting(
visual_prompt={"bbox": [85, 106, 122, 145]},
Expand All @@ -277,7 +325,7 @@ def test_loca_visual_prompt_counting() -> None:
assert result["count"] == 25


def test_git_vqa_v2() -> None:
def test_git_vqa_v2():
img = ski.data.rocket()
result = git_vqa_v2(
prompt="Is the scene captured during day or night ?",
Expand All @@ -286,7 +334,7 @@ def test_git_vqa_v2() -> None:
assert result.strip() == "night"


def test_image_qa_with_context() -> None:
def test_image_qa_with_context():
img = ski.data.rocket()
result = florence2_roberta_vqa(
prompt="Is the scene captured during day or night ?",
Expand All @@ -295,7 +343,7 @@ def test_image_qa_with_context() -> None:
assert "night" in result.strip()


def test_ixc25_image_vqa() -> None:
def test_ixc25_image_vqa():
img = ski.data.cat()
result = ixc25_image_vqa(
prompt="What animal is in this image?",
Expand All @@ -304,7 +352,7 @@ def test_ixc25_image_vqa() -> None:
assert "cat" in result.strip()


def test_ixc25_video_vqa() -> None:
def test_ixc25_video_vqa():
frames = [
np.array(Image.fromarray(ski.data.cat()).convert("RGB")) for _ in range(10)
]
Expand All @@ -315,7 +363,7 @@ def test_ixc25_video_vqa() -> None:
assert "cat" in result.strip()


def test_ixc25_temporal_localization() -> None:
def test_ixc25_temporal_localization():
frames = [
np.array(Image.fromarray(ski.data.cat()).convert("RGB")) for _ in range(10)
]
Expand All @@ -326,22 +374,36 @@ def test_ixc25_temporal_localization() -> None:
assert result == [True] * 10


def test_ocr() -> None:
def test_ocr():
img = ski.data.page()
result = ocr(
image=img,
)
assert any("Region-based segmentation" in res["label"] for res in result)


def test_florence2_ocr() -> None:
def test_ocr_empty():
result = ocr(
image=np.zeros((0, 0, 3)).astype(np.uint8),
)
assert result == []


def test_florence2_ocr():
img = ski.data.page()
result = florence2_ocr(
image=img,
)
assert any("Region-based segmentation" in res["label"] for res in result)


def test_florence2_ocr_empty():
result = florence2_ocr(
image=np.zeros((0, 0, 3)).astype(np.uint8),
)
assert result == []


def test_mask_distance():
# Create two binary masks
mask1 = np.zeros((100, 100), dtype=np.uint8)
Expand Down Expand Up @@ -399,18 +461,34 @@ def test_generate_hed():
assert result.shape == img.shape


def test_countgd_counting() -> None:
def test_countgd_counting():
img = ski.data.coins()
result = countgd_counting(image=img, prompt="coin")
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24


def test_countgd_example_based_counting() -> None:
def test_countgd_counting_empty():
result = countgd_counting(
prompt="coin",
image=np.zeros((0, 0, 3)).astype(np.uint8),
)
assert result == []


def test_countgd_example_based_counting():
img = ski.data.coins()
result = countgd_example_based_counting(
visual_prompts=[[85, 106, 122, 145]],
image=img,
)
assert len(result) == 24
assert [res["label"] for res in result] == ["object"] * 24


def test_countgd_example_based_counting_empty():
result = countgd_example_based_counting(
visual_prompts=[[85, 106, 122, 145]],
image=np.zeros((0, 0, 3)).astype(np.uint8),
)
assert result == []
25 changes: 25 additions & 0 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ def owl_v2_image(
"""

image_size = image.shape[:2]
if image_size[0] < 1 or image_size[1] < 1:
return []

if fine_tune_id is not None:
image_b64 = convert_to_b64(image)
Expand Down Expand Up @@ -413,6 +415,9 @@ def florence2_sam2_image(
},
]
"""
if image.shape[0] < 1 or image.shape[1] < 1:
return []

if fine_tune_id is not None:
image_b64 = convert_to_b64(image)
landing_api = LandingPublicAPI()
Expand Down Expand Up @@ -701,6 +706,8 @@ def countgd_counting(
]
"""
image_size = image.shape[:2]
if image_size[0] < 1 or image_size[1] < 1:
return []
buffer_bytes = numpy_to_bytes(image)
files = [("image", buffer_bytes)]
prompt = prompt.replace(", ", " .")
Expand Down Expand Up @@ -759,6 +766,8 @@ def countgd_example_based_counting(
]
"""
image_size = image.shape[:2]
if image_size[0] < 1 or image_size[1] < 1:
return []
buffer_bytes = numpy_to_bytes(image)
files = [("image", buffer_bytes)]
visual_prompts = [
Expand Down Expand Up @@ -828,6 +837,8 @@ def ixc25_image_vqa(prompt: str, image: np.ndarray) -> str:
>>> ixc25_image_vqa('What is the cat doing?', image)
'drinking milk'
"""
if image.shape[0] < 1 or image.shape[1] < 1:
raise ValueError(f"Image is empty, image shape: {image.shape}")

buffer_bytes = numpy_to_bytes(image)
files = [("image", buffer_bytes)]
Expand Down Expand Up @@ -1024,6 +1035,9 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]:
{"labels": ["dog", "cat", "bird"], "scores": [0.68, 0.30, 0.02]},
"""

if image.shape[0] < 1 or image.shape[1] < 1:
return {"labels": [], "scores": []}

image_b64 = convert_to_b64(image)
data = {
"prompt": ",".join(classes),
Expand Down Expand Up @@ -1052,6 +1066,8 @@ def vit_image_classification(image: np.ndarray) -> Dict[str, Any]:
>>> vit_image_classification(image)
{"labels": ["leopard", "lemur, otter", "bird"], "scores": [0.68, 0.30, 0.02]},
"""
if image.shape[0] < 1 or image.shape[1] < 1:
return {"labels": [], "scores": []}

image_b64 = convert_to_b64(image)
data = {
Expand Down Expand Up @@ -1080,6 +1096,8 @@ def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]:
>>> vit_nsfw_classification(image)
{"label": "normal", "scores": 0.68},
"""
if image.shape[0] < 1 or image.shape[1] < 1:
raise ValueError(f"Image is empty, image shape: {image.shape}")

image_b64 = convert_to_b64(image)
data = {
Expand Down Expand Up @@ -1180,6 +1198,8 @@ def florence2_phrase_grounding(
]
"""
image_size = image.shape[:2]
if image_size[0] < 1 or image_size[1] < 1:
return []
image_b64 = convert_to_b64(image)

if fine_tune_id is not None:
Expand Down Expand Up @@ -1399,6 +1419,8 @@ def detr_segmentation(image: np.ndarray) -> List[Dict[str, Any]]:
},
]
"""
if image.shape[0] < 1 or image.shape[1] < 1:
return []
image_b64 = convert_to_b64(image)
data = {
"image": image_b64,
Expand Down Expand Up @@ -1442,6 +1464,9 @@ def depth_anything_v2(image: np.ndarray) -> np.ndarray:
[10, 11, 15, ..., 202, 202, 205],
[10, 10, 10, ..., 200, 200, 200]], dtype=uint8),
"""
if image.shape[0] < 1 or image.shape[1] < 1:
raise ValueError(f"Image is empty, image shape: {image.shape}")

image_b64 = convert_to_b64(image)
data = {
"image": image_b64,
Expand Down

0 comments on commit f34be2d

Please sign in to comment.