From a3b435c998d717aa75491b26966b756b4814d9ef Mon Sep 17 00:00:00 2001 From: Dillon Laird Date: Tue, 7 May 2024 15:47:38 -0700 Subject: [PATCH] fixed api key client issue --- vision_agent/utils/sim.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/vision_agent/utils/sim.py b/vision_agent/utils/sim.py index 4e052494..3a244cd8 100644 --- a/vision_agent/utils/sim.py +++ b/vision_agent/utils/sim.py @@ -5,10 +5,10 @@ from openai import Client from scipy.spatial.distance import cosine # type: ignore -client = Client() - -def get_embedding(text: str, model: str = "text-embedding-3-small") -> List[float]: +def get_embedding( + client: Client, text: str, model: str = "text-embedding-3-small" +) -> List[float]: text = text.replace("\n", " ") return client.embeddings.create(input=[text], model=model).data[0].embedding @@ -18,6 +18,7 @@ def __init__( self, df: pd.DataFrame, sim_key: Optional[str] = None, + api_key: Optional[str] = None, model: str = "text-embedding-3-small", ) -> None: """Creates a similarity object that can be used to find similar items in a @@ -30,13 +31,18 @@ def __init__( model: str: The model to use for embeddings. """ self.df = df + if not api_key: + self.client = Client() + else: + self.client = Client(api_key=api_key) + self.model = model if "embs" not in df.columns and sim_key is None: raise ValueError("key is required if no column 'embs' is present.") if sim_key is not None: self.df["embs"] = self.df[sim_key].apply( - lambda x: get_embedding(x, model=self.model) + lambda x: get_embedding(self.client, x, model=self.model) ) def save(self, sim_file: Union[str, Path]) -> None: @@ -53,7 +59,7 @@ def top_k(self, query: str, k: int = 5) -> Sequence[Dict]: Sequence[Dict]: The top k most similar items. """ - embedding = get_embedding(query, model=self.model) + embedding = get_embedding(self.client, query, model=self.model) self.df["sim"] = self.df.embs.apply(lambda x: 1 - cosine(x, embedding)) res = self.df.sort_values("sim", ascending=False).head(k) return res[[c for c in res.columns if c != "embs"]].to_dict(orient="records")