From d9445e3f8b2cf0e41ff028dcfbdcefea130839b5 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Fri, 11 Oct 2024 08:38:50 -0700 Subject: [PATCH] Fix Bugs (#265) * separated out planner, renamed chat methods * fixed circular imports * added type for plan context * add planner as separate call to vision agent * export plan context * fixed circular imports * fixed wrong key * better json parsing * more test cases for json parsing * have planner visualize results * add more guard rails to remove double chat * revert changes with planning step for now * revert to original prompts * fix type issue * fix format issue * skip examples for flake8 * fix names and readme * fixed type error * fix countgd integ test * synced code with new code interpreter arg * separated out planner, renamed chat methods * add planner as separate call to vision agent * revert changes with planning step for now * strip extra function calls from generated code * fix code rewrite issue with () * fix issue if plan format is incorrect * increase count threshold and size * switch to using tags to fix issue of mixing up code and tests * skip tests for flake8 * fix type issues * fix test case * remove extra planning import * fixed type issues * fixed type issues * fix test case * fix format issue --- .github/workflows/ci_cd.yml | 2 +- poetry.lock | 60 +++++++- pyproject.toml | 1 + tests/integ/test_tools.py | 4 +- tests/unit/test_meta_tools.py | 35 +++++ tests/unit/test_va.py | 14 +- tests/unit/test_vac.py | 143 ++++++++++++++++++ vision_agent/agent/agent_utils.py | 22 +++ vision_agent/agent/vision_agent.py | 10 +- vision_agent/agent/vision_agent_coder.py | 86 ++++++++--- .../agent/vision_agent_coder_prompts.py | 30 ++-- vision_agent/agent/vision_agent_planner.py | 34 ++++- vision_agent/tools/meta_tools.py | 55 +++---- vision_agent/tools/tools.py | 4 +- 14 files changed, 402 insertions(+), 98 deletions(-) create mode 100644 tests/unit/test_vac.py diff --git a/.github/workflows/ci_cd.yml b/.github/workflows/ci_cd.yml index 17757846..d2a2f1e3 100644 --- a/.github/workflows/ci_cd.yml +++ b/.github/workflows/ci_cd.yml @@ -43,7 +43,7 @@ jobs: - name: Linting run: | # stop the build if there are Python syntax errors or undefined names - poetry run flake8 . --exclude .venv,examples --count --show-source --statistics + poetry run flake8 . --exclude .venv,examples,tests --count --show-source --statistics - name: Check Format run: | poetry run black --check --diff --color . diff --git a/poetry.lock b/poetry.lock index e03f8871..4cb21c70 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "annotated-types" @@ -58,6 +58,17 @@ doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphin test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] trio = ["trio (>=0.23)"] +[[package]] +name = "appdirs" +version = "1.4.4" +description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +optional = false +python-versions = "*" +files = [ + {file = "appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128"}, + {file = "appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41"}, +] + [[package]] name = "appnope" version = "0.1.4" @@ -189,6 +200,20 @@ files = [ [package.extras] dev = ["freezegun (>=1.0,<2.0)", "pytest (>=6.0)", "pytest-cov"] +[[package]] +name = "baron" +version = "0.10.1" +description = "Full Syntax Tree for python to make writing refactoring code a realist task" +optional = false +python-versions = "*" +files = [ + {file = "baron-0.10.1-py2.py3-none-any.whl", hash = "sha256:befb33f4b9e832c7cd1e3cf0eafa6dd3cb6ed4cb2544245147c019936f4e0a8a"}, + {file = "baron-0.10.1.tar.gz", hash = "sha256:af822ad44d4eb425c8516df4239ac4fdba9fdb398ef77e4924cd7c9b4045bc2f"}, +] + +[package.dependencies] +rply = "*" + [[package]] name = "black" version = "24.8.0" @@ -2756,6 +2781,23 @@ files = [ [package.dependencies] cffi = {version = "*", markers = "implementation_name == \"pypy\""} +[[package]] +name = "redbaron" +version = "0.9.2" +description = "Abstraction on top of baron, a FST for python to make writing refactoring code a realistic task" +optional = false +python-versions = "*" +files = [ + {file = "redbaron-0.9.2-py2.py3-none-any.whl", hash = "sha256:d01032b6a848b5521a8d6ef72486315c2880f420956870cdd742e2b5a09b9bab"}, + {file = "redbaron-0.9.2.tar.gz", hash = "sha256:472d0739ca6b2240bb2278ae428604a75472c9c12e86c6321e8c016139c0132f"}, +] + +[package.dependencies] +baron = ">=0.7" + +[package.extras] +notebook = ["pygments"] + [[package]] name = "referencing" version = "0.35.1" @@ -3030,6 +3072,20 @@ files = [ {file = "rpds_py-0.20.0.tar.gz", hash = "sha256:d72a210824facfdaf8768cf2d7ca25a042c30320b3020de2fa04640920d4e121"}, ] +[[package]] +name = "rply" +version = "0.7.8" +description = "A pure Python Lex/Yacc that works with RPython" +optional = false +python-versions = "*" +files = [ + {file = "rply-0.7.8-py2.py3-none-any.whl", hash = "sha256:28ffd11d656c48aeb8c508eb382acd6a0bd906662624b34388751732a27807e7"}, + {file = "rply-0.7.8.tar.gz", hash = "sha256:2a808ac25a4580a9991fc304d64434e299a8fc75760574492f242cbb5bb301c9"}, +] + +[package.dependencies] +appdirs = "*" + [[package]] name = "scikit-image" version = "0.22.0" @@ -3603,4 +3659,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "5014a9fd241fb625394843263c73feb27ba35b31d13199db45b06985d9e1fbeb" +content-hash = "c0c568df9865d15015b88942284b8555983c689c37757dc10f3ae3e07558812f" diff --git a/pyproject.toml b/pyproject.toml index 05222562..3ae0ff21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ pytube = "15.0.0" anthropic = "^0.31.0" pydantic = "2.7.4" av = "^11.0.0" +redbaron = "^0.9.2" [tool.poetry.group.dev.dependencies] autoflake = "1.*" diff --git a/tests/integ/test_tools.py b/tests/integ/test_tools.py index 9fd9f15c..690795f0 100644 --- a/tests/integ/test_tools.py +++ b/tests/integ/test_tools.py @@ -6,6 +6,8 @@ blip_image_caption, clip, closest_mask_distance, + countgd_counting, + countgd_example_based_counting, depth_anything_v2, detr_segmentation, dpt_hybrid_midas, @@ -32,8 +34,6 @@ template_match, vit_image_classification, vit_nsfw_classification, - countgd_counting, - countgd_example_based_counting, ) FINE_TUNE_ID = "65ebba4a-88b7-419f-9046-0750e30250da" diff --git a/tests/unit/test_meta_tools.py b/tests/unit/test_meta_tools.py index fced644b..6cac95ce 100644 --- a/tests/unit/test_meta_tools.py +++ b/tests/unit/test_meta_tools.py @@ -1,6 +1,7 @@ from vision_agent.tools.meta_tools import ( Artifacts, check_and_load_image, + use_extra_vision_agent_args, use_object_detection_fine_tuning, ) @@ -71,3 +72,37 @@ def test_use_object_detection_fine_tuning_twice(): assert 'owl_v2_image("two", image2, "456")' in output assert 'florence2_sam2_image("three", image3, "456")' in output assert artifacts["code"] == expected_code2 + + +def test_use_object_detection_fine_tuning_real_case(): + artifacts = Artifacts("test") + code = "florence2_phrase_grounding('(strange arg)', image1)" + expected_code = 'florence2_phrase_grounding("(strange arg)", image1, "123")' + artifacts["code"] = code + output = use_object_detection_fine_tuning(artifacts, "code", "123") + assert 'florence2_phrase_grounding("(strange arg)", image1, "123")' in output + assert artifacts["code"] == expected_code + + +def test_use_extra_vision_agent_args_real_case(): + code = "generate_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])" + expected_code = "generate_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], test_multi_plan=True)" + out_code = use_extra_vision_agent_args(code) + assert out_code == expected_code + + code = "edit_vision_code(artifacts, 'code.py', ['write code 1', 'write code 2'], ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])" + expected_code = "edit_vision_code(artifacts, 'code.py', ['write code 1', 'write code 2'], ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], test_multi_plan=True)" + out_code = use_extra_vision_agent_args(code) + assert out_code == expected_code + + +def test_use_extra_vision_args_with_custom_tools(): + code = "generate_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])" + expected_code = "generate_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], test_multi_plan=True, custom_tool_names=['tool1', 'tool2'])" + out_code = use_extra_vision_agent_args(code, custom_tool_names=["tool1", "tool2"]) + assert out_code == expected_code + + code = "edit_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'])" + expected_code = "edit_vision_code(artifacts, 'code.py', 'write code', ['/home/user/n0xn5X6_IMG_2861%20(1).mov'], test_multi_plan=True, custom_tool_names=['tool1', 'tool2'])" + out_code = use_extra_vision_agent_args(code, custom_tool_names=["tool1", "tool2"]) + assert out_code == expected_code diff --git a/tests/unit/test_va.py b/tests/unit/test_va.py index ff4e9b46..3fe619e8 100644 --- a/tests/unit/test_va.py +++ b/tests/unit/test_va.py @@ -23,27 +23,23 @@ def test_parse_execution_no_test_multi_plan_edit(): code = "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])" assert ( parse_execution(code, False) - == "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])" + == "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'], test_multi_plan=False)" ) def test_parse_execution_custom_tool_names_generate(): code = "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'])" assert ( - parse_execution( - code, test_multi_plan=False, customed_tool_names=["owl_v2_image"] - ) + parse_execution(code, test_multi_plan=False, custom_tool_names=["owl_v2_image"]) == "generate_vision_code(artifacts, 'code.py', 'Generate code', ['image.png'], test_multi_plan=False, custom_tool_names=['owl_v2_image'])" ) -def test_prase_execution_custom_tool_names_edit(): +def test_parse_execution_custom_tool_names_edit(): code = "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'])" assert ( - parse_execution( - code, test_multi_plan=False, customed_tool_names=["owl_v2_image"] - ) - == "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'], custom_tool_names=['owl_v2_image'])" + parse_execution(code, test_multi_plan=False, custom_tool_names=["owl_v2_image"]) + == "edit_vision_code(artifacts, 'code.py', ['Generate code'], ['image.png'], test_multi_plan=False, custom_tool_names=['owl_v2_image'])" ) diff --git a/tests/unit/test_vac.py b/tests/unit/test_vac.py new file mode 100644 index 00000000..4f8ead55 --- /dev/null +++ b/tests/unit/test_vac.py @@ -0,0 +1,143 @@ +from vision_agent.agent.vision_agent_coder import strip_function_calls + + +def test_strip_non_function_real_case(): + code = """import os +import numpy as np +from vision_agent.tools import * +from typing import * +from pillow_heif import register_heif_opener +register_heif_opener() +import vision_agent as va +from vision_agent.tools import register_tool + + +from vision_agent.tools import load_image, owl_v2_image, overlay_bounding_boxes, save_image, save_json + +def check_helmets(image_path): + # Load the image + image = load_image(image_path) + + # Detect people and helmets + detections = owl_v2_image("person, helmet", image, box_threshold=0.2) + + # Separate people and helmets + people = [d for d in detections if d['label'] == 'person'] + helmets = [d for d in detections if d['label'] == 'helmet'] + + people_with_helmets = 0 + people_without_helmets = 0 + + height, width = image.shape[:2] + + for person in people: + person_x = (person['bbox'][0] + person['bbox'][2]) / 2 + person_y = person['bbox'][1] # Top of the bounding box + + helmet_found = False + for helmet in helmets: + helmet_x = (helmet['bbox'][0] + helmet['bbox'][2]) / 2 + helmet_y = (helmet['bbox'][1] + helmet['bbox'][3]) / 2 + + # Check if the helmet is within 20 pixels of the person's head + if (abs((helmet_x - person_x) * width) < 20 and + -5 < ((helmet_y - person_y) * height) < 20): + helmet_found = True + break + + if helmet_found: + people_with_helmets += 1 + person['label'] = 'person with helmet' + else: + people_without_helmets += 1 + person['label'] = 'person without helmet' + + # Create the count dictionary + count_dict = { + "people_with_helmets": people_with_helmets, + "people_without_helmets": people_without_helmets + } + + # Visualize the results + visualized_image = overlay_bounding_boxes(image, detections) + + # Save the visualized image + save_image(visualized_image, "/home/user/visualized_result.png") + + # Save the count dictionary as JSON + save_json(count_dict, "/home/user/helmet_counts.json") + + return count_dict + +# The function can be called with the image path +result = check_helmets("/home/user/edQPXGK_workers.png")""" + expected_code = """import os +import numpy as np +from vision_agent.tools import * +from typing import * +from pillow_heif import register_heif_opener +register_heif_opener() +import vision_agent as va +from vision_agent.tools import register_tool + + +from vision_agent.tools import load_image, owl_v2_image, overlay_bounding_boxes, save_image, save_json + +def check_helmets(image_path): + # Load the image + image = load_image(image_path) + + # Detect people and helmets + detections = owl_v2_image("person, helmet", image, box_threshold=0.2) + + # Separate people and helmets + people = [d for d in detections if d['label'] == 'person'] + helmets = [d for d in detections if d['label'] == 'helmet'] + + people_with_helmets = 0 + people_without_helmets = 0 + + height, width = image.shape[:2] + + for person in people: + person_x = (person['bbox'][0] + person['bbox'][2]) / 2 + person_y = person['bbox'][1] # Top of the bounding box + + helmet_found = False + for helmet in helmets: + helmet_x = (helmet['bbox'][0] + helmet['bbox'][2]) / 2 + helmet_y = (helmet['bbox'][1] + helmet['bbox'][3]) / 2 + + # Check if the helmet is within 20 pixels of the person's head + if (abs((helmet_x - person_x) * width) < 20 and + -5 < ((helmet_y - person_y) * height) < 20): + helmet_found = True + break + + if helmet_found: + people_with_helmets += 1 + person['label'] = 'person with helmet' + else: + people_without_helmets += 1 + person['label'] = 'person without helmet' + + # Create the count dictionary + count_dict = { + "people_with_helmets": people_with_helmets, + "people_without_helmets": people_without_helmets + } + + # Visualize the results + visualized_image = overlay_bounding_boxes(image, detections) + + # Save the visualized image + save_image(visualized_image, "/home/user/visualized_result.png") + + # Save the count dictionary as JSON + save_json(count_dict, "/home/user/helmet_counts.json") + + return count_dict + +# The function can be called with the image path""" + code_out = strip_function_calls(code, exclusions=["register_heif_opener"]) + assert code_out == expected_code diff --git a/vision_agent/agent/agent_utils.py b/vision_agent/agent/agent_utils.py index 9b7ea02a..cb7e1b44 100644 --- a/vision_agent/agent/agent_utils.py +++ b/vision_agent/agent/agent_utils.py @@ -13,6 +13,7 @@ logging.basicConfig(stream=sys.stdout) _LOGGER = logging.getLogger(__name__) _CONSOLE = Console() +_MAX_TABULATE_COL_WIDTH = 80 def _extract_sub_json(json_str: str) -> Optional[Dict[str, Any]]: @@ -91,6 +92,27 @@ def extract_code(code: str) -> str: return code +def extract_tag( + content: str, + tag: str, +) -> Optional[str]: + inner_content = None + remaning = content + all_inner_content = [] + + while f"<{tag}>" in remaning: + inner_content_i = remaning[remaning.find(f"<{tag}>") + len(f"<{tag}>") :] + if f"" not in inner_content_i: + break + inner_content_i = inner_content_i[: inner_content_i.find(f"")] + remaning = remaning[remaning.find(f"") + len(f"") :] + all_inner_content.append(inner_content_i) + + if len(all_inner_content) > 0: + inner_content = "\n".join(all_inner_content) + return inner_content + + def remove_installs_from_code(code: str) -> str: pattern = r"\n!pip install.*?(\n|\Z)\n" code = re.sub(pattern, "", code, flags=re.DOTALL) diff --git a/vision_agent/agent/vision_agent.py b/vision_agent/agent/vision_agent.py index 6e1621f0..42541d33 100644 --- a/vision_agent/agent/vision_agent.py +++ b/vision_agent/agent/vision_agent.py @@ -103,7 +103,7 @@ def execute_code_action( def parse_execution( response: str, test_multi_plan: bool = True, - customed_tool_names: Optional[List[str]] = None, + custom_tool_names: Optional[List[str]] = None, ) -> Optional[str]: code = None remaining = response @@ -122,7 +122,7 @@ def parse_execution( code = "\n".join(all_code) if code is not None: - code = use_extra_vision_agent_args(code, test_multi_plan, customed_tool_names) + code = use_extra_vision_agent_args(code, test_multi_plan, custom_tool_names) return code @@ -278,7 +278,7 @@ def chat_with_artifacts( chat: List[Message], artifacts: Optional[Artifacts] = None, test_multi_plan: bool = True, - customized_tool_names: Optional[List[str]] = None, + custom_tool_names: Optional[List[str]] = None, ) -> Tuple[List[Message], Artifacts]: """Chat with VisionAgent, it will use code to execute actions to accomplish its tasks. @@ -292,7 +292,7 @@ def chat_with_artifacts( test_multi_plan (bool): If True, it will test tools for multiple plans and pick the best one based off of the tool results. If False, it will go with the first plan. - customized_tool_names (List[str]): A list of customized tools for agent to + custom_tool_names (List[str]): A list of customized tools for agent to pick and use. If not provided, default to full tool set from vision_agent.tools. @@ -411,7 +411,7 @@ def chat_with_artifacts( finished = response["let_user_respond"] code_action = parse_execution( - response["response"], test_multi_plan, customized_tool_names + response["response"], test_multi_plan, custom_tool_names ) if last_response == response: diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py index f1246f09..9a30f334 100644 --- a/vision_agent/agent/vision_agent_coder.py +++ b/vision_agent/agent/vision_agent_coder.py @@ -5,14 +5,16 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast +from redbaron import RedBaron # type: ignore from tabulate import tabulate import vision_agent.tools as T from vision_agent.agent.agent import Agent from vision_agent.agent.agent_utils import ( + _MAX_TABULATE_COL_WIDTH, DefaultImports, extract_code, - extract_json, + extract_tag, format_memory, print_code, remove_installs_from_code, @@ -45,7 +47,44 @@ logging.basicConfig(stream=sys.stdout) WORKSPACE = Path(os.getenv("WORKSPACE", "")) _LOGGER = logging.getLogger(__name__) -_MAX_TABULATE_COL_WIDTH = 80 + + +def strip_function_calls(code: str, exclusions: Optional[List[str]] = None) -> str: + """This will strip out all code that calls functions except for functions included + in exclusions. + """ + if exclusions is None: + exclusions = [] + + red = RedBaron(code) + nodes_to_remove = [] + for node in red: + if node.type == "def": + continue + elif node.type == "import" or node.type == "from_import": + continue + elif node.type == "call": + if node.value and node.value[0].value in exclusions: + continue + nodes_to_remove.append(node) + elif node.type == "atomtrailers": + if node[0].value in exclusions: + continue + nodes_to_remove.append(node) + elif node.type == "assignment": + if node.value.type == "call" or node.value.type == "atomtrailers": + func_name = node.value[0].value + if func_name in exclusions: + continue + nodes_to_remove.append(node) + elif node.type == "endl": + continue + else: + nodes_to_remove.append(node) + for node in nodes_to_remove: + node.parent.remove(node) + cleaned_code = red.dumps().strip() + return cleaned_code if isinstance(cleaned_code, str) else code def write_code( @@ -130,6 +169,7 @@ def write_and_test_code( plan_thoughts, format_memory(working_memory), ) + code = strip_function_calls(code) test = write_test( tester, chat, tool_utils, code, format_memory(working_memory), media ) @@ -220,7 +260,9 @@ def debug_code( } ) - fixed_code_and_test = {"code": "", "test": "", "reflections": ""} + fixed_code = None + fixed_test = None + thoughts = "" success = False count = 0 while not success and count < 3: @@ -243,21 +285,16 @@ def debug_code( stream=False, ) fixed_code_and_test_str = cast(str, fixed_code_and_test_str) - fixed_code_and_test = extract_json(fixed_code_and_test_str) - code = extract_code(fixed_code_and_test_str) - if ( - "which_code" in fixed_code_and_test - and fixed_code_and_test["which_code"] == "test" - ): - fixed_code_and_test["code"] = "" - fixed_code_and_test["test"] = code - else: # for everything else always assume it's updating code - fixed_code_and_test["code"] = code - fixed_code_and_test["test"] = "" - if "which_code" in fixed_code_and_test: - del fixed_code_and_test["which_code"] - - success = True + thoughts_tag = extract_tag(fixed_code_and_test_str, "thoughts") + thoughts = thoughts_tag if thoughts_tag is not None else "" + fixed_code = extract_tag(fixed_code_and_test_str, "code") + fixed_test = extract_tag(fixed_code_and_test_str, "test") + + if fixed_code is None and fixed_test is None: + success = False + else: + success = True + except Exception as e: _LOGGER.exception(f"Error while extracting JSON: {e}") @@ -266,15 +303,15 @@ def debug_code( old_code = code old_test = test - if fixed_code_and_test["code"].strip() != "": - code = fixed_code_and_test["code"] - if fixed_code_and_test["test"].strip() != "": - test = fixed_code_and_test["test"] + if fixed_code is not None and fixed_code.strip() != "": + code = fixed_code + if fixed_test is not None and fixed_test.strip() != "": + test = fixed_test new_working_memory.append( { "code": f"{code}\n{test}", - "feedback": fixed_code_and_test["reflections"], + "feedback": thoughts, "edits": get_diff(f"{old_code}\n{old_test}", f"{code}\n{test}"), } ) @@ -310,7 +347,7 @@ def debug_code( if verbosity == 2: print_code("Code and test after attempted fix:", code, test) _LOGGER.info( - f"Reflection: {fixed_code_and_test['reflections']}\nCode execution result after attempted fix: {result.text(include_logs=True)}" + f"Reflection: {thoughts}\nCode execution result after attempted fix: {result.text(include_logs=True)}" ) return code, test, result @@ -514,7 +551,6 @@ def generate_code_from_plan( code = remove_installs_from_code(cast(str, results["code"])) test = remove_installs_from_code(cast(str, results["test"])) working_memory.extend(results["working_memory"]) - execution_result = cast(Execution, results["test_result"]) return { diff --git a/vision_agent/agent/vision_agent_coder_prompts.py b/vision_agent/agent/vision_agent_coder_prompts.py index 66eb4c29..ffb83cc2 100644 --- a/vision_agent/agent/vision_agent_coder_prompts.py +++ b/vision_agent/agent/vision_agent_coder_prompts.py @@ -238,35 +238,29 @@ def find_text(image_path: str, text: str) -> str: {docstring} **Instructions**: -Please re-complete the code to fix the error message. Here is the previous version: -```python +Please re-complete the code to fix the error message. Here is the current version of the CODE: + {code} -``` + -When we run this test code: -```python +When we run the TEST code: + {tests} -``` + It raises this error: -``` + {result} -``` + This is previous feedback provided on the code: {feedback} -Please fix the bug by correcting the error. Return the following JSON object followed by the fixed code in the below format: -```json -{{ - "reflections": str # any thoughts you have about the bug and how you fixed it - "which_code": str # the code that was fixed, can only be 'code' or 'test' -}} -``` +Please fix the bug by correcting the error. Return thoughts you have about the bug and how you fixed in tags followed by the fixed CODE in tags and the fixed TEST in tags. For example: -```python -# Your fixed code here -``` +Your thoughts here... +# your fixed code here +# your fixed test here """ diff --git a/vision_agent/agent/vision_agent_planner.py b/vision_agent/agent/vision_agent_planner.py index bb7ac3ba..834cb594 100644 --- a/vision_agent/agent/vision_agent_planner.py +++ b/vision_agent/agent/vision_agent_planner.py @@ -5,10 +5,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast from pydantic import BaseModel +from tabulate import tabulate import vision_agent.tools as T from vision_agent.agent import Agent from vision_agent.agent.agent_utils import ( + _MAX_TABULATE_COL_WIDTH, DefaultImports, extract_code, extract_json, @@ -90,6 +92,18 @@ def retrieve_tools( return tool_lists_unique +def _check_plan_format(plan: Dict[str, Any]) -> bool: + if not isinstance(plan, dict): + return False + + for k in plan: + if "thoughts" not in plan[k] or "instructions" not in plan[k]: + return False + if not isinstance(plan[k]["instructions"], list): + return False + return True + + def write_plans( chat: List[Message], tool_desc: str, working_memory: str, model: LMM ) -> Dict[str, Any]: @@ -105,7 +119,16 @@ def write_plans( feedback=working_memory, ) chat[-1]["content"] = prompt - return extract_json(model(chat, stream=False)) # type: ignore + plans = extract_json(model(chat, stream=False)) # type: ignore + + count = 0 + while not _check_plan_format(plans) and count < 3: + _LOGGER.info("Invalid plan format. Retrying.") + plans = extract_json(model(chat, stream=False)) # type: ignore + count += 1 + if count == 3: + raise ValueError("Failed to generate valid plans after 3 attempts.") + return plans def write_and_exec_plan_tests( @@ -307,7 +330,6 @@ def pick_plan( "payload": plans[plan_thoughts["best_plan"]], } ) - # return plan_thoughts, "```python\n" + code + "\n```\n" + tool_output_str return plan_thoughts, code, tool_output @@ -404,6 +426,14 @@ def generate_plan( format_memory(working_memory), self.planner, ) + if self.verbosity >= 1: + for plan in plans: + plan_fixed = [ + {"instructions": e} for e in plans[plan]["instructions"] + ] + _LOGGER.info( + f"\n{tabulate(tabular_data=plan_fixed, headers='keys', tablefmt='mixed_grid', maxcolwidths=_MAX_TABULATE_COL_WIDTH)}" + ) tool_docs = retrieve_tools( plans, diff --git a/vision_agent/tools/meta_tools.py b/vision_agent/tools/meta_tools.py index 0fb46cee..7f59c685 100644 --- a/vision_agent/tools/meta_tools.py +++ b/vision_agent/tools/meta_tools.py @@ -11,6 +11,7 @@ import numpy as np from IPython.display import display +from redbaron import RedBaron # type: ignore import vision_agent as va from vision_agent.agent.agent_utils import extract_json @@ -24,8 +25,6 @@ from vision_agent.utils.image_utils import convert_to_b64, numpy_to_bytes from vision_agent.utils.video import frames_to_bytes -# These tools are adapted from SWE-Agent https://github.com/princeton-nlp/SWE-agent - CURRENT_FILE = None CURRENT_LINE = 0 DEFAULT_WINDOW_SIZE = 100 @@ -154,6 +153,9 @@ def __contains__(self, name: str) -> bool: return name in self.artifacts +# These tools are adapted from SWE-Agent https://github.com/princeton-nlp/SWE-agent + + def format_lines(lines: List[str], start_idx: int) -> str: output = "" for i, line in enumerate(lines): @@ -491,7 +493,7 @@ def edit_vision_code( name: str, chat_history: List[str], media: List[str], - customized_tool_names: Optional[List[str]] = None, + custom_tool_names: Optional[List[str]] = None, ) -> str: """Edits python code to solve a vision based task. @@ -499,7 +501,7 @@ def edit_vision_code( artifacts (Artifacts): The artifacts object to save the code to. name (str): The file path to the code. chat_history (List[str]): The chat history to used to generate the code. - customized_tool_names (Optional[List[str]]): Do not change this parameter. + custom_tool_names (Optional[List[str]]): Do not change this parameter. Returns: str: The edited code. @@ -542,7 +544,7 @@ def detect_dogs(image_path: str): response = agent.generate_code( fixed_chat_history, test_multi_plan=False, - custom_tool_names=customized_tool_names, + custom_tool_names=custom_tool_names, ) redisplay_results(response["test_result"]) code = response["code"] @@ -705,7 +707,7 @@ def get_diff_with_prompts(name: str, before: str, after: str) -> str: def use_extra_vision_agent_args( code: str, test_multi_plan: bool = True, - customized_tool_names: Optional[List[str]] = None, + custom_tool_names: Optional[List[str]] = None, ) -> str: """This is for forcing arguments passed by the user to VisionAgent into the VisionAgentCoder call. @@ -713,36 +715,25 @@ def use_extra_vision_agent_args( Parameters: code (str): The code to edit. test_multi_plan (bool): Do not change this parameter. - customized_tool_names (Optional[List[str]]): Do not change this parameter. + custom_tool_names (Optional[List[str]]): Do not change this parameter. Returns: str: The edited code. """ - generate_pattern = r"generate_vision_code\(\s*([^\)]+)\s*\)" - - def generate_replacer(match: re.Match) -> str: - arg = match.group(1) - out_str = f"generate_vision_code({arg}, test_multi_plan={test_multi_plan}" - if customized_tool_names is not None: - out_str += f", custom_tool_names={customized_tool_names})" - else: - out_str += ")" - return out_str - - edit_pattern = r"edit_vision_code\(\s*([^\)]+)\s*\)" - - def edit_replacer(match: re.Match) -> str: - arg = match.group(1) - out_str = f"edit_vision_code({arg}" - if customized_tool_names is not None: - out_str += f", custom_tool_names={customized_tool_names})" - else: - out_str += ")" - return out_str - - new_code = re.sub(generate_pattern, generate_replacer, code) - new_code = re.sub(edit_pattern, edit_replacer, new_code) - return new_code + red = RedBaron(code) + for node in red: + # seems to always be atomtrailers not call type + if node.type == "atomtrailers": + if ( + node.name.value == "generate_vision_code" + or node.name.value == "edit_vision_code" + ): + node.value[1].value.append(f"test_multi_plan={test_multi_plan}") + + if custom_tool_names is not None: + node.value[1].value.append(f"custom_tool_names={custom_tool_names}") + cleaned_code = red.dumps().strip() + return cleaned_code if isinstance(cleaned_code, str) else code def use_object_detection_fine_tuning( diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index bf4da892..27a95463 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -1923,7 +1923,7 @@ def overlay_bounding_boxes( bboxes = bbox_int[i] bboxes = sorted(bboxes, key=lambda x: x["label"], reverse=True) - if len(bboxes) > 20: + if len(bboxes) > 40: pil_image = _plot_counting(pil_image, bboxes, color) else: width, height = pil_image.size @@ -2117,7 +2117,7 @@ def _plot_counting( colors: Dict[str, Tuple[int, int, int]], ) -> Image.Image: width, height = image.size - fontsize = max(10, int(min(width, height) / 80)) + fontsize = max(12, int(min(width, height) / 40)) draw = ImageDraw.Draw(image) font = ImageFont.truetype( str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")),