From 7f23463b6f0a75b33760181fdb6d9022e8364cad Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Wed, 24 Apr 2024 22:58:03 -0700 Subject: [PATCH] Fix zero shot count visualization (#65) * fixed issue with zero shot viz * updated docs * updated return for visual prompt counting * add minor fixes which were causing issues --------- Co-authored-by: shankar_ws3 --- vision_agent/agent/vision_agent.py | 32 +++++++++++++++++++++--------- vision_agent/image_utils.py | 12 +++++------ vision_agent/tools/tools.py | 9 +++++++-- 3 files changed, 35 insertions(+), 18 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 514f8fa5..b72e89eb 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -314,6 +314,7 @@ def _handle_extract_frames( image_to_data[image] = { "bboxes": [], "masks": [], + "heat_map": [], "labels": [], "scores": [], } @@ -340,9 +341,12 @@ def _handle_viz_tools( return image_to_data for param, call_result in zip(parameters, tool_result["call_results"]): - # calls can fail, so we need to check if the call was successful + # Calls can fail, so we need to check if the call was successful. It can either: + # 1. return a str or some error that's not a dictionary + # 2. return a dictionary but not have the necessary keys + if not isinstance(call_result, dict) or ( - "bboxes" not in call_result and "masks" not in call_result + "bboxes" not in call_result and "heat_map" not in call_result ): return image_to_data @@ -352,6 +356,7 @@ def _handle_viz_tools( image_to_data[image] = { "bboxes": [], "masks": [], + "heat_map": [], "labels": [], "scores": [], } @@ -360,6 +365,8 @@ def _handle_viz_tools( image_to_data[image]["labels"].extend(call_result.get("labels", [])) image_to_data[image]["scores"].extend(call_result.get("scores", [])) image_to_data[image]["masks"].extend(call_result.get("masks", [])) + # only single heatmap is returned + image_to_data[image]["heat_map"].append(call_result.get("heat_map", [])) if "mask_shape" in call_result: image_to_data[image]["mask_shape"] = call_result["mask_shape"] @@ -480,9 +487,14 @@ def __call__( """Invoke the vision agent. Parameters: - input: a prompt that describe the task or a conversation in the format of + chat: A conversation in the format of [{"role": "user", "content": "describe your task here..."}]. - image: the input image referenced in the prompt parameter. + image: The input image referenced in the chat parameter. + reference_data: A dictionary containing the reference image, mask or bounding + box in the format of: + {"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]} + where the bounding box coordinates are normalized. + visualize_output: Whether to visualize the output. Returns: The result of the vision agent in text. @@ -522,12 +534,14 @@ def chat_with_workflow( """Chat with the vision agent and return the final answer and all tool results. Parameters: - chat: a conversation in the format of + chat: A conversation in the format of [{"role": "user", "content": "describe your task here..."}]. - image: the input image referenced in the chat parameter. - reference_data: a dictionary containing the reference image and mask. in the - format of {"image": "image.jpg", "mask": "mask.jpg} - visualize_output: whether to visualize the output. + image: The input image referenced in the chat parameter. + reference_data: A dictionary containing the reference image, mask or bounding + box in the format of: + {"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]} + where the bounding box coordinates are normalized. + visualize_output: Whether to visualize the output. Returns: A tuple where the first item is the final answer and the second item is a diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py index 4786f84b..23dc8506 100644 --- a/vision_agent/image_utils.py +++ b/vision_agent/image_utils.py @@ -211,7 +211,7 @@ def overlay_masks( } for label, mask in zip(masks["labels"], masks["masks"]): - if isinstance(mask, str): + if isinstance(mask, str) or isinstance(mask, Path): mask = np.array(Image.open(mask)) np_mask = np.zeros((image.size[1], image.size[0], 4)) np_mask[mask > 0, :] = color[label] + (255 * alpha,) @@ -221,7 +221,7 @@ def overlay_masks( def overlay_heat_map( - image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.8 + image: Union[str, Path, np.ndarray, ImageType], heat_map: Dict, alpha: float = 0.8 ) -> ImageType: r"""Plots heat map on to an image. @@ -238,14 +238,12 @@ def overlay_heat_map( elif isinstance(image, np.ndarray): image = Image.fromarray(image) - if "masks" not in masks: + if "heat_map" not in heat_map: return image.convert("RGB") - # Only one heat map per image, so no need to loop through masks image = image.convert("L") - - if isinstance(masks["masks"][0], str): - mask = b64_to_pil(masks["masks"][0]) + # Only one heat map per image, so no need to loop through masks + mask = Image.fromarray(heat_map["heat_map"][0]) overlay = Image.new("RGBA", mask.size) odraw = ImageDraw.Draw(overlay) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 32a998db..fa06a823 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -11,6 +11,7 @@ from PIL.Image import Image as ImageType from vision_agent.image_utils import ( + b64_to_pil, convert_to_b64, denormalize_bbox, get_image_size, @@ -516,7 +517,9 @@ def __call__(self, image: Union[str, ImageType]) -> Dict: "image": image_b64, "tool": "zero_shot_counting", } - return _send_inference_request(data, "tools") + resp_data = _send_inference_request(data, "tools") + resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0])) + return resp_data class VisualPromptCounting(Tool): @@ -585,7 +588,9 @@ def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict: "prompt": prompt, "tool": "few_shot_counting", } - return _send_inference_request(data, "tools") + resp_data = _send_inference_request(data, "tools") + resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0])) + return resp_data class VisualQuestionAnswering(Tool):