Skip to content

Commit

Permalink
Fix Issues (#268)
Browse files Browse the repository at this point in the history
* fix side case for replace args

* fix new names

* fix hallucination cases

* add forced responses to streaming
  • Loading branch information
dillonalaird authored Oct 14, 2024
1 parent 47d0057 commit 840f0fc
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 8 deletions.
2 changes: 1 addition & 1 deletion examples/chat/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def update_messages(messages, lock):
with lock:
if Path("artifacts.pkl").exists():
artifacts.load("artifacts.pkl")
new_chat, _ = agent.chat_with_code(messages, artifacts=artifacts)
new_chat, _ = agent.chat_with_artifacts(messages, artifacts=artifacts)
for new_message in new_chat[len(messages) :]:
messages.append(new_message)

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_use_extra_vision_agent_args_real_case():
assert out_code == expected_code

code = "edit_vision_code(artifacts, 'code.py', ['write code 1', 'write code 2'], ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])"
expected_code = "edit_vision_code(artifacts, 'code.py', ['write code 1', 'write code 2'], ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], test_multi_plan=True)"
expected_code = "edit_vision_code(artifacts, 'code.py', ['write code 1', 'write code 2'], ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])"
out_code = use_extra_vision_agent_args(code)
assert out_code == expected_code

Expand All @@ -103,6 +103,6 @@ def test_use_extra_vision_args_with_custom_tools():
assert out_code == expected_code

code = "edit_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])"
expected_code = "edit_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], test_multi_plan=True, custom_tool_names=['tool1', 'tool2'])"
expected_code = "edit_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], custom_tool_names=['tool1', 'tool2'])"
out_code = use_extra_vision_agent_args(code, custom_tool_names=["tool1", "tool2"])
assert out_code == expected_code
24 changes: 22 additions & 2 deletions tests/unit/test_va.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from vision_agent.agent.agent_utils import extract_tag
from vision_agent.agent.vision_agent import _clean_response
from vision_agent.tools.meta_tools import use_extra_vision_agent_args


Expand Down Expand Up @@ -31,7 +32,7 @@ def test_parse_execution_no_test_multi_plan_edit():
code = "<execute_python>edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])</execute_python>"
assert (
parse_execution(code, False)
== "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'], test_multi_plan=False)"
== "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])"
)


Expand All @@ -47,10 +48,29 @@ def test_parse_execution_custom_tool_names_edit():
code = "<execute_python>edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])</execute_python>"
assert (
parse_execution(code, test_multi_plan=False, custom_tool_names=["owl_v2_image"])
== "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'], test_multi_plan=False, custom_tool_names=['owl_v2_image'])"
== "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'], custom_tool_names=['owl_v2_image'])"
)


def test_parse_execution_multiple_executes():
code = "<execute_python>print('Hello, World!')</execute_python><execute_python>print('Hello, World!')</execute_python>"
assert parse_execution(code) == "print('Hello, World!')\nprint('Hello, World!')"


def test_clean_response():
response = """<thinking>Thinking...</thinking>
<response>Here is the code:</response>
<execute_python>print('Hello, World!')</execute_python>"""
assert _clean_response(response) == response


def test_clean_response_remove_extra():
response = """<thinking>Thinking...</thinking>
<response>Here is the code:</response>
<execute_python>print('Hello, World!')</execute_python>
<thinking>More thinking...</thinking>
<response>Response to code...</response>"""
expected_response = """<thinking>Thinking...</thinking>
<response>Here is the code:</response>
<execute_python>print('Hello, World!')</execute_python>"""
assert _clean_response(response) == expected_response
15 changes: 14 additions & 1 deletion vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ def format_agent_message(agent_message: str) -> str:
return output


def _clean_response(response: str) -> str:
# Sometimes the LLM will hallucinate responses to an <execute_python> tag as if it
# had already executed the code. This function removes the hallucinated response.
if "<execute_python>" in response:
end_execute_python = response.find("</execute_python>")
response = response[: end_execute_python + len("</execute_python>")]
return response


def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:
chat = copy.deepcopy(chat)

Expand Down Expand Up @@ -114,6 +123,10 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:
message["media"] = chat[-1]["media"]
conv_resp = cast(str, orch([message], stream=False))

# clean the response first, if we are executing code, do not resond or end
# conversation before the code has been executed.
conv_resp = _clean_response(conv_resp)

let_user_respond_str = extract_tag(conv_resp, "let_user_respond")
let_user_respond = (
"true" in let_user_respond_str.lower() if let_user_respond_str else False
Expand Down Expand Up @@ -458,7 +471,7 @@ def chat_with_artifacts(
self.streaming_message(
{
"role": "assistant",
"content": json.dumps(response),
"content": json.dumps(add_step_descriptions(response)),
"finished": finished and code_action is None,
}
)
Expand Down
5 changes: 3 additions & 2 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,12 +676,13 @@ def use_extra_vision_agent_args(
for node in red:
# seems to always be atomtrailers not call type
if node.type == "atomtrailers":
if node.name.value == "generate_vision_code":
node.value[1].value.append(f"test_multi_plan={test_multi_plan}")

if (
node.name.value == "generate_vision_code"
or node.name.value == "edit_vision_code"
):
node.value[1].value.append(f"test_multi_plan={test_multi_plan}")

if custom_tool_names is not None:
node.value[1].value.append(f"custom_tool_names={custom_tool_names}")
cleaned_code = red.dumps().strip()
Expand Down

0 comments on commit 840f0fc

Please sign in to comment.