From 7907a11ce902ae6bdb8ae6e007202d3bd5784b7b Mon Sep 17 00:00:00 2001 From: Asia <2736300+humpydonkey@users.noreply.github.com> Date: Tue, 9 Jul 2024 14:23:40 -0700 Subject: [PATCH] Minor fix and improvement (#165) 1. Fix the issue that str result is wrapped with single quotes by notebook. 2. Cache top_k() for better performance --- vision_agent/utils/execute.py | 5 +++++ vision_agent/utils/sim.py | 8 +++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/vision_agent/utils/execute.py b/vision_agent/utils/execute.py index e7526384..fb9b004b 100644 --- a/vision_agent/utils/execute.py +++ b/vision_agent/utils/execute.py @@ -112,6 +112,11 @@ def __init__(self, is_main_result: bool, data: Dict[str, Any]): self.raw = copy.deepcopy(data) self.text = data.pop(MimeType.TEXT_PLAIN, None) + if self.text and (self.text.startswith("'") and self.text.endswith("'")): + # This is a workaround for the issue that str result is wrapped with single quotes by notebook. + # E.g. input text: "'flower'". what we want: "flower" + self.text = self.text[1:-1] + self.html = data.pop(MimeType.TEXT_HTML, None) self.markdown = data.pop(MimeType.TEXT_MARKDOWN, None) self.svg = data.pop(MimeType.IMAGE_SVG, None) diff --git a/vision_agent/utils/sim.py b/vision_agent/utils/sim.py index 1ac8069b..6a2bbdca 100644 --- a/vision_agent/utils/sim.py +++ b/vision_agent/utils/sim.py @@ -1,4 +1,5 @@ import os +from functools import lru_cache from pathlib import Path from typing import Dict, List, Optional, Sequence, Union @@ -33,11 +34,7 @@ def __init__( model: str: The model to use for embeddings. """ self.df = df - if not api_key: - self.client = OpenAI() - else: - self.client = OpenAI(api_key=api_key) - + self.client = OpenAI(api_key=api_key) self.model = model if "embs" not in df.columns and sim_key is None: raise ValueError("key is required if no column 'embs' is present.") @@ -57,6 +54,7 @@ def save(self, sim_file: Union[str, Path]) -> None: df = df.drop("embs", axis=1) df.to_csv(sim_file / "df.csv", index=False) + @lru_cache(maxsize=256) def top_k( self, query: str, k: int = 5, thresh: Optional[float] = None ) -> Sequence[Dict]: