diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index 9fd9f15c..690795f0 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -6,6 +6,8 @@ blip_image_caption, clip, closest_mask_distance, + countgd_counting, + countgd_example_based_counting, depth_anything_v2, detr_segmentation, dpt_hybrid_midas, @@ -32,8 +34,6 @@ template_match, vit_image_classification, vit_nsfw_classification, - countgd_counting, - countgd_example_based_counting, ) FINE_TUNE_ID = "65ebba4a-88b7-419f-9046-0750e30250da" diff --git a/tests/unit/test_meta_tools.py b/tests/unit/test_meta_tools.py index fced644b..6cac95ce 100644 --- a/tests/unit/test_meta_tools.py +++ b/tests/unit/test_meta_tools.py @@ -1,6 +1,7 @@ from vision_agent.tools.meta_tools import ( Artifacts, check_and_load_image, + use_extra_vision_agent_args, use_object_detection_fine_tuning, ) @@ -71,3 +72,37 @@ def test_use_object_detection_fine_tuning_twice(): assert 'owl_v2_image("two", image2, "456")' in output assert 'florence2_sam2_image("three", image3, "456")' in output assert artifacts["code"] == expected_code2 + + +def test_use_object_detection_fine_tuning_real_case(): + artifacts = Artifacts("test") + code = "florence2_phrase_grounding('(strange arg)', image1)" + expected_code = 'florence2_phrase_grounding("(strange arg)", image1, "123")' + artifacts["code"] = code + output = use_object_detection_fine_tuning(artifacts, "code", "123") + assert 'florence2_phrase_grounding("(strange arg)", image1, "123")' in output + assert artifacts["code"] == expected_code + + +def test_use_extra_vision_agent_args_real_case(): + code = "generate_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])" + expected_code = "generate_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], test_multi_plan=True)" + out_code = use_extra_vision_agent_args(code) + assert out_code == expected_code + + code = "edit_vision_code(artifacts, 'code.py', ['write code 1', 'write code 2'], ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])" + expected_code = "edit_vision_code(artifacts, 'code.py', ['write code 1', 'write code 2'], ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], test_multi_plan=True)" + out_code = use_extra_vision_agent_args(code) + assert out_code == expected_code + + +def test_use_extra_vision_args_with_custom_tools(): + code = "generate_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])" + expected_code = "generate_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], test_multi_plan=True, custom_tool_names=['tool1', 'tool2'])" + out_code = use_extra_vision_agent_args(code, custom_tool_names=["tool1", "tool2"]) + assert out_code == expected_code + + code = "edit_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])" + expected_code = "edit_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], test_multi_plan=True, custom_tool_names=['tool1', 'tool2'])" + out_code = use_extra_vision_agent_args(code, custom_tool_names=["tool1", "tool2"]) + assert out_code == expected_code diff --git a/vision_agent/tools/meta_tools.py b/vision_agent/tools/meta_tools.py index 0fb46cee..32d6d8e1 100644 --- a/vision_agent/tools/meta_tools.py +++ b/vision_agent/tools/meta_tools.py @@ -11,6 +11,7 @@ import numpy as np from IPython.display import display +from redbaron import RedBaron import vision_agent as va from vision_agent.agent.agent_utils import extract_json @@ -491,7 +492,7 @@ def edit_vision_code( name: str, chat_history: List[str], media: List[str], - customized_tool_names: Optional[List[str]] = None, + custom_tool_names: Optional[List[str]] = None, ) -> str: """Edits python code to solve a vision based task. @@ -499,7 +500,7 @@ def edit_vision_code( artifacts (Artifacts): The artifacts object to save the code to. name (str): The file path to the code. chat_history (List[str]): The chat history to used to generate the code. - customized_tool_names (Optional[List[str]]): Do not change this parameter. + custom_tool_names (Optional[List[str]]): Do not change this parameter. Returns: str: The edited code. @@ -542,7 +543,7 @@ def detect_dogs(image_path: str): response = agent.generate_code( fixed_chat_history, test_multi_plan=False, - custom_tool_names=customized_tool_names, + custom_tool_names=custom_tool_names, ) redisplay_results(response["test_result"]) code = response["code"] @@ -705,7 +706,7 @@ def get_diff_with_prompts(name: str, before: str, after: str) -> str: def use_extra_vision_agent_args( code: str, test_multi_plan: bool = True, - customized_tool_names: Optional[List[str]] = None, + custom_tool_names: Optional[List[str]] = None, ) -> str: """This is for forcing arguments passed by the user to VisionAgent into the VisionAgentCoder call. @@ -713,36 +714,24 @@ def use_extra_vision_agent_args( Parameters: code (str): The code to edit. test_multi_plan (bool): Do not change this parameter. - customized_tool_names (Optional[List[str]]): Do not change this parameter. + custom_tool_names (Optional[List[str]]): Do not change this parameter. Returns: str: The edited code. """ - generate_pattern = r"generate_vision_code\(\s*([^\)]+)\s*\)" - - def generate_replacer(match: re.Match) -> str: - arg = match.group(1) - out_str = f"generate_vision_code({arg}, test_multi_plan={test_multi_plan}" - if customized_tool_names is not None: - out_str += f", custom_tool_names={customized_tool_names})" - else: - out_str += ")" - return out_str - - edit_pattern = r"edit_vision_code\(\s*([^\)]+)\s*\)" - - def edit_replacer(match: re.Match) -> str: - arg = match.group(1) - out_str = f"edit_vision_code({arg}" - if customized_tool_names is not None: - out_str += f", custom_tool_names={customized_tool_names})" - else: - out_str += ")" - return out_str - - new_code = re.sub(generate_pattern, generate_replacer, code) - new_code = re.sub(edit_pattern, edit_replacer, new_code) - return new_code + red = RedBaron(code) + for node in red: + # seems to always be atomtrailers not call type + if node.type == "atomtrailers": + if ( + node.name.value == "generate_vision_code" + or node.name.value == "edit_vision_code" + ): + node.value[1].value.append(f"test_multi_plan={test_multi_plan}") + + if custom_tool_names is not None: + node.value[1].value.append(f"custom_tool_names={custom_tool_names}") + return red.dumps().strip() def use_object_detection_fine_tuning(