Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix typing errors #3

Merged
merged 3 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci_cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
Test:
strategy:
matrix:
python-version: [3.8, 3.10.11]
python-version: [3.10.11]
os: [ ubuntu-22.04, windows-2022, macos-12 ]
runs-on: ${{ matrix.os }}
steps:
Expand Down
25 changes: 14 additions & 11 deletions lmm_tools/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import uuid
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, cast

import faiss
import numpy as np
Expand Down Expand Up @@ -31,7 +31,7 @@ def __init__(self, df: pd.DataFrame):
self.df = df
self.lmm: Optional[LMM] = None
self.emb: Optional[Embedder] = None
self.index = None
self.index: Optional[faiss.IndexFlatL2] = None # type: ignore
if "image_paths" not in self.df.columns:
raise ValueError("image_paths column must be present in DataFrame")
if "image_id" not in self.df.columns:
Expand All @@ -50,7 +50,7 @@ def add_column(self, name: str, prompt: str) -> Self:
if self.lmm is None:
raise ValueError("LMM not set yet")

self.df[name] = self.df["image_paths"].progress_apply(
self.df[name] = self.df["image_paths"].progress_apply( # type: ignore
lambda x: self.lmm.generate(prompt, image=x)
)
return self
Expand All @@ -62,19 +62,22 @@ def build_index(self, target_col: str) -> Self:
if self.emb is None:
raise ValueError("Embedder not set yet")

embeddings = self.df[target_col].progress_apply(lambda x: self.emb.embed(x))
embeddings = np.array(embeddings.tolist()).astype(np.float32)
self.index = faiss.IndexFlatL2(embeddings.shape[1])
self.index.add(embeddings)
embeddings: pd.Series = self.df[target_col].progress_apply(lambda x: self.emb.embed(x)) # type: ignore
embeddings_np = np.array(embeddings.tolist()).astype(np.float32)
self.index = faiss.IndexFlatL2(embeddings_np.shape[1])
self.index.add(embeddings_np)
return self

def get_embeddings(self) -> npt.NDArray[np.float32]:
if self.index is None:
raise ValueError("Index not built yet")

ntotal = self.index.ntotal
d = self.index.d
return faiss.rev_swig_ptr(self.index.get_xb(), ntotal * d).reshape(ntotal, d)
d: int = self.index.d
return cast(
npt.NDArray[np.float32],
faiss.rev_swig_ptr(self.index.get_xb(), ntotal * d).reshape(ntotal, d),
)

def search(self, query: str, top_k: int = 10) -> List[Dict]:
r"""Searches the index for the most similar images to the query and returns the top_k results.
Expand All @@ -86,9 +89,9 @@ def search(self, query: str, top_k: int = 10) -> List[Dict]:
if self.emb is None:
raise ValueError("Embedder not set yet")

query_embedding = self.emb.embed(query)
query_embedding: npt.NDArray[np.float32] = self.emb.embed(query)
_, idx = self.index.search(query_embedding.reshape(1, -1), top_k)
return self.df.iloc[idx[0]].to_dict(orient="records")
return cast(List[Dict], self.df.iloc[idx[0]].to_dict(orient="records"))

def save(self, path: Union[str, Path]) -> None:
path = Path(path)
Expand Down
8 changes: 6 additions & 2 deletions lmm_tools/emb/emb.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from abc import ABC, abstractmethod
from typing import cast

import numpy as np
import numpy.typing as npt


class Embedder(ABC):
@abstractmethod
def embed(self, text: str) -> list:
def embed(self, text: str) -> npt.NDArray[np.float32]:
pass


Expand All @@ -17,7 +18,10 @@ def __init__(self, model_name: str = "all-MiniLM-L12-v2"):
self.model = SentenceTransformer(model_name)

def embed(self, text: str) -> npt.NDArray[np.float32]:
return self.model.encode([text]).flatten().astype(np.float32)
return cast(
npt.NDArray[np.float32],
self.model.encode([text]).flatten().astype(np.float32),
)


class OpenAIEmb(Embedder):
Expand Down
10 changes: 5 additions & 5 deletions lmm_tools/lmm/lmm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Union
from typing import Any, Dict, Optional, Union, cast


def encode_image(image: Union[str, Path]) -> str:
Expand All @@ -23,7 +23,7 @@ def __init__(self, name: str):
self.name = name

def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str:
pass
raise NotImplementedError("LLaVA LMM not implemented yet")


class OpenAILMM(LMM):
Expand All @@ -36,7 +36,7 @@ def __init__(self, name: str):
self.client = OpenAI()

def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str:
message = [
message: list[Dict[str, Any]] = [
{
"role": "user",
"content": [
Expand All @@ -58,9 +58,9 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str:
)

response = self.client.chat.completions.create(
model="gpt-4-vision-preview", message=message
model="gpt-4-vision-preview", messages=message # type: ignore
)
return response.choices[0].message.content
return cast(str, response.choices[0].message.content)


def get_lmm(name: str) -> LMM:
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ packages = [{include = "lmm_tools"}]
"documentation" = "https://github.com/landing-ai/lmm-tools"

[tool.poetry.dependencies] # main dependency group
python = ">=3.8,<4.0"
python = ">=3.10,<4.0"

numpy = ">=1.21.0,<2.0.0"
pillow = "10.*"
Expand Down Expand Up @@ -81,4 +81,7 @@ disallow_any_unimported = true
ignore_missing_imports = true
module = [
"cv2.*",
"faiss.*",
"openai.*",
"sentence_transformers.*",
]
Loading