From 23ad47669b5209b145be5be40b1b06216beb38c6 Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Wed, 29 Jan 2025 14:09:20 +0000 Subject: [PATCH] attempt at async code --- .../models/endpoints/openai_model.py | 76 +++++++++++++++++-- 1 file changed, 71 insertions(+), 5 deletions(-) diff --git a/src/lighteval/models/endpoints/openai_model.py b/src/lighteval/models/endpoints/openai_model.py index fa3d179f..56c81b0d 100644 --- a/src/lighteval/models/endpoints/openai_model.py +++ b/src/lighteval/models/endpoints/openai_model.py @@ -20,6 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import asyncio import logging import os import time @@ -28,6 +29,7 @@ from typing import Optional from tqdm import tqdm +from tqdm.asyncio import tqdm_asyncio from transformers import AutoTokenizer from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset @@ -55,7 +57,7 @@ import logging import tiktoken - from openai import OpenAI + from openai import AsyncOpenAI, OpenAI logging.getLogger("openai").setLevel(logging.ERROR) logging.getLogger("httpx").setLevel(logging.ERROR) @@ -87,8 +89,12 @@ def from_path(cls, path: str) -> "OpenAIModelConfig": class OpenAIClient(LightevalModel): _DEFAULT_MAX_LENGTH: int = 4096 - def __init__(self, config: OpenAIModelConfig, env_config) -> None: - self.client = OpenAI(api_key=config.api_key, base_url=config.base_url) + def __init__(self, config: OpenAIModelConfig, env_config, is_async: bool = False) -> None: + self.is_async = is_async + if is_async: + self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) + else: + self.client = OpenAI(api_key=config.api_key, base_url=config.base_url) self.config = config self.generation_parameters = config.generation_parameters self.sampling_params = self.generation_parameters.to_vllm_openai_dict() @@ -124,11 +130,12 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, logit_b **self.sampling_params, **response_format, ) + self.API_RETRY_SLEEP = 3 return response except Exception as e: logger.warning(f"{type(e), e}") time.sleep(self.API_RETRY_SLEEP) - self.API_RETRY_SLEEP = self.API_RETRY_SLEEP**self.API_RETRY_MULTIPLIER + self.API_RETRY_SLEEP = self.API_RETRY_SLEEP * self.API_RETRY_MULTIPLIER raise Exception("Failed to get response from the API") def __call_api_parallel( @@ -162,6 +169,62 @@ def __call_api_parallel( return results + async def __call_api_async_one(self, prompt, return_logits, max_new_tokens, num_samples, logit_bias): + for _ in range(self.API_MAX_RETRY): + try: + response_format = {"response_format": {"type": "text"}} if "openai" in self.config.base_url else {} + response = await self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": prompt}], + max_tokens=max_new_tokens if max_new_tokens > 0 else None, + logprobs=return_logits, + logit_bias=logit_bias, + n=num_samples, + **self.sampling_params, + **response_format, + ) + return response + except Exception as e: + logger.warning(f"{type(e), e}") + time.sleep(self.API_RETRY_SLEEP) + self.API_RETRY_SLEEP = self.API_RETRY_SLEEP**self.API_RETRY_MULTIPLIER + raise Exception("Failed to get response from the API") + + async def __call_api_async( + self, + prompts, + return_logits: bool | list[bool], + max_new_tokens: int | list[int], + num_samples: int | list[int], + logit_bias: list[dict[int, float]] | None = None, + ): + # Convert single values to lists + return_logitss = [return_logits for _ in prompts] if not isinstance(return_logits, list) else return_logits + max_new_tokenss = [max_new_tokens for _ in prompts] if not isinstance(max_new_tokens, list) else max_new_tokens + num_sampless = [num_samples for _ in prompts] if not isinstance(num_samples, list) else num_samples + logit_biass = [logit_bias for _ in prompts] if logit_bias is None else logit_bias + + # Validate input lengths + assert ( + len(prompts) == len(return_logitss) == len(max_new_tokenss) == len(num_sampless) == len(logit_biass) + ), "Length of prompts, return_logitss, max_new_tokenss, num_sampless, logit_biass should be same" + + async with asyncio.Semaphore(10): # 10 = num workers + # Create tasks for each prompt + tasks = [ + await self.__call_api_async_one(prompt, ret_log, max_tok, num_samp, log_bias) + for prompt, ret_log, max_tok, num_samp, log_bias in zip( + prompts, return_logitss, max_new_tokenss, num_sampless, logit_biass + ) + ] + + results = await tqdm_asyncio.gather(*tasks, return_exceptions=True) + + if None in results: + raise ValueError("Some entries are not annotated due to errors in annotate_p, please inspect and retry.") + + return results + def greedy_until( self, requests: list[GreedyUntilRequest], @@ -195,7 +258,10 @@ def greedy_until( num_samples = dataset[0].num_samples contexts = [c.context for c in dataset] - responses = self.__call_api_parallel(contexts, return_logits, max_new_tokens, num_samples) + if self.is_async: + responses = asyncio.run(self.__call_api_async(contexts, return_logits, max_new_tokens, num_samples)) + else: + responses = self.__call_api_parallel(contexts, return_logits, max_new_tokens, num_samples) for response in responses: result: list[str] = [output.message.content for output in response.choices]