Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Issues #268

Merged
merged 4 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/chat/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def update_messages(messages, lock):
with lock:
if Path("artifacts.pkl").exists():
artifacts.load("artifacts.pkl")
new_chat, _ = agent.chat_with_code(messages, artifacts=artifacts)
new_chat, _ = agent.chat_with_artifacts(messages, artifacts=artifacts)
for new_message in new_chat[len(messages) :]:
messages.append(new_message)

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_use_extra_vision_agent_args_real_case():
assert out_code == expected_code

code = "edit_vision_code(artifacts, 'code.py', ['write code 1', 'write code 2'], ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])"
expected_code = "edit_vision_code(artifacts, 'code.py', ['write code 1', 'write code 2'], ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], test_multi_plan=True)"
expected_code = "edit_vision_code(artifacts, 'code.py', ['write code 1', 'write code 2'], ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])"
out_code = use_extra_vision_agent_args(code)
assert out_code == expected_code

Expand All @@ -103,6 +103,6 @@ def test_use_extra_vision_args_with_custom_tools():
assert out_code == expected_code

code = "edit_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])"
expected_code = "edit_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], test_multi_plan=True, custom_tool_names=['tool1', 'tool2'])"
expected_code = "edit_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], custom_tool_names=['tool1', 'tool2'])"
out_code = use_extra_vision_agent_args(code, custom_tool_names=["tool1", "tool2"])
assert out_code == expected_code
24 changes: 22 additions & 2 deletions tests/unit/test_va.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from vision_agent.agent.agent_utils import extract_tag
from vision_agent.agent.vision_agent import _clean_response
from vision_agent.tools.meta_tools import use_extra_vision_agent_args


Expand Down Expand Up @@ -31,7 +32,7 @@ def test_parse_execution_no_test_multi_plan_edit():
code = "<execute_python>edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])</execute_python>"
assert (
parse_execution(code, False)
== "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'], test_multi_plan=False)"
== "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])"
)


Expand All @@ -47,10 +48,29 @@ def test_parse_execution_custom_tool_names_edit():
code = "<execute_python>edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])</execute_python>"
assert (
parse_execution(code, test_multi_plan=False, custom_tool_names=["owl_v2_image"])
== "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'], test_multi_plan=False, custom_tool_names=['owl_v2_image'])"
== "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'], custom_tool_names=['owl_v2_image'])"
)


def test_parse_execution_multiple_executes():
code = "<execute_python>print('Hello, World!')</execute_python><execute_python>print('Hello, World!')</execute_python>"
assert parse_execution(code) == "print('Hello, World!')\nprint('Hello, World!')"


def test_clean_response():
response = """<thinking>Thinking...</thinking>
<response>Here is the code:</response>
<execute_python>print('Hello, World!')</execute_python>"""
assert _clean_response(response) == response


def test_clean_response_remove_extra():
response = """<thinking>Thinking...</thinking>
<response>Here is the code:</response>
<execute_python>print('Hello, World!')</execute_python>
<thinking>More thinking...</thinking>
<response>Response to code...</response>"""
expected_response = """<thinking>Thinking...</thinking>
<response>Here is the code:</response>
<execute_python>print('Hello, World!')</execute_python>"""
assert _clean_response(response) == expected_response
15 changes: 14 additions & 1 deletion vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ def format_agent_message(agent_message: str) -> str:
return output


def _clean_response(response: str) -> str:
# Sometimes the LLM will hallucinate responses to an <execute_python> tag as if it
# had already executed the code. This function removes the hallucinated response.
if "<execute_python>" in response:
end_execute_python = response.find("</execute_python>")
response = response[: end_execute_python + len("</execute_python>")]
return response


def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:
chat = copy.deepcopy(chat)

Expand Down Expand Up @@ -114,6 +123,10 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:
message["media"] = chat[-1]["media"]
conv_resp = cast(str, orch([message], stream=False))

# clean the response first, if we are executing code, do not resond or end
# conversation before the code has been executed.
conv_resp = _clean_response(conv_resp)

let_user_respond_str = extract_tag(conv_resp, "let_user_respond")
let_user_respond = (
"true" in let_user_respond_str.lower() if let_user_respond_str else False
Expand Down Expand Up @@ -458,7 +471,7 @@ def chat_with_artifacts(
self.streaming_message(
{
"role": "assistant",
"content": json.dumps(response),
"content": json.dumps(add_step_descriptions(response)),
"finished": finished and code_action is None,
}
)
Expand Down
5 changes: 3 additions & 2 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,12 +676,13 @@ def use_extra_vision_agent_args(
for node in red:
# seems to always be atomtrailers not call type
if node.type == "atomtrailers":
if node.name.value == "generate_vision_code":
node.value[1].value.append(f"test_multi_plan={test_multi_plan}")

if (
node.name.value == "generate_vision_code"
or node.name.value == "edit_vision_code"
):
node.value[1].value.append(f"test_multi_plan={test_multi_plan}")

if custom_tool_names is not None:
node.value[1].value.append(f"custom_tool_names={custom_tool_names}")
cleaned_code = red.dumps().strip()
Expand Down
Loading