diff --git a/vision_agent/image_utils.py b/vision_agent/image_utils.py index c48d5620..9ad2bdaa 100644 --- a/vision_agent/image_utils.py +++ b/vision_agent/image_utils.py @@ -5,16 +5,17 @@ import numpy as np from PIL import Image +from PIL.Image import Image as ImageType -def b64_to_pil(b64_str: str) -> Image.Image: +def b64_to_pil(b64_str: str) -> ImageType: # , can't be encoded in b64 data so must be part of prefix if "," in b64_str: b64_str = b64_str.split(",")[1] return Image.open(BytesIO(base64.b64decode(b64_str))) -def get_image_size(data: Union[str, Path, np.ndarray, Image.Image]) -> Tuple[int, ...]: +def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int, ...]: if isinstance(data, (str, Path)): data = Image.open(data) @@ -24,7 +25,7 @@ def get_image_size(data: Union[str, Path, np.ndarray, Image.Image]) -> Tuple[int return data.shape[:2] -def convert_to_b64(data: Union[str, Path, np.ndarray, Image.Image]) -> str: +def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str: if data is None: raise ValueError(f"Invalid input image: {data}. Input image can't be None.") if isinstance(data, (str, Path)): diff --git a/vision_agent/llm/llm.py b/vision_agent/llm/llm.py index 6931ca03..08197abf 100644 --- a/vision_agent/llm/llm.py +++ b/vision_agent/llm/llm.py @@ -1,6 +1,6 @@ import json from abc import ABC, abstractmethod -from typing import cast +from typing import Any, Dict, cast from vision_agent.tools import ( CHOOSE_PARAMS, @@ -67,7 +67,7 @@ def generate_detector(self, params: str) -> ImageTool: params = json.loads(cast(str, response.choices[0].message.content))[ "Parameters" ] - return GroundingDINO(**params) + return GroundingDINO(*params) def generate_segmentor(self, params: str) -> ImageTool: params = CHOOSE_PARAMS.format(api_doc=GroundingSAM.doc, question=params) @@ -83,4 +83,4 @@ def generate_segmentor(self, params: str) -> ImageTool: params = json.loads(cast(str, response.choices[0].message.content))[ "Parameters" ] - return GroundingSAM(**params) + return GroundingSAM(*params)