Skip to content

Commit

Permalink
fixed issue with zero shot viz
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Apr 25, 2024
1 parent 63c90a1 commit 3bc737b
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 18 deletions.
11 changes: 9 additions & 2 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,13 @@ def _handle_viz_tools(
return image_to_data

for param, call_result in zip(parameters, tool_result["call_results"]):
# calls can fail, so we need to check if the call was successful
# Calls can fail, so we need to check if the call was successful. It can either:
# 1. return a str or some error that's not a dictionary
# 2. return a dictionary but not have the necessary keys
if not isinstance(call_result, dict) or (
"bboxes" not in call_result and "masks" not in call_result
"bboxes" not in call_result
and "masks" not in call_result
and "heat_map" not in call_result
):
return image_to_data

Expand All @@ -352,6 +356,7 @@ def _handle_viz_tools(
image_to_data[image] = {
"bboxes": [],
"masks": [],
"heat_map": [],
"labels": [],
"scores": [],
}
Expand All @@ -360,6 +365,8 @@ def _handle_viz_tools(
image_to_data[image]["labels"].extend(call_result.get("labels", []))
image_to_data[image]["scores"].extend(call_result.get("scores", []))
image_to_data[image]["masks"].extend(call_result.get("masks", []))
# only single heatmap is returned
image_to_data[image]["heat_map"].append(call_result.get("heat_map", []))
if "mask_shape" in call_result:
image_to_data[image]["mask_shape"] = call_result["mask_shape"]

Expand Down
12 changes: 5 additions & 7 deletions vision_agent/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def overlay_masks(
}

for label, mask in zip(masks["labels"], masks["masks"]):
if isinstance(mask, str):
if isinstance(mask, str) or isinstance(mask, Path):
mask = np.array(Image.open(mask))
np_mask = np.zeros((image.size[1], image.size[0], 4))
np_mask[mask > 0, :] = color[label] + (255 * alpha,)
Expand All @@ -221,7 +221,7 @@ def overlay_masks(


def overlay_heat_map(
image: Union[str, Path, np.ndarray, ImageType], masks: Dict, alpha: float = 0.8
image: Union[str, Path, np.ndarray, ImageType], heat_map: Dict, alpha: float = 0.8
) -> ImageType:
r"""Plots heat map on to an image.
Expand All @@ -238,14 +238,12 @@ def overlay_heat_map(
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)

if "masks" not in masks:
if "masks" not in heat_map:
return image.convert("RGB")

# Only one heat map per image, so no need to loop through masks
image = image.convert("L")

if isinstance(masks["masks"][0], str):
mask = b64_to_pil(masks["masks"][0])
# Only one heat map per image, so no need to loop through masks
mask = Image.fromarray(heat_map["heat_map"][0])

overlay = Image.new("RGBA", mask.size)
odraw = ImageDraw.Draw(overlay)
Expand Down
5 changes: 1 addition & 4 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
import requests
from openai import AzureOpenAI, OpenAI

from vision_agent.tools import (
CHOOSE_PARAMS,
SYSTEM_PROMPT,
)
from vision_agent.tools import CHOOSE_PARAMS, SYSTEM_PROMPT

_LOGGER = logging.getLogger(__name__)

Expand Down
6 changes: 3 additions & 3 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
GroundingDINO,
GroundingSAM,
ImageCaption,
ZeroShotCounting,
VisualPromptCounting,
VisualQuestionAnswering,
ImageQuestionAnswering,
SegArea,
SegIoU,
Tool,
VisualPromptCounting,
VisualQuestionAnswering,
ZeroShotCounting,
register_tool,
)
7 changes: 5 additions & 2 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@
from PIL.Image import Image as ImageType

from vision_agent.image_utils import (
b64_to_pil,
convert_to_b64,
denormalize_bbox,
get_image_size,
normalize_bbox,
rle_decode,
)
from vision_agent.lmm import OpenAILMM
from vision_agent.tools.video import extract_frames_from_video
from vision_agent.type_defs import LandingaiAPIKey
from vision_agent.lmm import OpenAILMM

_LOGGER = logging.getLogger(__name__)
_LND_API_KEY = LandingaiAPIKey().api_key
Expand Down Expand Up @@ -516,7 +517,9 @@ def __call__(self, image: Union[str, ImageType]) -> Dict:
"image": image_b64,
"tool": "zero_shot_counting",
}
return _send_inference_request(data, "tools")
resp_data = _send_inference_request(data, "tools")
resp_data["heat_map"] = np.array(b64_to_pil(resp_data["heat_map"][0]))
return resp_data


class VisualPromptCounting(Tool):
Expand Down

0 comments on commit 3bc737b

Please sign in to comment.