From 6267d0921165088cfcf616d6c35b9bae1954c4f6 Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Tue, 10 Dec 2024 18:45:44 +0100 Subject: [PATCH] Added cache for openai models. --- src/lighteval/models/openai_model.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/lighteval/models/openai_model.py b/src/lighteval/models/openai_model.py index 12fbeb95c..c45320f37 100644 --- a/src/lighteval/models/openai_model.py +++ b/src/lighteval/models/openai_model.py @@ -26,6 +26,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional +from diskcache import Cache from tqdm import tqdm from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset @@ -64,6 +65,7 @@ class OpenAIClient(LightevalModel): def __init__(self, config, env_config) -> None: api_key = os.environ["OPENAI_API_KEY"] self.client = OpenAI(api_key=api_key) + self.cache = Cache(".cache/openai") # Initialize the cache self.model_info = ModelInfo( model_name=config.model, @@ -80,6 +82,20 @@ def __init__(self, config, env_config) -> None: self.pairwise_tokenization = False def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, logit_bias): + # Create a unique key for the cache based on the input parameters + cache_key = ( + self.model, + prompt, + return_logits, + max_new_tokens, + num_samples, + tuple(logit_bias.items()) if logit_bias else None, + ) + + # Check if the response is already in the cache + if cache_key in self.cache: + return self.cache[cache_key] + for _ in range(self.API_MAX_RETRY): try: response = self.client.chat.completions.create( @@ -91,6 +107,8 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, logit_b logit_bias=logit_bias, n=num_samples, ) + # Store the response in the cache + self.cache[cache_key] = response return response except Exception as e: logger.warning(f"{type(e), e}")