Skip to content

Commit

Permalink
Merge branch 'main' into feat/batch_generate
Browse files Browse the repository at this point in the history
  • Loading branch information
llllvvuu committed Oct 9, 2024
2 parents 8fb82fe + b7373cb commit 308ad24
Show file tree
Hide file tree
Showing 43 changed files with 1,150 additions and 690 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ __pycache__/
# C extensions
*.so

# Vim
*.swp

# Distribution / packaging
.Python
build/
Expand Down
39 changes: 35 additions & 4 deletions llms/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,31 @@ The `mlx-lm` package also has:
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)

### Quick Start

To generate text with an LLM use:

```bash
mlx_lm.generate --prompt "Hi!"
```

To chat with an LLM use:

```bash
mlx_lm.chat
```

This will give you a chat REPL that you can use to interact with the LLM. The
chat context is preserved during the lifetime of the REPL.

Commands in `mlx-lm` typically take command line options which let you specify
the model, sampling parameters, and more. Use `-h` to see a list of available
options for a command, e.g.:

```bash
mlx_lm.generate -h
```

### Python API

You can use `mlx-lm` as a module:
Expand Down Expand Up @@ -138,7 +163,7 @@ mlx_lm.convert \

### Long Prompts and Generations

MLX LM has some tools to scale efficiently to long prompts and generations:
`mlx-lm` has some tools to scale efficiently to long prompts and generations:

- A rotating fixed-size key-value cache.
- Prompt caching
Expand All @@ -155,24 +180,30 @@ different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example:
cat prompt.txt | mlx_lm.cache_prompt \
--model mistralai/Mistral-7B-Instruct-v0.3 \
--prompt - \
--kv-cache-file mistral_prompt.safetensors
--prompt-cache-file mistral_prompt.safetensors
```

Then use the cached prompt with `mlx_lm.generate`:

```
mlx_lm.generate \
--kv-cache-file mistral_prompt.safetensors \
--prompt-cache-file mistral_prompt.safetensors \
--prompt "\nSummarize the above text."
```

The cached prompt is treated as a prefix to the supplied prompt. Also notice
when using a cached prompt, the model to use is read from the cache and need
not be supplied explicitly.

Prompt caching can also be used in the Python API in order to to avoid
recomputing the prompt. This is useful in multi-turn dialogues or across
requests that use the same context. See the
[example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py)
for more usage details.

### Supported Models

MLX LM supports thousands of Hugging Face format LLMs. If the model you want to
`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to
run is not supported, file an
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
submit a pull request.
Expand Down
2 changes: 1 addition & 1 deletion llms/mlx_lm/_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.

__version__ = "0.18.2"
__version__ = "0.19.1"
19 changes: 9 additions & 10 deletions llms/mlx_lm/cache_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@

import mlx.core as mx

from .utils import load, make_kv_caches
from .models.cache import make_prompt_cache, save_prompt_cache
from .utils import load


def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(
description="Cache the KV cache of a prompt to be reused with mlx_lm.generate"
description="Cache the state of a prompt to be reused with mlx_lm.generate"
)
parser.add_argument(
"--model",
Expand Down Expand Up @@ -60,7 +61,9 @@ def setup_arg_parser():
help="Set the maximum key-value cache size",
)
parser.add_argument(
"--kv-cache-file", help="The file to save the KV caches in", required=True
"--prompt-cache-file",
help="The file to save the prompt cache in",
required=True,
)
parser.add_argument(
"--prompt",
Expand Down Expand Up @@ -115,7 +118,7 @@ def main():
else:
prompt = args.prompt

cache = make_kv_caches(model, args.max_kv_size)
cache = make_prompt_cache(model, args.max_kv_size)
y = mx.array(tokenizer.encode(prompt))

# Process the prompt
Expand All @@ -137,16 +140,12 @@ def main():
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")

print("Saving...")
cache_dict = {}
for i, c in enumerate(cache):
cache_dict[f"{i}_keys"] = c.state[0][..., : c.offset, :]
cache_dict[f"{i}_values"] = c.state[1][..., : c.offset, :]
metadata = {}
metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
metadata["max_kv_size"] = str(args.max_kv_size)
mx.save_safetensors(args.kv_cache_file, cache_dict, metadata)
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
save_prompt_cache(args.prompt_cache_file, cache, metadata)


if __name__ == "__main__":
Expand Down
82 changes: 82 additions & 0 deletions llms/mlx_lm/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright © 2023-2024 Apple Inc.

import argparse
import json

import mlx.core as mx

from .models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
from .utils import load, stream_generate

DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"


def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(description="Chat with an LLM")
parser.add_argument(
"--model",
type=str,
help="The path to the local model directory or Hugging Face repo.",
default=DEFAULT_MODEL,
)
parser.add_argument(
"--adapter-path",
type=str,
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
)
parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
)
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
parser.add_argument(
"--max-kv-size",
type=int,
help="Set the maximum key-value cache size",
default=None,
)
return parser


def main():
parser = setup_arg_parser()
args = parser.parse_args()

mx.random.seed(args.seed)

model, tokenizer = load(
args.model,
adapter_path=args.adapter_path,
tokenizer_config={"trust_remote_code": True},
)

print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.")
prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True:
query = input(">> ")
if query == "q":
break
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for response in stream_generate(
model,
tokenizer,
prompt,
temp=args.temp,
top_p=args.top_p,
prompt_cache=prompt_cache,
):
print(response, flush=True, end="")
print()


if __name__ == "__main__":
main()
53 changes: 53 additions & 0 deletions llms/mlx_lm/examples/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright © 2024 Apple Inc.

"""
An example of a multi-turn chat with prompt caching.
"""

from mlx_lm import generate, load
from mlx_lm.models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache

model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")

# Make the initial prompt cache for the model
prompt_cache = make_prompt_cache(model)

# User turn
prompt = "Hi my name is <Name>."
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

# Assistant response
response = generate(
model,
tokenizer,
prompt=prompt,
verbose=True,
temp=0.0,
prompt_cache=prompt_cache,
)

# User turn
prompt = "What's my name?"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

# Assistant response
response = generate(
model,
tokenizer,
prompt=prompt,
verbose=True,
temp=0.0,
prompt_cache=prompt_cache,
)

# Save the prompt cache to disk to reuse it at a later time
save_prompt_cache("mistral_prompt.safetensors", prompt_cache)

# Load the prompt cache from disk
prompt_cache = load_prompt_cache("mistral_prompt.safetensors")
2 changes: 2 additions & 0 deletions llms/mlx_lm/examples/generate_response.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright © 2024 Apple Inc.

from mlx_lm import generate, load

# Specify the checkpoint
Expand Down
Loading

0 comments on commit 308ad24

Please sign in to comment.