From b5f4fac39fee2ecb7ffd192fd7b52da18b081108 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Mon, 14 Oct 2024 10:19:56 -0700 Subject: [PATCH] fix side case for replace args --- tests/unit/test_meta_tools.py | 4 ++-- vision_agent/tools/meta_tools.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_meta_tools.py b/tests/unit/test_meta_tools.py index 6cac95ce..fff867d9 100644 --- a/tests/unit/test_meta_tools.py +++ b/tests/unit/test_meta_tools.py @@ -91,7 +91,7 @@ def test_use_extra_vision_agent_args_real_case(): assert out_code == expected_code code = "edit_vision_code(artifacts, 'code.py', ['write code 1', 'write code 2'], ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])" - expected_code = "edit_vision_code(artifacts, 'code.py', ['write code 1', 'write code 2'], ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], test_multi_plan=True)" + expected_code = "edit_vision_code(artifacts, 'code.py', ['write code 1', 'write code 2'], ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])" out_code = use_extra_vision_agent_args(code) assert out_code == expected_code @@ -103,6 +103,6 @@ def test_use_extra_vision_args_with_custom_tools(): assert out_code == expected_code code = "edit_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])" - expected_code = "edit_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], test_multi_plan=True, custom_tool_names=['tool1', 'tool2'])" + expected_code = "edit_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], custom_tool_names=['tool1', 'tool2'])" out_code = use_extra_vision_agent_args(code, custom_tool_names=["tool1", "tool2"]) assert out_code == expected_code diff --git a/vision_agent/tools/meta_tools.py b/vision_agent/tools/meta_tools.py index d9537e7c..b481f4f7 100644 --- a/vision_agent/tools/meta_tools.py +++ b/vision_agent/tools/meta_tools.py @@ -676,12 +676,13 @@ def use_extra_vision_agent_args( for node in red: # seems to always be atomtrailers not call type if node.type == "atomtrailers": + if node.name.value == "generate_vision_code": + node.value[1].value.append(f"test_multi_plan={test_multi_plan}") + if ( node.name.value == "generate_vision_code" or node.name.value == "edit_vision_code" ): - node.value[1].value.append(f"test_multi_plan={test_multi_plan}") - if custom_tool_names is not None: node.value[1].value.append(f"custom_tool_names={custom_tool_names}") cleaned_code = red.dumps().strip()