diff --git a/lmm_tools/data/data.py b/lmm_tools/data/data.py index 533b5a70..09156acd 100644 --- a/lmm_tools/data/data.py +++ b/lmm_tools/data/data.py @@ -2,7 +2,7 @@ import uuid from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, cast import faiss import numpy as np @@ -31,7 +31,7 @@ def __init__(self, df: pd.DataFrame): self.df = df self.lmm: Optional[LMM] = None self.emb: Optional[Embedder] = None - self.index = None + self.index: Optional[faiss.IndexFlatL2] = None # type: ignore if "image_paths" not in self.df.columns: raise ValueError("image_paths column must be present in DataFrame") if "image_id" not in self.df.columns: @@ -50,7 +50,7 @@ def add_column(self, name: str, prompt: str) -> Self: if self.lmm is None: raise ValueError("LMM not set yet") - self.df[name] = self.df["image_paths"].progress_apply( + self.df[name] = self.df["image_paths"].progress_apply( # type: ignore lambda x: self.lmm.generate(prompt, image=x) ) return self @@ -62,10 +62,10 @@ def build_index(self, target_col: str) -> Self: if self.emb is None: raise ValueError("Embedder not set yet") - embeddings = self.df[target_col].progress_apply(lambda x: self.emb.embed(x)) - embeddings = np.array(embeddings.tolist()).astype(np.float32) - self.index = faiss.IndexFlatL2(embeddings.shape[1]) - self.index.add(embeddings) + embeddings: pd.Series = self.df[target_col].progress_apply(lambda x: self.emb.embed(x)) # type: ignore + embeddings_np = np.array(embeddings.tolist()).astype(np.float32) + self.index = faiss.IndexFlatL2(embeddings_np.shape[1]) + self.index.add(embeddings_np) return self def get_embeddings(self) -> npt.NDArray[np.float32]: @@ -73,8 +73,11 @@ def get_embeddings(self) -> npt.NDArray[np.float32]: raise ValueError("Index not built yet") ntotal = self.index.ntotal - d = self.index.d - return faiss.rev_swig_ptr(self.index.get_xb(), ntotal * d).reshape(ntotal, d) + d: int = self.index.d + return cast( + npt.NDArray[np.float32], + faiss.rev_swig_ptr(self.index.get_xb(), ntotal * d).reshape(ntotal, d), + ) def search(self, query: str, top_k: int = 10) -> List[Dict]: r"""Searches the index for the most similar images to the query and returns the top_k results. @@ -86,9 +89,9 @@ def search(self, query: str, top_k: int = 10) -> List[Dict]: if self.emb is None: raise ValueError("Embedder not set yet") - query_embedding = self.emb.embed(query) + query_embedding: npt.NDArray[np.float32] = self.emb.embed(query) _, idx = self.index.search(query_embedding.reshape(1, -1), top_k) - return self.df.iloc[idx[0]].to_dict(orient="records") + return cast(List[Dict], self.df.iloc[idx[0]].to_dict(orient="records")) def save(self, path: Union[str, Path]) -> None: path = Path(path) diff --git a/lmm_tools/emb/emb.py b/lmm_tools/emb/emb.py index 91d53b07..25696bc5 100644 --- a/lmm_tools/emb/emb.py +++ b/lmm_tools/emb/emb.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import cast import numpy as np import numpy.typing as npt @@ -6,7 +7,7 @@ class Embedder(ABC): @abstractmethod - def embed(self, text: str) -> list: + def embed(self, text: str) -> npt.NDArray[np.float32]: pass @@ -17,7 +18,10 @@ def __init__(self, model_name: str = "all-MiniLM-L12-v2"): self.model = SentenceTransformer(model_name) def embed(self, text: str) -> npt.NDArray[np.float32]: - return self.model.encode([text]).flatten().astype(np.float32) + return cast( + npt.NDArray[np.float32], + self.model.encode([text]).flatten().astype(np.float32), + ) class OpenAIEmb(Embedder): diff --git a/lmm_tools/lmm/lmm.py b/lmm_tools/lmm/lmm.py index c9c38945..c2efde5d 100644 --- a/lmm_tools/lmm/lmm.py +++ b/lmm_tools/lmm/lmm.py @@ -1,7 +1,7 @@ import base64 from abc import ABC, abstractmethod from pathlib import Path -from typing import Optional, Union +from typing import Any, Dict, Optional, Union, cast def encode_image(image: Union[str, Path]) -> str: @@ -23,7 +23,7 @@ def __init__(self, name: str): self.name = name def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str: - pass + raise NotImplementedError("LLaVA LMM not implemented yet") class OpenAILMM(LMM): @@ -36,7 +36,7 @@ def __init__(self, name: str): self.client = OpenAI() def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str: - message = [ + message: list[Dict[str, Any]] = [ { "role": "user", "content": [ @@ -60,7 +60,7 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str: response = self.client.chat.completions.create( model="gpt-4-vision-preview", message=message ) - return response.choices[0].message.content + return cast(str, response.choices[0].message.content) def get_lmm(name: str) -> LMM: diff --git a/pyproject.toml b/pyproject.toml index c640028a..b59f30c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,4 +81,7 @@ disallow_any_unimported = true ignore_missing_imports = true module = [ "cv2.*", + "faiss.*", + "openai.*", + "sentence_transformers.*", ]