diff --git a/tests/test_tools.py b/tests/test_tools.py index ab2d20b7..b35cd7e3 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -36,7 +36,7 @@ def test_clip(): assert result["scores"] == [1.0] -def test_image_caption(): +def test_image_caption() -> None: img = Image.fromarray(ski.data.coins()) result = ImageCaption()(image=img) - assert result["text"] == ["a black and white photo of a coin"] + assert result["text"] diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index b02cdf72..8627b06c 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -8,7 +8,12 @@ from PIL import Image from tabulate import tabulate -from vision_agent.image_utils import overlay_bboxes, overlay_masks, overlay_heat_map +from vision_agent.image_utils import ( + convert_to_b64, + overlay_bboxes, + overlay_heat_map, + overlay_masks, +) from vision_agent.llm import LLM, OpenAILLM from vision_agent.lmm import LMM, OpenAILMM from vision_agent.tools import TOOLS @@ -481,6 +486,17 @@ def log_progress(self, description: str) -> None: if self.report_progress_callback: self.report_progress_callback(description) + def _report_visualization_via_callback( + self, images: Sequence[Union[str, Path]] + ) -> None: + """This is intended for streaming the visualization images via the callback to the client side.""" + if self.report_progress_callback: + self.report_progress_callback("") + if images: + for img in images: + self.report_progress_callback(f"{convert_to_b64(img)}") + self.report_progress_callback("") + def chat_with_workflow( self, chat: List[Dict[str, str]], @@ -577,9 +593,12 @@ def chat_with_workflow( ) if visualize_output: - visualized_output = all_tool_results[-1]["visualized_output"] - for image in visualized_output: - Image.open(image).show() + viz_images: Sequence[Union[str, Path]] = all_tool_results[-1][ + "visualized_output" + ] + self._report_visualization_via_callback(viz_images) + for img in viz_images: + Image.open(img).show() return final_answer, all_tool_results diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py index 43da645f..4786f84b 100644 --- a/vision_agent/image_utils.py +++ b/vision_agent/image_utils.py @@ -4,7 +4,7 @@ from importlib import resources from io import BytesIO from pathlib import Path -from typing import Dict, Tuple, Union, List +from typing import Dict, List, Tuple, Union import numpy as np from PIL import Image, ImageDraw, ImageFont @@ -108,7 +108,7 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str: data = Image.open(data) if isinstance(data, Image.Image): buffer = BytesIO() - data.convert("RGB").save(buffer, format="JPEG") + data.convert("RGB").save(buffer, format="PNG") return base64.b64encode(buffer.getvalue()).decode("utf-8") else: arr_bytes = data.tobytes()