Skip to content

Commit

Permalink
use inner product
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Feb 21, 2024
1 parent c913583 commit 16b16e5
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lmm_tools/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, df: pd.DataFrame):
self.df = df
self.lmm: Optional[LMM] = None
self.emb: Optional[Embedder] = None
self.index: Optional[faiss.IndexFlatL2] = None # type: ignore
self.index: Optional[faiss.IndexFlatIP] = None # type: ignore
if "image_paths" not in self.df.columns:
raise ValueError("image_paths column must be present in DataFrame")
if "image_id" not in self.df.columns:
Expand Down Expand Up @@ -64,7 +64,7 @@ def build_index(self, target_col: str) -> Self:

embeddings: pd.Series = self.df[target_col].progress_apply(lambda x: self.emb.embed(x)) # type: ignore
embeddings_np = np.array(embeddings.tolist()).astype(np.float32)
self.index = faiss.IndexFlatL2(embeddings_np.shape[1])
self.index = faiss.IndexFlatIP(embeddings_np.shape[1])
self.index.add(embeddings_np)
return self

Expand Down

0 comments on commit 16b16e5

Please sign in to comment.