Skip to content

Commit

Permalink
fix typing errors
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Mar 12, 2024
1 parent 482b3cc commit d626bca
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
7 changes: 4 additions & 3 deletions vision_agent/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@

import numpy as np
from PIL import Image
from PIL.Image import Image as ImageType


def b64_to_pil(b64_str: str) -> Image.Image:
def b64_to_pil(b64_str: str) -> ImageType:
# , can't be encoded in b64 data so must be part of prefix
if "," in b64_str:
b64_str = b64_str.split(",")[1]
return Image.open(BytesIO(base64.b64decode(b64_str)))


def get_image_size(data: Union[str, Path, np.ndarray, Image.Image]) -> Tuple[int, ...]:
def get_image_size(data: Union[str, Path, np.ndarray, ImageType]) -> Tuple[int, ...]:
if isinstance(data, (str, Path)):
data = Image.open(data)

Expand All @@ -24,7 +25,7 @@ def get_image_size(data: Union[str, Path, np.ndarray, Image.Image]) -> Tuple[int
return data.shape[:2]


def convert_to_b64(data: Union[str, Path, np.ndarray, Image.Image]) -> str:
def convert_to_b64(data: Union[str, Path, np.ndarray, ImageType]) -> str:
if data is None:
raise ValueError(f"Invalid input image: {data}. Input image can't be None.")
if isinstance(data, (str, Path)):
Expand Down
6 changes: 3 additions & 3 deletions vision_agent/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from abc import ABC, abstractmethod
from typing import cast
from typing import Any, Dict, cast

from vision_agent.tools import (
CHOOSE_PARAMS,
Expand Down Expand Up @@ -67,7 +67,7 @@ def generate_detector(self, params: str) -> ImageTool:
params = json.loads(cast(str, response.choices[0].message.content))[
"Parameters"
]
return GroundingDINO(**params)
return GroundingDINO(*params)

def generate_segmentor(self, params: str) -> ImageTool:
params = CHOOSE_PARAMS.format(api_doc=GroundingSAM.doc, question=params)
Expand All @@ -83,4 +83,4 @@ def generate_segmentor(self, params: str) -> ImageTool:
params = json.loads(cast(str, response.choices[0].message.content))[
"Parameters"
]
return GroundingSAM(**params)
return GroundingSAM(*params)

0 comments on commit d626bca

Please sign in to comment.