Skip to content

Commit

Permalink
added test case for every tool
Browse files Browse the repository at this point in the history
  • Loading branch information
shankar-vision-eng committed May 16, 2024
1 parent 6604608 commit 3a8d3c4
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 27 deletions.
37 changes: 21 additions & 16 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
import skimage as ski
from PIL import Image

from vision_agent.tools.tools import CLIP, GroundingDINO, GroundingSAM, ImageCaption

from vision_agent.tools.tools_v2 import (
clip,
zero_shot_counting,
visual_prompt_counting,
image_question_answering,
ocr,
grounding_dino,
grounding_sam,
image_caption,
)


def test_grounding_dino():
img = Image.fromarray(ski.data.coins())
result = GroundingDINO()(
result = grounding_dino(
prompt="coin",
image=img,
)
Expand All @@ -24,7 +27,7 @@ def test_grounding_dino():

def test_grounding_sam():
img = Image.fromarray(ski.data.coins())
result = GroundingSAM()(
result = grounding_sam(
prompt="coin",
image=img,
)
Expand All @@ -40,44 +43,46 @@ def test_clip():
classes=["coins", "notes"],
image=img,
)
assert result["scores"] == [0.99, 0.01]
assert result["scores"] == [0.9999, 0.0001]


def test_image_caption() -> None:
img = Image.fromarray(ski.data.coins())
result = ImageCaption()(image=img)
assert result["text"]
img = ski.data.rocket()
result = image_caption(
image=img,
)
assert result.strip() == "a rocket on a stand"


def test_zero_shot_counting() -> None:
img = Image.fromarray(ski.data.coins())
img = ski.data.coins()
result = zero_shot_counting(
image=img,
)
assert result["count"] == 24
assert result["count"] == 21


def test_visual_prompt_counting() -> None:
img = Image.fromarray(ski.data.checkerboard())
img = ski.data.coins()
result = visual_prompt_counting(
visual_prompt={"bbox": [0.125, 0, 0.25, 0.125]},
visual_prompt={"bbox": [85, 106, 122, 145]},
image=img,
)
assert result["count"] == 32
assert result["count"] == 25


def test_image_question_answering() -> None:
img = Image.fromarray(ski.data.rocket())
img = ski.data.rocket()
result = image_question_answering(
prompt="Is the scene captured during day or night ?",
image=img,
)
assert result == "night"
assert result.strip() == "night"


def test_ocr() -> None:
img = Image.fromarray(ski.data.page())
img = ski.data.page()
result = ocr(
image=img,
)
assert result[0]["label"] == "Region-based segmentation"
assert any("Region-based segmentation" in res["label"] for res in result)
47 changes: 38 additions & 9 deletions vision_agent/tools/tools_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def grounding_dino(
prompt: str,
image: np.ndarray,
box_threshold: float = 0.20,
iou_threshold: float = 0.75,
iou_threshold: float = 0.20,
) -> List[Dict[str, Any]]:
"""'grounding_dino' is a tool that can detect and count objects given a text prompt
such as category names or referring expressions. It returns a list and count of
Expand Down Expand Up @@ -84,7 +84,7 @@ def grounding_dino(
]
"""
image_size = image.shape[:2]
image_b64 = convert_to_b64(Image.fromarray(image))
image_b64 = convert_to_b64(image)
request_data = {
"prompt": prompt,
"image": image_b64,
Expand All @@ -108,7 +108,7 @@ def grounding_sam(
prompt: str,
image: np.ndarray,
box_threshold: float = 0.20,
iou_threshold: float = 0.75,
iou_threshold: float = 0.20,
) -> List[Dict[str, Any]]:
"""'grounding_sam' is a tool that can detect and segment objects given a text
prompt such as category names or referring expressions. It returns a list of
Expand Down Expand Up @@ -144,7 +144,7 @@ def grounding_sam(
]
"""
image_size = image.shape[:2]
image_b64 = convert_to_b64(Image.fromarray(image))
image_b64 = convert_to_b64(image)
request_data = {
"prompt": prompt,
"image": image_b64,
Expand Down Expand Up @@ -242,7 +242,7 @@ def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
return output


def zero_shot_counting(image: np.ndarray, classes: List[str]) -> Dict[str, Any]:
def zero_shot_counting(image: np.ndarray) -> Dict[str, Any]:
"""'zero_shot_counting' is a tool that counts the dominant foreground object given an image and no other information about the content.
It returns only the count of the objects in the image.
Expand Down Expand Up @@ -305,7 +305,7 @@ def visual_prompt_counting(

def image_question_answering(image: np.ndarray, prompt: str) -> str:
"""'image_question_answering_' is a tool that can answer questions about the visual contents of an image given a question and an image.
It returns a text describing the image and the answer to the question
It returns an answer to the question
Parameters:
image (np.ndarray): The reference image used for the question
Expand All @@ -317,7 +317,7 @@ def image_question_answering(image: np.ndarray, prompt: str) -> str:
Example
-------
>>> image_question_answering(image, 'What is the cat doing ?')
'This image contains a cat sitting on a table with a bowl of milk.'
'drinking milk'
"""

Expand All @@ -328,7 +328,8 @@ def image_question_answering(image: np.ndarray, prompt: str) -> str:
"tool": "image_question_answering",
}

return _send_inference_request(data, "tools")["text"]
answer = _send_inference_request(data, "tools")
return answer["text"][0]


def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]:
Expand All @@ -351,7 +352,7 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]:

image_b64 = convert_to_b64(image)
data = {
"prompt": classes,
"prompt": ",".join(classes),
"image": image_b64,
"tool": "closed_set_image_classification",
}
Expand All @@ -360,6 +361,33 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]:
return resp_data


def image_caption(image: np.ndarray) -> str:
"""'image_caption' is a tool that can caption an image based on its contents.
It returns a text describing the image.
Parameters:
image (np.ndarray): The image to caption
Returns:
str: A string which is the caption for the given image.
Example
-------
>>> image_caption(image)
'This image contains a cat sitting on a table with a bowl of milk.'
"""

image_b64 = convert_to_b64(image)
data = {
"image": image_b64,
"tool": "image_captioning",
}

answer = _send_inference_request(data, "tools")
return answer["text"][0]


def closest_mask_distance(mask1: np.ndarray, mask2: np.ndarray) -> float:
"""'closest_mask_distance' calculates the closest distance between two masks.
Expand Down Expand Up @@ -633,6 +661,7 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
zero_shot_counting,
visual_prompt_counting,
image_question_answering,
image_caption,
closest_mask_distance,
closest_box_distance,
save_json,
Expand Down
9 changes: 7 additions & 2 deletions vision_agent/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,20 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
"""
if data is None:
raise ValueError(f"Invalid input image: {data}. Input image can't be None.")

if isinstance(data, (str, Path)):
data = Image.open(data)
elif isinstance(data, np.ndarray):
data = Image.fromarray(data)

if isinstance(data, Image.Image):
buffer = BytesIO()
data.convert("RGB").save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
else:
arr_bytes = data.tobytes()
return base64.b64encode(arr_bytes).decode("utf-8")
raise ValueError(
f"Invalid input image: {data}. Input image must be a PIL Image or a numpy array."
)


def denormalize_bbox(
Expand Down

0 comments on commit 3a8d3c4

Please sign in to comment.