Skip to content

Commit

Permalink
Merge branch 'main' into pc/video
Browse files Browse the repository at this point in the history
  • Loading branch information
Blaizzy authored Jan 31, 2025
2 parents cb082d6 + 5c7d159 commit 7efef0e
Show file tree
Hide file tree
Showing 25 changed files with 1,494 additions and 107 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ jobs:
- name: Run Python tests
run: |
cd mlx_vlm/
pytest -s ./tests
pytest -s ./tests --ignore=tests/test_smoke.py
9 changes: 8 additions & 1 deletion mlx_vlm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
from .prompt_utils import apply_chat_template, get_message_json
from .utils import convert, generate, load, prepare_inputs, process_image
from .utils import (
convert,
generate,
load,
prepare_inputs,
process_image,
quantize_model,
)
from .version import __version__
89 changes: 69 additions & 20 deletions mlx_vlm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@
import codecs

from .prompt_utils import apply_chat_template
from .utils import generate, get_model_path, load, load_config, load_image_processor
from .utils import (
generate,
get_model_path,
load,
load_config,
load_image_processor,
stream_generate,
)

DEFAULT_MODEL_PATH = "mlx-community/nanoLLaVA-1.5-8bit"
DEFAULT_IMAGE = []
Expand Down Expand Up @@ -39,7 +46,7 @@ def parse_arguments():
parser.add_argument(
"--resize-shape",
type=int,
nargs=2,
nargs="+",
default=None,
help="Resize shape for the image.",
)
Expand All @@ -49,16 +56,27 @@ def parse_arguments():
default=DEFAULT_PROMPT,
help="Message to be processed by the model.",
)
parser.add_argument(
"--system",
type=str,
default=None,
help="System message for the model.",
)
parser.add_argument(
"--max-tokens",
type=int,
default=DEFAULT_MAX_TOKENS,
help="Maximum number of tokens to generate.",
)
parser.add_argument(
"--temp", type=float, default=DEFAULT_TEMP, help="Temperature for sampling."
"--temperature",
type=float,
default=DEFAULT_TEMP,
help="Temperature for sampling.",
)
parser.add_argument("--chat", action="store_true", help="Chat in multi-turn style.")
parser.add_argument("--verbose", action="store_false", help="Detailed output.")

return parser.parse_args()


Expand All @@ -83,24 +101,55 @@ def main():
prompt = apply_chat_template(processor, config, prompt, num_images=len(args.image))

kwargs = {}

if args.resize_shape is not None:
assert (
len(args.resize_shape) == 2
), "Resize shape must be a tuple of two integers"
kwargs["resize_shape"] = args.resize_shape

output = generate(
model,
processor,
prompt,
image=args.image,
temp=args.temp,
max_tokens=args.max_tokens,
verbose=args.verbose,
**kwargs,
)
if not args.verbose:
print(output)
if len(args.resize_shape) not in [1, 2]:
raise ValueError("Resize shape must be 1 or 2 integers")
kwargs["resize_shape"] = (
(args.resize_shape[0],) * 2
if len(args.resize_shape) == 1
else tuple(args.resize_shape)
)

if args.chat:
chat = []
if args.system:
chat.append({"role": "system", "content": args.system})
while user := input("User:"):
chat.append({"role": "user", "content": user})
prompt = apply_chat_template(
processor, config, chat, num_images=len(args.image)
)
response = ""
print("Assistant:", end="")
for chunk in stream_generate(
model,
processor,
prompt,
args.image,
max_tokens=args.max_tokens,
temperature=args.temperature,
**kwargs,
):
response += chunk.text
print(chunk.text, end="")

chat.append({"role": "assistant", "content": response})
print()

else:
output = generate(
model,
processor,
prompt,
image=args.image,
temperature=args.temperature,
max_tokens=args.max_tokens,
verbose=args.verbose,
**kwargs,
)
if not args.verbose:
print(output)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def get_input_embeddings(
total_tiles.append(pixel_values[idx, : batch_num_tiles[idx]])

total_tiles = mx.concatenate(total_tiles, axis=0)
assert total_tiles.shape[0] == sum(batch_num_tiles)

if total_tiles.shape[0] == 0:
return self.language_model.model.embed_tokens(input_ids)

Expand Down
6 changes: 3 additions & 3 deletions mlx_vlm/models/deepseek_vl_v2/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,9 @@ def __call__(self, x):

# Calculate group scores using top-2 sum per group
scores_reshaped = scores_for_choice.reshape(bsz * seq_len, self.n_group, -1)
k = 2
group_scores_topk = mx.sort(scores_reshaped, axis=-1)[..., -k:]
group_scores = group_scores_topk.sum(axis=-1)

# Get top 2 scores per group
group_scores = mx.topk(scores_reshaped, 2, axis=-1).sum(axis=-1)

# Get top groups
k = self.n_group - self.topk_group
Expand Down
33 changes: 16 additions & 17 deletions mlx_vlm/models/idefics2/idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ class ModelConfig:
perceiver_config: PerceiverConfig
model_type: str
ignore_index: int = -100
image_token_index: int = 32001
image_token_id: int = 32001
vocab_size: int = 151936
image_token_index: Optional[int] = None

@classmethod
def from_dict(cls, params):
Expand All @@ -56,6 +57,10 @@ def from_dict(cls, params):
}
)

def __post_init__(self):
if self.image_token_index is None:
self.image_token_index = self.image_token_id


class Idefics2PerceiverAttention(nn.Module):
def __init__(self, config: ModelConfig):
Expand Down Expand Up @@ -219,9 +224,7 @@ def get_input_embeddings(
pooler_output, embeddings, hidden_state = self.vision_model(
pixel_values[0].transpose(0, 2, 3, 1), output_hidden_states=True
)

image_features = pooler_output[None, :].astype(pixel_values.dtype)

image_features = pooler_output.astype(pixel_values.dtype)
image_features = self.connector(image_features, mask=None)

final_inputs_embeds = self._prepare_inputs_for_multimodal(
Expand All @@ -231,25 +234,21 @@ def get_input_embeddings(

def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids):
image_token_index = self.config.image_token_index
num_images, num_image_patches, embed_dim = image_features.shape

# Positions of <image> tokens in input_ids, assuming batch size is 1
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
image_positions = np.where(input_ids == image_token_index)[1].tolist()
num_images, _, vision_hidden_size = image_features.shape

text_segments = []
start_idx = 0
reshaped_image_hidden_states = image_features.reshape(-1, vision_hidden_size)

for position in image_positions:
text_segments.append(inputs_embeds[:, start_idx:position])
start_idx = position + 1
# cast to the dtype of the input_embeds to support quantized models
reshaped_image_hidden_states = reshaped_image_hidden_states.astype(
inputs_embeds.dtype
)

image_embeddings = mx.split(image_features, image_features.shape[0])
final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
final_embeddings += [inputs_embeds[:, start_idx:]]
inputs_embeds[:, image_positions, :] = reshaped_image_hidden_states

# Create a final embedding of shape
# (1, num_image_patches*num_images + sequence_len, embed_dim)
return mx.concatenate(final_embeddings, axis=1)
return inputs_embeds

def __call__(
self,
Expand Down
3 changes: 1 addition & 2 deletions mlx_vlm/models/idefics2/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __call__(
if output_hidden_states:
encoder_states = encoder_states + (x,)

h = x[0]
h = x

return (h, encoder_states)

Expand Down Expand Up @@ -243,7 +243,6 @@ def __call__(
)

x = self.embeddings(x, mask=patch_attention_mask)

encoder_outputs = self.encoder(x=x, output_hidden_states=output_hidden_states)

pooler_output = self.post_layernorm(encoder_outputs[0])
Expand Down
23 changes: 9 additions & 14 deletions mlx_vlm/models/idefics3/idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def get_input_embeddings(
pixel_values[0].transpose(0, 2, 3, 1), output_hidden_states=True
)

# image_features = pooler_output[None, :].astype(pixel_values.dtype)
image_features = pooler_output.astype(pixel_values.dtype)
image_features = self.connector(image_features)

Expand All @@ -117,25 +116,21 @@ def get_input_embeddings(

def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids):
image_token_index = self.config.image_token_index
num_images, num_image_patches, embed_dim = image_features.shape

# Positions of <image> tokens in input_ids, assuming batch size is 1
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
image_positions = np.where(input_ids == image_token_index)[1].tolist()

text_segments = []
start_idx = 0
num_images, _, vision_hidden_size = image_features.shape

for position in image_positions:
text_segments.append(inputs_embeds[:, start_idx:position])
start_idx = position + 1
reshaped_image_hidden_states = image_features.reshape(-1, vision_hidden_size)

image_embeddings = mx.split(image_features, image_features.shape[0])
final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
final_embeddings += [inputs_embeds[:, start_idx:]]
# cast to the dtype of the input_embeds to support quantized models
reshaped_image_hidden_states = reshaped_image_hidden_states.astype(
inputs_embeds.dtype
)
inputs_embeds[:, image_positions, :] = reshaped_image_hidden_states

# Create a final embedding of shape
# (1, num_image_patches*num_images + sequence_len, embed_dim)
return mx.concatenate(final_embeddings, axis=1)
return inputs_embeds

def __call__(
self,
Expand Down
10 changes: 10 additions & 0 deletions mlx_vlm/models/llava_bunny/llava_bunny.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ class ModelConfig:

@classmethod
def from_dict(cls, params):
if not params.get("text_config", {}):
# Copy text config parameters from root level
excluded_keys = {"vision_config"}
params["text_config"] = dict(
filter(lambda x: x[0] not in excluded_keys, params.items())
)
if not params.get("vision_config", {}).get("model_type", {}):
# Set default model type
params["vision_config"]["model_type"] = "siglip_vision_model"

return cls(
**{
k: v
Expand Down
2 changes: 1 addition & 1 deletion mlx_vlm/models/multi_modality/multi_modality.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def get_input_embeddings(
pixel_values: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model(input_ids)
return self.language_model.model.embed_tokens(input_ids)

image_token_index = self.config.image_token_index
num_image_tokens = self.config.num_image_tokens
Expand Down
2 changes: 1 addition & 1 deletion mlx_vlm/models/paligemma/paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def get_input_embeddings(
mask: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model(input_ids)
return self.language_model.model.embed_tokens(input_ids), None

inputs_embeds = self.language_model.model.embed_tokens(input_ids)

Expand Down
8 changes: 8 additions & 0 deletions mlx_vlm/models/qwen2_5_vl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .qwen2_5_vl import (
LanguageModel,
Model,
ModelConfig,
TextConfig,
VisionConfig,
VisionModel,
)
Loading

0 comments on commit 7efef0e

Please sign in to comment.