Skip to content

Commit

Permalink
strip extra function calls from generated code
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Oct 11, 2024
1 parent 29d934e commit 4d184bf
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 3 deletions.
60 changes: 58 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pytube = "15.0.0"
anthropic = "^0.31.0"
pydantic = "2.7.4"
av = "^11.0.0"
redbaron = "^0.9.2"

[tool.poetry.group.dev.dependencies]
autoflake = "1.*"
Expand Down
145 changes: 145 additions & 0 deletions tests/unit/test_vac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from vision_agent.agent.vision_agent_coder import strip_function_calls


def test_strip_non_function_real_case():
code = """import os
import numpy as np
from vision_agent.tools import *
from typing import *
from pillow_heif import register_heif_opener
register_heif_opener()
import vision_agent as va
from vision_agent.tools import register_tool
from vision_agent.tools import load_image, owl_v2_image, overlay_bounding_boxes, save_image, save_json
def check_helmets(image_path):
# Load the image
image = load_image(image_path)
# Detect people and helmets
detections = owl_v2_image("person, helmet", image, box_threshold=0.2)
# Separate people and helmets
people = [d for d in detections if d['label'] == 'person']
helmets = [d for d in detections if d['label'] == 'helmet']
people_with_helmets = 0
people_without_helmets = 0
height, width = image.shape[:2]
for person in people:
person_x = (person['bbox'][0] + person['bbox'][2]) / 2
person_y = person['bbox'][1] # Top of the bounding box
helmet_found = False
for helmet in helmets:
helmet_x = (helmet['bbox'][0] + helmet['bbox'][2]) / 2
helmet_y = (helmet['bbox'][1] + helmet['bbox'][3]) / 2
# Check if the helmet is within 20 pixels of the person's head
if (abs((helmet_x - person_x) * width) < 20 and
-5 < ((helmet_y - person_y) * height) < 20):
helmet_found = True
break
if helmet_found:
people_with_helmets += 1
person['label'] = 'person with helmet'
else:
people_without_helmets += 1
person['label'] = 'person without helmet'
# Create the count dictionary
count_dict = {
"people_with_helmets": people_with_helmets,
"people_without_helmets": people_without_helmets
}
# Visualize the results
visualized_image = overlay_bounding_boxes(image, detections)
# Save the visualized image
save_image(visualized_image, "/home/user/visualized_result.png")
# Save the count dictionary as JSON
save_json(count_dict, "/home/user/helmet_counts.json")
return count_dict
# The function can be called with the image path
result = check_helmets("/home/user/edQPXGK_workers.png")"""
expected_code = """
Edit
import os
import numpy as np
from vision_agent.tools import *
from typing import *
from pillow_heif import register_heif_opener
register_heif_opener()
import vision_agent as va
from vision_agent.tools import register_tool
from vision_agent.tools import load_image, owl_v2_image, overlay_bounding_boxes, save_image, save_json
def check_helmets(image_path):
# Load the image
image = load_image(image_path)
# Detect people and helmets
detections = owl_v2_image("person, helmet", image, box_threshold=0.2)
# Separate people and helmets
people = [d for d in detections if d['label'] == 'person']
helmets = [d for d in detections if d['label'] == 'helmet']
people_with_helmets = 0
people_without_helmets = 0
height, width = image.shape[:2]
for person in people:
person_x = (person['bbox'][0] + person['bbox'][2]) / 2
person_y = person['bbox'][1] # Top of the bounding box
helmet_found = False
for helmet in helmets:
helmet_x = (helmet['bbox'][0] + helmet['bbox'][2]) / 2
helmet_y = (helmet['bbox'][1] + helmet['bbox'][3]) / 2
# Check if the helmet is within 20 pixels of the person's head
if (abs((helmet_x - person_x) * width) < 20 and
-5 < ((helmet_y - person_y) * height) < 20):
helmet_found = True
break
if helmet_found:
people_with_helmets += 1
person['label'] = 'person with helmet'
else:
people_without_helmets += 1
person['label'] = 'person without helmet'
# Create the count dictionary
count_dict = {
"people_with_helmets": people_with_helmets,
"people_without_helmets": people_without_helmets
}
# Visualize the results
visualized_image = overlay_bounding_boxes(image, detections)
# Save the visualized image
save_image(visualized_image, "/home/user/visualized_result.png")
# Save the count dictionary as JSON
save_json(count_dict, "/home/user/helmet_counts.json")
return count_dict
# The function can be called with the image path"""
code_out = strip_function_calls(code, exclusions=["register_heif_opener"])
assert code_out == expected_code
42 changes: 41 additions & 1 deletion vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast

from redbaron import RedBaron
from tabulate import tabulate

import vision_agent.tools as T
Expand Down Expand Up @@ -48,6 +49,44 @@
_MAX_TABULATE_COL_WIDTH = 80


def strip_function_calls(code: str, exclusions: Optional[List[str]] = None) -> str:
"""This will strip out all code that calls functions except for functions included
in exclusions.
"""
if exclusions is None:
exclusions = []

red = RedBaron(code)
nodes_to_remove = []
for node in red:
if node.type == "def":
continue
elif node.type == "import" or node.type == "from_import":
continue
elif node.type == "call":
if node.value and node.value[0].value in exclusions:
continue
nodes_to_remove.append(node)
elif node.type == "atomtrailers":
if node[0].value in exclusions:
continue
nodes_to_remove.append(node)
elif node.type == "assignment":
if node.value.type == "call" or node.value.type == "atomtrailers":
func_name = node.value[0].value
if func_name in exclusions:
continue
nodes_to_remove.append(node)
elif node.type == "endl":
continue
else:
nodes_to_remove.append(node)
for node in nodes_to_remove:
node.parent.remove(node)
cleaned_code = red.dumps().strip()
return cleaned_code


def write_code(
coder: LMM,
chat: List[Message],
Expand Down Expand Up @@ -130,6 +169,7 @@ def write_and_test_code(
plan_thoughts,
format_memory(working_memory),
)
code = strip_function_calls(code)
test = write_test(
tester, chat, tool_utils, code, format_memory(working_memory), media
)
Expand Down Expand Up @@ -252,7 +292,7 @@ def debug_code(
fixed_code_and_test["code"] = ""
fixed_code_and_test["test"] = code
else: # for everything else always assume it's updating code
fixed_code_and_test["code"] = code
fixed_code_and_test["code"] = strip_function_calls(code)
fixed_code_and_test["test"] = ""
if "which_code" in fixed_code_and_test:
del fixed_code_and_test["which_code"]
Expand Down

0 comments on commit 4d184bf

Please sign in to comment.