Skip to content

Commit

Permalink
Fix typing errors
Browse files Browse the repository at this point in the history
  • Loading branch information
AsiaCao committed Feb 20, 2024
1 parent 8efbae0 commit 4004792
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 17 deletions.
25 changes: 14 additions & 11 deletions lmm_tools/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import uuid
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, cast

import faiss
import numpy as np
Expand Down Expand Up @@ -31,7 +31,7 @@ def __init__(self, df: pd.DataFrame):
self.df = df
self.lmm: Optional[LMM] = None
self.emb: Optional[Embedder] = None
self.index = None
self.index: Optional[faiss.IndexFlatL2] = None # type: ignore
if "image_paths" not in self.df.columns:
raise ValueError("image_paths column must be present in DataFrame")
if "image_id" not in self.df.columns:
Expand All @@ -50,7 +50,7 @@ def add_column(self, name: str, prompt: str) -> Self:
if self.lmm is None:
raise ValueError("LMM not set yet")

self.df[name] = self.df["image_paths"].progress_apply(
self.df[name] = self.df["image_paths"].progress_apply( # type: ignore
lambda x: self.lmm.generate(prompt, image=x)
)
return self
Expand All @@ -62,19 +62,22 @@ def build_index(self, target_col: str) -> Self:
if self.emb is None:
raise ValueError("Embedder not set yet")

embeddings = self.df[target_col].progress_apply(lambda x: self.emb.embed(x))
embeddings = np.array(embeddings.tolist()).astype(np.float32)
self.index = faiss.IndexFlatL2(embeddings.shape[1])
self.index.add(embeddings)
embeddings: pd.Series = self.df[target_col].progress_apply(lambda x: self.emb.embed(x)) # type: ignore
embeddings_np = np.array(embeddings.tolist()).astype(np.float32)
self.index = faiss.IndexFlatL2(embeddings_np.shape[1])
self.index.add(embeddings_np)
return self

def get_embeddings(self) -> npt.NDArray[np.float32]:
if self.index is None:
raise ValueError("Index not built yet")

ntotal = self.index.ntotal
d = self.index.d
return faiss.rev_swig_ptr(self.index.get_xb(), ntotal * d).reshape(ntotal, d)
d: int = self.index.d
return cast(
npt.NDArray[np.float32],
faiss.rev_swig_ptr(self.index.get_xb(), ntotal * d).reshape(ntotal, d),
)

def search(self, query: str, top_k: int = 10) -> List[Dict]:
r"""Searches the index for the most similar images to the query and returns the top_k results.
Expand All @@ -86,9 +89,9 @@ def search(self, query: str, top_k: int = 10) -> List[Dict]:
if self.emb is None:
raise ValueError("Embedder not set yet")

query_embedding = self.emb.embed(query)
query_embedding: npt.NDArray[np.float32] = self.emb.embed(query)
_, idx = self.index.search(query_embedding.reshape(1, -1), top_k)
return self.df.iloc[idx[0]].to_dict(orient="records")
return cast(List[Dict], self.df.iloc[idx[0]].to_dict(orient="records"))

def save(self, path: Union[str, Path]) -> None:
path = Path(path)
Expand Down
8 changes: 6 additions & 2 deletions lmm_tools/emb/emb.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from abc import ABC, abstractmethod
from typing import cast

import numpy as np
import numpy.typing as npt


class Embedder(ABC):
@abstractmethod
def embed(self, text: str) -> list:
def embed(self, text: str) -> npt.NDArray[np.float32]:
pass


Expand All @@ -17,7 +18,10 @@ def __init__(self, model_name: str = "all-MiniLM-L12-v2"):
self.model = SentenceTransformer(model_name)

def embed(self, text: str) -> npt.NDArray[np.float32]:
return self.model.encode([text]).flatten().astype(np.float32)
return cast(
npt.NDArray[np.float32],
self.model.encode([text]).flatten().astype(np.float32),
)


class OpenAIEmb(Embedder):
Expand Down
8 changes: 4 additions & 4 deletions lmm_tools/lmm/lmm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Union
from typing import Any, Dict, Optional, Union, cast


def encode_image(image: Union[str, Path]) -> str:
Expand All @@ -23,7 +23,7 @@ def __init__(self, name: str):
self.name = name

def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str:
pass
raise NotImplementedError("LLaVA LMM not implemented yet")


class OpenAILMM(LMM):
Expand All @@ -36,7 +36,7 @@ def __init__(self, name: str):
self.client = OpenAI()

def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str:
message = [
message: list[Dict[str, Any]] = [
{
"role": "user",
"content": [
Expand All @@ -60,7 +60,7 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str:
response = self.client.chat.completions.create(
model="gpt-4-vision-preview", message=message
)
return response.choices[0].message.content
return cast(str, response.choices[0].message.content)


def get_lmm(name: str) -> LMM:
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,7 @@ disallow_any_unimported = true
ignore_missing_imports = true
module = [
"cv2.*",
"faiss.*",
"openai.*",
"sentence_transformers.*",
]

0 comments on commit 4004792

Please sign in to comment.