From 4d184bf3c5a15747e88975f882c728e6cae316d9 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Thu, 10 Oct 2024 14:18:42 -0700 Subject: [PATCH] strip extra function calls from generated code --- poetry.lock | 60 +++++++++- pyproject.toml | 1 + tests/unit/test_vac.py | 145 +++++++++++++++++++++++ vision_agent/agent/vision_agent_coder.py | 42 ++++++- 4 files changed, 245 insertions(+), 3 deletions(-) create mode 100644 tests/unit/test_vac.py diff --git a/poetry.lock b/poetry.lock index e03f8871..4cb21c70 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "annotated-types" @@ -58,6 +58,17 @@ doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphin test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] trio = ["trio (>=0.23)"] +[[package]] +name = "appdirs" +version = "1.4.4" +description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +optional = false +python-versions = "*" +files = [ + {file = "appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128"}, + {file = "appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41"}, +] + [[package]] name = "appnope" version = "0.1.4" @@ -189,6 +200,20 @@ files = [ [package.extras] dev = ["freezegun (>=1.0,<2.0)", "pytest (>=6.0)", "pytest-cov"] +[[package]] +name = "baron" +version = "0.10.1" +description = "Full Syntax Tree for python to make writing refactoring code a realist task" +optional = false +python-versions = "*" +files = [ + {file = "baron-0.10.1-py2.py3-none-any.whl", hash = "sha256:befb33f4b9e832c7cd1e3cf0eafa6dd3cb6ed4cb2544245147c019936f4e0a8a"}, + {file = "baron-0.10.1.tar.gz", hash = "sha256:af822ad44d4eb425c8516df4239ac4fdba9fdb398ef77e4924cd7c9b4045bc2f"}, +] + +[package.dependencies] +rply = "*" + [[package]] name = "black" version = "24.8.0" @@ -2756,6 +2781,23 @@ files = [ [package.dependencies] cffi = {version = "*", markers = "implementation_name == \"pypy\""} +[[package]] +name = "redbaron" +version = "0.9.2" +description = "Abstraction on top of baron, a FST for python to make writing refactoring code a realistic task" +optional = false +python-versions = "*" +files = [ + {file = "redbaron-0.9.2-py2.py3-none-any.whl", hash = "sha256:d01032b6a848b5521a8d6ef72486315c2880f420956870cdd742e2b5a09b9bab"}, + {file = "redbaron-0.9.2.tar.gz", hash = "sha256:472d0739ca6b2240bb2278ae428604a75472c9c12e86c6321e8c016139c0132f"}, +] + +[package.dependencies] +baron = ">=0.7" + +[package.extras] +notebook = ["pygments"] + [[package]] name = "referencing" version = "0.35.1" @@ -3030,6 +3072,20 @@ files = [ {file = "rpds_py-0.20.0.tar.gz", hash = "sha256:d72a210824facfdaf8768cf2d7ca25a042c30320b3020de2fa04640920d4e121"}, ] +[[package]] +name = "rply" +version = "0.7.8" +description = "A pure Python Lex/Yacc that works with RPython" +optional = false +python-versions = "*" +files = [ + {file = "rply-0.7.8-py2.py3-none-any.whl", hash = "sha256:28ffd11d656c48aeb8c508eb382acd6a0bd906662624b34388751732a27807e7"}, + {file = "rply-0.7.8.tar.gz", hash = "sha256:2a808ac25a4580a9991fc304d64434e299a8fc75760574492f242cbb5bb301c9"}, +] + +[package.dependencies] +appdirs = "*" + [[package]] name = "scikit-image" version = "0.22.0" @@ -3603,4 +3659,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "5014a9fd241fb625394843263c73feb27ba35b31d13199db45b06985d9e1fbeb" +content-hash = "c0c568df9865d15015b88942284b8555983c689c37757dc10f3ae3e07558812f" diff --git a/pyproject.toml b/pyproject.toml index ca6726a7..fcd7b299 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ pytube = "15.0.0" anthropic = "^0.31.0" pydantic = "2.7.4" av = "^11.0.0" +redbaron = "^0.9.2" [tool.poetry.group.dev.dependencies] autoflake = "1.*" diff --git a/tests/unit/test_vac.py b/tests/unit/test_vac.py new file mode 100644 index 00000000..c0d33219 --- /dev/null +++ b/tests/unit/test_vac.py @@ -0,0 +1,145 @@ +from vision_agent.agent.vision_agent_coder import strip_function_calls + + +def test_strip_non_function_real_case(): + code = """import os +import numpy as np +from vision_agent.tools import * +from typing import * +from pillow_heif import register_heif_opener +register_heif_opener() +import vision_agent as va +from vision_agent.tools import register_tool + + +from vision_agent.tools import load_image, owl_v2_image, overlay_bounding_boxes, save_image, save_json + +def check_helmets(image_path): + # Load the image + image = load_image(image_path) + + # Detect people and helmets + detections = owl_v2_image("person, helmet", image, box_threshold=0.2) + + # Separate people and helmets + people = [d for d in detections if d['label'] == 'person'] + helmets = [d for d in detections if d['label'] == 'helmet'] + + people_with_helmets = 0 + people_without_helmets = 0 + + height, width = image.shape[:2] + + for person in people: + person_x = (person['bbox'][0] + person['bbox'][2]) / 2 + person_y = person['bbox'][1] # Top of the bounding box + + helmet_found = False + for helmet in helmets: + helmet_x = (helmet['bbox'][0] + helmet['bbox'][2]) / 2 + helmet_y = (helmet['bbox'][1] + helmet['bbox'][3]) / 2 + + # Check if the helmet is within 20 pixels of the person's head + if (abs((helmet_x - person_x) * width) < 20 and + -5 < ((helmet_y - person_y) * height) < 20): + helmet_found = True + break + + if helmet_found: + people_with_helmets += 1 + person['label'] = 'person with helmet' + else: + people_without_helmets += 1 + person['label'] = 'person without helmet' + + # Create the count dictionary + count_dict = { + "people_with_helmets": people_with_helmets, + "people_without_helmets": people_without_helmets + } + + # Visualize the results + visualized_image = overlay_bounding_boxes(image, detections) + + # Save the visualized image + save_image(visualized_image, "/home/user/visualized_result.png") + + # Save the count dictionary as JSON + save_json(count_dict, "/home/user/helmet_counts.json") + + return count_dict + +# The function can be called with the image path +result = check_helmets("/home/user/edQPXGK_workers.png")""" + expected_code = """ +Edit +import os +import numpy as np +from vision_agent.tools import * +from typing import * +from pillow_heif import register_heif_opener +register_heif_opener() +import vision_agent as va +from vision_agent.tools import register_tool + + +from vision_agent.tools import load_image, owl_v2_image, overlay_bounding_boxes, save_image, save_json + +def check_helmets(image_path): + # Load the image + image = load_image(image_path) + + # Detect people and helmets + detections = owl_v2_image("person, helmet", image, box_threshold=0.2) + + # Separate people and helmets + people = [d for d in detections if d['label'] == 'person'] + helmets = [d for d in detections if d['label'] == 'helmet'] + + people_with_helmets = 0 + people_without_helmets = 0 + + height, width = image.shape[:2] + + for person in people: + person_x = (person['bbox'][0] + person['bbox'][2]) / 2 + person_y = person['bbox'][1] # Top of the bounding box + + helmet_found = False + for helmet in helmets: + helmet_x = (helmet['bbox'][0] + helmet['bbox'][2]) / 2 + helmet_y = (helmet['bbox'][1] + helmet['bbox'][3]) / 2 + + # Check if the helmet is within 20 pixels of the person's head + if (abs((helmet_x - person_x) * width) < 20 and + -5 < ((helmet_y - person_y) * height) < 20): + helmet_found = True + break + + if helmet_found: + people_with_helmets += 1 + person['label'] = 'person with helmet' + else: + people_without_helmets += 1 + person['label'] = 'person without helmet' + + # Create the count dictionary + count_dict = { + "people_with_helmets": people_with_helmets, + "people_without_helmets": people_without_helmets + } + + # Visualize the results + visualized_image = overlay_bounding_boxes(image, detections) + + # Save the visualized image + save_image(visualized_image, "/home/user/visualized_result.png") + + # Save the count dictionary as JSON + save_json(count_dict, "/home/user/helmet_counts.json") + + return count_dict + +# The function can be called with the image path""" + code_out = strip_function_calls(code, exclusions=["register_heif_opener"]) + assert code_out == expected_code diff --git a/vision_agent/agent/vision_agent_coder.py b/vision_agent/agent/vision_agent_coder.py index 784a83f9..345c4552 100644 --- a/vision_agent/agent/vision_agent_coder.py +++ b/vision_agent/agent/vision_agent_coder.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast +from redbaron import RedBaron from tabulate import tabulate import vision_agent.tools as T @@ -48,6 +49,44 @@ _MAX_TABULATE_COL_WIDTH = 80 +def strip_function_calls(code: str, exclusions: Optional[List[str]] = None) -> str: + """This will strip out all code that calls functions except for functions included + in exclusions. + """ + if exclusions is None: + exclusions = [] + + red = RedBaron(code) + nodes_to_remove = [] + for node in red: + if node.type == "def": + continue + elif node.type == "import" or node.type == "from_import": + continue + elif node.type == "call": + if node.value and node.value[0].value in exclusions: + continue + nodes_to_remove.append(node) + elif node.type == "atomtrailers": + if node[0].value in exclusions: + continue + nodes_to_remove.append(node) + elif node.type == "assignment": + if node.value.type == "call" or node.value.type == "atomtrailers": + func_name = node.value[0].value + if func_name in exclusions: + continue + nodes_to_remove.append(node) + elif node.type == "endl": + continue + else: + nodes_to_remove.append(node) + for node in nodes_to_remove: + node.parent.remove(node) + cleaned_code = red.dumps().strip() + return cleaned_code + + def write_code( coder: LMM, chat: List[Message], @@ -130,6 +169,7 @@ def write_and_test_code( plan_thoughts, format_memory(working_memory), ) + code = strip_function_calls(code) test = write_test( tester, chat, tool_utils, code, format_memory(working_memory), media ) @@ -252,7 +292,7 @@ def debug_code( fixed_code_and_test["code"] = "" fixed_code_and_test["test"] = code else: # for everything else always assume it's updating code - fixed_code_and_test["code"] = code + fixed_code_and_test["code"] = strip_function_calls(code) fixed_code_and_test["test"] = "" if "which_code" in fixed_code_and_test: del fixed_code_and_test["which_code"]