From 4bbb7771b5b482e189bcbdbdcf393b30948690f8 Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Sat, 17 Feb 2024 17:40:41 -0800 Subject: [PATCH] updateds and docs --- lmm_tools/data/data.py | 63 +++++++++++++++++++++++++++++++++++------- lmm_tools/emb/emb.py | 4 ++- lmm_tools/lmm/lmm.py | 13 +++++---- 3 files changed, 64 insertions(+), 16 deletions(-) diff --git a/lmm_tools/data/data.py b/lmm_tools/data/data.py index 5d14d2fb..bbafcb2f 100644 --- a/lmm_tools/data/data.py +++ b/lmm_tools/data/data.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import uuid from pathlib import Path @@ -5,37 +7,57 @@ import numpy as np import numpy.typing as npt import pandas as pd +from faiss import read_index, write_index from tqdm import tqdm +from typing_extensions import Self -from lmm_tools import LMM, Embedder +from lmm_tools.emb import Embedder +from lmm_tools.lmm import LMM tqdm.pandas() -class Data: +class DataStore: + r"""A class to store and manage image data along with its generated metadata from an LMM.""" + def __init__(self, df: pd.DataFrame): - self.df = pd.DataFrame() + r"""Initializes the DataStore with a DataFrame containing image paths and image IDs. If the image IDs are not present, they are generated using UUID4. The DataFrame must contain an 'image_paths' column. + + Args: + df (pd.DataFrame): The DataFrame containing "image_paths" and "image_id" columns. + + """ + self.df = df self.lmm: LMM | None = None self.emb: Embedder | None = None self.index = None - if "image_paths" not in df.columns: + if "image_paths" not in self.df.columns: raise ValueError("image_paths column must be present in DataFrame") + if "image_id" not in self.df.columns: + self.df["image_id"] = [str(uuid.uuid4()) for _ in range(len(df))] - def add_embedder(self, emb: Embedder): + def add_embedder(self, emb: Embedder) -> Self: self.emb = emb + return self - def add_lmm(self, lmm: LMM): + def add_lmm(self, lmm: LMM) -> Self: self.lmm = lmm + return self - def add_column(self, name: str, prompt: str) -> None: + def add_column(self, name: str, prompt: str) -> Self: + r"""Adds a new column to the DataFrame containing the generated metadata from the LMM.""" if self.lmm is None: raise ValueError("LMM not set yet") self.df[name] = self.df["image_paths"].progress_apply( lambda x: self.lmm.generate(prompt, image=x) ) + return self - def add_index(self, target_col: str) -> None: + def build_index(self, target_col: str) -> Self: + r"""This will generate embeddings for the `target_col` and build a searchable index over them, so next time you run search it will search over this index. + Args: + target_col (str): The column name containing the data to be indexed.""" if self.emb is None: raise ValueError("Embedder not set yet") @@ -43,6 +65,7 @@ def add_index(self, target_col: str) -> None: embeddings = np.array(embeddings.tolist()).astype(np.float32) self.index = faiss.IndexFlatL2(embeddings.shape[1]) self.index.add(embeddings) + return self def get_embeddings(self) -> npt.NDArray[np.float32]: if self.index is None: @@ -53,6 +76,10 @@ def get_embeddings(self) -> npt.NDArray[np.float32]: return faiss.rev_swig_ptr(self.index.get_xb(), ntotal * d).reshape(ntotal, d) def search(self, query: str, top_k: int = 10) -> list[dict]: + r"""Searches the index for the most similar images to the query and returns the top_k results. + Args: + query (str): The query to search for. + top_k (int, optional): The number of results to return. Defaults to 10.""" if self.index is None: raise ValueError("Index not built yet") if self.emb is None: @@ -62,8 +89,24 @@ def search(self, query: str, top_k: int = 10) -> list[dict]: _, I = self.index.search(query_embedding.reshape(1, -1), top_k) return self.df.iloc[I[0]].to_dict(orient="records") + def save(self, path: str | Path) -> None: + path = Path(path) + path.mkdir(parents=True) + self.df.to_csv(path / "data.csv") + if self.index is not None: + write_index(self.index, str(path / "data.index")) + + @classmethod + def load(cls, path: str | Path) -> DataStore: + path = Path(path) + df = pd.read_csv(path / "data.csv", index_col=0) + ds = DataStore(df) + if Path(path / "data.index").exists(): + ds.index = read_index(str(path / "data.index")) + return ds + -def build_data(data: str | Path | list[str | Path]) -> Data: +def build_data_store(data: str | Path | list[str | Path]) -> DataStore: if isinstance(data, Path) or isinstance(data, str): data = Path(data) data_files = list(Path(data).glob("*")) @@ -73,4 +116,4 @@ def build_data(data: str | Path | list[str | Path]) -> Data: df = pd.DataFrame() df["image_paths"] = data_files df["image_id"] = [uuid.uuid4() for _ in range(len(data_files))] - return Data(df) + return DataStore(df) diff --git a/lmm_tools/emb/emb.py b/lmm_tools/emb/emb.py index d8d0ad1f..91d53b07 100644 --- a/lmm_tools/emb/emb.py +++ b/lmm_tools/emb/emb.py @@ -38,4 +38,6 @@ def get_embedder(name: str) -> Embedder: elif name == "openai": return OpenAIEmb() else: - raise ValueError(f"Unknown embedder name: {name}, currently support sentence-transformer, openai.") + raise ValueError( + f"Unknown embedder name: {name}, currently support sentence-transformer, openai." + ) diff --git a/lmm_tools/lmm/lmm.py b/lmm_tools/lmm/lmm.py index 81a08c96..4f20c7f3 100644 --- a/lmm_tools/lmm/lmm.py +++ b/lmm_tools/lmm/lmm.py @@ -17,6 +17,8 @@ def generate(self, prompt: str, image: Optional[str | Path]) -> str: class LLaVALMM(LMM): + r"""An LMM class for the LLaVA-1.6 34B model.""" + def __init__(self, name: str): self.name = name @@ -25,6 +27,8 @@ def generate(self, prompt: str, image: Optional[str | Path]) -> str: class OpenAILMM(LMM): + r"""An LMM class for the OpenAI GPT-4 Vision model.""" + def __init__(self, name: str): from openai import OpenAI @@ -37,7 +41,7 @@ def generate(self, prompt: str, image: Optional[str | Path]) -> str: "role": "user", "content": [ {"type": "text", "text": prompt}, - ] + ], } ] if image: @@ -54,8 +58,7 @@ def generate(self, prompt: str, image: Optional[str | Path]) -> str: ) response = self.client.chat.completions.create( - model="gpt-4-vision-preview", - message=message + model="gpt-4-vision-preview", message=message ) return response.choices[0].message.content @@ -63,7 +66,7 @@ def generate(self, prompt: str, image: Optional[str | Path]) -> str: def get_lmm(name: str) -> LMM: if name == "openai": return OpenAILMM(name) - elif name == "llava-v1.6-34b": + elif name == "llava": return LLaVALMM(name) else: - raise ValueError(f"Unknown LMM: {name}, current support openai, llava-v1.6-34b") + raise ValueError(f"Unknown LMM: {name}, current support openai, llava")