Skip to content

Commit

Permalink
added option for inference endpoints in openai
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier committed Jan 28, 2025
1 parent f9bb2a1 commit db4c4e8
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
5 changes: 5 additions & 0 deletions examples/model_configs/serverless_model_with_openai.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
model:
model_name: "deepseek-ai/DeepSeek-R1" #meta-llama/Llama-3.1-8B-Instruct" #Qwen/Qwen2.5-14B" #Qwen/Qwen2.5-7B"
api:
base_url: "https://huggingface.co/api/inference-proxy/together"
api_key: "hf_"
23 changes: 16 additions & 7 deletions src/lighteval/models/endpoints/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from typing import Optional

from tqdm import tqdm
from transformers import AutoTokenizer

from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset
from lighteval.models.abstract_model import LightevalModel
Expand Down Expand Up @@ -64,6 +65,8 @@
class OpenAIModelConfig:
model: str
generation_parameters: GenerationParameters = None
base_url: str = "https://api.openai.com/v1"
api_key: str = os.environ.get("OPENAI_API_KEY", None)

def __post_init__(self):
if not self.generation_parameters:
Expand All @@ -74,17 +77,19 @@ def from_path(cls, path: str) -> "OpenAIModelConfig":
import yaml

with open(path, "r") as f:
config = yaml.safe_load(f)["model"]
loaded_file = yaml.safe_load(f)
config = loaded_file["model"]
api = loaded_file.get("api", {})
generation_parameters = GenerationParameters.from_dict(config)
return cls(model=config["model_name"], generation_parameters=generation_parameters)
return cls(model=config["model_name"], generation_parameters=generation_parameters, **api)


class OpenAIClient(LightevalModel):
_DEFAULT_MAX_LENGTH: int = 4096

def __init__(self, config: OpenAIModelConfig, env_config) -> None:
api_key = os.environ["OPENAI_API_KEY"]
self.client = OpenAI(api_key=api_key)
self.client = OpenAI(api_key=config.api_key, base_url=config.base_url)
self.config = config
self.generation_parameters = config.generation_parameters
self.sampling_params = self.generation_parameters.to_vllm_openai_dict()

Expand All @@ -99,21 +104,25 @@ def __init__(self, config: OpenAIModelConfig, env_config) -> None:
self.API_RETRY_MULTIPLIER = 2
self.CONCURENT_CALLS = 100
self.model = config.model
self._tokenizer = tiktoken.encoding_for_model(self.model)
try:
self._tokenizer = tiktoken.encoding_for_model(self.model)
except KeyError:
self._tokenizer = AutoTokenizer.from_pretrained(self.model)
self.pairwise_tokenization = False

def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, logit_bias):
for _ in range(self.API_MAX_RETRY):
try:
response_format = {"response_format": {"type": "text"}} if "openai" in self.config.base_url else {}
response = self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
response_format={"type": "text"},
max_tokens=max_new_tokens if max_new_tokens > 0 else None,
logprobs=return_logits,
logit_bias=logit_bias,
n=num_samples,
**self.sampling_params,
**response_format,
)
return response
except Exception as e:
Expand Down Expand Up @@ -181,7 +190,7 @@ def greedy_until(
position=0,
disable=False, # self.disable_tqdm,
):
max_new_tokens = dataset[0].generation_size # could be none
max_new_tokens = 500 # dataset[0].generation_size # could be none
return_logits = dataset[0].use_logits
num_samples = dataset[0].num_samples
contexts = [c.context for c in dataset]
Expand Down

0 comments on commit db4c4e8

Please sign in to comment.