Skip to content

Commit

Permalink
fix typing
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Feb 17, 2024
1 parent 58c2289 commit fc6a6ba
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions lmm_tools/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import faiss
import numpy as np
import numpy.typing as np
import numpy.typing as npt
import pandas as pd
from tqdm import tqdm

Expand All @@ -15,8 +15,8 @@
class Data:
def __init__(self, df: pd.DataFrame):
self.df = pd.DataFrame()
self.lmm: LMM = None
self.emb: Embedder = None
self.lmm: LMM | None = None
self.emb: Embedder | None = None
self.index = None
if "image_paths" not in df.columns:
raise ValueError("image_paths column must be present in DataFrame")
Expand All @@ -28,28 +28,42 @@ def add_lmm(self, lmm: LMM):
self.lmm = lmm

def add_column(self, name: str, prompt: str) -> None:
if self.lmm is None:
raise ValueError("LMM not set yet")

self.df[name] = self.df["image_paths"].progress_apply(
lambda x: self.lmm.generate(prompt, image=x)
)

def add_index(self, target_col: str) -> None:
if self.emb is None:
raise ValueError("Embedder not set yet")

embeddings = self.df[target_col].progress_apply(lambda x: self.emb.embed(x))
embeddings = np.array(embeddings.tolist()).astype(np.float32)
self.index = faiss.IndexFlatL2(embeddings.shape[1])
self.index.add(embeddings)

def get_embeddings(self) -> np.ndarray:
def get_embeddings(self) -> npt.NDArray[np.float32]:
if self.index is None:
raise ValueError("Index not built yet")

ntotal = self.index.ntotal
d = self.index.d
return faiss.rev_swig_ptr(self.index.get_xb(), ntotal * d).reshape(ntotal, d)

def search(self, query: str, top_k: int = 10) -> list[dict]:
if self.index is None:
raise ValueError("Index not built yet")
if self.emb is None:
raise ValueError("Embedder not set yet")

query_embedding = self.emb.embed(query)
_, I = self.index.search(query_embedding.reshape(1, -1), top_k)
return self.df.iloc[I[0]].to_dict(orient="records")


def build_data(self, data: str | Path | list[str | Path]) -> Data:
def build_data(data: str | Path | list[str | Path]) -> Data:
if isinstance(data, Path) or isinstance(data, str):
data = Path(data)
data_files = list(Path(data).glob("*"))
Expand Down

0 comments on commit fc6a6ba

Please sign in to comment.