Skip to content

Commit

Permalink
fix type errors
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Mar 19, 2024
1 parent 8fded0b commit 066959c
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 66 deletions.
23 changes: 9 additions & 14 deletions vision_agent/agent/easytool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)


def parse_json(s: str) -> Dict:
def parse_json(s: str) -> Any:
s = (
s.replace(": true", ": True")
.replace(": false", ": False")
Expand All @@ -28,7 +28,7 @@ def parse_json(s: str) -> Dict:
return json.loads(s)


def change_name(name: str):
def change_name(name: str) -> str:
change_list = ["from", "class", "return", "false", "true", "id", "and", "", "ID"]
if name in change_list:
name = "is_" + name.lower()
Expand All @@ -53,7 +53,7 @@ def task_decompose(
try:
str_result = model(prompt)
result = parse_json(str_result)
return result["Tasks"]
return result["Tasks"] # type: ignore
except Exception:
if tries > 10:
raise ValueError(f"Failed task_decompose on: {str_result}")
Expand All @@ -78,7 +78,7 @@ def task_topology(
elt["dep"] = [elt["dep"]]
elif isinstance(elt["dep"], list):
elt["dep"] = [int(dep) for dep in elt["dep"]]
return result["Tasks"]
return result["Tasks"] # type: ignore
except Exception:
if tries > 10:
raise ValueError(f"Failed task_topology on: {str_result}")
Expand All @@ -96,7 +96,7 @@ def choose_tool(
try:
str_result = model(prompt)
result = parse_json(str_result)
return result["ID"]
return result["ID"] # type: ignore
except Exception:
if tries > 10:
raise ValueError(f"Failed choose_tool on: {str_result}")
Expand Down Expand Up @@ -217,15 +217,10 @@ def __init__(
task_model: Optional[Union[LLM, LMM]] = None,
answer_model: Optional[Union[LLM, LMM]] = None,
):
if task_model is None:
self.task_model = OpenAILLM(json_mode=True)
else:
self.task_model = task_model

if answer_model is None:
self.answer_model = OpenAILLM()
else:
self.answer_model = answer_model
self.task_model = (
OpenAILLM(json_mode=True) if task_model is None else task_model
)
self.answer_model = OpenAILLM() if answer_model is None else answer_model

self.retrieval_num = 3
self.tools = TOOLS
Expand Down
2 changes: 1 addition & 1 deletion vision_agent/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def chat(self, chat: List[Dict[str, str]]) -> str:
response = self.client.chat.completions.create(
model=self.model_name,
messages=chat, # type: ignore
**kwargs, # type: ignore
**kwargs,
)

return cast(str, response.choices[0].message.content)
Expand Down
101 changes: 50 additions & 51 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,13 @@ def rle_decode(mask_rle: str, shape: Tuple[int, int]) -> np.ndarray:
return img.reshape(shape)


class ImageTool(ABC):
class Tool(ABC):
name: str
description: str
usage: Dict


class ImageTool(Tool):
@abstractmethod
def __call__(self, image: Union[str, ImageType]) -> List[Dict]:
pass
Expand All @@ -68,7 +74,7 @@ class CLIP(ImageTool):
'Example 2: User Question: "Can you tag this photograph with cat or dog?" {{"Parameters":{{"prompt": ["cat", "dog"]}}}}\n'
'Exmaple 3: User Question: "Can you build me a classifier taht classifies red shirts, green shirts and other?" {{"Parameters":{{"prompt": ["red shirt", "green shirt", "other"]}}}}\n'
)
usage = {}
usage: Dict = {}

def __init__(self, prompt: list[str]):
self.prompt = prompt
Expand Down Expand Up @@ -183,7 +189,7 @@ class GroundingSAM(ImageTool):
'Example 2: User Question: "Can you segment the person on the left?" {{"Parameters":{{"prompt": ["person on the left"]}}\n'
'Exmaple 3: User Question: "Can you build me a tool that segments red shirts and green shirts?" {{"Parameters":{{"prompt": ["red shirt", "green shirt"]}}}}\n'
)
usage = {}
usage: Dict = {}

def __init__(self, prompt: list[str]):
self.prompt = prompt
Expand Down Expand Up @@ -219,77 +225,69 @@ def __call__(self, image: Union[str, ImageType]) -> List[Dict]:
return preds


class Add:
class Add(Tool):
name = "add_"
description = "'add_' returns the sum of all the arguments passed to it, normalized to 2 decimal places."
usage = (
{
"required_parameters": {"name": "input", "type": "List[int]"},
"examples": [
{
"scenario": "If you want to calculate 2 + 4",
"parameters": {"input": [2, 4]},
}
],
},
)
usage = {
"required_parameters": {"name": "input", "type": "List[int]"},
"examples": [
{
"scenario": "If you want to calculate 2 + 4",
"parameters": {"input": [2, 4]},
}
],
}

def __call__(self, input: List[int]) -> float:
return round(sum(input), 2)


class Subtract:
class Subtract(Tool):
name = "subtract_"
description = "'subtract_' returns the difference of all the arguments passed to it, normalized to 2 decimal places."
usage = (
{
"required_parameters": {"name": "input", "type": "List[int]"},
"examples": [
{
"scenario": "If you want to calculate 4 - 2",
"parameters": {"input": [4, 2]},
}
],
},
)
usage = {
"required_parameters": {"name": "input", "type": "List[int]"},
"examples": [
{
"scenario": "If you want to calculate 4 - 2",
"parameters": {"input": [4, 2]},
}
],
}

def __call__(self, input: List[int]) -> float:
return round(input[0] - input[1], 2)


class Multiply:
class Multiply(Tool):
name = "multiply_"
description = "'multiply_' returns the product of all the arguments passed to it, normalized to 2 decimal places."
usage = (
{
"required_parameters": {"name": "input", "type": "List[int]"},
"examples": [
{
"scenario": "If you want to calculate 2 * 4",
"parameters": {"input": [2, 4]},
}
],
},
)
usage = {
"required_parameters": {"name": "input", "type": "List[int]"},
"examples": [
{
"scenario": "If you want to calculate 2 * 4",
"parameters": {"input": [2, 4]},
}
],
}

def __call__(self, input: List[int]) -> float:
return round(input[0] * input[1], 2)


class Divide:
class Divide(Tool):
name = "divide_"
description = "'divide_' returns the division of all the arguments passed to it, normalized to 2 decimal places."
usage = (
{
"required_parameters": {"name": "input", "type": "List[int]"},
"examples": [
{
"scenario": "If you want to calculate 4 / 2",
"parameters": {"input": [4, 2]},
}
],
},
)
usage = {
"required_parameters": {"name": "input", "type": "List[int]"},
"examples": [
{
"scenario": "If you want to calculate 4 / 2",
"parameters": {"input": [4, 2]},
}
],
}

def __call__(self, input: List[int]) -> float:
return round(input[0] / input[1], 2)
Expand All @@ -300,4 +298,5 @@ def __call__(self, input: List[int]) -> float:
for i, c in enumerate(
[CLIP, GroundingDINO, GroundingSAM, Add, Subtract, Multiply, Divide]
)
if (hasattr(c, "name") and hasattr(c, "description") and hasattr(c, "usage"))
}

0 comments on commit 066959c

Please sign in to comment.