Skip to content

Commit 4176f1e

Browse files
committed
add doc to lighteval_tasks
1 parent 8b36fe0 commit 4176f1e

File tree

2 files changed

+128
-19
lines changed

2 files changed

+128
-19
lines changed

src/lighteval/metrics/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def higher_is_better():
501501
return res
502502

503503
@staticmethod
504-
def corpus_level_fns():
504+
def corpus_level_fns() -> dict[str, callable]:
505505
res = {}
506506
for metric in Metrics:
507507
if metric.value.category == MetricCategory.IGNORED:

src/lighteval/tasks/lighteval_task.py

Lines changed: 127 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,19 @@
4040

4141

4242
class LightevalTask:
43-
def __init__(self, name: str, cfg: dict, cache_dir: str = None, custom_tasks_module=None):
43+
def __init__(self, name: str, cfg: dict, cache_dir: Optional[str] = None, custom_tasks_module=None):
44+
"""
45+
Initialize a LightEval task.
46+
47+
Args:
48+
name (str): The name of the task.
49+
cfg (dict): The configuration dictionary containing
50+
task-specific settings (from the task_table.json file).
51+
cache_dir (Optional[str], optional): The directory to cache the
52+
dataset. Defaults to None.
53+
custom_tasks_module ([type], optional): A custom module
54+
containing task-specific functions. Defaults to None.
55+
"""
4456
self.name = name
4557
self.VERSION = 0
4658
self.is_main_process = False
@@ -108,24 +120,53 @@ def cfg(self):
108120
return self._cfg
109121

110122
def doc_to_text_without_instructions(self, doc: Doc) -> str:
123+
"""
124+
Returns the query of the document without the instructions. If the
125+
document has instructions, it removes them from the query:
126+
127+
Args:
128+
doc (Doc): The document.
129+
130+
Returns:
131+
str: The query of the document without the instructions.
132+
"""
111133
if doc.instruction is not None:
112134
if not doc.query.startswith(doc.instruction):
113135
raise ValueError(f"Prompt query {doc.query} is not starting with instruction {doc.instruction}")
114136
return doc.query[len(doc.instruction) :]
115137
return doc.query
116138

117139
def doc_to_text_and_instructions(self, doc: Doc) -> Tuple[str, str]:
140+
"""
141+
Returns a tuple with the query of the document and the instructions.
142+
If the document has no instructions, the second element of the tuple is
143+
an empty string.
144+
145+
Args:
146+
doc (Doc): The document.
147+
148+
Returns:
149+
Tuple[str, str]: A tuple with the query of the document and the
150+
instructions.
151+
"""
118152
if doc.instruction is not None:
119153
if not doc.query.startswith(doc.instruction):
120154
raise ValueError(f"Prompt query {doc.query} is not starting with instruction {doc.instruction}")
121155
return (doc.query[len(doc.instruction) :], doc.instruction)
122156
return (doc.query, "")
123157

124158
def get_first_possible_fewshot_splits(self, number_of_splits: int = 1) -> list[str]:
125-
"""Parses the possible fewshot split keys in order:
126-
train, then validation keys
127-
and matches them with the available keys.
128-
Returns the first available.
159+
"""
160+
Parses the possible fewshot split keys in order: train, then validation
161+
keys and matches them with the available keys. Returns the first
162+
available.
163+
164+
Args:
165+
number_of_splits (int, optional): The number of splits to return.
166+
Defaults to 1.
167+
168+
Returns:
169+
list[str]: The list of the first available fewshot splits.
129170
"""
130171
# Possible few shot splits are the available splits not used for evaluation
131172
possible_fewshot_splits = [k for k in self.all_available_splits if k not in self.evaluation_split]
@@ -145,6 +186,17 @@ def get_first_possible_fewshot_splits(self, number_of_splits: int = 1) -> list[s
145186
return None
146187

147188
def _get_docs_from_split(self, keys, few_shots=False) -> list[Doc]:
189+
"""
190+
Get the documents from the dataset for the given keys (splits).
191+
192+
Args:
193+
keys (list): The list of keys (splits).
194+
few_shots (bool, optional): Whether the documents are used for few
195+
shot examples. Defaults to False.
196+
197+
Returns:
198+
list[Doc]: The list of documents.
199+
"""
148200
if self.dataset is None:
149201
self.dataset = download_dataset_worker((self.dataset_path, self.dataset_config_name))
150202

@@ -159,6 +211,13 @@ def _get_docs_from_split(self, keys, few_shots=False) -> list[Doc]:
159211
return docs
160212

161213
def fewshot_docs(self) -> list[Doc]:
214+
"""
215+
Returns the few shot documents. If the few shot documents are not
216+
available, it gets them from the few shot split or the evaluation split.
217+
218+
Returns:
219+
list[Doc]: The few shot documents.
220+
"""
162221
if self._fewshot_docs is None:
163222
self._fewshot_docs = []
164223

@@ -170,11 +229,28 @@ def fewshot_docs(self) -> list[Doc]:
170229
return self._fewshot_docs
171230

172231
def eval_docs(self) -> list[Doc]:
232+
"""
233+
Returns the evaluation documents.
234+
235+
Returns:
236+
list[Doc]: The evaluation documents.
237+
"""
173238
if self._docs is None:
174239
self._docs = self._get_docs_from_split(self.evaluation_split)
175240
return self._docs
176241

177-
def doc_to_target(self, formatted_doc: Doc, few_shot: bool = False):
242+
def doc_to_target(self, formatted_doc: Doc, few_shot: bool = False) -> str:
243+
"""
244+
Returns the target of the given document.
245+
246+
Args:
247+
formatted_doc (Doc): The formatted document.
248+
few_shot (bool, optional): Whether the document is used for few
249+
shot examples. Defaults to False.
250+
251+
Returns:
252+
str: The target of the document.
253+
"""
178254
if few_shot:
179255
if formatted_doc.target_for_fewshot_sorting is not None:
180256
return formatted_doc.target_for_fewshot_sorting
@@ -184,6 +260,16 @@ def doc_to_target(self, formatted_doc: Doc, few_shot: bool = False):
184260

185261
# Requests
186262
def get_request_type(self) -> list[RequestType]:
263+
"""
264+
Returns the request types for the task.
265+
266+
Returns:
267+
list[RequestType]: The request types for the task.
268+
269+
Raises:
270+
NotImplementedError: If the request type is not implemented for the
271+
task.
272+
"""
187273
request_types = []
188274
if self.has_metric_category[MetricCategory.TARGET_PERPLEXITY]:
189275
request_types.append(RequestType.LOGLIKELIHOOD)
@@ -207,7 +293,7 @@ def construct_requests(
207293
self, formatted_doc: Doc, context: str, document_id_seed: str, current_task_name: str
208294
) -> List[Request]:
209295
"""
210-
Constructs a list of requests based on the given parameters.
296+
Constructs a list of requests from the task based on the given parameters.
211297
212298
Args:
213299
formatted_doc (Doc): The formatted document almost straight from the dataset.
@@ -282,7 +368,17 @@ def construct_requests(
282368

283369
return requests
284370

285-
def process_results(self, formatted_doc: Doc, results: list[ModelReturn]):
371+
def process_results(self, formatted_doc: Doc, results: list[ModelReturn]) -> dict[str, float]:
372+
"""
373+
Processes the results of the task. and stores them in the output dict.
374+
375+
Args:
376+
formatted_doc (Doc): The formatted document of the task.
377+
results (list[ModelReturn]): The results of the task, returned by the model class after evaluation.
378+
379+
Returns:
380+
dict[str, float]: The output dictionary containing the results of the task.
381+
"""
286382
# Metrics management is done in metrics.__init__
287383
outputs = {}
288384
if self.has_metric_category[MetricCategory.TARGET_PERPLEXITY]:
@@ -319,6 +415,10 @@ def process_results(self, formatted_doc: Doc, results: list[ModelReturn]):
319415
return outputs
320416

321417
def aggregation(self):
418+
"""
419+
Return a dict with metric name and its aggregation function for all
420+
metrics
421+
"""
322422
return Metrics.corpus_level_fns()
323423

324424
@staticmethod
@@ -349,6 +449,10 @@ def load_datasets(tasks: list["LightevalTask"], dataset_loading_processes: int =
349449

350450

351451
def download_dataset_worker(args):
452+
"""
453+
Worker function to download a dataset from the HuggingFace Hub.
454+
Used for parallel dataset loading.
455+
"""
352456
dataset_path, dataset_config_name = args
353457
dataset = load_dataset(
354458
path=dataset_path,
@@ -370,22 +474,27 @@ def create_requests_from_tasks( # noqa: C901
370474
use_chat_template: bool,
371475
) -> Tuple[dict[RequestType, list[Request]], dict[TaskExampleId, Doc]]:
372476
"""
373-
Takes a task dict and a fewshot dict and returns a dict of requests, a dict of docs, and a dict of requests origins.
374-
The construction of prompts and thus the managing of few shots is done here.
477+
Takes a task dict and a fewshot dict and returns a dict of requests, a dict
478+
of docs, and a dict of requests origins. The construction of prompts and
479+
thus the managing of few shots is done here.
375480
376481
Args:
377-
task_dict (_type_): _description_
378-
fewshot_dict (_type_): _description_
379-
num_fewshot_seeds (_type_): _description_
380-
lm (_type_): _description_
381-
max_samples (_type_): _description_
382-
evaluation_tracker (_type_): _description_
482+
task_dict (dict[str, LightevalTask]): A dictionary of tasks.
483+
fewshot_dict (dict[str, list[Tuple[int, bool]]]): A dictionary of few
484+
shot examples.
485+
num_fewshot_seeds (int): The number of few shot seeds.
486+
lm (BaseModel): The language model.
487+
max_samples (int): The maximum number of samples.
488+
evaluation_tracker (EvaluationTracker): The evaluation tracker.
489+
use_chat_template (bool): Whether to use the chat template.
383490
384491
Raises:
385-
RuntimeError: _description_
492+
NotImplementedError: If the request type is not implemented for the
493+
task.
386494
387495
Returns:
388-
_type_: _description_
496+
Tuple[dict[RequestType, list[Request]], dict[TaskExampleId, Doc]]: A
497+
tuple containing the requests and the documents.
389498
"""
390499
docs: dict[TaskExampleId, Doc] = {}
391500
requests: dict[RequestType, list[Request]] = collections.defaultdict(list)

0 commit comments

Comments
 (0)