diff --git a/.gitignore b/.gitignore index f3dfe929..45445fc8 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,9 @@ __pycache__/ # C extensions *.so +# Vim +*.swp + # Distribution / packaging .Python build/ diff --git a/llms/README.md b/llms/README.md index 75677865..20863041 100644 --- a/llms/README.md +++ b/llms/README.md @@ -20,6 +20,31 @@ The `mlx-lm` package also has: - [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md) - [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md) +### Quick Start + +To generate text with an LLM use: + +```bash +mlx_lm.generate --prompt "Hi!" +``` + +To chat with an LLM use: + +```bash +mlx_lm.chat +``` + +This will give you a chat REPL that you can use to interact with the LLM. The +chat context is preserved during the lifetime of the REPL. + +Commands in `mlx-lm` typically take command line options which let you specify +the model, sampling parameters, and more. Use `-h` to see a list of available +options for a command, e.g.: + +```bash +mlx_lm.generate -h +``` + ### Python API You can use `mlx-lm` as a module: @@ -138,7 +163,7 @@ mlx_lm.convert \ ### Long Prompts and Generations -MLX LM has some tools to scale efficiently to long prompts and generations: +`mlx-lm` has some tools to scale efficiently to long prompts and generations: - A rotating fixed-size key-value cache. - Prompt caching @@ -155,14 +180,14 @@ different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example: cat prompt.txt | mlx_lm.cache_prompt \ --model mistralai/Mistral-7B-Instruct-v0.3 \ --prompt - \ - --kv-cache-file mistral_prompt.safetensors + --prompt-cache-file mistral_prompt.safetensors ``` Then use the cached prompt with `mlx_lm.generate`: ``` mlx_lm.generate \ - --kv-cache-file mistral_prompt.safetensors \ + --prompt-cache-file mistral_prompt.safetensors \ --prompt "\nSummarize the above text." ``` @@ -170,9 +195,15 @@ The cached prompt is treated as a prefix to the supplied prompt. Also notice when using a cached prompt, the model to use is read from the cache and need not be supplied explicitly. +Prompt caching can also be used in the Python API in order to to avoid +recomputing the prompt. This is useful in multi-turn dialogues or across +requests that use the same context. See the +[example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py) +for more usage details. + ### Supported Models -MLX LM supports thousands of Hugging Face format LLMs. If the model you want to +`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to run is not supported, file an [issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet, submit a pull request. diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index 8110c823..70239db6 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.18.2" +__version__ = "0.19.1" diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index 9829efb4..04e75a3e 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -7,13 +7,14 @@ import mlx.core as mx -from .utils import load, make_kv_caches +from .models.cache import make_prompt_cache, save_prompt_cache +from .utils import load def setup_arg_parser(): """Set up and return the argument parser.""" parser = argparse.ArgumentParser( - description="Cache the KV cache of a prompt to be reused with mlx_lm.generate" + description="Cache the state of a prompt to be reused with mlx_lm.generate" ) parser.add_argument( "--model", @@ -60,7 +61,9 @@ def setup_arg_parser(): help="Set the maximum key-value cache size", ) parser.add_argument( - "--kv-cache-file", help="The file to save the KV caches in", required=True + "--prompt-cache-file", + help="The file to save the prompt cache in", + required=True, ) parser.add_argument( "--prompt", @@ -115,7 +118,7 @@ def main(): else: prompt = args.prompt - cache = make_kv_caches(model, args.max_kv_size) + cache = make_prompt_cache(model, args.max_kv_size) y = mx.array(tokenizer.encode(prompt)) # Process the prompt @@ -137,16 +140,12 @@ def main(): print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") print("Saving...") - cache_dict = {} - for i, c in enumerate(cache): - cache_dict[f"{i}_keys"] = c.state[0][..., : c.offset, :] - cache_dict[f"{i}_values"] = c.state[1][..., : c.offset, :] metadata = {} metadata["model"] = args.model metadata["chat_template"] = tokenizer.chat_template metadata["tokenizer_config"] = json.dumps(tokenizer_config) - metadata["max_kv_size"] = str(args.max_kv_size) - mx.save_safetensors(args.kv_cache_file, cache_dict, metadata) + print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") + save_prompt_cache(args.prompt_cache_file, cache, metadata) if __name__ == "__main__": diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py new file mode 100644 index 00000000..7968a868 --- /dev/null +++ b/llms/mlx_lm/chat.py @@ -0,0 +1,82 @@ +# Copyright © 2023-2024 Apple Inc. + +import argparse +import json + +import mlx.core as mx + +from .models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache +from .utils import load, stream_generate + +DEFAULT_TEMP = 0.0 +DEFAULT_TOP_P = 1.0 +DEFAULT_SEED = 0 +DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" + + +def setup_arg_parser(): + """Set up and return the argument parser.""" + parser = argparse.ArgumentParser(description="Chat with an LLM") + parser.add_argument( + "--model", + type=str, + help="The path to the local model directory or Hugging Face repo.", + default=DEFAULT_MODEL, + ) + parser.add_argument( + "--adapter-path", + type=str, + help="Optional path for the trained adapter weights and config.", + ) + parser.add_argument( + "--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature" + ) + parser.add_argument( + "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" + ) + parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") + parser.add_argument( + "--max-kv-size", + type=int, + help="Set the maximum key-value cache size", + default=None, + ) + return parser + + +def main(): + parser = setup_arg_parser() + args = parser.parse_args() + + mx.random.seed(args.seed) + + model, tokenizer = load( + args.model, + adapter_path=args.adapter_path, + tokenizer_config={"trust_remote_code": True}, + ) + + print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.") + prompt_cache = make_prompt_cache(model, args.max_kv_size) + while True: + query = input(">> ") + if query == "q": + break + messages = [{"role": "user", "content": query}] + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + for response in stream_generate( + model, + tokenizer, + prompt, + temp=args.temp, + top_p=args.top_p, + prompt_cache=prompt_cache, + ): + print(response, flush=True, end="") + print() + + +if __name__ == "__main__": + main() diff --git a/llms/mlx_lm/examples/chat.py b/llms/mlx_lm/examples/chat.py new file mode 100644 index 00000000..3bf01688 --- /dev/null +++ b/llms/mlx_lm/examples/chat.py @@ -0,0 +1,53 @@ +# Copyright © 2024 Apple Inc. + +""" +An example of a multi-turn chat with prompt caching. +""" + +from mlx_lm import generate, load +from mlx_lm.models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache + +model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit") + +# Make the initial prompt cache for the model +prompt_cache = make_prompt_cache(model) + +# User turn +prompt = "Hi my name is ." +messages = [{"role": "user", "content": prompt}] +prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True +) + +# Assistant response +response = generate( + model, + tokenizer, + prompt=prompt, + verbose=True, + temp=0.0, + prompt_cache=prompt_cache, +) + +# User turn +prompt = "What's my name?" +messages = [{"role": "user", "content": prompt}] +prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True +) + +# Assistant response +response = generate( + model, + tokenizer, + prompt=prompt, + verbose=True, + temp=0.0, + prompt_cache=prompt_cache, +) + +# Save the prompt cache to disk to reuse it at a later time +save_prompt_cache("mistral_prompt.safetensors", prompt_cache) + +# Load the prompt cache from disk +prompt_cache = load_prompt_cache("mistral_prompt.safetensors") diff --git a/llms/mlx_lm/examples/generate_response.py b/llms/mlx_lm/examples/generate_response.py index af599c1b..25730617 100644 --- a/llms/mlx_lm/examples/generate_response.py +++ b/llms/mlx_lm/examples/generate_response.py @@ -1,3 +1,5 @@ +# Copyright © 2024 Apple Inc. + from mlx_lm import generate, load # Specify the checkpoint diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 537bd853..0bf98ab2 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -6,13 +6,15 @@ import mlx.core as mx +from .models.cache import load_prompt_cache from .utils import generate, load DEFAULT_PROMPT = "hello" DEFAULT_MAX_TOKENS = 100 -DEFAULT_TEMP = 0.6 +DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 DEFAULT_SEED = 0 +DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" def str2bool(string): @@ -25,7 +27,11 @@ def setup_arg_parser(): parser.add_argument( "--model", type=str, - help="The path to the local model directory or Hugging Face repo.", + help=( + "The path to the local model directory or Hugging Face repo. " + f"If no model is specified, then {DEFAULT_MODEL} is used." + ), + default=None, ) parser.add_argument( "--adapter-path", @@ -96,7 +102,7 @@ def setup_arg_parser(): default=None, ) parser.add_argument( - "--kv-cache-file", + "--prompt-cache-file", type=str, default=None, help="A file containing saved KV caches to avoid recomputing them", @@ -131,24 +137,6 @@ def colorprint_by_t0(s, t0): colorprint(color, s) -def load_kv_cache_from_file(kv_cache_file): - if kv_cache_file is None: - return None, None - - kv_cache, metadata = mx.load(kv_cache_file, return_metadata=True) - cache_per_layer = {} - for k, x in kv_cache.items(): - layer, kv_type = k.split("_") - if layer not in cache_per_layer: - cache_per_layer[layer] = {} - cache_per_layer[layer][kv_type] = x - - cache_history = [None] * len(cache_per_layer) - for layer, c in cache_per_layer.items(): - cache_history[int(layer)] = (c["keys"], c["values"]) - return cache_history, metadata - - def main(): parser = setup_arg_parser() args = parser.parse_args() @@ -158,22 +146,33 @@ def main(): if args.cache_limit_gb is not None: mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) - # Load the kv cache and metadata if a kv cache file is provided - cache_history, metadata = load_kv_cache_from_file(args.kv_cache_file) + # Load the prompt cache and metadata if a cache file is provided + using_cache = args.prompt_cache_file is not None + if using_cache: + prompt_cache, metadata = load_prompt_cache( + args.prompt_cache_file, return_metadata=True + ) # Building tokenizer_config tokenizer_config = ( - {} if cache_history is None else json.loads(metadata["tokenizer_config"]) + {} if not using_cache else json.loads(metadata["tokenizer_config"]) ) if args.trust_remote_code: tokenizer_config["trust_remote_code"] = True if args.eos_token is not None: tokenizer_config["eos_token"] = args.eos_token - # If no model path is provided then use the one in the kv cache history model_path = args.model - if cache_history is not None and model_path is None: - model_path = metadata["model"] + if using_cache: + if model_path is None: + model_path = metadata["model"] + elif model_path != metadata["model"]: + raise ValueError( + f"Providing a different model ({model_path}) than that " + f"used to create the prompt cache ({metadata['model']}) " + "is an error." + ) + model_path = model_path or DEFAULT_MODEL model, tokenizer = load( model_path, @@ -184,7 +183,7 @@ def main(): if args.use_default_chat_template: if tokenizer.chat_template is None: tokenizer.chat_template = tokenizer.default_chat_template - elif cache_history is not None: + elif using_cache: tokenizer.chat_template = metadata["chat_template"] if not args.ignore_chat_template and ( @@ -203,7 +202,7 @@ def main(): # Treat the prompt as a suffix assuming that the prefix is in the # stored kv cache. - if cache_history is not None: + if using_cache: test_prompt = tokenizer.apply_chat_template( [{"role": "user", "content": ""}], tokenize=False, @@ -217,12 +216,6 @@ def main(): raise ValueError("Cannot use --colorize with --verbose=False") formatter = colorprint_by_t0 if args.colorize else None - # Determine the max kv size from the kv cache or passed arguments - max_kv_size = args.max_kv_size - if cache_history is not None: - max_kv_size = metadata["max_kv_size"] - max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None - response = generate( model, tokenizer, @@ -232,8 +225,8 @@ def main(): formatter=formatter, temp=args.temp, top_p=args.top_p, - max_kv_size=max_kv_size, - cache_history=cache_history, + max_kv_size=args.max_kv_size, + prompt_cache=prompt_cache if using_cache else None, ) if not args.verbose: print(response) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index dc19dd05..3628a808 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -2,145 +2,9 @@ import inspect from dataclasses import dataclass -from typing import Any, List, Optional +from typing import Any, Optional import mlx.core as mx -import mlx.nn as nn - - -class KVCache: - - def __init__(self, head_dim, n_kv_heads): - self.n_kv_heads = n_kv_heads - if isinstance(head_dim, int): - self.k_head_dim = self.v_head_dim = head_dim - elif isinstance(head_dim, tuple) and len(head_dim) == 2: - self.k_head_dim, self.v_head_dim = head_dim - else: - raise ValueError("head_dim must be an int or a tuple of two ints") - self.keys = None - self.values = None - self.offset = 0 - self.step = 256 - - def update_and_fetch(self, keys, values): - prev = self.offset - if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: - B = keys.shape[0] - n_steps = (self.step + keys.shape[2] - 1) // self.step - k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim) - v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim) - new_k = mx.zeros(k_shape, keys.dtype) - new_v = mx.zeros(v_shape, values.dtype) - if self.keys is not None: - if prev % self.step != 0: - self.keys = self.keys[..., :prev, :] - self.values = self.values[..., :prev, :] - self.keys = mx.concatenate([self.keys, new_k], axis=2) - self.values = mx.concatenate([self.values, new_v], axis=2) - else: - self.keys, self.values = new_k, new_v - - self.offset += keys.shape[2] - self.keys[..., prev : self.offset, :] = keys - self.values[..., prev : self.offset, :] = values - return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] - - @property - def state(self): - return self.keys, self.values - - -class RotatingKVCache: - - def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256): - self.n_kv_heads = n_kv_heads - if isinstance(head_dim, int): - self.k_head_dim = self.v_head_dim = head_dim - elif isinstance(head_dim, tuple) and len(head_dim) == 2: - self.k_head_dim, self.v_head_dim = head_dim - else: - raise ValueError("head_dim must be an int or a tuple of two ints") - self.keep = keep - self.keys = None - self.values = None - self.offset = 0 - self.max_size = max_size - self.step = step - self._idx = 0 - - def _trim(self, trim_size, v, append=None): - to_cat = [] - if trim_size > 0: - to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]] - else: - to_cat = [v] - if append is not None: - to_cat.append(append) - return mx.concatenate(to_cat, axis=2) - - def update_and_fetch(self, keys, values): - prev = self.offset - B, _, S = keys.shape[:3] - - # Prefill mode - if S > 1: - if self.keys is None: - self.keys = keys - self.values = values - else: - # The largest size is self.max_size + S - 1 to ensure - # every token gets at least self.max_size context - trim_size = self.keys.shape[2] - self.max_size + 1 - self.keys = self._trim(trim_size, self.keys, keys) - self.values = self._trim(trim_size, self.values, values) - self.offset += S - self._idx = self.keys.shape[2] - return self.keys, self.values - - # Generation mode - # May not have hit the max size yet, so potentially - # keep growing the cache - if self.keys is None or ( - prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size - ): - new_size = min(self.step, self.max_size - prev) - k_shape = (B, self.n_kv_heads, new_size, self.k_head_dim) - v_shape = (B, self.n_kv_heads, new_size, self.v_head_dim) - new_k = mx.zeros(k_shape, keys.dtype) - new_v = mx.zeros(v_shape, values.dtype) - if self.keys is not None: - self.keys = mx.concatenate([self.keys, new_k], axis=2) - self.values = mx.concatenate([self.values, new_v], axis=2) - else: - self.keys, self.values = new_k, new_v - self._idx = prev - - # Trim if needed - trim_size = self.keys.shape[2] - self.max_size - if trim_size > 0: - self.keys = self._trim(trim_size, self.keys) - self.values = self._trim(trim_size, self.values) - self._idx = self.max_size - - # Rotate - if self._idx == self.max_size: - self._idx = self.keep - - # Assign - self.keys[..., self._idx : self._idx + 1, :] = keys - self.values[..., self._idx : self._idx + 1, :] = values - self.offset += 1 - self._idx += 1 - - # If the buffer is not full, slice off the end - if self.offset < self.max_size: - return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] - return self.keys, self.values - - @property - def state(self): - return self.keys, self.values @dataclass @@ -156,25 +20,30 @@ def from_dict(cls, params): ) -def create_additive_causal_mask(N: int, offset: int = 0): +def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None): rinds = mx.arange(offset + N) linds = mx.arange(offset, offset + N) if offset else rinds - mask = linds[:, None] < rinds[None] + linds = linds[:, None] + rinds = rinds[None] + mask = linds < rinds + if window_size is not None: + mask = mask | (linds > rinds + window_size) return mask * -1e9 def create_attention_mask(h: mx.array, cache: Optional[Any] = None): T = h.shape[1] if T > 1: + window_size = None + offset = 0 if cache is not None and cache[0] is not None: c = cache[0] - if isinstance(c, RotatingKVCache): + if hasattr(c, "max_size"): offset = min(c.max_size - 1, c.offset) + window_size = c.max_size else: offset = c.offset - else: - offset = 0 - mask = create_additive_causal_mask(T, offset) + mask = create_causal_mask(T, offset, window_size=window_size) mask = mask.astype(h.dtype) else: mask = None diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py new file mode 100644 index 00000000..b06422e5 --- /dev/null +++ b/llms/mlx_lm/models/cache.py @@ -0,0 +1,333 @@ +# Copyright © 2023-2024 Apple Inc. + +from typing import Any, Dict, List, Optional + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_flatten, tree_unflatten + + +def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]: + """ + Construct the model's cache for use when cgeneration. + + This function will defer the cache construction to the model if it has a + ``make_cache`` method, otherwise it will make a default KV cache. + + Args: + model (nn.Module): The language model. + max_kv_size (Optional[int]): If provided and the model does not have a + ``make_cache`` method, a ``RotatingKVCache`` is used with a maximum + size of ``max_kv_size`` + """ + if hasattr(model, "make_cache"): + return model.make_cache() + + num_layers = len(model.layers) + if max_kv_size is not None: + return [ + RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers) + ] + else: + return [KVCache() for _ in range(num_layers)] + + +def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}): + """ + Save a pre-computed prompt cache to a file. + + Args: + file_name (str): The ``.safetensors`` file name. + cache (List[Any]): The model state. + metadata (Dict[str, str]): Optional metadata to save along with model + state. + """ + cache_data = [c.state for c in cache] + cache_info = [c.meta_state for c in cache] + cache_data = dict(tree_flatten(cache_data)) + cache_classes = [type(c).__name__ for c in cache] + cache_metadata = [cache_info, metadata, cache_classes] + cache_metadata = dict(tree_flatten(cache_metadata)) + mx.save_safetensors(file_name, cache_data, cache_metadata) + + +def load_prompt_cache(file_name, return_metadata=False): + """ + Load a prompt cache from a file. + + Args: + file_name (str): The ``.safetensors`` file name. + return_metadata (bool): Whether or not to return metadata. + Default: ``False``. + + Returns: + List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and + the metadata if requested. + """ + arrays, cache_metadata = mx.load(file_name, return_metadata=True) + arrays = tree_unflatten(list(arrays.items())) + cache_metadata = tree_unflatten(list(cache_metadata.items())) + info, metadata, classes = cache_metadata + cache = [globals()[c]() for c in classes] + for c, state, meta_state in zip(cache, arrays, info): + c.state = state + c.meta_state = meta_state + if return_metadata: + return cache, metadata + return cache + + +def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: + """ + Trim the model's cache by the given number of tokens. + + This function will trim the cache if possible (in-place) and return the + number of tokens that were trimmed. + + Args: + cache (List[Any]): The model's cache. + num_tokens (int): The number of tokens to trim. + + Returns: + (int): The number of tokens that were trimmed. + """ + if not all(c.is_trimmable() for c in cache) or len(cache) == 0: + return 0 + return [c.trim(num_tokens) for c in cache][0] + + +class _BaseCache: + @property + def state(self): + return [] + + @state.setter + def state(self, v): + if v is not None and v: + raise ValueError("This cache has no state but a state was set.") + + @property + def meta_state(self): + return "" + + @meta_state.setter + def meta_state(self, v): + if v is not None and v: + raise ValueError("This cache has no meta_state but a meta_state was set.") + + def is_trimmable(self): + return False + + +class KVCache(_BaseCache): + def __init__(self): + self.keys = None + self.values = None + self.offset = 0 + self.step = 256 + + def update_and_fetch(self, keys, values): + prev = self.offset + if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: + B, n_kv_heads, _, k_head_dim = keys.shape + v_head_dim = values.shape[3] + n_steps = (self.step + keys.shape[2] - 1) // self.step + k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim) + v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim) + new_k = mx.zeros(k_shape, keys.dtype) + new_v = mx.zeros(v_shape, values.dtype) + if self.keys is not None: + if prev % self.step != 0: + self.keys = self.keys[..., :prev, :] + self.values = self.values[..., :prev, :] + self.keys = mx.concatenate([self.keys, new_k], axis=2) + self.values = mx.concatenate([self.values, new_v], axis=2) + else: + self.keys, self.values = new_k, new_v + + self.offset += keys.shape[2] + self.keys[..., prev : self.offset, :] = keys + self.values[..., prev : self.offset, :] = values + return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + + @property + def state(self): + if self.offset == self.keys.shape[2]: + return self.keys, self.values + else: + return ( + self.keys[..., : self.offset, :], + self.values[..., : self.offset, :], + ) + + @state.setter + def state(self, v): + self.keys, self.values = v + self.offset = self.keys.shape[2] + + def is_trimmable(self): + return True + + def trim(self, n): + n = min(self.offset, n) + self.offset -= n + return n + + +class RotatingKVCache(_BaseCache): + + def __init__(self, max_size=None, keep=0, step=256): + self.keep = keep + self.keys = None + self.values = None + self.offset = 0 + self.max_size = max_size + self.step = step + self._idx = 0 + + def _trim(self, trim_size, v, append=None): + to_cat = [] + if trim_size > 0: + to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]] + else: + to_cat = [v] + if append is not None: + to_cat.append(append) + return mx.concatenate(to_cat, axis=2) + + def _temporal_order(self, v): + """ + Rearrange the cache into temporal order, slicing off the end if unused. + """ + if self._idx == v.shape[2]: + return v + elif self._idx < self.offset: + return mx.concatenate( + [ + v[..., : self.keep, :], + v[..., self._idx :, :], + v[..., self.keep : self._idx, :], + ], + axis=2, + ) + else: + return v[..., : self._idx, :] + + def _update_concat(self, keys, values): + if self.keys is None: + self.keys = keys + self.values = values + else: + # Put the keys/values in temporal order to + # preserve context + self.keys = self._temporal_order(self.keys) + self.values = self._temporal_order(self.values) + + # The largest size is self.max_size + S - 1 to ensure + # every token gets at least self.max_size context + trim_size = self._idx - self.max_size + 1 + self.keys = self._trim(trim_size, self.keys, keys) + self.values = self._trim(trim_size, self.values, values) + self.offset += keys.shape[2] + self._idx = self.keys.shape[2] + return self.keys, self.values + + def _update_in_place(self, keys, values): + # May not have hit the max size yet, so potentially + # keep growing the cache + B, n_kv_heads, S, k_head_dim = keys.shape + prev = self.offset + if self.keys is None or ( + prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size + ): + v_head_dim = values.shape[3] + new_size = min(self.step, self.max_size - prev) + k_shape = (B, n_kv_heads, new_size, k_head_dim) + v_shape = (B, n_kv_heads, new_size, v_head_dim) + new_k = mx.zeros(k_shape, keys.dtype) + new_v = mx.zeros(v_shape, values.dtype) + if self.keys is not None: + self.keys = mx.concatenate([self.keys, new_k], axis=2) + self.values = mx.concatenate([self.values, new_v], axis=2) + else: + self.keys, self.values = new_k, new_v + self._idx = prev + + # Trim if needed + trim_size = self.keys.shape[2] - self.max_size + if trim_size > 0: + self.keys = self._trim(trim_size, self.keys) + self.values = self._trim(trim_size, self.values) + self._idx = self.max_size + + # Rotate + if self._idx == self.max_size: + self._idx = self.keep + + # Assign + self.keys[..., self._idx : self._idx + S, :] = keys + self.values[..., self._idx : self._idx + S, :] = values + self.offset += S + self._idx += S + + # If the buffer is not full, slice off the end + if self.offset < self.max_size: + return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + return self.keys, self.values + + def update_and_fetch(self, keys, values): + if keys.shape[2] == 1: + return self._update_in_place(keys, values) + return self._update_concat(keys, values) + + @property + def state(self): + if self.offset < self.keys.shape[2]: + return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + else: + return self.keys, self.values + + @state.setter + def state(self, v): + self.keys, self.values = v + + @property + def meta_state(self): + return tuple( + map(str, (self.keep, self.max_size, self.step, self.offset, self._idx)) + ) + + @meta_state.setter + def meta_state(self, v): + self.keep, self.max_size, self.step, self.offset, self._idx = map( + int, + v, + ) + + def is_trimmable(self): + return self.offset < self.max_size + + def trim(self, n): + n = min(self.offset, n) + self.offset -= n + self._idx -= n + return n + + +class MambaCache(_BaseCache): + def __init__(self): + self.cache = [None, None] + + def __setitem__(self, idx, value): + self.cache[idx] = value + + def __getitem__(self, idx): + return self.cache[idx] + + @property + def state(self): + return self.cache + + @state.setter + def state(self, v): + self.cache = v diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py index cfcf2945..057c816d 100644 --- a/llms/mlx_lm/models/cohere.py +++ b/llms/mlx_lm/models/cohere.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -69,7 +69,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -129,7 +129,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: h = self.input_layernorm(x) attn_h = self.self_attn(h, mask, cache) @@ -190,11 +190,3 @@ def __call__( @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py index f0214549..3b7e83d7 100644 --- a/llms/mlx_lm/models/dbrx.py +++ b/llms/mlx_lm/models/dbrx.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -49,7 +49,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: qkv = self.Wqkv(x) @@ -92,7 +92,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: h = self.attn(self.norm_1(x), mask=mask, cache=cache) x = h + x @@ -179,7 +179,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r, h = self.norm_attn_norm(x, mask, cache) out = self.ffn(h) + r @@ -249,11 +249,3 @@ def sanitize(self, weights): experts = [(s, sv.T) for s, sv in experts] new_weights.update(experts) return new_weights - - @property - def head_dim(self): - return self.args.d_model // self.args.n_heads - - @property - def n_kv_heads(self): - return self.args.attn_config["kv_n_heads"] diff --git a/llms/mlx_lm/models/deepseek.py b/llms/mlx_lm/models/deepseek.py index dcfa331c..03cb3b1a 100644 --- a/llms/mlx_lm/models/deepseek.py +++ b/llms/mlx_lm/models/deepseek.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from typing import Dict, Optional +from typing import Any, Dict, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask from .switch_layers import SwitchGLU @@ -77,7 +77,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, _ = x.shape @@ -108,8 +108,8 @@ class DeepseekMLP(nn.Module): def __init__( self, config: ModelArgs, - hidden_size: int | None = None, - intermediate_size: int | None = None, + hidden_size: Optional[int] = None, + intermediate_size: Optional[int] = None, ): super().__init__() self.config = config @@ -188,7 +188,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -210,7 +210,7 @@ def __init__(self, config: ModelArgs): def __call__( self, x: mx.array, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: h = self.embed_tokens(x) mask = create_attention_mask(h, cache) @@ -235,7 +235,7 @@ def __init__(self, config: ModelArgs): def __call__( self, inputs: mx.array, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ): out = self.model(inputs, cache) return self.lm_head(out) @@ -256,11 +256,3 @@ def sanitize(self, weights): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 602a9710..17d061a8 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -2,12 +2,12 @@ import math from dataclasses import dataclass -from typing import Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask from .switch_layers import SwitchGLU @@ -38,7 +38,7 @@ class ModelArgs(BaseModelArgs): max_position_embeddings: int = 2048 rms_norm_eps: float = 1e-6 rope_theta: float = 10000.0 - rope_scaling: Optional[Dict] = None + rope_scaling: Dict = None attention_bias: bool = False @@ -172,12 +172,11 @@ def __init__(self, config: ModelArgs): bias=config.attention_bias, ) - if self.config.rope_scaling is not None: - mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) - scaling_factor = self.config.rope_scaling["factor"] - if mscale_all_dim: - mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) - self.scale = self.scale * mscale * mscale + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scale = self.scale * mscale * mscale rope_kwargs = { key: self.config.rope_scaling[key] @@ -202,7 +201,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -347,7 +346,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -370,7 +369,7 @@ def __init__(self, config: ModelArgs): def __call__( self, x: mx.array, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: h = self.embed_tokens(x) mask = create_attention_mask(h, cache) @@ -395,7 +394,7 @@ def __init__(self, config: ModelArgs): def __call__( self, inputs: mx.array, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ): out = self.model(inputs, cache) return self.lm_head(out) @@ -416,14 +415,3 @@ def sanitize(self, weights): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return ( - self.args.qk_nope_head_dim + self.args.qk_rope_head_dim, - self.args.v_head_dim, - ) - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index c6150284..61de781e 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -60,7 +60,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -113,7 +113,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -173,11 +173,3 @@ def __call__( @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.head_dim - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/gemma2.py b/llms/mlx_lm/models/gemma2.py index 1d410a15..ccc327a8 100644 --- a/llms/mlx_lm/models/gemma2.py +++ b/llms/mlx_lm/models/gemma2.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -64,7 +64,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) @@ -135,13 +135,11 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: - r = self.self_attn(self.input_layernorm(x.astype(mx.float32)), mask, cache) + r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + self.post_attention_layernorm(r) - r = self.mlp(self.pre_feedforward_layernorm(h).astype(mx.float16)).astype( - mx.float32 - ) + r = self.mlp(self.pre_feedforward_layernorm(h)) out = h + self.post_feedforward_layernorm(r) return out @@ -200,11 +198,3 @@ def __call__( @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.head_dim - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/gpt2.py b/llms/mlx_lm/models/gpt2.py index 8a770936..97d9a8ff 100644 --- a/llms/mlx_lm/models/gpt2.py +++ b/llms/mlx_lm/models/gpt2.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -46,7 +46,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -100,7 +100,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.attn(self.ln_1(x), mask, cache) h = x + r @@ -196,11 +196,3 @@ def sanitize(self, weights): @property def layers(self): return self.model.h - - @property - def head_dim(self): - return self.args.n_embd // self.args.n_head - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py index 652eb9e4..068046ea 100644 --- a/llms/mlx_lm/models/gpt_bigcode.py +++ b/llms/mlx_lm/models/gpt_bigcode.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -57,7 +57,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -114,7 +114,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.attn(self.ln_1(x), mask, cache) h = x + r @@ -184,11 +184,3 @@ def __call__( @property def layers(self): return self.transformer.h - - @property - def head_dim(self): - return self.args.n_embd // self.args.n_head - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/gpt_neox.py b/llms/mlx_lm/models/gpt_neox.py index c2aaa9ea..9f662491 100644 --- a/llms/mlx_lm/models/gpt_neox.py +++ b/llms/mlx_lm/models/gpt_neox.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -60,7 +60,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -120,7 +120,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: residual = x # NeoX runs attention and feedforward network in parallel. @@ -214,11 +214,3 @@ def sanitize(self, weights): @property def layers(self): return self.model.h - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/internlm2.py b/llms/mlx_lm/models/internlm2.py index bcc0cf0c..5264cb57 100644 --- a/llms/mlx_lm/models/internlm2.py +++ b/llms/mlx_lm/models/internlm2.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -116,7 +116,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -171,7 +171,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.attention(self.attention_norm(x), mask, cache) h = x + r @@ -236,11 +236,3 @@ def sanitize(self, weights): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index c4a947a5..7da6b333 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -1,12 +1,12 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -171,7 +171,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -233,7 +233,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -303,13 +303,3 @@ def sanitize(self, weights): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return ( - self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads - ) - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 26408426..d2740dc1 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -7,6 +7,7 @@ import mlx.nn as nn from .base import BaseModelArgs +from .cache import MambaCache @dataclass @@ -45,21 +46,6 @@ def __post_init__(self): self.time_step_rank = math.ceil(self.hidden_size / 16) -class MambaCache: - def __init__(self): - self.cache = [None, None] - - def __setitem__(self, idx, value): - self.cache[idx] = value - - def __getitem__(self, idx): - return self.cache[idx] - - @property - def state(self): - return self.cache - - class DepthWiseConv1d(nn.Module): def __init__(self, channels, kernel_size, bias=True, padding=0): super().__init__() @@ -223,7 +209,7 @@ def sanitize(self, weights): weights[k] = v.moveaxis(2, 1) return weights - def make_cache(self, batch_size: int = 1): + def make_cache(self): return [MambaCache() for _ in range(len(self.layers))] @property diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index df0670be..4ac3c3b4 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -85,7 +85,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ): B, L, _ = x.shape @@ -135,7 +135,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers)) @@ -205,11 +205,3 @@ def sanitize(self, weights): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index 2db57752..20944fe3 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -2,7 +2,7 @@ import math from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -66,7 +66,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -138,7 +138,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -215,11 +215,3 @@ def sanitize(self, weights): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/nemotron.py b/llms/mlx_lm/models/nemotron.py index ef55d1d7..3ea06e27 100644 --- a/llms/mlx_lm/models/nemotron.py +++ b/llms/mlx_lm/models/nemotron.py @@ -2,12 +2,12 @@ from dataclasses import dataclass from functools import partial -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -94,7 +94,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, _ = x.shape @@ -151,7 +151,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -215,13 +215,3 @@ def __call__( @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return ( - self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads - ) - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index 59849c96..3627df06 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -1,8 +1,8 @@ # Copyright © 2023-2024 Apple Inc. +import sys from dataclasses import dataclass -from sys import exit -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -13,7 +13,7 @@ import hf_olmo except ImportError: print("To run olmo install ai2-olmo: pip install ai2-olmo") - exit(1) + sys.exit(1) @dataclass @@ -68,7 +68,7 @@ def attend( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -98,7 +98,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.attend(self.att_norm(x), mask, cache) h = x + r @@ -174,11 +174,3 @@ def __call__( @property def layers(self): return self.model.transformer.blocks - - @property - def head_dim(self): - return self.args.d_model // self.args.n_heads - - @property - def n_kv_heads(self): - return self.args.n_heads diff --git a/llms/mlx_lm/models/openelm.py b/llms/mlx_lm/models/openelm.py index 19d3c027..090e21c6 100644 --- a/llms/mlx_lm/models/openelm.py +++ b/llms/mlx_lm/models/openelm.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -80,7 +80,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -152,7 +152,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.attn(self.attn_norm(x), mask, cache) h = x + r @@ -218,11 +218,3 @@ def __call__( @property def layers(self): return self.transformer.layers - - @property - def head_dim(self): - return self.args.head_dim - - @property - def n_kv_heads(self): - return self.args.num_kv_heads diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index fd3fd709..56b383b2 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -162,19 +162,11 @@ def __init__(self, config: ModelArgs): def __call__( self, x: mx.array, - cache: mx.array = None, - ) -> Tuple[mx.array, mx.array]: + cache=None, + ) -> mx.array: y = self.model(x, cache) return self.lm_head(y) @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index 112ade7d..9ef76f04 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -1,12 +1,12 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask from .su_rope import SuScaledRotaryEmbedding @@ -84,7 +84,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -143,7 +143,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -202,11 +202,3 @@ def __call__( @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py index 665dbc73..6b0759b4 100644 --- a/llms/mlx_lm/models/phi3small.py +++ b/llms/mlx_lm/models/phi3small.py @@ -3,12 +3,12 @@ import math from dataclasses import dataclass from functools import partial -from typing import Dict, Optional, Tuple, Union +from typing import Any, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -22,14 +22,14 @@ class ModelArgs(BaseModelArgs): num_attention_heads: int layer_norm_epsilon: float vocab_size: int - num_key_value_heads: Optional[int] = None + num_key_value_heads: int mup_attn_multiplier: float = 1.0 mup_use_scaling: bool = True mup_embedding_multiplier: float = 10.0 mup_width_multiplier: float = 8.0 rope_embedding_base: float = 1000000 rope_position_scale: float = 1.0 - blocksparse_block_size: Tuple[int] = (64,) + blocksparse_block_size: int = 64 blocksparse_num_local_blocks: int = 16 blocksparse_vert_stride: int = 8 @@ -61,7 +61,6 @@ def __init__(self, args: ModelArgs, layer_idx): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads - assert args.num_key_value_heads is not None self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_q_per_kv = n_heads // n_kv_heads @@ -161,7 +160,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -230,7 +229,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -304,16 +303,8 @@ def __call__( def layers(self): return self.model.layers - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - def sanitize(self, weights): # Remove unused precomputed rotary freqs return { k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k } - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/phimoe.py b/llms/mlx_lm/models/phimoe.py index db6bd4b5..ca20a388 100644 --- a/llms/mlx_lm/models/phimoe.py +++ b/llms/mlx_lm/models/phimoe.py @@ -173,6 +173,7 @@ def __call__( class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() + self.model_type = args.model_type self.args = args self.model = PhiMoEModel(args) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True) @@ -208,11 +209,3 @@ def sanitize(self, weights): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index bb67615d..865d0d8e 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -168,8 +168,8 @@ def __call__( self, x: mx.array, mask: mx.array = None, - cache: mx.array = None, - ) -> Tuple[mx.array, mx.array]: + cache=None, + ) -> mx.array: mask = create_attention_mask(x, cache) y = self.transformer(x, mask, cache) @@ -193,11 +193,3 @@ def sanitize(self, weights): @property def layers(self): return self.transformer.h - - @property - def head_dim(self): - return self.args.model_dim // self.args.num_heads - - @property - def n_kv_heads(self): - return self.args.num_heads diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index 5d2b7586..090922ae 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional import mlx.core as mx import mlx.nn as nn @@ -62,8 +62,8 @@ def __call__( self, hidden_states: mx.array, attention_mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: + cache: Optional[Any] = None, + ) -> mx.array: bsz, q_len, _ = hidden_states.shape queries = self.q_proj(hidden_states) @@ -127,8 +127,8 @@ def __call__( self, hidden_states: mx.array, attention_mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> Tuple[Any, ...]: + cache: Optional[Any] = None, + ): # from LlamaDecoder residual = hidden_states @@ -169,8 +169,8 @@ def __init__(self, config: ModelArgs): def __call__( self, inputs: mx.array, - cache: Optional[List[Union[Tuple[mx.array, mx.array], None]]] = None, - ) -> Tuple[mx.array, Optional[List[Union[Tuple[mx.array, mx.array], None]]]]: + cache: Optional[Any] = None, + ) -> mx.array: h = self.embed_tokens(inputs) mask = create_attention_mask(h, cache) @@ -197,19 +197,11 @@ def __init__(self, args: ModelArgs) -> None: def __call__( self, inputs: mx.array, - cache: Optional[List[Tuple[mx.array, mx.array]]] = None, - ) -> Tuple[mx.array, mx.array]: + cache: Optional[Any] = None, + ) -> mx.array: out = self.model(inputs, cache) return self.lm_head(out) @property def layers(self): return self.model.layers.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_attention_heads // self.args.n_shared_head diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py index 6d2c7bbf..2b69d5ec 100644 --- a/llms/mlx_lm/models/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -1,7 +1,6 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Tuple import mlx.core as mx import mlx.nn as nn @@ -149,19 +148,11 @@ def __call__( self, x: mx.array, mask: mx.array = None, - cache: mx.array = None, - ) -> Tuple[mx.array, mx.array]: + cache=None, + ) -> mx.array: y = self.transformer(x, mask, cache) return self.lm_head(y) @property def layers(self): return self.transformer.h - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_attention_heads diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index b3ce02a3..4e7858de 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -1,12 +1,12 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -70,7 +70,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -124,7 +124,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -196,11 +196,3 @@ def sanitize(self, weights): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index ff7831f3..d199116f 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -2,12 +2,12 @@ import math from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask from .switch_layers import SwitchGLU @@ -70,7 +70,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -162,7 +162,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -236,11 +236,3 @@ def sanitize(self, weights): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py index 34750ace..06a307a6 100644 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -7,13 +7,13 @@ import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask +from .cache import MambaCache, RotatingKVCache @dataclass class ModelArgs(BaseModelArgs): model_type: str - hidden_size: int attention_bias: bool conv1d_width: int hidden_size: int @@ -36,59 +36,6 @@ def __post_init__(self): self.block_types = self._block_types -def create_window_causal_mask(N: int, window_size: int): - inds = mx.arange(N) - linds = inds[:, None] - rinds = inds[None] - mask = (linds < rinds) | (linds > rinds + window_size) - return mask * -1e9 - - -class RecurrentCache: - - def __init__(self): - self._cache = (None, None) - - def __getitem__(self, idx): - return self._cache[idx] - - def update(self, conv_state, recurrent_state): - self._cache = (conv_state, recurrent_state) - - def state(self): - return self._cache - - -class WindowKVCache: - - def __init__(self, window_size): - self.keys = None - self.values = None - self.offset = 0 - self.window_size = window_size - - def update_and_fetch(self, keys, values): - # TODO consider using rotating buffer here - # especially for very long generations - def _update(x, v): - t = x.shape[2] - self.window_size - if t > 0: - x = x[..., t:, :] - return mx.concatenate([x, v], axis=2) - - self.offset += keys.shape[2] - if self.keys is None: - self.keys = keys - self.values = values - else: - self.keys = _update(self.keys, keys) - self.values = _update(self.values, values) - return self.keys, self.values - - def state(self): - return self.keys, self.values - - class RMSNorm(nn.Module): def __init__(self, dims: int, eps: float = 1e-5): super().__init__() @@ -136,31 +83,22 @@ def __init__( kernel_size: int, ): super().__init__() - self.weight = mx.zeros((kernel_size, channels)) + self.weight = mx.zeros((channels, kernel_size, 1)) self.bias = mx.zeros((channels,)) def __call__(self, x, cache=None): - w = self.weight.T[..., None] - kw, groups = self.weight.shape + B, L, C = x.shape + groups, K, _ = self.weight.shape + if cache is not None: - l = [] - # Pad the cache if needed - if cache.shape[1] < kw - 1: - l.append( - mx.zeros( - (x.shape[0], kw - 1 - cache.shape[1], groups), dtype=x.dtype - ) - ) - l.extend([cache, x]) - x = mx.concatenate(l, axis=1) - y = (x * w.swapaxes(0, 2)).sum(axis=1, keepdims=True) + x = mx.concatenate([cache, x], axis=1) else: - y = mx.conv_general(x, w, padding=([kw - 1], [0]), groups=groups) + x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)]) - # The cache is always kw - 1 - cache = x[:, max(x.shape[1] - kw + 1, 0) :, :] + y = mx.conv_general(x, self.weight, groups=groups) y = y + self.bias - return y, cache + + return y, x[:, -K + 1 :, :] class RGLRU(nn.Module): @@ -269,19 +207,9 @@ def __call__( # x branch. x = self.linear_x(x) if cache is None: - conv_state, recurrent_state = (None, None) - else: - conv_state, recurrent_state = cache[0], cache[1] - x, conv_state = self.conv_1d( - x=x, - cache=conv_state, - ) - x, recurrent_state = self.rg_lru( - x=x, - cache=recurrent_state, - ) - if cache is not None: - cache.update(conv_state, recurrent_state) + cache = [None, None] + x, cache[0] = self.conv_1d(x=x, cache=cache[0]) + x, cache[1] = self.rg_lru(x=x, cache=cache[1]) x = x * y x = self.linear_out(x) @@ -467,12 +395,14 @@ def __call__( if self.scale_by_sqrt_dim: x = x * math.sqrt(x.shape[-1]) - mask = None - if x.shape[1] > 1: - mask = create_window_causal_mask( - x.shape[1], self.config.attention_window_size - ) - mask = mask.astype(x.dtype) + if cache is None: + cache = [None] * len(self.layers) + + for i, block in enumerate(self.layers): + if block.temporal_block_type != "recurrent": + mask_cache = [cache[i]] + + mask = create_attention_mask(x, mask_cache) for i, block in enumerate(self.layers): x = block(x, mask=mask, cache=cache[i]) @@ -485,6 +415,7 @@ class Model(nn.Module): def __init__(self, config): self.args = config self.model = Griffin(config) + self.model_type = config.model_type self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def __call__(self, tokens: mx.array, cache=None) -> mx.array: @@ -508,10 +439,9 @@ def layers(self): return self.model.layers def sanitize(self, weights): - # Remove unused precomputed rotary freqs for k, v in weights.items(): if "conv_1d.weight" in k and v.ndim == 3: - weights[k] = v.squeeze(1).T + weights[k] = v.moveaxis(2, 1) if "lm_head.weight" not in weights: self.pop("lm_head") return weights @@ -520,7 +450,7 @@ def make_cache(self): cache = [] for layer in self.layers: if layer.temporal_block_type == "recurrent": - cache.append(RecurrentCache()) + cache.append(MambaCache()) else: - cache.append(WindowKVCache(self.args.attention_window_size)) + cache.append(RotatingKVCache(max_size=self.args.attention_window_size)) return cache diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index b340de28..11202b02 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -2,7 +2,6 @@ import math from dataclasses import dataclass -from typing import Tuple import mlx.core as mx import mlx.nn as nn @@ -198,8 +197,8 @@ def __call__( self, x: mx.array, mask: mx.array = None, - cache: mx.array = None, - ) -> Tuple[mx.array, mx.array]: + cache=None, + ) -> mx.array: mask = create_attention_mask(x, cache) y = self.model(x, mask, cache) return self.lm_head(y) @@ -207,11 +206,3 @@ def __call__( @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 9cec0e39..ce0a2ec5 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -1,12 +1,12 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -45,7 +45,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -100,7 +100,7 @@ def __call__( self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -164,11 +164,3 @@ def __call__( @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 159efb54..8b9fb27d 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -18,7 +18,7 @@ from transformers import PreTrainedTokenizer # Local imports -from .models.base import KVCache, RotatingKVCache +from .models import base, cache from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model @@ -124,26 +124,6 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float) return logits -def make_kv_caches( - model: nn.Module, max_kv_size: Optional[int] = None -) -> List[Union[KVCache, RotatingKVCache]]: - if hasattr(model, "make_cache"): - return model.make_cache() - - kv_heads = ( - [model.n_kv_heads] * len(model.layers) - if isinstance(model.n_kv_heads, int) - else model.n_kv_heads - ) - if max_kv_size is not None: - return [ - RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4) - for n in kv_heads - ] - else: - return [KVCache(model.head_dim, n) for n in kv_heads] - - def generate_step( prompts: mx.array, model: nn.Module, @@ -155,7 +135,7 @@ def generate_step( min_tokens_to_keep: int = 1, prefill_step_size: int = 512, max_kv_size: Optional[int] = None, - cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None, + prompt_cache: Optional[Any] = None, logit_bias: Optional[Dict[int, float]] = None, logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: @@ -180,6 +160,8 @@ def generate_step( prefill_step_size (int): Step size for processing the prompt. max_kv_size (int, optional): Maximum size of the key-value cache. Old entries (except the first 4 tokens) will be overwritten. + prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if + provided, the cache will be updated in place. logit_bias (dictionary, optional): Additive logit bias. logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): A list of functions that take tokens and logits and return the processed @@ -243,20 +225,13 @@ def logit_bias_processor(_, logits): tokens = None # Create the KV cache for generation - cache = make_kv_caches(model, max_kv_size) - - if cache_history is not None: - if len(cache_history) != len(cache): - raise ValueError("Wrong number of layers in the cache history") - - # Set the history in the cache objects and evaluate them to prepare for - # generation. - for c, h in zip(cache, cache_history): - c.update_and_fetch(h[0], h[1]) - mx.eval([c.state for c in cache]) + if prompt_cache is None: + prompt_cache = cache.make_prompt_cache(model, max_kv_size) + elif len(prompt_cache) != len(model.layers): + raise ValueError("Wrong number of layers in the prompt cache.") def _step(y): - logits = model(y, cache=cache) + logits = model(y, cache=prompt_cache) logits = logits[:, -1, :] if logits_processor: @@ -270,7 +245,7 @@ def _step(y): return y, logprobs while y.shape[1] > prefill_step_size: - model(y[:, :prefill_step_size], cache=cache) + model(y[:, :prefill_step_size], cache=prompt_cache) mx.eval([c.state for c in cache]) y = y[:, prefill_step_size:] @@ -312,9 +287,9 @@ def stream_generate( detokenizer = tokenizer.detokenizer detokenizer.reset() - for (token, _), n in zip( - generate_step(prompt_tokens[None], model, **kwargs), + for _, (token, _) in zip( range(max_tokens), + generate_step(prompt_tokens, model, **kwargs), ): token = token.item() if token == tokenizer.eos_token_id: @@ -365,9 +340,9 @@ def generate( tic = time.perf_counter() detokenizer.reset() - for (token, logprobs), n in zip( - generate_step(prompt_tokens[None], model, **kwargs), + for n, (token, logprobs) in zip( range(max_tokens), + generate_step(prompt_tokens[None], model, **kwargs), ): token = token.item() if n == 0: diff --git a/llms/setup.py b/llms/setup.py index e2cfe0cd..1c696dc0 100644 --- a/llms/setup.py +++ b/llms/setup.py @@ -32,6 +32,7 @@ entry_points={ "console_scripts": [ "mlx_lm.cache_prompt = mlx_lm.cache_prompt:main", + "mlx_lm.chat = mlx_lm.chat:main", "mlx_lm.convert = mlx_lm.convert:main", "mlx_lm.fuse = mlx_lm.fuse:main", "mlx_lm.generate = mlx_lm.generate:main", diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index cd7e7fd0..1efde5ae 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -1,17 +1,15 @@ # Copyright © 2024 Apple Inc. - import unittest import mlx.core as mx from mlx.utils import tree_map -from mlx_lm.models.base import KVCache, RotatingKVCache -from mlx_lm.utils import make_kv_caches +from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache class TestModels(unittest.TestCase): def test_kv_cache(self): - cache = KVCache(32, 4) + cache = KVCache() k = mx.ones((1, 4, 1, 32), mx.float16) v = mx.ones((1, 4, 1, 32), mx.float16) @@ -32,7 +30,7 @@ def test_kv_cache(self): def test_rotating_kv_cache(self): b, h, d = 1, 2, 32 - cache = RotatingKVCache(d, h, max_size=8, step=4) + cache = RotatingKVCache(max_size=8, step=4) k = mx.random.uniform(shape=(b, h, 2, d)) v = mx.random.uniform(shape=(b, h, 2, d)) @@ -65,7 +63,7 @@ def test_rotating_kv_cache(self): idx %= 8 # Try with nonzero keep - cache = RotatingKVCache(d, h, max_size=8, step=4, keep=2) + cache = RotatingKVCache(max_size=8, step=4, keep=2) # Check a large update k = mx.random.uniform(shape=(b, h, 20, d)) @@ -88,6 +86,46 @@ def test_rotating_kv_cache(self): if idx >= 8: idx = 2 + def test_rotating_kv_cache_chat_mode(self): + # Test that the rotating kv cache can handle + # alternating prompt/prefill with generation + d = 4 + h = 2 + cache = RotatingKVCache(max_size=18, step=4) + + x = mx.random.uniform(shape=(1, h, 8, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(k.shape[2], 8) + self.assertEqual(cache.offset, 8) + + x = mx.random.uniform(shape=(1, h, 1, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(k.shape[2], 9) + self.assertEqual(cache.offset, 9) + self.assertTrue(mx.allclose(x, k[..., 8:9, :])) + + x = mx.random.uniform(shape=(1, h, 2, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(k.shape[2], 11) + self.assertEqual(cache.offset, 11) + self.assertTrue(mx.allclose(x, k[..., 9:11, :])) + + x = mx.random.uniform(shape=(1, h, 3, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(k.shape[2], 14) + self.assertEqual(cache.offset, 14) + self.assertTrue(mx.allclose(x, k[..., 11:14, :])) + + x = mx.random.uniform(shape=(1, h, 6, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(cache.offset, 20) + self.assertTrue(mx.allclose(x, k[..., -6:, :])) + + x = mx.random.uniform(shape=(1, h, 2, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(cache.offset, 22) + self.assertTrue(mx.allclose(x, k[..., -2:, :])) + def model_test_runner(self, model, model_type, vocab_size, num_layers): self.assertEqual(len(model.layers), num_layers) @@ -101,7 +139,7 @@ def model_test_runner(self, model, model_type, vocab_size, num_layers): self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t) - cache = make_kv_caches(model) + cache = make_prompt_cache(model) outputs = model(inputs, cache) self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t) @@ -549,6 +587,179 @@ def test_llama3_1(self): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_deepseek(self): + from mlx_lm.models import deepseek + + args = deepseek.ModelArgs( + model_type="deepseek", + vocab_size=1024, + hidden_size=128, + intermediate_size=256, + moe_intermediate_size=256, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=4, + ) + model = deepseek.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_deepseek_v2(self): + from mlx_lm.models import deepseek_v2 + + args = deepseek_v2.ModelArgs( + model_type="deepseek_v2", + vocab_size=1024, + hidden_size=128, + intermediate_size=256, + moe_intermediate_size=256, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + kv_lora_rank=4, + q_lora_rank=4, + qk_rope_head_dim=32, + v_head_dim=16, + qk_nope_head_dim=32, + rope_scaling={ + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn", + }, + ) + model = deepseek_v2.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_gemma2(self): + from mlx_lm.models import gemma2 + + args = gemma2.ModelArgs( + model_type="gemma2", + hidden_size=128, + num_hidden_layers=4, + intermediate_size=256, + num_attention_heads=2, + head_dim=32, + rms_norm_eps=1e-4, + vocab_size=1024, + num_key_value_heads=2, + ) + model = gemma2.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_gpt_bigcode(self): + from mlx_lm.models import gpt_bigcode + + args = gpt_bigcode.ModelArgs( + model_type="gpt_bigcode", + n_embd=128, + n_layer=128, + n_inner=256, + n_head=4, + n_positions=1000, + layer_norm_epsilon=1e-5, + vocab_size=1024, + ) + model = gpt_bigcode.Model(args) + self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer) + + def test_nemotron(self): + from mlx_lm.models import nemotron + + args = nemotron.ModelArgs( + model_type="nemotron", + hidden_size=128, + hidden_act="gelu", + num_hidden_layers=4, + intermediate_size=256, + num_attention_heads=4, + norm_eps=1e-5, + vocab_size=1024, + num_key_value_heads=2, + ) + model = nemotron.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_phi3small(self): + from mlx_lm.models import phi3small + + args = phi3small.ModelArgs( + model_type="phi3small", + hidden_size=128, + dense_attention_every_n_layers=2, + ff_intermediate_size=256, + gegelu_limit=1.0, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + layer_norm_epsilon=1e-4, + vocab_size=1000, + ) + model = phi3small.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_phimoe(self): + from mlx_lm.models import phimoe + + args = phimoe.ModelArgs( + model_type="phimoe", + vocab_size=320, + hidden_size=128, + intermediate_size=256, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=4, + rope_scaling={ + "long_factor": [1.0] * 16, + "long_mscale": 1.243163121016122, + "original_max_position_embeddings": 4096, + "short_factor": [1.0] * 16, + "short_mscale": 1.243163121016122, + "type": "longrope", + }, + ) + model = phimoe.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_recurrent_gemma(self): + from mlx_lm.models import recurrent_gemma + + args = recurrent_gemma.ModelArgs( + model_type="recurrent_gemma", + hidden_size=128, + attention_bias=False, + conv1d_width=3, + intermediate_size=256, + logits_soft_cap=1.0, + num_attention_heads=4, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-4, + rope_theta=1000, + attention_window_size=1024, + vocab_size=1000, + block_types=["recurrent", "recurrent", "attention"], + ) + model = recurrent_gemma.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + if __name__ == "__main__": unittest.main() diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py new file mode 100644 index 00000000..3c1ef49b --- /dev/null +++ b/llms/tests/test_prompt_cache.py @@ -0,0 +1,220 @@ +# Copyright © 2024 Apple Inc. + +import os +import tempfile +import unittest + +import mlx.core as mx +from mlx_lm.models.cache import ( + KVCache, + MambaCache, + RotatingKVCache, + load_prompt_cache, + make_prompt_cache, + save_prompt_cache, + trim_prompt_cache, +) +from mlx_lm.utils import generate_step, load + +HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" + + +class TestPromptCache(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.test_dir_fid = tempfile.TemporaryDirectory() + cls.test_dir = cls.test_dir_fid.name + + @classmethod + def tearDownClass(cls): + cls.test_dir_fid.cleanup() + + def test_save_load(self): + cache = [KVCache() for _ in range(4)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 4)) + c.update_and_fetch(x, x) + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + save_prompt_cache(cache_file, cache) + loaded_cache = load_prompt_cache(cache_file) + self.assertTrue(len(cache), len(loaded_cache)) + for c, lc in zip(cache, loaded_cache): + self.assertEqual(c.offset, lc.offset) + self.assertTrue(mx.array_equal(c.state[0], lc.state[0])) + self.assertTrue(mx.array_equal(c.state[1], lc.state[1])) + + # Test with metadata + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + metadata = {"a": "b", "c": "d"} + save_prompt_cache(cache_file, cache, metadata) + _, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True) + self.assertEqual(metadata, loaded_metadata) + + def test_save_load_rotating_cache(self): + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + + # Test with rotating cache + cache = [RotatingKVCache(max_size=8, keep=2) for _ in range(4)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 4)) + c.update_and_fetch(x, x) + + save_prompt_cache(cache_file, cache) + loaded_cache = load_prompt_cache(cache_file) + self.assertTrue(len(cache), len(loaded_cache)) + for c, lc in zip(cache, loaded_cache): + self.assertEqual(c.offset, lc.offset) + self.assertEqual(c.keep, lc.keep) + self.assertEqual(c.max_size, lc.max_size) + self.assertEqual(c.step, lc.step) + self.assertTrue(mx.array_equal(c.state[0], lc.state[0])) + self.assertTrue(mx.array_equal(c.state[1], lc.state[1])) + + # Do a couple single token updates to get a rotation + for _ in range(2): + for c in cache: + x = mx.random.uniform(shape=(1, 8, 1, 4)) + c.update_and_fetch(x, x) + + save_prompt_cache(cache_file, cache) + loaded_cache = load_prompt_cache(cache_file) + + for c, lc in zip(cache, loaded_cache): + x = mx.random.uniform(shape=(1, 8, 1, 4)) + k, v = c.update_and_fetch(x, x) + lk, lv = lc.update_and_fetch(x, x) + self.assertEqual(c.offset, lc.offset) + self.assertTrue(mx.array_equal(k, lk)) + self.assertTrue(mx.array_equal(v, lv)) + + def test_save_load_mixed_cache(self): + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + + cache = [MambaCache(), KVCache(), RotatingKVCache(8), MambaCache()] + for c in cache: + if isinstance(c, MambaCache): + c[0] = mx.random.uniform(shape=(4, 4, 4)) + c[1] = mx.random.uniform(shape=(4, 4, 4)) + else: + x = mx.random.uniform(shape=(4, 4, 7, 4)) + y = mx.random.uniform(shape=(4, 4, 7, 4)) + c.update_and_fetch(x, y) + + save_prompt_cache(cache_file, cache) + loaded_cache = load_prompt_cache(cache_file) + for c, lc in zip(cache, loaded_cache): + if isinstance(c, MambaCache): + self.assertTrue(mx.array_equal(c[0], lc[0])) + self.assertTrue(mx.array_equal(c[1], lc[1])) + else: + x = mx.random.uniform(shape=(4, 4, 1, 4)) + y = mx.random.uniform(shape=(4, 4, 1, 4)) + k, v = c.update_and_fetch(x, y) + lk, lv = lc.update_and_fetch(x, y) + self.assertEqual(c.offset, lc.offset) + self.assertTrue(mx.array_equal(k, lk)) + self.assertTrue(mx.array_equal(v, lv)) + + def test_cache_with_generate(self): + model, tokenizer = load(HF_MODEL_PATH) + prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] + results = zip(range(4), generate_step(prompt, model)) + toks, all_logits = zip(*(r[1] for r in results)) + + prompt_cache = make_prompt_cache(model) + i = 0 + for _, (tok, logits) in zip( + range(2), generate_step(prompt, model, prompt_cache=prompt_cache) + ): + self.assertEqual(tok, toks[i]) + self.assertTrue(mx.allclose(logits, all_logits[i])) + i += 1 + + for _, (tok, logits) in zip( + range(1), + generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache), + ): + i += 1 + self.assertEqual(tok, toks[i]) + self.assertTrue(mx.allclose(logits, all_logits[i])) + + def test_trim_cache(self): + cache = [KVCache() for _ in range(2)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 4)) + c.update_and_fetch(x, x) + + # Trim + num_trimmed = trim_prompt_cache(cache, 7) + self.assertEqual(num_trimmed, 7) + + # Trim more tokens than remain + num_trimmed = trim_prompt_cache(cache, 4) + self.assertEqual(num_trimmed, 3) + + # Can't trim mamba cache + cache = [MambaCache() for _ in range(2)] + for c in cache: + c.state = mx.zeros((5, 5)) + num_trimmed = trim_prompt_cache(cache, 7) + self.assertEqual(num_trimmed, 0) + + # All cache's have to be trimmable + cache = [MambaCache(), KVCache()] + cache[0].state = mx.zeros((5, 5)) + x = mx.random.uniform(shape=(1, 8, 10, 4)) + cache[1].update_and_fetch(x, x) + num_trimmed = trim_prompt_cache(cache, 1) + self.assertEqual(num_trimmed, 0) + + cache = [RotatingKVCache(max_size=6) for _ in range(2)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 5, 4)) + c.update_and_fetch(x, x) + + num_trimmed = trim_prompt_cache(cache, 4) + self.assertEqual(num_trimmed, 4) + + # Can't trim fixed-size KV cache after processing + # more than max_kv_size tokens + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 4)) + c.update_and_fetch(x, x) + + num_trimmed = trim_prompt_cache(cache, 4) + self.assertEqual(num_trimmed, 0) + + def test_trim_cache_with_generate(self): + model, tokenizer = load(HF_MODEL_PATH) + prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] + + prompt_cache = make_prompt_cache(model) + + # Generate one token so we process the full prompt + last_tok, _ = next(generate_step(prompt, model, prompt_cache=prompt_cache)) + last_tok = mx.array([last_tok]) + + # Generate two more tokens + results = zip( + range(2), generate_step(last_tok, model, prompt_cache=prompt_cache) + ) + toks, all_logits = zip(*(r[1] for r in results)) + + # To get back to the cache just after processing the prompt, + # trim by 3 tokens + trim_prompt_cache(prompt_cache, 3) + + # Generate the same thing again + results = zip( + range(2), generate_step(last_tok, model, prompt_cache=prompt_cache) + ) + second_toks, second_all_logits = zip(*(r[1] for r in results)) + self.assertEqual(toks, second_toks) + self.assertTrue( + all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits)) + ) + + +if __name__ == "__main__": + unittest.main()