Skip to content

Commit

Permalink
Adding more cv tools to coder agent (#88)
Browse files Browse the repository at this point in the history
* adding more cv tools to coder agent

* Bug fix: missing kernel python3 (#87)

* Bug fix: missing kernel python3

* Update API key

* [skip ci] chore(release): vision-agent 0.2.24

* adding more cv tools to coder agent

* Added test cases for the tools added

* added test case for every tool

* fix linting

* fixing tests

* fix linting

* fixing test cases for grounding tools as the output format changed

---------

Co-authored-by: Asia <[email protected]>
Co-authored-by: GitHub Actions Bot <[email protected]>
  • Loading branch information
3 people authored May 17, 2024
1 parent 92fd2c8 commit 21849f4
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 32 deletions.
82 changes: 62 additions & 20 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,84 @@
import skimage as ski
from PIL import Image

from vision_agent.tools.tools import CLIP, GroundingDINO, GroundingSAM, ImageCaption
from vision_agent.tools.tools_v2 import (
clip,
zero_shot_counting,
visual_prompt_counting,
image_question_answering,
ocr,
grounding_dino,
grounding_sam,
image_caption,
)


def test_grounding_dino():
img = Image.fromarray(ski.data.coins())
result = GroundingDINO()(
img = ski.data.coins()
result = grounding_dino(
prompt="coin",
image=img,
)
assert result["labels"] == ["coin"] * 24
assert len(result["bboxes"]) == 24
assert len(result["scores"]) == 24
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24


def test_grounding_sam():
img = Image.fromarray(ski.data.coins())
result = GroundingSAM()(
img = ski.data.coins()
result = grounding_sam(
prompt="coin",
image=img,
)
assert result["labels"] == ["coin"] * 24
assert len(result["bboxes"]) == 24
assert len(result["scores"]) == 24
assert len(result["masks"]) == 24
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24
assert len([res["mask"] for res in result]) == 24


def test_clip():
img = Image.fromarray(ski.data.coins())
result = CLIP()(
prompt="coins",
img = ski.data.coins()
result = clip(
classes=["coins", "notes"],
image=img,
)
assert result["scores"] == [1.0]
assert result["scores"] == [0.9999, 0.0001]


def test_image_caption() -> None:
img = Image.fromarray(ski.data.coins())
result = ImageCaption()(image=img)
assert result["text"]
img = ski.data.rocket()
result = image_caption(
image=img,
)
assert result.strip() == "a rocket on a stand"


def test_zero_shot_counting() -> None:
img = ski.data.coins()
result = zero_shot_counting(
image=img,
)
assert result["count"] == 21


def test_visual_prompt_counting() -> None:
img = ski.data.coins()
result = visual_prompt_counting(
visual_prompt={"bbox": [85, 106, 122, 145]},
image=img,
)
assert result["count"] == 25


def test_image_question_answering() -> None:
img = ski.data.rocket()
result = image_question_answering(
prompt="Is the scene captured during day or night ?",
image=img,
)
assert result.strip() == "night"


def test_ocr() -> None:
img = ski.data.page()
result = ocr(
image=img,
)
assert any("Region-based segmentation" in res["label"] for res in result)
2 changes: 1 addition & 1 deletion vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __call__(self) -> None:


class CLIP(Tool):
r"""CLIP is a tool that can classify or tag any image given a set if input classes
r"""CLIP is a tool that can classify or tag any image given a set of input classes
or tags.
Example
Expand Down
180 changes: 171 additions & 9 deletions vision_agent/tools/tools_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@

from vision_agent.tools.tool_utils import _send_inference_request
from vision_agent.utils import extract_frames_from_video
from vision_agent.utils.image_utils import convert_to_b64, normalize_bbox, rle_decode
from vision_agent.utils.image_utils import (
convert_to_b64,
normalize_bbox,
rle_decode,
b64_to_pil,
get_image_size,
denormalize_bbox,
)

COLORS = [
(158, 218, 229),
Expand Down Expand Up @@ -49,7 +56,7 @@ def grounding_dino(
prompt: str,
image: np.ndarray,
box_threshold: float = 0.20,
iou_threshold: float = 0.75,
iou_threshold: float = 0.20,
) -> List[Dict[str, Any]]:
"""'grounding_dino' is a tool that can detect and count objects given a text prompt
such as category names or referring expressions. It returns a list and count of
Expand All @@ -61,12 +68,13 @@ def grounding_dino(
box_threshold (float, optional): The threshold for the box detection. Defaults
to 0.20.
iou_threshold (float, optional): The threshold for the Intersection over Union
(IoU). Defaults to 0.75.
(IoU). Defaults to 0.20.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
bounding box of the detected objects with normalized coordinates
(x1, y1, x2, y2).
(xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left and
xmax and ymax are the coordinates of the bottom-right of the bounding box.
Example
-------
Expand All @@ -77,7 +85,7 @@ def grounding_dino(
]
"""
image_size = image.shape[:2]
image_b64 = convert_to_b64(Image.fromarray(image))
image_b64 = convert_to_b64(image)
request_data = {
"prompt": prompt,
"image": image_b64,
Expand All @@ -101,7 +109,7 @@ def grounding_sam(
prompt: str,
image: np.ndarray,
box_threshold: float = 0.20,
iou_threshold: float = 0.75,
iou_threshold: float = 0.20,
) -> List[Dict[str, Any]]:
"""'grounding_sam' is a tool that can detect and segment objects given a text
prompt such as category names or referring expressions. It returns a list of
Expand All @@ -113,12 +121,15 @@ def grounding_sam(
box_threshold (float, optional): The threshold for the box detection. Defaults
to 0.20.
iou_threshold (float, optional): The threshold for the Intersection over Union
(IoU). Defaults to 0.75.
(IoU). Defaults to 0.20.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label,
bounding box, and mask of the detected objects with normalized coordinates
(x1, y1, x2, y2).
(xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left and
xmax and ymax are the coordinates of the bottom-right of the bounding box.
The mask is binary 2D numpy array where 1 indicates the object and 0 indicates
the background.
Example
-------
Expand All @@ -137,7 +148,7 @@ def grounding_sam(
]
"""
image_size = image.shape[:2]
image_b64 = convert_to_b64(Image.fromarray(image))
image_b64 = convert_to_b64(image)
request_data = {
"prompt": prompt,
"image": image_b64,
Expand Down Expand Up @@ -235,6 +246,152 @@ def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
return output


def zero_shot_counting(image: np.ndarray) -> Dict[str, Any]:
"""'zero_shot_counting' is a tool that counts the dominant foreground object given an image and no other information about the content.
It returns only the count of the objects in the image.
Parameters:
image (np.ndarray): The image that contains lot of instances of a single object
Returns:
Dict[str, Any]: A dictionary containing the key 'count' and the count as a value. E.g. {count: 12}.
Example
-------
>>> zero_shot_counting(image)
{'count': 45},
"""

image_b64 = convert_to_b64(image)
data = {
"image": image_b64,
"tool": "zero_shot_counting",
}
resp_data = _send_inference_request(data, "tools")
resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0]))
return resp_data


def visual_prompt_counting(
image: np.ndarray, visual_prompt: Dict[str, List[float]]
) -> Dict[str, Any]:
"""'visual_prompt_counting' is a tool that counts the dominant foreground object given an image and a visual prompt which is a bounding box describing the object.
It returns only the count of the objects in the image.
Parameters:
image (np.ndarray): The image that contains lot of instances of a single object
Returns:
Dict[str, Any]: A dictionary containing the key 'count' and the count as a value. E.g. {count: 12}.
Example
-------
>>> visual_prompt_counting(image, {"bbox": [0.1, 0.1, 0.4, 0.42]})
{'count': 45},
"""

image_size = get_image_size(image)
bbox = visual_prompt["bbox"]
bbox_str = ", ".join(map(str, denormalize_bbox(bbox, image_size)))
image_b64 = convert_to_b64(image)

data = {
"image": image_b64,
"prompt": bbox_str,
"tool": "few_shot_counting",
}
resp_data = _send_inference_request(data, "tools")
resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0]))
return resp_data


def image_question_answering(image: np.ndarray, prompt: str) -> str:
"""'image_question_answering_' is a tool that can answer questions about the visual contents of an image given a question and an image.
It returns an answer to the question
Parameters:
image (np.ndarray): The reference image used for the question
prompt (str): The question about the image
Returns:
str: A string which is the answer to the given prompt. E.g. {'text': 'This image contains a cat sitting on a table with a bowl of milk.'}.
Example
-------
>>> image_question_answering(image, 'What is the cat doing ?')
'drinking milk'
"""

image_b64 = convert_to_b64(image)
data = {
"image": image_b64,
"prompt": prompt,
"tool": "image_question_answering",
}

answer = _send_inference_request(data, "tools")
return answer["text"][0] # type: ignore


def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]:
"""'clip' is a tool that can classify an image given a list of input classes or tags.
It returns the same list of the input classes along with their probability scores based on image content.
Parameters:
image (np.ndarray): The image to classify or tag
classes (List[str]): The list of classes or tags that is associated with the image
Returns:
Dict[str, Any]: A dictionary containing the labels and scores. One dictionary contains a list of given labels and other a list of scores.
Example
-------
>>> clip(image, ['dog', 'cat', 'bird'])
{"labels": ["dog", "cat", "bird"], "scores": [0.68, 0.30, 0.02]},
"""

image_b64 = convert_to_b64(image)
data = {
"prompt": ",".join(classes),
"image": image_b64,
"tool": "closed_set_image_classification",
}
resp_data = _send_inference_request(data, "tools")
resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
return resp_data


def image_caption(image: np.ndarray) -> str:
"""'image_caption' is a tool that can caption an image based on its contents.
It returns a text describing the image.
Parameters:
image (np.ndarray): The image to caption
Returns:
str: A string which is the caption for the given image.
Example
-------
>>> image_caption(image)
'This image contains a cat sitting on a table with a bowl of milk.'
"""

image_b64 = convert_to_b64(image)
data = {
"image": image_b64,
"tool": "image_captioning",
}

answer = _send_inference_request(data, "tools")
return answer["text"][0] # type: ignore


def closest_mask_distance(mask1: np.ndarray, mask2: np.ndarray) -> float:
"""'closest_mask_distance' calculates the closest distance between two masks.
Expand Down Expand Up @@ -504,6 +661,11 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
grounding_sam,
extract_frames,
ocr,
clip,
zero_shot_counting,
visual_prompt_counting,
image_question_answering,
image_caption,
closest_mask_distance,
closest_box_distance,
save_json,
Expand Down
9 changes: 7 additions & 2 deletions vision_agent/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,20 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
"""
if data is None:
raise ValueError(f"Invalid input image: {data}. Input image can't be None.")

if isinstance(data, (str, Path)):
data = Image.open(data)
elif isinstance(data, np.ndarray):
data = Image.fromarray(data)

if isinstance(data, Image.Image):
buffer = BytesIO()
data.convert("RGB").save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
else:
arr_bytes = data.tobytes()
return base64.b64encode(arr_bytes).decode("utf-8")
raise ValueError(
f"Invalid input image: {data}. Input image must be a PIL Image or a numpy array."
)


def denormalize_bbox(
Expand Down

0 comments on commit 21849f4

Please sign in to comment.