|
3 | 3 | import json
|
4 | 4 | import logging
|
5 | 5 | import os
|
6 |
| -from typing import Dict, List, Optional, Tuple, Type, TypeVar |
| 6 | +from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union |
7 | 7 |
|
8 | 8 | import transformers
|
9 | 9 | from sqlitedict import SqliteDict
|
@@ -192,15 +192,13 @@ def tokenizer_name(self) -> str:
|
192 | 192 | "To use this model with chat templates, please implement the 'tokenizer_name' property."
|
193 | 193 | )
|
194 | 194 |
|
195 |
| - @property |
196 |
| - def chat_template(self) -> str: |
197 |
| - """Must be defined for LM subclasses that implement Chat Templating. |
198 |
| - Should return the structure of the chat template applied to user/assistant messages. |
199 |
| - This is used only to save in the experiment results for reproducibility. |
| 195 | + def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: |
| 196 | + """Returns the chat template structure for user/assistant messages if a template is provided. |
| 197 | + This method is intended to be overridden in a subclass to define a specific chat template format. |
| 198 | + For models that do not support chat templates, this method returns None by default. |
200 | 199 | """
|
201 |
| - raise NotImplementedError( |
202 |
| - "To use this model with chat templates, please implement the 'chat_template' property." |
203 |
| - ) |
| 200 | + |
| 201 | + return "" |
204 | 202 |
|
205 | 203 | def set_cache_hook(self, cache_hook) -> None:
|
206 | 204 | self.cache_hook = cache_hook
|
@@ -316,6 +314,8 @@ class TemplateLM(LM):
|
316 | 314 | and boilerplate often included in other LM subclasses.
|
317 | 315 | """
|
318 | 316 |
|
| 317 | + tokenizer = None |
| 318 | + |
319 | 319 | @property
|
320 | 320 | @abc.abstractmethod
|
321 | 321 | def eot_token_id(self):
|
@@ -386,3 +386,99 @@ def loglikelihood_rolling(
|
386 | 386 | @abc.abstractmethod
|
387 | 387 | def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
|
388 | 388 | pass
|
| 389 | + |
| 390 | + def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: |
| 391 | + """ |
| 392 | + Set and get the appropriate chat template for the model. |
| 393 | + This method sets the tokenizer's chat_template and returns the template string for reproducibility. |
| 394 | +
|
| 395 | + The template selection logic is adapted from the Transformers library's `apply_chat_template` |
| 396 | + method in the Tokenizer class. The original implementation can be found at: |
| 397 | + https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1687 |
| 398 | +
|
| 399 | + This method ensures that the right template is chosen based on the following: |
| 400 | + 0. If the model has no 'tokenizer' attribute: assumes that there is only a single possible chat template, handled on the model provider side internally. Returns the empty string. |
| 401 | + 1. If the model's tokenizer has multiple templates: |
| 402 | + a. Use the specified template if it exists in the dictionary. |
| 403 | + b. Use the default template from the list if no specific template is provided. |
| 404 | + c. Raise an error if no default template exists and no specific template is provided. |
| 405 | + 2. If the model's tokenizer has a single template or no template: |
| 406 | + a. Use the tokenizer's chat template if available. |
| 407 | + b. Fall back to the default chat template if no tokenizer chat template exists. |
| 408 | +
|
| 409 | + Args: |
| 410 | + chat_template (Union[bool, str]): Specifies the chat template to use. |
| 411 | + - If False or None, no template is applied. |
| 412 | + - If True, the default or only available template is used. |
| 413 | + - If a string, the template with the matching name is used. |
| 414 | +
|
| 415 | + Returns: |
| 416 | + Optional[str]: The selected chat template, or None if no template is applied. |
| 417 | + """ |
| 418 | + if self.tokenizer is None: |
| 419 | + return "" |
| 420 | + |
| 421 | + if chat_template is False or chat_template is None: |
| 422 | + eval_logger.warning( |
| 423 | + "model.chat_template was called with the chat_template set to False or None. " |
| 424 | + "Therefore no chat template will be applied. Make sure this is an intended behavior." |
| 425 | + ) |
| 426 | + return None |
| 427 | + |
| 428 | + # Convert boolean chat_template to None to ensure compatibility with the adapted logic |
| 429 | + if isinstance(chat_template, bool): |
| 430 | + chat_template = None |
| 431 | + using_default_template = False |
| 432 | + |
| 433 | + # First, handle the cases when the model has a dict of multiple templates |
| 434 | + template = self.tokenizer.chat_template or self.tokenizer.default_chat_template |
| 435 | + |
| 436 | + if isinstance(template, dict): |
| 437 | + using_default_dict = self.tokenizer.chat_template is None |
| 438 | + |
| 439 | + if chat_template is not None: |
| 440 | + if chat_template in template: |
| 441 | + selected_template = template[chat_template] |
| 442 | + if using_default_dict: |
| 443 | + using_default_template = True |
| 444 | + else: |
| 445 | + raise ValueError( |
| 446 | + f"The specified chat template '{chat_template}' is not available. " |
| 447 | + f"Available template names are {sorted(template.keys())}." |
| 448 | + ) |
| 449 | + else: |
| 450 | + # If user didn't pass a chat template, use the default template from the dict |
| 451 | + if "default" in template: |
| 452 | + selected_template = template["default"] |
| 453 | + using_default_template = True |
| 454 | + else: |
| 455 | + raise ValueError( |
| 456 | + "This model has multiple chat templates with no default specified! Please either pass a chat " |
| 457 | + "template or the name of the template you wish to use to the `chat_template` argument. Available " |
| 458 | + f"template names are {sorted(template.keys())}." |
| 459 | + ) |
| 460 | + |
| 461 | + # Cases when the model has a single template or no template |
| 462 | + else: |
| 463 | + # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template |
| 464 | + if isinstance(chat_template, str): |
| 465 | + eval_logger.warning( |
| 466 | + "Chat template name provided, but the tokenizer's chat template is not a dictionary. " |
| 467 | + "Using the tokenizer's chat template or the default template instead." |
| 468 | + ) |
| 469 | + if self.tokenizer.chat_template is not None: |
| 470 | + selected_template = self.tokenizer.chat_template |
| 471 | + else: |
| 472 | + selected_template = self.tokenizer.default_chat_template |
| 473 | + using_default_template = True |
| 474 | + |
| 475 | + if using_default_template: |
| 476 | + eval_logger.warning( |
| 477 | + "No chat template is set for this tokenizer, falling back to a default class-level template. This is " |
| 478 | + "very error-prone, because models are often trained with templates different from the class default! " |
| 479 | + "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which " |
| 480 | + "point any code depending on them will stop working. We recommend setting a valid chat template before " |
| 481 | + "then to ensure that this model continues working without issues." |
| 482 | + ) |
| 483 | + |
| 484 | + return selected_template |
0 commit comments