Skip to content

Commit

Permalink
inferenceendpoint renamed to ie
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier committed Dec 26, 2024
1 parent e233190 commit 97db620
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 7 deletions.
4 changes: 1 addition & 3 deletions src/lighteval/models/endpoints/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,7 @@ def __init__( # noqa: C901
model_size=-1,
)
self.generation_parameters = config.generation_parameters
self.generation_config = TextGenerationInputGenerateParameters(
**self.generation_parameters.to_tgi_inferenceendpoint_dict()
)
self.generation_config = TextGenerationInputGenerateParameters(**self.generation_parameters.to_tgi_ie_dict())

@staticmethod
def get_larger_hardware_suggestion(cur_instance_type: str = None, cur_instance_size: str = None):
Expand Down
4 changes: 1 addition & 3 deletions src/lighteval/models/endpoints/tgi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ def __init__(self, config: TGIModelConfig) -> None:

self.client = AsyncClient(config.inference_server_address, headers=headers, timeout=240)
self.generation_parameters = config.generation_parameters
self.generation_config = TextGenerationInputGenerateParameters(
**self.generation_parameters.to_tgi_inferenceendpoint_dict()
)
self.generation_config = TextGenerationInputGenerateParameters(**self.generation_parameters.to_tgi_ie_dict())
self._max_gen_toks = 256
self.model_info = requests.get(f"{config.inference_server_address}/info", headers=headers).json()
if "model_id" not in self.model_info:
Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/models/model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def to_transformers_dict(self) -> dict:
}
return {k: v for k, v in args.items() if v is not None}

def to_tgi_inferenceendpoint_dict(self) -> dict:
def to_tgi_ie_dict(self) -> dict:
"""Selects relevant generation and sampling parameters for tgi or inference endpoints models.
Doc: https://huggingface.co/docs/huggingface_hub/v0.26.3/en/package_reference/inference_types#huggingface_hub.TextGenerationInputGenerateParameters
Expand Down

0 comments on commit 97db620

Please sign in to comment.