Skip to content

Commit

Permalink
added overlay heatmap
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed May 29, 2024
1 parent 503b3e0 commit e1a3543
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 14 deletions.
4 changes: 3 additions & 1 deletion tests/tools/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import pytest
from PIL import Image

from vision_agent.tools.easytool_tools import TOOLS, Tool, register_tool
from vision_agent.tools.easytool_tools import (
TOOLS,
BboxIoU,
BoxDistance,
MaskDistance,
SegArea,
SegIoU,
Tool,
register_tool,
)


Expand Down
1 change: 1 addition & 0 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
load_image,
ocr,
overlay_bounding_boxes,
overlay_heat_map,
overlay_segmentation_masks,
save_image,
save_json,
Expand Down
79 changes: 66 additions & 13 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def ocr(image: np.ndarray) -> List[Dict[str, Any]]:
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the detected text, bbox,
and confidence score.
and confidence score.
Example
-------
Expand Down Expand Up @@ -247,14 +247,16 @@ def ocr(image: np.ndarray) -> List[Dict[str, Any]]:


def zero_shot_counting(image: np.ndarray) -> Dict[str, Any]:
"""'zero_shot_counting' is a tool that counts the dominant foreground object given an image and no other information about the content.
It returns only the count of the objects in the image.
"""'zero_shot_counting' is a tool that counts the dominant foreground object given
an image and no other information about the content. It returns only the count of
the objects in the image.
Parameters:
image (np.ndarray): The image that contains lot of instances of a single object
Returns:
Dict[str, Any]: A dictionary containing the key 'count' and the count as a value. E.g. {count: 12}.
Dict[str, Any]: A dictionary containing the key 'count' and the count as a
value. E.g. {count: 12}.
Example
-------
Expand All @@ -276,14 +278,16 @@ def zero_shot_counting(image: np.ndarray) -> Dict[str, Any]:
def visual_prompt_counting(
image: np.ndarray, visual_prompt: Dict[str, List[float]]
) -> Dict[str, Any]:
"""'visual_prompt_counting' is a tool that counts the dominant foreground object given an image and a visual prompt which is a bounding box describing the object.
"""'visual_prompt_counting' is a tool that counts the dominant foreground object
given an image and a visual prompt which is a bounding box describing the object.
It returns only the count of the objects in the image.
Parameters:
image (np.ndarray): The image that contains lot of instances of a single object
Returns:
Dict[str, Any]: A dictionary containing the key 'count' and the count as a value. E.g. {count: 12}.
Dict[str, Any]: A dictionary containing the key 'count' and the count as a
value. E.g. {count: 12}.
Example
-------
Expand All @@ -308,15 +312,17 @@ def visual_prompt_counting(


def image_question_answering(image: np.ndarray, prompt: str) -> str:
"""'image_question_answering_' is a tool that can answer questions about the visual contents of an image given a question and an image.
It returns an answer to the question
"""'image_question_answering_' is a tool that can answer questions about the visual
contents of an image given a question and an image. It returns an answer to the
question
Parameters:
image (np.ndarray): The reference image used for the question
prompt (str): The question about the image
Returns:
str: A string which is the answer to the given prompt. E.g. {'text': 'This image contains a cat sitting on a table with a bowl of milk.'}.
str: A string which is the answer to the given prompt. E.g. {'text': 'This
image contains a cat sitting on a table with a bowl of milk.'}.
Example
-------
Expand All @@ -338,14 +344,16 @@ def image_question_answering(image: np.ndarray, prompt: str) -> str:

def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]:
"""'clip' is a tool that can classify an image given a list of input classes or tags.
It returns the same list of the input classes along with their probability scores based on image content.
It returns the same list of the input classes along with their probability scores
based on image content.
Parameters:
image (np.ndarray): The image to classify or tag
classes (List[str]): The list of classes or tags that is associated with the image
Returns:
Dict[str, Any]: A dictionary containing the labels and scores. One dictionary contains a list of given labels and other a list of scores.
Dict[str, Any]: A dictionary containing the labels and scores. One dictionary
contains a list of given labels and other a list of scores.
Example
-------
Expand All @@ -366,8 +374,8 @@ def clip(image: np.ndarray, classes: List[str]) -> Dict[str, Any]:


def image_caption(image: np.ndarray) -> str:
"""'image_caption' is a tool that can caption an image based on its contents.
It returns a text describing the image.
"""'image_caption' is a tool that can caption an image based on its contents. It
returns a text describing the image.
Parameters:
image (np.ndarray): The image to caption
Expand Down Expand Up @@ -619,6 +627,51 @@ def overlay_segmentation_masks(
return np.array(pil_image.convert("RGB"))


def overlay_heat_map(
image: np.ndarray, heat_map: Dict[str, Any], alpha: float = 0.8
) -> np.ndarray:
"""'display_heat_map' is a utility function that displays a heat map on an image.
Parameters:
image (np.ndarray): The image to display the heat map on.
heat_map (Dict[str, Any]): A dictionary containing the heat map under the key
'heat_map'.
alpha (float, optional): The transparency of the overlay. Defaults to 0.8.
Returns:
np.ndarray: The image with the heat map displayed.
Example
-------
>>> image_with_heat_map = display_heat_map(
image,
{
'heat_map': array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 125, 125, 125]], dtype=uint8),
},
)
"""
pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")

if "heat_map" not in heat_map or len(heat_map["heat_map"]) == 0:
return image

pil_image = pil_image.convert("L")
mask = Image.fromarray(heat_map["heat_map"])
mask = mask.resize(pil_image.size)

overlay = Image.new("RGBA", mask.size)
odraw = ImageDraw.Draw(overlay)
odraw.bitmap((0, 0), mask, fill=(255, 0, 0, round(alpha * 255)))
combined = Image.alpha_composite(
pil_image.convert("RGBA"), overlay.resize(pil_image.size)
)
return np.array(combined.convert("RGB"))


def get_tool_documentation(funcs: List[Callable[..., Any]]) -> str:
docstrings = ""
for func in funcs:
Expand Down

0 comments on commit e1a3543

Please sign in to comment.