Skip to content

Commit

Permalink
fix code rewrite issue with ()
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Oct 11, 2024
1 parent 4d184bf commit a740a28
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 32 deletions.
4 changes: 2 additions & 2 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
blip_image_caption,
clip,
closest_mask_distance,
countgd_counting,
countgd_example_based_counting,
depth_anything_v2,
detr_segmentation,
dpt_hybrid_midas,
Expand All @@ -32,8 +34,6 @@
template_match,
vit_image_classification,
vit_nsfw_classification,
countgd_counting,
countgd_example_based_counting,
)

FINE_TUNE_ID = "65ebba4a-88b7-419f-9046-0750e30250da"
Expand Down
35 changes: 35 additions & 0 deletions tests/unit/test_meta_tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from vision_agent.tools.meta_tools import (
Artifacts,
check_and_load_image,
use_extra_vision_agent_args,
use_object_detection_fine_tuning,
)

Expand Down Expand Up @@ -71,3 +72,37 @@ def test_use_object_detection_fine_tuning_twice():
assert 'owl_v2_image("two", image2, "456")' in output
assert 'florence2_sam2_image("three", image3, "456")' in output
assert artifacts["code"] == expected_code2


def test_use_object_detection_fine_tuning_real_case():
artifacts = Artifacts("test")
code = "florence2_phrase_grounding('(strange arg)', image1)"
expected_code = 'florence2_phrase_grounding("(strange arg)", image1, "123")'
artifacts["code"] = code
output = use_object_detection_fine_tuning(artifacts, "code", "123")
assert 'florence2_phrase_grounding("(strange arg)", image1, "123")' in output
assert artifacts["code"] == expected_code


def test_use_extra_vision_agent_args_real_case():
code = "generate_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])"
expected_code = "generate_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], test_multi_plan=True)"
out_code = use_extra_vision_agent_args(code)
assert out_code == expected_code

code = "edit_vision_code(artifacts, 'code.py', ['write code 1', 'write code 2'], ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])"
expected_code = "edit_vision_code(artifacts, 'code.py', ['write code 1', 'write code 2'], ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], test_multi_plan=True)"
out_code = use_extra_vision_agent_args(code)
assert out_code == expected_code


def test_use_extra_vision_args_with_custom_tools():
code = "generate_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])"
expected_code = "generate_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], test_multi_plan=True, custom_tool_names=['tool1', 'tool2'])"
out_code = use_extra_vision_agent_args(code, custom_tool_names=["tool1", "tool2"])
assert out_code == expected_code

code = "edit_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])"
expected_code = "edit_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], test_multi_plan=True, custom_tool_names=['tool1', 'tool2'])"
out_code = use_extra_vision_agent_args(code, custom_tool_names=["tool1", "tool2"])
assert out_code == expected_code
49 changes: 19 additions & 30 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import numpy as np
from IPython.display import display
from redbaron import RedBaron

import vision_agent as va
from vision_agent.agent.agent_utils import extract_json
Expand Down Expand Up @@ -491,15 +492,15 @@ def edit_vision_code(
name: str,
chat_history: List[str],
media: List[str],
customized_tool_names: Optional[List[str]] = None,
custom_tool_names: Optional[List[str]] = None,
) -> str:
"""Edits python code to solve a vision based task.
Parameters:
artifacts (Artifacts): The artifacts object to save the code to.
name (str): The file path to the code.
chat_history (List[str]): The chat history to used to generate the code.
customized_tool_names (Optional[List[str]]): Do not change this parameter.
custom_tool_names (Optional[List[str]]): Do not change this parameter.
Returns:
str: The edited code.
Expand Down Expand Up @@ -542,7 +543,7 @@ def detect_dogs(image_path: str):
response = agent.generate_code(
fixed_chat_history,
test_multi_plan=False,
custom_tool_names=customized_tool_names,
custom_tool_names=custom_tool_names,
)
redisplay_results(response["test_result"])
code = response["code"]
Expand Down Expand Up @@ -705,44 +706,32 @@ def get_diff_with_prompts(name: str, before: str, after: str) -> str:
def use_extra_vision_agent_args(
code: str,
test_multi_plan: bool = True,
customized_tool_names: Optional[List[str]] = None,
custom_tool_names: Optional[List[str]] = None,
) -> str:
"""This is for forcing arguments passed by the user to VisionAgent into the
VisionAgentCoder call.
Parameters:
code (str): The code to edit.
test_multi_plan (bool): Do not change this parameter.
customized_tool_names (Optional[List[str]]): Do not change this parameter.
custom_tool_names (Optional[List[str]]): Do not change this parameter.
Returns:
str: The edited code.
"""
generate_pattern = r"generate_vision_code\(\s*([^\)]+)\s*\)"

def generate_replacer(match: re.Match) -> str:
arg = match.group(1)
out_str = f"generate_vision_code({arg}, test_multi_plan={test_multi_plan}"
if customized_tool_names is not None:
out_str += f", custom_tool_names={customized_tool_names})"
else:
out_str += ")"
return out_str

edit_pattern = r"edit_vision_code\(\s*([^\)]+)\s*\)"

def edit_replacer(match: re.Match) -> str:
arg = match.group(1)
out_str = f"edit_vision_code({arg}"
if customized_tool_names is not None:
out_str += f", custom_tool_names={customized_tool_names})"
else:
out_str += ")"
return out_str

new_code = re.sub(generate_pattern, generate_replacer, code)
new_code = re.sub(edit_pattern, edit_replacer, new_code)
return new_code
red = RedBaron(code)
for node in red:
# seems to always be atomtrailers not call type
if node.type == "atomtrailers":
if (
node.name.value == "generate_vision_code"
or node.name.value == "edit_vision_code"
):
node.value[1].value.append(f"test_multi_plan={test_multi_plan}")

if custom_tool_names is not None:
node.value[1].value.append(f"custom_tool_names={custom_tool_names}")
return red.dumps().strip()


def use_object_detection_fine_tuning(
Expand Down

0 comments on commit a740a28

Please sign in to comment.