Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier committed Dec 12, 2024
1 parent 0135c2e commit f62cc89
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 13 deletions.
2 changes: 1 addition & 1 deletion docs/source/package_reference/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
## Endpoints-based Models
### InferenceEndpointModel
[[autodoc]] models.endpoints.endpoint_model.InferenceEndpointModelConfig
[[autodoc]] models.endpoints.endpoint_model.InferenceModelConfig
[[autodoc]] models.endpoints.endpoint_model.ServerlessEndpointModelConfig
[[autodoc]] models.endpoints.endpoint_model.InferenceEndpointModel

### TGI ModelClient
Expand Down
20 changes: 12 additions & 8 deletions src/lighteval/main_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ def inference_endpoint(
str, Argument(help="Path to model config yaml file. (examples/model_configs/endpoint_model.yaml)")
],
tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")],
free_endpoint: Annotated[
str,
Argument(
help="True if you want to use the serverless free endpoints, False (default) if you want to spin up your own inference endpoint."
),
] = False,
# === Common parameters ===
use_chat_template: Annotated[
bool, Option(help="Use chat template for evaluation.", rich_help_panel=HELP_PANNEL_NAME_4)
Expand Down Expand Up @@ -200,9 +206,7 @@ def inference_endpoint(
"""

from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.endpoints.endpoint_model import (
InferenceEndpointModelConfig,
)
from lighteval.models.endpoints.endpoint_model import InferenceEndpointModelConfig, ServerlessEndpointModelConfig
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters

env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir)
Expand All @@ -220,10 +224,10 @@ def inference_endpoint(
parallelism_manager = ParallelismManager.NONE # since we're using inference endpoints in remote

# Find a way to add this back
# if config["base_params"].get("endpoint_name", None):
# return InferenceModelConfig(model=config["base_params"]["endpoint_name"])

model_config = InferenceEndpointModelConfig.from_path(model_config_path)
if free_endpoint:
model_config = ServerlessEndpointModelConfig.from_path(model_config_path)
else:
model_config = InferenceEndpointModelConfig.from_path(model_config_path)

pipeline_params = PipelineParameters(
launcher_type=parallelism_manager,
Expand Down Expand Up @@ -317,7 +321,7 @@ def tgi(
import yaml

from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.model_config import TGIModelConfig
from lighteval.models.endpoints.tgi_model import TGIModelConfig
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters

env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir)
Expand Down
12 changes: 10 additions & 2 deletions src/lighteval/models/endpoints/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,18 @@


@dataclass
class InferenceModelConfig:
class ServerlessEndpointModelConfig:
model: str
add_special_tokens: bool = True

@classmethod
def from_path(cls, path: str) -> "InferenceEndpointModelConfig":
import yaml

with open(path, "r") as f:
config = yaml.safe_load(f)["model"]
return cls(**config["base_params"])


@dataclass
class InferenceEndpointModelConfig:
Expand Down Expand Up @@ -142,7 +150,7 @@ class InferenceEndpointModel(LightevalModel):
"""

def __init__( # noqa: C901
self, config: Union[InferenceEndpointModelConfig, InferenceModelConfig], env_config: EnvConfig
self, config: Union[InferenceEndpointModelConfig, ServerlessEndpointModelConfig], env_config: EnvConfig
) -> None:
self.reuse_existing = getattr(config, "reuse_existing", False)
self._max_length = None
Expand Down
4 changes: 2 additions & 2 deletions src/lighteval/models/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from lighteval.models.endpoints.endpoint_model import (
InferenceEndpointModel,
InferenceEndpointModelConfig,
InferenceModelConfig,
ServerlessEndpointModelConfig,
)
from lighteval.models.endpoints.openai_model import OpenAIClient, OpenAIModelConfig
from lighteval.models.endpoints.tgi_model import ModelClient, TGIModelConfig
Expand Down Expand Up @@ -80,7 +80,7 @@ def load_model( # noqa: C901
if isinstance(config, TGIModelConfig):
return load_model_with_tgi(config)

if isinstance(config, InferenceEndpointModelConfig) or isinstance(config, InferenceModelConfig):
if isinstance(config, InferenceEndpointModelConfig) or isinstance(config, ServerlessEndpointModelConfig):
return load_model_with_inference_endpoints(config, env_config=env_config)

if isinstance(config, BaseModelConfig):
Expand Down

0 comments on commit f62cc89

Please sign in to comment.