Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(mlx_lm): support batch input in generate() #948

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llms/mlx_lm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright © 2023-2024 Apple Inc.

from ._version import __version__
from .utils import convert, generate, load, stream_generate
from .utils import convert, generate, load, stream_generate, batch_generate
64 changes: 39 additions & 25 deletions llms/mlx_lm/sample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def min_p_sampling(
0.99-0.8 range.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered. Default: ``1``.

temperature: Temperature parameter for softmax distribution reshaping.
Returns:
token(s) selected based on the min-p criterion.
Shape: same as logits, but with the last dimension having size 1.
"""
if not (0 <= min_p <= 1.0):
raise ValueError(
Expand All @@ -39,14 +42,14 @@ def min_p_sampling(
# reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605

# Softmax probabilities
probs = mx.softmax(logits * (1 / temperature), axis=-1)
probs = mx.softmax(logits / temperature, axis=-1)

# Indices sorted in decreasing order
sorted_indices = mx.argsort(-logits).squeeze(0)
sorted_probs = probs[..., sorted_indices]
sorted_indices = mx.argsort(-logits)
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1)

# Top probability
top_probs = probs[..., sorted_indices[0]]
top_probs = mx.expand_dims(sorted_probs[..., 0], axis=-1)

# Calculate the min_p threshold
scaled_min_p = min_p * top_probs
Expand All @@ -58,43 +61,54 @@ def min_p_sampling(
# Create pool of tokens with probability less than scaled min_p
selected_probs = mx.where(tokens_to_remove, 0, sorted_probs)

# Return sampled token
sorted_token = mx.random.categorical(mx.log(selected_probs))
return sorted_indices[sorted_token]
# Return sampled token(s)
sampled_indices = mx.random.categorical(mx.log(selected_probs))
tokens = mx.take_along_axis(
sorted_indices, mx.expand_dims(sampled_indices, axis=-1), axis=-1
)
return tokens.squeeze(-1)


@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
def top_p_sampling(
logits: mx.array, top_p: float, temperature: float, axis: int = -1
) -> mx.array:
"""
Apply top-p (nucleus) sampling to logits.

Args:
logits: The logits from the model's output.
top_p: The cumulative probability threshold for top-p filtering.
temperature: Temperature parameter for softmax distribution reshaping.
axis: The axis along which to apply top-p sampling.
Returns:
token selected based on the top-p criterion.
token(s) selected based on the top-p criterion.
"""
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
probs = mx.softmax(logits * (1 / temperature), axis=-1)
# Apply temperature and compute softmax
probs = mx.softmax(logits / temperature, axis=axis)

# sort probs in ascending order
sorted_indices = mx.argsort(probs, axis=-1)
sorted_probs = probs[..., sorted_indices.squeeze(0)]
# Sort probs in descending order
sorted_indices = mx.argsort(-probs, axis=axis)
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=axis)

cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
# Compute cumulative probabilities
cumulative_probs = mx.cumsum(sorted_probs, axis=axis)

# select tokens with cumulative probs below threshold
top_probs = mx.where(
cumulative_probs > 1 - top_p,
sorted_probs,
0,
)
# Create a mask for probs above the threshold
mask = cumulative_probs <= top_p

# Apply the mask to the sorted probabilities
masked_probs = sorted_probs * mask

sorted_token = mx.random.categorical(mx.log(top_probs))
token = sorted_indices.squeeze(0)[sorted_token]
# Sample from the normalized probabilities
sampled_indices = mx.random.categorical(mx.log(masked_probs), axis=axis)

# Gather the original token indices
tokens = mx.take_along_axis(
sorted_indices, mx.expand_dims(sampled_indices, axis=axis), axis=axis
)

return token
return tokens.squeeze(axis)


@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
Expand Down
7 changes: 5 additions & 2 deletions llms/mlx_lm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def handle_completion(
top_tokens = []
for (token, logprobs), _ in zip(
generate_step(
prompt=prompt,
prompts=prompt[None],
model=self.model,
temp=self.temperature,
top_p=self.top_p,
Expand All @@ -421,6 +421,8 @@ def handle_completion(
),
range(self.max_tokens),
):
token = token.item()
logprobs = logprobs.squeeze(0)
detokenizer.add_token(token)
logging.debug(detokenizer.text)
tokens.append(token)
Expand Down Expand Up @@ -498,7 +500,7 @@ def handle_stream(

for (token, _), _ in zip(
generate_step(
prompt=prompt,
prompts=prompt[None],
model=self.model,
temp=self.temperature,
top_p=self.top_p,
Expand All @@ -507,6 +509,7 @@ def handle_stream(
),
range(self.max_tokens),
):
token = token.item()
detokenizer.add_token(token)
logging.debug(detokenizer.text)
tokens.append(token)
Expand Down
138 changes: 113 additions & 25 deletions llms/mlx_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,16 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
logits (mx.array): Logits with repetition penalty applied to generated tokens.
"""
if len(tokens) > 0:
selected_logits = logits[:, tokens]
selected_logits = mx.take_along_axis(logits, tokens, axis=-1)
selected_logits = mx.where(
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
)
logits[:, tokens] = selected_logits
logits[mx.arange(tokens.shape[0])[:, None], tokens] = selected_logits
return logits


def generate_step(
prompt: mx.array,
prompts: mx.array,
model: nn.Module,
temp: float = 0.0,
repetition_penalty: Optional[float] = None,
Expand All @@ -143,7 +143,7 @@ def generate_step(
A generator producing token ids based on the given prompt from the model.

Args:
prompt (mx.array): The input prompt.
prompts (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
Expand All @@ -169,23 +169,29 @@ def generate_step(

Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
one token and a vector of log probabilities.
one token and a vector of log probabilities per prompt.
Shapes: ``(bs, 1), (bs, vocab_size)``.
"""

def sample(logits: mx.array) -> Tuple[mx.array, float]:
logprobs = logits - mx.logsumexp(logits)
if prompts.ndim != 2:
raise ValueError(
f"Shape of prompts should be (bs, seq_len), got {prompts.shape}"
)

def sample(logits: mx.array) -> Tuple[mx.array, mx.array]:
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)

if temp == 0:
token = mx.argmax(logits, axis=-1)
tokens = mx.argmax(logits, axis=-1)
else:
if top_p > 0 and top_p < 1.0:
token = top_p_sampling(logits, top_p, temp)
tokens = top_p_sampling(logits, top_p, temp)
elif min_p != 0.0:
token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
tokens = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
else:
token = categorical_sampling(logits, temp)
tokens = categorical_sampling(logits, temp)

return token, logprobs
return mx.expand_dims(tokens, axis=-1), logprobs

if repetition_penalty and (
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
Expand Down Expand Up @@ -215,7 +221,7 @@ def logit_bias_processor(_, logits):

logits_processor.append(logit_bias_processor)

y = prompt
y = prompts
tokens = None

# Create the KV cache for generation
Expand All @@ -225,31 +231,32 @@ def logit_bias_processor(_, logits):
raise ValueError("Wrong number of layers in the prompt cache.")

def _step(y):
logits = model(y[None], cache=prompt_cache)
logits = model(y, cache=prompt_cache)
logits = logits[:, -1, :]

if logits_processor:
nonlocal tokens
tokens = mx.concat([tokens, y]) if tokens is not None else y
tokens = mx.concat([tokens, y], axis=-1) if tokens is not None else y

for processor in logits_processor:
logits = processor(tokens, logits)

y, logprobs = sample(logits)
return y, logprobs.squeeze(0)
return y, logprobs

while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache])
y = y[prefill_step_size:]
while y.shape[1] > prefill_step_size:
model(y[:, :prefill_step_size], cache=prompt_cache)
mx.eval([c.state for c in cache])
y = y[:, prefill_step_size:]

y, logprobs = _step(y)

mx.async_eval(y)
while True:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y)
yield y.item(), logprobs
mx.eval(y)
yield y, logprobs
y, logprobs = next_y, next_logprobs


Expand All @@ -259,7 +266,7 @@ def stream_generate(
prompt: str,
max_tokens: int = 100,
**kwargs,
) -> Union[str, Generator[str, None, None]]:
) -> Generator[str, None, None]:
"""
A generator producing text based on the given prompt from the model.

Expand All @@ -280,10 +287,11 @@ def stream_generate(
detokenizer = tokenizer.detokenizer

detokenizer.reset()
for n, (token, _) in zip(
for _, (token, _) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
token = token.item()
if token == tokenizer.eos_token_id:
break
detokenizer.add_token(token)
Expand All @@ -303,7 +311,7 @@ def generate(
verbose: bool = False,
formatter: Optional[Callable] = None,
**kwargs,
) -> Union[str, Generator[str, None, None]]:
) -> str:
"""
Generate a complete response from the model.

Expand Down Expand Up @@ -334,8 +342,9 @@ def generate(

for n, (token, logprobs) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
generate_step(prompt_tokens[None], model, **kwargs),
):
token = token.item()
if n == 0:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
Expand Down Expand Up @@ -371,6 +380,85 @@ def generate(
return detokenizer.text


def batch_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompts: List[str],
max_tokens: int = 100,
verbose: bool = False,
**kwargs,
) -> str:
"""
Generate a complete response from the model.

Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompts (List[str]): The string prompts.
max_tokens (int): The maximum number of tokens. Default: ``100``.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
"""
if kwargs.get("max_kv_size", None) is not None:
raise ValueError("max_kv_size is not supported for batch generation")

if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)

tokenizer._tokenizer.padding_side = "left"
if tokenizer.pad_token is None:
tokenizer._tokenizer.pad_token = tokenizer.eos_token
tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id
prompt_tokens = mx.array(
tokenizer._tokenizer(prompts, padding=True)["input_ids"]
)
output_toks = []

tic = time.perf_counter()

for (tokens, _), n in zip(
generate_step(prompt_tokens, model, **kwargs),
range(max_tokens),
):
if n == 0:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
if (tokens == tokenizer.eos_token_id).all():
break
output_toks.append(tokens)
if verbose:
print(".", end="", flush=True)

output_toks = mx.concatenate(output_toks, axis=1)
token_count = output_toks.size
response = [
response.split(tokenizer.eos_token)[0].split(tokenizer.pad_token)[0]
for response in tokenizer.batch_decode(output_toks.tolist())
]

if verbose:
gen_time = time.perf_counter() - tic
if token_count <= 0:
print("No tokens generated for this prompt")
else:
print()
for p, resp in zip(prompts, response):
print("=" * 10)
print("Prompt:", p)
print(resp)
prompt_tps = prompt_tokens.size / prompt_time
gen_tps = token_count / gen_time
print("=" * 10)
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 2**30
print(f"Peak memory: {peak_mem:.3f} GB")

return response


def load_config(model_path: Path) -> dict:
try:
with open(model_path / "config.json", "r") as f:
Expand Down