diff --git a/src/lighteval/models/endpoints/openai_model.py b/src/lighteval/models/endpoints/openai_model.py index 37b8ca347..cb65790a9 100644 --- a/src/lighteval/models/endpoints/openai_model.py +++ b/src/lighteval/models/endpoints/openai_model.py @@ -27,6 +27,7 @@ from dataclasses import dataclass from typing import Optional +from diskcache import Cache from tqdm import tqdm from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset @@ -85,6 +86,7 @@ class OpenAIClient(LightevalModel): def __init__(self, config: OpenAIModelConfig, env_config) -> None: api_key = os.environ["OPENAI_API_KEY"] self.client = OpenAI(api_key=api_key) + self.cache = Cache(".cache/openai") # Initialize the cache self.generation_parameters = config.generation_parameters self.sampling_params = self.generation_parameters.to_vllm_openai_dict() @@ -103,6 +105,20 @@ def __init__(self, config: OpenAIModelConfig, env_config) -> None: self.pairwise_tokenization = False def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, logit_bias): + # Create a unique key for the cache based on the input parameters + cache_key = ( + self.model, + prompt, + return_logits, + max_new_tokens, + num_samples, + tuple(logit_bias.items()) if logit_bias else None, + ) + + # Check if the response is already in the cache + if cache_key in self.cache: + return self.cache[cache_key] + for _ in range(self.API_MAX_RETRY): try: response = self.client.chat.completions.create( @@ -115,6 +131,8 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, logit_b n=num_samples, **self.sampling_params, ) + # Store the response in the cache + self.cache[cache_key] = response return response except Exception as e: logger.warning(f"{type(e), e}")