Skip to content

Commit

Permalink
fixed heatmap overlay and addressesessed PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
shankar-vision-eng committed Apr 21, 2024
1 parent 5e9ef67 commit dd198bc
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 21 deletions.
43 changes: 30 additions & 13 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from PIL import Image
from tabulate import tabulate

from vision_agent.image_utils import overlay_bboxes, overlay_masks
from vision_agent.image_utils import overlay_bboxes, overlay_masks, overlay_heat_map
from vision_agent.llm import LLM, OpenAILLM
from vision_agent.lmm import LMM, OpenAILMM
from vision_agent.tools import TOOLS
Expand Down Expand Up @@ -335,7 +335,9 @@ def _handle_viz_tools(

for param, call_result in zip(parameters, tool_result["call_results"]):
# calls can fail, so we need to check if the call was successful
if not isinstance(call_result, dict) or "bboxes" not in call_result:
if not isinstance(call_result, dict) or (
"bboxes" not in call_result and "masks" not in call_result
):
return image_to_data

# if the call was successful, then we can add the image data
Expand All @@ -348,11 +350,12 @@ def _handle_viz_tools(
"scores": [],
}

image_to_data[image]["bboxes"].extend(call_result["bboxes"])
image_to_data[image]["labels"].extend(call_result["labels"])
image_to_data[image]["scores"].extend(call_result["scores"])
if "masks" in call_result:
image_to_data[image]["masks"].extend(call_result["masks"])
image_to_data[image]["bboxes"].extend(call_result.get("bboxes", []))
image_to_data[image]["labels"].extend(call_result.get("labels", []))
image_to_data[image]["scores"].extend(call_result.get("scores", []))
image_to_data[image]["masks"].extend(call_result.get("masks", []))
if "mask_shape" in call_result:
image_to_data[image]["mask_shape"] = call_result["mask_shape"]

return image_to_data

Expand All @@ -366,6 +369,8 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]
"grounding_dino_",
"extract_frames_",
"dinov_",
"zero_shot_counting_",
"visual_prompt_counting_",
]:
continue

Expand All @@ -378,8 +383,11 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]]
for image_str in image_to_data:
image_path = Path(image_str)
image_data = image_to_data[image_str]
image = overlay_masks(image_path, image_data)
image = overlay_bboxes(image, image_data)
if "_counting_" in tool_result["tool_name"]:
image = overlay_heat_map(image_path, image_data)
else:
image = overlay_masks(image_path, image_data)
image = overlay_bboxes(image, image_data)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
image.save(f.name)
visualized_images.append(f.name)
Expand Down Expand Up @@ -477,11 +485,21 @@ def chat_with_workflow(
if image:
question += f" Image name: {image}"
if reference_data:
if not ("image" in reference_data and "mask" in reference_data):
if not (
"image" in reference_data
and ("mask" in reference_data or "bbox" in reference_data)
):
raise ValueError(
f"Reference data must contain 'image' and 'mask'. but got {reference_data}"
f"Reference data must contain 'image' and a visual prompt which can be 'mask' or 'bbox'. but got {reference_data}"
)
question += f" Reference image: {reference_data['image']}, Reference mask: {reference_data['mask']}"
visual_prompt_data = (
f"Reference mask: {reference_data['mask']}"
if "mask" in reference_data
else f"Reference bbox: {reference_data['bbox']}"
)
question += (
f" Reference image: {reference_data['image']}, {visual_prompt_data}"
)

reflections = ""
final_answer = ""
Expand Down Expand Up @@ -524,7 +542,6 @@ def chat_with_workflow(
final_answer = answer_summarize(
self.answer_model, question, answers, reflections
)

visualized_output = visualize_result(all_tool_results)
all_tool_results.append({"visualized_output": visualized_output})
if len(visualized_output) > 0:
Expand Down
50 changes: 46 additions & 4 deletions vision_agent/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def overlay_bboxes(
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)

if "bboxes" not in bboxes:
return image.convert("RGB")

color = {
label: COLORS[i % len(COLORS)] for i, label in enumerate(set(bboxes["labels"]))
}
Expand All @@ -114,8 +117,6 @@ def overlay_bboxes(
str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")),
fontsize,
)
if "bboxes" not in bboxes:
return image.convert("RGB")

for label, box, scores in zip(bboxes["labels"], bboxes["bboxes"], bboxes["scores"]):
box = [
Expand Down Expand Up @@ -150,11 +151,15 @@ def overlay_masks(
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)

if "masks" not in masks:
return image.convert("RGB")

if "labels" not in masks:
masks["labels"] = [""] * len(masks["masks"])

color = {
label: COLORS[i % len(COLORS)] for i, label in enumerate(set(masks["labels"]))
}
if "masks" not in masks:
return image.convert("RGB")

for label, mask in zip(masks["labels"], masks["masks"]):
if isinstance(mask, str):
Expand All @@ -164,3 +169,40 @@ def overlay_masks(
mask_img = Image.fromarray(np_mask.astype(np.uint8))
image = Image.alpha_composite(image.convert("RGBA"), mask_img)
return image.convert("RGB")


def overlay_heat_map(
image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.8
) -> ImageType:
r"""Plots heat map on to an image.
Parameters:
image: the input image
masks: the heatmap to overlay
alpha: the transparency of the overlay
Returns:
The image with the heatmap overlayed
"""
if isinstance(image, (str, Path)):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)

if "masks" not in masks:
return image.convert("RGB")

# Only one heat map per image, so no need to loop through masks
image = image.convert("L")

if isinstance(masks["masks"][0], str):
mask = b64_to_pil(masks["masks"][0])

overlay = Image.new("RGBA", mask.size)
odraw = ImageDraw.Draw(overlay)
odraw.bitmap(
(0, 0), mask, fill=(255, 0, 0, round(alpha * 255))
) # fill=(R, G, B, Alpha)
combined = Image.alpha_composite(image.convert("RGBA"), overlay.resize(image.size))

return combined.convert("RGB")
7 changes: 3 additions & 4 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,8 @@ class ZeroShotCounting(Tool):
"""

name = "zero_shot_counting_"
description = """'zero_shot_counting_' is a tool that can count total number of instances of an object present in an image belonging to the same class without a text or visual prompt.
It returns the total count of the objects."""
description = "'zero_shot_counting_' is a tool that counts and returns the total number of instances of an object present in an image belonging to the same class without a text or visual prompt."

usage = {
"required_parameters": [
{"name": "image", "type": "str"},
Expand Down Expand Up @@ -561,8 +561,7 @@ class VisualPromptCounting(Tool):
"""

name = "visual_prompt_counting_"
description = """'visual_prompt_counting_' is a tool that can count total number of instances of an object present in an image belonging to the same class given an
example bounding box around a single instance. It returns the total count of the objects."""
description = "'visual_prompt_counting_' is a tool that can count and return total number of instances of an object present in an image belonging to the same class given an example bounding box."

usage = {
"required_parameters": [
Expand Down

0 comments on commit dd198bc

Please sign in to comment.