Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add callback to stream back message #232

Merged
merged 5 commits into from
Sep 11, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import copy
import json
import logging
import os
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Tuple, Union, cast, Callable

from vision_agent.agent import Agent
from vision_agent.agent.agent_utils import extract_json
Expand All @@ -13,8 +14,9 @@
VA_CODE,
)
from vision_agent.lmm import LMM, Message, OpenAILMM
from vision_agent.tools import META_TOOL_DOCSTRING
from vision_agent.tools import META_TOOL_DOCSTRING, save_image
from vision_agent.tools.meta_tools import Artifacts, use_extra_vision_agent_args
from vision_agent.tools.tools import load_image
from vision_agent.utils import CodeInterpreterFactory
from vision_agent.utils.execute import CodeInterpreter, Execution

Expand Down Expand Up @@ -123,6 +125,7 @@ def __init__(
verbosity: int = 0,
local_artifacts_path: Optional[Union[str, Path]] = None,
code_sandbox_runtime: Optional[str] = None,
callback_message: Optional[Callable[[Message], None]] = None,
) -> None:
"""Initialize the VisionAgent.

Expand All @@ -141,6 +144,7 @@ def __init__(
self.max_iterations = 100
self.verbosity = verbosity
self.code_sandbox_runtime = code_sandbox_runtime
self.callback_message = callback_message
if self.verbosity >= 1:
_LOGGER.setLevel(logging.INFO)
self.local_artifacts_path = cast(
Expand Down Expand Up @@ -181,7 +185,7 @@ def chat_with_code(
self,
chat: List[Message],
artifacts: Optional[Artifacts] = None,
test_multi_plan: bool = True,
test_multi_plan: bool = False,
customized_tool_names: Optional[List[str]] = None,
) -> Tuple[List[Message], Artifacts]:
"""Chat with VisionAgent, it will use code to execute actions to accomplish
Expand Down Expand Up @@ -220,7 +224,13 @@ def chat_with_code(
for chat_i in int_chat:
if "media" in chat_i:
for media in chat_i["media"]:
media = cast(str, media)
if type(media) is str and media.startswith(("http", "https")):
file_path = Path(media).name
ndarray = load_image(media)
save_image(ndarray, file_path)
media = file_path
else:
media = cast(str, media)
artifacts.artifacts[Path(media).name] = open(media, "rb").read()

media_remote_path = (
Expand Down Expand Up @@ -262,6 +272,7 @@ def chat_with_code(
artifacts_loaded = artifacts.show()
int_chat.append({"role": "observation", "content": artifacts_loaded})
orig_chat.append({"role": "observation", "content": artifacts_loaded})
self.callback_message({"role": "observation", "content": artifacts_loaded})

while not finished and iterations < self.max_iterations:
response = run_conversation(self.agent, int_chat)
Expand All @@ -274,6 +285,8 @@ def chat_with_code(
if last_response == response:
response["let_user_respond"] = True

self.callback_message({"role": "assistant", "content": response})
wuyiqunLu marked this conversation as resolved.
Show resolved Hide resolved

if response["let_user_respond"]:
break

Expand All @@ -293,6 +306,13 @@ def chat_with_code(
orig_chat.append(
{"role": "observation", "content": obs, "execution": result}
)
self.callback_message(
{
"role": "observation",
"content": obs,
"execution": result,
}
)

iterations += 1
last_response = response
Expand All @@ -305,5 +325,9 @@ def chat_with_code(
artifacts.save()
return orig_chat, artifacts

def streaming_message(self, message: Message) -> None:
if self.callback_message:
self.callback_message(message)

def log_progress(self, data: Dict[str, Any]) -> None:
wuyiqunLu marked this conversation as resolved.
Show resolved Hide resolved
pass
Loading