Skip to content

Commit

Permalink
fixed api key client issue
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed May 7, 2024
1 parent 70e6eac commit a3b435c
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions vision_agent/utils/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from openai import Client
from scipy.spatial.distance import cosine # type: ignore

client = Client()


def get_embedding(text: str, model: str = "text-embedding-3-small") -> List[float]:
def get_embedding(
client: Client, text: str, model: str = "text-embedding-3-small"
) -> List[float]:
text = text.replace("\n", " ")
return client.embeddings.create(input=[text], model=model).data[0].embedding

Expand All @@ -18,6 +18,7 @@ def __init__(
self,
df: pd.DataFrame,
sim_key: Optional[str] = None,
api_key: Optional[str] = None,
model: str = "text-embedding-3-small",
) -> None:
"""Creates a similarity object that can be used to find similar items in a
Expand All @@ -30,13 +31,18 @@ def __init__(
model: str: The model to use for embeddings.
"""
self.df = df
if not api_key:
self.client = Client()
else:
self.client = Client(api_key=api_key)

self.model = model
if "embs" not in df.columns and sim_key is None:
raise ValueError("key is required if no column 'embs' is present.")

if sim_key is not None:
self.df["embs"] = self.df[sim_key].apply(
lambda x: get_embedding(x, model=self.model)
lambda x: get_embedding(self.client, x, model=self.model)
)

def save(self, sim_file: Union[str, Path]) -> None:
Expand All @@ -53,7 +59,7 @@ def top_k(self, query: str, k: int = 5) -> Sequence[Dict]:
Sequence[Dict]: The top k most similar items.
"""

embedding = get_embedding(query, model=self.model)
embedding = get_embedding(self.client, query, model=self.model)
self.df["sim"] = self.df.embs.apply(lambda x: 1 - cosine(x, embedding))
res = self.df.sort_values("sim", ascending=False).head(k)
return res[[c for c in res.columns if c != "embs"]].to_dict(orient="records")
Expand Down

0 comments on commit a3b435c

Please sign in to comment.