Skip to content

Commit

Permalink
Fix for several issues with VisionAgent (#251)
Browse files Browse the repository at this point in the history
* fix issues around vision agent coder

* fix flake8

* fixed issue where it can't see media from view_media_artifact

* fixed user exec obs

* fixed side cases with agent

* fixed bug with edit vision code

* fixed bug with chat app

* added more test cases for string replacement funcs

* fix linting error
  • Loading branch information
dillonalaird authored Sep 26, 2024
1 parent d2074d7 commit d14a76f
Show file tree
Hide file tree
Showing 7 changed files with 267 additions and 119 deletions.
35 changes: 20 additions & 15 deletions examples/chat/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,12 @@


def update_messages(messages, lock):
if Path("artifacts.pkl").exists():
artifacts.load("artifacts.pkl")
new_chat, _ = agent.chat_with_code(messages, artifacts=artifacts)
with lock:
for new_message in new_chat:
if new_message not in messages:
messages.append(new_message)
if Path("artifacts.pkl").exists():
artifacts.load("artifacts.pkl")
new_chat, _ = agent.chat_with_code(messages, artifacts=artifacts)
for new_message in new_chat[len(messages) :]:
messages.append(new_message)


def get_updates(updates, lock):
Expand Down Expand Up @@ -106,15 +105,21 @@ def main():
prompt = st.session_state.input_text

if prompt:
st.session_state.messages.append({"role": "user", "content": prompt})
messages.chat_message("user").write(prompt)
message_thread = threading.Thread(
target=update_messages,
args=(st.session_state.messages, message_lock),
)
message_thread.daemon = True
message_thread.start()
st.session_state.input_text = ""
if (
len(st.session_state.messages) == 0
or prompt != st.session_state.messages[-1]["content"]
):
st.session_state.messages.append(
{"role": "user", "content": prompt}
)
messages.chat_message("user").write(prompt)
message_thread = threading.Thread(
target=update_messages,
args=(st.session_state.messages, message_lock),
)
message_thread.daemon = True
message_thread.start()
st.session_state.input_text = ""

with tabs[1]:
updates = st.container(height=400)
Expand Down
73 changes: 73 additions & 0 deletions tests/unit/test_meta_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from vision_agent.tools.meta_tools import (
Artifacts,
check_and_load_image,
use_object_detection_fine_tuning,
)


def test_check_and_load_image_none():
assert check_and_load_image("print('Hello, World!')") == []


def test_check_and_load_image_one():
assert check_and_load_image("view_media_artifact(artifacts, 'image.jpg')") == [
"image.jpg"
]


def test_check_and_load_image_two():
code = "view_media_artifact(artifacts, 'image1.jpg')\nview_media_artifact(artifacts, 'image2.jpg')"
assert check_and_load_image(code) == ["image1.jpg", "image2.jpg"]


def test_use_object_detection_fine_tuning_none():
artifacts = Artifacts("test")
code = "print('Hello, World!')"
artifacts["code"] = code
output = use_object_detection_fine_tuning(artifacts, "code", "123")
assert (
output == "[No function calls to replace with fine tuning id in artifact code]"
)
assert artifacts["code"] == code


def test_use_object_detection_fine_tuning():
artifacts = Artifacts("test")
code = """florence2_phrase_grounding('one', image1)
owl_v2_image('two', image2)
florence2_sam2_image('three', image3)"""
expected_code = """florence2_phrase_grounding("one", image1, "123")
owl_v2_image("two", image2, "123")
florence2_sam2_image("three", image3, "123")"""
artifacts["code"] = code

output = use_object_detection_fine_tuning(artifacts, "code", "123")
assert 'florence2_phrase_grounding("one", image1, "123")' in output
assert 'owl_v2_image("two", image2, "123")' in output
assert 'florence2_sam2_image("three", image3, "123")' in output
assert artifacts["code"] == expected_code


def test_use_object_detection_fine_tuning_twice():
artifacts = Artifacts("test")
code = """florence2_phrase_grounding('one', image1)
owl_v2_image('two', image2)
florence2_sam2_image('three', image3)"""
expected_code1 = """florence2_phrase_grounding("one", image1, "123")
owl_v2_image("two", image2, "123")
florence2_sam2_image("three", image3, "123")"""
expected_code2 = """florence2_phrase_grounding("one", image1, "456")
owl_v2_image("two", image2, "456")
florence2_sam2_image("three", image3, "456")"""
artifacts["code"] = code
output = use_object_detection_fine_tuning(artifacts, "code", "123")
assert 'florence2_phrase_grounding("one", image1, "123")' in output
assert 'owl_v2_image("two", image2, "123")' in output
assert 'florence2_sam2_image("three", image3, "123")' in output
assert artifacts["code"] == expected_code1

output = use_object_detection_fine_tuning(artifacts, "code", "456")
assert 'florence2_phrase_grounding("one", image1, "456")' in output
assert 'owl_v2_image("two", image2, "456")' in output
assert 'florence2_sam2_image("three", image3, "456")' in output
assert artifacts["code"] == expected_code2
52 changes: 52 additions & 0 deletions tests/unit/test_va.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from vision_agent.agent.vision_agent import parse_execution


def test_parse_execution_zero():
code = "print('Hello, World!')"
assert parse_execution(code) is None


def test_parse_execution_one():
code = "<execute_python>print('Hello, World!')</execute_python>"
assert parse_execution(code) == "print('Hello, World!')"


def test_parse_execution_no_test_multi_plan_generate():
code = "<execute_python>generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'])</execute_python>"
assert (
parse_execution(code, False)
== "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'], test_multi_plan=False)"
)


def test_parse_execution_no_test_multi_plan_edit():
code = "<execute_python>edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])</execute_python>"
assert (
parse_execution(code, False)
== "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])"
)


def test_parse_execution_custom_tool_names_generate():
code = "<execute_python>generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'])</execute_python>"
assert (
parse_execution(
code, test_multi_plan=False, customed_tool_names=["owl_v2_image"]
)
== "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'], test_multi_plan=False, custom_tool_names=['owl_v2_image'])"
)


def test_prase_execution_custom_tool_names_edit():
code = "<execute_python>edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])</execute_python>"
assert (
parse_execution(
code, test_multi_plan=False, customed_tool_names=["owl_v2_image"]
)
== "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'], custom_tool_names=['owl_v2_image'])"
)


def test_parse_execution_multiple_executes():
code = "<execute_python>print('Hello, World!')</execute_python><execute_python>print('Hello, World!')</execute_python>"
assert parse_execution(code) == "print('Hello, World!')\nprint('Hello, World!')"
103 changes: 62 additions & 41 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:
return extract_json(orch([message], stream=False)) # type: ignore


def run_code_action(
def execute_code_action(
code: str, code_interpreter: CodeInterpreter, artifact_remote_path: str
) -> Tuple[Execution, str]:
result = code_interpreter.exec_isolation(
Expand All @@ -106,19 +106,53 @@ def parse_execution(
customed_tool_names: Optional[List[str]] = None,
) -> Optional[str]:
code = None
if "<execute_python>" in response:
code = response[response.find("<execute_python>") + len("<execute_python>") :]
code = code[: code.find("</execute_python>")]
remaining = response
all_code = []
while "<execute_python>" in remaining:
code_i = remaining[
remaining.find("<execute_python>") + len("<execute_python>") :
]
code_i = code_i[: code_i.find("</execute_python>")]
remaining = remaining[
remaining.find("</execute_python>") + len("</execute_python>") :
]
all_code.append(code_i)

if len(all_code) > 0:
code = "\n".join(all_code)

if code is not None:
code = use_extra_vision_agent_args(code, test_multi_plan, customed_tool_names)
return code


def execute_user_code_action(
last_user_message: Message,
code_interpreter: CodeInterpreter,
artifact_remote_path: str,
) -> Tuple[Optional[Execution], Optional[str]]:
user_result = None
user_obs = None

if last_user_message["role"] != "user":
return user_result, user_obs

last_user_content = cast(str, last_user_message.get("content", ""))

user_code_action = parse_execution(last_user_content, False)
if user_code_action is not None:
user_result, user_obs = execute_code_action(
user_code_action, code_interpreter, artifact_remote_path
)
if user_result.error:
user_obs += f"\n{user_result.error}"
return user_result, user_obs


class VisionAgent(Agent):
"""Vision Agent is an agent that can chat with the user and call tools or other
agents to generate code for it. Vision Agent uses python code to execute actions
for the user. Vision Agent is inspired by by OpenDev
for the user. Vision Agent is inspired by by OpenDevin
https://github.com/OpenDevin/OpenDevin and CodeAct https://arxiv.org/abs/2402.01030
Example
Expand Down Expand Up @@ -278,9 +312,24 @@ def chat_with_code(
orig_chat.append({"role": "observation", "content": artifacts_loaded})
self.streaming_message({"role": "observation", "content": artifacts_loaded})

finished = self.execute_user_code_action(
last_user_message, code_interpreter, remote_artifacts_path
user_result, user_obs = execute_user_code_action(
last_user_message, code_interpreter, str(remote_artifacts_path)
)
finished = user_result is not None and user_obs is not None
if user_result is not None and user_obs is not None:
# be sure to update the chat with user execution results
chat_elt: Message = {"role": "observation", "content": user_obs}
int_chat.append(chat_elt)
chat_elt["execution"] = user_result
orig_chat.append(chat_elt)
self.streaming_message(
{
"role": "observation",
"content": user_obs,
"execution": user_result,
"finished": finished,
}
)

while not finished and iterations < self.max_iterations:
response = run_conversation(self.agent, int_chat)
Expand Down Expand Up @@ -322,7 +371,7 @@ def chat_with_code(
)

if code_action is not None:
result, obs = run_code_action(
result, obs = execute_code_action(
code_action, code_interpreter, str(remote_artifacts_path)
)

Expand All @@ -331,17 +380,17 @@ def chat_with_code(
if self.verbosity >= 1:
_LOGGER.info(obs)

chat_elt: Message = {"role": "observation", "content": obs}
obs_chat_elt: Message = {"role": "observation", "content": obs}
if media_obs and result.success:
chat_elt["media"] = [
obs_chat_elt["media"] = [
Path(code_interpreter.remote_path) / media_ob
for media_ob in media_obs
]

# don't add execution results to internal chat
int_chat.append(chat_elt)
chat_elt["execution"] = result
orig_chat.append(chat_elt)
int_chat.append(obs_chat_elt)
obs_chat_elt["execution"] = result
orig_chat.append(obs_chat_elt)
self.streaming_message(
{
"role": "observation",
Expand All @@ -362,34 +411,6 @@ def chat_with_code(
artifacts.save()
return orig_chat, artifacts

def execute_user_code_action(
self,
last_user_message: Message,
code_interpreter: CodeInterpreter,
remote_artifacts_path: Path,
) -> bool:
if last_user_message["role"] != "user":
return False
user_code_action = parse_execution(
cast(str, last_user_message.get("content", "")), False
)
if user_code_action is not None:
user_result, user_obs = run_code_action(
user_code_action, code_interpreter, str(remote_artifacts_path)
)
if self.verbosity >= 1:
_LOGGER.info(user_obs)
self.streaming_message(
{
"role": "observation",
"content": user_obs,
"execution": user_result,
"finished": True,
}
)
return True
return False

def streaming_message(self, message: Dict[str, Any]) -> None:
if self.callback_message:
self.callback_message(message)
Expand Down
8 changes: 4 additions & 4 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ def chat_with_workflow(
chat: List[Message],
test_multi_plan: bool = True,
display_visualization: bool = False,
customized_tool_names: Optional[List[str]] = None,
custom_tool_names: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Chat with VisionAgentCoder and return intermediate information regarding the
task.
Expand All @@ -707,8 +707,8 @@ def chat_with_workflow(
with the first plan.
display_visualization (bool): If True, it opens a new window locally to
show the image(s) created by visualization code (if there is any).
customized_tool_names (List[str]): A list of customized tools for agent to pick and use.
If not provided, default to full tool set from vision_agent.tools.
custom_tool_names (List[str]): A list of custom tools for the agent to pick
and use. If not provided, default to full tool set from vision_agent.tools.
Returns:
Dict[str, Any]: A dictionary containing the code, test, test result, plan,
Expand Down Expand Up @@ -760,7 +760,7 @@ def chat_with_workflow(
success = False

plans = self._create_plans(
int_chat, customized_tool_names, working_memory, self.planner
int_chat, custom_tool_names, working_memory, self.planner
)

if test_multi_plan:
Expand Down
Loading

0 comments on commit d14a76f

Please sign in to comment.