Skip to content

Commit

Permalink
updateds and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Feb 18, 2024
1 parent 369f793 commit 4bbb777
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 16 deletions.
63 changes: 53 additions & 10 deletions lmm_tools/data/data.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,71 @@
from __future__ import annotations

import uuid
from pathlib import Path

import faiss
import numpy as np
import numpy.typing as npt
import pandas as pd
from faiss import read_index, write_index
from tqdm import tqdm
from typing_extensions import Self

from lmm_tools import LMM, Embedder
from lmm_tools.emb import Embedder
from lmm_tools.lmm import LMM

tqdm.pandas()


class Data:
class DataStore:
r"""A class to store and manage image data along with its generated metadata from an LMM."""

def __init__(self, df: pd.DataFrame):
self.df = pd.DataFrame()
r"""Initializes the DataStore with a DataFrame containing image paths and image IDs. If the image IDs are not present, they are generated using UUID4. The DataFrame must contain an 'image_paths' column.
Args:
df (pd.DataFrame): The DataFrame containing "image_paths" and "image_id" columns.
"""
self.df = df
self.lmm: LMM | None = None
self.emb: Embedder | None = None
self.index = None
if "image_paths" not in df.columns:
if "image_paths" not in self.df.columns:
raise ValueError("image_paths column must be present in DataFrame")
if "image_id" not in self.df.columns:
self.df["image_id"] = [str(uuid.uuid4()) for _ in range(len(df))]

def add_embedder(self, emb: Embedder):
def add_embedder(self, emb: Embedder) -> Self:
self.emb = emb
return self

def add_lmm(self, lmm: LMM):
def add_lmm(self, lmm: LMM) -> Self:
self.lmm = lmm
return self

def add_column(self, name: str, prompt: str) -> None:
def add_column(self, name: str, prompt: str) -> Self:
r"""Adds a new column to the DataFrame containing the generated metadata from the LMM."""
if self.lmm is None:
raise ValueError("LMM not set yet")

self.df[name] = self.df["image_paths"].progress_apply(
lambda x: self.lmm.generate(prompt, image=x)
)
return self

def add_index(self, target_col: str) -> None:
def build_index(self, target_col: str) -> Self:
r"""This will generate embeddings for the `target_col` and build a searchable index over them, so next time you run search it will search over this index.
Args:
target_col (str): The column name containing the data to be indexed."""
if self.emb is None:
raise ValueError("Embedder not set yet")

embeddings = self.df[target_col].progress_apply(lambda x: self.emb.embed(x))
embeddings = np.array(embeddings.tolist()).astype(np.float32)
self.index = faiss.IndexFlatL2(embeddings.shape[1])
self.index.add(embeddings)
return self

def get_embeddings(self) -> npt.NDArray[np.float32]:
if self.index is None:
Expand All @@ -53,6 +76,10 @@ def get_embeddings(self) -> npt.NDArray[np.float32]:
return faiss.rev_swig_ptr(self.index.get_xb(), ntotal * d).reshape(ntotal, d)

def search(self, query: str, top_k: int = 10) -> list[dict]:
r"""Searches the index for the most similar images to the query and returns the top_k results.
Args:
query (str): The query to search for.
top_k (int, optional): The number of results to return. Defaults to 10."""
if self.index is None:
raise ValueError("Index not built yet")
if self.emb is None:
Expand All @@ -62,8 +89,24 @@ def search(self, query: str, top_k: int = 10) -> list[dict]:
_, I = self.index.search(query_embedding.reshape(1, -1), top_k)
return self.df.iloc[I[0]].to_dict(orient="records")

def save(self, path: str | Path) -> None:
path = Path(path)
path.mkdir(parents=True)
self.df.to_csv(path / "data.csv")
if self.index is not None:
write_index(self.index, str(path / "data.index"))

@classmethod
def load(cls, path: str | Path) -> DataStore:
path = Path(path)
df = pd.read_csv(path / "data.csv", index_col=0)
ds = DataStore(df)
if Path(path / "data.index").exists():
ds.index = read_index(str(path / "data.index"))
return ds


def build_data(data: str | Path | list[str | Path]) -> Data:
def build_data_store(data: str | Path | list[str | Path]) -> DataStore:
if isinstance(data, Path) or isinstance(data, str):
data = Path(data)
data_files = list(Path(data).glob("*"))
Expand All @@ -73,4 +116,4 @@ def build_data(data: str | Path | list[str | Path]) -> Data:
df = pd.DataFrame()
df["image_paths"] = data_files
df["image_id"] = [uuid.uuid4() for _ in range(len(data_files))]
return Data(df)
return DataStore(df)
4 changes: 3 additions & 1 deletion lmm_tools/emb/emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,6 @@ def get_embedder(name: str) -> Embedder:
elif name == "openai":
return OpenAIEmb()
else:
raise ValueError(f"Unknown embedder name: {name}, currently support sentence-transformer, openai.")
raise ValueError(
f"Unknown embedder name: {name}, currently support sentence-transformer, openai."
)
13 changes: 8 additions & 5 deletions lmm_tools/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def generate(self, prompt: str, image: Optional[str | Path]) -> str:


class LLaVALMM(LMM):
r"""An LMM class for the LLaVA-1.6 34B model."""

def __init__(self, name: str):
self.name = name

Expand All @@ -25,6 +27,8 @@ def generate(self, prompt: str, image: Optional[str | Path]) -> str:


class OpenAILMM(LMM):
r"""An LMM class for the OpenAI GPT-4 Vision model."""

def __init__(self, name: str):
from openai import OpenAI

Expand All @@ -37,7 +41,7 @@ def generate(self, prompt: str, image: Optional[str | Path]) -> str:
"role": "user",
"content": [
{"type": "text", "text": prompt},
]
],
}
]
if image:
Expand All @@ -54,16 +58,15 @@ def generate(self, prompt: str, image: Optional[str | Path]) -> str:
)

response = self.client.chat.completions.create(
model="gpt-4-vision-preview",
message=message
model="gpt-4-vision-preview", message=message
)
return response.choices[0].message.content


def get_lmm(name: str) -> LMM:
if name == "openai":
return OpenAILMM(name)
elif name == "llava-v1.6-34b":
elif name == "llava":
return LLaVALMM(name)
else:
raise ValueError(f"Unknown LMM: {name}, current support openai, llava-v1.6-34b")
raise ValueError(f"Unknown LMM: {name}, current support openai, llava")

0 comments on commit 4bbb777

Please sign in to comment.