Skip to content

Commit

Permalink
Fix typing errors (#3)
Browse files Browse the repository at this point in the history
* Fix typing errors

* Update python version

* Fix ubuntu-specific type error

---------

Co-authored-by: Yazhou Cao <[email protected]>
  • Loading branch information
humpydonkey and AsiaCao authored Feb 20, 2024
1 parent 8efbae0 commit acfc4f5
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci_cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
Test:
strategy:
matrix:
python-version: [3.8, 3.10.11]
python-version: [3.10.11]
os: [ ubuntu-22.04, windows-2022, macos-12 ]
runs-on: ${{ matrix.os }}
steps:
Expand Down
25 changes: 14 additions & 11 deletions lmm_tools/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import uuid
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, cast

import faiss
import numpy as np
Expand Down Expand Up @@ -31,7 +31,7 @@ def __init__(self, df: pd.DataFrame):
self.df = df
self.lmm: Optional[LMM] = None
self.emb: Optional[Embedder] = None
self.index = None
self.index: Optional[faiss.IndexFlatL2] = None # type: ignore
if "image_paths" not in self.df.columns:
raise ValueError("image_paths column must be present in DataFrame")
if "image_id" not in self.df.columns:
Expand All @@ -50,7 +50,7 @@ def add_column(self, name: str, prompt: str) -> Self:
if self.lmm is None:
raise ValueError("LMM not set yet")

self.df[name] = self.df["image_paths"].progress_apply(
self.df[name] = self.df["image_paths"].progress_apply( # type: ignore
lambda x: self.lmm.generate(prompt, image=x)
)
return self
Expand All @@ -62,19 +62,22 @@ def build_index(self, target_col: str) -> Self:
if self.emb is None:
raise ValueError("Embedder not set yet")

embeddings = self.df[target_col].progress_apply(lambda x: self.emb.embed(x))
embeddings = np.array(embeddings.tolist()).astype(np.float32)
self.index = faiss.IndexFlatL2(embeddings.shape[1])
self.index.add(embeddings)
embeddings: pd.Series = self.df[target_col].progress_apply(lambda x: self.emb.embed(x)) # type: ignore
embeddings_np = np.array(embeddings.tolist()).astype(np.float32)
self.index = faiss.IndexFlatL2(embeddings_np.shape[1])
self.index.add(embeddings_np)
return self

def get_embeddings(self) -> npt.NDArray[np.float32]:
if self.index is None:
raise ValueError("Index not built yet")

ntotal = self.index.ntotal
d = self.index.d
return faiss.rev_swig_ptr(self.index.get_xb(), ntotal * d).reshape(ntotal, d)
d: int = self.index.d
return cast(
npt.NDArray[np.float32],
faiss.rev_swig_ptr(self.index.get_xb(), ntotal * d).reshape(ntotal, d),
)

def search(self, query: str, top_k: int = 10) -> List[Dict]:
r"""Searches the index for the most similar images to the query and returns the top_k results.
Expand All @@ -86,9 +89,9 @@ def search(self, query: str, top_k: int = 10) -> List[Dict]:
if self.emb is None:
raise ValueError("Embedder not set yet")

query_embedding = self.emb.embed(query)
query_embedding: npt.NDArray[np.float32] = self.emb.embed(query)
_, idx = self.index.search(query_embedding.reshape(1, -1), top_k)
return self.df.iloc[idx[0]].to_dict(orient="records")
return cast(List[Dict], self.df.iloc[idx[0]].to_dict(orient="records"))

def save(self, path: Union[str, Path]) -> None:
path = Path(path)
Expand Down
8 changes: 6 additions & 2 deletions lmm_tools/emb/emb.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from abc import ABC, abstractmethod
from typing import cast

import numpy as np
import numpy.typing as npt


class Embedder(ABC):
@abstractmethod
def embed(self, text: str) -> list:
def embed(self, text: str) -> npt.NDArray[np.float32]:
pass


Expand All @@ -17,7 +18,10 @@ def __init__(self, model_name: str = "all-MiniLM-L12-v2"):
self.model = SentenceTransformer(model_name)

def embed(self, text: str) -> npt.NDArray[np.float32]:
return self.model.encode([text]).flatten().astype(np.float32)
return cast(
npt.NDArray[np.float32],
self.model.encode([text]).flatten().astype(np.float32),
)


class OpenAIEmb(Embedder):
Expand Down
10 changes: 5 additions & 5 deletions lmm_tools/lmm/lmm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Union
from typing import Any, Dict, Optional, Union, cast


def encode_image(image: Union[str, Path]) -> str:
Expand All @@ -23,7 +23,7 @@ def __init__(self, name: str):
self.name = name

def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str:
pass
raise NotImplementedError("LLaVA LMM not implemented yet")


class OpenAILMM(LMM):
Expand All @@ -36,7 +36,7 @@ def __init__(self, name: str):
self.client = OpenAI()

def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str:
message = [
message: list[Dict[str, Any]] = [
{
"role": "user",
"content": [
Expand All @@ -58,9 +58,9 @@ def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str:
)

response = self.client.chat.completions.create(
model="gpt-4-vision-preview", message=message
model="gpt-4-vision-preview", messages=message # type: ignore
)
return response.choices[0].message.content
return cast(str, response.choices[0].message.content)


def get_lmm(name: str) -> LMM:
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ packages = [{include = "lmm_tools"}]
"documentation" = "https://github.com/landing-ai/lmm-tools"

[tool.poetry.dependencies] # main dependency group
python = ">=3.8,<4.0"
python = ">=3.10,<4.0"

numpy = ">=1.21.0,<2.0.0"
pillow = "10.*"
Expand Down Expand Up @@ -81,4 +81,7 @@ disallow_any_unimported = true
ignore_missing_imports = true
module = [
"cv2.*",
"faiss.*",
"openai.*",
"sentence_transformers.*",
]

0 comments on commit acfc4f5

Please sign in to comment.