From 5ef03f21d9ccd1bd53152f5930996c0f3a372576 Mon Sep 17 00:00:00 2001 From: wuyiqunLu Date: Tue, 10 Sep 2024 20:50:50 +0800 Subject: [PATCH 1/5] feat: add callback to stream back message --- vision_agent/agent/vision_agent.py | 32 ++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 776ab964..d087ce86 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -1,9 +1,10 @@ import copy +import json import logging import os import tempfile from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Tuple, Union, cast, Callable from vision_agent.agent import Agent from vision_agent.agent.agent_utils import extract_json @@ -13,8 +14,9 @@ VA_CODE, ) from vision_agent.lmm import LMM, Message, OpenAILMM -from vision_agent.tools import META_TOOL_DOCSTRING +from vision_agent.tools import META_TOOL_DOCSTRING, save_image from vision_agent.tools.meta_tools import Artifacts, use_extra_vision_agent_args +from vision_agent.tools.tools import load_image from vision_agent.utils import CodeInterpreterFactory from vision_agent.utils.execute import CodeInterpreter, Execution @@ -123,6 +125,7 @@ def __init__( verbosity: int = 0, local_artifacts_path: Optional[Union[str, Path]] = None, code_sandbox_runtime: Optional[str] = None, + callback_message: Optional[Callable[[Message], None]] = None, ) -> None: """Initialize the VisionAgent. @@ -141,6 +144,7 @@ def __init__( self.max_iterations = 100 self.verbosity = verbosity self.code_sandbox_runtime = code_sandbox_runtime + self.callback_message = callback_message if self.verbosity >= 1: _LOGGER.setLevel(logging.INFO) self.local_artifacts_path = cast( @@ -181,7 +185,7 @@ def chat_with_code( self, chat: List[Message], artifacts: Optional[Artifacts] = None, - test_multi_plan: bool = True, + test_multi_plan: bool = False, customized_tool_names: Optional[List[str]] = None, ) -> Tuple[List[Message], Artifacts]: """Chat with VisionAgent, it will use code to execute actions to accomplish @@ -220,7 +224,13 @@ def chat_with_code( for chat_i in int_chat: if "media" in chat_i: for media in chat_i["media"]: - media = cast(str, media) + if type(media) is str and media.startswith(("http", "https")): + file_path = Path(media).name + ndarray = load_image(media) + save_image(ndarray, file_path) + media = file_path + else: + media = cast(str, media) artifacts.artifacts[Path(media).name] = open(media, "rb").read() media_remote_path = ( @@ -262,6 +272,7 @@ def chat_with_code( artifacts_loaded = artifacts.show() int_chat.append({"role": "observation", "content": artifacts_loaded}) orig_chat.append({"role": "observation", "content": artifacts_loaded}) + self.callback_message({"role": "observation", "content": artifacts_loaded}) while not finished and iterations < self.max_iterations: response = run_conversation(self.agent, int_chat) @@ -274,6 +285,8 @@ def chat_with_code( if last_response == response: response["let_user_respond"] = True + self.callback_message({"role": "assistant", "content": response}) + if response["let_user_respond"]: break @@ -293,6 +306,13 @@ def chat_with_code( orig_chat.append( {"role": "observation", "content": obs, "execution": result} ) + self.callback_message( + { + "role": "observation", + "content": obs, + "execution": result, + } + ) iterations += 1 last_response = response @@ -305,5 +325,9 @@ def chat_with_code( artifacts.save() return orig_chat, artifacts + def streaming_message(self, message: Message) -> None: + if self.callback_message: + self.callback_message(message) + def log_progress(self, data: Dict[str, Any]) -> None: pass From 2a8c779cb2b7e04246ecb2136fa15cbb90e10d30 Mon Sep 17 00:00:00 2001 From: wuyiqunLu Date: Wed, 11 Sep 2024 00:15:40 +0800 Subject: [PATCH 2/5] address comment --- vision_agent/agent/vision_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index d087ce86..d3450088 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -1,5 +1,4 @@ import copy -import json import logging import os import tempfile @@ -225,6 +224,7 @@ def chat_with_code( if "media" in chat_i: for media in chat_i["media"]: if type(media) is str and media.startswith(("http", "https")): + # TODO: Ideally we should not call VA.tools here, we should come to revisit how to better support remote image later file_path = Path(media).name ndarray = load_image(media) save_image(ndarray, file_path) From 4e0bfd996727277a5c31897bb58b0fd69809e25e Mon Sep 17 00:00:00 2001 From: wuyiqunLu Date: Wed, 11 Sep 2024 00:17:51 +0800 Subject: [PATCH 3/5] fix lint --- vision_agent/agent/vision_agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index d3450088..cc8b4f53 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -272,7 +272,7 @@ def chat_with_code( artifacts_loaded = artifacts.show() int_chat.append({"role": "observation", "content": artifacts_loaded}) orig_chat.append({"role": "observation", "content": artifacts_loaded}) - self.callback_message({"role": "observation", "content": artifacts_loaded}) + self.streaming_message({"role": "observation", "content": artifacts_loaded}) while not finished and iterations < self.max_iterations: response = run_conversation(self.agent, int_chat) @@ -285,7 +285,7 @@ def chat_with_code( if last_response == response: response["let_user_respond"] = True - self.callback_message({"role": "assistant", "content": response}) + self.streaming_message({"role": "assistant", "content": response}) if response["let_user_respond"]: break @@ -306,7 +306,7 @@ def chat_with_code( orig_chat.append( {"role": "observation", "content": obs, "execution": result} ) - self.callback_message( + self.streaming_message( { "role": "observation", "content": obs, From 53fff2dd43ebb8fe1f29840160decdcb3749b2a5 Mon Sep 17 00:00:00 2001 From: wuyiqunLu Date: Wed, 11 Sep 2024 11:43:51 +0800 Subject: [PATCH 4/5] change back test code --- vision_agent/agent/vision_agent.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index cc8b4f53..b60e8391 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -13,9 +13,8 @@ VA_CODE, ) from vision_agent.lmm import LMM, Message, OpenAILMM -from vision_agent.tools import META_TOOL_DOCSTRING, save_image +from vision_agent.tools import META_TOOL_DOCSTRING, save_image, load_image from vision_agent.tools.meta_tools import Artifacts, use_extra_vision_agent_args -from vision_agent.tools.tools import load_image from vision_agent.utils import CodeInterpreterFactory from vision_agent.utils.execute import CodeInterpreter, Execution @@ -184,7 +183,7 @@ def chat_with_code( self, chat: List[Message], artifacts: Optional[Artifacts] = None, - test_multi_plan: bool = False, + test_multi_plan: bool = True, customized_tool_names: Optional[List[str]] = None, ) -> Tuple[List[Message], Artifacts]: """Chat with VisionAgent, it will use code to execute actions to accomplish From 77cf29bbd4d860f8324e9179fa053672bb801516 Mon Sep 17 00:00:00 2001 From: wuyiqunLu Date: Wed, 11 Sep 2024 11:47:35 +0800 Subject: [PATCH 5/5] fix type error --- vision_agent/agent/vision_agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index b60e8391..736c9754 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -123,7 +123,7 @@ def __init__( verbosity: int = 0, local_artifacts_path: Optional[Union[str, Path]] = None, code_sandbox_runtime: Optional[str] = None, - callback_message: Optional[Callable[[Message], None]] = None, + callback_message: Optional[Callable[[Dict[str, Any]], None]] = None, ) -> None: """Initialize the VisionAgent. @@ -324,7 +324,7 @@ def chat_with_code( artifacts.save() return orig_chat, artifacts - def streaming_message(self, message: Message) -> None: + def streaming_message(self, message: Dict[str, Any]) -> None: if self.callback_message: self.callback_message(message)