diff --git a/tests/unit/test_meta_tools.py b/tests/unit/test_meta_tools.py
new file mode 100644
index 00000000..fced644b
--- /dev/null
+++ b/tests/unit/test_meta_tools.py
@@ -0,0 +1,73 @@
+from vision_agent.tools.meta_tools import (
+ Artifacts,
+ check_and_load_image,
+ use_object_detection_fine_tuning,
+)
+
+
+def test_check_and_load_image_none():
+ assert check_and_load_image("print('Hello, World!')") == []
+
+
+def test_check_and_load_image_one():
+ assert check_and_load_image("view_media_artifact(artifacts, 'image.jpg')") == [
+ "image.jpg"
+ ]
+
+
+def test_check_and_load_image_two():
+ code = "view_media_artifact(artifacts, 'image1.jpg')\nview_media_artifact(artifacts, 'image2.jpg')"
+ assert check_and_load_image(code) == ["image1.jpg", "image2.jpg"]
+
+
+def test_use_object_detection_fine_tuning_none():
+ artifacts = Artifacts("test")
+ code = "print('Hello, World!')"
+ artifacts["code"] = code
+ output = use_object_detection_fine_tuning(artifacts, "code", "123")
+ assert (
+ output == "[No function calls to replace with fine tuning id in artifact code]"
+ )
+ assert artifacts["code"] == code
+
+
+def test_use_object_detection_fine_tuning():
+ artifacts = Artifacts("test")
+ code = """florence2_phrase_grounding('one', image1)
+owl_v2_image('two', image2)
+florence2_sam2_image('three', image3)"""
+ expected_code = """florence2_phrase_grounding("one", image1, "123")
+owl_v2_image("two", image2, "123")
+florence2_sam2_image("three", image3, "123")"""
+ artifacts["code"] = code
+
+ output = use_object_detection_fine_tuning(artifacts, "code", "123")
+ assert 'florence2_phrase_grounding("one", image1, "123")' in output
+ assert 'owl_v2_image("two", image2, "123")' in output
+ assert 'florence2_sam2_image("three", image3, "123")' in output
+ assert artifacts["code"] == expected_code
+
+
+def test_use_object_detection_fine_tuning_twice():
+ artifacts = Artifacts("test")
+ code = """florence2_phrase_grounding('one', image1)
+owl_v2_image('two', image2)
+florence2_sam2_image('three', image3)"""
+ expected_code1 = """florence2_phrase_grounding("one", image1, "123")
+owl_v2_image("two", image2, "123")
+florence2_sam2_image("three", image3, "123")"""
+ expected_code2 = """florence2_phrase_grounding("one", image1, "456")
+owl_v2_image("two", image2, "456")
+florence2_sam2_image("three", image3, "456")"""
+ artifacts["code"] = code
+ output = use_object_detection_fine_tuning(artifacts, "code", "123")
+ assert 'florence2_phrase_grounding("one", image1, "123")' in output
+ assert 'owl_v2_image("two", image2, "123")' in output
+ assert 'florence2_sam2_image("three", image3, "123")' in output
+ assert artifacts["code"] == expected_code1
+
+ output = use_object_detection_fine_tuning(artifacts, "code", "456")
+ assert 'florence2_phrase_grounding("one", image1, "456")' in output
+ assert 'owl_v2_image("two", image2, "456")' in output
+ assert 'florence2_sam2_image("three", image3, "456")' in output
+ assert artifacts["code"] == expected_code2
diff --git a/tests/unit/test_va.py b/tests/unit/test_va.py
new file mode 100644
index 00000000..85e75426
--- /dev/null
+++ b/tests/unit/test_va.py
@@ -0,0 +1,52 @@
+from vision_agent.agent.vision_agent import parse_execution
+
+
+def test_parse_execution_zero():
+ code = "print('Hello, World!')"
+ assert parse_execution(code) == None
+
+
+def test_parse_execution_one():
+ code = "print('Hello, World!')"
+ assert parse_execution(code) == "print('Hello, World!')"
+
+
+def test_parse_execution_no_test_multi_plan_generate():
+ code = "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'])"
+ assert (
+ parse_execution(code, False)
+ == "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'], test_multi_plan=False)"
+ )
+
+
+def test_parse_execution_no_test_multi_plan_edit():
+ code = "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])"
+ assert (
+ parse_execution(code, False)
+ == "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])"
+ )
+
+
+def test_parse_execution_custom_tool_names_generate():
+ code = "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'])"
+ assert (
+ parse_execution(
+ code, test_multi_plan=False, customed_tool_names=["owl_v2_image"]
+ )
+ == "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'], test_multi_plan=False, custom_tool_names=['owl_v2_image'])"
+ )
+
+
+def test_prase_execution_custom_tool_names_edit():
+ code = "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])"
+ assert (
+ parse_execution(
+ code, test_multi_plan=False, customed_tool_names=["owl_v2_image"]
+ )
+ == "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'], custom_tool_names=['owl_v2_image'])"
+ )
+
+
+def test_parse_execution_multiple_executes():
+ code = "print('Hello, World!')print('Hello, World!')"
+ assert parse_execution(code) == "print('Hello, World!')\nprint('Hello, World!')"
diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py
index 14dec56b..a303c631 100644
--- a/vision_agent/agent/vision_agent.py
+++ b/vision_agent/agent/vision_agent.py
@@ -106,9 +106,20 @@ def parse_execution(
customed_tool_names: Optional[List[str]] = None,
) -> Optional[str]:
code = None
- if "" in response:
- code = response[response.find("") + len("") :]
- code = code[: code.find("")]
+ remaining = response
+ all_code = []
+ while "" in remaining:
+ code_i = remaining[
+ remaining.find("") + len("") :
+ ]
+ code_i = code_i[: code_i.find("")]
+ remaining = remaining[
+ remaining.find("") + len("") :
+ ]
+ all_code.append(code_i)
+
+ if len(all_code) > 0:
+ code = "\n".join(all_code)
if code is not None:
code = use_extra_vision_agent_args(code, test_multi_plan, customed_tool_names)
@@ -306,6 +317,7 @@ def chat_with_code(
)
finished = user_result is not None and user_obs is not None
if user_result is not None and user_obs is not None:
+ # be sure to update the chat with user execution results
chat_elt: Message = {"role": "observation", "content": user_obs}
int_chat.append(chat_elt)
chat_elt["execution"] = user_result
diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py
index ec5ece0b..aa4d83da 100644
--- a/vision_agent/agent/vision_agent_coder.py
+++ b/vision_agent/agent/vision_agent_coder.py
@@ -691,7 +691,7 @@ def chat_with_workflow(
chat: List[Message],
test_multi_plan: bool = True,
display_visualization: bool = False,
- customized_tool_names: Optional[List[str]] = None,
+ custom_tool_names: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Chat with VisionAgentCoder and return intermediate information regarding the
task.
@@ -707,8 +707,8 @@ def chat_with_workflow(
with the first plan.
display_visualization (bool): If True, it opens a new window locally to
show the image(s) created by visualization code (if there is any).
- customized_tool_names (List[str]): A list of customized tools for agent to pick and use.
- If not provided, default to full tool set from vision_agent.tools.
+ custom_tool_names (List[str]): A list of custom tools for the agent to pick
+ and use. If not provided, default to full tool set from vision_agent.tools.
Returns:
Dict[str, Any]: A dictionary containing the code, test, test result, plan,
@@ -760,7 +760,7 @@ def chat_with_workflow(
success = False
plans = self._create_plans(
- int_chat, customized_tool_names, working_memory, self.planner
+ int_chat, custom_tool_names, working_memory, self.planner
)
if test_multi_plan:
diff --git a/vision_agent/tools/meta_tools.py b/vision_agent/tools/meta_tools.py
index 6a5ad289..7d70e031 100644
--- a/vision_agent/tools/meta_tools.py
+++ b/vision_agent/tools/meta_tools.py
@@ -330,7 +330,7 @@ def generate_vision_code(
chat: str,
media: List[str],
test_multi_plan: bool = True,
- customized_tool_names: Optional[List[str]] = None,
+ custom_tool_names: Optional[List[str]] = None,
) -> str:
"""Generates python code to solve vision based tasks.
@@ -340,7 +340,7 @@ def generate_vision_code(
chat (str): The chat message from the user.
media (List[str]): The media files to use.
test_multi_plan (bool): Do not change this parameter.
- customized_tool_names (Optional[List[str]]): Do not change this parameter.
+ custom_tool_names (Optional[List[str]]): Do not change this parameter.
Returns:
str: The generated code.
@@ -368,7 +368,7 @@ def detect_dogs(image_path: str):
response = agent.chat_with_workflow(
fixed_chat,
test_multi_plan=test_multi_plan,
- customized_tool_names=customized_tool_names,
+ custom_tool_names=custom_tool_names,
)
redisplay_results(response["test_result"])
code = response["code"]
@@ -448,7 +448,7 @@ def detect_dogs(image_path: str):
response = agent.chat_with_workflow(
fixed_chat_history,
test_multi_plan=False,
- customized_tool_names=customized_tool_names,
+ custom_tool_names=customized_tool_names,
)
redisplay_results(response["test_result"])
code = response["code"]
@@ -513,11 +513,8 @@ def check_and_load_image(code: str) -> List[str]:
return []
pattern = r"view_media_artifact\(\s*([^\)]+),\s*['\"]([^\)]+)['\"]\s*\)"
- match = re.search(pattern, code)
- if match:
- name = match.group(2)
- return [name]
- return []
+ matches = re.findall(pattern, code)
+ return [match[1] for match in matches]
def view_media_artifact(artifacts: Artifacts, name: str) -> str:
@@ -620,7 +617,7 @@ def generate_replacer(match: re.Match) -> str:
arg = match.group(1)
out_str = f"generate_vision_code({arg}, test_multi_plan={test_multi_plan}"
if customized_tool_names is not None:
- out_str += f", customized_tool_names={customized_tool_names})"
+ out_str += f", custom_tool_names={customized_tool_names})"
else:
out_str += ")"
return out_str
@@ -631,7 +628,7 @@ def edit_replacer(match: re.Match) -> str:
arg = match.group(1)
out_str = f"edit_vision_code({arg}"
if customized_tool_names is not None:
- out_str += f", customized_tool_names={customized_tool_names})"
+ out_str += f", custom_tool_names={customized_tool_names})"
else:
out_str += ")"
return out_str
@@ -668,51 +665,28 @@ def use_object_detection_fine_tuning(
patterns_with_fine_tune_id = [
(
- r'florence2_phrase_grounding\(\s*"([^"]+)"\s*,\s*([^,]+)(?:,\s*"[^"]+")?\s*\)',
+ r'florence2_phrase_grounding\(\s*["\']([^"\']+)["\']\s*,\s*([^,]+)(?:,\s*["\'][^"\']+["\'])?\s*\)',
lambda match: f'florence2_phrase_grounding("{match.group(1)}", {match.group(2)}, "{fine_tune_id}")',
),
(
- r'owl_v2_image\(\s*"([^"]+)"\s*,\s*([^,]+)(?:,\s*"[^"]+")?\s*\)',
+ r'owl_v2_image\(\s*["\']([^"\']+)["\']\s*,\s*([^,]+)(?:,\s*["\'][^"\']+["\'])?\s*\)',
lambda match: f'owl_v2_image("{match.group(1)}", {match.group(2)}, "{fine_tune_id}")',
),
(
- r'florence2_sam2_image\(\s*"([^"]+)"\s*,\s*([^,]+)(?:,\s*"[^"]+")?\s*\)',
+ r'florence2_sam2_image\(\s*["\']([^"\']+)["\']\s*,\s*([^,]+)(?:,\s*["\'][^"\']+["\'])?\s*\)',
lambda match: f'florence2_sam2_image("{match.group(1)}", {match.group(2)}, "{fine_tune_id}")',
),
]
- patterns_without_fine_tune_id = [
- (
- r"florence2_phrase_grounding\(\s*([^\)]+)\s*\)",
- lambda match: f'florence2_phrase_grounding({match.group(1)}, "{fine_tune_id}")',
- ),
- (
- r"owl_v2_image\(\s*([^\)]+)\s*\)",
- lambda match: f'owl_v2_image({match.group(1)}, "{fine_tune_id}")',
- ),
- (
- r"florence2_sam2_image\(\s*([^\)]+)\s*\)",
- lambda match: f'florence2_sam2_image({match.group(1)}, "{fine_tune_id}")',
- ),
- ]
-
new_code = code
-
- for index, (pattern_with_fine_tune_id, replacer_with_fine_tune_id) in enumerate(
- patterns_with_fine_tune_id
- ):
+ for (
+ pattern_with_fine_tune_id,
+ replacer_with_fine_tune_id,
+ ) in patterns_with_fine_tune_id:
if re.search(pattern_with_fine_tune_id, new_code):
new_code = re.sub(
pattern_with_fine_tune_id, replacer_with_fine_tune_id, new_code
)
- else:
- (
- pattern_without_fine_tune_id,
- replacer_without_fine_tune_id,
- ) = patterns_without_fine_tune_id[index]
- new_code = re.sub(
- pattern_without_fine_tune_id, replacer_without_fine_tune_id, new_code
- )
if new_code == code:
output_str = (