Skip to content

Commit 648c9f1

Browse files
humpydonkeyAsiaCao
andauthored
feat(vision_agent): report visualization images via callback (landing-ai#62)
* Send visualization images via callback * Fix lint errors * Fix lint errors * Fix mypy error --------- Co-authored-by: Yazhou Cao <[email protected]>
1 parent 4dc0cfb commit 648c9f1

File tree

3 files changed

+27
-8
lines changed

3 files changed

+27
-8
lines changed

tests/test_tools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_clip():
3636
assert result["scores"] == [1.0]
3737

3838

39-
def test_image_caption():
39+
def test_image_caption() -> None:
4040
img = Image.fromarray(ski.data.coins())
4141
result = ImageCaption()(image=img)
42-
assert result["text"] == ["a black and white photo of a coin"]
42+
assert result["text"]

vision_agent/agent/vision_agent.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
from PIL import Image
99
from tabulate import tabulate
1010

11-
from vision_agent.image_utils import overlay_bboxes, overlay_masks, overlay_heat_map
11+
from vision_agent.image_utils import (
12+
convert_to_b64,
13+
overlay_bboxes,
14+
overlay_heat_map,
15+
overlay_masks,
16+
)
1217
from vision_agent.llm import LLM, OpenAILLM
1318
from vision_agent.lmm import LMM, OpenAILMM
1419
from vision_agent.tools import TOOLS
@@ -481,6 +486,17 @@ def log_progress(self, description: str) -> None:
481486
if self.report_progress_callback:
482487
self.report_progress_callback(description)
483488

489+
def _report_visualization_via_callback(
490+
self, images: Sequence[Union[str, Path]]
491+
) -> None:
492+
"""This is intended for streaming the visualization images via the callback to the client side."""
493+
if self.report_progress_callback:
494+
self.report_progress_callback("<VIZ>")
495+
if images:
496+
for img in images:
497+
self.report_progress_callback(f"<IMG>{convert_to_b64(img)}</IMG>")
498+
self.report_progress_callback("</VIZ>")
499+
484500
def chat_with_workflow(
485501
self,
486502
chat: List[Dict[str, str]],
@@ -577,9 +593,12 @@ def chat_with_workflow(
577593
)
578594

579595
if visualize_output:
580-
visualized_output = all_tool_results[-1]["visualized_output"]
581-
for image in visualized_output:
582-
Image.open(image).show()
596+
viz_images: Sequence[Union[str, Path]] = all_tool_results[-1][
597+
"visualized_output"
598+
]
599+
self._report_visualization_via_callback(viz_images)
600+
for img in viz_images:
601+
Image.open(img).show()
583602

584603
return final_answer, all_tool_results
585604

vision_agent/image_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from importlib import resources
55
from io import BytesIO
66
from pathlib import Path
7-
from typing import Dict, Tuple, Union, List
7+
from typing import Dict, List, Tuple, Union
88

99
import numpy as np
1010
from PIL import Image, ImageDraw, ImageFont
@@ -108,7 +108,7 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
108108
data = Image.open(data)
109109
if isinstance(data, Image.Image):
110110
buffer = BytesIO()
111-
data.convert("RGB").save(buffer, format="JPEG")
111+
data.convert("RGB").save(buffer, format="PNG")
112112
return base64.b64encode(buffer.getvalue()).decode("utf-8")
113113
else:
114114
arr_bytes = data.tobytes()

0 commit comments

Comments
 (0)