diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 44c3aa08..9740f23e 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -8,7 +8,7 @@ from PIL import Image from tabulate import tabulate -from vision_agent.image_utils import overlay_bboxes, overlay_masks +from vision_agent.image_utils import overlay_bboxes, overlay_masks, overlay_heat_map from vision_agent.llm import LLM, OpenAILLM from vision_agent.lmm import LMM, OpenAILMM from vision_agent.tools import TOOLS @@ -335,7 +335,9 @@ def _handle_viz_tools( for param, call_result in zip(parameters, tool_result["call_results"]): # calls can fail, so we need to check if the call was successful - if not isinstance(call_result, dict) or "bboxes" not in call_result: + if not isinstance(call_result, dict) or ( + "bboxes" not in call_result and "masks" not in call_result + ): return image_to_data # if the call was successful, then we can add the image data @@ -348,11 +350,12 @@ def _handle_viz_tools( "scores": [], } - image_to_data[image]["bboxes"].extend(call_result["bboxes"]) - image_to_data[image]["labels"].extend(call_result["labels"]) - image_to_data[image]["scores"].extend(call_result["scores"]) - if "masks" in call_result: - image_to_data[image]["masks"].extend(call_result["masks"]) + image_to_data[image]["bboxes"].extend(call_result.get("bboxes", [])) + image_to_data[image]["labels"].extend(call_result.get("labels", [])) + image_to_data[image]["scores"].extend(call_result.get("scores", [])) + image_to_data[image]["masks"].extend(call_result.get("masks", [])) + if "mask_shape" in call_result: + image_to_data[image]["mask_shape"] = call_result["mask_shape"] return image_to_data @@ -366,6 +369,8 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]] "grounding_dino_", "extract_frames_", "dinov_", + "zero_shot_counting_", + "visual_prompt_counting_", ]: continue @@ -378,8 +383,11 @@ def visualize_result(all_tool_results: List[Dict]) -> Sequence[Union[str, Path]] for image_str in image_to_data: image_path = Path(image_str) image_data = image_to_data[image_str] - image = overlay_masks(image_path, image_data) - image = overlay_bboxes(image, image_data) + if "_counting_" in tool_result["tool_name"]: + image = overlay_heat_map(image_path, image_data) + else: + image = overlay_masks(image_path, image_data) + image = overlay_bboxes(image, image_data) with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: image.save(f.name) visualized_images.append(f.name) @@ -477,11 +485,21 @@ def chat_with_workflow( if image: question += f" Image name: {image}" if reference_data: - if not ("image" in reference_data and "mask" in reference_data): + if not ( + "image" in reference_data + and ("mask" in reference_data or "bbox" in reference_data) + ): raise ValueError( - f"Reference data must contain 'image' and 'mask'. but got {reference_data}" + f"Reference data must contain 'image' and a visual prompt which can be 'mask' or 'bbox'. but got {reference_data}" ) - question += f" Reference image: {reference_data['image']}, Reference mask: {reference_data['mask']}" + visual_prompt_data = ( + f"Reference mask: {reference_data['mask']}" + if "mask" in reference_data + else f"Reference bbox: {reference_data['bbox']}" + ) + question += ( + f" Reference image: {reference_data['image']}, {visual_prompt_data}" + ) reflections = "" final_answer = "" @@ -524,7 +542,6 @@ def chat_with_workflow( final_answer = answer_summarize( self.answer_model, question, answers, reflections ) - visualized_output = visualize_result(all_tool_results) all_tool_results.append({"visualized_output": visualized_output}) if len(visualized_output) > 0: diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py index f36a2033..fefa6b13 100644 --- a/vision_agent/image_utils.py +++ b/vision_agent/image_utils.py @@ -103,6 +103,9 @@ def overlay_bboxes( elif isinstance(image, np.ndarray): image = Image.fromarray(image) + if "bboxes" not in bboxes: + return image.convert("RGB") + color = { label: COLORS[i % len(COLORS)] for i, label in enumerate(set(bboxes["labels"])) } @@ -114,8 +117,6 @@ def overlay_bboxes( str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")), fontsize, ) - if "bboxes" not in bboxes: - return image.convert("RGB") for label, box, scores in zip(bboxes["labels"], bboxes["bboxes"], bboxes["scores"]): box = [ @@ -150,11 +151,15 @@ def overlay_masks( elif isinstance(image, np.ndarray): image = Image.fromarray(image) + if "masks" not in masks: + return image.convert("RGB") + + if "labels" not in masks: + masks["labels"] = [""] * len(masks["masks"]) + color = { label: COLORS[i % len(COLORS)] for i, label in enumerate(set(masks["labels"])) } - if "masks" not in masks: - return image.convert("RGB") for label, mask in zip(masks["labels"], masks["masks"]): if isinstance(mask, str): @@ -164,3 +169,40 @@ def overlay_masks( mask_img = Image.fromarray(np_mask.astype(np.uint8)) image = Image.alpha_composite(image.convert("RGBA"), mask_img) return image.convert("RGB") + + +def overlay_heat_map( + image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.8 +) -> ImageType: + r"""Plots heat map on to an image. + + Parameters: + image: the input image + masks: the heatmap to overlay + alpha: the transparency of the overlay + + Returns: + The image with the heatmap overlayed + """ + if isinstance(image, (str, Path)): + image = Image.open(image) + elif isinstance(image, np.ndarray): + image = Image.fromarray(image) + + if "masks" not in masks: + return image.convert("RGB") + + # Only one heat map per image, so no need to loop through masks + image = image.convert("L") + + if isinstance(masks["masks"][0], str): + mask = b64_to_pil(masks["masks"][0]) + + overlay = Image.new("RGBA", mask.size) + odraw = ImageDraw.Draw(overlay) + odraw.bitmap( + (0, 0), mask, fill=(255, 0, 0, round(alpha * 255)) + ) # fill=(R, G, B, Alpha) + combined = Image.alpha_composite(image.convert("RGBA"), overlay.resize(image.size)) + + return combined.convert("RGB") diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 40728e62..e4d8ada7 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -506,8 +506,8 @@ class ZeroShotCounting(Tool): """ name = "zero_shot_counting_" - description = """'zero_shot_counting_' is a tool that can count total number of instances of an object present in an image belonging to the same class without a text or visual prompt. - It returns the total count of the objects.""" + description = "'zero_shot_counting_' is a tool that counts and returns the total number of instances of an object present in an image belonging to the same class without a text or visual prompt." + usage = { "required_parameters": [ {"name": "image", "type": "str"}, @@ -561,8 +561,7 @@ class VisualPromptCounting(Tool): """ name = "visual_prompt_counting_" - description = """'visual_prompt_counting_' is a tool that can count total number of instances of an object present in an image belonging to the same class given an - example bounding box around a single instance. It returns the total count of the objects.""" + description = "'visual_prompt_counting_' is a tool that can count and return total number of instances of an object present in an image belonging to the same class given an example bounding box." usage = { "required_parameters": [