diff --git a/tests/unit/test_meta_tools.py b/tests/unit/test_meta_tools.py index 6cac95ce..fff867d9 100644 --- a/tests/unit/test_meta_tools.py +++ b/tests/unit/test_meta_tools.py @@ -91,7 +91,7 @@ def test_use_extra_vision_agent_args_real_case(): assert out_code == expected_code code = "edit_vision_code(artifacts, 'code.py', ['write code 1', 'write code 2'], ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])" - expected_code = "edit_vision_code(artifacts, 'code.py', ['write code 1', 'write code 2'], ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], test_multi_plan=True)" + expected_code = "edit_vision_code(artifacts, 'code.py', ['write code 1', 'write code 2'], ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])" out_code = use_extra_vision_agent_args(code) assert out_code == expected_code @@ -103,6 +103,6 @@ def test_use_extra_vision_args_with_custom_tools(): assert out_code == expected_code code = "edit_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])" - expected_code = "edit_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], test_multi_plan=True, custom_tool_names=['tool1', 'tool2'])" + expected_code = "edit_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], custom_tool_names=['tool1', 'tool2'])" out_code = use_extra_vision_agent_args(code, custom_tool_names=["tool1", "tool2"]) assert out_code == expected_code diff --git a/vision_agent/tools/meta_tools.py b/vision_agent/tools/meta_tools.py index d9537e7c..b481f4f7 100644 --- a/vision_agent/tools/meta_tools.py +++ b/vision_agent/tools/meta_tools.py @@ -676,12 +676,13 @@ def use_extra_vision_agent_args( for node in red: # seems to always be atomtrailers not call type if node.type == "atomtrailers": + if node.name.value == "generate_vision_code": + node.value[1].value.append(f"test_multi_plan={test_multi_plan}") + if ( node.name.value == "generate_vision_code" or node.name.value == "edit_vision_code" ): - node.value[1].value.append(f"test_multi_plan={test_multi_plan}") - if custom_tool_names is not None: node.value[1].value.append(f"custom_tool_names={custom_tool_names}") cleaned_code = red.dumps().strip()