Skip to content

Commit

Permalink
initial commit for structure
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Feb 14, 2024
1 parent fee778a commit 58c2289
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 0 deletions.
3 changes: 3 additions & 0 deletions lmm_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .data import Data, get_data
from .lmm import LMM, get_lmm
from .emb import Embedder, get_embedder
1 change: 1 addition & 0 deletions lmm_tools/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .data import Data, build_data
62 changes: 62 additions & 0 deletions lmm_tools/data/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import uuid
from pathlib import Path

import faiss
import numpy as np
import numpy.typing as np
import pandas as pd
from tqdm import tqdm

from lmm_tools import LMM, Embedder

tqdm.pandas()


class Data:
def __init__(self, df: pd.DataFrame):
self.df = pd.DataFrame()
self.lmm: LMM = None
self.emb: Embedder = None
self.index = None
if "image_paths" not in df.columns:
raise ValueError("image_paths column must be present in DataFrame")

def add_embedder(self, emb: Embedder):
self.emb = emb

def add_lmm(self, lmm: LMM):
self.lmm = lmm

def add_column(self, name: str, prompt: str) -> None:
self.df[name] = self.df["image_paths"].progress_apply(
lambda x: self.lmm.generate(prompt, image=x)
)

def add_index(self, target_col: str) -> None:
embeddings = self.df[target_col].progress_apply(lambda x: self.emb.embed(x))
embeddings = np.array(embeddings.tolist()).astype(np.float32)
self.index = faiss.IndexFlatL2(embeddings.shape[1])
self.index.add(embeddings)

def get_embeddings(self) -> np.ndarray:
ntotal = self.index.ntotal
d = self.index.d
return faiss.rev_swig_ptr(self.index.get_xb(), ntotal * d).reshape(ntotal, d)

def search(self, query: str, top_k: int = 10) -> list[dict]:
query_embedding = self.emb.embed(query)
_, I = self.index.search(query_embedding.reshape(1, -1), top_k)
return self.df.iloc[I[0]].to_dict(orient="records")


def build_data(self, data: str | Path | list[str | Path]) -> Data:
if isinstance(data, Path) or isinstance(data, str):
data = Path(data)
data_files = list(Path(data).glob("*"))
elif isinstance(data, list):
data_files = [Path(d) for d in data]

df = pd.DataFrame()
df["image_paths"] = data_files
df["image_id"] = [uuid.uuid4() for _ in range(len(data_files))]
return Data(df)
1 change: 1 addition & 0 deletions lmm_tools/emb/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .emb import Embedder, SentenceTransformerEmb, OpenAIEmb, get_embedder
41 changes: 41 additions & 0 deletions lmm_tools/emb/emb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from abc import ABC, abstractmethod

import numpy as np
import numpy.typing as npt


class Embedder(ABC):
@abstractmethod
def embed(self, text: str) -> list:
pass


class SentenceTransformerEmb(Embedder):
def __init__(self, model_name: str = "all-MiniLM-L12-v2"):
from sentence_transformers import SentenceTransformer

self.model = SentenceTransformer(model_name)

def embed(self, text: str) -> npt.NDArray[np.float32]:
return self.model.encode([text]).flatten().astype(np.float32)


class OpenAIEmb(Embedder):
def __init__(self, model_name: str = "text-embedding-3-small"):
from openai import OpenAI

self.client = OpenAI()
self.model_name = model_name

def embed(self, text: str) -> npt.NDArray[np.float32]:
response = self.client.embeddings.create(input=text, model=self.model_name)
return np.array(response.data[0].embedding).astype(np.float32)


def get_embedder(name: str) -> Embedder:
if name == "sentence-transformer":
return SentenceTransformerEmb()
elif name == "openai":
return OpenAIEmb()
else:
raise ValueError(f"Unknown embedder name: {name}, currently support sentence-transformer, openai.")
1 change: 1 addition & 0 deletions lmm_tools/lmm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .lmm import LMM, OpenAILMM, LLaVALMM, get_lmm
69 changes: 69 additions & 0 deletions lmm_tools/lmm/lmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import base64
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional


def encode_image(image: str | Path) -> str:
with open(image, "rb") as f:
encoded_image = base64.b64encode(f.read()).decode("utf-8")
return encoded_image


class LMM(ABC):
@abstractmethod
def generate(self, prompt: str, image: Optional[str | Path]) -> str:
pass


class LLaVALMM(LMM):
def __init__(self, name: str):
self.name = name

def generate(self, prompt: str, image: Optional[str | Path]) -> str:
pass


class OpenAILMM(LMM):
def __init__(self, name: str):
from openai import OpenAI

self.name = name
self.client = OpenAI()

def generate(self, prompt: str, image: Optional[str | Path]) -> str:
message = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
]
}
]
if image:
extension = Path(image).suffix
encoded_image = encode_image(image)
message[0]["content"].append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/{extension};base64,{encoded_image}",
"detail": "low",
},
},
)

response = self.client.chat.completions.create(
model="gpt-4-vision-preview",
message=message
)
return response.choices[0].message.content


def get_lmm(name: str) -> LMM:
if name == "openai":
return OpenAILMM(name)
elif name == "llava-v1.6-34b":
return LLaVALMM(name)
else:
raise ValueError(f"Unknown LMM: {name}, current support openai, llava-v1.6-34b")

0 comments on commit 58c2289

Please sign in to comment.