diff --git a/lmm_tools/data/data.py b/lmm_tools/data/data.py index 7357b6c5..533b5a70 100644 --- a/lmm_tools/data/data.py +++ b/lmm_tools/data/data.py @@ -2,6 +2,7 @@ import uuid from pathlib import Path +from typing import Dict, List, Optional, Union import faiss import numpy as np @@ -28,8 +29,8 @@ def __init__(self, df: pd.DataFrame): """ self.df = df - self.lmm: LMM | None = None - self.emb: Embedder | None = None + self.lmm: Optional[LMM] = None + self.emb: Optional[Embedder] = None self.index = None if "image_paths" not in self.df.columns: raise ValueError("image_paths column must be present in DataFrame") @@ -75,7 +76,7 @@ def get_embeddings(self) -> npt.NDArray[np.float32]: d = self.index.d return faiss.rev_swig_ptr(self.index.get_xb(), ntotal * d).reshape(ntotal, d) - def search(self, query: str, top_k: int = 10) -> list[dict]: + def search(self, query: str, top_k: int = 10) -> List[Dict]: r"""Searches the index for the most similar images to the query and returns the top_k results. Args: query (str): The query to search for. @@ -89,7 +90,7 @@ def search(self, query: str, top_k: int = 10) -> list[dict]: _, idx = self.index.search(query_embedding.reshape(1, -1), top_k) return self.df.iloc[idx[0]].to_dict(orient="records") - def save(self, path: str | Path) -> None: + def save(self, path: Union[str, Path]) -> None: path = Path(path) path.mkdir(parents=True) self.df.to_csv(path / "data.csv") @@ -97,7 +98,7 @@ def save(self, path: str | Path) -> None: write_index(self.index, str(path / "data.index")) @classmethod - def load(cls, path: str | Path) -> DataStore: + def load(cls, path: Union[str, Path]) -> DataStore: path = Path(path) df = pd.read_csv(path / "data.csv", index_col=0) ds = DataStore(df) @@ -106,7 +107,7 @@ def load(cls, path: str | Path) -> DataStore: return ds -def build_data_store(data: str | Path | list[str | Path]) -> DataStore: +def build_data_store(data: Union[str, Path, list[Union[str, Path]]]) -> DataStore: if isinstance(data, Path) or isinstance(data, str): data = Path(data) data_files = list(Path(data).glob("*")) diff --git a/lmm_tools/lmm/lmm.py b/lmm_tools/lmm/lmm.py index 4f20c7f3..e4c80032 100644 --- a/lmm_tools/lmm/lmm.py +++ b/lmm_tools/lmm/lmm.py @@ -1,7 +1,7 @@ import base64 from abc import ABC, abstractmethod from pathlib import Path -from typing import Optional +from typing import Optional, Union def encode_image(image: str | Path) -> str: @@ -12,7 +12,7 @@ def encode_image(image: str | Path) -> str: class LMM(ABC): @abstractmethod - def generate(self, prompt: str, image: Optional[str | Path]) -> str: + def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str: pass @@ -22,7 +22,7 @@ class LLaVALMM(LMM): def __init__(self, name: str): self.name = name - def generate(self, prompt: str, image: Optional[str | Path]) -> str: + def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str: pass @@ -35,7 +35,7 @@ def __init__(self, name: str): self.name = name self.client = OpenAI() - def generate(self, prompt: str, image: Optional[str | Path]) -> str: + def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str: message = [ { "role": "user",