From d14a76fa790702033d99b6c0d1674e5cc5109ff3 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Thu, 26 Sep 2024 14:55:13 -0700 Subject: [PATCH] Fix for several issues with VisionAgent (#251) * fix issues around vision agent coder * fix flake8 * fixed issue where it can't see media from view_media_artifact * fixed user exec obs * fixed side cases with agent * fixed bug with edit vision code * fixed bug with chat app * added more test cases for string replacement funcs * fix linting error --- examples/chat/app.py | 35 ++++--- tests/unit/test_meta_tools.py | 73 ++++++++++++++ tests/unit/test_va.py | 52 ++++++++++ vision_agent/agent/vision_agent.py | 103 ++++++++++++-------- vision_agent/agent/vision_agent_coder.py | 8 +- vision_agent/agent/vision_agent_prompts.py | 10 +- vision_agent/tools/meta_tools.py | 105 ++++++++++----------- 7 files changed, 267 insertions(+), 119 deletions(-) create mode 100644 tests/unit/test_meta_tools.py create mode 100644 tests/unit/test_va.py diff --git a/examples/chat/app.py b/examples/chat/app.py index 9291f65a..0389b2f1 100644 --- a/examples/chat/app.py +++ b/examples/chat/app.py @@ -51,13 +51,12 @@ def update_messages(messages, lock): - if Path("artifacts.pkl").exists(): - artifacts.load("artifacts.pkl") - new_chat, _ = agent.chat_with_code(messages, artifacts=artifacts) with lock: - for new_message in new_chat: - if new_message not in messages: - messages.append(new_message) + if Path("artifacts.pkl").exists(): + artifacts.load("artifacts.pkl") + new_chat, _ = agent.chat_with_code(messages, artifacts=artifacts) + for new_message in new_chat[len(messages) :]: + messages.append(new_message) def get_updates(updates, lock): @@ -106,15 +105,21 @@ def main(): prompt = st.session_state.input_text if prompt: - st.session_state.messages.append({"role": "user", "content": prompt}) - messages.chat_message("user").write(prompt) - message_thread = threading.Thread( - target=update_messages, - args=(st.session_state.messages, message_lock), - ) - message_thread.daemon = True - message_thread.start() - st.session_state.input_text = "" + if ( + len(st.session_state.messages) == 0 + or prompt != st.session_state.messages[-1]["content"] + ): + st.session_state.messages.append( + {"role": "user", "content": prompt} + ) + messages.chat_message("user").write(prompt) + message_thread = threading.Thread( + target=update_messages, + args=(st.session_state.messages, message_lock), + ) + message_thread.daemon = True + message_thread.start() + st.session_state.input_text = "" with tabs[1]: updates = st.container(height=400) diff --git a/tests/unit/test_meta_tools.py b/tests/unit/test_meta_tools.py new file mode 100644 index 00000000..fced644b --- /dev/null +++ b/tests/unit/test_meta_tools.py @@ -0,0 +1,73 @@ +from vision_agent.tools.meta_tools import ( + Artifacts, + check_and_load_image, + use_object_detection_fine_tuning, +) + + +def test_check_and_load_image_none(): + assert check_and_load_image("print('Hello, World!')") == [] + + +def test_check_and_load_image_one(): + assert check_and_load_image("view_media_artifact(artifacts, 'image.jpg')") == [ + "image.jpg" + ] + + +def test_check_and_load_image_two(): + code = "view_media_artifact(artifacts, 'image1.jpg')\nview_media_artifact(artifacts, 'image2.jpg')" + assert check_and_load_image(code) == ["image1.jpg", "image2.jpg"] + + +def test_use_object_detection_fine_tuning_none(): + artifacts = Artifacts("test") + code = "print('Hello, World!')" + artifacts["code"] = code + output = use_object_detection_fine_tuning(artifacts, "code", "123") + assert ( + output == "[No function calls to replace with fine tuning id in artifact code]" + ) + assert artifacts["code"] == code + + +def test_use_object_detection_fine_tuning(): + artifacts = Artifacts("test") + code = """florence2_phrase_grounding('one', image1) +owl_v2_image('two', image2) +florence2_sam2_image('three', image3)""" + expected_code = """florence2_phrase_grounding("one", image1, "123") +owl_v2_image("two", image2, "123") +florence2_sam2_image("three", image3, "123")""" + artifacts["code"] = code + + output = use_object_detection_fine_tuning(artifacts, "code", "123") + assert 'florence2_phrase_grounding("one", image1, "123")' in output + assert 'owl_v2_image("two", image2, "123")' in output + assert 'florence2_sam2_image("three", image3, "123")' in output + assert artifacts["code"] == expected_code + + +def test_use_object_detection_fine_tuning_twice(): + artifacts = Artifacts("test") + code = """florence2_phrase_grounding('one', image1) +owl_v2_image('two', image2) +florence2_sam2_image('three', image3)""" + expected_code1 = """florence2_phrase_grounding("one", image1, "123") +owl_v2_image("two", image2, "123") +florence2_sam2_image("three", image3, "123")""" + expected_code2 = """florence2_phrase_grounding("one", image1, "456") +owl_v2_image("two", image2, "456") +florence2_sam2_image("three", image3, "456")""" + artifacts["code"] = code + output = use_object_detection_fine_tuning(artifacts, "code", "123") + assert 'florence2_phrase_grounding("one", image1, "123")' in output + assert 'owl_v2_image("two", image2, "123")' in output + assert 'florence2_sam2_image("three", image3, "123")' in output + assert artifacts["code"] == expected_code1 + + output = use_object_detection_fine_tuning(artifacts, "code", "456") + assert 'florence2_phrase_grounding("one", image1, "456")' in output + assert 'owl_v2_image("two", image2, "456")' in output + assert 'florence2_sam2_image("three", image3, "456")' in output + assert artifacts["code"] == expected_code2 diff --git a/tests/unit/test_va.py b/tests/unit/test_va.py new file mode 100644 index 00000000..ff4e9b46 --- /dev/null +++ b/tests/unit/test_va.py @@ -0,0 +1,52 @@ +from vision_agent.agent.vision_agent import parse_execution + + +def test_parse_execution_zero(): + code = "print('Hello, World!')" + assert parse_execution(code) is None + + +def test_parse_execution_one(): + code = "print('Hello, World!')" + assert parse_execution(code) == "print('Hello, World!')" + + +def test_parse_execution_no_test_multi_plan_generate(): + code = "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'])" + assert ( + parse_execution(code, False) + == "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'], test_multi_plan=False)" + ) + + +def test_parse_execution_no_test_multi_plan_edit(): + code = "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])" + assert ( + parse_execution(code, False) + == "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])" + ) + + +def test_parse_execution_custom_tool_names_generate(): + code = "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'])" + assert ( + parse_execution( + code, test_multi_plan=False, customed_tool_names=["owl_v2_image"] + ) + == "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'], test_multi_plan=False, custom_tool_names=['owl_v2_image'])" + ) + + +def test_prase_execution_custom_tool_names_edit(): + code = "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])" + assert ( + parse_execution( + code, test_multi_plan=False, customed_tool_names=["owl_v2_image"] + ) + == "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'], custom_tool_names=['owl_v2_image'])" + ) + + +def test_parse_execution_multiple_executes(): + code = "print('Hello, World!')print('Hello, World!')" + assert parse_execution(code) == "print('Hello, World!')\nprint('Hello, World!')" diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index bf35e5e9..a303c631 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -87,7 +87,7 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]: return extract_json(orch([message], stream=False)) # type: ignore -def run_code_action( +def execute_code_action( code: str, code_interpreter: CodeInterpreter, artifact_remote_path: str ) -> Tuple[Execution, str]: result = code_interpreter.exec_isolation( @@ -106,19 +106,53 @@ def parse_execution( customed_tool_names: Optional[List[str]] = None, ) -> Optional[str]: code = None - if "" in response: - code = response[response.find("") + len("") :] - code = code[: code.find("")] + remaining = response + all_code = [] + while "" in remaining: + code_i = remaining[ + remaining.find("") + len("") : + ] + code_i = code_i[: code_i.find("")] + remaining = remaining[ + remaining.find("") + len("") : + ] + all_code.append(code_i) + + if len(all_code) > 0: + code = "\n".join(all_code) if code is not None: code = use_extra_vision_agent_args(code, test_multi_plan, customed_tool_names) return code +def execute_user_code_action( + last_user_message: Message, + code_interpreter: CodeInterpreter, + artifact_remote_path: str, +) -> Tuple[Optional[Execution], Optional[str]]: + user_result = None + user_obs = None + + if last_user_message["role"] != "user": + return user_result, user_obs + + last_user_content = cast(str, last_user_message.get("content", "")) + + user_code_action = parse_execution(last_user_content, False) + if user_code_action is not None: + user_result, user_obs = execute_code_action( + user_code_action, code_interpreter, artifact_remote_path + ) + if user_result.error: + user_obs += f"\n{user_result.error}" + return user_result, user_obs + + class VisionAgent(Agent): """Vision Agent is an agent that can chat with the user and call tools or other agents to generate code for it. Vision Agent uses python code to execute actions - for the user. Vision Agent is inspired by by OpenDev + for the user. Vision Agent is inspired by by OpenDevin https://github.com/OpenDevin/OpenDevin and CodeAct https://arxiv.org/abs/2402.01030 Example @@ -278,9 +312,24 @@ def chat_with_code( orig_chat.append({"role": "observation", "content": artifacts_loaded}) self.streaming_message({"role": "observation", "content": artifacts_loaded}) - finished = self.execute_user_code_action( - last_user_message, code_interpreter, remote_artifacts_path + user_result, user_obs = execute_user_code_action( + last_user_message, code_interpreter, str(remote_artifacts_path) ) + finished = user_result is not None and user_obs is not None + if user_result is not None and user_obs is not None: + # be sure to update the chat with user execution results + chat_elt: Message = {"role": "observation", "content": user_obs} + int_chat.append(chat_elt) + chat_elt["execution"] = user_result + orig_chat.append(chat_elt) + self.streaming_message( + { + "role": "observation", + "content": user_obs, + "execution": user_result, + "finished": finished, + } + ) while not finished and iterations < self.max_iterations: response = run_conversation(self.agent, int_chat) @@ -322,7 +371,7 @@ def chat_with_code( ) if code_action is not None: - result, obs = run_code_action( + result, obs = execute_code_action( code_action, code_interpreter, str(remote_artifacts_path) ) @@ -331,17 +380,17 @@ def chat_with_code( if self.verbosity >= 1: _LOGGER.info(obs) - chat_elt: Message = {"role": "observation", "content": obs} + obs_chat_elt: Message = {"role": "observation", "content": obs} if media_obs and result.success: - chat_elt["media"] = [ + obs_chat_elt["media"] = [ Path(code_interpreter.remote_path) / media_ob for media_ob in media_obs ] # don't add execution results to internal chat - int_chat.append(chat_elt) - chat_elt["execution"] = result - orig_chat.append(chat_elt) + int_chat.append(obs_chat_elt) + obs_chat_elt["execution"] = result + orig_chat.append(obs_chat_elt) self.streaming_message( { "role": "observation", @@ -362,34 +411,6 @@ def chat_with_code( artifacts.save() return orig_chat, artifacts - def execute_user_code_action( - self, - last_user_message: Message, - code_interpreter: CodeInterpreter, - remote_artifacts_path: Path, - ) -> bool: - if last_user_message["role"] != "user": - return False - user_code_action = parse_execution( - cast(str, last_user_message.get("content", "")), False - ) - if user_code_action is not None: - user_result, user_obs = run_code_action( - user_code_action, code_interpreter, str(remote_artifacts_path) - ) - if self.verbosity >= 1: - _LOGGER.info(user_obs) - self.streaming_message( - { - "role": "observation", - "content": user_obs, - "execution": user_result, - "finished": True, - } - ) - return True - return False - def streaming_message(self, message: Dict[str, Any]) -> None: if self.callback_message: self.callback_message(message) diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py index ec5ece0b..aa4d83da 100644 --- a/vision_agent/agent/vision_agent_coder.py +++ b/vision_agent/agent/vision_agent_coder.py @@ -691,7 +691,7 @@ def chat_with_workflow( chat: List[Message], test_multi_plan: bool = True, display_visualization: bool = False, - customized_tool_names: Optional[List[str]] = None, + custom_tool_names: Optional[List[str]] = None, ) -> Dict[str, Any]: """Chat with VisionAgentCoder and return intermediate information regarding the task. @@ -707,8 +707,8 @@ def chat_with_workflow( with the first plan. display_visualization (bool): If True, it opens a new window locally to show the image(s) created by visualization code (if there is any). - customized_tool_names (List[str]): A list of customized tools for agent to pick and use. - If not provided, default to full tool set from vision_agent.tools. + custom_tool_names (List[str]): A list of custom tools for the agent to pick + and use. If not provided, default to full tool set from vision_agent.tools. Returns: Dict[str, Any]: A dictionary containing the code, test, test result, plan, @@ -760,7 +760,7 @@ def chat_with_workflow( success = False plans = self._create_plans( - int_chat, customized_tool_names, working_memory, self.planner + int_chat, custom_tool_names, working_memory, self.planner ) if test_multi_plan: diff --git a/vision_agent/agent/vision_agent_prompts.py b/vision_agent/agent/vision_agent_prompts.py index 80623016..bc3295ef 100644 --- a/vision_agent/agent/vision_agent_prompts.py +++ b/vision_agent/agent/vision_agent_prompts.py @@ -26,7 +26,9 @@ **Instructions**: 1. **Understand and Clarify**: Make sure you understand the task, ask clarifying questions if the task is not clear. -2. **Output in JSON**: Respond in the following format in JSON: +2. **Code Generation**: Only use code provided in the Documentation in your tags. Only use `edit_vision_code` to modify code written by `generate_vision_code`. +3. **Execute**: Do only what the user asked you to do and no more. If you need to ask the user a question, set `let_user_respond` to `true`. +4. **Output in JSON**: Respond in the following format in JSON: ```json {{"thoughts": , "response": , "let_user_respond": }}. @@ -149,7 +151,7 @@ 16| return count [End of artifact] -AGENT: {"thoughts": "I have generated the code to count the workers with helmets in the image, I must now run the code and print the output and write the visualization to the artifacts so I can see the result and the user can see the visaulization.", "response": "from code import count_workers_with_helmets\n print(count_workers_with_helmets('/path/to/images/workers.png', 'workers_viz.png'))\n write_media_artifact(artifacts, 'workers_viz.png')", "let_user_respond": false} +AGENT: {"thoughts": "I have generated the code to count the workers with helmets in the image, I must now run the code and print the output and write the visualization to the artifacts so I can see the result and the user can see the visaulization.", "response": "from code import count_workers_with_helmets\n print(count_workers_with_helmets('/path/to/images/workers.png', 'workers_viz.png'))\n write_media_artifact(artifacts, 'workers_viz.png', 'workers_viz.png')", "let_user_respond": false} OBSERVATION: ----- stdout ----- @@ -164,7 +166,6 @@ OBSERVATION: [Fine tuning id: 23b3b022-5ebf-4798-9373-20ef36429abf] - AGENT: {"thoughts": "The model has finished fine tuning, I will now replace the original florence2_phrase_grounding call with the fine tuning id.", "response": "I will not update the code to use the fine tuned model. use_object_detection_fine_tuning(artifacts, "code.py", "23b3b022-5ebf-4798-9373-20ef36429abf")", "let_user_respond": false} OBSERVATION: @@ -182,8 +183,7 @@ def count_workers_with_helmets(image_path: str, output_path: str): count = 0 [End of edits] - -AGENT: {"thoughts": "I have updated the code to use the fine tuned model, I must now run the code and print the output and write the visualization to the artifacts so I can see the result and the user can see the visaulization.", "response": "from code import count_workers_with_helmets\n print(count_workers_with_helmets('/path/to/images/workers.png', 'workers_viz.png'))\n write_media_artifact(artifacts, 'workers_viz.png')", "let_user_respond": false} +AGENT: {"thoughts": "I have updated the code to use the fine tuned model, I must now run the code and print the output and write the visualization to the artifacts so I can see the result and the user can see the visaulization.", "response": "from code import count_workers_with_helmets\n print(count_workers_with_helmets('/path/to/images/workers.png', 'workers_viz.png'))\n write_media_artifact(artifacts, 'workers_viz.png', 'workers_viz.png')", "let_user_respond": false} OBSERVATION: ----- stdout ----- diff --git a/vision_agent/tools/meta_tools.py b/vision_agent/tools/meta_tools.py index e108f8a4..7d70e031 100644 --- a/vision_agent/tools/meta_tools.py +++ b/vision_agent/tools/meta_tools.py @@ -8,6 +8,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union +import numpy as np from IPython.display import display import vision_agent as va @@ -17,7 +18,8 @@ from vision_agent.tools.tools import TOOL_DESCRIPTIONS from vision_agent.tools.tools_types import BboxInput, BboxInputBase64, PromptTask from vision_agent.utils.execute import Execution, MimeType -from vision_agent.utils.image_utils import convert_to_b64 +from vision_agent.utils.image_utils import convert_to_b64, numpy_to_bytes +from vision_agent.utils.video import frames_to_bytes # These tools are adapted from SWE-Agent https://github.com/princeton-nlp/SWE-agent @@ -328,7 +330,7 @@ def generate_vision_code( chat: str, media: List[str], test_multi_plan: bool = True, - customized_tool_names: Optional[List[str]] = None, + custom_tool_names: Optional[List[str]] = None, ) -> str: """Generates python code to solve vision based tasks. @@ -338,7 +340,7 @@ def generate_vision_code( chat (str): The chat message from the user. media (List[str]): The media files to use. test_multi_plan (bool): Do not change this parameter. - customized_tool_names (Optional[List[str]]): Do not change this parameter. + custom_tool_names (Optional[List[str]]): Do not change this parameter. Returns: str: The generated code. @@ -366,7 +368,7 @@ def detect_dogs(image_path: str): response = agent.chat_with_workflow( fixed_chat, test_multi_plan=test_multi_plan, - customized_tool_names=customized_tool_names, + custom_tool_names=custom_tool_names, ) redisplay_results(response["test_result"]) code = response["code"] @@ -432,19 +434,21 @@ def detect_dogs(image_path: str): # Append latest code to second to last message from assistant fixed_chat_history: List[Message] = [] + user_message = "Previous user requests:" for i, chat in enumerate(chat_history): - if i == 0: - fixed_chat_history.append({"role": "user", "content": chat, "media": media}) - elif i > 0 and i < len(chat_history) - 1: - fixed_chat_history.append({"role": "user", "content": chat}) - elif i == len(chat_history) - 1: + if i < len(chat_history) - 1: + user_message += " " + chat + else: + fixed_chat_history.append( + {"role": "user", "content": user_message, "media": media} + ) fixed_chat_history.append({"role": "assistant", "content": code}) fixed_chat_history.append({"role": "user", "content": chat}) response = agent.chat_with_workflow( fixed_chat_history, test_multi_plan=False, - customized_tool_names=customized_tool_names, + custom_tool_names=customized_tool_names, ) redisplay_results(response["test_result"]) code = response["code"] @@ -467,17 +471,34 @@ def detect_dogs(image_path: str): return view_lines(code_lines, 0, total_lines, name, total_lines) -def write_media_artifact(artifacts: Artifacts, local_path: str) -> str: +def write_media_artifact( + artifacts: Artifacts, + name: str, + media: Union[str, np.ndarray, List[np.ndarray]], + fps: Optional[float] = None, +) -> str: """Writes a media file to the artifacts object. Parameters: artifacts (Artifacts): The artifacts object to save the media to. - local_path (str): The local path to the media file. + name (str): The name of the media artifact to save. + media (Union[str, np.ndarray, List[np.ndarray]]): The media to save, can either + be a file path, single image or list of frames for a video. + fps (Optional[float]): The frames per second if you are writing a video. """ - with open(local_path, "rb") as f: - media = f.read() - artifacts[Path(local_path).name] = media - return f"[Media {Path(local_path).name} saved]" + if isinstance(media, str): + with open(media, "rb") as f: + media_bytes = f.read() + elif isinstance(media, list): + media_bytes = frames_to_bytes(media, fps=fps if fps is not None else 1.0) + elif isinstance(media, np.ndarray): + media_bytes = numpy_to_bytes(media) + else: + print(f"[Invalid media type {type(media)}]") + return f"[Invalid media type {type(media)}]" + artifacts[name] = media_bytes + print(f"[Media {name} saved]") + return f"[Media {name} saved]" def list_artifacts(artifacts: Artifacts) -> str: @@ -491,16 +512,14 @@ def check_and_load_image(code: str) -> List[str]: if not code.strip(): return [] - pattern = r"show_media_artifact\(\s*([^\)]+),\s*['\"]([^\)]+)['\"]\s*\)" - match = re.search(pattern, code) - if match: - name = match.group(2) - return [name] - return [] + pattern = r"view_media_artifact\(\s*([^\)]+),\s*['\"]([^\)]+)['\"]\s*\)" + matches = re.findall(pattern, code) + return [match[1] for match in matches] def view_media_artifact(artifacts: Artifacts, name: str) -> str: - """Views the image artifact with the given name. + """Allows you to view the media artifact with the given name. This does not show + the media to the user, the user can already see all media saved in the artifacts. Parameters: artifacts (Artifacts): The artifacts object to show the image from. @@ -598,7 +617,7 @@ def generate_replacer(match: re.Match) -> str: arg = match.group(1) out_str = f"generate_vision_code({arg}, test_multi_plan={test_multi_plan}" if customized_tool_names is not None: - out_str += f", customized_tool_names={customized_tool_names})" + out_str += f", custom_tool_names={customized_tool_names})" else: out_str += ")" return out_str @@ -609,7 +628,7 @@ def edit_replacer(match: re.Match) -> str: arg = match.group(1) out_str = f"edit_vision_code({arg}" if customized_tool_names is not None: - out_str += f", customized_tool_names={customized_tool_names})" + out_str += f", custom_tool_names={customized_tool_names})" else: out_str += ")" return out_str @@ -646,50 +665,28 @@ def use_object_detection_fine_tuning( patterns_with_fine_tune_id = [ ( - r'florence2_phrase_grounding\(\s*"([^"]+)"\s*,\s*([^,]+)(?:,\s*"[^"]+")?\s*\)', + r'florence2_phrase_grounding\(\s*["\']([^"\']+)["\']\s*,\s*([^,]+)(?:,\s*["\'][^"\']+["\'])?\s*\)', lambda match: f'florence2_phrase_grounding("{match.group(1)}", {match.group(2)}, "{fine_tune_id}")', ), ( - r'owl_v2_image\(\s*"([^"]+)"\s*,\s*([^,]+)(?:,\s*"[^"]+")?\s*\)', + r'owl_v2_image\(\s*["\']([^"\']+)["\']\s*,\s*([^,]+)(?:,\s*["\'][^"\']+["\'])?\s*\)', lambda match: f'owl_v2_image("{match.group(1)}", {match.group(2)}, "{fine_tune_id}")', ), ( - r'florence2_sam2_image\(\s*"([^"]+)"\s*,\s*([^,]+)(?:,\s*"[^"]+")?\s*\)', + r'florence2_sam2_image\(\s*["\']([^"\']+)["\']\s*,\s*([^,]+)(?:,\s*["\'][^"\']+["\'])?\s*\)', lambda match: f'florence2_sam2_image("{match.group(1)}", {match.group(2)}, "{fine_tune_id}")', ), ] - patterns_without_fine_tune_id = [ - ( - r"florence2_phrase_grounding\(\s*([^\)]+)\s*\)", - lambda match: f'florence2_phrase_grounding({match.group(1)}, "{fine_tune_id}")', - ), - ( - r"owl_v2_image\(\s*([^\)]+)\s*\)", - lambda match: f'owl_v2_image({match.group(1)}, "{fine_tune_id}")', - ), - ( - r"florence2_sam2_image\(\s*([^\)]+)\s*\)", - lambda match: f'florence2_sam2_image({match.group(1)}, "{fine_tune_id}")', - ), - ] - new_code = code - - for index, (pattern_with_fine_tune_id, replacer_with_fine_tune_id) in enumerate( - patterns_with_fine_tune_id - ): + for ( + pattern_with_fine_tune_id, + replacer_with_fine_tune_id, + ) in patterns_with_fine_tune_id: if re.search(pattern_with_fine_tune_id, new_code): new_code = re.sub( pattern_with_fine_tune_id, replacer_with_fine_tune_id, new_code ) - else: - (pattern_without_fine_tune_id, replacer_without_fine_tune_id) = ( - patterns_without_fine_tune_id[index] - ) - new_code = re.sub( - pattern_without_fine_tune_id, replacer_without_fine_tune_id, new_code - ) if new_code == code: output_str = (