diff --git a/mlx_vlm/models/pixtral/pixtral.py b/mlx_vlm/models/pixtral/pixtral.py index fb51a84..37dbc3c 100644 --- a/mlx_vlm/models/pixtral/pixtral.py +++ b/mlx_vlm/models/pixtral/pixtral.py @@ -70,7 +70,7 @@ def get_input_embeddings( pixel_values: Optional[mx.array] = None, ): if pixel_values is None: - return self.language_model(input_ids) + return self.language_model.model.embed_tokens(input_ids) # Get the input embeddings from the language model inputs_embeds = self.language_model.model.embed_tokens(input_ids) diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index 218ffef..ffc71e4 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -1000,7 +1000,7 @@ def stream_generate( Yields: Generator[Tuple[mx.array, mx.array]]: A generator producing text. """ - tokenizer = processor if hasattr(processor, "encode") else processor.tokenizer + tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor prompt_tokens = mx.array(tokenizer.encode(prompt)) resize_shape = kwargs.pop("resize_shape", None)