-
Notifications
You must be signed in to change notification settings - Fork 128
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fee778a
commit 58c2289
Showing
7 changed files
with
178 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .data import Data, get_data | ||
from .lmm import LMM, get_lmm | ||
from .emb import Embedder, get_embedder |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .data import Data, build_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import uuid | ||
from pathlib import Path | ||
|
||
import faiss | ||
import numpy as np | ||
import numpy.typing as np | ||
import pandas as pd | ||
from tqdm import tqdm | ||
|
||
from lmm_tools import LMM, Embedder | ||
|
||
tqdm.pandas() | ||
|
||
|
||
class Data: | ||
def __init__(self, df: pd.DataFrame): | ||
self.df = pd.DataFrame() | ||
self.lmm: LMM = None | ||
self.emb: Embedder = None | ||
self.index = None | ||
if "image_paths" not in df.columns: | ||
raise ValueError("image_paths column must be present in DataFrame") | ||
|
||
def add_embedder(self, emb: Embedder): | ||
self.emb = emb | ||
|
||
def add_lmm(self, lmm: LMM): | ||
self.lmm = lmm | ||
|
||
def add_column(self, name: str, prompt: str) -> None: | ||
self.df[name] = self.df["image_paths"].progress_apply( | ||
lambda x: self.lmm.generate(prompt, image=x) | ||
) | ||
|
||
def add_index(self, target_col: str) -> None: | ||
embeddings = self.df[target_col].progress_apply(lambda x: self.emb.embed(x)) | ||
embeddings = np.array(embeddings.tolist()).astype(np.float32) | ||
self.index = faiss.IndexFlatL2(embeddings.shape[1]) | ||
self.index.add(embeddings) | ||
|
||
def get_embeddings(self) -> np.ndarray: | ||
ntotal = self.index.ntotal | ||
d = self.index.d | ||
return faiss.rev_swig_ptr(self.index.get_xb(), ntotal * d).reshape(ntotal, d) | ||
|
||
def search(self, query: str, top_k: int = 10) -> list[dict]: | ||
query_embedding = self.emb.embed(query) | ||
_, I = self.index.search(query_embedding.reshape(1, -1), top_k) | ||
return self.df.iloc[I[0]].to_dict(orient="records") | ||
|
||
|
||
def build_data(self, data: str | Path | list[str | Path]) -> Data: | ||
if isinstance(data, Path) or isinstance(data, str): | ||
data = Path(data) | ||
data_files = list(Path(data).glob("*")) | ||
elif isinstance(data, list): | ||
data_files = [Path(d) for d in data] | ||
|
||
df = pd.DataFrame() | ||
df["image_paths"] = data_files | ||
df["image_id"] = [uuid.uuid4() for _ in range(len(data_files))] | ||
return Data(df) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .emb import Embedder, SentenceTransformerEmb, OpenAIEmb, get_embedder |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
import numpy as np | ||
import numpy.typing as npt | ||
|
||
|
||
class Embedder(ABC): | ||
@abstractmethod | ||
def embed(self, text: str) -> list: | ||
pass | ||
|
||
|
||
class SentenceTransformerEmb(Embedder): | ||
def __init__(self, model_name: str = "all-MiniLM-L12-v2"): | ||
from sentence_transformers import SentenceTransformer | ||
|
||
self.model = SentenceTransformer(model_name) | ||
|
||
def embed(self, text: str) -> npt.NDArray[np.float32]: | ||
return self.model.encode([text]).flatten().astype(np.float32) | ||
|
||
|
||
class OpenAIEmb(Embedder): | ||
def __init__(self, model_name: str = "text-embedding-3-small"): | ||
from openai import OpenAI | ||
|
||
self.client = OpenAI() | ||
self.model_name = model_name | ||
|
||
def embed(self, text: str) -> npt.NDArray[np.float32]: | ||
response = self.client.embeddings.create(input=text, model=self.model_name) | ||
return np.array(response.data[0].embedding).astype(np.float32) | ||
|
||
|
||
def get_embedder(name: str) -> Embedder: | ||
if name == "sentence-transformer": | ||
return SentenceTransformerEmb() | ||
elif name == "openai": | ||
return OpenAIEmb() | ||
else: | ||
raise ValueError(f"Unknown embedder name: {name}, currently support sentence-transformer, openai.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .lmm import LMM, OpenAILMM, LLaVALMM, get_lmm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import base64 | ||
from abc import ABC, abstractmethod | ||
from pathlib import Path | ||
from typing import Optional | ||
|
||
|
||
def encode_image(image: str | Path) -> str: | ||
with open(image, "rb") as f: | ||
encoded_image = base64.b64encode(f.read()).decode("utf-8") | ||
return encoded_image | ||
|
||
|
||
class LMM(ABC): | ||
@abstractmethod | ||
def generate(self, prompt: str, image: Optional[str | Path]) -> str: | ||
pass | ||
|
||
|
||
class LLaVALMM(LMM): | ||
def __init__(self, name: str): | ||
self.name = name | ||
|
||
def generate(self, prompt: str, image: Optional[str | Path]) -> str: | ||
pass | ||
|
||
|
||
class OpenAILMM(LMM): | ||
def __init__(self, name: str): | ||
from openai import OpenAI | ||
|
||
self.name = name | ||
self.client = OpenAI() | ||
|
||
def generate(self, prompt: str, image: Optional[str | Path]) -> str: | ||
message = [ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{"type": "text", "text": prompt}, | ||
] | ||
} | ||
] | ||
if image: | ||
extension = Path(image).suffix | ||
encoded_image = encode_image(image) | ||
message[0]["content"].append( | ||
{ | ||
"type": "image_url", | ||
"image_url": { | ||
"url": f"data:image/{extension};base64,{encoded_image}", | ||
"detail": "low", | ||
}, | ||
}, | ||
) | ||
|
||
response = self.client.chat.completions.create( | ||
model="gpt-4-vision-preview", | ||
message=message | ||
) | ||
return response.choices[0].message.content | ||
|
||
|
||
def get_lmm(name: str) -> LMM: | ||
if name == "openai": | ||
return OpenAILMM(name) | ||
elif name == "llava-v1.6-34b": | ||
return LLaVALMM(name) | ||
else: | ||
raise ValueError(f"Unknown LMM: {name}, current support openai, llava-v1.6-34b") |