Skip to content

Commit

Permalink
Fix zero shot count visualization (#65)
Browse files Browse the repository at this point in the history
* fixed issue with zero shot viz

* updated docs

* updated return for visual prompt counting

* add minor fixes which were causing issues

---------

Co-authored-by: shankar_ws3 <[email protected]>
  • Loading branch information
dillonalaird and shankar-vision-eng authored Apr 25, 2024
1 parent 0241adf commit 7f23463
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 18 deletions.
32 changes: 23 additions & 9 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def _handle_extract_frames(
image_to_data[image] = {
"bboxes": [],
"masks": [],
"heat_map": [],
"labels": [],
"scores": [],
}
Expand All @@ -340,9 +341,12 @@ def _handle_viz_tools(
return image_to_data

for param, call_result in zip(parameters, tool_result["call_results"]):
# calls can fail, so we need to check if the call was successful
# Calls can fail, so we need to check if the call was successful. It can either:
# 1. return a str or some error that's not a dictionary
# 2. return a dictionary but not have the necessary keys

if not isinstance(call_result, dict) or (
"bboxes" not in call_result and "masks" not in call_result
"bboxes" not in call_result and "heat_map" not in call_result
):
return image_to_data

Expand All @@ -352,6 +356,7 @@ def _handle_viz_tools(
image_to_data[image] = {
"bboxes": [],
"masks": [],
"heat_map": [],
"labels": [],
"scores": [],
}
Expand All @@ -360,6 +365,8 @@ def _handle_viz_tools(
image_to_data[image]["labels"].extend(call_result.get("labels", []))
image_to_data[image]["scores"].extend(call_result.get("scores", []))
image_to_data[image]["masks"].extend(call_result.get("masks", []))
# only single heatmap is returned
image_to_data[image]["heat_map"].append(call_result.get("heat_map", []))
if "mask_shape" in call_result:
image_to_data[image]["mask_shape"] = call_result["mask_shape"]

Expand Down Expand Up @@ -480,9 +487,14 @@ def __call__(
"""Invoke the vision agent.
Parameters:
input: a prompt that describe the task or a conversation in the format of
chat: A conversation in the format of
[{"role": "user", "content": "describe your task here..."}].
image: the input image referenced in the prompt parameter.
image: The input image referenced in the chat parameter.
reference_data: A dictionary containing the reference image, mask or bounding
box in the format of:
{"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]}
where the bounding box coordinates are normalized.
visualize_output: Whether to visualize the output.
Returns:
The result of the vision agent in text.
Expand Down Expand Up @@ -522,12 +534,14 @@ def chat_with_workflow(
"""Chat with the vision agent and return the final answer and all tool results.
Parameters:
chat: a conversation in the format of
chat: A conversation in the format of
[{"role": "user", "content": "describe your task here..."}].
image: the input image referenced in the chat parameter.
reference_data: a dictionary containing the reference image and mask. in the
format of {"image": "image.jpg", "mask": "mask.jpg}
visualize_output: whether to visualize the output.
image: The input image referenced in the chat parameter.
reference_data: A dictionary containing the reference image, mask or bounding
box in the format of:
{"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]}
where the bounding box coordinates are normalized.
visualize_output: Whether to visualize the output.
Returns:
A tuple where the first item is the final answer and the second item is a
Expand Down
12 changes: 5 additions & 7 deletions vision_agent/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def overlay_masks(
}

for label, mask in zip(masks["labels"], masks["masks"]):
if isinstance(mask, str):
if isinstance(mask, str) or isinstance(mask, Path):
mask = np.array(Image.open(mask))
np_mask = np.zeros((image.size[1], image.size[0], 4))
np_mask[mask > 0, :] = color[label] + (255 * alpha,)
Expand All @@ -221,7 +221,7 @@ def overlay_masks(


def overlay_heat_map(
image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.8
image: Union[str, Path, np.ndarray, ImageType], heat_map: Dict, alpha: float = 0.8
) -> ImageType:
r"""Plots heat map on to an image.
Expand All @@ -238,14 +238,12 @@ def overlay_heat_map(
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)

if "masks" not in masks:
if "heat_map" not in heat_map:
return image.convert("RGB")

# Only one heat map per image, so no need to loop through masks
image = image.convert("L")

if isinstance(masks["masks"][0], str):
mask = b64_to_pil(masks["masks"][0])
# Only one heat map per image, so no need to loop through masks
mask = Image.fromarray(heat_map["heat_map"][0])

overlay = Image.new("RGBA", mask.size)
odraw = ImageDraw.Draw(overlay)
Expand Down
9 changes: 7 additions & 2 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from PIL.Image import Image as ImageType

from vision_agent.image_utils import (
b64_to_pil,
convert_to_b64,
denormalize_bbox,
get_image_size,
Expand Down Expand Up @@ -516,7 +517,9 @@ def __call__(self, image: Union[str, ImageType]) -> Dict:
"image": image_b64,
"tool": "zero_shot_counting",
}
return _send_inference_request(data, "tools")
resp_data = _send_inference_request(data, "tools")
resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0]))
return resp_data


class VisualPromptCounting(Tool):
Expand Down Expand Up @@ -585,7 +588,9 @@ def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict:
"prompt": prompt,
"tool": "few_shot_counting",
}
return _send_inference_request(data, "tools")
resp_data = _send_inference_request(data, "tools")
resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0]))
return resp_data


class VisualQuestionAnswering(Tool):
Expand Down

0 comments on commit 7f23463

Please sign in to comment.