diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index b02cdf72..a15eddaa 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -8,7 +8,12 @@ from PIL import Image from tabulate import tabulate -from vision_agent.image_utils import overlay_bboxes, overlay_masks, overlay_heat_map +from vision_agent.image_utils import ( + convert_to_b64, + overlay_bboxes, + overlay_heat_map, + overlay_masks, +) from vision_agent.llm import LLM, OpenAILLM from vision_agent.lmm import LMM, OpenAILMM from vision_agent.tools import TOOLS @@ -481,6 +486,15 @@ def log_progress(self, description: str) -> None: if self.report_progress_callback: self.report_progress_callback(description) + def _report_visualization_via_callback(self, images: List[Union[str, Path]]) -> None: + """This is intended for streaming the visualization images via the callback to the client side.""" + if self.report_progress_callback: + self.report_progress_callback("" ) + if images: + for img in images: + self.report_progress_callback(f"{convert_to_b64(img)}") + self.report_progress_callback("" ) + def chat_with_workflow( self, chat: List[Dict[str, str]], @@ -578,6 +592,7 @@ def chat_with_workflow( if visualize_output: visualized_output = all_tool_results[-1]["visualized_output"] + self._report_visualization_via_callback(visualized_output) for image in visualized_output: Image.open(image).show() diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py index 43da645f..5a28df1d 100644 --- a/vision_agent/image_utils.py +++ b/vision_agent/image_utils.py @@ -4,7 +4,7 @@ from importlib import resources from io import BytesIO from pathlib import Path -from typing import Dict, Tuple, Union, List +from typing import Dict, List, Tuple, Union import numpy as np from PIL import Image, ImageDraw, ImageFont