Skip to content

Commit 1e837a9

Browse files
Adding inference endpoints models (#12)
This PR: uses Requests instead of passing tuples (which are more error prones) in the Datasets introduces an Abstract Model class which defines the minimum functions we need to have in a model for it to be lighteval compatible cleans up the BaseModel code introduces inference endpoints models Inference endpoints models are these ones: https://huggingface.co/docs/huggingface_hub/v0.20.3/en/package_reference/inference_endpoints Not to be confused with TGI models (which need a local deployment) --------- Co-authored-by: Nathan Habib <[email protected]>
1 parent 77eee8c commit 1e837a9

16 files changed

+940
-415
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ keywords = ["evaluation", "nlp", "llm"]
5050
dependencies = [
5151
# Base dependencies
5252
"transformers>=4.36.0",
53-
"huggingface_hub==0.19.4",
53+
"huggingface_hub==0.20.3",
5454
"torch>=2.0",
5555
"GitPython==3.1.31", # for logging
5656
"datasets>=2.14.0",

src/lighteval/data.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,14 @@
55
from torch.utils.data.distributed import DistributedSampler, T_co
66

77
from lighteval.logging.hierarchical_logger import hlog_warn
8-
from lighteval.tasks.requests import Request
8+
from lighteval.tasks.requests import (
9+
GreedyUntilRequest,
10+
GreedyUntilWithLogitsRequest,
11+
LoglikelihoodRequest,
12+
LoglikelihoodRollingRequest,
13+
LoglikelihoodSingleTokenRequest,
14+
Request,
15+
)
916

1017

1118
class DynamicBatchDataset(Dataset):
@@ -28,6 +35,9 @@ def __init__(
2835
requests (List): A list of requests.
2936
dataset_splits (int): The number of dataset splits.
3037
"""
38+
# We make sure the requests contain the tokenized versions of their values
39+
if any(r.tokenized_context is None for r in requests):
40+
raise ValueError("You passed a request for which tokenization had not happened yet.")
3141

3242
# sort the requests using the collate function and save the original order
3343
enumerated_requests = list(enumerate(requests))
@@ -124,12 +134,12 @@ def __len__(self) -> int:
124134
"""
125135
return self.split_end - self.split_start
126136

127-
def _sorting_criteria(self, x) -> int:
137+
def _sorting_criteria(self, request) -> int:
128138
raise NotImplementedError()
129139

130140

131141
class LoglikelihoodDataset(DynamicBatchDataset):
132-
def _sorting_criteria(self, x) -> int:
142+
def _sorting_criteria(self, request: LoglikelihoodRequest | LoglikelihoodRollingRequest) -> int:
133143
"""
134144
Collates the input data for batching.
135145
@@ -149,13 +159,12 @@ def _sorting_criteria(self, x) -> int:
149159
Returns:
150160
tuple: A tuple containing the sorted input data.
151161
"""
152-
153-
toks = x[1] + x[2]
162+
toks = request.tokenized_context + request.tokenized_continuation
154163
return -len(toks)
155164

156165

157166
class LoglikelihoodSingleTokenDataset(DynamicBatchDataset):
158-
def _sorting_criteria(self, x) -> int:
167+
def _sorting_criteria(self, request: LoglikelihoodSingleTokenRequest) -> int:
159168
"""
160169
Collates the input data for batching.
161170
@@ -167,19 +176,14 @@ def _sorting_criteria(self, x) -> int:
167176
is useful to simplify the batching logic and more importantly to make
168177
automatic adaptive batches much much easier to implement
169178
- any OOMs will happen right away rather than near the end
170-
171-
Args:
172-
x (tuple): A tuple containing the input data.
173-
174-
Returns:
175-
tuple: A tuple containing the collated data.
176179
"""
177-
toks = x[1] # We take only the prompt, no need for the continuation (since it's a list of single tokens)
180+
# We take only the prompt, no need for the continuation (since it's a list of single tokens)
181+
toks = request.tokenized_context
178182
return -len(toks)
179183

180184

181185
class GenerativeTaskDataset(DynamicBatchDataset):
182-
def _sorting_criteria(self, x) -> int:
186+
def _sorting_criteria(self, request: GreedyUntilRequest | GreedyUntilWithLogitsRequest) -> int:
183187
"""
184188
Collate function for generating batches.
185189
@@ -189,7 +193,8 @@ def _sorting_criteria(self, x) -> int:
189193
Returns:
190194
Any: The collated data.
191195
"""
192-
toks, (stop_tokens, gen_length) = x
196+
toks = request.tokenized_context
197+
gen_length = request.generation_size
193198
return -(len(toks) + gen_length)
194199

195200

src/lighteval/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from lighteval.logging.evaluation_tracker import EvaluationTracker
99
from lighteval.logging.hierarchical_logger import hlog
1010
from lighteval.models.base_model import BaseModel
11-
from lighteval.models.inference_client import ModelClient
11+
from lighteval.models.tgi_model import ModelClient
1212
from lighteval.tasks.lighteval_task import LightevalTask
1313
from lighteval.tasks.requests import Doc, Request, RequestType, TaskExampleId
1414

src/lighteval/metrics/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ def apply_multichoice_metric(results: list[ModelReturn], formatted_doc: Doc, met
9494
raise ValueError(
9595
"You can't use a multi choice metric with only one choice. Use `acc_golds_likelihood` instead."
9696
)
97-
choices_logprob = [results[i].result[0] for i in range(len(formatted_doc.choices))]
97+
98+
# Todo: make better system with return_bool_score instead of taking first element
99+
choices_logprob = [results[i].result[0] for i in range(len(formatted_doc.choices))] # sum(
98100
gold_ixs = as_list(formatted_doc.gold_index)
99101

100102
for metric in metrics:
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Optional, Union
3+
4+
import torch
5+
from transformers import BatchEncoding
6+
7+
from lighteval.models.model_config import EnvConfig
8+
from lighteval.models.model_output import GenerateReturn, LoglikelihoodReturn, LoglikelihoodSingleTokenReturn
9+
from lighteval.tasks.requests import (
10+
GreedyUntilRequest,
11+
GreedyUntilWithLogitsRequest,
12+
LoglikelihoodRequest,
13+
LoglikelihoodRollingRequest,
14+
LoglikelihoodSingleTokenRequest,
15+
)
16+
17+
18+
TokenSequence = Union[list[int], torch.LongTensor, torch.Tensor, BatchEncoding]
19+
20+
21+
class LightevalModel(ABC):
22+
DATASET_SPLITS = 4
23+
24+
"""Abstract model class defining the API that every model to plug into lighteval must follow."""
25+
26+
@abstractmethod
27+
def __init__(
28+
self,
29+
config,
30+
env_config: EnvConfig,
31+
):
32+
return NotImplemented
33+
34+
def cleanup(self):
35+
"""Clean up operations if needed, such as closing an endpoint."""
36+
return
37+
38+
@property
39+
@abstractmethod
40+
def tokenizer(self):
41+
raise NotImplementedError
42+
43+
@property
44+
@abstractmethod
45+
def add_special_tokens(self):
46+
raise NotImplementedError
47+
48+
@property
49+
@abstractmethod
50+
def max_length(self) -> int:
51+
"""Return the maximum sequence length of the model."""
52+
raise NotImplementedError
53+
54+
@property
55+
def disable_tqdm(self) -> bool:
56+
raise NotImplementedError
57+
58+
def greedy_until_with_logits(
59+
self,
60+
requests: list[GreedyUntilWithLogitsRequest],
61+
override_bs: Optional[int] = None,
62+
) -> list[GenerateReturn]:
63+
"""
64+
Generates sequences greedily until a stopping condition is met,
65+
returning both the generated sequences and the logits.
66+
67+
Args:
68+
requests (list[tuple[str, dict]]): A list of input requests,
69+
where each request is a tuple containing a prompt string and a dictionary of additional parameters.
70+
disable_tqdm (bool, optional): Whether to disable the tqdm progress bar. Defaults to False.
71+
override_bs (Optional[int], optional): Overrides the batch size for generation. Defaults to None.
72+
73+
Returns:
74+
list[GenerateReturn]: A list of GenerateReturn objects,
75+
where each object contains the generated sequence and the corresponding logits.
76+
"""
77+
return self.greedy_until(
78+
requests=requests,
79+
override_bs=override_bs,
80+
returns_logits=True,
81+
)
82+
83+
@abstractmethod
84+
def greedy_until(
85+
self,
86+
requests: list[GreedyUntilRequest],
87+
returns_logits: bool = False,
88+
override_bs: Optional[int] = None,
89+
) -> list[GenerateReturn]:
90+
"""
91+
Generates responses using a greedy decoding strategy until certain ending conditions are met.
92+
93+
Args:
94+
requests (list[Request]): list of requests containing the context and ending conditions.
95+
returns_logits (bool, optional): Whether to return the logits of the generated responses. Defaults to False.
96+
disable_tqdm (bool, optional): Whether to disable the progress bar. Defaults to False.
97+
override_bs (int, optional): Override the batch size for generation. Defaults to None.
98+
99+
Returns:
100+
list[GenerateReturn]: list of generated responses.
101+
"""
102+
return NotImplemented
103+
104+
@abstractmethod
105+
def loglikelihood(
106+
self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None
107+
) -> list[LoglikelihoodReturn]:
108+
"""Tokenize the context and continuation and compute the log likelihood of those
109+
tokenized sequences.
110+
"""
111+
return NotImplemented
112+
113+
@abstractmethod
114+
def loglikelihood_rolling(
115+
self, requests: list[LoglikelihoodRollingRequest], override_bs=None
116+
) -> list[LoglikelihoodReturn]:
117+
"""This function is used to compute the log likelihood of the context for perplexity metrics."""
118+
return NotImplemented
119+
120+
@abstractmethod
121+
def loglikelihood_single_token(
122+
self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None
123+
) -> list[LoglikelihoodSingleTokenReturn]:
124+
"""Tokenize the context and continuation and compute the log likelihood of those
125+
tokenized sequences.
126+
"""
127+
return NotImplemented
128+
129+
# Tokenization utils
130+
def tok_encode(self, str_to_encode: str | list[str], add_special_tokens: Optional[bool] = None) -> TokenSequence:
131+
if add_special_tokens is None:
132+
add_special_tokens = self.add_special_tokens
133+
if isinstance(str_to_encode, str):
134+
return self.tokenizer.encode(str_to_encode, add_special_tokens=add_special_tokens)
135+
return self.tokenizer(
136+
str_to_encode,
137+
padding=True,
138+
add_special_tokens=add_special_tokens,
139+
return_tensors="pt",
140+
)
141+
142+
def tok_encode_pair(self, context, continuation):
143+
"""Encodes a context, continuation pair by taking care of the spaces in between."""
144+
n_spaces = len(context) - len(context.rstrip())
145+
if n_spaces > 0:
146+
continuation = context[-n_spaces:] + continuation
147+
context = context[:-n_spaces]
148+
whole_enc = self.tok_encode(context + continuation)
149+
context_enc = self.tok_encode(context)
150+
context_enc_len = len(context_enc)
151+
continuation_enc = whole_enc[context_enc_len:]
152+
return context_enc, continuation_enc
153+
154+
def tok_decode(self, tokens: torch.LongTensor) -> list[str]:
155+
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)

src/lighteval/models/adapter_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from contextlib import nullcontext
22

33
import torch
4-
from transformers import AutoModel, PreTrainedTokenizer
4+
from transformers import AutoModelForCausalLM, PreTrainedTokenizer
55

66
from lighteval.logging.hierarchical_logger import hlog
77
from lighteval.models.base_model import BaseModel
@@ -20,7 +20,7 @@ def _create_auto_tokenizer(self, config: AdapterModelConfig, env_config: EnvConf
2020
# (= the parent model, not the model of interest)
2121
return self._create_auto_tokenizer_with_name(config.base_model, config=config, env_config=env_config)
2222

23-
def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig) -> AutoModel:
23+
def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig) -> AutoModelForCausalLM:
2424
"""Returns a PeftModel from a base model and a version fined tuned using PEFT."""
2525
torch_dtype = _get_dtype(config.dtype, self._config)
2626
config.model_parallel, max_memory, device_map = self.init_model_parallel(config.model_parallel)
@@ -31,7 +31,7 @@ def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig)
3131

3232
if self.accelerator.is_local_main_process if self.accelerator is not None else nullcontext():
3333
hlog(f"Loading model from {adapter_weights} and applying adapter to {config.base_model}")
34-
base = self.AUTO_MODEL_CLASS.from_pretrained(
34+
base = AutoModelForCausalLM.from_pretrained(
3535
config.base_model, torch_dtype=torch.float16, low_cpu_mem_usage=True, token=env_config.token
3636
)
3737
# Should pass revision
@@ -43,7 +43,7 @@ def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig)
4343

4444
hlog(f"Loading model from {merged_path}")
4545

46-
model = self.AUTO_MODEL_CLASS.from_pretrained(
46+
model = AutoModelForCausalLM.from_pretrained(
4747
merged_path,
4848
max_memory=max_memory,
4949
device_map=device_map,

0 commit comments

Comments
 (0)