diff --git a/vision_agent/agent/agent_coder.py b/vision_agent/agent/agent_coder.py index 38b9515e..c49d32be 100644 --- a/vision_agent/agent/agent_coder.py +++ b/vision_agent/agent/agent_coder.py @@ -72,7 +72,8 @@ def run_visual_tests( feedback=feedback, ) completion = model(prompt, images=[viz_file]) - return json.loads(completion) + # type is from the prompt + return json.loads(completion) # type: ignore def fix_bugs(code: str, tests: str, result: str, feedback: str, model: LLM) -> str: @@ -147,7 +148,7 @@ def chat( if not results["passed"]: code = fix_bugs( - code, debug, results["result"].strip(), feedback, self.coder_agent + code, debug, results["result"].strip(), feedback, self.coder_agent # type: ignore ) _LOGGER.info(f"fixed code:\n{code}") else: diff --git a/vision_agent/agent/execution.py b/vision_agent/agent/execution.py index 86b4dcc6..1c61e075 100644 --- a/vision_agent/agent/execution.py +++ b/vision_agent/agent/execution.py @@ -12,7 +12,7 @@ import tempfile import traceback from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Dict, Generator, List, Optional, Union IMPORT_HELPER = """ import math @@ -53,7 +53,6 @@ def unsafe_execute(code: str, timeout: float, result: List) -> None: reliability_guard() try: - exec_globals = {} with swallow_io() as s: with time_limit(timeout): # WARNING @@ -65,8 +64,8 @@ def unsafe_execute(code: str, timeout: float, result: List) -> None: # does not perform destructive actions on their host or network. # Once you have read this disclaimer and taken appropriate precautions, # uncomment the following line and proceed at your own risk: - code = compile(code, code_path, "exec") - exec(code, exec_globals) + code = compile(code, code_path, "exec") # type: ignore + exec(code) result.append({"output": s.getvalue(), "passed": True}) except TimeoutError: result.append({"output": "Timed out", "passed": False}) @@ -133,12 +132,12 @@ def check_correctness( # ============================================================================ -class redirect_stdin(contextlib._RedirectStream): # type: ignore +class redirect_stdin(contextlib._RedirectStream): _stream = "stdin" @contextlib.contextmanager -def chdir(root): +def chdir(root: str) -> Generator[None, None, None]: if root == ".": yield return @@ -155,29 +154,29 @@ def chdir(root): class WriteOnlyStringIO(io.StringIO): """StringIO that throws an exception when it's read from""" - def read(self, *args, **kwargs): + def read(self, *args, **kwargs): # type: ignore raise IOError - def readline(self, *args, **kwargs): + def readline(self, *args, **kwargs): # type: ignore raise IOError - def readlines(self, *args, **kwargs): + def readlines(self, *args, **kwargs): # type: ignore raise IOError - def readable(self, *args, **kwargs): + def readable(self, *args, **kwargs): # type: ignore """Returns True if the IO object can be read.""" return False @contextlib.contextmanager -def create_tempdir(): +def create_tempdir() -> Generator[str, None, None]: with tempfile.TemporaryDirectory() as dirname: with chdir(dirname): yield dirname @contextlib.contextmanager -def swallow_io(): +def swallow_io() -> Generator[WriteOnlyStringIO, None, None]: stream = WriteOnlyStringIO() with contextlib.redirect_stdout(stream): with contextlib.redirect_stderr(stream): @@ -186,8 +185,8 @@ def swallow_io(): @contextlib.contextmanager -def time_limit(seconds: float): - def signal_handler(signum, frame): +def time_limit(seconds: float) -> Generator[None, None, None]: + def signal_handler(signum, frame): # type: ignore raise TimeoutError("Timed out!") signal.setitimer(signal.ITIMER_REAL, seconds) @@ -198,7 +197,7 @@ def signal_handler(signum, frame): signal.setitimer(signal.ITIMER_REAL, 0) -def reliability_guard(maximum_memory_bytes: Optional[int] = None): +def reliability_guard(maximum_memory_bytes: Optional[int] = None) -> None: """ This disables various destructive functions and prevents the generated code from interfering with the test (e.g. fork bomb, killing other processes, @@ -229,46 +228,46 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None): import builtins - builtins.exit = None - builtins.quit = None + builtins.exit = None # type: ignore + builtins.quit = None # type: ignore import os os.environ["OMP_NUM_THREADS"] = "1" - os.kill = None - os.system = None + os.kill = None # type: ignore + os.system = None # type: ignore # os.putenv = None # this causes numpy to fail on import - os.remove = None - os.removedirs = None - os.rmdir = None - os.fchdir = None - os.setuid = None - os.fork = None - os.forkpty = None - os.killpg = None - os.rename = None - os.renames = None - os.truncate = None - os.replace = None - os.unlink = None - os.fchmod = None - os.fchown = None - os.chmod = None - os.chown = None - os.chroot = None - os.fchdir = None - os.lchflags = None - os.lchmod = None - os.lchown = None - os.getcwd = None - os.chdir = None + os.remove = None # type: ignore + os.removedirs = None # type: ignore + os.rmdir = None # type: ignore + os.fchdir = None # type: ignore + os.setuid = None # type: ignore + os.fork = None # type: ignore + os.forkpty = None # type: ignore + os.killpg = None # type: ignore + os.rename = None # type: ignore + os.renames = None # type: ignore + os.truncate = None # type: ignore + os.replace = None # type: ignore + os.unlink = None # type: ignore + os.fchmod = None # type: ignore + os.fchown = None # type: ignore + os.chmod = None # type: ignore + os.chown = None # type: ignore + os.chroot = None # type: ignore + os.fchdir = None # type: ignore + os.lchflags = None # type: ignore + os.lchmod = None # type: ignore + os.lchown = None # type: ignore + os.getcwd = None # type: ignore + os.chdir = None # type: ignore import shutil - shutil.rmtree = None - shutil.move = None - shutil.chown = None + shutil.rmtree = None # type: ignore + shutil.move = None # type: ignore + shutil.chown = None # type: ignore import subprocess @@ -278,8 +277,8 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None): import sys - sys.modules["ipdb"] = None - sys.modules["joblib"] = None - sys.modules["resource"] = None - sys.modules["psutil"] = None - sys.modules["tkinter"] = None + sys.modules["ipdb"] = None # type: ignore + sys.modules["joblib"] = None # type: ignore + sys.modules["resource"] = None # type: ignore + sys.modules["psutil"] = None # type: ignore + sys.modules["tkinter"] = None # type: ignore diff --git a/vision_agent/tools/tools_v2.py b/vision_agent/tools/tools_v2.py index 7dc2167f..3e39107d 100644 --- a/vision_agent/tools/tools_v2.py +++ b/vision_agent/tools/tools_v2.py @@ -1,7 +1,7 @@ import inspect import tempfile from importlib import resources -from typing import Any, Dict, List +from typing import Any, Callable, Dict, List import numpy as np from PIL import Image, ImageDraw, ImageFont @@ -112,8 +112,8 @@ def save_image(image: np.ndarray) -> str: """ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: - image = Image.fromarray(image) - image.save(f, "PNG") + pil_image = Image.fromarray(image.astype(np.uint8)) + pil_image.save(f, "PNG") return f.name @@ -133,7 +133,7 @@ def display_bounding_boxes( ------- >>> image_with_bboxes = display_bounding_boxes(image, [{'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}]) """ - pil_image = Image.fromarray(image) + pil_image = Image.fromarray(image.astype(np.uint8)) color = { label: COLORS[i % len(COLORS)] @@ -167,7 +167,7 @@ def display_bounding_boxes( return np.array(pil_image.convert("RGB")) -def get_tool_documentation(funcs): +def get_tool_documentation(funcs: List[Callable]) -> str: docstrings = "" for func in funcs: docstrings += f"{func.__name__}: {inspect.signature(func)}\n{func.__doc__}\n\n"