From 16b16e58a563029ae1fb343a7fa54001df7036c6 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Tue, 20 Feb 2024 16:57:27 -0800 Subject: [PATCH] use inner product --- lmm_tools/data/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lmm_tools/data/data.py b/lmm_tools/data/data.py index 09156acd..8893ae72 100644 --- a/lmm_tools/data/data.py +++ b/lmm_tools/data/data.py @@ -31,7 +31,7 @@ def __init__(self, df: pd.DataFrame): self.df = df self.lmm: Optional[LMM] = None self.emb: Optional[Embedder] = None - self.index: Optional[faiss.IndexFlatL2] = None # type: ignore + self.index: Optional[faiss.IndexFlatIP] = None # type: ignore if "image_paths" not in self.df.columns: raise ValueError("image_paths column must be present in DataFrame") if "image_id" not in self.df.columns: @@ -64,7 +64,7 @@ def build_index(self, target_col: str) -> Self: embeddings: pd.Series = self.df[target_col].progress_apply(lambda x: self.emb.embed(x)) # type: ignore embeddings_np = np.array(embeddings.tolist()).astype(np.float32) - self.index = faiss.IndexFlatL2(embeddings_np.shape[1]) + self.index = faiss.IndexFlatIP(embeddings_np.shape[1]) self.index.add(embeddings_np) return self