From fad3b18d8ff6445c6271c30e286a8cbf64bb0ea5 Mon Sep 17 00:00:00 2001 From: Yazhou Cao Date: Tue, 23 Apr 2024 09:30:38 -0700 Subject: [PATCH 1/4] Send visualization images via callback --- vision_agent/agent/vision_agent.py | 17 ++++++++++++++++- vision_agent/image_utils.py | 2 +- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index b02cdf72..a15eddaa 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -8,7 +8,12 @@ from PIL import Image from tabulate import tabulate -from vision_agent.image_utils import overlay_bboxes, overlay_masks, overlay_heat_map +from vision_agent.image_utils import ( + convert_to_b64, + overlay_bboxes, + overlay_heat_map, + overlay_masks, +) from vision_agent.llm import LLM, OpenAILLM from vision_agent.lmm import LMM, OpenAILMM from vision_agent.tools import TOOLS @@ -481,6 +486,15 @@ def log_progress(self, description: str) -> None: if self.report_progress_callback: self.report_progress_callback(description) + def _report_visualization_via_callback(self, images: List[Union[str, Path]]) -> None: + """This is intended for streaming the visualization images via the callback to the client side.""" + if self.report_progress_callback: + self.report_progress_callback("" ) + if images: + for img in images: + self.report_progress_callback(f"{convert_to_b64(img)}") + self.report_progress_callback("" ) + def chat_with_workflow( self, chat: List[Dict[str, str]], @@ -578,6 +592,7 @@ def chat_with_workflow( if visualize_output: visualized_output = all_tool_results[-1]["visualized_output"] + self._report_visualization_via_callback(visualized_output) for image in visualized_output: Image.open(image).show() diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py index 43da645f..5a28df1d 100644 --- a/vision_agent/image_utils.py +++ b/vision_agent/image_utils.py @@ -4,7 +4,7 @@ from importlib import resources from io import BytesIO from pathlib import Path -from typing import Dict, Tuple, Union, List +from typing import Dict, List, Tuple, Union import numpy as np from PIL import Image, ImageDraw, ImageFont From 5bc186badbf1afcdb7f183b7692c116453a64bfc Mon Sep 17 00:00:00 2001 From: Yazhou Cao Date: Tue, 23 Apr 2024 09:39:21 -0700 Subject: [PATCH 2/4] Fix lint errors --- vision_agent/agent/vision_agent.py | 8 +++++--- vision_agent/image_utils.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index a15eddaa..81224e44 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -486,14 +486,16 @@ def log_progress(self, description: str) -> None: if self.report_progress_callback: self.report_progress_callback(description) - def _report_visualization_via_callback(self, images: List[Union[str, Path]]) -> None: + def _report_visualization_via_callback( + self, images: Sequence[Union[str, Path]] + ) -> None: """This is intended for streaming the visualization images via the callback to the client side.""" if self.report_progress_callback: - self.report_progress_callback("" ) + self.report_progress_callback("") if images: for img in images: self.report_progress_callback(f"{convert_to_b64(img)}") - self.report_progress_callback("" ) + self.report_progress_callback("") def chat_with_workflow( self, diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py index 5a28df1d..4786f84b 100644 --- a/vision_agent/image_utils.py +++ b/vision_agent/image_utils.py @@ -108,7 +108,7 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str: data = Image.open(data) if isinstance(data, Image.Image): buffer = BytesIO() - data.convert("RGB").save(buffer, format="JPEG") + data.convert("RGB").save(buffer, format="PNG") return base64.b64encode(buffer.getvalue()).decode("utf-8") else: arr_bytes = data.tobytes() From c325965a312e14091e9cf06b72848268615de249 Mon Sep 17 00:00:00 2001 From: Yazhou Cao Date: Tue, 23 Apr 2024 09:42:56 -0700 Subject: [PATCH 3/4] Fix lint errors --- tests/test_tools.py | 4 ++-- vision_agent/agent/vision_agent.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_tools.py b/tests/test_tools.py index ab2d20b7..b35cd7e3 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -36,7 +36,7 @@ def test_clip(): assert result["scores"] == [1.0] -def test_image_caption(): +def test_image_caption() -> None: img = Image.fromarray(ski.data.coins()) result = ImageCaption()(image=img) - assert result["text"] == ["a black and white photo of a coin"] + assert result["text"] diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 81224e44..3928ea71 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -593,9 +593,9 @@ def chat_with_workflow( ) if visualize_output: - visualized_output = all_tool_results[-1]["visualized_output"] - self._report_visualization_via_callback(visualized_output) - for image in visualized_output: + viz_images = all_tool_results[-1]["visualized_output"] + self._report_visualization_via_callback(viz_images) + for image in viz_images: Image.open(image).show() return final_answer, all_tool_results From 884ccde28a7c453ea0ad9194454578885c495e0b Mon Sep 17 00:00:00 2001 From: Yazhou Cao Date: Tue, 23 Apr 2024 09:46:13 -0700 Subject: [PATCH 4/4] Fix mypy error --- vision_agent/agent/vision_agent.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 3928ea71..8627b06c 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -593,10 +593,12 @@ def chat_with_workflow( ) if visualize_output: - viz_images = all_tool_results[-1]["visualized_output"] + viz_images: Sequence[Union[str, Path]] = all_tool_results[-1][ + "visualized_output" + ] self._report_visualization_via_callback(viz_images) - for image in viz_images: - Image.open(image).show() + for img in viz_images: + Image.open(img).show() return final_answer, all_tool_results