From 97db62032c4191fb11afae67e58519f90d4a738d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine?= Date: Thu, 26 Dec 2024 12:35:00 +0100 Subject: [PATCH] inferenceendpoint renamed to ie --- src/lighteval/models/endpoints/endpoint_model.py | 4 +--- src/lighteval/models/endpoints/tgi_model.py | 4 +--- src/lighteval/models/model_input.py | 2 +- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/lighteval/models/endpoints/endpoint_model.py b/src/lighteval/models/endpoints/endpoint_model.py index 47978adff..942ece410 100644 --- a/src/lighteval/models/endpoints/endpoint_model.py +++ b/src/lighteval/models/endpoints/endpoint_model.py @@ -316,9 +316,7 @@ def __init__( # noqa: C901 model_size=-1, ) self.generation_parameters = config.generation_parameters - self.generation_config = TextGenerationInputGenerateParameters( - **self.generation_parameters.to_tgi_inferenceendpoint_dict() - ) + self.generation_config = TextGenerationInputGenerateParameters(**self.generation_parameters.to_tgi_ie_dict()) @staticmethod def get_larger_hardware_suggestion(cur_instance_type: str = None, cur_instance_size: str = None): diff --git a/src/lighteval/models/endpoints/tgi_model.py b/src/lighteval/models/endpoints/tgi_model.py index 9ca5dc053..f0bb712b6 100644 --- a/src/lighteval/models/endpoints/tgi_model.py +++ b/src/lighteval/models/endpoints/tgi_model.py @@ -88,9 +88,7 @@ def __init__(self, config: TGIModelConfig) -> None: self.client = AsyncClient(config.inference_server_address, headers=headers, timeout=240) self.generation_parameters = config.generation_parameters - self.generation_config = TextGenerationInputGenerateParameters( - **self.generation_parameters.to_tgi_inferenceendpoint_dict() - ) + self.generation_config = TextGenerationInputGenerateParameters(**self.generation_parameters.to_tgi_ie_dict()) self._max_gen_toks = 256 self.model_info = requests.get(f"{config.inference_server_address}/info", headers=headers).json() if "model_id" not in self.model_info: diff --git a/src/lighteval/models/model_input.py b/src/lighteval/models/model_input.py index 2635245c3..04e35be17 100644 --- a/src/lighteval/models/model_input.py +++ b/src/lighteval/models/model_input.py @@ -97,7 +97,7 @@ def to_transformers_dict(self) -> dict: } return {k: v for k, v in args.items() if v is not None} - def to_tgi_inferenceendpoint_dict(self) -> dict: + def to_tgi_ie_dict(self) -> dict: """Selects relevant generation and sampling parameters for tgi or inference endpoints models. Doc: https://huggingface.co/docs/huggingface_hub/v0.26.3/en/package_reference/inference_types#huggingface_hub.TextGenerationInputGenerateParameters