diff --git a/tests/unit/test_va.py b/tests/unit/test_va.py index ff4e9b46..8f5e234d 100644 --- a/tests/unit/test_va.py +++ b/tests/unit/test_va.py @@ -23,7 +23,7 @@ def test_parse_execution_no_test_multi_plan_edit(): code = "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])" assert ( parse_execution(code, False) - == "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])" + == "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'], test_multi_plan=False)" ) @@ -31,19 +31,19 @@ def test_parse_execution_custom_tool_names_generate(): code = "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'])" assert ( parse_execution( - code, test_multi_plan=False, customed_tool_names=["owl_v2_image"] + code, test_multi_plan=False, custom_tool_names=["owl_v2_image"] ) == "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'], test_multi_plan=False, custom_tool_names=['owl_v2_image'])" ) -def test_prase_execution_custom_tool_names_edit(): +def test_parse_execution_custom_tool_names_edit(): code = "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])" assert ( parse_execution( - code, test_multi_plan=False, customed_tool_names=["owl_v2_image"] + code, test_multi_plan=False, custom_tool_names=["owl_v2_image"] ) - == "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'], custom_tool_names=['owl_v2_image'])" + == "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'], test_multi_plan=False, custom_tool_names=['owl_v2_image'])" ) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 6e1621f0..42541d33 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -103,7 +103,7 @@ def execute_code_action( def parse_execution( response: str, test_multi_plan: bool = True, - customed_tool_names: Optional[List[str]] = None, + custom_tool_names: Optional[List[str]] = None, ) -> Optional[str]: code = None remaining = response @@ -122,7 +122,7 @@ def parse_execution( code = "\n".join(all_code) if code is not None: - code = use_extra_vision_agent_args(code, test_multi_plan, customed_tool_names) + code = use_extra_vision_agent_args(code, test_multi_plan, custom_tool_names) return code @@ -278,7 +278,7 @@ def chat_with_artifacts( chat: List[Message], artifacts: Optional[Artifacts] = None, test_multi_plan: bool = True, - customized_tool_names: Optional[List[str]] = None, + custom_tool_names: Optional[List[str]] = None, ) -> Tuple[List[Message], Artifacts]: """Chat with VisionAgent, it will use code to execute actions to accomplish its tasks. @@ -292,7 +292,7 @@ def chat_with_artifacts( test_multi_plan (bool): If True, it will test tools for multiple plans and pick the best one based off of the tool results. If False, it will go with the first plan. - customized_tool_names (List[str]): A list of customized tools for agent to + custom_tool_names (List[str]): A list of customized tools for agent to pick and use. If not provided, default to full tool set from vision_agent.tools. @@ -411,7 +411,7 @@ def chat_with_artifacts( finished = response["let_user_respond"] code_action = parse_execution( - response["response"], test_multi_plan, customized_tool_names + response["response"], test_multi_plan, custom_tool_names ) if last_response == response: