Skip to content

Commit 58c2289

Browse files
committed
initial commit for structure
1 parent fee778a commit 58c2289

File tree

7 files changed

+178
-0
lines changed

7 files changed

+178
-0
lines changed

lmm_tools/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .data import Data, get_data
2+
from .lmm import LMM, get_lmm
3+
from .emb import Embedder, get_embedder

lmm_tools/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .data import Data, build_data

lmm_tools/data/data.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import uuid
2+
from pathlib import Path
3+
4+
import faiss
5+
import numpy as np
6+
import numpy.typing as np
7+
import pandas as pd
8+
from tqdm import tqdm
9+
10+
from lmm_tools import LMM, Embedder
11+
12+
tqdm.pandas()
13+
14+
15+
class Data:
16+
def __init__(self, df: pd.DataFrame):
17+
self.df = pd.DataFrame()
18+
self.lmm: LMM = None
19+
self.emb: Embedder = None
20+
self.index = None
21+
if "image_paths" not in df.columns:
22+
raise ValueError("image_paths column must be present in DataFrame")
23+
24+
def add_embedder(self, emb: Embedder):
25+
self.emb = emb
26+
27+
def add_lmm(self, lmm: LMM):
28+
self.lmm = lmm
29+
30+
def add_column(self, name: str, prompt: str) -> None:
31+
self.df[name] = self.df["image_paths"].progress_apply(
32+
lambda x: self.lmm.generate(prompt, image=x)
33+
)
34+
35+
def add_index(self, target_col: str) -> None:
36+
embeddings = self.df[target_col].progress_apply(lambda x: self.emb.embed(x))
37+
embeddings = np.array(embeddings.tolist()).astype(np.float32)
38+
self.index = faiss.IndexFlatL2(embeddings.shape[1])
39+
self.index.add(embeddings)
40+
41+
def get_embeddings(self) -> np.ndarray:
42+
ntotal = self.index.ntotal
43+
d = self.index.d
44+
return faiss.rev_swig_ptr(self.index.get_xb(), ntotal * d).reshape(ntotal, d)
45+
46+
def search(self, query: str, top_k: int = 10) -> list[dict]:
47+
query_embedding = self.emb.embed(query)
48+
_, I = self.index.search(query_embedding.reshape(1, -1), top_k)
49+
return self.df.iloc[I[0]].to_dict(orient="records")
50+
51+
52+
def build_data(self, data: str | Path | list[str | Path]) -> Data:
53+
if isinstance(data, Path) or isinstance(data, str):
54+
data = Path(data)
55+
data_files = list(Path(data).glob("*"))
56+
elif isinstance(data, list):
57+
data_files = [Path(d) for d in data]
58+
59+
df = pd.DataFrame()
60+
df["image_paths"] = data_files
61+
df["image_id"] = [uuid.uuid4() for _ in range(len(data_files))]
62+
return Data(df)

lmm_tools/emb/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .emb import Embedder, SentenceTransformerEmb, OpenAIEmb, get_embedder

lmm_tools/emb/emb.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from abc import ABC, abstractmethod
2+
3+
import numpy as np
4+
import numpy.typing as npt
5+
6+
7+
class Embedder(ABC):
8+
@abstractmethod
9+
def embed(self, text: str) -> list:
10+
pass
11+
12+
13+
class SentenceTransformerEmb(Embedder):
14+
def __init__(self, model_name: str = "all-MiniLM-L12-v2"):
15+
from sentence_transformers import SentenceTransformer
16+
17+
self.model = SentenceTransformer(model_name)
18+
19+
def embed(self, text: str) -> npt.NDArray[np.float32]:
20+
return self.model.encode([text]).flatten().astype(np.float32)
21+
22+
23+
class OpenAIEmb(Embedder):
24+
def __init__(self, model_name: str = "text-embedding-3-small"):
25+
from openai import OpenAI
26+
27+
self.client = OpenAI()
28+
self.model_name = model_name
29+
30+
def embed(self, text: str) -> npt.NDArray[np.float32]:
31+
response = self.client.embeddings.create(input=text, model=self.model_name)
32+
return np.array(response.data[0].embedding).astype(np.float32)
33+
34+
35+
def get_embedder(name: str) -> Embedder:
36+
if name == "sentence-transformer":
37+
return SentenceTransformerEmb()
38+
elif name == "openai":
39+
return OpenAIEmb()
40+
else:
41+
raise ValueError(f"Unknown embedder name: {name}, currently support sentence-transformer, openai.")

lmm_tools/lmm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .lmm import LMM, OpenAILMM, LLaVALMM, get_lmm

lmm_tools/lmm/lmm.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import base64
2+
from abc import ABC, abstractmethod
3+
from pathlib import Path
4+
from typing import Optional
5+
6+
7+
def encode_image(image: str | Path) -> str:
8+
with open(image, "rb") as f:
9+
encoded_image = base64.b64encode(f.read()).decode("utf-8")
10+
return encoded_image
11+
12+
13+
class LMM(ABC):
14+
@abstractmethod
15+
def generate(self, prompt: str, image: Optional[str | Path]) -> str:
16+
pass
17+
18+
19+
class LLaVALMM(LMM):
20+
def __init__(self, name: str):
21+
self.name = name
22+
23+
def generate(self, prompt: str, image: Optional[str | Path]) -> str:
24+
pass
25+
26+
27+
class OpenAILMM(LMM):
28+
def __init__(self, name: str):
29+
from openai import OpenAI
30+
31+
self.name = name
32+
self.client = OpenAI()
33+
34+
def generate(self, prompt: str, image: Optional[str | Path]) -> str:
35+
message = [
36+
{
37+
"role": "user",
38+
"content": [
39+
{"type": "text", "text": prompt},
40+
]
41+
}
42+
]
43+
if image:
44+
extension = Path(image).suffix
45+
encoded_image = encode_image(image)
46+
message[0]["content"].append(
47+
{
48+
"type": "image_url",
49+
"image_url": {
50+
"url": f"data:image/{extension};base64,{encoded_image}",
51+
"detail": "low",
52+
},
53+
},
54+
)
55+
56+
response = self.client.chat.completions.create(
57+
model="gpt-4-vision-preview",
58+
message=message
59+
)
60+
return response.choices[0].message.content
61+
62+
63+
def get_lmm(name: str) -> LMM:
64+
if name == "openai":
65+
return OpenAILMM(name)
66+
elif name == "llava-v1.6-34b":
67+
return LLaVALMM(name)
68+
else:
69+
raise ValueError(f"Unknown LMM: {name}, current support openai, llava-v1.6-34b")

0 commit comments

Comments
 (0)