Skip to content

Commit

Permalink
added zmq logging
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Jul 24, 2024
1 parent de38bcb commit 9d5e365
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 75 deletions.
197 changes: 124 additions & 73 deletions examples/chat/app.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import os
import pickle as pkl
import threading
import time
from pathlib import Path
from typing import Dict, List

import streamlit as st
import zmq
from code_editor import code_editor
from streamlit_autorefresh import st_autorefresh

import vision_agent as va

WORKSPACE = Path(os.environ.get("WORKSPACE", ""))
ZMQ_PORT = os.environ.get("ZMQ_PORT", None)
if ZMQ_PORT is None:
ZMQ_PORT = "5555"
os.environ["ZMQ_PORT"] = ZMQ_PORT

st.set_page_config(layout="wide")

st.title("Vision Agent")
left_column, right_column = st.columns([2, 3])

CACHE = Path(".cache.pkl")
SAVE = {
Expand All @@ -26,75 +28,124 @@
}
agent = va.agent.VisionAgent(verbosity=1)


def generate_response(chat: List[Dict]) -> List[Dict]:
return agent.chat_with_code(chat)

st.set_page_config(layout="wide")

if "file_path" not in st.session_state:
st.session_state.file_path = None


if "messages" not in st.session_state:
if CACHE.exists():
with open(CACHE, "rb") as f:
st.session_state.messages = pkl.load(f)
else:
st.session_state.messages = []

with left_column:
st.title("Chat")

messages = st.container(height=400)
for message in st.session_state.messages:
messages.chat_message(message["role"]).write(message["content"])

if prompt := st.chat_input("Chat here"):
st.session_state.messages.append({"role": "user", "content": prompt})
messages.chat_message("user").write(prompt)

updated_chat = generate_response(st.session_state.messages)
for chat in updated_chat:
if chat not in st.session_state.messages:
st.session_state.messages.append(chat)
if chat["role"] in {"user", "assistant"}:
msg = chat["content"]
msg = msg.replace("<execute_python>", "<execute_python>`")
msg = msg.replace("</execute_python>", "`</execute_python>")
messages.chat_message(chat["role"]).write(msg)
else:
messages.chat_message("observation").text(chat["content"])


with right_column:
st.title("File & Code Editor")

tabs = st.tabs(["File Browser", "Code Editor"])

with tabs[0]:
uploaded_file = st.file_uploader("Upload a file")
if uploaded_file is not None:
with open(WORKSPACE / uploaded_file.name, "wb") as f:
f.write(uploaded_file.getbuffer())

for file in WORKSPACE.iterdir():
if "__pycache__" not in str(file) and not str(file).startswith("."):
if st.button(file.name):
st.session_state.file_path = file

with tabs[1]:
if (
"file_path" not in st.session_state
or st.session_state.file_path is None
or st.session_state.file_path.suffix != ".py"
):
st.write("Please select a python file from the file browser.")
else:
with open(WORKSPACE / st.session_state.file_path, "r") as f:
code = f.read()

resp = code_editor(code, lang="python", buttons=[SAVE])
if resp["type"] == "saved":
text = resp["text"]
with open(st.session_state.file_path, "w") as f:
f.write(text)
st.session_state.messages = []

if "updates" not in st.session_state:
st.session_state.updates = []

if "input_text" not in st.session_state:
st.session_state.input_text = ""


def update_messages(messages, lock):
new_chat = agent.chat_with_code(messages)
with lock:
for new_message in new_chat:
if new_message not in messages:
messages.append(new_message)


def get_updates(updates, lock):
context = zmq.Context()
socket = context.socket(zmq.PULL)
socket.bind(f"tcp://*:{ZMQ_PORT}")

while True:
message = socket.recv_json()
with lock:
updates.append(message)
time.sleep(0.1)


def submit():
st.session_state.input_text = st.session_state.widget
st.session_state.widget = ""


update_lock = threading.Lock()
message_lock = threading.Lock()

st_autorefresh(interval=1000, key="refresh")


def main():
st.title("Vision Agent")
left_column, right_column = st.columns([2, 3])

with left_column:
st.title("Chat & Code Execution")
tabs = st.tabs(["Chat", "Code Execution"])

with tabs[0]:
messages = st.container(height=400)
for message in st.session_state.messages:
messages.chat_message(message["role"]).write(message["content"])

st.text_input("Chat here", key="widget", on_change=submit)
prompt = st.session_state.input_text

if prompt:
st.session_state.messages.append({"role": "user", "content": prompt})
messages.chat_message("user").write(prompt)
message_thread = threading.Thread(
target=update_messages,
args=(st.session_state.messages, message_lock),
)
message_thread.daemon = True
message_thread.start()
st.session_state.input_text = ""

with tabs[1]:
updates = st.container(height=400)
for update in st.session_state.updates:
updates.chat_message("coder").write(update)

with right_column:
st.title("File Browser & Code Editor")
tabs = st.tabs(["File Browser", "Code Editor"])

with tabs[0]:
uploaded_file = st.file_uploader("Upload a file")
if uploaded_file is not None:
with open(WORKSPACE / uploaded_file.name, "wb") as f:
f.write(uploaded_file.getbuffer())

for file in WORKSPACE.iterdir():
if "__pycache__" not in str(file) and not str(file).startswith("."):
if st.button(file.name):
st.session_state.file_path = file

with tabs[1]:
if (
"file_path" not in st.session_state
or st.session_state.file_path is None
or st.session_state.file_path.suffix != ".py"
):
st.write("Please select a python file from the file browser.")
else:
with open(WORKSPACE / st.session_state.file_path, "r") as f:
code = f.read()

resp = code_editor(code, lang="python", buttons=[SAVE])
if resp["type"] == "saved":
text = resp["text"]
with open(st.session_state.file_path, "w") as f:
f.write(text)

if "update_thread_started" not in st.session_state:
update_thread = threading.Thread(
target=get_updates, args=(st.session_state.updates, update_lock)
)
update_thread.daemon = True
update_thread.start()
st.session_state.update_thread_started = True


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions examples/chat/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
streamlit
streamlit_code_editor
streamlit-autorefresh
zmq
42 changes: 42 additions & 0 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,17 @@ def pick_plan(
tool_info: str,
model: LMM,
code_interpreter: CodeInterpreter,
log_progress: Callable[[Dict[str, Any]], None],
verbosity: int = 0,
max_retries: int = 3,
) -> Tuple[str, str]:
log_progress(
{
"type": "pick_plan",
"status": "started",
}
)

chat = copy.deepcopy(chat)
if chat[-1]["role"] != "user":
raise ValueError("Last chat message must be from the user.")
Expand All @@ -151,6 +159,15 @@ def pick_plan(
)

code = extract_code(model(prompt))
log_progress(
{
"type": "pick_plan",
"status": "running",
"payload": {
"code": DefaultImports.prepend_imports(code),
},
}
)
tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
tool_output_str = ""
if len(tool_output.logs.stdout) > 0:
Expand All @@ -171,6 +188,15 @@ def pick_plan(
),
)
code = extract_code(model(prompt))
log_progress(
{
"type": "pick_plan",
"status": "running",
"payload": {
"code": DefaultImports.prepend_imports(code),
},
}
)
tool_output = code_interpreter.exec_isolation(
DefaultImports.prepend_imports(code)
)
Expand Down Expand Up @@ -198,8 +224,16 @@ def pick_plan(
)
chat[-1]["content"] = prompt
best_plan = extract_json(model(chat))

if verbosity >= 1:
_LOGGER.info(f"Best plan:\n{best_plan}")
log_progress(
{
"type": "pick_plan",
"status": "completed",
"payload": best_plan,
}
)
return best_plan["best_plan"], tool_output_str


Expand Down Expand Up @@ -662,6 +696,13 @@ def chat_with_workflow(
_LOGGER.info(
f"\n{tabulate(tabular_data=plans[p], headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}"
)
self.log_progress(
{
"type": "plans",
"status": "completed",
"payload": plans,
}
)

tool_infos = retrieve_tools(
plans,
Expand All @@ -677,6 +718,7 @@ def chat_with_workflow(
tool_infos["all"],
self.coder,
code_interpreter,
self.log_progress,
verbosity=self.verbosity,
)
else:
Expand Down
23 changes: 21 additions & 2 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import os
import subprocess
from pathlib import Path
from typing import List, Union
from typing import Any, Dict, List, Union

import vision_agent as va
from vision_agent.lmm.types import Message
Expand All @@ -12,6 +14,16 @@
CURRENT_FILE = None
CURRENT_LINE = 0
DEFAULT_WINDOW_SIZE = 100
ZMQ_PORT = os.environ.get("ZMQ_PORT", None)


def report_progress_callback(port: int, inp: Dict[str, Any]) -> None:
import zmq

context = zmq.Context()
socket = context.socket(zmq.PUSH)
socket.connect(f"tcp://localhost:{port}")
socket.send_json(inp)


def filter_file(file_name: Union[str, Path]) -> bool:
Expand Down Expand Up @@ -45,7 +57,14 @@ def detect_dogs(image_path: str):
return dogs
"""

agent = va.agent.VisionAgentCoder()
if ZMQ_PORT is not None:
agent = va.agent.VisionAgentCoder(
report_progress_callback=lambda inp: report_progress_callback(
int(ZMQ_PORT), inp # type: ignore
)
)
else:
agent = va.agent.VisionAgentCoder()
try:
fixed_chat: List[Message] = [{"role": "user", "content": chat, "media": media}]
response = agent.chat_with_workflow(fixed_chat)
Expand Down

0 comments on commit 9d5e365

Please sign in to comment.