diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index 690795f0..62a0a771 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -71,6 +71,14 @@ def test_owl_v2_image(): assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result]) +def test_owl_v2_image_empty(): + result = owl_v2_image( + prompt="coin", + image=np.zeros((0, 0, 3)).astype(np.uint8), + ) + assert result == [] + + def test_owl_v2_fine_tune_id(): img = ski.data.coins() result = owl_v2_image( @@ -110,6 +118,14 @@ def test_florence2_phrase_grounding(): assert all([all([0 <= x <= 1 for x in obj["bbox"]]) for obj in result]) +def test_florence2_phrase_grounding_empty(): + result = florence2_phrase_grounding( + image=np.zeros((0, 0, 3)).astype(np.uint8), + prompt="coin", + ) + assert result == [] + + def test_florence2_phrase_grounding_fine_tune_id(): img = ski.data.coins() result = florence2_phrase_grounding( @@ -195,6 +211,14 @@ def test_florence2_sam2_image_fine_tune_id(): assert len([res["mask"] for res in result]) == len(result) +def test_florence2_sam2_image_empty(): + result = florence2_sam2_image( + prompt="coin", + image=np.zeros((0, 0, 3)).astype(np.uint8), + ) + assert result == [] + + def test_florence2_sam2_video(): frames = [ np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10) @@ -208,7 +232,7 @@ def test_florence2_sam2_video(): assert len([res["mask"] for res in result[0]]) == 25 -def test_segmentation(): +def test_detr_segmentation(): img = ski.data.coins() result = detr_segmentation( image=img, @@ -218,6 +242,13 @@ def test_segmentation(): assert len([res["mask"] for res in result]) == 1 +def test_detr_segmentation_empty(): + result = detr_segmentation( + image=np.zeros((0, 0, 3)).astype(np.uint8), + ) + assert result == [] + + def test_clip(): img = ski.data.coins() result = clip( @@ -227,6 +258,15 @@ def test_clip(): assert result["scores"] == [0.9999, 0.0001] +def test_clip_empty(): + result = clip( + classes=["coins", "notes"], + image=np.zeros((0, 0, 3)).astype(np.uint8), + ) + assert result["scores"] == [] + assert result["labels"] == [] + + def test_vit_classification(): img = ski.data.coins() result = vit_image_classification( @@ -235,6 +275,14 @@ def test_vit_classification(): assert "typewriter keyboard" in result["labels"] +def test_vit_classification_empty(): + result = vit_image_classification( + image=np.zeros((0, 0, 3)).astype(np.uint8), + ) + assert result["labels"] == [] + assert result["scores"] == [] + + def test_nsfw_classification(): img = ski.data.coins() result = vit_nsfw_classification( @@ -243,7 +291,7 @@ def test_nsfw_classification(): assert result["label"] == "normal" -def test_image_caption() -> None: +def test_image_caption(): img = ski.data.rocket() result = blip_image_caption( image=img, @@ -251,7 +299,7 @@ def test_image_caption() -> None: assert result.strip() == "a rocket on a stand" -def test_florence_image_caption() -> None: +def test_florence_image_caption(): img = ski.data.rocket() result = florence2_image_caption( image=img, @@ -259,7 +307,7 @@ def test_florence_image_caption() -> None: assert "The image shows a rocket on a launch pad at night" in result.strip() -def test_loca_zero_shot_counting() -> None: +def test_loca_zero_shot_counting(): img = ski.data.coins() result = loca_zero_shot_counting( @@ -268,7 +316,7 @@ def test_loca_zero_shot_counting() -> None: assert result["count"] == 21 -def test_loca_visual_prompt_counting() -> None: +def test_loca_visual_prompt_counting(): img = ski.data.coins() result = loca_visual_prompt_counting( visual_prompt={"bbox": [85, 106, 122, 145]}, @@ -277,7 +325,7 @@ def test_loca_visual_prompt_counting() -> None: assert result["count"] == 25 -def test_git_vqa_v2() -> None: +def test_git_vqa_v2(): img = ski.data.rocket() result = git_vqa_v2( prompt="Is the scene captured during day or night ?", @@ -286,7 +334,7 @@ def test_git_vqa_v2() -> None: assert result.strip() == "night" -def test_image_qa_with_context() -> None: +def test_image_qa_with_context(): img = ski.data.rocket() result = florence2_roberta_vqa( prompt="Is the scene captured during day or night ?", @@ -295,7 +343,7 @@ def test_image_qa_with_context() -> None: assert "night" in result.strip() -def test_ixc25_image_vqa() -> None: +def test_ixc25_image_vqa(): img = ski.data.cat() result = ixc25_image_vqa( prompt="What animal is in this image?", @@ -304,7 +352,7 @@ def test_ixc25_image_vqa() -> None: assert "cat" in result.strip() -def test_ixc25_video_vqa() -> None: +def test_ixc25_video_vqa(): frames = [ np.array(Image.fromarray(ski.data.cat()).convert("RGB")) for _ in range(10) ] @@ -315,7 +363,7 @@ def test_ixc25_video_vqa() -> None: assert "cat" in result.strip() -def test_ixc25_temporal_localization() -> None: +def test_ixc25_temporal_localization(): frames = [ np.array(Image.fromarray(ski.data.cat()).convert("RGB")) for _ in range(10) ] @@ -326,7 +374,7 @@ def test_ixc25_temporal_localization() -> None: assert result == [True] * 10 -def test_ocr() -> None: +def test_ocr(): img = ski.data.page() result = ocr( image=img, @@ -334,7 +382,14 @@ def test_ocr() -> None: assert any("Region-based segmentation" in res["label"] for res in result) -def test_florence2_ocr() -> None: +def test_ocr_empty(): + result = ocr( + image=np.zeros((0, 0, 3)).astype(np.uint8), + ) + assert result == [] + + +def test_florence2_ocr(): img = ski.data.page() result = florence2_ocr( image=img, @@ -342,6 +397,13 @@ def test_florence2_ocr() -> None: assert any("Region-based segmentation" in res["label"] for res in result) +def test_florence2_ocr_empty(): + result = florence2_ocr( + image=np.zeros((0, 0, 3)).astype(np.uint8), + ) + assert result == [] + + def test_mask_distance(): # Create two binary masks mask1 = np.zeros((100, 100), dtype=np.uint8) @@ -399,14 +461,22 @@ def test_generate_hed(): assert result.shape == img.shape -def test_countgd_counting() -> None: +def test_countgd_counting(): img = ski.data.coins() result = countgd_counting(image=img, prompt="coin") assert len(result) == 24 assert [res["label"] for res in result] == ["coin"] * 24 -def test_countgd_example_based_counting() -> None: +def test_countgd_counting_empty(): + result = countgd_counting( + prompt="coin", + image=np.zeros((0, 0, 3)).astype(np.uint8), + ) + assert result == [] + + +def test_countgd_example_based_counting(): img = ski.data.coins() result = countgd_example_based_counting( visual_prompts=[[85, 106, 122, 145]], @@ -414,3 +484,11 @@ def test_countgd_example_based_counting() -> None: ) assert len(result) == 24 assert [res["label"] for res in result] == ["object"] * 24 + + +def test_countgd_example_based_counting_empty(): + result = countgd_example_based_counting( + visual_prompts=[[85, 106, 122, 145]], + image=np.zeros((0, 0, 3)).astype(np.uint8), + ) + assert result == [] diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 45f10c33..9c03467c 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -181,6 +181,8 @@ def owl_v2_image( """ image_size = image.shape[:2] + if image_size[0] < 1 or image_size[1] < 1: + return [] if fine_tune_id is not None: image_b64 = convert_to_b64(image) @@ -413,6 +415,9 @@ def florence2_sam2_image( }, ] """ + if image.shape[0] < 1 or image.shape[1] < 1: + return [] + if fine_tune_id is not None: image_b64 = convert_to_b64(image) landing_api = LandingPublicAPI() @@ -701,6 +706,8 @@ def countgd_counting( ] """ image_size = image.shape[:2] + if image_size[0] < 1 or image_size[1] < 1: + return [] buffer_bytes = numpy_to_bytes(image) files = [("image", buffer_bytes)] prompt = prompt.replace(", ", " .") @@ -759,6 +766,8 @@ def countgd_example_based_counting( ] """ image_size = image.shape[:2] + if image_size[0] < 1 or image_size[1] < 1: + return [] buffer_bytes = numpy_to_bytes(image) files = [("image", buffer_bytes)] visual_prompts = [ @@ -828,6 +837,8 @@ def ixc25_image_vqa(prompt: str, image: np.ndarray) -> str: >>> ixc25_image_vqa('What is the cat doing?', image) 'drinking milk' """ + if image.shape[0] < 1 or image.shape[1] < 1: + raise ValueError(f"Image is empty, image shape: {image.shape}") buffer_bytes = numpy_to_bytes(image) files = [("image", buffer_bytes)] @@ -1024,6 +1035,9 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]: {"labels": ["dog", "cat", "bird"], "scores": [0.68, 0.30, 0.02]}, """ + if image.shape[0] < 1 or image.shape[1] < 1: + return {"labels": [], "scores": []} + image_b64 = convert_to_b64(image) data = { "prompt": ",".join(classes), @@ -1052,6 +1066,8 @@ def vit_image_classification(image: np.ndarray) -> Dict[str, Any]: >>> vit_image_classification(image) {"labels": ["leopard", "lemur, otter", "bird"], "scores": [0.68, 0.30, 0.02]}, """ + if image.shape[0] < 1 or image.shape[1] < 1: + return {"labels": [], "scores": []} image_b64 = convert_to_b64(image) data = { @@ -1080,6 +1096,8 @@ def vit_nsfw_classification(image: np.ndarray) -> Dict[str, Any]: >>> vit_nsfw_classification(image) {"label": "normal", "scores": 0.68}, """ + if image.shape[0] < 1 or image.shape[1] < 1: + raise ValueError(f"Image is empty, image shape: {image.shape}") image_b64 = convert_to_b64(image) data = { @@ -1180,6 +1198,8 @@ def florence2_phrase_grounding( ] """ image_size = image.shape[:2] + if image_size[0] < 1 or image_size[1] < 1: + return [] image_b64 = convert_to_b64(image) if fine_tune_id is not None: @@ -1399,6 +1419,8 @@ def detr_segmentation(image: np.ndarray) -> List[Dict[str, Any]]: }, ] """ + if image.shape[0] < 1 or image.shape[1] < 1: + return [] image_b64 = convert_to_b64(image) data = { "image": image_b64, @@ -1442,6 +1464,9 @@ def depth_anything_v2(image: np.ndarray) -> np.ndarray: [10, 11, 15, ..., 202, 202, 205], [10, 10, 10, ..., 200, 200, 200]], dtype=uint8), """ + if image.shape[0] < 1 or image.shape[1] < 1: + raise ValueError(f"Image is empty, image shape: {image.shape}") + image_b64 = convert_to_b64(image) data = { "image": image_b64,