From 41115278c325b2c0c4a73e2225739d8729c0975e Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Thu, 26 Sep 2024 14:47:27 -0700 Subject: [PATCH] added more test cases for string replacement funcs --- tests/unit/test_meta_tools.py | 73 ++++++++++++++++++++++++ tests/unit/test_va.py | 52 +++++++++++++++++ vision_agent/agent/vision_agent.py | 18 +++++- vision_agent/agent/vision_agent_coder.py | 8 +-- vision_agent/tools/meta_tools.py | 56 +++++------------- 5 files changed, 159 insertions(+), 48 deletions(-) create mode 100644 tests/unit/test_meta_tools.py create mode 100644 tests/unit/test_va.py diff --git a/tests/unit/test_meta_tools.py b/tests/unit/test_meta_tools.py new file mode 100644 index 00000000..fced644b --- /dev/null +++ b/tests/unit/test_meta_tools.py @@ -0,0 +1,73 @@ +from vision_agent.tools.meta_tools import ( + Artifacts, + check_and_load_image, + use_object_detection_fine_tuning, +) + + +def test_check_and_load_image_none(): + assert check_and_load_image("print('Hello, World!')") == [] + + +def test_check_and_load_image_one(): + assert check_and_load_image("view_media_artifact(artifacts, 'image.jpg')") == [ + "image.jpg" + ] + + +def test_check_and_load_image_two(): + code = "view_media_artifact(artifacts, 'image1.jpg')\nview_media_artifact(artifacts, 'image2.jpg')" + assert check_and_load_image(code) == ["image1.jpg", "image2.jpg"] + + +def test_use_object_detection_fine_tuning_none(): + artifacts = Artifacts("test") + code = "print('Hello, World!')" + artifacts["code"] = code + output = use_object_detection_fine_tuning(artifacts, "code", "123") + assert ( + output == "[No function calls to replace with fine tuning id in artifact code]" + ) + assert artifacts["code"] == code + + +def test_use_object_detection_fine_tuning(): + artifacts = Artifacts("test") + code = """florence2_phrase_grounding('one', image1) +owl_v2_image('two', image2) +florence2_sam2_image('three', image3)""" + expected_code = """florence2_phrase_grounding("one", image1, "123") +owl_v2_image("two", image2, "123") +florence2_sam2_image("three", image3, "123")""" + artifacts["code"] = code + + output = use_object_detection_fine_tuning(artifacts, "code", "123") + assert 'florence2_phrase_grounding("one", image1, "123")' in output + assert 'owl_v2_image("two", image2, "123")' in output + assert 'florence2_sam2_image("three", image3, "123")' in output + assert artifacts["code"] == expected_code + + +def test_use_object_detection_fine_tuning_twice(): + artifacts = Artifacts("test") + code = """florence2_phrase_grounding('one', image1) +owl_v2_image('two', image2) +florence2_sam2_image('three', image3)""" + expected_code1 = """florence2_phrase_grounding("one", image1, "123") +owl_v2_image("two", image2, "123") +florence2_sam2_image("three", image3, "123")""" + expected_code2 = """florence2_phrase_grounding("one", image1, "456") +owl_v2_image("two", image2, "456") +florence2_sam2_image("three", image3, "456")""" + artifacts["code"] = code + output = use_object_detection_fine_tuning(artifacts, "code", "123") + assert 'florence2_phrase_grounding("one", image1, "123")' in output + assert 'owl_v2_image("two", image2, "123")' in output + assert 'florence2_sam2_image("three", image3, "123")' in output + assert artifacts["code"] == expected_code1 + + output = use_object_detection_fine_tuning(artifacts, "code", "456") + assert 'florence2_phrase_grounding("one", image1, "456")' in output + assert 'owl_v2_image("two", image2, "456")' in output + assert 'florence2_sam2_image("three", image3, "456")' in output + assert artifacts["code"] == expected_code2 diff --git a/tests/unit/test_va.py b/tests/unit/test_va.py new file mode 100644 index 00000000..85e75426 --- /dev/null +++ b/tests/unit/test_va.py @@ -0,0 +1,52 @@ +from vision_agent.agent.vision_agent import parse_execution + + +def test_parse_execution_zero(): + code = "print('Hello, World!')" + assert parse_execution(code) == None + + +def test_parse_execution_one(): + code = "print('Hello, World!')" + assert parse_execution(code) == "print('Hello, World!')" + + +def test_parse_execution_no_test_multi_plan_generate(): + code = "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'])" + assert ( + parse_execution(code, False) + == "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'], test_multi_plan=False)" + ) + + +def test_parse_execution_no_test_multi_plan_edit(): + code = "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])" + assert ( + parse_execution(code, False) + == "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])" + ) + + +def test_parse_execution_custom_tool_names_generate(): + code = "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'])" + assert ( + parse_execution( + code, test_multi_plan=False, customed_tool_names=["owl_v2_image"] + ) + == "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'], test_multi_plan=False, custom_tool_names=['owl_v2_image'])" + ) + + +def test_prase_execution_custom_tool_names_edit(): + code = "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])" + assert ( + parse_execution( + code, test_multi_plan=False, customed_tool_names=["owl_v2_image"] + ) + == "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'], custom_tool_names=['owl_v2_image'])" + ) + + +def test_parse_execution_multiple_executes(): + code = "print('Hello, World!')print('Hello, World!')" + assert parse_execution(code) == "print('Hello, World!')\nprint('Hello, World!')" diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 14dec56b..a303c631 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -106,9 +106,20 @@ def parse_execution( customed_tool_names: Optional[List[str]] = None, ) -> Optional[str]: code = None - if "" in response: - code = response[response.find("") + len("") :] - code = code[: code.find("")] + remaining = response + all_code = [] + while "" in remaining: + code_i = remaining[ + remaining.find("") + len("") : + ] + code_i = code_i[: code_i.find("")] + remaining = remaining[ + remaining.find("") + len("") : + ] + all_code.append(code_i) + + if len(all_code) > 0: + code = "\n".join(all_code) if code is not None: code = use_extra_vision_agent_args(code, test_multi_plan, customed_tool_names) @@ -306,6 +317,7 @@ def chat_with_code( ) finished = user_result is not None and user_obs is not None if user_result is not None and user_obs is not None: + # be sure to update the chat with user execution results chat_elt: Message = {"role": "observation", "content": user_obs} int_chat.append(chat_elt) chat_elt["execution"] = user_result diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py index ec5ece0b..aa4d83da 100644 --- a/vision_agent/agent/vision_agent_coder.py +++ b/vision_agent/agent/vision_agent_coder.py @@ -691,7 +691,7 @@ def chat_with_workflow( chat: List[Message], test_multi_plan: bool = True, display_visualization: bool = False, - customized_tool_names: Optional[List[str]] = None, + custom_tool_names: Optional[List[str]] = None, ) -> Dict[str, Any]: """Chat with VisionAgentCoder and return intermediate information regarding the task. @@ -707,8 +707,8 @@ def chat_with_workflow( with the first plan. display_visualization (bool): If True, it opens a new window locally to show the image(s) created by visualization code (if there is any). - customized_tool_names (List[str]): A list of customized tools for agent to pick and use. - If not provided, default to full tool set from vision_agent.tools. + custom_tool_names (List[str]): A list of custom tools for the agent to pick + and use. If not provided, default to full tool set from vision_agent.tools. Returns: Dict[str, Any]: A dictionary containing the code, test, test result, plan, @@ -760,7 +760,7 @@ def chat_with_workflow( success = False plans = self._create_plans( - int_chat, customized_tool_names, working_memory, self.planner + int_chat, custom_tool_names, working_memory, self.planner ) if test_multi_plan: diff --git a/vision_agent/tools/meta_tools.py b/vision_agent/tools/meta_tools.py index 6a5ad289..7d70e031 100644 --- a/vision_agent/tools/meta_tools.py +++ b/vision_agent/tools/meta_tools.py @@ -330,7 +330,7 @@ def generate_vision_code( chat: str, media: List[str], test_multi_plan: bool = True, - customized_tool_names: Optional[List[str]] = None, + custom_tool_names: Optional[List[str]] = None, ) -> str: """Generates python code to solve vision based tasks. @@ -340,7 +340,7 @@ def generate_vision_code( chat (str): The chat message from the user. media (List[str]): The media files to use. test_multi_plan (bool): Do not change this parameter. - customized_tool_names (Optional[List[str]]): Do not change this parameter. + custom_tool_names (Optional[List[str]]): Do not change this parameter. Returns: str: The generated code. @@ -368,7 +368,7 @@ def detect_dogs(image_path: str): response = agent.chat_with_workflow( fixed_chat, test_multi_plan=test_multi_plan, - customized_tool_names=customized_tool_names, + custom_tool_names=custom_tool_names, ) redisplay_results(response["test_result"]) code = response["code"] @@ -448,7 +448,7 @@ def detect_dogs(image_path: str): response = agent.chat_with_workflow( fixed_chat_history, test_multi_plan=False, - customized_tool_names=customized_tool_names, + custom_tool_names=customized_tool_names, ) redisplay_results(response["test_result"]) code = response["code"] @@ -513,11 +513,8 @@ def check_and_load_image(code: str) -> List[str]: return [] pattern = r"view_media_artifact\(\s*([^\)]+),\s*['\"]([^\)]+)['\"]\s*\)" - match = re.search(pattern, code) - if match: - name = match.group(2) - return [name] - return [] + matches = re.findall(pattern, code) + return [match[1] for match in matches] def view_media_artifact(artifacts: Artifacts, name: str) -> str: @@ -620,7 +617,7 @@ def generate_replacer(match: re.Match) -> str: arg = match.group(1) out_str = f"generate_vision_code({arg}, test_multi_plan={test_multi_plan}" if customized_tool_names is not None: - out_str += f", customized_tool_names={customized_tool_names})" + out_str += f", custom_tool_names={customized_tool_names})" else: out_str += ")" return out_str @@ -631,7 +628,7 @@ def edit_replacer(match: re.Match) -> str: arg = match.group(1) out_str = f"edit_vision_code({arg}" if customized_tool_names is not None: - out_str += f", customized_tool_names={customized_tool_names})" + out_str += f", custom_tool_names={customized_tool_names})" else: out_str += ")" return out_str @@ -668,51 +665,28 @@ def use_object_detection_fine_tuning( patterns_with_fine_tune_id = [ ( - r'florence2_phrase_grounding\(\s*"([^"]+)"\s*,\s*([^,]+)(?:,\s*"[^"]+")?\s*\)', + r'florence2_phrase_grounding\(\s*["\']([^"\']+)["\']\s*,\s*([^,]+)(?:,\s*["\'][^"\']+["\'])?\s*\)', lambda match: f'florence2_phrase_grounding("{match.group(1)}", {match.group(2)}, "{fine_tune_id}")', ), ( - r'owl_v2_image\(\s*"([^"]+)"\s*,\s*([^,]+)(?:,\s*"[^"]+")?\s*\)', + r'owl_v2_image\(\s*["\']([^"\']+)["\']\s*,\s*([^,]+)(?:,\s*["\'][^"\']+["\'])?\s*\)', lambda match: f'owl_v2_image("{match.group(1)}", {match.group(2)}, "{fine_tune_id}")', ), ( - r'florence2_sam2_image\(\s*"([^"]+)"\s*,\s*([^,]+)(?:,\s*"[^"]+")?\s*\)', + r'florence2_sam2_image\(\s*["\']([^"\']+)["\']\s*,\s*([^,]+)(?:,\s*["\'][^"\']+["\'])?\s*\)', lambda match: f'florence2_sam2_image("{match.group(1)}", {match.group(2)}, "{fine_tune_id}")', ), ] - patterns_without_fine_tune_id = [ - ( - r"florence2_phrase_grounding\(\s*([^\)]+)\s*\)", - lambda match: f'florence2_phrase_grounding({match.group(1)}, "{fine_tune_id}")', - ), - ( - r"owl_v2_image\(\s*([^\)]+)\s*\)", - lambda match: f'owl_v2_image({match.group(1)}, "{fine_tune_id}")', - ), - ( - r"florence2_sam2_image\(\s*([^\)]+)\s*\)", - lambda match: f'florence2_sam2_image({match.group(1)}, "{fine_tune_id}")', - ), - ] - new_code = code - - for index, (pattern_with_fine_tune_id, replacer_with_fine_tune_id) in enumerate( - patterns_with_fine_tune_id - ): + for ( + pattern_with_fine_tune_id, + replacer_with_fine_tune_id, + ) in patterns_with_fine_tune_id: if re.search(pattern_with_fine_tune_id, new_code): new_code = re.sub( pattern_with_fine_tune_id, replacer_with_fine_tune_id, new_code ) - else: - ( - pattern_without_fine_tune_id, - replacer_without_fine_tune_id, - ) = patterns_without_fine_tune_id[index] - new_code = re.sub( - pattern_without_fine_tune_id, replacer_without_fine_tune_id, new_code - ) if new_code == code: output_str = (