diff --git a/llms/mlx_lm/__init__.py b/llms/mlx_lm/__init__.py index 502c78e5..ed39bb03 100644 --- a/llms/mlx_lm/__init__.py +++ b/llms/mlx_lm/__init__.py @@ -1,4 +1,4 @@ # Copyright © 2023-2024 Apple Inc. from ._version import __version__ -from .utils import convert, generate, load, stream_generate +from .utils import convert, generate, load, stream_generate, batch_generate diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index 20b008fa..1b403b39 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -26,7 +26,10 @@ def min_p_sampling( 0.99-0.8 range. min_tokens_to_keep (int, optional): Minimum number of tokens that cannot be filtered. Default: ``1``. - + temperature: Temperature parameter for softmax distribution reshaping. + Returns: + token(s) selected based on the min-p criterion. + Shape: same as logits, but with the last dimension having size 1. """ if not (0 <= min_p <= 1.0): raise ValueError( @@ -39,14 +42,14 @@ def min_p_sampling( # reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605 # Softmax probabilities - probs = mx.softmax(logits * (1 / temperature), axis=-1) + probs = mx.softmax(logits / temperature, axis=-1) # Indices sorted in decreasing order - sorted_indices = mx.argsort(-logits).squeeze(0) - sorted_probs = probs[..., sorted_indices] + sorted_indices = mx.argsort(-logits) + sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1) # Top probability - top_probs = probs[..., sorted_indices[0]] + top_probs = mx.expand_dims(sorted_probs[..., 0], axis=-1) # Calculate the min_p threshold scaled_min_p = min_p * top_probs @@ -58,13 +61,18 @@ def min_p_sampling( # Create pool of tokens with probability less than scaled min_p selected_probs = mx.where(tokens_to_remove, 0, sorted_probs) - # Return sampled token - sorted_token = mx.random.categorical(mx.log(selected_probs)) - return sorted_indices[sorted_token] + # Return sampled token(s) + sampled_indices = mx.random.categorical(mx.log(selected_probs)) + tokens = mx.take_along_axis( + sorted_indices, mx.expand_dims(sampled_indices, axis=-1), axis=-1 + ) + return tokens.squeeze(-1) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) -def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array: +def top_p_sampling( + logits: mx.array, top_p: float, temperature: float, axis: int = -1 +) -> mx.array: """ Apply top-p (nucleus) sampling to logits. @@ -72,29 +80,35 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr logits: The logits from the model's output. top_p: The cumulative probability threshold for top-p filtering. temperature: Temperature parameter for softmax distribution reshaping. + axis: The axis along which to apply top-p sampling. Returns: - token selected based on the top-p criterion. + token(s) selected based on the top-p criterion. """ - # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 - probs = mx.softmax(logits * (1 / temperature), axis=-1) + # Apply temperature and compute softmax + probs = mx.softmax(logits / temperature, axis=axis) - # sort probs in ascending order - sorted_indices = mx.argsort(probs, axis=-1) - sorted_probs = probs[..., sorted_indices.squeeze(0)] + # Sort probs in descending order + sorted_indices = mx.argsort(-probs, axis=axis) + sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=axis) - cumulative_probs = mx.cumsum(sorted_probs, axis=-1) + # Compute cumulative probabilities + cumulative_probs = mx.cumsum(sorted_probs, axis=axis) - # select tokens with cumulative probs below threshold - top_probs = mx.where( - cumulative_probs > 1 - top_p, - sorted_probs, - 0, - ) + # Create a mask for probs above the threshold + mask = cumulative_probs <= top_p + + # Apply the mask to the sorted probabilities + masked_probs = sorted_probs * mask - sorted_token = mx.random.categorical(mx.log(top_probs)) - token = sorted_indices.squeeze(0)[sorted_token] + # Sample from the normalized probabilities + sampled_indices = mx.random.categorical(mx.log(masked_probs), axis=axis) + + # Gather the original token indices + tokens = mx.take_along_axis( + sorted_indices, mx.expand_dims(sampled_indices, axis=axis), axis=axis + ) - return token + return tokens.squeeze(axis) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 42962b54..c5682630 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -411,7 +411,7 @@ def handle_completion( top_tokens = [] for (token, logprobs), _ in zip( generate_step( - prompt=prompt, + prompts=prompt[None], model=self.model, temp=self.temperature, top_p=self.top_p, @@ -421,6 +421,8 @@ def handle_completion( ), range(self.max_tokens), ): + token = token.item() + logprobs = logprobs.squeeze(0) detokenizer.add_token(token) logging.debug(detokenizer.text) tokens.append(token) @@ -498,7 +500,7 @@ def handle_stream( for (token, _), _ in zip( generate_step( - prompt=prompt, + prompts=prompt[None], model=self.model, temp=self.temperature, top_p=self.top_p, @@ -507,6 +509,7 @@ def handle_stream( ), range(self.max_tokens), ): + token = token.item() detokenizer.add_token(token) logging.debug(detokenizer.text) tokens.append(token) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index cfbcf29e..8b9fb27d 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -116,16 +116,16 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float) logits (mx.array): Logits with repetition penalty applied to generated tokens. """ if len(tokens) > 0: - selected_logits = logits[:, tokens] + selected_logits = mx.take_along_axis(logits, tokens, axis=-1) selected_logits = mx.where( selected_logits < 0, selected_logits * penalty, selected_logits / penalty ) - logits[:, tokens] = selected_logits + logits[mx.arange(tokens.shape[0])[:, None], tokens] = selected_logits return logits def generate_step( - prompt: mx.array, + prompts: mx.array, model: nn.Module, temp: float = 0.0, repetition_penalty: Optional[float] = None, @@ -143,7 +143,7 @@ def generate_step( A generator producing token ids based on the given prompt from the model. Args: - prompt (mx.array): The input prompt. + prompts (mx.array): The input prompt. model (nn.Module): The model to use for generation. temp (float): The temperature for sampling, if 0 the argmax is used. Default: ``0``. @@ -169,23 +169,29 @@ def generate_step( Yields: Generator[Tuple[mx.array, mx.array], None, None]: A generator producing - one token and a vector of log probabilities. + one token and a vector of log probabilities per prompt. + Shapes: ``(bs, 1), (bs, vocab_size)``. """ - def sample(logits: mx.array) -> Tuple[mx.array, float]: - logprobs = logits - mx.logsumexp(logits) + if prompts.ndim != 2: + raise ValueError( + f"Shape of prompts should be (bs, seq_len), got {prompts.shape}" + ) + + def sample(logits: mx.array) -> Tuple[mx.array, mx.array]: + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) if temp == 0: - token = mx.argmax(logits, axis=-1) + tokens = mx.argmax(logits, axis=-1) else: if top_p > 0 and top_p < 1.0: - token = top_p_sampling(logits, top_p, temp) + tokens = top_p_sampling(logits, top_p, temp) elif min_p != 0.0: - token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp) + tokens = min_p_sampling(logits, min_p, min_tokens_to_keep, temp) else: - token = categorical_sampling(logits, temp) + tokens = categorical_sampling(logits, temp) - return token, logprobs + return mx.expand_dims(tokens, axis=-1), logprobs if repetition_penalty and ( repetition_penalty < 0 or not isinstance(repetition_penalty, float) @@ -215,7 +221,7 @@ def logit_bias_processor(_, logits): logits_processor.append(logit_bias_processor) - y = prompt + y = prompts tokens = None # Create the KV cache for generation @@ -225,23 +231,23 @@ def logit_bias_processor(_, logits): raise ValueError("Wrong number of layers in the prompt cache.") def _step(y): - logits = model(y[None], cache=prompt_cache) + logits = model(y, cache=prompt_cache) logits = logits[:, -1, :] if logits_processor: nonlocal tokens - tokens = mx.concat([tokens, y]) if tokens is not None else y + tokens = mx.concat([tokens, y], axis=-1) if tokens is not None else y for processor in logits_processor: logits = processor(tokens, logits) y, logprobs = sample(logits) - return y, logprobs.squeeze(0) + return y, logprobs - while y.size > prefill_step_size: - model(y[:prefill_step_size][None], cache=prompt_cache) - mx.eval([c.state for c in prompt_cache]) - y = y[prefill_step_size:] + while y.shape[1] > prefill_step_size: + model(y[:, :prefill_step_size], cache=prompt_cache) + mx.eval([c.state for c in cache]) + y = y[:, prefill_step_size:] y, logprobs = _step(y) @@ -249,7 +255,8 @@ def _step(y): while True: next_y, next_logprobs = _step(y) mx.async_eval(next_y) - yield y.item(), logprobs + mx.eval(y) + yield y, logprobs y, logprobs = next_y, next_logprobs @@ -259,7 +266,7 @@ def stream_generate( prompt: str, max_tokens: int = 100, **kwargs, -) -> Union[str, Generator[str, None, None]]: +) -> Generator[str, None, None]: """ A generator producing text based on the given prompt from the model. @@ -280,10 +287,11 @@ def stream_generate( detokenizer = tokenizer.detokenizer detokenizer.reset() - for n, (token, _) in zip( + for _, (token, _) in zip( range(max_tokens), generate_step(prompt_tokens, model, **kwargs), ): + token = token.item() if token == tokenizer.eos_token_id: break detokenizer.add_token(token) @@ -303,7 +311,7 @@ def generate( verbose: bool = False, formatter: Optional[Callable] = None, **kwargs, -) -> Union[str, Generator[str, None, None]]: +) -> str: """ Generate a complete response from the model. @@ -334,8 +342,9 @@ def generate( for n, (token, logprobs) in zip( range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), + generate_step(prompt_tokens[None], model, **kwargs), ): + token = token.item() if n == 0: prompt_time = time.perf_counter() - tic tic = time.perf_counter() @@ -371,6 +380,85 @@ def generate( return detokenizer.text +def batch_generate( + model: nn.Module, + tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], + prompts: List[str], + max_tokens: int = 100, + verbose: bool = False, + **kwargs, +) -> str: + """ + Generate a complete response from the model. + + Args: + model (nn.Module): The language model. + tokenizer (PreTrainedTokenizer): The tokenizer. + prompts (List[str]): The string prompts. + max_tokens (int): The maximum number of tokens. Default: ``100``. + verbose (bool): If ``True``, print tokens and timing information. + Default: ``False``. + kwargs: The remaining options get passed to :func:`generate_step`. + See :func:`generate_step` for more details. + """ + if kwargs.get("max_kv_size", None) is not None: + raise ValueError("max_kv_size is not supported for batch generation") + + if not isinstance(tokenizer, TokenizerWrapper): + tokenizer = TokenizerWrapper(tokenizer) + + tokenizer._tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer._tokenizer.pad_token = tokenizer.eos_token + tokenizer._tokenizer.pad_token_id = tokenizer.eos_token_id + prompt_tokens = mx.array( + tokenizer._tokenizer(prompts, padding=True)["input_ids"] + ) + output_toks = [] + + tic = time.perf_counter() + + for (tokens, _), n in zip( + generate_step(prompt_tokens, model, **kwargs), + range(max_tokens), + ): + if n == 0: + prompt_time = time.perf_counter() - tic + tic = time.perf_counter() + if (tokens == tokenizer.eos_token_id).all(): + break + output_toks.append(tokens) + if verbose: + print(".", end="", flush=True) + + output_toks = mx.concatenate(output_toks, axis=1) + token_count = output_toks.size + response = [ + response.split(tokenizer.eos_token)[0].split(tokenizer.pad_token)[0] + for response in tokenizer.batch_decode(output_toks.tolist()) + ] + + if verbose: + gen_time = time.perf_counter() - tic + if token_count <= 0: + print("No tokens generated for this prompt") + else: + print() + for p, resp in zip(prompts, response): + print("=" * 10) + print("Prompt:", p) + print(resp) + prompt_tps = prompt_tokens.size / prompt_time + gen_tps = token_count / gen_time + print("=" * 10) + print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec") + print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") + peak_mem = mx.metal.get_peak_memory() / 2**30 + print(f"Peak memory: {peak_mem:.3f} GB") + + return response + + def load_config(model_path: Path) -> dict: try: with open(model_path / "config.json", "r") as f: