Skip to content

Commit

Permalink
fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed May 6, 2024
1 parent 0d763ee commit 4856b7b
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 12 deletions.
3 changes: 1 addition & 2 deletions vision_agent/agent/automated_vision_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def write_plan(
context = USER_REQ_CONTEXT.format(user_requirement=user_requirements)
prompt = PLAN.format(context=context, plan="", tool_desc=tool_desc)
plan = json.loads(model(prompt).replace("```", "").strip())
return plan["plan"]
return plan["plan"] # type: ignore


def write_code(
Expand Down Expand Up @@ -217,7 +217,6 @@ def __init__(
self.tool_recommender = Sim(TOOLS_DF, sim_key="desc")
else:
self.tool_recommender = tool_recommender
self.long_term_memory = []
self.verbose = verbose
if self.verbose:
_LOGGER.setLevel(logging.INFO)
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def generate(self, prompt: str) -> str:

response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
messages=messages, # type: ignore
**self.kwargs,
)

Expand Down
16 changes: 8 additions & 8 deletions vision_agent/tools/tools_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,15 +369,15 @@ def display_segmentation_masks(
return np.array(pil_image.convert("RGB"))


def get_tool_documentation(funcs: List[Callable]) -> str:
def get_tool_documentation(funcs: List[Callable[..., Any]]) -> str:
docstrings = ""
for func in funcs:
docstrings += f"{func.__name__}{inspect.signature(func)}:\n{func.__doc__}\n\n"

return docstrings


def get_tool_descriptions(funcs: List[Callable]) -> str:
def get_tool_descriptions(funcs: List[Callable[..., Any]]) -> str:
descriptions = ""
for func in funcs:
description = func.__doc__
Expand All @@ -392,8 +392,8 @@ def get_tool_descriptions(funcs: List[Callable]) -> str:
return descriptions


def get_tools_df(funcs: List[Callable]) -> pd.DataFrame:
data = {"desc": [], "doc": []}
def get_tools_df(funcs: List[Callable[..., Any]]) -> pd.DataFrame:
data: Dict[str, List[str]] = {"desc": [], "doc": []}

for func in funcs:
desc = func.__doc__
Expand All @@ -406,7 +406,7 @@ def get_tools_df(funcs: List[Callable]) -> pd.DataFrame:
data["desc"].append(desc)
data["doc"].append(doc)

return pd.DataFrame(data)
return pd.DataFrame(data) # type: ignore


TOOLS = [
Expand All @@ -419,9 +419,9 @@ def get_tools_df(funcs: List[Callable]) -> pd.DataFrame:
display_bounding_boxes,
display_segmentation_masks,
]
TOOLS_DF = get_tools_df(TOOLS)
TOOL_DESCRIPTIONS = get_tool_descriptions(TOOLS)
TOOL_DOCSTRING = get_tool_documentation(TOOLS)
TOOLS_DF = get_tools_df(TOOLS) # type: ignore
TOOL_DESCRIPTIONS = get_tool_descriptions(TOOLS) # type: ignore
TOOL_DOCSTRING = get_tool_documentation(TOOLS) # type: ignore
UTILITIES_DOCSTRING = get_tool_documentation(
[load_image, save_image, display_bounding_boxes]
)
2 changes: 1 addition & 1 deletion vision_agent/utils/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pandas as pd
from openai import Client
from scipy.spatial.distance import cosine
from scipy.spatial.distance import cosine # type: ignore

client = Client()

Expand Down

0 comments on commit 4856b7b

Please sign in to comment.