Skip to content

Commit

Permalink
fix type errors'
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Apr 29, 2024
1 parent adf5efe commit bb8550f
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 58 deletions.
5 changes: 3 additions & 2 deletions vision_agent/agent/agent_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def run_visual_tests(
feedback=feedback,
)
completion = model(prompt, images=[viz_file])
return json.loads(completion)
# type is from the prompt
return json.loads(completion) # type: ignore


def fix_bugs(code: str, tests: str, result: str, feedback: str, model: LLM) -> str:
Expand Down Expand Up @@ -147,7 +148,7 @@ def chat(

if not results["passed"]:
code = fix_bugs(
code, debug, results["result"].strip(), feedback, self.coder_agent
code, debug, results["result"].strip(), feedback, self.coder_agent # type: ignore
)
_LOGGER.info(f"fixed code:\n{code}")
else:
Expand Down
101 changes: 50 additions & 51 deletions vision_agent/agent/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import tempfile
import traceback
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, Generator, List, Optional, Union

IMPORT_HELPER = """
import math
Expand Down Expand Up @@ -53,7 +53,6 @@ def unsafe_execute(code: str, timeout: float, result: List) -> None:
reliability_guard()

try:
exec_globals = {}
with swallow_io() as s:
with time_limit(timeout):
# WARNING
Expand All @@ -65,8 +64,8 @@ def unsafe_execute(code: str, timeout: float, result: List) -> None:
# does not perform destructive actions on their host or network.
# Once you have read this disclaimer and taken appropriate precautions,
# uncomment the following line and proceed at your own risk:
code = compile(code, code_path, "exec")
exec(code, exec_globals)
code = compile(code, code_path, "exec") # type: ignore
exec(code)
result.append({"output": s.getvalue(), "passed": True})
except TimeoutError:
result.append({"output": "Timed out", "passed": False})
Expand Down Expand Up @@ -133,12 +132,12 @@ def check_correctness(
# ============================================================================


class redirect_stdin(contextlib._RedirectStream): # type: ignore
class redirect_stdin(contextlib._RedirectStream):
_stream = "stdin"


@contextlib.contextmanager
def chdir(root):
def chdir(root: str) -> Generator[None, None, None]:
if root == ".":
yield
return
Expand All @@ -155,29 +154,29 @@ def chdir(root):
class WriteOnlyStringIO(io.StringIO):
"""StringIO that throws an exception when it's read from"""

def read(self, *args, **kwargs):
def read(self, *args, **kwargs): # type: ignore
raise IOError

def readline(self, *args, **kwargs):
def readline(self, *args, **kwargs): # type: ignore
raise IOError

def readlines(self, *args, **kwargs):
def readlines(self, *args, **kwargs): # type: ignore
raise IOError

def readable(self, *args, **kwargs):
def readable(self, *args, **kwargs): # type: ignore
"""Returns True if the IO object can be read."""
return False


@contextlib.contextmanager
def create_tempdir():
def create_tempdir() -> Generator[str, None, None]:
with tempfile.TemporaryDirectory() as dirname:
with chdir(dirname):
yield dirname


@contextlib.contextmanager
def swallow_io():
def swallow_io() -> Generator[WriteOnlyStringIO, None, None]:
stream = WriteOnlyStringIO()
with contextlib.redirect_stdout(stream):
with contextlib.redirect_stderr(stream):
Expand All @@ -186,8 +185,8 @@ def swallow_io():


@contextlib.contextmanager
def time_limit(seconds: float):
def signal_handler(signum, frame):
def time_limit(seconds: float) -> Generator[None, None, None]:
def signal_handler(signum, frame): # type: ignore
raise TimeoutError("Timed out!")

signal.setitimer(signal.ITIMER_REAL, seconds)
Expand All @@ -198,7 +197,7 @@ def signal_handler(signum, frame):
signal.setitimer(signal.ITIMER_REAL, 0)


def reliability_guard(maximum_memory_bytes: Optional[int] = None):
def reliability_guard(maximum_memory_bytes: Optional[int] = None) -> None:
"""
This disables various destructive functions and prevents the generated code
from interfering with the test (e.g. fork bomb, killing other processes,
Expand Down Expand Up @@ -229,46 +228,46 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):

import builtins

builtins.exit = None
builtins.quit = None
builtins.exit = None # type: ignore
builtins.quit = None # type: ignore

import os

os.environ["OMP_NUM_THREADS"] = "1"

os.kill = None
os.system = None
os.kill = None # type: ignore
os.system = None # type: ignore
# os.putenv = None # this causes numpy to fail on import
os.remove = None
os.removedirs = None
os.rmdir = None
os.fchdir = None
os.setuid = None
os.fork = None
os.forkpty = None
os.killpg = None
os.rename = None
os.renames = None
os.truncate = None
os.replace = None
os.unlink = None
os.fchmod = None
os.fchown = None
os.chmod = None
os.chown = None
os.chroot = None
os.fchdir = None
os.lchflags = None
os.lchmod = None
os.lchown = None
os.getcwd = None
os.chdir = None
os.remove = None # type: ignore
os.removedirs = None # type: ignore
os.rmdir = None # type: ignore
os.fchdir = None # type: ignore
os.setuid = None # type: ignore
os.fork = None # type: ignore
os.forkpty = None # type: ignore
os.killpg = None # type: ignore
os.rename = None # type: ignore
os.renames = None # type: ignore
os.truncate = None # type: ignore
os.replace = None # type: ignore
os.unlink = None # type: ignore
os.fchmod = None # type: ignore
os.fchown = None # type: ignore
os.chmod = None # type: ignore
os.chown = None # type: ignore
os.chroot = None # type: ignore
os.fchdir = None # type: ignore
os.lchflags = None # type: ignore
os.lchmod = None # type: ignore
os.lchown = None # type: ignore
os.getcwd = None # type: ignore
os.chdir = None # type: ignore

import shutil

shutil.rmtree = None
shutil.move = None
shutil.chown = None
shutil.rmtree = None # type: ignore
shutil.move = None # type: ignore
shutil.chown = None # type: ignore

import subprocess

Expand All @@ -278,8 +277,8 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):

import sys

sys.modules["ipdb"] = None
sys.modules["joblib"] = None
sys.modules["resource"] = None
sys.modules["psutil"] = None
sys.modules["tkinter"] = None
sys.modules["ipdb"] = None # type: ignore
sys.modules["joblib"] = None # type: ignore
sys.modules["resource"] = None # type: ignore
sys.modules["psutil"] = None # type: ignore
sys.modules["tkinter"] = None # type: ignore
10 changes: 5 additions & 5 deletions vision_agent/tools/tools_v2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import tempfile
from importlib import resources
from typing import Any, Dict, List
from typing import Any, Callable, Dict, List

import numpy as np
from PIL import Image, ImageDraw, ImageFont
Expand Down Expand Up @@ -112,8 +112,8 @@ def save_image(image: np.ndarray) -> str:
"""

with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
image = Image.fromarray(image)
image.save(f, "PNG")
pil_image = Image.fromarray(image.astype(np.uint8))
pil_image.save(f, "PNG")
return f.name


Expand All @@ -133,7 +133,7 @@ def display_bounding_boxes(
-------
>>> image_with_bboxes = display_bounding_boxes(image, [{'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}])
"""
pil_image = Image.fromarray(image)
pil_image = Image.fromarray(image.astype(np.uint8))

color = {
label: COLORS[i % len(COLORS)]
Expand Down Expand Up @@ -167,7 +167,7 @@ def display_bounding_boxes(
return np.array(pil_image.convert("RGB"))


def get_tool_documentation(funcs):
def get_tool_documentation(funcs: List[Callable]) -> str:
docstrings = ""
for func in funcs:
docstrings += f"{func.__name__}: {inspect.signature(func)}\n{func.__doc__}\n\n"
Expand Down

0 comments on commit bb8550f

Please sign in to comment.