Skip to content

Commit

Permalink
fix llava video processing
Browse files Browse the repository at this point in the history
  • Loading branch information
Blaizzy committed Feb 1, 2025
1 parent b3e7d0a commit 86bdeda
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 50 deletions.
27 changes: 8 additions & 19 deletions mlx_vlm/models/llava/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,27 +110,16 @@ def _merge_input_ids_with_image_features(

# Positions of <image> tokens in input_ids, assuming batch size is 1
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
num_images, _, vision_hidden_size = image_features.shape

if len(image_positions) != num_images:
raise ValueError(
f"The number of image tokens ({len(image_positions)}) does not "
f" match the number of image inputs ({num_images})."
)

text_segments = []
start_idx = 0
reshaped_image_hidden_states = image_features.reshape(-1, vision_hidden_size)

for position in image_positions:
text_segments.append(inputs_embeds[:, start_idx:position])
start_idx = position + 1

image_embeddings = mx.split(image_features, image_features.shape[0])
final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
final_embeddings += [inputs_embeds[:, start_idx:]]

# Create a final embedding of shape
# (1, num_image_patches*num_images + sequence_len, embed_dim)
return mx.concatenate(final_embeddings, axis=1)
# cast to the dtype of the input_embeds to support quantized models
reshaped_image_hidden_states = reshaped_image_hidden_states.astype(
inputs_embeds.dtype
)
inputs_embeds[:, image_positions, :] = reshaped_image_hidden_states
return inputs_embeds

def __call__(
self,
Expand Down
30 changes: 11 additions & 19 deletions mlx_vlm/models/llava_bunny/llava_bunny.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,28 +161,20 @@ def get_input_embeddings(

def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids):
image_token_index = self.config.image_token_index
batch_size, seq_length, embed_dim = inputs_embeds.shape
num_images, num_image_patches, _ = image_features.shape
num_images, num_image_patches, embed_dim = image_features.shape

# Positions of <image> tokens in input_ids for each batch
image_positions = mx.argmax(input_ids == image_token_index, axis=1)
# Positions of <image> tokens in input_ids, assuming batch size is 1
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
num_images, _, vision_hidden_size = image_features.shape

final_embeddings = []
for b in range(batch_size):
text_segments = []
start_idx = 0
position = int(image_positions[b].item())
reshaped_image_hidden_states = image_features.reshape(-1, vision_hidden_size)

text_segments.append(inputs_embeds[b : b + 1, start_idx:position])
text_segments.append(image_features[b : b + 1])
text_segments.append(inputs_embeds[b : b + 1, position + 1 :])

batch_embeddings = mx.concatenate(text_segments, axis=1)
final_embeddings.append(batch_embeddings)

# Create a final embedding of shape
# (batch_size, num_image_patches + sequence_len, embed_dim)
return mx.concatenate(final_embeddings, axis=0)
# cast to the dtype of the input_embeds to support quantized models
reshaped_image_hidden_states = reshaped_image_hidden_states.astype(
inputs_embeds.dtype
)
inputs_embeds[:, image_positions, :] = reshaped_image_hidden_states
return inputs_embeds

def __call__(
self,
Expand Down
21 changes: 9 additions & 12 deletions mlx_vlm/models/llava_next/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,23 +118,20 @@ def _merge_input_ids_with_image_features(
self, image_features, inputs_embeds, input_ids
):
image_token_index = self.config.image_token_index
num_images, num_image_patches, embed_dim = image_features.shape

# Positions of <image> tokens in input_ids, assuming batch size is 1
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
text_segments = []
start_idx = 0
num_images, _, vision_hidden_size = image_features.shape

for position in image_positions:
text_segments.append(inputs_embeds[:, start_idx:position])
start_idx = position + 1
reshaped_image_hidden_states = image_features.reshape(-1, vision_hidden_size)

image_embeddings = mx.split(image_features, image_features.shape[0])
final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
final_embeddings += [inputs_embeds[:, start_idx:]]

# Create a final embedding of shape
# (1, num_image_patches*num_images + sequence_len, embed_dim)
return mx.concatenate(final_embeddings, axis=1)
# cast to the dtype of the input_embeds to support quantized models
reshaped_image_hidden_states = reshaped_image_hidden_states.astype(
inputs_embeds.dtype
)
inputs_embeds[:, image_positions, :] = reshaped_image_hidden_states
return inputs_embeds

def __call__(
self,
Expand Down

0 comments on commit 86bdeda

Please sign in to comment.