diff --git a/mlx_vlm/models/llava/llava.py b/mlx_vlm/models/llava/llava.py index 696d9db..5004151 100644 --- a/mlx_vlm/models/llava/llava.py +++ b/mlx_vlm/models/llava/llava.py @@ -110,27 +110,16 @@ def _merge_input_ids_with_image_features( # Positions of tokens in input_ids, assuming batch size is 1 image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() + num_images, _, vision_hidden_size = image_features.shape - if len(image_positions) != num_images: - raise ValueError( - f"The number of image tokens ({len(image_positions)}) does not " - f" match the number of image inputs ({num_images})." - ) - - text_segments = [] - start_idx = 0 + reshaped_image_hidden_states = image_features.reshape(-1, vision_hidden_size) - for position in image_positions: - text_segments.append(inputs_embeds[:, start_idx:position]) - start_idx = position + 1 - - image_embeddings = mx.split(image_features, image_features.shape[0]) - final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] - final_embeddings += [inputs_embeds[:, start_idx:]] - - # Create a final embedding of shape - # (1, num_image_patches*num_images + sequence_len, embed_dim) - return mx.concatenate(final_embeddings, axis=1) + # cast to the dtype of the input_embeds to support quantized models + reshaped_image_hidden_states = reshaped_image_hidden_states.astype( + inputs_embeds.dtype + ) + inputs_embeds[:, image_positions, :] = reshaped_image_hidden_states + return inputs_embeds def __call__( self, diff --git a/mlx_vlm/models/llava_bunny/llava_bunny.py b/mlx_vlm/models/llava_bunny/llava_bunny.py index 730d4dc..9c09bef 100644 --- a/mlx_vlm/models/llava_bunny/llava_bunny.py +++ b/mlx_vlm/models/llava_bunny/llava_bunny.py @@ -161,28 +161,20 @@ def get_input_embeddings( def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids): image_token_index = self.config.image_token_index - batch_size, seq_length, embed_dim = inputs_embeds.shape - num_images, num_image_patches, _ = image_features.shape + num_images, num_image_patches, embed_dim = image_features.shape - # Positions of tokens in input_ids for each batch - image_positions = mx.argmax(input_ids == image_token_index, axis=1) + # Positions of tokens in input_ids, assuming batch size is 1 + image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() + num_images, _, vision_hidden_size = image_features.shape - final_embeddings = [] - for b in range(batch_size): - text_segments = [] - start_idx = 0 - position = int(image_positions[b].item()) + reshaped_image_hidden_states = image_features.reshape(-1, vision_hidden_size) - text_segments.append(inputs_embeds[b : b + 1, start_idx:position]) - text_segments.append(image_features[b : b + 1]) - text_segments.append(inputs_embeds[b : b + 1, position + 1 :]) - - batch_embeddings = mx.concatenate(text_segments, axis=1) - final_embeddings.append(batch_embeddings) - - # Create a final embedding of shape - # (batch_size, num_image_patches + sequence_len, embed_dim) - return mx.concatenate(final_embeddings, axis=0) + # cast to the dtype of the input_embeds to support quantized models + reshaped_image_hidden_states = reshaped_image_hidden_states.astype( + inputs_embeds.dtype + ) + inputs_embeds[:, image_positions, :] = reshaped_image_hidden_states + return inputs_embeds def __call__( self, diff --git a/mlx_vlm/models/llava_next/llava_next.py b/mlx_vlm/models/llava_next/llava_next.py index f10649f..3ddf43b 100644 --- a/mlx_vlm/models/llava_next/llava_next.py +++ b/mlx_vlm/models/llava_next/llava_next.py @@ -118,23 +118,20 @@ def _merge_input_ids_with_image_features( self, image_features, inputs_embeds, input_ids ): image_token_index = self.config.image_token_index + num_images, num_image_patches, embed_dim = image_features.shape # Positions of tokens in input_ids, assuming batch size is 1 image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() - text_segments = [] - start_idx = 0 + num_images, _, vision_hidden_size = image_features.shape - for position in image_positions: - text_segments.append(inputs_embeds[:, start_idx:position]) - start_idx = position + 1 + reshaped_image_hidden_states = image_features.reshape(-1, vision_hidden_size) - image_embeddings = mx.split(image_features, image_features.shape[0]) - final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] - final_embeddings += [inputs_embeds[:, start_idx:]] - - # Create a final embedding of shape - # (1, num_image_patches*num_images + sequence_len, embed_dim) - return mx.concatenate(final_embeddings, axis=1) + # cast to the dtype of the input_embeds to support quantized models + reshaped_image_hidden_states = reshaped_image_hidden_states.astype( + inputs_embeds.dtype + ) + inputs_embeds[:, image_positions, :] = reshaped_image_hidden_states + return inputs_embeds def __call__( self,