Skip to content

Commit

Permalink
fix type errors
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Jul 16, 2024
1 parent 8e3b8fe commit 33856a7
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 56 deletions.
2 changes: 1 addition & 1 deletion vision_agent/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

from vision_agent.lmm import Message
from vision_agent.lmm.types import Message


class Agent(ABC):
Expand Down
4 changes: 2 additions & 2 deletions vision_agent/agent/vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
_LOGGER = logging.getLogger(__name__)
WORKSPACE = Path(os.getenv("WORKSPACE", ""))
WORKSPACE.mkdir(parents=True, exist_ok=True)
if WORKSPACE != "":
if str(WORKSPACE) != "":
os.environ["PYTHONPATH"] = f"{WORKSPACE}:{os.getenv('PYTHONPATH', '')}"


Expand Down Expand Up @@ -113,7 +113,7 @@ def __call__(
self,
input: Union[str, List[Message]],
media: Optional[Union[str, Path]] = None,
) -> List[Message]:
) -> str:
if isinstance(input, str):
input = [{"role": "user", "content": input}]
if media is not None:
Expand Down
41 changes: 2 additions & 39 deletions vision_agent/agent/vision_agent_coder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import os
import copy
import difflib
import json
import logging
import os
import sys
import tempfile
from pathlib import Path
Expand Down Expand Up @@ -93,42 +92,6 @@ def format_plans(plans: Dict[str, Any]) -> str:
return plan_str


def extract_code(code: str) -> str:
if "\n```python" in code:
start = "\n```python"
elif "```python" in code:
start = "```python"
else:
return code

code = code[code.find(start) + len(start) :]
code = code[: code.find("```")]
if code.startswith("python\n"):
code = code[len("python\n") :]
return code


def extract_json(json_str: str) -> Dict[str, Any]:
try:
json_dict = json.loads(json_str)
except json.JSONDecodeError:
input_json_str = json_str
if "```json" in json_str:
json_str = json_str[json_str.find("```json") + len("```json") :]
json_str = json_str[: json_str.find("```")]
elif "```" in json_str:
json_str = json_str[json_str.find("```") + len("```") :]
# get the last ``` not one from an intermediate string
json_str = json_str[: json_str.find("}```")]
try:
json_dict = json.loads(json_str)
except json.JSONDecodeError as e:
error_msg = f"Could not extract JSON from the given str: {json_str}.\nFunction input:\n{input_json_str}"
_LOGGER.exception(error_msg)
raise ValueError(error_msg) from e
return json_dict # type: ignore


def extract_image(
media: Optional[Sequence[Union[str, Path]]]
) -> Optional[Sequence[Union[str, Path]]]:
Expand Down Expand Up @@ -610,7 +573,7 @@ def __call__(
input[0]["media"] = [media]
results = self.chat_with_workflow(input)
results.pop("working_memory")
return results # type: ignore
return results["code"] # type: ignore

def chat_with_workflow(
self,
Expand Down
3 changes: 2 additions & 1 deletion vision_agent/lmm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .lmm import LMM, AzureOpenAILMM, Message, OllamaLMM, OpenAILMM
from .lmm import LMM, AzureOpenAILMM, OllamaLMM, OpenAILMM
from .types import Message
9 changes: 3 additions & 6 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union, cast
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast

import requests
from openai import AzureOpenAI, OpenAI
from PIL import Image

import vision_agent.tools as T
from vision_agent.tools.prompts import CHOOSE_PARAMS, SYSTEM_PROMPT
from .types import Message

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -51,10 +52,6 @@ def encode_media(media: Union[str, Path]) -> str:
return encode_image_bytes(image_bytes)


TextOrImage = Union[str, List[Union[str, Path]]]
Message = Dict[str, TextOrImage]


class LMM(ABC):
@abstractmethod
def generate(
Expand Down Expand Up @@ -134,7 +131,7 @@ def chat(
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{encoded_media}", # type: ignore
"url": f"data:image/png;base64,{encoded_media}",
"detail": "low",
},
},
Expand Down
5 changes: 5 additions & 0 deletions vision_agent/lmm/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from pathlib import Path
from typing import Dict, Sequence, Union

TextOrImage = Union[str, Sequence[Union[str, Path]]]
Message = Dict[str, TextOrImage]
19 changes: 12 additions & 7 deletions vision_agent/tools/meta_tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import subprocess
from pathlib import Path
from typing import List
from vision_agent.lmm.types import Message

import vision_agent as va
from vision_agent.tools.tool_utils import get_tool_documentation
Expand Down Expand Up @@ -43,7 +44,7 @@ def detect_dogs(image_path: str):

agent = va.agent.VisionAgentCoder()
try:
fixed_chat = [{"role": "user", "content": chat, "media": media}]
fixed_chat: List[Message] = [{"role": "user", "content": chat, "media": media}]
response = agent.chat_with_workflow(fixed_chat)
code = response["code"]
with open(save_file, "w") as f:
Expand Down Expand Up @@ -84,7 +85,7 @@ def detect_dogs(image_path: str):
code = f.read()

# Append latest code to second to last message from assistant
fixed_chat_history = []
fixed_chat_history: List[Message] = []
for i, chat in enumerate(chat_history):
if i == 0:
fixed_chat_history.append({"role": "user", "content": chat, "media": media})
Expand Down Expand Up @@ -118,9 +119,11 @@ def view_lines(
) -> str:
start = max(0, line_num - window_size)
end = min(len(lines), line_num + window_size)
return f"[File: {file_path} ({total_lines} lines total)]\n" + format_lines(
lines[start:end], start
) + ("[End of file]" if end == len(lines) else f"[{len(lines) - end} more lines]")
return (
f"[File: {file_path} ({total_lines} lines total)]\n"
+ format_lines(lines[start:end], start)
+ ("[End of file]" if end == len(lines) else f"[{len(lines) - end} more lines]")
)


def open_file(file_path: str, line_num: int = 0, window_size: int = 100) -> str:
Expand Down Expand Up @@ -245,7 +248,9 @@ def search_file(search_term: str, file_path: str) -> str:
if not search_results:
return f"[No matches found for {search_term} in {file_path}]"

return_str = f"[Found {len(search_results)} matches for {search_term} in {file_path}]\n"
return_str = (
f"[Found {len(search_results)} matches for {search_term} in {file_path}]\n"
)
for result in search_results:
return_str += result

Expand Down Expand Up @@ -330,7 +335,7 @@ def edit_file(file_path: str, start: int, end: int, content: str) -> str:
tmp_file.unlink()
if stdout != "":
stdout = stdout.replace(tmp_file.name, file_path)
error_msg = f"[Edit failed with the following status]\n" + stdout
error_msg = "[Edit failed with the following status]\n" + stdout
original_view = view_lines(
lines,
start + ((end - start) // 2),
Expand Down

0 comments on commit 33856a7

Please sign in to comment.