diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 93218e6c..414f36d6 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -340,9 +340,13 @@ def _handle_viz_tools( return image_to_data for param, call_result in zip(parameters, tool_result["call_results"]): - # calls can fail, so we need to check if the call was successful + # Calls can fail, so we need to check if the call was successful. It can either: + # 1. return a str or some error that's not a dictionary + # 2. return a dictionary but not have the necessary keys if not isinstance(call_result, dict) or ( - "bboxes" not in call_result and "masks" not in call_result + "bboxes" not in call_result + and "masks" not in call_result + and "heat_map" not in call_result ): return image_to_data @@ -352,6 +356,7 @@ def _handle_viz_tools( image_to_data[image] = { "bboxes": [], "masks": [], + "heat_map": [], "labels": [], "scores": [], } @@ -360,6 +365,8 @@ def _handle_viz_tools( image_to_data[image]["labels"].extend(call_result.get("labels", [])) image_to_data[image]["scores"].extend(call_result.get("scores", [])) image_to_data[image]["masks"].extend(call_result.get("masks", [])) + # only single heatmap is returned + image_to_data[image]["heat_map"].append(call_result.get("heat_map", [])) if "mask_shape" in call_result: image_to_data[image]["mask_shape"] = call_result["mask_shape"] diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py index 4786f84b..4f7b69cd 100644 --- a/vision_agent/image_utils.py +++ b/vision_agent/image_utils.py @@ -211,7 +211,7 @@ def overlay_masks( } for label, mask in zip(masks["labels"], masks["masks"]): - if isinstance(mask, str): + if isinstance(mask, str) or isinstance(mask, Path): mask = np.array(Image.open(mask)) np_mask = np.zeros((image.size[1], image.size[0], 4)) np_mask[mask > 0, :] = color[label] + (255 * alpha,) @@ -221,7 +221,7 @@ def overlay_masks( def overlay_heat_map( - image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.8 + image: Union[str, Path, np.ndarray, ImageType], heat_map: Dict, alpha: float = 0.8 ) -> ImageType: r"""Plots heat map on to an image. @@ -238,14 +238,12 @@ def overlay_heat_map( elif isinstance(image, np.ndarray): image = Image.fromarray(image) - if "masks" not in masks: + if "masks" not in heat_map: return image.convert("RGB") - # Only one heat map per image, so no need to loop through masks image = image.convert("L") - - if isinstance(masks["masks"][0], str): - mask = b64_to_pil(masks["masks"][0]) + # Only one heat map per image, so no need to loop through masks + mask = Image.fromarray(heat_map["heat_map"][0]) overlay = Image.new("RGBA", mask.size) odraw = ImageDraw.Draw(overlay) diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index a1fcc3c2..cc8861bd 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -9,10 +9,7 @@ import requests from openai import AzureOpenAI, OpenAI -from vision_agent.tools import ( - CHOOSE_PARAMS, - SYSTEM_PROMPT, -) +from vision_agent.tools import CHOOSE_PARAMS, SYSTEM_PROMPT _LOGGER = logging.getLogger(__name__) diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 60870b56..10daf7eb 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -12,12 +12,12 @@ GroundingDINO, GroundingSAM, ImageCaption, - ZeroShotCounting, - VisualPromptCounting, - VisualQuestionAnswering, ImageQuestionAnswering, SegArea, SegIoU, Tool, + VisualPromptCounting, + VisualQuestionAnswering, + ZeroShotCounting, register_tool, ) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 3bf2bfbf..fd315bed 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -11,15 +11,16 @@ from PIL.Image import Image as ImageType from vision_agent.image_utils import ( + b64_to_pil, convert_to_b64, denormalize_bbox, get_image_size, normalize_bbox, rle_decode, ) +from vision_agent.lmm import OpenAILMM from vision_agent.tools.video import extract_frames_from_video from vision_agent.type_defs import LandingaiAPIKey -from vision_agent.lmm import OpenAILMM _LOGGER = logging.getLogger(__name__) _LND_API_KEY = LandingaiAPIKey().api_key @@ -516,7 +517,9 @@ def __call__(self, image: Union[str, ImageType]) -> Dict: "image": image_b64, "tool": "zero_shot_counting", } - return _send_inference_request(data, "tools") + resp_data = _send_inference_request(data, "tools") + resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0])) + return resp_data class VisualPromptCounting(Tool):