diff --git a/tests/unit/test_va.py b/tests/unit/test_va.py
index e7a6e7c5..b0b80e2d 100644
--- a/tests/unit/test_va.py
+++ b/tests/unit/test_va.py
@@ -1,4 +1,5 @@
from vision_agent.agent.agent_utils import extract_tag
+from vision_agent.agent.vision_agent import _clean_response
from vision_agent.tools.meta_tools import use_extra_vision_agent_args
@@ -31,7 +32,7 @@ def test_parse_execution_no_test_multi_plan_edit():
code = "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])"
assert (
parse_execution(code, False)
- == "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'], test_multi_plan=False)"
+ == "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])"
)
@@ -47,10 +48,29 @@ def test_parse_execution_custom_tool_names_edit():
code = "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])"
assert (
parse_execution(code, test_multi_plan=False, custom_tool_names=["owl_v2_image"])
- == "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'], test_multi_plan=False, custom_tool_names=['owl_v2_image'])"
+ == "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'], custom_tool_names=['owl_v2_image'])"
)
def test_parse_execution_multiple_executes():
code = "print('Hello, World!')print('Hello, World!')"
assert parse_execution(code) == "print('Hello, World!')\nprint('Hello, World!')"
+
+
+def test_clean_response():
+ response = """Thinking...
+Here is the code:
+print('Hello, World!')"""
+ assert _clean_response(response) == response
+
+
+def test_clean_response_remove_extra():
+ response = """Thinking...
+Here is the code:
+print('Hello, World!')
+More thinking...
+Response to code..."""
+ expected_response = """Thinking...
+Here is the code:
+print('Hello, World!')"""
+ assert _clean_response(response) == expected_response
diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py
index 29643ecd..bfd2697b 100644
--- a/vision_agent/agent/vision_agent.py
+++ b/vision_agent/agent/vision_agent.py
@@ -85,6 +85,15 @@ def format_agent_message(agent_message: str) -> str:
return output
+def _clean_response(response: str) -> str:
+ # Sometimes the LLM will hallucinate responses to an tag as if it
+ # had already executed the code. This function removes the hallucinated response.
+ if "" in response:
+ end_execute_python = response.find("")
+ response = response[: end_execute_python + len("")]
+ return response
+
+
def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:
chat = copy.deepcopy(chat)
@@ -114,6 +123,10 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:
message["media"] = chat[-1]["media"]
conv_resp = cast(str, orch([message], stream=False))
+ # clean the response first, if we are executing code, do not resond or end
+ # conversation before the code has been executed.
+ conv_resp = _clean_response(conv_resp)
+
let_user_respond_str = extract_tag(conv_resp, "let_user_respond")
let_user_respond = (
"true" in let_user_respond_str.lower() if let_user_respond_str else False