Skip to content

Commit e22d478

Browse files
committed
Added initial proposal for lazy loading model initialization.
1 parent ca8331a commit e22d478

File tree

1 file changed

+93
-64
lines changed

1 file changed

+93
-64
lines changed

src/lighteval/pipeline.py

Lines changed: 93 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,14 @@
2929
from dataclasses import dataclass, field
3030
from datetime import timedelta
3131
from enum import Enum, auto
32+
from typing import Dict
3233

3334
import numpy as np
3435
from tqdm import tqdm
3536

3637
from lighteval.logging.evaluation_tracker import EvaluationTracker
3738
from lighteval.metrics.utils.metric_utils import MetricCategory
39+
from lighteval.models.abstract_model import ModelInfo
3840
from lighteval.models.model_loader import TransformersModel, load_model
3941
from lighteval.models.model_output import (
4042
GenerativeMultiturnResponse,
@@ -43,6 +45,9 @@
4345
LoglikelihoodSingleTokenResponse,
4446
ModelResponse,
4547
)
48+
from lighteval.models.transformers.transformers_model import TransformersModelConfig
49+
from lighteval.models.utils import _simplify_name
50+
from lighteval.models.vllm.vllm_model import VLLMModelConfig
4651
from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks
4752
from lighteval.tasks.registry import Registry, taskinfo_selector
4853
from lighteval.tasks.requests import RequestType, SampleUid
@@ -142,42 +147,49 @@ def __init__(
142147
"--max_samples WAS SET. THESE NUMBERS ARE ONLY PARTIAL AND SHOULD NOT BE USED FOR COMPARISON UNLESS YOU KNOW WHAT YOU ARE DOING."
143148
)
144149

150+
self.tasks = tasks
151+
self.model = model
145152
self.model_config = model_config
146153
self.evaluation_tracker = evaluation_tracker
147-
self.accelerator, self.parallel_context = self._init_parallelism_manager()
148-
self.model = self._init_model(model_config, model)
149-
150-
self.evaluation_tracker.general_config_logger.log_model_info(self.model.model_info)
151-
self._init_tasks_and_requests(tasks=tasks)
154+
self._init_parallelism_manager()
152155
self._init_random_seeds()
156+
157+
self.evaluation_tracker.general_config_logger.log_model_info(self._get_model_info())
158+
self._init_tasks()
159+
153160
# Final results
154161
self.final_dict: dict = None
155162

163+
def _get_model_info(self):
164+
if isinstance(self.model_config, (VLLMModelConfig, TransformersModelConfig)):
165+
# At this point we only need the model name to know the details path
166+
return ModelInfo(model_name=_simplify_name(self.model_config.pretrained))
167+
else:
168+
return self._init_model().model_info
169+
156170
def _init_parallelism_manager(self):
157-
accelerator, parallel_context = None, None
171+
self.accelerator, self.parallel_context = None, None
158172
if self.launcher_type == ParallelismManager.ACCELERATE:
159173
if not is_accelerate_available():
160174
raise ValueError("You are trying to launch an accelerate model, but accelerate is not installed")
161-
accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))])
162-
test_all_gather(accelerator=accelerator)
175+
self.accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))])
176+
test_all_gather(accelerator=self.accelerator)
163177
elif self.launcher_type == ParallelismManager.NANOTRON:
164178
if not is_nanotron_available():
165179
raise ValueError("You are trying to launch a nanotron model, but nanotron is not installed")
166180
dist.initialize_torch_distributed()
167-
parallel_context = ParallelContext(
181+
self.parallel_context = ParallelContext(
168182
tensor_parallel_size=self.model_config.lighteval_config.parallelism.tp,
169183
pipeline_parallel_size=self.model_config.lighteval_config.parallelism.pp,
170184
data_parallel_size=self.model_config.lighteval_config.parallelism.dp,
171185
)
172-
test_all_gather(parallel_context=parallel_context)
186+
test_all_gather(parallel_context=self.parallel_context)
173187

174-
return accelerator, parallel_context
175-
176-
def _init_model(self, model_config, model):
188+
def _init_model(self):
177189
logger.info("--- LOADING MODEL ---")
178-
if model_config is not None:
190+
if self.model_config is not None:
179191
if self.parallel_context:
180-
return NanotronLightevalModel(
192+
self.model = NanotronLightevalModel(
181193
checkpoint_path=os.path.dirname(self.pipeline_parameters.nanotron_checkpoint_path)
182194
if self.pipeline_parameters.nanotron_checkpoint_path
183195
else "",
@@ -188,46 +200,42 @@ def _init_model(self, model_config, model):
188200
env_config=self.pipeline_parameters.env_config,
189201
)
190202
else:
191-
return load_model(config=model_config, env_config=self.pipeline_parameters.env_config)
192-
if isinstance(model, TransformersModel):
193-
return model
194-
else:
195-
return TransformersModel.from_model(
196-
model=model,
203+
self.model = load_model(config=self.model_config, env_config=self.pipeline_parameters.env_config)
204+
if not isinstance(self.model, TransformersModel):
205+
self.model = TransformersModel.from_model(
206+
model=self.model,
197207
use_chat_template=self.pipeline_parameters.use_chat_template,
198208
env_config=self.pipeline_parameters.env_config,
199209
accelerator=self.accelerator,
200210
)
211+
return self.model
201212

202-
def _init_tasks_and_requests(self, tasks: str):
213+
def _init_tasks(self):
203214
with local_ranks_zero_first() if self.launcher_type == ParallelismManager.NANOTRON else nullcontext():
204215
logger.info("--- LOADING TASKS ---")
205216
registry = Registry(
206217
cache_dir=self.pipeline_parameters.env_config.cache_dir,
207218
custom_tasks=self.pipeline_parameters.custom_tasks_directory,
208219
)
209-
task_names_list, fewshots_dict = taskinfo_selector(tasks, registry)
210-
task_dict = registry.get_task_dict(task_names_list)
211-
LightevalTask.load_datasets(list(task_dict.values()), self.pipeline_parameters.dataset_loading_processes)
212-
213-
self.evaluation_tracker.task_config_logger.log(task_dict)
214-
215-
requests, docs = create_requests_from_tasks(
216-
task_dict=task_dict,
217-
fewshot_dict=fewshots_dict,
218-
num_fewshot_seeds=self.pipeline_parameters.num_fewshot_seeds,
219-
lm=self.model,
220-
max_samples=self.pipeline_parameters.max_samples,
221-
evaluation_tracker=self.evaluation_tracker,
222-
use_chat_template=self.pipeline_parameters.use_chat_template,
223-
system_prompt=self.pipeline_parameters.system_prompt,
220+
self.task_names_list, self.fewshots_dict = taskinfo_selector(self.tasks, registry)
221+
self.task_dict = registry.get_task_dict(self.task_names_list)
222+
LightevalTask.load_datasets(
223+
list(self.task_dict.values()), self.pipeline_parameters.dataset_loading_processes
224224
)
225225

226-
self.task_names_list = task_names_list
227-
self.task_dict = task_dict
228-
self.fewshot_dict = fewshots_dict
229-
self.requests = requests
230-
self.docs = docs
226+
self.evaluation_tracker.task_config_logger.log(self.task_dict)
227+
228+
def _init_requests(self):
229+
self.requests, self.docs = create_requests_from_tasks(
230+
task_dict=self.task_dict,
231+
fewshot_dict=self.fewshots_dict,
232+
num_fewshot_seeds=self.pipeline_parameters.num_fewshot_seeds,
233+
lm=self.model,
234+
max_samples=self.pipeline_parameters.max_samples,
235+
evaluation_tracker=self.evaluation_tracker,
236+
use_chat_template=self.pipeline_parameters.use_chat_template,
237+
system_prompt=self.pipeline_parameters.system_prompt,
238+
)
231239

232240
def _init_random_seeds(self):
233241
logger.info("--- INIT SEEDS ---")
@@ -280,16 +288,37 @@ def evaluate(self):
280288
except OSError:
281289
pass
282290

291+
@staticmethod
292+
def _metric_category_to_request_type() -> Dict[MetricCategory, RequestType]:
293+
"""Maps MetricCategory to their corresponding RequestType."""
294+
return {
295+
MetricCategory.TARGET_PERPLEXITY: RequestType.LOGLIKELIHOOD,
296+
MetricCategory.PERPLEXITY: RequestType.LOGLIKELIHOOD_ROLLING,
297+
MetricCategory.GENERATIVE_SAMPLING: RequestType.GREEDY_UNTIL,
298+
MetricCategory.GENERATIVE: RequestType.GREEDY_UNTIL,
299+
MetricCategory.GENERATIVE_LOGPROB: RequestType.GREEDY_UNTIL,
300+
MetricCategory.MULTICHOICE: RequestType.LOGLIKELIHOOD,
301+
MetricCategory.MULTICHOICE_PMI: RequestType.LOGLIKELIHOOD,
302+
MetricCategory.MULTICHOICE_ONE_TOKEN: RequestType.LOGLIKELIHOOD_SINGLE_TOKEN,
303+
MetricCategory.LLM_AS_JUDGE_MULTI_TURN: RequestType.GREEDY_UNTIL_MULTI_TURN,
304+
MetricCategory.LLM_AS_JUDGE: RequestType.GREEDY_UNTIL,
305+
}
306+
307+
@staticmethod
308+
def _request_type_to_response() -> Dict[RequestType, type[ModelResponse]]:
309+
return {
310+
RequestType.LOGLIKELIHOOD: LoglikelihoodResponse,
311+
RequestType.LOGLIKELIHOOD_SINGLE_TOKEN: LoglikelihoodSingleTokenResponse,
312+
RequestType.LOGLIKELIHOOD_ROLLING: LoglikelihoodResponse,
313+
RequestType.GREEDY_UNTIL_MULTI_TURN: GenerativeMultiturnResponse,
314+
RequestType.GREEDY_UNTIL: GenerativeResponse,
315+
}
316+
283317
def _load_responses_from_details(self):
284318
logger.info("--- LOADING RESPONSES FROM DETAILS ---")
285319
sample_id_to_responses: dict[(SampleUid, MetricCategory), list[ModelResponse]] = collections.defaultdict(list)
286320

287-
request_types = list(self.requests.keys())
288-
if len(request_types) > 1:
289-
raise ValueError(
290-
"Loading responses from details when there are multiple request types is currently not supported"
291-
)
292-
model_response_type = self._get_model_response_type(request_types[0])
321+
model_response_type = self._get_model_response_type()
293322

294323
details_datasets = self.evaluation_tracker.load_details_datasets(
295324
self.pipeline_parameters.load_responses_from_details_date_id, self.task_names_list
@@ -333,25 +362,25 @@ def _load_responses_from_details(self):
333362
sample_id_to_responses[(SampleUid(task_name, f"{idx}_{0}"), metric_category)] = [response]
334363
return sample_id_to_responses
335364

336-
def _get_model_response_type(self, request_type):
337-
if request_type == RequestType.LOGLIKELIHOOD:
338-
model_response_type = LoglikelihoodResponse
339-
elif request_type == RequestType.LOGLIKELIHOOD_SINGLE_TOKEN:
340-
model_response_type = LoglikelihoodSingleTokenResponse
341-
elif request_type == RequestType.LOGLIKELIHOOD_ROLLING:
342-
model_response_type = LoglikelihoodResponse
343-
elif request_type == RequestType.GREEDY_UNTIL_MULTI_TURN:
344-
model_response_type = GenerativeMultiturnResponse
345-
elif request_type == RequestType.GREEDY_UNTIL:
346-
model_response_type = GenerativeResponse
347-
else:
348-
raise ValueError(
349-
f"Loading responses from details for request type {request_type} is currently not supported"
350-
)
351-
365+
def _get_model_response_type(self):
366+
model_response_type = None
367+
for task in self.task_dict.values():
368+
for metric_category, has_metric_category in task.has_metric_category.items():
369+
if has_metric_category:
370+
request_type = self._metric_category_to_request_type()[metric_category]
371+
new_model_response_type = self._request_type_to_response()[request_type]
372+
if model_response_type and new_model_response_type != model_response_type:
373+
raise ValueError(
374+
f"Loading responses from details with multiple model response types ({model_response_type} and {new_model_response_type}) is currently not supported"
375+
)
376+
model_response_type = new_model_response_type
352377
return model_response_type
353378

354379
def _run_model(self):
380+
# Initi model stuff lazily to avoid loading the model if not needed
381+
self._init_model()
382+
self._init_requests() # Needs the model to be initialized
383+
355384
# Running all requests depending on the model call type (log likelihood, generative, ...)
356385
# to be able to batch them
357386
logger.info("--- RUNNING MODEL ---")

0 commit comments

Comments
 (0)