Skip to content

Commit 195f2b9

Browse files
authored
Set multiple eos tokens for VLM (#119)
* Set multiple eos tokens for VLM * attempt 2
1 parent f2e7ec7 commit 195f2b9

File tree

3 files changed

+17
-6
lines changed

3 files changed

+17
-6
lines changed

mlx_engine/generate.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -267,11 +267,7 @@ def sampler_func_wrapper(*args, **kwargs):
267267
tokenizer = model_kit.tokenizer
268268

269269
# Set up stop string processor if non-empty stop_strings are provided
270-
eos_token_ids = (
271-
tokenizer.eos_token_ids
272-
if isinstance(tokenizer.eos_token_ids, Iterable)
273-
else [tokenizer.eos_token_ids]
274-
)
270+
eos_token_ids = tokenizer.eos_token_ids
275271
stop_string_processor = None
276272
if stop_strings is not None and len(stop_strings) > 0:
277273
stop_string_processor = StopStringProcessor(stop_strings, tokenizer)

mlx_engine/vision/vision_model_kit.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .vision_model_wrapper import VisionModelWrapper
77

88
import mlx_vlm
9+
import mlx_lm
910
from pathlib import Path
1011
import mlx.core as mx
1112
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
@@ -100,8 +101,21 @@ def _full_model_init(self):
100101
else:
101102
self.model, self.processor, self.model_weights = return_tuple
102103
self.model = VisionModelWrapper(self.model)
103-
self.tokenizer = mlx_vlm.tokenizer_utils.load_tokenizer(self.model_path)
104+
105+
# Set the eos_token_ids
106+
eos_token_ids = []
107+
if (eos_tokens := self.config.get("eos_token_ids", None)) is not None:
108+
eos_token_ids = list(set(eos_tokens))
109+
log_info(f"Setting eos token ids: {eos_token_ids}")
110+
elif (eos_tokens := self.config.get("eos_token_id", None)) is not None:
111+
eos_token_ids = [eos_tokens]
112+
113+
# Use the mlx_lm tokenizer since it's more robust
114+
self.tokenizer = mlx_lm.tokenizer_utils.load_tokenizer(
115+
self.model_path, eos_token_ids=list(eos_token_ids)
116+
)
104117
self.detokenizer = self.tokenizer.detokenizer
118+
105119
self.cache_wrapper = None
106120
mx.metal.clear_cache()
107121

tests/test_vision_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def test_gemma3_text_only(self):
227227
prompt = f"{self.text_only_prompt}"
228228
self.model_helper("mlx-community/gemma-3-4b-it-4bit", prompt, text_only=True)
229229

230+
230231
"""
231232
To find the correct prompt format for new models, run this command for your model in the terminal and check the prompt dump:
232233
python -m mlx_vlm.generate --model ~/.cache/lm-studio/models/mlx-community/MODEL-NAME --max-tokens 100 --temp 0.0 --image http://images.cocodataset.org/val2017/000000039769.jpg --prompt "What do you see?"

0 commit comments

Comments
 (0)