diff --git a/lmm_tools/__init__.py b/lmm_tools/__init__.py index e69de29b..75e4f900 100644 --- a/lmm_tools/__init__.py +++ b/lmm_tools/__init__.py @@ -0,0 +1,3 @@ +from .data import Data, get_data +from .lmm import LMM, get_lmm +from .emb import Embedder, get_embedder diff --git a/lmm_tools/data/__init__.py b/lmm_tools/data/__init__.py new file mode 100644 index 00000000..5d816078 --- /dev/null +++ b/lmm_tools/data/__init__.py @@ -0,0 +1 @@ +from .data import Data, build_data diff --git a/lmm_tools/data/data.py b/lmm_tools/data/data.py new file mode 100644 index 00000000..b0864f5d --- /dev/null +++ b/lmm_tools/data/data.py @@ -0,0 +1,62 @@ +import uuid +from pathlib import Path + +import faiss +import numpy as np +import numpy.typing as np +import pandas as pd +from tqdm import tqdm + +from lmm_tools import LMM, Embedder + +tqdm.pandas() + + +class Data: + def __init__(self, df: pd.DataFrame): + self.df = pd.DataFrame() + self.lmm: LMM = None + self.emb: Embedder = None + self.index = None + if "image_paths" not in df.columns: + raise ValueError("image_paths column must be present in DataFrame") + + def add_embedder(self, emb: Embedder): + self.emb = emb + + def add_lmm(self, lmm: LMM): + self.lmm = lmm + + def add_column(self, name: str, prompt: str) -> None: + self.df[name] = self.df["image_paths"].progress_apply( + lambda x: self.lmm.generate(prompt, image=x) + ) + + def add_index(self, target_col: str) -> None: + embeddings = self.df[target_col].progress_apply(lambda x: self.emb.embed(x)) + embeddings = np.array(embeddings.tolist()).astype(np.float32) + self.index = faiss.IndexFlatL2(embeddings.shape[1]) + self.index.add(embeddings) + + def get_embeddings(self) -> np.ndarray: + ntotal = self.index.ntotal + d = self.index.d + return faiss.rev_swig_ptr(self.index.get_xb(), ntotal * d).reshape(ntotal, d) + + def search(self, query: str, top_k: int = 10) -> list[dict]: + query_embedding = self.emb.embed(query) + _, I = self.index.search(query_embedding.reshape(1, -1), top_k) + return self.df.iloc[I[0]].to_dict(orient="records") + + +def build_data(self, data: str | Path | list[str | Path]) -> Data: + if isinstance(data, Path) or isinstance(data, str): + data = Path(data) + data_files = list(Path(data).glob("*")) + elif isinstance(data, list): + data_files = [Path(d) for d in data] + + df = pd.DataFrame() + df["image_paths"] = data_files + df["image_id"] = [uuid.uuid4() for _ in range(len(data_files))] + return Data(df) diff --git a/lmm_tools/emb/__init__.py b/lmm_tools/emb/__init__.py new file mode 100644 index 00000000..8c68f801 --- /dev/null +++ b/lmm_tools/emb/__init__.py @@ -0,0 +1 @@ +from .emb import Embedder, SentenceTransformerEmb, OpenAIEmb, get_embedder diff --git a/lmm_tools/emb/emb.py b/lmm_tools/emb/emb.py new file mode 100644 index 00000000..d8d0ad1f --- /dev/null +++ b/lmm_tools/emb/emb.py @@ -0,0 +1,41 @@ +from abc import ABC, abstractmethod + +import numpy as np +import numpy.typing as npt + + +class Embedder(ABC): + @abstractmethod + def embed(self, text: str) -> list: + pass + + +class SentenceTransformerEmb(Embedder): + def __init__(self, model_name: str = "all-MiniLM-L12-v2"): + from sentence_transformers import SentenceTransformer + + self.model = SentenceTransformer(model_name) + + def embed(self, text: str) -> npt.NDArray[np.float32]: + return self.model.encode([text]).flatten().astype(np.float32) + + +class OpenAIEmb(Embedder): + def __init__(self, model_name: str = "text-embedding-3-small"): + from openai import OpenAI + + self.client = OpenAI() + self.model_name = model_name + + def embed(self, text: str) -> npt.NDArray[np.float32]: + response = self.client.embeddings.create(input=text, model=self.model_name) + return np.array(response.data[0].embedding).astype(np.float32) + + +def get_embedder(name: str) -> Embedder: + if name == "sentence-transformer": + return SentenceTransformerEmb() + elif name == "openai": + return OpenAIEmb() + else: + raise ValueError(f"Unknown embedder name: {name}, currently support sentence-transformer, openai.") diff --git a/lmm_tools/lmm/__init__.py b/lmm_tools/lmm/__init__.py new file mode 100644 index 00000000..dfef1e35 --- /dev/null +++ b/lmm_tools/lmm/__init__.py @@ -0,0 +1 @@ +from .lmm import LMM, OpenAILMM, LLaVALMM, get_lmm diff --git a/lmm_tools/lmm/lmm.py b/lmm_tools/lmm/lmm.py new file mode 100644 index 00000000..81a08c96 --- /dev/null +++ b/lmm_tools/lmm/lmm.py @@ -0,0 +1,69 @@ +import base64 +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional + + +def encode_image(image: str | Path) -> str: + with open(image, "rb") as f: + encoded_image = base64.b64encode(f.read()).decode("utf-8") + return encoded_image + + +class LMM(ABC): + @abstractmethod + def generate(self, prompt: str, image: Optional[str | Path]) -> str: + pass + + +class LLaVALMM(LMM): + def __init__(self, name: str): + self.name = name + + def generate(self, prompt: str, image: Optional[str | Path]) -> str: + pass + + +class OpenAILMM(LMM): + def __init__(self, name: str): + from openai import OpenAI + + self.name = name + self.client = OpenAI() + + def generate(self, prompt: str, image: Optional[str | Path]) -> str: + message = [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + ] + } + ] + if image: + extension = Path(image).suffix + encoded_image = encode_image(image) + message[0]["content"].append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/{extension};base64,{encoded_image}", + "detail": "low", + }, + }, + ) + + response = self.client.chat.completions.create( + model="gpt-4-vision-preview", + message=message + ) + return response.choices[0].message.content + + +def get_lmm(name: str) -> LMM: + if name == "openai": + return OpenAILMM(name) + elif name == "llava-v1.6-34b": + return LLaVALMM(name) + else: + raise ValueError(f"Unknown LMM: {name}, current support openai, llava-v1.6-34b")