Skip to content

Commit

Permalink
Merge branch 'main' into custom-tools
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird authored Apr 23, 2024
2 parents 0f3a55c + cec32f7 commit 3aca8c7
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 8 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "vision-agent"
version = "0.2.1"
version = "0.2.2"
description = "Toolset for Vision Agent"
authors = ["Landing AI <[email protected]>"]
readme = "README.md"
Expand Down
4 changes: 2 additions & 2 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_clip():
assert result["scores"] == [1.0]


def test_image_caption():
def test_image_caption() -> None:
img = Image.fromarray(ski.data.coins())
result = ImageCaption()(image=img)
assert result["text"] == ["a black and white photo of a coin"]
assert result["text"]
27 changes: 23 additions & 4 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from PIL import Image
from tabulate import tabulate

from vision_agent.image_utils import overlay_bboxes, overlay_heat_map, overlay_masks
from vision_agent.image_utils import (
convert_to_b64,
overlay_bboxes,
overlay_heat_map,
overlay_masks,
)
from vision_agent.llm import LLM, OpenAILLM
from vision_agent.lmm import LMM, OpenAILMM
from vision_agent.tools import TOOLS
Expand Down Expand Up @@ -481,6 +486,17 @@ def log_progress(self, description: str) -> None:
if self.report_progress_callback:
self.report_progress_callback(description)

def _report_visualization_via_callback(
self, images: Sequence[Union[str, Path]]
) -> None:
"""This is intended for streaming the visualization images via the callback to the client side."""
if self.report_progress_callback:
self.report_progress_callback("<VIZ>")
if images:
for img in images:
self.report_progress_callback(f"<IMG>{convert_to_b64(img)}</IMG>")
self.report_progress_callback("</VIZ>")

def chat_with_workflow(
self,
chat: List[Dict[str, str]],
Expand Down Expand Up @@ -577,9 +593,12 @@ def chat_with_workflow(
)

if visualize_output:
visualized_output = all_tool_results[-1]["visualized_output"]
for image in visualized_output:
Image.open(image).show()
viz_images: Sequence[Union[str, Path]] = all_tool_results[-1][
"visualized_output"
]
self._report_visualization_via_callback(viz_images)
for img in viz_images:
Image.open(img).show()

return final_answer, all_tool_results

Expand Down
2 changes: 1 addition & 1 deletion vision_agent/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
data = Image.open(data)
if isinstance(data, Image.Image):
buffer = BytesIO()
data.convert("RGB").save(buffer, format="JPEG")
data.convert("RGB").save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
else:
arr_bytes = data.tobytes()
Expand Down

0 comments on commit 3aca8c7

Please sign in to comment.