Skip to content

Commit

Permalink
Change it back to numpy array
Browse files Browse the repository at this point in the history
  • Loading branch information
humpydonkey committed Jun 4, 2024
1 parent 583b3a2 commit c1a43f9
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def save_image(image: np.ndarray) -> str:

def overlay_bounding_boxes(
image: np.ndarray, bboxes: List[Dict[str, Any]]
) -> Image.Image:
) -> np.ndarray:
"""'overlay_bounding_boxes' is a utility function that displays bounding boxes on
an image.
Expand All @@ -533,7 +533,7 @@ def overlay_bounding_boxes(
boxes.
Returns:
PIL.Image.Image: The image with the bounding boxes, labels and scores displayed.
np.ndarray: The image with the bounding boxes, labels and scores displayed.
Example
-------
Expand Down Expand Up @@ -577,12 +577,12 @@ def overlay_bounding_boxes(
text_box = draw.textbbox((box[0], box[1]), text=text, font=font)
draw.rectangle((box[0], box[1], text_box[2], text_box[3]), fill=color[label])
draw.text((box[0], box[1]), text, fill="black", font=font)
return pil_image.convert("RGB")
return np.array(pil_image.convert("RGB"))


def overlay_segmentation_masks(
image: np.ndarray, masks: List[Dict[str, Any]]
) -> Image.Image:
) -> np.ndarray:
"""'overlay_segmentation_masks' is a utility function that displays segmentation
masks.
Expand All @@ -591,7 +591,7 @@ def overlay_segmentation_masks(
masks (List[Dict[str, Any]]): A list of dictionaries containing the masks.
Returns:
PIL.Image.Image: The image with the masks displayed.
np.ndarray: The image with the masks displayed.
Example
-------
Expand Down Expand Up @@ -627,12 +627,12 @@ def overlay_segmentation_masks(
np_mask[mask > 0, :] = color[label] + (255 * 0.5,)
mask_img = Image.fromarray(np_mask.astype(np.uint8))
pil_image = Image.alpha_composite(pil_image, mask_img)
return pil_image.convert("RGB")
return np.array(pil_image.convert("RGB"))


def overlay_heat_map(
image: np.ndarray, heat_map: Dict[str, Any], alpha: float = 0.8
) -> Image.Image:
) -> np.ndarray:
"""'overlay_heat_map' is a utility function that displays a heat map on an image.
Parameters:
Expand All @@ -642,7 +642,7 @@ def overlay_heat_map(
alpha (float, optional): The transparency of the overlay. Defaults to 0.8.
Returns:
PIL.Image.Image: The image with the heat map displayed.
np.ndarray: The image with the heat map displayed.
Example
-------
Expand All @@ -660,7 +660,7 @@ def overlay_heat_map(
pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")

if "heat_map" not in heat_map or len(heat_map["heat_map"]) == 0:
return pil_image
return image

pil_image = pil_image.convert("L")
mask = Image.fromarray(heat_map["heat_map"])
Expand All @@ -672,7 +672,7 @@ def overlay_heat_map(
combined = Image.alpha_composite(
pil_image.convert("RGBA"), overlay.resize(pil_image.size)
)
return combined.convert("RGB")
return np.array(combined.convert("RGB"))


def get_tool_documentation(funcs: List[Callable[..., Any]]) -> str:
Expand Down

0 comments on commit c1a43f9

Please sign in to comment.