Skip to content

Commit

Permalink
added more test cases for string replacement funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Sep 26, 2024
1 parent 48b820d commit 4111527
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 48 deletions.
73 changes: 73 additions & 0 deletions tests/unit/test_meta_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from vision_agent.tools.meta_tools import (
Artifacts,
check_and_load_image,
use_object_detection_fine_tuning,
)


def test_check_and_load_image_none():
assert check_and_load_image("print('Hello, World!')") == []


def test_check_and_load_image_one():
assert check_and_load_image("view_media_artifact(artifacts, 'image.jpg')") == [
"image.jpg"
]


def test_check_and_load_image_two():
code = "view_media_artifact(artifacts, 'image1.jpg')\nview_media_artifact(artifacts, 'image2.jpg')"
assert check_and_load_image(code) == ["image1.jpg", "image2.jpg"]


def test_use_object_detection_fine_tuning_none():
artifacts = Artifacts("test")
code = "print('Hello, World!')"
artifacts["code"] = code
output = use_object_detection_fine_tuning(artifacts, "code", "123")
assert (
output == "[No function calls to replace with fine tuning id in artifact code]"
)
assert artifacts["code"] == code


def test_use_object_detection_fine_tuning():
artifacts = Artifacts("test")
code = """florence2_phrase_grounding('one', image1)
owl_v2_image('two', image2)
florence2_sam2_image('three', image3)"""
expected_code = """florence2_phrase_grounding("one", image1, "123")
owl_v2_image("two", image2, "123")
florence2_sam2_image("three", image3, "123")"""
artifacts["code"] = code

output = use_object_detection_fine_tuning(artifacts, "code", "123")
assert 'florence2_phrase_grounding("one", image1, "123")' in output
assert 'owl_v2_image("two", image2, "123")' in output
assert 'florence2_sam2_image("three", image3, "123")' in output
assert artifacts["code"] == expected_code


def test_use_object_detection_fine_tuning_twice():
artifacts = Artifacts("test")
code = """florence2_phrase_grounding('one', image1)
owl_v2_image('two', image2)
florence2_sam2_image('three', image3)"""
expected_code1 = """florence2_phrase_grounding("one", image1, "123")
owl_v2_image("two", image2, "123")
florence2_sam2_image("three", image3, "123")"""
expected_code2 = """florence2_phrase_grounding("one", image1, "456")
owl_v2_image("two", image2, "456")
florence2_sam2_image("three", image3, "456")"""
artifacts["code"] = code
output = use_object_detection_fine_tuning(artifacts, "code", "123")
assert 'florence2_phrase_grounding("one", image1, "123")' in output
assert 'owl_v2_image("two", image2, "123")' in output
assert 'florence2_sam2_image("three", image3, "123")' in output
assert artifacts["code"] == expected_code1

output = use_object_detection_fine_tuning(artifacts, "code", "456")
assert 'florence2_phrase_grounding("one", image1, "456")' in output
assert 'owl_v2_image("two", image2, "456")' in output
assert 'florence2_sam2_image("three", image3, "456")' in output
assert artifacts["code"] == expected_code2
52 changes: 52 additions & 0 deletions tests/unit/test_va.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from vision_agent.agent.vision_agent import parse_execution


def test_parse_execution_zero():
code = "print('Hello, World!')"
assert parse_execution(code) == None


def test_parse_execution_one():
code = "<execute_python>print('Hello, World!')</execute_python>"
assert parse_execution(code) == "print('Hello, World!')"


def test_parse_execution_no_test_multi_plan_generate():
code = "<execute_python>generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'])</execute_python>"
assert (
parse_execution(code, False)
== "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'], test_multi_plan=False)"
)


def test_parse_execution_no_test_multi_plan_edit():
code = "<execute_python>edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])</execute_python>"
assert (
parse_execution(code, False)
== "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])"
)


def test_parse_execution_custom_tool_names_generate():
code = "<execute_python>generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'])</execute_python>"
assert (
parse_execution(
code, test_multi_plan=False, customed_tool_names=["owl_v2_image"]
)
== "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'], test_multi_plan=False, custom_tool_names=['owl_v2_image'])"
)


def test_prase_execution_custom_tool_names_edit():
code = "<execute_python>edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])</execute_python>"
assert (
parse_execution(
code, test_multi_plan=False, customed_tool_names=["owl_v2_image"]
)
== "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'], custom_tool_names=['owl_v2_image'])"
)


def test_parse_execution_multiple_executes():
code = "<execute_python>print('Hello, World!')</execute_python><execute_python>print('Hello, World!')</execute_python>"
assert parse_execution(code) == "print('Hello, World!')\nprint('Hello, World!')"
18 changes: 15 additions & 3 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,20 @@ def parse_execution(
customed_tool_names: Optional[List[str]] = None,
) -> Optional[str]:
code = None
if "<execute_python>" in response:
code = response[response.find("<execute_python>") + len("<execute_python>") :]
code = code[: code.find("</execute_python>")]
remaining = response
all_code = []
while "<execute_python>" in remaining:
code_i = remaining[
remaining.find("<execute_python>") + len("<execute_python>") :
]
code_i = code_i[: code_i.find("</execute_python>")]
remaining = remaining[
remaining.find("</execute_python>") + len("</execute_python>") :
]
all_code.append(code_i)

if len(all_code) > 0:
code = "\n".join(all_code)

if code is not None:
code = use_extra_vision_agent_args(code, test_multi_plan, customed_tool_names)
Expand Down Expand Up @@ -306,6 +317,7 @@ def chat_with_code(
)
finished = user_result is not None and user_obs is not None
if user_result is not None and user_obs is not None:
# be sure to update the chat with user execution results
chat_elt: Message = {"role": "observation", "content": user_obs}
int_chat.append(chat_elt)
chat_elt["execution"] = user_result
Expand Down
8 changes: 4 additions & 4 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ def chat_with_workflow(
chat: List[Message],
test_multi_plan: bool = True,
display_visualization: bool = False,
customized_tool_names: Optional[List[str]] = None,
custom_tool_names: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""Chat with VisionAgentCoder and return intermediate information regarding the
task.
Expand All @@ -707,8 +707,8 @@ def chat_with_workflow(
with the first plan.
display_visualization (bool): If True, it opens a new window locally to
show the image(s) created by visualization code (if there is any).
customized_tool_names (List[str]): A list of customized tools for agent to pick and use.
If not provided, default to full tool set from vision_agent.tools.
custom_tool_names (List[str]): A list of custom tools for the agent to pick
and use. If not provided, default to full tool set from vision_agent.tools.
Returns:
Dict[str, Any]: A dictionary containing the code, test, test result, plan,
Expand Down Expand Up @@ -760,7 +760,7 @@ def chat_with_workflow(
success = False

plans = self._create_plans(
int_chat, customized_tool_names, working_memory, self.planner
int_chat, custom_tool_names, working_memory, self.planner
)

if test_multi_plan:
Expand Down
56 changes: 15 additions & 41 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def generate_vision_code(
chat: str,
media: List[str],
test_multi_plan: bool = True,
customized_tool_names: Optional[List[str]] = None,
custom_tool_names: Optional[List[str]] = None,
) -> str:
"""Generates python code to solve vision based tasks.
Expand All @@ -340,7 +340,7 @@ def generate_vision_code(
chat (str): The chat message from the user.
media (List[str]): The media files to use.
test_multi_plan (bool): Do not change this parameter.
customized_tool_names (Optional[List[str]]): Do not change this parameter.
custom_tool_names (Optional[List[str]]): Do not change this parameter.
Returns:
str: The generated code.
Expand Down Expand Up @@ -368,7 +368,7 @@ def detect_dogs(image_path: str):
response = agent.chat_with_workflow(
fixed_chat,
test_multi_plan=test_multi_plan,
customized_tool_names=customized_tool_names,
custom_tool_names=custom_tool_names,
)
redisplay_results(response["test_result"])
code = response["code"]
Expand Down Expand Up @@ -448,7 +448,7 @@ def detect_dogs(image_path: str):
response = agent.chat_with_workflow(
fixed_chat_history,
test_multi_plan=False,
customized_tool_names=customized_tool_names,
custom_tool_names=customized_tool_names,
)
redisplay_results(response["test_result"])
code = response["code"]
Expand Down Expand Up @@ -513,11 +513,8 @@ def check_and_load_image(code: str) -> List[str]:
return []

pattern = r"view_media_artifact\(\s*([^\)]+),\s*['\"]([^\)]+)['\"]\s*\)"
match = re.search(pattern, code)
if match:
name = match.group(2)
return [name]
return []
matches = re.findall(pattern, code)
return [match[1] for match in matches]


def view_media_artifact(artifacts: Artifacts, name: str) -> str:
Expand Down Expand Up @@ -620,7 +617,7 @@ def generate_replacer(match: re.Match) -> str:
arg = match.group(1)
out_str = f"generate_vision_code({arg}, test_multi_plan={test_multi_plan}"
if customized_tool_names is not None:
out_str += f", customized_tool_names={customized_tool_names})"
out_str += f", custom_tool_names={customized_tool_names})"
else:
out_str += ")"
return out_str
Expand All @@ -631,7 +628,7 @@ def edit_replacer(match: re.Match) -> str:
arg = match.group(1)
out_str = f"edit_vision_code({arg}"
if customized_tool_names is not None:
out_str += f", customized_tool_names={customized_tool_names})"
out_str += f", custom_tool_names={customized_tool_names})"
else:
out_str += ")"
return out_str
Expand Down Expand Up @@ -668,51 +665,28 @@ def use_object_detection_fine_tuning(

patterns_with_fine_tune_id = [
(
r'florence2_phrase_grounding\(\s*"([^"]+)"\s*,\s*([^,]+)(?:,\s*"[^"]+")?\s*\)',
r'florence2_phrase_grounding\(\s*["\']([^"\']+)["\']\s*,\s*([^,]+)(?:,\s*["\'][^"\']+["\'])?\s*\)',
lambda match: f'florence2_phrase_grounding("{match.group(1)}", {match.group(2)}, "{fine_tune_id}")',
),
(
r'owl_v2_image\(\s*"([^"]+)"\s*,\s*([^,]+)(?:,\s*"[^"]+")?\s*\)',
r'owl_v2_image\(\s*["\']([^"\']+)["\']\s*,\s*([^,]+)(?:,\s*["\'][^"\']+["\'])?\s*\)',
lambda match: f'owl_v2_image("{match.group(1)}", {match.group(2)}, "{fine_tune_id}")',
),
(
r'florence2_sam2_image\(\s*"([^"]+)"\s*,\s*([^,]+)(?:,\s*"[^"]+")?\s*\)',
r'florence2_sam2_image\(\s*["\']([^"\']+)["\']\s*,\s*([^,]+)(?:,\s*["\'][^"\']+["\'])?\s*\)',
lambda match: f'florence2_sam2_image("{match.group(1)}", {match.group(2)}, "{fine_tune_id}")',
),
]

patterns_without_fine_tune_id = [
(
r"florence2_phrase_grounding\(\s*([^\)]+)\s*\)",
lambda match: f'florence2_phrase_grounding({match.group(1)}, "{fine_tune_id}")',
),
(
r"owl_v2_image\(\s*([^\)]+)\s*\)",
lambda match: f'owl_v2_image({match.group(1)}, "{fine_tune_id}")',
),
(
r"florence2_sam2_image\(\s*([^\)]+)\s*\)",
lambda match: f'florence2_sam2_image({match.group(1)}, "{fine_tune_id}")',
),
]

new_code = code

for index, (pattern_with_fine_tune_id, replacer_with_fine_tune_id) in enumerate(
patterns_with_fine_tune_id
):
for (
pattern_with_fine_tune_id,
replacer_with_fine_tune_id,
) in patterns_with_fine_tune_id:
if re.search(pattern_with_fine_tune_id, new_code):
new_code = re.sub(
pattern_with_fine_tune_id, replacer_with_fine_tune_id, new_code
)
else:
(
pattern_without_fine_tune_id,
replacer_without_fine_tune_id,
) = patterns_without_fine_tune_id[index]
new_code = re.sub(
pattern_without_fine_tune_id, replacer_without_fine_tune_id, new_code
)

if new_code == code:
output_str = (
Expand Down

0 comments on commit 4111527

Please sign in to comment.