diff --git a/vision_agent/agent/easytool.py b/vision_agent/agent/easytool.py index ac35f805..b21114ec 100644 --- a/vision_agent/agent/easytool.py +++ b/vision_agent/agent/easytool.py @@ -16,7 +16,7 @@ ) -def parse_json(s: str) -> Dict: +def parse_json(s: str) -> Any: s = ( s.replace(": true", ": True") .replace(": false", ": False") @@ -28,7 +28,7 @@ def parse_json(s: str) -> Dict: return json.loads(s) -def change_name(name: str): +def change_name(name: str) -> str: change_list = ["from", "class", "return", "false", "true", "id", "and", "", "ID"] if name in change_list: name = "is_" + name.lower() @@ -53,7 +53,7 @@ def task_decompose( try: str_result = model(prompt) result = parse_json(str_result) - return result["Tasks"] + return result["Tasks"] # type: ignore except Exception: if tries > 10: raise ValueError(f"Failed task_decompose on: {str_result}") @@ -78,7 +78,7 @@ def task_topology( elt["dep"] = [elt["dep"]] elif isinstance(elt["dep"], list): elt["dep"] = [int(dep) for dep in elt["dep"]] - return result["Tasks"] + return result["Tasks"] # type: ignore except Exception: if tries > 10: raise ValueError(f"Failed task_topology on: {str_result}") @@ -96,7 +96,7 @@ def choose_tool( try: str_result = model(prompt) result = parse_json(str_result) - return result["ID"] + return result["ID"] # type: ignore except Exception: if tries > 10: raise ValueError(f"Failed choose_tool on: {str_result}") @@ -217,15 +217,10 @@ def __init__( task_model: Optional[Union[LLM, LMM]] = None, answer_model: Optional[Union[LLM, LMM]] = None, ): - if task_model is None: - self.task_model = OpenAILLM(json_mode=True) - else: - self.task_model = task_model - - if answer_model is None: - self.answer_model = OpenAILLM() - else: - self.answer_model = answer_model + self.task_model = ( + OpenAILLM(json_mode=True) if task_model is None else task_model + ) + self.answer_model = OpenAILLM() if answer_model is None else answer_model self.retrieval_num = 3 self.tools = TOOLS diff --git a/vision_agent/llm/llm.py b/vision_agent/llm/llm.py index dcf61762..bb6e7d7f 100644 --- a/vision_agent/llm/llm.py +++ b/vision_agent/llm/llm.py @@ -55,7 +55,7 @@ def chat(self, chat: List[Dict[str, str]]) -> str: response = self.client.chat.completions.create( model=self.model_name, messages=chat, # type: ignore - **kwargs, # type: ignore + **kwargs, ) return cast(str, response.choices[0].message.content) diff --git a/vision_agent/tools/tools.py b/vision_agent/tools/tools.py index 1c60738d..d6d66168 100644 --- a/vision_agent/tools/tools.py +++ b/vision_agent/tools/tools.py @@ -41,7 +41,13 @@ def rle_decode(mask_rle: str, shape: Tuple[int, int]) -> np.ndarray: return img.reshape(shape) -class ImageTool(ABC): +class Tool(ABC): + name: str + description: str + usage: Dict + + +class ImageTool(Tool): @abstractmethod def __call__(self, image: Union[str, ImageType]) -> List[Dict]: pass @@ -68,7 +74,7 @@ class CLIP(ImageTool): 'Example 2: User Question: "Can you tag this photograph with cat or dog?" {{"Parameters":{{"prompt": ["cat", "dog"]}}}}\n' 'Exmaple 3: User Question: "Can you build me a classifier taht classifies red shirts, green shirts and other?" {{"Parameters":{{"prompt": ["red shirt", "green shirt", "other"]}}}}\n' ) - usage = {} + usage: Dict = {} def __init__(self, prompt: list[str]): self.prompt = prompt @@ -183,7 +189,7 @@ class GroundingSAM(ImageTool): 'Example 2: User Question: "Can you segment the person on the left?" {{"Parameters":{{"prompt": ["person on the left"]}}\n' 'Exmaple 3: User Question: "Can you build me a tool that segments red shirts and green shirts?" {{"Parameters":{{"prompt": ["red shirt", "green shirt"]}}}}\n' ) - usage = {} + usage: Dict = {} def __init__(self, prompt: list[str]): self.prompt = prompt @@ -219,77 +225,69 @@ def __call__(self, image: Union[str, ImageType]) -> List[Dict]: return preds -class Add: +class Add(Tool): name = "add_" description = "'add_' returns the sum of all the arguments passed to it, normalized to 2 decimal places." - usage = ( - { - "required_parameters": {"name": "input", "type": "List[int]"}, - "examples": [ - { - "scenario": "If you want to calculate 2 + 4", - "parameters": {"input": [2, 4]}, - } - ], - }, - ) + usage = { + "required_parameters": {"name": "input", "type": "List[int]"}, + "examples": [ + { + "scenario": "If you want to calculate 2 + 4", + "parameters": {"input": [2, 4]}, + } + ], + } def __call__(self, input: List[int]) -> float: return round(sum(input), 2) -class Subtract: +class Subtract(Tool): name = "subtract_" description = "'subtract_' returns the difference of all the arguments passed to it, normalized to 2 decimal places." - usage = ( - { - "required_parameters": {"name": "input", "type": "List[int]"}, - "examples": [ - { - "scenario": "If you want to calculate 4 - 2", - "parameters": {"input": [4, 2]}, - } - ], - }, - ) + usage = { + "required_parameters": {"name": "input", "type": "List[int]"}, + "examples": [ + { + "scenario": "If you want to calculate 4 - 2", + "parameters": {"input": [4, 2]}, + } + ], + } def __call__(self, input: List[int]) -> float: return round(input[0] - input[1], 2) -class Multiply: +class Multiply(Tool): name = "multiply_" description = "'multiply_' returns the product of all the arguments passed to it, normalized to 2 decimal places." - usage = ( - { - "required_parameters": {"name": "input", "type": "List[int]"}, - "examples": [ - { - "scenario": "If you want to calculate 2 * 4", - "parameters": {"input": [2, 4]}, - } - ], - }, - ) + usage = { + "required_parameters": {"name": "input", "type": "List[int]"}, + "examples": [ + { + "scenario": "If you want to calculate 2 * 4", + "parameters": {"input": [2, 4]}, + } + ], + } def __call__(self, input: List[int]) -> float: return round(input[0] * input[1], 2) -class Divide: +class Divide(Tool): name = "divide_" description = "'divide_' returns the division of all the arguments passed to it, normalized to 2 decimal places." - usage = ( - { - "required_parameters": {"name": "input", "type": "List[int]"}, - "examples": [ - { - "scenario": "If you want to calculate 4 / 2", - "parameters": {"input": [4, 2]}, - } - ], - }, - ) + usage = { + "required_parameters": {"name": "input", "type": "List[int]"}, + "examples": [ + { + "scenario": "If you want to calculate 4 / 2", + "parameters": {"input": [4, 2]}, + } + ], + } def __call__(self, input: List[int]) -> float: return round(input[0] / input[1], 2) @@ -300,4 +298,5 @@ def __call__(self, input: List[int]) -> float: for i, c in enumerate( [CLIP, GroundingDINO, GroundingSAM, Add, Subtract, Multiply, Divide] ) + if (hasattr(c, "name") and hasattr(c, "description") and hasattr(c, "usage")) }