Skip to content

Commit

Permalink
Fix multi-image and 2x speed improvements (DS-VL2) (#157)
Browse files Browse the repository at this point in the history
* fix multi-image and use .tolist() for (2.16× prompt and 1.83x for generation speedup)

* format
  • Loading branch information
Blaizzy authored Dec 24, 2024
1 parent f0b0058 commit 050a6d7
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,14 @@ def process_image_features(
if num_width_tiles == 0 or num_height_tiles == 0:
break

num_tiles_in_image = num_width_tiles * num_height_tiles
num_tiles_in_image = (num_width_tiles * num_height_tiles).tolist()

# Get global features [hw, D]
global_features = images_embeds[tile_index]

# Get local features [num_height_tiles * num_width_tiles, hw, D]
local_features = images_embeds[
tile_index + 1 : tile_index + 1 + int(num_tiles_in_image)
tile_index + 1 : tile_index + 1 + num_tiles_in_image
]

tile_index += num_tiles_in_image + 1
Expand Down Expand Up @@ -378,14 +378,18 @@ def get_input_embeddings(

batch_num_tiles = [0 for _ in range(bs)]
total_tiles = []

# Total number of tiles in each batch
for idx in range(bs):
for jdx in range(max_n_images):
num_width_tiles, num_height_tiles = images_spatial_crop[idx][jdx]
if num_width_tiles == 0 or num_height_tiles == 0:
break
batch_num_tiles[idx] += 1 + num_width_tiles * num_height_tiles
batch_num_tiles[idx] += (
1 + num_width_tiles * num_height_tiles
).tolist()

total_tiles.append(pixel_values[idx, : int(batch_num_tiles[idx])])
total_tiles.append(pixel_values[idx, : batch_num_tiles[idx]])

total_tiles = mx.concatenate(total_tiles, axis=0)
assert total_tiles.shape[0] == sum(batch_num_tiles)
Expand Down

0 comments on commit 050a6d7

Please sign in to comment.