Skip to content

Commit

Permalink
attempt at async code
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier committed Jan 29, 2025
1 parent 59ce9e1 commit 23ad476
Showing 1 changed file with 71 additions and 5 deletions.
76 changes: 71 additions & 5 deletions src/lighteval/models/endpoints/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import asyncio
import logging
import os
import time
Expand All @@ -28,6 +29,7 @@
from typing import Optional

from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio
from transformers import AutoTokenizer

from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset
Expand Down Expand Up @@ -55,7 +57,7 @@
import logging

import tiktoken
from openai import OpenAI
from openai import AsyncOpenAI, OpenAI

logging.getLogger("openai").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
Expand Down Expand Up @@ -87,8 +89,12 @@ def from_path(cls, path: str) -> "OpenAIModelConfig":
class OpenAIClient(LightevalModel):
_DEFAULT_MAX_LENGTH: int = 4096

def __init__(self, config: OpenAIModelConfig, env_config) -> None:
self.client = OpenAI(api_key=config.api_key, base_url=config.base_url)
def __init__(self, config: OpenAIModelConfig, env_config, is_async: bool = False) -> None:
self.is_async = is_async
if is_async:
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
else:
self.client = OpenAI(api_key=config.api_key, base_url=config.base_url)
self.config = config
self.generation_parameters = config.generation_parameters
self.sampling_params = self.generation_parameters.to_vllm_openai_dict()
Expand Down Expand Up @@ -124,11 +130,12 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, logit_b
**self.sampling_params,
**response_format,
)
self.API_RETRY_SLEEP = 3
return response
except Exception as e:
logger.warning(f"{type(e), e}")
time.sleep(self.API_RETRY_SLEEP)
self.API_RETRY_SLEEP = self.API_RETRY_SLEEP**self.API_RETRY_MULTIPLIER
self.API_RETRY_SLEEP = self.API_RETRY_SLEEP * self.API_RETRY_MULTIPLIER
raise Exception("Failed to get response from the API")

def __call_api_parallel(
Expand Down Expand Up @@ -162,6 +169,62 @@ def __call_api_parallel(

return results

async def __call_api_async_one(self, prompt, return_logits, max_new_tokens, num_samples, logit_bias):
for _ in range(self.API_MAX_RETRY):
try:
response_format = {"response_format": {"type": "text"}} if "openai" in self.config.base_url else {}
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_new_tokens if max_new_tokens > 0 else None,
logprobs=return_logits,
logit_bias=logit_bias,
n=num_samples,
**self.sampling_params,
**response_format,
)
return response
except Exception as e:
logger.warning(f"{type(e), e}")
time.sleep(self.API_RETRY_SLEEP)
self.API_RETRY_SLEEP = self.API_RETRY_SLEEP**self.API_RETRY_MULTIPLIER
raise Exception("Failed to get response from the API")

async def __call_api_async(
self,
prompts,
return_logits: bool | list[bool],
max_new_tokens: int | list[int],
num_samples: int | list[int],
logit_bias: list[dict[int, float]] | None = None,
):
# Convert single values to lists
return_logitss = [return_logits for _ in prompts] if not isinstance(return_logits, list) else return_logits
max_new_tokenss = [max_new_tokens for _ in prompts] if not isinstance(max_new_tokens, list) else max_new_tokens
num_sampless = [num_samples for _ in prompts] if not isinstance(num_samples, list) else num_samples
logit_biass = [logit_bias for _ in prompts] if logit_bias is None else logit_bias

# Validate input lengths
assert (
len(prompts) == len(return_logitss) == len(max_new_tokenss) == len(num_sampless) == len(logit_biass)
), "Length of prompts, return_logitss, max_new_tokenss, num_sampless, logit_biass should be same"

async with asyncio.Semaphore(10): # 10 = num workers
# Create tasks for each prompt
tasks = [
await self.__call_api_async_one(prompt, ret_log, max_tok, num_samp, log_bias)
for prompt, ret_log, max_tok, num_samp, log_bias in zip(
prompts, return_logitss, max_new_tokenss, num_sampless, logit_biass
)
]

results = await tqdm_asyncio.gather(*tasks, return_exceptions=True)

if None in results:
raise ValueError("Some entries are not annotated due to errors in annotate_p, please inspect and retry.")

return results

def greedy_until(
self,
requests: list[GreedyUntilRequest],
Expand Down Expand Up @@ -195,7 +258,10 @@ def greedy_until(
num_samples = dataset[0].num_samples
contexts = [c.context for c in dataset]

responses = self.__call_api_parallel(contexts, return_logits, max_new_tokens, num_samples)
if self.is_async:
responses = asyncio.run(self.__call_api_async(contexts, return_logits, max_new_tokens, num_samples))
else:
responses = self.__call_api_parallel(contexts, return_logits, max_new_tokens, num_samples)

for response in responses:
result: list[str] = [output.message.content for output in response.choices]
Expand Down

0 comments on commit 23ad476

Please sign in to comment.