From d04db84a5ed05b491ba60590ae1c7d635ed1e4b9 Mon Sep 17 00:00:00 2001 From: wuyiqunLu <132986242+wuyiqunLu@users.noreply.github.com> Date: Wed, 11 Sep 2024 11:50:44 +0800 Subject: [PATCH] feat: add callback to stream back message (#232) * feat: add callback to stream back message * address comment * fix lint * change back test code * fix type error --- vision_agent/agent/vision_agent.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 776ab964..736c9754 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -3,7 +3,7 @@ import os import tempfile from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Tuple, Union, cast, Callable from vision_agent.agent import Agent from vision_agent.agent.agent_utils import extract_json @@ -13,7 +13,7 @@ VA_CODE, ) from vision_agent.lmm import LMM, Message, OpenAILMM -from vision_agent.tools import META_TOOL_DOCSTRING +from vision_agent.tools import META_TOOL_DOCSTRING, save_image, load_image from vision_agent.tools.meta_tools import Artifacts, use_extra_vision_agent_args from vision_agent.utils import CodeInterpreterFactory from vision_agent.utils.execute import CodeInterpreter, Execution @@ -123,6 +123,7 @@ def __init__( verbosity: int = 0, local_artifacts_path: Optional[Union[str, Path]] = None, code_sandbox_runtime: Optional[str] = None, + callback_message: Optional[Callable[[Dict[str, Any]], None]] = None, ) -> None: """Initialize the VisionAgent. @@ -141,6 +142,7 @@ def __init__( self.max_iterations = 100 self.verbosity = verbosity self.code_sandbox_runtime = code_sandbox_runtime + self.callback_message = callback_message if self.verbosity >= 1: _LOGGER.setLevel(logging.INFO) self.local_artifacts_path = cast( @@ -220,7 +222,14 @@ def chat_with_code( for chat_i in int_chat: if "media" in chat_i: for media in chat_i["media"]: - media = cast(str, media) + if type(media) is str and media.startswith(("http", "https")): + # TODO: Ideally we should not call VA.tools here, we should come to revisit how to better support remote image later + file_path = Path(media).name + ndarray = load_image(media) + save_image(ndarray, file_path) + media = file_path + else: + media = cast(str, media) artifacts.artifacts[Path(media).name] = open(media, "rb").read() media_remote_path = ( @@ -262,6 +271,7 @@ def chat_with_code( artifacts_loaded = artifacts.show() int_chat.append({"role": "observation", "content": artifacts_loaded}) orig_chat.append({"role": "observation", "content": artifacts_loaded}) + self.streaming_message({"role": "observation", "content": artifacts_loaded}) while not finished and iterations < self.max_iterations: response = run_conversation(self.agent, int_chat) @@ -274,6 +284,8 @@ def chat_with_code( if last_response == response: response["let_user_respond"] = True + self.streaming_message({"role": "assistant", "content": response}) + if response["let_user_respond"]: break @@ -293,6 +305,13 @@ def chat_with_code( orig_chat.append( {"role": "observation", "content": obs, "execution": result} ) + self.streaming_message( + { + "role": "observation", + "content": obs, + "execution": result, + } + ) iterations += 1 last_response = response @@ -305,5 +324,9 @@ def chat_with_code( artifacts.save() return orig_chat, artifacts + def streaming_message(self, message: Dict[str, Any]) -> None: + if self.callback_message: + self.callback_message(message) + def log_progress(self, data: Dict[str, Any]) -> None: pass