Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix zero shot count visualization #65

Merged
merged 4 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def _handle_extract_frames(
image_to_data[image] = {
"bboxes": [],
"masks": [],
"heat_map": [],
"labels": [],
"scores": [],
}
Expand All @@ -340,9 +341,12 @@ def _handle_viz_tools(
return image_to_data

for param, call_result in zip(parameters, tool_result["call_results"]):
# calls can fail, so we need to check if the call was successful
# Calls can fail, so we need to check if the call was successful. It can either:
# 1. return a str or some error that's not a dictionary
# 2. return a dictionary but not have the necessary keys

if not isinstance(call_result, dict) or (
"bboxes" not in call_result and "masks" not in call_result
"bboxes" not in call_result and "heat_map" not in call_result
):
return image_to_data

Expand All @@ -352,6 +356,7 @@ def _handle_viz_tools(
image_to_data[image] = {
"bboxes": [],
"masks": [],
"heat_map": [],
"labels": [],
"scores": [],
}
Expand All @@ -360,6 +365,8 @@ def _handle_viz_tools(
image_to_data[image]["labels"].extend(call_result.get("labels", []))
image_to_data[image]["scores"].extend(call_result.get("scores", []))
image_to_data[image]["masks"].extend(call_result.get("masks", []))
# only single heatmap is returned
image_to_data[image]["heat_map"].append(call_result.get("heat_map", []))
if "mask_shape" in call_result:
image_to_data[image]["mask_shape"] = call_result["mask_shape"]

Expand Down Expand Up @@ -466,9 +473,14 @@ def __call__(
"""Invoke the vision agent.

Parameters:
input: a prompt that describe the task or a conversation in the format of
chat: A conversation in the format of
[{"role": "user", "content": "describe your task here..."}].
image: the input image referenced in the prompt parameter.
image: The input image referenced in the chat parameter.
reference_data: A dictionary containing the reference image, mask or bounding
box in the format of:
{"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]}
where the bounding box coordinates are normalized.
visualize_output: Whether to visualize the output.

Returns:
The result of the vision agent in text.
Expand Down Expand Up @@ -508,12 +520,14 @@ def chat_with_workflow(
"""Chat with the vision agent and return the final answer and all tool results.

Parameters:
chat: a conversation in the format of
chat: A conversation in the format of
[{"role": "user", "content": "describe your task here..."}].
image: the input image referenced in the chat parameter.
reference_data: a dictionary containing the reference image and mask. in the
format of {"image": "image.jpg", "mask": "mask.jpg}
visualize_output: whether to visualize the output.
image: The input image referenced in the chat parameter.
reference_data: A dictionary containing the reference image, mask or bounding
box in the format of:
{"image": "image.jpg", "mask": "mask.jpg", "bbox": [0.1, 0.2, 0.1, 0.2]}
where the bounding box coordinates are normalized.
visualize_output: Whether to visualize the output.

Returns:
A tuple where the first item is the final answer and the second item is a
Expand Down
12 changes: 5 additions & 7 deletions vision_agent/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def overlay_masks(
}

for label, mask in zip(masks["labels"], masks["masks"]):
if isinstance(mask, str):
if isinstance(mask, str) or isinstance(mask, Path):
mask = np.array(Image.open(mask))
np_mask = np.zeros((image.size[1], image.size[0], 4))
np_mask[mask > 0, :] = color[label] + (255 * alpha,)
Expand All @@ -221,7 +221,7 @@ def overlay_masks(


def overlay_heat_map(
image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.8
image: Union[str, Path, np.ndarray, ImageType], heat_map: Dict, alpha: float = 0.8
) -> ImageType:
r"""Plots heat map on to an image.

Expand All @@ -238,14 +238,12 @@ def overlay_heat_map(
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)

if "masks" not in masks:
if "heat_map" not in heat_map:
return image.convert("RGB")

# Only one heat map per image, so no need to loop through masks
image = image.convert("L")

if isinstance(masks["masks"][0], str):
mask = b64_to_pil(masks["masks"][0])
# Only one heat map per image, so no need to loop through masks
mask = Image.fromarray(heat_map["heat_map"][0])

overlay = Image.new("RGBA", mask.size)
odraw = ImageDraw.Draw(overlay)
Expand Down
5 changes: 1 addition & 4 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
import requests
from openai import AzureOpenAI, OpenAI

from vision_agent.tools import (
CHOOSE_PARAMS,
SYSTEM_PROMPT,
)
from vision_agent.tools import CHOOSE_PARAMS, SYSTEM_PROMPT

_LOGGER = logging.getLogger(__name__)

Expand Down
6 changes: 3 additions & 3 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
GroundingDINO,
GroundingSAM,
ImageCaption,
ZeroShotCounting,
VisualPromptCounting,
VisualQuestionAnswering,
ImageQuestionAnswering,
SegArea,
SegIoU,
Tool,
VisualPromptCounting,
VisualQuestionAnswering,
ZeroShotCounting,
register_tool,
)
11 changes: 8 additions & 3 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@
from PIL.Image import Image as ImageType

from vision_agent.image_utils import (
b64_to_pil,
convert_to_b64,
denormalize_bbox,
get_image_size,
normalize_bbox,
rle_decode,
)
from vision_agent.lmm import OpenAILMM
from vision_agent.tools.video import extract_frames_from_video
from vision_agent.type_defs import LandingaiAPIKey
from vision_agent.lmm import OpenAILMM

_LOGGER = logging.getLogger(__name__)
_LND_API_KEY = LandingaiAPIKey().api_key
Expand Down Expand Up @@ -516,7 +517,9 @@ def __call__(self, image: Union[str, ImageType]) -> Dict:
"image": image_b64,
"tool": "zero_shot_counting",
}
return _send_inference_request(data, "tools")
resp_data = _send_inference_request(data, "tools")
resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0]))
return resp_data


class VisualPromptCounting(Tool):
Expand Down Expand Up @@ -585,7 +588,9 @@ def __call__(self, image: Union[str, ImageType], prompt: str) -> Dict:
"prompt": prompt,
"tool": "few_shot_counting",
}
return _send_inference_request(data, "tools")
resp_data = _send_inference_request(data, "tools")
resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0]))
return resp_data


class VisualQuestionAnswering(Tool):
Expand Down
Loading