Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added custom model inference. #437

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
5909d4a
Added first version of custom model.
JoelNiklaus Dec 11, 2024
a2d6b63
Merge branch 'main' into add-custom-model
JoelNiklaus Dec 11, 2024
2283c89
Merge branch 'main' into add-custom-model
JoelNiklaus Dec 11, 2024
9563fab
Merge branch 'main' into add-custom-model
JoelNiklaus Dec 12, 2024
319d482
Merge branch 'main' into add-custom-model
clefourrier Dec 12, 2024
464edfe
Merge branch 'main' into add-custom-model
clefourrier Dec 12, 2024
6096042
Moved custom model config.
JoelNiklaus Dec 12, 2024
a7e1fe5
Added warning.
JoelNiklaus Dec 12, 2024
24b8bd3
Added custom model example for google translate.
JoelNiklaus Dec 12, 2024
c177a8e
Added documentation for custom model config.
JoelNiklaus Dec 12, 2024
d712cdb
Added docs.
JoelNiklaus Dec 12, 2024
7553147
Merge branch 'main' into add-custom-model
JoelNiklaus Dec 12, 2024
b41949c
Fixed path error.
JoelNiklaus Dec 12, 2024
aaaadb0
Fixed doc error.
JoelNiklaus Dec 12, 2024
c85065f
Added requirements file for google translate.
JoelNiklaus Dec 12, 2024
f1103da
Moved model loading function to reduce merge conflicts with litellm i…
JoelNiklaus Dec 12, 2024
71f871e
Added diskcache and get source and target language from the task name.
JoelNiklaus Dec 12, 2024
d1af518
Fixed problem with removing languages in the context.
JoelNiklaus Dec 12, 2024
2511158
Added retry logic.
JoelNiklaus Dec 13, 2024
7d5f76d
Merge branch 'main' into add-custom-model
JoelNiklaus Dec 16, 2024
743a284
Update google-translate requirements.
JoelNiklaus Dec 16, 2024
1a37f71
Added another example for a custom model.
JoelNiklaus Dec 17, 2024
2f27645
Made local mt model example more general to support madlad400 as well.
JoelNiklaus Dec 17, 2024
a4d4fee
Merge branch 'main' into add-custom-model
JoelNiklaus Dec 17, 2024
bd08781
Merge branch 'main' into add-custom-model
clefourrier Dec 18, 2024
b7106e4
Make sure generation can happen on the GPU.
JoelNiklaus Dec 18, 2024
a7d176c
Fixed issue with src and tgt lang for seamless model.
JoelNiklaus Dec 19, 2024
f1ba65c
Added cleanup to free the GPU memory again.
JoelNiklaus Dec 19, 2024
ace6e59
Fix dependency issues by switching to deep-translator.
JoelNiklaus Dec 22, 2024
cfd7254
Made inference code more robust against empty responses.
JoelNiklaus Dec 22, 2024
3ddc104
Merge branch 'main' into add-custom-model
JoelNiklaus Dec 23, 2024
f6df2a3
Merge branch 'main' into add-custom-model
clefourrier Jan 2, 2025
348e427
Merge branch 'main' into add-custom-model
JoelNiklaus Jan 7, 2025
a63f4b3
Merge branch 'main' into add-custom-model
JoelNiklaus Jan 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
@@ -15,6 +15,8 @@
title: Add a custom task
- local: adding-a-new-metric
title: Add a custom metric
- local: evaluating-a-custom-model
title: Evaluate a custom model
- local: use-vllm-as-backend
title: Use VLLM as backend
- local: evaluate-the-model-on-a-server-or-container
129 changes: 129 additions & 0 deletions docs/source/evaluating-a-custom-model.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Evaluating a Custom Model

Lighteval allows you to evaluate custom model implementations by creating a custom model class that inherits from `LightevalModel`. This is useful when you want to evaluate models that aren't directly supported by the standard backends (transformers, vllm, etc).

## Creating a Custom Model

1. Create a Python file containing your custom model implementation. The model must inherit from `LightevalModel` and implement all required methods.

Here's a basic example:

```python
from lighteval.models.abstract_model import LightevalModel

class MyCustomModel(LightevalModel):
def __init__(self, config, env_config):
super().__init__(config, env_config)
# Initialize your model here...

def greedy_until(self, requests, max_tokens=None, stop_sequences=None):
# Implement generation logic
pass

def loglikelihood(self, requests, log=True):
# Implement loglikelihood computation
pass

def loglikelihood_rolling(self, requests):
# Implement rolling loglikelihood computation
pass

def loglikelihood_single_token(self, requests):
# Implement single token loglikelihood computation
pass
```

2. The custom model file should contain exactly one class that inherits from `LightevalModel`. This class will be automatically detected and instantiated when loading the model.

> [!TIP]
> You can find a complete example of a custom model implementation in `examples/custom_models/google_translate_model.py`.

## Running the Evaluation

You can evaluate your custom model using either the command line interface or the Python API.

### Using the Command Line

```bash
python -m lighteval custom \
"google-translate" \
"examples/custom_models/google_translate_model.py" \
"lighteval|wmt20:fr-de|0|0" \
--output-dir results \
--max-samples 10
```

The command takes three required arguments:
- The model name (used for tracking in results/logs)
- The path to your model implementation file
- The tasks to evaluate on (same format as other backends)

### Using the Python API

```python
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.custom.custom_model import CustomModelConfig
from lighteval.pipeline import Pipeline, PipelineParameters, EnvConfig

# Set up evaluation tracking
evaluation_tracker = EvaluationTracker(
output_dir="results",
save_details=True
)

# Configure the pipeline
pipeline_params = PipelineParameters(
launcher_type=ParallelismManager.CUSTOM,
env_config=EnvConfig(cache_dir="tmp/")
)

# Configure your custom model
model_config = CustomModelConfig(
model="my-custom-model",
model_definition_file_path="path/to/my_model.py"
)

# Create and run the pipeline
pipeline = Pipeline(
tasks="leaderboard|truthfulqa:mc|0|0",
pipeline_parameters=pipeline_params,
evaluation_tracker=evaluation_tracker,
model_config=model_config
)

pipeline.evaluate()
pipeline.save_and_push_results()
```

## Required Methods

Your custom model must implement these core methods:

- `greedy_until`: For generating text until a stop sequence or max tokens is reached
- `loglikelihood`: For computing log probabilities of specific continuations
- `loglikelihood_rolling`: For computing rolling log probabilities of sequences
- `loglikelihood_single_token`: For computing log probabilities of single tokens

See the `LightevalModel` base class documentation for detailed method signatures and requirements.

## Best Practices

1. **Error Handling**: Implement robust error handling in your model methods to gracefully handle edge cases.

2. **Batching**: Consider implementing efficient batching in your model methods to improve performance.

3. **Resource Management**: Properly manage any resources (e.g., API connections, model weights) in your model's `__init__` and `__del__` methods.

4. **Documentation**: Add clear docstrings to your model class and methods explaining any specific requirements or limitations.

## Example Use Cases

Custom models are particularly useful for:

- Evaluating models accessed through custom APIs
- Wrapping models with specialized preprocessing/postprocessing
- Testing novel model architectures
- Evaluating ensemble models
- Integrating with external services or tools

For a complete example of a custom model that wraps the Google Translate API, see `examples/custom_models/google_translate_model.py`.
4 changes: 4 additions & 0 deletions docs/source/package_reference/models.mdx
Original file line number Diff line number Diff line change
@@ -28,6 +28,10 @@
[[autodoc]] models.endpoints.tgi_model.TGIModelConfig
[[autodoc]] models.endpoints.tgi_model.ModelClient

### Custom Model
[[autodoc]] models.custom.custom_model.CustomModelConfig
[[autodoc]] models.custom.custom_model.CustomModel

### Open AI Models
[[autodoc]] models.endpoints.openai_model.OpenAIClient

200 changes: 200 additions & 0 deletions examples/custom_models/google_translate_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import hashlib
import logging
import os
import time
from typing import Optional

import diskcache
import tenacity
from deep_translator import GoogleTranslator
from tqdm import tqdm
from transformers import AutoTokenizer

from lighteval.data import GenerativeTaskDataset
from lighteval.models.abstract_model import LightevalModel, ModelInfo
from lighteval.models.model_output import (
GenerativeResponse,
LoglikelihoodResponse,
LoglikelihoodSingleTokenResponse,
)
from lighteval.tasks.requests import (
GreedyUntilRequest,
LoglikelihoodRequest,
LoglikelihoodRollingRequest,
LoglikelihoodSingleTokenRequest,
)


logger = logging.getLogger(__name__)


class GoogleTranslateClient(LightevalModel):
def __init__(self, config, env_config) -> None:
self.model = config.model
self.model_definition_file_path = config.model_definition_file_path

self.model_info = ModelInfo(
model_name=config.model,
model_sha="",
model_dtype=None,
model_size="",
)

self._tokenizer = AutoTokenizer.from_pretrained("gpt2") # Use a dummy tokenizer for compatibility

# Deep-translator also supports other translators
self.translator = GoogleTranslator()

# Initialize disk cache
cache_dir = os.path.join(os.getcwd(), ".translation_cache")
self.cache = diskcache.Cache(cache_dir)

self.max_retries = 3
self.retry_delay = 1

def _get_cache_key(self, context: str, src_lang: str, tgt_lang: str) -> str:
"""Generate a unique cache key for the translation request."""
# IMPORTANT: In case we want to support other translators, we can add the translator name to the key
key_string = f"{context}|{src_lang}|{tgt_lang}"
return hashlib.md5(key_string.encode()).hexdigest()

@tenacity.retry(
stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_exponential(multiplier=1, min=4, max=10),
retry=tenacity.retry_if_exception_type((Exception)),
before_sleep=lambda retry_state: time.sleep(1),
)
def _translate_with_cache(self, context: str, src_lang: str, tgt_lang: str) -> str:
"""Translate text using cache if available, otherwise call Google Translate with retry logic."""
cache_key = self._get_cache_key(context, src_lang, tgt_lang)

# Try to get from cache
if cache_key in self.cache:
result = self.cache[cache_key]
if result is not None and result != "":
return result
logger.warning("Translation in cache is empty. Removing from cache and retrying...")
del self.cache[cache_key]

try:
# Updated translation call for deep-translator
self.translator.source = src_lang
self.translator.target = tgt_lang
result = self.translator.translate(context)
if result is None or result == "":
result = ""

self.cache[cache_key] = result
return result
except Exception as e:
logger.warning(f"Translation error: {str(e)}. Retrying...")
raise # Let tenacity handle the retry

def greedy_until(
self,
requests: list[GreedyUntilRequest],
override_bs: Optional[int] = None,
) -> list[GenerativeResponse]:
"""
Generates responses using a greedy decoding strategy until certain ending conditions are met.
Results are cached to disk to avoid repeated translations.

Args:
requests (list[Request]): list of requests containing the context and ending conditions.
override_bs (int, optional): Override the batch size for generation. Defaults to None.

Returns:
list[GenerativeResponse]: list of generated responses.
"""
for request in requests:
request.tokenized_context = self.tok_encode(request.context)

dataset = GenerativeTaskDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
results = []

for _ in tqdm(
dataset.splits_start_end_iterator(),
total=dataset.num_dataset_splits,
desc="Splits",
position=0,
disable=False, # self.disable_tqdm,
):
for r in tqdm(dataset, desc="Batch", position=1, disable=False):
# Extract source and target languages from task name
# Format is like "community|sdst-text_level:de-fr|0"
src_lang, tgt_lang = r.task_name.split("|")[1].split(":")[-1].split("-")

context = r.context.replace(f"{src_lang.upper()}: ", "").replace(f"\n{tgt_lang.upper()}: ", "")
result = self._translate_with_cache(context, src_lang, tgt_lang)
if result is None:
result = "" # Set to empty string to prevent errors in metric computation

cur_response = GenerativeResponse(
result=result,
logits=None,
generated_tokens=[],
input_tokens=[],
)
results.append(cur_response)

return dataset.get_original_order(results)

@property
def tokenizer(self):
return self._tokenizer

def tok_encode(self, text: str):
return text

@property
def add_special_tokens(self) -> bool:
return False

@property
def max_length(self) -> int:
"""Return the maximum sequence length of the model."""
return 4096

def loglikelihood(
self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodResponse]:
"""Tokenize the context and continuation and compute the log likelihood of those
tokenized sequences.
"""
raise NotImplementedError

def loglikelihood_rolling(
self, requests: list[LoglikelihoodRollingRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodResponse]:
"""This function is used to compute the log likelihood of the context for perplexity metrics."""
raise NotImplementedError

def loglikelihood_single_token(
self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodSingleTokenResponse]:
"""Tokenize the context and continuation and compute the log likelihood of those
tokenized sequences.
"""
raise NotImplementedError
285 changes: 285 additions & 0 deletions examples/custom_models/local_mt_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import logging
from typing import Optional

import pycountry
import torch
from tqdm import tqdm
from transformers import (
AutoModelForSeq2SeqLM,
AutoProcessor,
AutoTokenizer,
SeamlessM4Tv2ForTextToText,
)

from lighteval.data import GenerativeTaskDataset
from lighteval.models.abstract_model import LightevalModel, ModelInfo, TokenSequence
from lighteval.models.model_output import (
GenerativeResponse,
LoglikelihoodResponse,
LoglikelihoodSingleTokenResponse,
)
from lighteval.tasks.requests import (
GreedyUntilRequest,
LoglikelihoodRequest,
LoglikelihoodRollingRequest,
LoglikelihoodSingleTokenRequest,
)


logger = logging.getLogger(__name__)


class LocalMTClient(LightevalModel):
"""
A custom model implementation for local machine translation models, specifically supporting:
- SeamlessM4T v2 models from Meta
- MADLAD-400 models from Google
This class provides a unified interface for both model families while handling their different
tokenization and generation approaches transparently.
Args:
config (CustomModelConfig): Configuration containing:
- model (str): Model identifier/path (e.g. "facebook/seamless-m4t-v2-large" or "google/madlad400-7b-mt")
- model_definition_file_path (str): Path to this model definition file
env_config: Environment configuration (unused)
The model automatically detects whether to load SeamlessM4T or MADLAD based on the model identifier
and initializes the appropriate tokenizer and model.
Translation tasks should specify the source and target languages in the format:
"{task_name}|{...}:{src}-{tgt}"
where src and tgt are ISO language codes (2 or 3 letter codes supported).
Example:
```lighteval custom facebook/seamless-m4t-v2-large examples/custom_models/local_mt_model.py "lighteval|wmt20:fr-de|0|0" --max-samples 10 --save-details
```
Note:
- SeamlessM4T models use the AutoProcessor for tokenization
- MADLAD models use the standard AutoTokenizer
- Language codes are automatically converted to 3-letter ISO codes for SeamlessM4T
"""

def __init__(self, config, env_config) -> None:
self.model = config.model
self.model_definition_file_path = config.model_definition_file_path
self.batch_size = 32
self.device = "cuda" if torch.cuda.is_available() else "cpu"

self.model_info = ModelInfo(
model_name=config.model,
model_sha="",
model_dtype=None,
model_size="",
)

# Update model initialization to handle both models
if "seamless-m4t" in config.model:
self._tokenizer = AutoProcessor.from_pretrained(config.model)
self._model = SeamlessM4Tv2ForTextToText.from_pretrained(config.model)
self.model_type = "seamless-4mt"
self.batch_size = 1
logger.info(
"Using batch size of 1 for seamless-4mt model because it the target language needs to be set for the entire batch."
)
elif "madlad400" in config.model:
self._tokenizer = AutoTokenizer.from_pretrained(config.model)
self._model = AutoModelForSeq2SeqLM.from_pretrained(config.model)
self.model_type = "madlad400"
else:
raise ValueError(f"Unsupported model: {config.model}")

self._model.to(self.device)
self._model.eval()

def _convert_to_iso3(self, lang_code: str) -> str:
"""Convert 2-letter ISO code to 3-letter ISO code."""
try:
return pycountry.languages.get(alpha_2=lang_code.lower()).alpha_3
except AttributeError:
# If conversion fails, return the original code
return lang_code

def greedy_until(
self,
requests: list[GreedyUntilRequest],
override_bs: Optional[int] = None,
) -> list[GenerativeResponse]:
"""
Generates responses using a greedy decoding strategy until certain ending conditions are met.
Results are cached to disk to avoid repeated translations.
Args:
requests (list[Request]): list of requests containing the context and ending conditions.
override_bs (int, optional): Override the batch size for generation. Defaults to None.
Returns:
list[GenerativeResponse]: list of generated responses.
"""

def get_langs(task_name: str) -> tuple[str, str]:
src, tgt = task_name.split("|")[1].split(":")[-1].split("-")
if self.model_type == "seamless-4mt":
return self._convert_to_iso3(src), self._convert_to_iso3(tgt)
return src, tgt

# Prepare all inputs first for creating the GenerativeTaskDataset
prepared_requests = []
for request in requests:
src_lang, tgt_lang = get_langs(request.task_name)
request.context = request.context.replace(f"{src_lang.upper()}: ", "").replace(
f"\n{tgt_lang.upper()}: ", ""
)
if self.model_type == "madlad400":
request.context = f"<2{tgt_lang}> {request.context}"

request.tokenized_context = self.tok_encode(request.context)
prepared_requests.append(request)

# Create dataset after preparation
dataset = GenerativeTaskDataset(requests=prepared_requests, num_dataset_splits=self.DATASET_SPLITS)
results = []
batch_size = override_bs or self.batch_size

for split_start, split_end in tqdm(
dataset.splits_start_end_iterator(),
total=dataset.num_dataset_splits,
desc="Splits",
position=0,
disable=False,
):
# Get all requests for this split directly from sorted_data
current_requests = dataset.sorted_data[split_start:split_end]

# Process in batches
for batch_idx in tqdm(
range(0, len(current_requests), batch_size), desc="Batches", position=1, disable=False
):
batch = current_requests[batch_idx : batch_idx + batch_size]

# Batch tokenize all inputs together instead of concatenating pre-tokenized inputs because of the padding
batch_texts = [r.context for r in batch]

# This is the tokenization step that really counts, as it actually gets used
tokenizer_kwargs = {"text": batch_texts, "return_tensors": "pt", "padding": True}
if self.model_type == "seamless-4mt":
src_lang = get_langs(batch[0].task_name)[0]
tokenizer_kwargs["src_lang"] = src_lang

input_ids, attention_mask = self._tokenizer(**tokenizer_kwargs).to(self.device).values()

generation_sizes = [r.generation_size for r in batch]
assert set(generation_sizes) == {generation_sizes[0]}, "All generation sizes must be the same"

# Use unpacked values directly
generate_kwargs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"max_new_tokens": generation_sizes[0],
}
if self.model_type == "seamless-4mt":
tgt_lang = get_langs(batch[0].task_name)[1]
generate_kwargs["tgt_lang"] = tgt_lang

output_ids = self._model.generate(**generate_kwargs)
translations = self._tokenizer.batch_decode(output_ids, skip_special_tokens=True)

# Create responses for the batch
for input_tokens, output_tokens, translation in zip(input_ids, output_ids, translations):
results.append(
GenerativeResponse(
input_tokens=input_tokens,
generated_tokens=output_tokens,
result=translation,
logits=None,
)
)

return dataset.get_original_order(results)

def cleanup(self):
import gc

logger.info("Cleaning up GPU memory for local MT client.")

# Show GPU memory before cleanup
if torch.cuda.is_available():
logger.info(f"GPU memory before cleanup: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")

# Delete model and move to CPU
if hasattr(self, "_model"):
self._model.cpu()
del self._model
self._model = None

if hasattr(self, "_tokenizer"):
del self._tokenizer
self._tokenizer = None

torch.cuda.empty_cache()
gc.collect()

# Show GPU memory after cleanup
if torch.cuda.is_available():
logger.info(f"GPU memory after cleanup: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")

@property
def tokenizer(self):
return self._tokenizer

def tok_encode(self, str_to_encode: str | list[str], add_special_tokens: Optional[bool] = None) -> TokenSequence:
return self._tokenizer(text=str_to_encode, add_special_tokens=add_special_tokens or False).to(self.device)

@property
def add_special_tokens(self) -> bool:
return False

@property
def max_length(self) -> int:
"""Return the maximum sequence length of the model."""
return 4096

def loglikelihood(
self, requests: list[LoglikelihoodRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodResponse]:
"""Tokenize the context and continuation and compute the log likelihood of those
tokenized sequences.
"""
raise NotImplementedError

def loglikelihood_rolling(
self, requests: list[LoglikelihoodRollingRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodResponse]:
"""This function is used to compute the log likelihood of the context for perplexity metrics."""
raise NotImplementedError

def loglikelihood_single_token(
self, requests: list[LoglikelihoodSingleTokenRequest], override_bs: Optional[int] = None
) -> list[LoglikelihoodSingleTokenResponse]:
"""Tokenize the context and continuation and compute the log likelihood of those
tokenized sequences.
"""
raise NotImplementedError
2 changes: 2 additions & 0 deletions src/lighteval/__main__.py
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@

import lighteval.main_accelerate
import lighteval.main_baseline
import lighteval.main_custom
import lighteval.main_endpoint
import lighteval.main_nanotron
import lighteval.main_tasks
@@ -64,6 +65,7 @@
app.command(rich_help_panel="Evaluation Utils")(lighteval.main_baseline.baseline)
app.command(rich_help_panel="Evaluation Backends")(lighteval.main_nanotron.nanotron)
app.command(rich_help_panel="Evaluation Backends")(lighteval.main_vllm.vllm)
app.command(rich_help_panel="Evaluation Backends")(lighteval.main_custom.custom)
app.add_typer(
lighteval.main_endpoint.app,
name="endpoint",
145 changes: 145 additions & 0 deletions src/lighteval/main_custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import os
from typing import Optional

import typer
from typer import Argument, Option
from typing_extensions import Annotated

from lighteval.models.custom.custom_model import CustomModelConfig


app = typer.Typer()


TOKEN = os.getenv("HF_TOKEN")
CACHE_DIR: str = os.getenv("HF_HOME", "/scratch")

HELP_PANNEL_NAME_1 = "Common Paramaters"
HELP_PANNEL_NAME_2 = "Logging Parameters"
HELP_PANNEL_NAME_3 = "Debug Paramaters"
HELP_PANNEL_NAME_4 = "Modeling Paramaters"


@app.command(rich_help_panel="Evaluation Backends")
def custom(
# === general ===
model_name: Annotated[str, Argument(help="The model name to evaluate")],
model_definition_file_path: Annotated[str, Argument(help="The model definition file path to evaluate")],
tasks: Annotated[str, Argument(help="Comma-separated list of tasks to evaluate on.")],
# === Common parameters ===
use_chat_template: Annotated[
bool, Option(help="Use chat template for evaluation.", rich_help_panel=HELP_PANNEL_NAME_4)
] = False,
system_prompt: Annotated[
Optional[str], Option(help="Use system prompt for evaluation.", rich_help_panel=HELP_PANNEL_NAME_4)
] = None,
dataset_loading_processes: Annotated[
int, Option(help="Number of processes to use for dataset loading.", rich_help_panel=HELP_PANNEL_NAME_1)
] = 1,
custom_tasks: Annotated[
Optional[str], Option(help="Path to custom tasks directory.", rich_help_panel=HELP_PANNEL_NAME_1)
] = None,
cache_dir: Annotated[
str, Option(help="Cache directory for datasets and models.", rich_help_panel=HELP_PANNEL_NAME_1)
] = CACHE_DIR,
num_fewshot_seeds: Annotated[
int, Option(help="Number of seeds to use for few-shot evaluation.", rich_help_panel=HELP_PANNEL_NAME_1)
] = 1,
# === saving ===
output_dir: Annotated[
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANNEL_NAME_2)
] = "results",
push_to_hub: Annotated[
bool, Option(help="Push results to the huggingface hub.", rich_help_panel=HELP_PANNEL_NAME_2)
] = False,
push_to_tensorboard: Annotated[
bool, Option(help="Push results to tensorboard.", rich_help_panel=HELP_PANNEL_NAME_2)
] = False,
public_run: Annotated[
bool, Option(help="Push results and details to a public repo.", rich_help_panel=HELP_PANNEL_NAME_2)
] = False,
results_org: Annotated[
Optional[str], Option(help="Organization to push results to.", rich_help_panel=HELP_PANNEL_NAME_2)
] = None,
save_details: Annotated[
bool, Option(help="Save detailed, sample per sample, results.", rich_help_panel=HELP_PANNEL_NAME_2)
] = False,
# === debug ===
max_samples: Annotated[
Optional[int], Option(help="Maximum number of samples to evaluate on.", rich_help_panel=HELP_PANNEL_NAME_3)
] = None,
override_batch_size: Annotated[
int, Option(help="Override batch size for evaluation.", rich_help_panel=HELP_PANNEL_NAME_3)
] = None,
job_id: Annotated[
int, Option(help="Optional job id for future refenrence.", rich_help_panel=HELP_PANNEL_NAME_3)
] = 0,
):
"""
Evaluate custom models (can be anything).
"""
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.pipeline import EnvConfig, ParallelismManager, Pipeline, PipelineParameters

env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir)
evaluation_tracker = EvaluationTracker(
output_dir=output_dir,
save_details=save_details,
push_to_hub=push_to_hub,
push_to_tensorboard=push_to_tensorboard,
public=public_run,
hub_results_org=results_org,
)

parallelism_manager = ParallelismManager.CUSTOM
model_config = CustomModelConfig(model=model_name, model_definition_file_path=model_definition_file_path)

pipeline_params = PipelineParameters(
launcher_type=parallelism_manager,
env_config=env_config,
job_id=job_id,
dataset_loading_processes=dataset_loading_processes,
custom_tasks_directory=custom_tasks,
override_batch_size=override_batch_size,
num_fewshot_seeds=num_fewshot_seeds,
max_samples=max_samples,
use_chat_template=use_chat_template,
system_prompt=system_prompt,
)
pipeline = Pipeline(
tasks=tasks,
pipeline_parameters=pipeline_params,
evaluation_tracker=evaluation_tracker,
model_config=model_config,
)

pipeline.evaluate()

pipeline.show_results()

results = pipeline.get_results()

pipeline.save_and_push_results()

return results
78 changes: 78 additions & 0 deletions src/lighteval/models/custom/custom_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# MIT License
#
# Copyright (c) 2024 The HuggingFace Team
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from dataclasses import dataclass


@dataclass
class CustomModelConfig:
"""
Configuration class for loading custom model implementations in Lighteval.
This config allows users to define and load their own model implementations by specifying
a Python file containing a custom model class that inherits from LightevalModel.
The custom model file should contain exactly one class that inherits from LightevalModel.
This class will be automatically detected and instantiated when loading the model.
Args:
model (str):
An identifier for the model. This can be used to track which model was evaluated
in the results and logs.
model_definition_file_path (str):
Path to a Python file containing the custom model implementation. This file must
define exactly one class that inherits from LightevalModel. The class should
implement all required methods from the LightevalModel interface.
Example usage:
```python
# Define config
config = CustomModelConfig(
model="my-custom-model",
model_definition_file_path="path/to/my_model.py"
)
# Example custom model file (my_model.py):
from lighteval.models.abstract_model import LightevalModel
class MyCustomModel(LightevalModel):
def __init__(self, config, env_config):
super().__init__(config, env_config)
# Custom initialization...
def greedy_until(self, *args, **kwargs):
# Custom generation logic...
pass
```
An example of a custom model can be found in `examples/custom_models/google_translate_model.py`.
Notes:
- The custom model class must inherit from LightevalModel and implement all required methods
- Only one class inheriting from LightevalModel should be defined in the file
- The model file is dynamically loaded at runtime, so ensure all dependencies are available
- Exercise caution when loading custom model files as they can execute arbitrary code
"""

model: str
model_definition_file_path: str
35 changes: 35 additions & 0 deletions src/lighteval/models/model_loader.py
Original file line number Diff line number Diff line change
@@ -23,6 +23,8 @@
import logging
from typing import Union

from lighteval.models.abstract_model import LightevalModel
from lighteval.models.custom.custom_model import CustomModelConfig
from lighteval.models.dummy.dummy_model import DummyModel, DummyModelConfig
from lighteval.models.endpoints.endpoint_model import (
InferenceEndpointModel,
@@ -60,6 +62,7 @@ def load_model( # noqa: C901
InferenceEndpointModelConfig,
DummyModelConfig,
VLLMModelConfig,
CustomModelConfig,
OpenAIModelConfig,
LiteLLMModelConfig,
],
@@ -96,6 +99,9 @@ def load_model( # noqa: C901
if isinstance(config, VLLMModelConfig):
return load_model_with_accelerate_or_default(config=config, env_config=env_config)

if isinstance(config, CustomModelConfig):
return load_custom_model(config=config, env_config=env_config)

if isinstance(config, OpenAIModelConfig):
return load_openai_model(config=config, env_config=env_config)

@@ -131,6 +137,35 @@ def load_openai_model(config: OpenAIModelConfig, env_config: EnvConfig):
return model


def load_custom_model(config: CustomModelConfig, env_config: EnvConfig):
logger.warning(f"Executing custom model code loaded from {config.model_definition_file_path}.")

import importlib.util

# Load the Python file
JoelNiklaus marked this conversation as resolved.
Show resolved Hide resolved
spec = importlib.util.spec_from_file_location("custom_model_module", config.model_definition_file_path)
if spec is None or spec.loader is None:
raise ImportError(f"Could not load file: {config.model_definition_file_path}")

module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

# Find the first class that inherits from LightevalModel
model_class = None
for attr_name in dir(module):
attr = getattr(module, attr_name)
if isinstance(attr, type) and issubclass(attr, LightevalModel) and attr != LightevalModel:
model_class = attr
break

if model_class is None:
raise ValueError(f"No class inheriting from LightevalModel found in {config.model_definition_file_path}")

model = model_class(config, env_config)

return model


def load_model_with_inference_endpoints(
config: Union[InferenceEndpointModelConfig, ServerlessEndpointModelConfig], env_config: EnvConfig
):
1 change: 1 addition & 0 deletions src/lighteval/pipeline.py
Original file line number Diff line number Diff line change
@@ -76,6 +76,7 @@ class ParallelismManager(Enum):
TGI = auto()
OPENAI = auto()
VLLM = auto()
CUSTOM = auto()
NONE = auto()