diff --git a/examples/chat/app.py b/examples/chat/app.py
index 9291f65a..0389b2f1 100644
--- a/examples/chat/app.py
+++ b/examples/chat/app.py
@@ -51,13 +51,12 @@
def update_messages(messages, lock):
- if Path("artifacts.pkl").exists():
- artifacts.load("artifacts.pkl")
- new_chat, _ = agent.chat_with_code(messages, artifacts=artifacts)
with lock:
- for new_message in new_chat:
- if new_message not in messages:
- messages.append(new_message)
+ if Path("artifacts.pkl").exists():
+ artifacts.load("artifacts.pkl")
+ new_chat, _ = agent.chat_with_code(messages, artifacts=artifacts)
+ for new_message in new_chat[len(messages) :]:
+ messages.append(new_message)
def get_updates(updates, lock):
@@ -106,15 +105,21 @@ def main():
prompt = st.session_state.input_text
if prompt:
- st.session_state.messages.append({"role": "user", "content": prompt})
- messages.chat_message("user").write(prompt)
- message_thread = threading.Thread(
- target=update_messages,
- args=(st.session_state.messages, message_lock),
- )
- message_thread.daemon = True
- message_thread.start()
- st.session_state.input_text = ""
+ if (
+ len(st.session_state.messages) == 0
+ or prompt != st.session_state.messages[-1]["content"]
+ ):
+ st.session_state.messages.append(
+ {"role": "user", "content": prompt}
+ )
+ messages.chat_message("user").write(prompt)
+ message_thread = threading.Thread(
+ target=update_messages,
+ args=(st.session_state.messages, message_lock),
+ )
+ message_thread.daemon = True
+ message_thread.start()
+ st.session_state.input_text = ""
with tabs[1]:
updates = st.container(height=400)
diff --git a/tests/unit/test_meta_tools.py b/tests/unit/test_meta_tools.py
new file mode 100644
index 00000000..fced644b
--- /dev/null
+++ b/tests/unit/test_meta_tools.py
@@ -0,0 +1,73 @@
+from vision_agent.tools.meta_tools import (
+ Artifacts,
+ check_and_load_image,
+ use_object_detection_fine_tuning,
+)
+
+
+def test_check_and_load_image_none():
+ assert check_and_load_image("print('Hello, World!')") == []
+
+
+def test_check_and_load_image_one():
+ assert check_and_load_image("view_media_artifact(artifacts, 'image.jpg')") == [
+ "image.jpg"
+ ]
+
+
+def test_check_and_load_image_two():
+ code = "view_media_artifact(artifacts, 'image1.jpg')\nview_media_artifact(artifacts, 'image2.jpg')"
+ assert check_and_load_image(code) == ["image1.jpg", "image2.jpg"]
+
+
+def test_use_object_detection_fine_tuning_none():
+ artifacts = Artifacts("test")
+ code = "print('Hello, World!')"
+ artifacts["code"] = code
+ output = use_object_detection_fine_tuning(artifacts, "code", "123")
+ assert (
+ output == "[No function calls to replace with fine tuning id in artifact code]"
+ )
+ assert artifacts["code"] == code
+
+
+def test_use_object_detection_fine_tuning():
+ artifacts = Artifacts("test")
+ code = """florence2_phrase_grounding('one', image1)
+owl_v2_image('two', image2)
+florence2_sam2_image('three', image3)"""
+ expected_code = """florence2_phrase_grounding("one", image1, "123")
+owl_v2_image("two", image2, "123")
+florence2_sam2_image("three", image3, "123")"""
+ artifacts["code"] = code
+
+ output = use_object_detection_fine_tuning(artifacts, "code", "123")
+ assert 'florence2_phrase_grounding("one", image1, "123")' in output
+ assert 'owl_v2_image("two", image2, "123")' in output
+ assert 'florence2_sam2_image("three", image3, "123")' in output
+ assert artifacts["code"] == expected_code
+
+
+def test_use_object_detection_fine_tuning_twice():
+ artifacts = Artifacts("test")
+ code = """florence2_phrase_grounding('one', image1)
+owl_v2_image('two', image2)
+florence2_sam2_image('three', image3)"""
+ expected_code1 = """florence2_phrase_grounding("one", image1, "123")
+owl_v2_image("two", image2, "123")
+florence2_sam2_image("three", image3, "123")"""
+ expected_code2 = """florence2_phrase_grounding("one", image1, "456")
+owl_v2_image("two", image2, "456")
+florence2_sam2_image("three", image3, "456")"""
+ artifacts["code"] = code
+ output = use_object_detection_fine_tuning(artifacts, "code", "123")
+ assert 'florence2_phrase_grounding("one", image1, "123")' in output
+ assert 'owl_v2_image("two", image2, "123")' in output
+ assert 'florence2_sam2_image("three", image3, "123")' in output
+ assert artifacts["code"] == expected_code1
+
+ output = use_object_detection_fine_tuning(artifacts, "code", "456")
+ assert 'florence2_phrase_grounding("one", image1, "456")' in output
+ assert 'owl_v2_image("two", image2, "456")' in output
+ assert 'florence2_sam2_image("three", image3, "456")' in output
+ assert artifacts["code"] == expected_code2
diff --git a/tests/unit/test_va.py b/tests/unit/test_va.py
new file mode 100644
index 00000000..ff4e9b46
--- /dev/null
+++ b/tests/unit/test_va.py
@@ -0,0 +1,52 @@
+from vision_agent.agent.vision_agent import parse_execution
+
+
+def test_parse_execution_zero():
+ code = "print('Hello, World!')"
+ assert parse_execution(code) is None
+
+
+def test_parse_execution_one():
+ code = "print('Hello, World!')"
+ assert parse_execution(code) == "print('Hello, World!')"
+
+
+def test_parse_execution_no_test_multi_plan_generate():
+ code = "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'])"
+ assert (
+ parse_execution(code, False)
+ == "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'], test_multi_plan=False)"
+ )
+
+
+def test_parse_execution_no_test_multi_plan_edit():
+ code = "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])"
+ assert (
+ parse_execution(code, False)
+ == "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])"
+ )
+
+
+def test_parse_execution_custom_tool_names_generate():
+ code = "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'])"
+ assert (
+ parse_execution(
+ code, test_multi_plan=False, customed_tool_names=["owl_v2_image"]
+ )
+ == "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'], test_multi_plan=False, custom_tool_names=['owl_v2_image'])"
+ )
+
+
+def test_prase_execution_custom_tool_names_edit():
+ code = "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])"
+ assert (
+ parse_execution(
+ code, test_multi_plan=False, customed_tool_names=["owl_v2_image"]
+ )
+ == "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'], custom_tool_names=['owl_v2_image'])"
+ )
+
+
+def test_parse_execution_multiple_executes():
+ code = "print('Hello, World!')print('Hello, World!')"
+ assert parse_execution(code) == "print('Hello, World!')\nprint('Hello, World!')"
diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py
index bf35e5e9..a303c631 100644
--- a/vision_agent/agent/vision_agent.py
+++ b/vision_agent/agent/vision_agent.py
@@ -87,7 +87,7 @@ def run_conversation(orch: LMM, chat: List[Message]) -> Dict[str, Any]:
return extract_json(orch([message], stream=False)) # type: ignore
-def run_code_action(
+def execute_code_action(
code: str, code_interpreter: CodeInterpreter, artifact_remote_path: str
) -> Tuple[Execution, str]:
result = code_interpreter.exec_isolation(
@@ -106,19 +106,53 @@ def parse_execution(
customed_tool_names: Optional[List[str]] = None,
) -> Optional[str]:
code = None
- if "" in response:
- code = response[response.find("") + len("") :]
- code = code[: code.find("")]
+ remaining = response
+ all_code = []
+ while "" in remaining:
+ code_i = remaining[
+ remaining.find("") + len("") :
+ ]
+ code_i = code_i[: code_i.find("")]
+ remaining = remaining[
+ remaining.find("") + len("") :
+ ]
+ all_code.append(code_i)
+
+ if len(all_code) > 0:
+ code = "\n".join(all_code)
if code is not None:
code = use_extra_vision_agent_args(code, test_multi_plan, customed_tool_names)
return code
+def execute_user_code_action(
+ last_user_message: Message,
+ code_interpreter: CodeInterpreter,
+ artifact_remote_path: str,
+) -> Tuple[Optional[Execution], Optional[str]]:
+ user_result = None
+ user_obs = None
+
+ if last_user_message["role"] != "user":
+ return user_result, user_obs
+
+ last_user_content = cast(str, last_user_message.get("content", ""))
+
+ user_code_action = parse_execution(last_user_content, False)
+ if user_code_action is not None:
+ user_result, user_obs = execute_code_action(
+ user_code_action, code_interpreter, artifact_remote_path
+ )
+ if user_result.error:
+ user_obs += f"\n{user_result.error}"
+ return user_result, user_obs
+
+
class VisionAgent(Agent):
"""Vision Agent is an agent that can chat with the user and call tools or other
agents to generate code for it. Vision Agent uses python code to execute actions
- for the user. Vision Agent is inspired by by OpenDev
+ for the user. Vision Agent is inspired by by OpenDevin
https://github.com/OpenDevin/OpenDevin and CodeAct https://arxiv.org/abs/2402.01030
Example
@@ -278,9 +312,24 @@ def chat_with_code(
orig_chat.append({"role": "observation", "content": artifacts_loaded})
self.streaming_message({"role": "observation", "content": artifacts_loaded})
- finished = self.execute_user_code_action(
- last_user_message, code_interpreter, remote_artifacts_path
+ user_result, user_obs = execute_user_code_action(
+ last_user_message, code_interpreter, str(remote_artifacts_path)
)
+ finished = user_result is not None and user_obs is not None
+ if user_result is not None and user_obs is not None:
+ # be sure to update the chat with user execution results
+ chat_elt: Message = {"role": "observation", "content": user_obs}
+ int_chat.append(chat_elt)
+ chat_elt["execution"] = user_result
+ orig_chat.append(chat_elt)
+ self.streaming_message(
+ {
+ "role": "observation",
+ "content": user_obs,
+ "execution": user_result,
+ "finished": finished,
+ }
+ )
while not finished and iterations < self.max_iterations:
response = run_conversation(self.agent, int_chat)
@@ -322,7 +371,7 @@ def chat_with_code(
)
if code_action is not None:
- result, obs = run_code_action(
+ result, obs = execute_code_action(
code_action, code_interpreter, str(remote_artifacts_path)
)
@@ -331,17 +380,17 @@ def chat_with_code(
if self.verbosity >= 1:
_LOGGER.info(obs)
- chat_elt: Message = {"role": "observation", "content": obs}
+ obs_chat_elt: Message = {"role": "observation", "content": obs}
if media_obs and result.success:
- chat_elt["media"] = [
+ obs_chat_elt["media"] = [
Path(code_interpreter.remote_path) / media_ob
for media_ob in media_obs
]
# don't add execution results to internal chat
- int_chat.append(chat_elt)
- chat_elt["execution"] = result
- orig_chat.append(chat_elt)
+ int_chat.append(obs_chat_elt)
+ obs_chat_elt["execution"] = result
+ orig_chat.append(obs_chat_elt)
self.streaming_message(
{
"role": "observation",
@@ -362,34 +411,6 @@ def chat_with_code(
artifacts.save()
return orig_chat, artifacts
- def execute_user_code_action(
- self,
- last_user_message: Message,
- code_interpreter: CodeInterpreter,
- remote_artifacts_path: Path,
- ) -> bool:
- if last_user_message["role"] != "user":
- return False
- user_code_action = parse_execution(
- cast(str, last_user_message.get("content", "")), False
- )
- if user_code_action is not None:
- user_result, user_obs = run_code_action(
- user_code_action, code_interpreter, str(remote_artifacts_path)
- )
- if self.verbosity >= 1:
- _LOGGER.info(user_obs)
- self.streaming_message(
- {
- "role": "observation",
- "content": user_obs,
- "execution": user_result,
- "finished": True,
- }
- )
- return True
- return False
-
def streaming_message(self, message: Dict[str, Any]) -> None:
if self.callback_message:
self.callback_message(message)
diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py
index ec5ece0b..aa4d83da 100644
--- a/vision_agent/agent/vision_agent_coder.py
+++ b/vision_agent/agent/vision_agent_coder.py
@@ -691,7 +691,7 @@ def chat_with_workflow(
chat: List[Message],
test_multi_plan: bool = True,
display_visualization: bool = False,
- customized_tool_names: Optional[List[str]] = None,
+ custom_tool_names: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Chat with VisionAgentCoder and return intermediate information regarding the
task.
@@ -707,8 +707,8 @@ def chat_with_workflow(
with the first plan.
display_visualization (bool): If True, it opens a new window locally to
show the image(s) created by visualization code (if there is any).
- customized_tool_names (List[str]): A list of customized tools for agent to pick and use.
- If not provided, default to full tool set from vision_agent.tools.
+ custom_tool_names (List[str]): A list of custom tools for the agent to pick
+ and use. If not provided, default to full tool set from vision_agent.tools.
Returns:
Dict[str, Any]: A dictionary containing the code, test, test result, plan,
@@ -760,7 +760,7 @@ def chat_with_workflow(
success = False
plans = self._create_plans(
- int_chat, customized_tool_names, working_memory, self.planner
+ int_chat, custom_tool_names, working_memory, self.planner
)
if test_multi_plan:
diff --git a/vision_agent/agent/vision_agent_prompts.py b/vision_agent/agent/vision_agent_prompts.py
index 80623016..bc3295ef 100644
--- a/vision_agent/agent/vision_agent_prompts.py
+++ b/vision_agent/agent/vision_agent_prompts.py
@@ -26,7 +26,9 @@
**Instructions**:
1. **Understand and Clarify**: Make sure you understand the task, ask clarifying questions if the task is not clear.
-2. **Output in JSON**: Respond in the following format in JSON:
+2. **Code Generation**: Only use code provided in the Documentation in your tags. Only use `edit_vision_code` to modify code written by `generate_vision_code`.
+3. **Execute**: Do only what the user asked you to do and no more. If you need to ask the user a question, set `let_user_respond` to `true`.
+4. **Output in JSON**: Respond in the following format in JSON:
```json
{{"thoughts": , "response": , "let_user_respond": }}.
@@ -149,7 +151,7 @@
16| return count
[End of artifact]
-AGENT: {"thoughts": "I have generated the code to count the workers with helmets in the image, I must now run the code and print the output and write the visualization to the artifacts so I can see the result and the user can see the visaulization.", "response": "from code import count_workers_with_helmets\n print(count_workers_with_helmets('/path/to/images/workers.png', 'workers_viz.png'))\n write_media_artifact(artifacts, 'workers_viz.png')", "let_user_respond": false}
+AGENT: {"thoughts": "I have generated the code to count the workers with helmets in the image, I must now run the code and print the output and write the visualization to the artifacts so I can see the result and the user can see the visaulization.", "response": "from code import count_workers_with_helmets\n print(count_workers_with_helmets('/path/to/images/workers.png', 'workers_viz.png'))\n write_media_artifact(artifacts, 'workers_viz.png', 'workers_viz.png')", "let_user_respond": false}
OBSERVATION:
----- stdout -----
@@ -164,7 +166,6 @@
OBSERVATION:
[Fine tuning id: 23b3b022-5ebf-4798-9373-20ef36429abf]
-
AGENT: {"thoughts": "The model has finished fine tuning, I will now replace the original florence2_phrase_grounding call with the fine tuning id.", "response": "I will not update the code to use the fine tuned model. use_object_detection_fine_tuning(artifacts, "code.py", "23b3b022-5ebf-4798-9373-20ef36429abf")", "let_user_respond": false}
OBSERVATION:
@@ -182,8 +183,7 @@ def count_workers_with_helmets(image_path: str, output_path: str):
count = 0
[End of edits]
-
-AGENT: {"thoughts": "I have updated the code to use the fine tuned model, I must now run the code and print the output and write the visualization to the artifacts so I can see the result and the user can see the visaulization.", "response": "from code import count_workers_with_helmets\n print(count_workers_with_helmets('/path/to/images/workers.png', 'workers_viz.png'))\n write_media_artifact(artifacts, 'workers_viz.png')", "let_user_respond": false}
+AGENT: {"thoughts": "I have updated the code to use the fine tuned model, I must now run the code and print the output and write the visualization to the artifacts so I can see the result and the user can see the visaulization.", "response": "from code import count_workers_with_helmets\n print(count_workers_with_helmets('/path/to/images/workers.png', 'workers_viz.png'))\n write_media_artifact(artifacts, 'workers_viz.png', 'workers_viz.png')", "let_user_respond": false}
OBSERVATION:
----- stdout -----
diff --git a/vision_agent/tools/meta_tools.py b/vision_agent/tools/meta_tools.py
index e108f8a4..7d70e031 100644
--- a/vision_agent/tools/meta_tools.py
+++ b/vision_agent/tools/meta_tools.py
@@ -8,6 +8,7 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
+import numpy as np
from IPython.display import display
import vision_agent as va
@@ -17,7 +18,8 @@
from vision_agent.tools.tools import TOOL_DESCRIPTIONS
from vision_agent.tools.tools_types import BboxInput, BboxInputBase64, PromptTask
from vision_agent.utils.execute import Execution, MimeType
-from vision_agent.utils.image_utils import convert_to_b64
+from vision_agent.utils.image_utils import convert_to_b64, numpy_to_bytes
+from vision_agent.utils.video import frames_to_bytes
# These tools are adapted from SWE-Agent https://github.com/princeton-nlp/SWE-agent
@@ -328,7 +330,7 @@ def generate_vision_code(
chat: str,
media: List[str],
test_multi_plan: bool = True,
- customized_tool_names: Optional[List[str]] = None,
+ custom_tool_names: Optional[List[str]] = None,
) -> str:
"""Generates python code to solve vision based tasks.
@@ -338,7 +340,7 @@ def generate_vision_code(
chat (str): The chat message from the user.
media (List[str]): The media files to use.
test_multi_plan (bool): Do not change this parameter.
- customized_tool_names (Optional[List[str]]): Do not change this parameter.
+ custom_tool_names (Optional[List[str]]): Do not change this parameter.
Returns:
str: The generated code.
@@ -366,7 +368,7 @@ def detect_dogs(image_path: str):
response = agent.chat_with_workflow(
fixed_chat,
test_multi_plan=test_multi_plan,
- customized_tool_names=customized_tool_names,
+ custom_tool_names=custom_tool_names,
)
redisplay_results(response["test_result"])
code = response["code"]
@@ -432,19 +434,21 @@ def detect_dogs(image_path: str):
# Append latest code to second to last message from assistant
fixed_chat_history: List[Message] = []
+ user_message = "Previous user requests:"
for i, chat in enumerate(chat_history):
- if i == 0:
- fixed_chat_history.append({"role": "user", "content": chat, "media": media})
- elif i > 0 and i < len(chat_history) - 1:
- fixed_chat_history.append({"role": "user", "content": chat})
- elif i == len(chat_history) - 1:
+ if i < len(chat_history) - 1:
+ user_message += " " + chat
+ else:
+ fixed_chat_history.append(
+ {"role": "user", "content": user_message, "media": media}
+ )
fixed_chat_history.append({"role": "assistant", "content": code})
fixed_chat_history.append({"role": "user", "content": chat})
response = agent.chat_with_workflow(
fixed_chat_history,
test_multi_plan=False,
- customized_tool_names=customized_tool_names,
+ custom_tool_names=customized_tool_names,
)
redisplay_results(response["test_result"])
code = response["code"]
@@ -467,17 +471,34 @@ def detect_dogs(image_path: str):
return view_lines(code_lines, 0, total_lines, name, total_lines)
-def write_media_artifact(artifacts: Artifacts, local_path: str) -> str:
+def write_media_artifact(
+ artifacts: Artifacts,
+ name: str,
+ media: Union[str, np.ndarray, List[np.ndarray]],
+ fps: Optional[float] = None,
+) -> str:
"""Writes a media file to the artifacts object.
Parameters:
artifacts (Artifacts): The artifacts object to save the media to.
- local_path (str): The local path to the media file.
+ name (str): The name of the media artifact to save.
+ media (Union[str, np.ndarray, List[np.ndarray]]): The media to save, can either
+ be a file path, single image or list of frames for a video.
+ fps (Optional[float]): The frames per second if you are writing a video.
"""
- with open(local_path, "rb") as f:
- media = f.read()
- artifacts[Path(local_path).name] = media
- return f"[Media {Path(local_path).name} saved]"
+ if isinstance(media, str):
+ with open(media, "rb") as f:
+ media_bytes = f.read()
+ elif isinstance(media, list):
+ media_bytes = frames_to_bytes(media, fps=fps if fps is not None else 1.0)
+ elif isinstance(media, np.ndarray):
+ media_bytes = numpy_to_bytes(media)
+ else:
+ print(f"[Invalid media type {type(media)}]")
+ return f"[Invalid media type {type(media)}]"
+ artifacts[name] = media_bytes
+ print(f"[Media {name} saved]")
+ return f"[Media {name} saved]"
def list_artifacts(artifacts: Artifacts) -> str:
@@ -491,16 +512,14 @@ def check_and_load_image(code: str) -> List[str]:
if not code.strip():
return []
- pattern = r"show_media_artifact\(\s*([^\)]+),\s*['\"]([^\)]+)['\"]\s*\)"
- match = re.search(pattern, code)
- if match:
- name = match.group(2)
- return [name]
- return []
+ pattern = r"view_media_artifact\(\s*([^\)]+),\s*['\"]([^\)]+)['\"]\s*\)"
+ matches = re.findall(pattern, code)
+ return [match[1] for match in matches]
def view_media_artifact(artifacts: Artifacts, name: str) -> str:
- """Views the image artifact with the given name.
+ """Allows you to view the media artifact with the given name. This does not show
+ the media to the user, the user can already see all media saved in the artifacts.
Parameters:
artifacts (Artifacts): The artifacts object to show the image from.
@@ -598,7 +617,7 @@ def generate_replacer(match: re.Match) -> str:
arg = match.group(1)
out_str = f"generate_vision_code({arg}, test_multi_plan={test_multi_plan}"
if customized_tool_names is not None:
- out_str += f", customized_tool_names={customized_tool_names})"
+ out_str += f", custom_tool_names={customized_tool_names})"
else:
out_str += ")"
return out_str
@@ -609,7 +628,7 @@ def edit_replacer(match: re.Match) -> str:
arg = match.group(1)
out_str = f"edit_vision_code({arg}"
if customized_tool_names is not None:
- out_str += f", customized_tool_names={customized_tool_names})"
+ out_str += f", custom_tool_names={customized_tool_names})"
else:
out_str += ")"
return out_str
@@ -646,50 +665,28 @@ def use_object_detection_fine_tuning(
patterns_with_fine_tune_id = [
(
- r'florence2_phrase_grounding\(\s*"([^"]+)"\s*,\s*([^,]+)(?:,\s*"[^"]+")?\s*\)',
+ r'florence2_phrase_grounding\(\s*["\']([^"\']+)["\']\s*,\s*([^,]+)(?:,\s*["\'][^"\']+["\'])?\s*\)',
lambda match: f'florence2_phrase_grounding("{match.group(1)}", {match.group(2)}, "{fine_tune_id}")',
),
(
- r'owl_v2_image\(\s*"([^"]+)"\s*,\s*([^,]+)(?:,\s*"[^"]+")?\s*\)',
+ r'owl_v2_image\(\s*["\']([^"\']+)["\']\s*,\s*([^,]+)(?:,\s*["\'][^"\']+["\'])?\s*\)',
lambda match: f'owl_v2_image("{match.group(1)}", {match.group(2)}, "{fine_tune_id}")',
),
(
- r'florence2_sam2_image\(\s*"([^"]+)"\s*,\s*([^,]+)(?:,\s*"[^"]+")?\s*\)',
+ r'florence2_sam2_image\(\s*["\']([^"\']+)["\']\s*,\s*([^,]+)(?:,\s*["\'][^"\']+["\'])?\s*\)',
lambda match: f'florence2_sam2_image("{match.group(1)}", {match.group(2)}, "{fine_tune_id}")',
),
]
- patterns_without_fine_tune_id = [
- (
- r"florence2_phrase_grounding\(\s*([^\)]+)\s*\)",
- lambda match: f'florence2_phrase_grounding({match.group(1)}, "{fine_tune_id}")',
- ),
- (
- r"owl_v2_image\(\s*([^\)]+)\s*\)",
- lambda match: f'owl_v2_image({match.group(1)}, "{fine_tune_id}")',
- ),
- (
- r"florence2_sam2_image\(\s*([^\)]+)\s*\)",
- lambda match: f'florence2_sam2_image({match.group(1)}, "{fine_tune_id}")',
- ),
- ]
-
new_code = code
-
- for index, (pattern_with_fine_tune_id, replacer_with_fine_tune_id) in enumerate(
- patterns_with_fine_tune_id
- ):
+ for (
+ pattern_with_fine_tune_id,
+ replacer_with_fine_tune_id,
+ ) in patterns_with_fine_tune_id:
if re.search(pattern_with_fine_tune_id, new_code):
new_code = re.sub(
pattern_with_fine_tune_id, replacer_with_fine_tune_id, new_code
)
- else:
- (pattern_without_fine_tune_id, replacer_without_fine_tune_id) = (
- patterns_without_fine_tune_id[index]
- )
- new_code = re.sub(
- pattern_without_fine_tune_id, replacer_without_fine_tune_id, new_code
- )
if new_code == code:
output_str = (