Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding more cv tools to coder agent #88

Merged
merged 11 commits into from
May 17, 2024
333 changes: 332 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "vision-agent"
version = "0.2.23"
version = "0.2.24"
description = "Toolset for Vision Agent"
authors = ["Landing AI <[email protected]>"]
readme = "README.md"
Expand Down Expand Up @@ -33,6 +33,7 @@ nbclient = "^0.10.0"
nbformat = "^5.10.4"
rich = "^13.7.1"
langsmith = "^0.1.58"
ipykernel = "^6.29.4"

[tool.poetry.group.dev.dependencies]
autoflake = "1.*"
Expand Down
82 changes: 62 additions & 20 deletions tests/test_tools.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Original file line number Diff line number Diff line change
@@ -1,42 +1,84 @@
import skimage as ski
from PIL import Image

from vision_agent.tools.tools import CLIP, GroundingDINO, GroundingSAM, ImageCaption
from vision_agent.tools.tools_v2 import (
clip,
zero_shot_counting,
visual_prompt_counting,
image_question_answering,
ocr,
grounding_dino,
grounding_sam,
image_caption,
)


def test_grounding_dino():
img = Image.fromarray(ski.data.coins())
result = GroundingDINO()(
img = ski.data.coins()
result = grounding_dino(
prompt="coin",
image=img,
)
assert result["labels"] == ["coin"] * 24
assert len(result["bboxes"]) == 24
assert len(result["scores"]) == 24
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24


def test_grounding_sam():
img = Image.fromarray(ski.data.coins())
result = GroundingSAM()(
img = ski.data.coins()
result = grounding_sam(
prompt="coin",
image=img,
)
assert result["labels"] == ["coin"] * 24
assert len(result["bboxes"]) == 24
assert len(result["scores"]) == 24
assert len(result["masks"]) == 24
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24
assert len([res["mask"] for res in result]) == 24


def test_clip():
img = Image.fromarray(ski.data.coins())
result = CLIP()(
prompt="coins",
img = ski.data.coins()
result = clip(
classes=["coins", "notes"],
image=img,
)
assert result["scores"] == [1.0]
assert result["scores"] == [0.9999, 0.0001]


def test_image_caption() -> None:
img = Image.fromarray(ski.data.coins())
result = ImageCaption()(image=img)
assert result["text"]
img = ski.data.rocket()
result = image_caption(
image=img,
)
assert result.strip() == "a rocket on a stand"


def test_zero_shot_counting() -> None:
img = ski.data.coins()
result = zero_shot_counting(
image=img,
)
assert result["count"] == 21


def test_visual_prompt_counting() -> None:
img = ski.data.coins()
result = visual_prompt_counting(
visual_prompt={"bbox": [85, 106, 122, 145]},
image=img,
)
assert result["count"] == 25


def test_image_question_answering() -> None:
img = ski.data.rocket()
result = image_question_answering(
prompt="Is the scene captured during day or night ?",
image=img,
)
assert result.strip() == "night"


def test_ocr() -> None:
img = ski.data.page()
result = ocr(
image=img,
)
assert any("Region-based segmentation" in res["label"] for res in result)
2 changes: 1 addition & 1 deletion vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __call__(self) -> None:


class CLIP(Tool):
r"""CLIP is a tool that can classify or tag any image given a set if input classes
r"""CLIP is a tool that can classify or tag any image given a set of input classes
or tags.
Example
Expand Down
180 changes: 171 additions & 9 deletions vision_agent/tools/tools_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@

from vision_agent.tools.tool_utils import _send_inference_request
from vision_agent.utils import extract_frames_from_video
from vision_agent.utils.image_utils import convert_to_b64, normalize_bbox, rle_decode
from vision_agent.utils.image_utils import (
convert_to_b64,
normalize_bbox,
rle_decode,
b64_to_pil,
get_image_size,
denormalize_bbox,
)

COLORS = [
(158, 218, 229),
Expand Down Expand Up @@ -49,7 +56,7 @@ def grounding_dino(
prompt: str,
image: np.ndarray,
box_threshold: float = 0.20,
iou_threshold: float = 0.75,
iou_threshold: float = 0.20,
) -> List[Dict[str, Any]]:
"""'grounding_dino' is a tool that can detect and count objects given a text prompt
such as category names or referring expressions. It returns a list and count of
Expand All @@ -61,12 +68,13 @@ def grounding_dino(
box_threshold (float, optional): The threshold for the box detection. Defaults
to 0.20.
iou_threshold (float, optional): The threshold for the Intersection over Union
(IoU). Defaults to 0.75.
(IoU). Defaults to 0.20.

Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label, and
bounding box of the detected objects with normalized coordinates
(x1, y1, x2, y2).
(xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left and
xmax and ymax are the coordinates of the bottom-right of the bounding box.

Example
-------
Expand All @@ -77,7 +85,7 @@ def grounding_dino(
]
"""
image_size = image.shape[:2]
image_b64 = convert_to_b64(Image.fromarray(image))
image_b64 = convert_to_b64(image)
request_data = {
"prompt": prompt,
"image": image_b64,
Expand All @@ -101,7 +109,7 @@ def grounding_sam(
prompt: str,
image: np.ndarray,
box_threshold: float = 0.20,
iou_threshold: float = 0.75,
iou_threshold: float = 0.20,
) -> List[Dict[str, Any]]:
"""'grounding_sam' is a tool that can detect and segment objects given a text
prompt such as category names or referring expressions. It returns a list of
Expand All @@ -113,12 +121,15 @@ def grounding_sam(
box_threshold (float, optional): The threshold for the box detection. Defaults
to 0.20.
iou_threshold (float, optional): The threshold for the Intersection over Union
(IoU). Defaults to 0.75.
(IoU). Defaults to 0.20.

Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label,
bounding box, and mask of the detected objects with normalized coordinates
(x1, y1, x2, y2).
(xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left and
xmax and ymax are the coordinates of the bottom-right of the bounding box.
The mask is binary 2D numpy array where 1 indicates the object and 0 indicates
the background.

Example
-------
Expand All @@ -137,7 +148,7 @@ def grounding_sam(
]
"""
image_size = image.shape[:2]
image_b64 = convert_to_b64(Image.fromarray(image))
image_b64 = convert_to_b64(image)
request_data = {
"prompt": prompt,
"image": image_b64,
Expand Down Expand Up @@ -235,6 +246,152 @@ def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
return output


def zero_shot_counting(image: np.ndarray) -> Dict[str, Any]:
"""'zero_shot_counting' is a tool that counts the dominant foreground object given an image and no other information about the content.
It returns only the count of the objects in the image.

Parameters:
image (np.ndarray): The image that contains lot of instances of a single object

Returns:
Dict[str, Any]: A dictionary containing the key 'count' and the count as a value. E.g. {count: 12}.

Example
-------
>>> zero_shot_counting(image)
{'count': 45},

"""

image_b64 = convert_to_b64(image)
data = {
"image": image_b64,
"tool": "zero_shot_counting",
}
resp_data = _send_inference_request(data, "tools")
resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0]))
return resp_data


def visual_prompt_counting(
image: np.ndarray, visual_prompt: Dict[str, List[float]]
) -> Dict[str, Any]:
"""'visual_prompt_counting' is a tool that counts the dominant foreground object given an image and a visual prompt which is a bounding box describing the object.
It returns only the count of the objects in the image.

Parameters:
image (np.ndarray): The image that contains lot of instances of a single object

Returns:
Dict[str, Any]: A dictionary containing the key 'count' and the count as a value. E.g. {count: 12}.

Example
-------
>>> visual_prompt_counting(image, {"bbox": [0.1, 0.1, 0.4, 0.42]})
{'count': 45},

"""

image_size = get_image_size(image)
bbox = visual_prompt["bbox"]
bbox_str = ", ".join(map(str, denormalize_bbox(bbox, image_size)))
image_b64 = convert_to_b64(image)

data = {
"image": image_b64,
"prompt": bbox_str,
"tool": "few_shot_counting",
}
resp_data = _send_inference_request(data, "tools")
resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0]))
return resp_data


def image_question_answering(image: np.ndarray, prompt: str) -> str:
"""'image_question_answering_' is a tool that can answer questions about the visual contents of an image given a question and an image.
It returns an answer to the question

Parameters:
image (np.ndarray): The reference image used for the question
prompt (str): The question about the image

Returns:
str: A string which is the answer to the given prompt. E.g. {'text': 'This image contains a cat sitting on a table with a bowl of milk.'}.

Example
-------
>>> image_question_answering(image, 'What is the cat doing ?')
'drinking milk'

"""

image_b64 = convert_to_b64(image)
data = {
"image": image_b64,
"prompt": prompt,
"tool": "image_question_answering",
}

answer = _send_inference_request(data, "tools")
return answer["text"][0] # type: ignore


def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]:
"""'clip' is a tool that can classify an image given a list of input classes or tags.
It returns the same list of the input classes along with their probability scores based on image content.

Parameters:
image (np.ndarray): The image to classify or tag
classes (List[str]): The list of classes or tags that is associated with the image

Returns:
Dict[str, Any]: A dictionary containing the labels and scores. One dictionary contains a list of given labels and other a list of scores.

Example
-------
>>> clip(image, ['dog', 'cat', 'bird'])
{"labels": ["dog", "cat", "bird"], "scores": [0.68, 0.30, 0.02]},

"""

image_b64 = convert_to_b64(image)
data = {
"prompt": ",".join(classes),
"image": image_b64,
"tool": "closed_set_image_classification",
}
resp_data = _send_inference_request(data, "tools")
resp_data["scores"] = [round(prob, 4) for prob in resp_data["scores"]]
return resp_data


def image_caption(image: np.ndarray) -> str:
"""'image_caption' is a tool that can caption an image based on its contents.
It returns a text describing the image.

Parameters:
image (np.ndarray): The image to caption

Returns:
str: A string which is the caption for the given image.

Example
-------
>>> image_caption(image)
'This image contains a cat sitting on a table with a bowl of milk.'

"""

image_b64 = convert_to_b64(image)
data = {
"image": image_b64,
"tool": "image_captioning",
}

answer = _send_inference_request(data, "tools")
return answer["text"][0] # type: ignore


def closest_mask_distance(mask1: np.ndarray, mask2: np.ndarray) -> float:
"""'closest_mask_distance' calculates the closest distance between two masks.

Expand Down Expand Up @@ -504,6 +661,11 @@ def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
grounding_sam,
extract_frames,
ocr,
clip,
zero_shot_counting,
visual_prompt_counting,
image_question_answering,
image_caption,
closest_mask_distance,
closest_box_distance,
save_json,
Expand Down
Loading
Loading