From 3bc737b9316e52111e48ccc19f28d4dded9a113e Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Wed, 24 Apr 2024 20:45:32 -0700 Subject: [PATCH 1/4] fixed issue with zero shot viz --- vision_agent/agent/vision_agent.py | 11 +++++++++-- vision_agent/image_utils.py | 12 +++++------- vision_agent/lmm/lmm.py | 5 +---- vision_agent/tools/__init__.py | 6 +++--- vision_agent/tools/tools.py | 7 +++++-- 5 files changed, 23 insertions(+), 18 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 93218e6c..414f36d6 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -340,9 +340,13 @@ def _handle_viz_tools( return image_to_data for param, call_result in zip(parameters, tool_result["call_results"]): - # calls can fail, so we need to check if the call was successful + # Calls can fail, so we need to check if the call was successful. It can either: + # 1. return a str or some error that's not a dictionary + # 2. return a dictionary but not have the necessary keys if not isinstance(call_result, dict) or ( - "bboxes" not in call_result and "masks" not in call_result + "bboxes" not in call_result + and "masks" not in call_result + and "heat_map" not in call_result ): return image_to_data @@ -352,6 +356,7 @@ def _handle_viz_tools( image_to_data[image] = { "bboxes": [], "masks": [], + "heat_map": [], "labels": [], "scores": [], } @@ -360,6 +365,8 @@ def _handle_viz_tools( image_to_data[image]["labels"].extend(call_result.get("labels", [])) image_to_data[image]["scores"].extend(call_result.get("scores", [])) image_to_data[image]["masks"].extend(call_result.get("masks", [])) + # only single heatmap is returned + image_to_data[image]["heat_map"].append(call_result.get("heat_map", [])) if "mask_shape" in call_result: image_to_data[image]["mask_shape"] = call_result["mask_shape"] diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py index 4786f84b..4f7b69cd 100644 --- a/vision_agent/image_utils.py +++ b/vision_agent/image_utils.py @@ -211,7 +211,7 @@ def overlay_masks( } for label, mask in zip(masks["labels"], masks["masks"]): - if isinstance(mask, str): + if isinstance(mask, str) or isinstance(mask, Path): mask = np.array(Image.open(mask)) np_mask = np.zeros((image.size[1], image.size[0], 4)) np_mask[mask > 0, :] = color[label] + (255 * alpha,) @@ -221,7 +221,7 @@ def overlay_masks( def overlay_heat_map( - image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.8 + image: Union[str, Path, np.ndarray, ImageType], heat_map: Dict, alpha: float = 0.8 ) -> ImageType: r"""Plots heat map on to an image. @@ -238,14 +238,12 @@ def overlay_heat_map( elif isinstance(image, np.ndarray): image = Image.fromarray(image) - if "masks" not in masks: + if "masks" not in heat_map: return image.convert("RGB") - # Only one heat map per image, so no need to loop through masks image = image.convert("L") - - if isinstance(masks["masks"][0], str): - mask = b64_to_pil(masks["masks"][0]) + # Only one heat map per image, so no need to loop through masks + mask = Image.fromarray(heat_map["heat_map"][0]) overlay = Image.new("RGBA", mask.size) odraw = ImageDraw.Draw(overlay) diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index a1fcc3c2..cc8861bd 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -9,10 +9,7 @@ import requests from openai import AzureOpenAI, OpenAI -from vision_agent.tools import ( - CHOOSE_PARAMS, - SYSTEM_PROMPT, -) +from vision_agent.tools import CHOOSE_PARAMS, SYSTEM_PROMPT _LOGGER = logging.getLogger(__name__) diff --git a/vision_agent/tools/__init__.py b/vision_agent/tools/__init__.py index 60870b56..10daf7eb 100644 --- a/vision_agent/tools/__init__.py +++ b/vision_agent/tools/__init__.py @@ -12,12 +12,12 @@ GroundingDINO, GroundingSAM, ImageCaption, - ZeroShotCounting, - VisualPromptCounting, - VisualQuestionAnswering, ImageQuestionAnswering, SegArea, SegIoU, Tool, + VisualPromptCounting, + VisualQuestionAnswering, + ZeroShotCounting, register_tool, ) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 3bf2bfbf..fd315bed 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -11,15 +11,16 @@ from PIL.Image import Image as ImageType from vision_agent.image_utils import ( + b64_to_pil, convert_to_b64, denormalize_bbox, get_image_size, normalize_bbox, rle_decode, ) +from vision_agent.lmm import OpenAILMM from vision_agent.tools.video import extract_frames_from_video from vision_agent.type_defs import LandingaiAPIKey -from vision_agent.lmm import OpenAILMM _LOGGER = logging.getLogger(__name__) _LND_API_KEY = LandingaiAPIKey().api_key @@ -516,7 +517,9 @@ def __call__(self, image: Union[str, ImageType]) -> Dict: "image": image_b64, "tool": "zero_shot_counting", } - return _send_inference_request(data, "tools") + resp_data = _send_inference_request(data, "tools") + resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0])) + return resp_data class VisualPromptCounting(Tool): From 8e3c52563a4f991c0b8ecbeaf97408db94b60de6 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Wed, 24 Apr 2024 20:54:39 -0700 Subject: [PATCH 2/4] updated docs --- vision_agent/agent/vision_agent.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 414f36d6..62dae768 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -473,9 +473,14 @@ def __call__( """Invoke the vision agent. Parameters: - input: a prompt that describe the task or a conversation in the format of + chat: A conversation in the format of [{"role": "user", "content": "describe your task here..."}]. - image: the input image referenced in the prompt parameter. + image: The input image referenced in the chat parameter. + reference_data: A dictionary containing the reference image, mask or bounding + box in the format of: + {"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]} + where the bounding box coordinates are normalized. + visualize_output: Whether to visualize the output. Returns: The result of the vision agent in text. @@ -515,12 +520,14 @@ def chat_with_workflow( """Chat with the vision agent and return the final answer and all tool results. Parameters: - chat: a conversation in the format of + chat: A conversation in the format of [{"role": "user", "content": "describe your task here..."}]. - image: the input image referenced in the chat parameter. - reference_data: a dictionary containing the reference image and mask. in the - format of {"image": "image.jpg", "mask": "mask.jpg} - visualize_output: whether to visualize the output. + image: The input image referenced in the chat parameter. + reference_data: A dictionary containing the reference image, mask or bounding + box in the format of: + {"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]} + where the bounding box coordinates are normalized. + visualize_output: Whether to visualize the output. Returns: A tuple where the first item is the final answer and the second item is a From 577fe664cb5b931bef7031f3e91207d86fa3844d Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Wed, 24 Apr 2024 20:55:58 -0700 Subject: [PATCH 3/4] updated return for visual prompt counting --- vision_agent/tools/tools.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index fd315bed..fa06a823 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -588,7 +588,9 @@ def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict: "prompt": prompt, "tool": "few_shot_counting", } - return _send_inference_request(data, "tools") + resp_data = _send_inference_request(data, "tools") + resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0])) + return resp_data class VisualQuestionAnswering(Tool): From adbc5451de6ebd8bdcd277fcc1911efe83364852 Mon Sep 17 00:00:00 2001 From: shankar_ws3 Date: Wed, 24 Apr 2024 22:52:28 -0700 Subject: [PATCH 4/4] add minor fixes which were causing issues --- vision_agent/agent/vision_agent.py | 6 +++--- vision_agent/image_utils.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 62dae768..34be1fc9 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -314,6 +314,7 @@ def _handle_extract_frames( image_to_data[image] = { "bboxes": [], "masks": [], + "heat_map": [], "labels": [], "scores": [], } @@ -343,10 +344,9 @@ def _handle_viz_tools( # Calls can fail, so we need to check if the call was successful. It can either: # 1. return a str or some error that's not a dictionary # 2. return a dictionary but not have the necessary keys + if not isinstance(call_result, dict) or ( - "bboxes" not in call_result - and "masks" not in call_result - and "heat_map" not in call_result + "bboxes" not in call_result and "heat_map" not in call_result ): return image_to_data diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py index 4f7b69cd..23dc8506 100644 --- a/vision_agent/image_utils.py +++ b/vision_agent/image_utils.py @@ -238,7 +238,7 @@ def overlay_heat_map( elif isinstance(image, np.ndarray): image = Image.fromarray(image) - if "masks" not in heat_map: + if "heat_map" not in heat_map: return image.convert("RGB") image = image.convert("L")