|
8 | 8 | from PIL import Image |
9 | 9 | from tabulate import tabulate |
10 | 10 |
|
11 | | -from vision_agent.image_utils import overlay_bboxes, overlay_masks, overlay_heat_map |
| 11 | +from vision_agent.image_utils import ( |
| 12 | + convert_to_b64, |
| 13 | + overlay_bboxes, |
| 14 | + overlay_heat_map, |
| 15 | + overlay_masks, |
| 16 | +) |
12 | 17 | from vision_agent.llm import LLM, OpenAILLM |
13 | 18 | from vision_agent.lmm import LMM, OpenAILMM |
14 | 19 | from vision_agent.tools import TOOLS |
@@ -481,6 +486,17 @@ def log_progress(self, description: str) -> None: |
481 | 486 | if self.report_progress_callback: |
482 | 487 | self.report_progress_callback(description) |
483 | 488 |
|
| 489 | + def _report_visualization_via_callback( |
| 490 | + self, images: Sequence[Union[str, Path]] |
| 491 | + ) -> None: |
| 492 | + """This is intended for streaming the visualization images via the callback to the client side.""" |
| 493 | + if self.report_progress_callback: |
| 494 | + self.report_progress_callback("<VIZ>") |
| 495 | + if images: |
| 496 | + for img in images: |
| 497 | + self.report_progress_callback(f"<IMG>{convert_to_b64(img)}</IMG>") |
| 498 | + self.report_progress_callback("</VIZ>") |
| 499 | + |
484 | 500 | def chat_with_workflow( |
485 | 501 | self, |
486 | 502 | chat: List[Dict[str, str]], |
@@ -577,9 +593,12 @@ def chat_with_workflow( |
577 | 593 | ) |
578 | 594 |
|
579 | 595 | if visualize_output: |
580 | | - visualized_output = all_tool_results[-1]["visualized_output"] |
581 | | - for image in visualized_output: |
582 | | - Image.open(image).show() |
| 596 | + viz_images: Sequence[Union[str, Path]] = all_tool_results[-1][ |
| 597 | + "visualized_output" |
| 598 | + ] |
| 599 | + self._report_visualization_via_callback(viz_images) |
| 600 | + for img in viz_images: |
| 601 | + Image.open(img).show() |
583 | 602 |
|
584 | 603 | return final_answer, all_tool_results |
585 | 604 |
|
|
0 commit comments