Skip to content

Commit

Permalink
feat: add callback to stream back message (#232)
Browse files Browse the repository at this point in the history
* feat: add callback to stream back message

* address comment

* fix lint

* change back test code

* fix type error
  • Loading branch information
wuyiqunLu authored Sep 11, 2024
1 parent 66dfbe1 commit d04db84
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Tuple, Union, cast, Callable

from vision_agent.agent import Agent
from vision_agent.agent.agent_utils import extract_json
Expand All @@ -13,7 +13,7 @@
VA_CODE,
)
from vision_agent.lmm import LMM, Message, OpenAILMM
from vision_agent.tools import META_TOOL_DOCSTRING
from vision_agent.tools import META_TOOL_DOCSTRING, save_image, load_image
from vision_agent.tools.meta_tools import Artifacts, use_extra_vision_agent_args
from vision_agent.utils import CodeInterpreterFactory
from vision_agent.utils.execute import CodeInterpreter, Execution
Expand Down Expand Up @@ -123,6 +123,7 @@ def __init__(
verbosity: int = 0,
local_artifacts_path: Optional[Union[str, Path]] = None,
code_sandbox_runtime: Optional[str] = None,
callback_message: Optional[Callable[[Dict[str, Any]], None]] = None,
) -> None:
"""Initialize the VisionAgent.
Expand All @@ -141,6 +142,7 @@ def __init__(
self.max_iterations = 100
self.verbosity = verbosity
self.code_sandbox_runtime = code_sandbox_runtime
self.callback_message = callback_message
if self.verbosity >= 1:
_LOGGER.setLevel(logging.INFO)
self.local_artifacts_path = cast(
Expand Down Expand Up @@ -220,7 +222,14 @@ def chat_with_code(
for chat_i in int_chat:
if "media" in chat_i:
for media in chat_i["media"]:
media = cast(str, media)
if type(media) is str and media.startswith(("http", "https")):
# TODO: Ideally we should not call VA.tools here, we should come to revisit how to better support remote image later
file_path = Path(media).name
ndarray = load_image(media)
save_image(ndarray, file_path)
media = file_path
else:
media = cast(str, media)
artifacts.artifacts[Path(media).name] = open(media, "rb").read()

media_remote_path = (
Expand Down Expand Up @@ -262,6 +271,7 @@ def chat_with_code(
artifacts_loaded = artifacts.show()
int_chat.append({"role": "observation", "content": artifacts_loaded})
orig_chat.append({"role": "observation", "content": artifacts_loaded})
self.streaming_message({"role": "observation", "content": artifacts_loaded})

while not finished and iterations < self.max_iterations:
response = run_conversation(self.agent, int_chat)
Expand All @@ -274,6 +284,8 @@ def chat_with_code(
if last_response == response:
response["let_user_respond"] = True

self.streaming_message({"role": "assistant", "content": response})

if response["let_user_respond"]:
break

Expand All @@ -293,6 +305,13 @@ def chat_with_code(
orig_chat.append(
{"role": "observation", "content": obs, "execution": result}
)
self.streaming_message(
{
"role": "observation",
"content": obs,
"execution": result,
}
)

iterations += 1
last_response = response
Expand All @@ -305,5 +324,9 @@ def chat_with_code(
artifacts.save()
return orig_chat, artifacts

def streaming_message(self, message: Dict[str, Any]) -> None:
if self.callback_message:
self.callback_message(message)

def log_progress(self, data: Dict[str, Any]) -> None:
pass

0 comments on commit d04db84

Please sign in to comment.