Skip to content

Commit

Permalink
make compatible with py <3.10
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Feb 18, 2024
1 parent bc6a47b commit 169249d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
13 changes: 7 additions & 6 deletions lmm_tools/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import uuid
from pathlib import Path
from typing import Dict, List, Optional, Union

import faiss
import numpy as np
Expand All @@ -28,8 +29,8 @@ def __init__(self, df: pd.DataFrame):
"""
self.df = df
self.lmm: LMM | None = None
self.emb: Embedder | None = None
self.lmm: Optional[LMM] = None
self.emb: Optional[Embedder] = None
self.index = None
if "image_paths" not in self.df.columns:
raise ValueError("image_paths column must be present in DataFrame")
Expand Down Expand Up @@ -75,7 +76,7 @@ def get_embeddings(self) -> npt.NDArray[np.float32]:
d = self.index.d
return faiss.rev_swig_ptr(self.index.get_xb(), ntotal * d).reshape(ntotal, d)

def search(self, query: str, top_k: int = 10) -> list[dict]:
def search(self, query: str, top_k: int = 10) -> List[Dict]:
r"""Searches the index for the most similar images to the query and returns the top_k results.
Args:
query (str): The query to search for.
Expand All @@ -89,15 +90,15 @@ def search(self, query: str, top_k: int = 10) -> list[dict]:
_, idx = self.index.search(query_embedding.reshape(1, -1), top_k)
return self.df.iloc[idx[0]].to_dict(orient="records")

def save(self, path: str | Path) -> None:
def save(self, path: Union[str, Path]) -> None:
path = Path(path)
path.mkdir(parents=True)
self.df.to_csv(path / "data.csv")
if self.index is not None:
write_index(self.index, str(path / "data.index"))

@classmethod
def load(cls, path: str | Path) -> DataStore:
def load(cls, path: Union[str, Path]) -> DataStore:
path = Path(path)
df = pd.read_csv(path / "data.csv", index_col=0)
ds = DataStore(df)
Expand All @@ -106,7 +107,7 @@ def load(cls, path: str | Path) -> DataStore:
return ds


def build_data_store(data: str | Path | list[str | Path]) -> DataStore:
def build_data_store(data: Union[str, Path, list[Union[str, Path]]]) -> DataStore:
if isinstance(data, Path) or isinstance(data, str):
data = Path(data)
data_files = list(Path(data).glob("*"))
Expand Down
8 changes: 4 additions & 4 deletions lmm_tools/lmm/lmm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional
from typing import Optional, Union


def encode_image(image: str | Path) -> str:
Expand All @@ -12,7 +12,7 @@ def encode_image(image: str | Path) -> str:

class LMM(ABC):
@abstractmethod
def generate(self, prompt: str, image: Optional[str | Path]) -> str:
def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str:
pass


Expand All @@ -22,7 +22,7 @@ class LLaVALMM(LMM):
def __init__(self, name: str):
self.name = name

def generate(self, prompt: str, image: Optional[str | Path]) -> str:
def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str:
pass


Expand All @@ -35,7 +35,7 @@ def __init__(self, name: str):
self.name = name
self.client = OpenAI()

def generate(self, prompt: str, image: Optional[str | Path]) -> str:
def generate(self, prompt: str, image: Optional[Union[str, Path]]) -> str:
message = [
{
"role": "user",
Expand Down

0 comments on commit 169249d

Please sign in to comment.