|
4 | 4 | from .registry import get_encoding
|
5 | 5 |
|
6 | 6 | # TODO: these will likely be replaced by an API endpoint
|
7 |
| -_MODEL_PREFIX_TO_ENCODING: dict[str, str] = { |
| 7 | +MODEL_PREFIX_TO_ENCODING: dict[str, str] = { |
8 | 8 | # chat
|
9 | 9 | "gpt-4-": "cl100k_base", # e.g., gpt-4-0314, etc., plus gpt-4-32k
|
10 | 10 | "gpt-3.5-turbo-": "cl100k_base", # e.g, gpt-3.5-turbo-0301, -0401, etc.
|
|
16 | 16 | "ft:babbage-002": "cl100k_base",
|
17 | 17 | }
|
18 | 18 |
|
19 |
| -_MODEL_TO_ENCODING: dict[str, str] = { |
| 19 | +MODEL_TO_ENCODING: dict[str, str] = { |
20 | 20 | # chat
|
21 | 21 | "gpt-4": "cl100k_base",
|
22 | 22 | "gpt-3.5-turbo": "cl100k_base",
|
|
64 | 64 | }
|
65 | 65 |
|
66 | 66 |
|
67 |
| -def encoding_for_model(model_name: str) -> Encoding: |
68 |
| - """Returns the encoding used by a model.""" |
| 67 | +def encoding_name_for_model(model_name: str) -> str: |
| 68 | + """Returns the name of the encoding used by a model. |
| 69 | +
|
| 70 | + Raises a KeyError if the model name is not recognised. |
| 71 | + """ |
69 | 72 | encoding_name = None
|
70 |
| - if model_name in _MODEL_TO_ENCODING: |
71 |
| - encoding_name = _MODEL_TO_ENCODING[model_name] |
| 73 | + if model_name in MODEL_TO_ENCODING: |
| 74 | + encoding_name = MODEL_TO_ENCODING[model_name] |
72 | 75 | else:
|
73 | 76 | # Check if the model matches a known prefix
|
74 | 77 | # Prefix matching avoids needing library updates for every model version release
|
75 | 78 | # Note that this can match on non-existent models (e.g., gpt-3.5-turbo-FAKE)
|
76 |
| - for model_prefix, model_encoding_name in _MODEL_PREFIX_TO_ENCODING.items(): |
| 79 | + for model_prefix, model_encoding_name in MODEL_PREFIX_TO_ENCODING.items(): |
77 | 80 | if model_name.startswith(model_prefix):
|
78 |
| - return get_encoding(model_encoding_name) |
| 81 | + return model_encoding_name |
79 | 82 |
|
80 | 83 | if encoding_name is None:
|
81 | 84 | raise KeyError(
|
82 | 85 | f"Could not automatically map {model_name} to a tokeniser. "
|
83 | 86 | "Please use `tiktoken.get_encoding` to explicitly get the tokeniser you expect."
|
84 | 87 | ) from None
|
85 | 88 |
|
86 |
| - return get_encoding(encoding_name) |
| 89 | + return encoding_name |
| 90 | + |
| 91 | + |
| 92 | +def encoding_for_model(model_name: str) -> Encoding: |
| 93 | + """Returns the encoding used by a model. |
| 94 | +
|
| 95 | + Raises a KeyError if the model name is not recognised. |
| 96 | + """ |
| 97 | + return get_encoding(encoding_name_for_model(model_name)) |
0 commit comments