Skip to content

Commit

Permalink
feat: add callback to stream back message
Browse files Browse the repository at this point in the history
  • Loading branch information
wuyiqunLu committed Sep 10, 2024
1 parent 66dfbe1 commit 5ef03f2
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import copy
import json
import logging
import os
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Tuple, Union, cast, Callable

from vision_agent.agent import Agent
from vision_agent.agent.agent_utils import extract_json
Expand All @@ -13,8 +14,9 @@
VA_CODE,
)
from vision_agent.lmm import LMM, Message, OpenAILMM
from vision_agent.tools import META_TOOL_DOCSTRING
from vision_agent.tools import META_TOOL_DOCSTRING, save_image
from vision_agent.tools.meta_tools import Artifacts, use_extra_vision_agent_args
from vision_agent.tools.tools import load_image
from vision_agent.utils import CodeInterpreterFactory
from vision_agent.utils.execute import CodeInterpreter, Execution

Expand Down Expand Up @@ -123,6 +125,7 @@ def __init__(
verbosity: int = 0,
local_artifacts_path: Optional[Union[str, Path]] = None,
code_sandbox_runtime: Optional[str] = None,
callback_message: Optional[Callable[[Message], None]] = None,
) -> None:
"""Initialize the VisionAgent.
Expand All @@ -141,6 +144,7 @@ def __init__(
self.max_iterations = 100
self.verbosity = verbosity
self.code_sandbox_runtime = code_sandbox_runtime
self.callback_message = callback_message
if self.verbosity >= 1:
_LOGGER.setLevel(logging.INFO)
self.local_artifacts_path = cast(
Expand Down Expand Up @@ -181,7 +185,7 @@ def chat_with_code(
self,
chat: List[Message],
artifacts: Optional[Artifacts] = None,
test_multi_plan: bool = True,
test_multi_plan: bool = False,
customized_tool_names: Optional[List[str]] = None,
) -> Tuple[List[Message], Artifacts]:
"""Chat with VisionAgent, it will use code to execute actions to accomplish
Expand Down Expand Up @@ -220,7 +224,13 @@ def chat_with_code(
for chat_i in int_chat:
if "media" in chat_i:
for media in chat_i["media"]:
media = cast(str, media)
if type(media) is str and media.startswith(("http", "https")):
file_path = Path(media).name
ndarray = load_image(media)
save_image(ndarray, file_path)
media = file_path
else:
media = cast(str, media)
artifacts.artifacts[Path(media).name] = open(media, "rb").read()

media_remote_path = (
Expand Down Expand Up @@ -262,6 +272,7 @@ def chat_with_code(
artifacts_loaded = artifacts.show()
int_chat.append({"role": "observation", "content": artifacts_loaded})
orig_chat.append({"role": "observation", "content": artifacts_loaded})
self.callback_message({"role": "observation", "content": artifacts_loaded})

while not finished and iterations < self.max_iterations:
response = run_conversation(self.agent, int_chat)
Expand All @@ -274,6 +285,8 @@ def chat_with_code(
if last_response == response:
response["let_user_respond"] = True

self.callback_message({"role": "assistant", "content": response})

if response["let_user_respond"]:
break

Expand All @@ -293,6 +306,13 @@ def chat_with_code(
orig_chat.append(
{"role": "observation", "content": obs, "execution": result}
)
self.callback_message(
{
"role": "observation",
"content": obs,
"execution": result,
}
)

iterations += 1
last_response = response
Expand All @@ -305,5 +325,9 @@ def chat_with_code(
artifacts.save()
return orig_chat, artifacts

def streaming_message(self, message: Message) -> None:
if self.callback_message:
self.callback_message(message)

def log_progress(self, data: Dict[str, Any]) -> None:
pass

0 comments on commit 5ef03f2

Please sign in to comment.