Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(mlx_lm): basic speculative decoding support in mlx_lm.generate / mlx_lm.server #954

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions llms/mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ def setup_arg_parser():
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--draft-model",
type=str,
required=False,
help="The path to the local model directory or Hugging Face repo for speculative decoding.",
)
parser.add_argument(
"--adapter-path",
type=str,
Expand Down Expand Up @@ -81,7 +87,7 @@ def setup_arg_parser():
"--max-kv-size",
type=int,
default=1024,
help="Set the maximum key-value cache size",
help="Set the maximum key-value cache size (0 for unlimited)",
)
return parser

Expand Down Expand Up @@ -132,6 +138,7 @@ def main():
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
)
draft_model = load(args.draft_model)[0] if args.draft_model is not None else None

if args.use_default_chat_template:
if tokenizer.chat_template is None:
Expand Down Expand Up @@ -159,7 +166,8 @@ def main():
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
max_kv_size=args.max_kv_size,
max_kv_size=args.max_kv_size if args.max_kv_size > 0 else None,
draft_model=draft_model,
)


Expand Down
18 changes: 11 additions & 7 deletions llms/mlx_lm/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


class KVCache:

def __init__(self, head_dim, n_kv_heads):
self.n_kv_heads = n_kv_heads
if isinstance(head_dim, int):
Expand All @@ -23,6 +22,13 @@ def __init__(self, head_dim, n_kv_heads):
self.offset = 0
self.step = 256

def drop(self, n):
if n >= self.offset:
self.keys = self.values = None
self.offset = 0
elif n > 0:
self.offset -= n

def update_and_fetch(self, keys, values):
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
Expand All @@ -33,11 +39,10 @@ def update_and_fetch(self, keys, values):
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
self.values = self.values[..., :prev, :]
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
self.keys = mx.concatenate([self.keys[..., :prev, :], new_k], axis=2)
self.values = mx.concatenate(
[self.values[..., :prev, :], new_v], axis=2
)
else:
self.keys, self.values = new_k, new_v

Expand All @@ -51,7 +56,6 @@ def state(self):


class RotatingKVCache:

def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256):
self.n_kv_heads = n_kv_heads
if isinstance(head_dim, int):
Expand Down
64 changes: 39 additions & 25 deletions llms/mlx_lm/sample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def min_p_sampling(
0.99-0.8 range.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered. Default: ``1``.

temperature: Temperature parameter for softmax distribution reshaping.
Returns:
token(s) selected based on the min-p criterion.
Shape: same as logits, but with the last dimension having size 1.
"""
if not (0 <= min_p <= 1.0):
raise ValueError(
Expand All @@ -39,14 +42,14 @@ def min_p_sampling(
# reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605

# Softmax probabilities
probs = mx.softmax(logits * (1 / temperature), axis=-1)
probs = mx.softmax(logits / temperature, axis=-1)

# Indices sorted in decreasing order
sorted_indices = mx.argsort(-logits).squeeze(0)
sorted_probs = probs[..., sorted_indices]
sorted_indices = mx.argsort(-logits)
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1)

# Top probability
top_probs = probs[..., sorted_indices[0]]
top_probs = mx.expand_dims(sorted_probs[..., 0], axis=-1)

# Calculate the min_p threshold
scaled_min_p = min_p * top_probs
Expand All @@ -58,43 +61,54 @@ def min_p_sampling(
# Create pool of tokens with probability less than scaled min_p
selected_probs = mx.where(tokens_to_remove, 0, sorted_probs)

# Return sampled token
sorted_token = mx.random.categorical(mx.log(selected_probs))
return sorted_indices[sorted_token]
# Return sampled token(s)
sampled_indices = mx.random.categorical(mx.log(selected_probs))
tokens = mx.take_along_axis(
sorted_indices, mx.expand_dims(sampled_indices, axis=-1), axis=-1
)
return tokens.squeeze(-1)


@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
def top_p_sampling(
logits: mx.array, top_p: float, temperature: float, axis: int = -1
) -> mx.array:
"""
Apply top-p (nucleus) sampling to logits.

Args:
logits: The logits from the model's output.
top_p: The cumulative probability threshold for top-p filtering.
temperature: Temperature parameter for softmax distribution reshaping.
axis: The axis along which to apply top-p sampling.
Returns:
token selected based on the top-p criterion.
token(s) selected based on the top-p criterion.
"""
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
probs = mx.softmax(logits * (1 / temperature), axis=-1)
# Apply temperature and compute softmax
probs = mx.softmax(logits / temperature, axis=axis)

# sort probs in ascending order
sorted_indices = mx.argsort(probs, axis=-1)
sorted_probs = probs[..., sorted_indices.squeeze(0)]
# Sort probs in descending order
sorted_indices = mx.argsort(-probs, axis=axis)
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=axis)

cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
# Compute cumulative probabilities
cumulative_probs = mx.cumsum(sorted_probs, axis=axis)

# select tokens with cumulative probs below threshold
top_probs = mx.where(
cumulative_probs > 1 - top_p,
sorted_probs,
0,
)
# Create a mask for probs above the threshold
mask = cumulative_probs <= top_p

# Apply the mask to the sorted probabilities
masked_probs = sorted_probs * mask

sorted_token = mx.random.categorical(mx.log(top_probs))
token = sorted_indices.squeeze(0)[sorted_token]
# Sample from the normalized probabilities
sampled_indices = mx.random.categorical(mx.log(masked_probs), axis=axis)

# Gather the original token indices
tokens = mx.take_along_axis(
sorted_indices, mx.expand_dims(sampled_indices, axis=axis), axis=axis
)

return token
return tokens.squeeze(axis)


@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
Expand Down
19 changes: 17 additions & 2 deletions llms/mlx_lm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def __init__(self, cli_args: argparse.Namespace):
# Preload the default model if it is provided
if self.cli_args.model is not None:
self.load("default_model")
if self.cli_args.draft_model is not None:
self.draft_model, _ = load(self.cli_args.model)

def _validate_model_path(self, model_path: str):
model_path = Path(model_path)
Expand Down Expand Up @@ -161,6 +163,7 @@ def __init__(self, model_provider: ModelProvider, *args, **kwargs):
"""
self.created = int(time.time())
self.model_provider = model_provider
self.draft_model = model_provider.draft_model
super().__init__(*args, **kwargs)

def _set_cors_headers(self):
Expand Down Expand Up @@ -410,8 +413,10 @@ def handle_completion(
top_tokens = []
for (token, logprobs), _ in zip(
generate_step(
prompt=prompt,
prompts=prompt[None],
model=self.model,
draft_model=self.draft_model,
tokenizer=self.tokenizer,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
Expand All @@ -420,6 +425,8 @@ def handle_completion(
),
range(self.max_tokens),
):
token = token.item()
logprobs = logprobs.squeeze(0)
detokenizer.add_token(token)
logging.debug(detokenizer.text)
tokens.append(token)
Expand Down Expand Up @@ -497,15 +504,18 @@ def handle_stream(

for (token, _), _ in zip(
generate_step(
prompt=prompt,
prompts=prompt[None],
model=self.model,
draft_model=self.draft_model,
tokenizer=self.tokenizer,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
),
range(self.max_tokens),
):
token = token.item()
detokenizer.add_token(token)
logging.debug(detokenizer.text)
tokens.append(token)
Expand Down Expand Up @@ -646,6 +656,11 @@ def main():
type=str,
help="The path to the MLX model weights, tokenizer, and config",
)
parser.add_argument(
"--draft-model",
type=str,
help="The path to the MLX model weights and config for speculative decoding",
)
parser.add_argument(
"--adapter-path",
type=str,
Expand Down
Loading