Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(mlx_lm): support batch input in generate() #948

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

llllvvuu
Copy link
Contributor

@llllvvuu llllvvuu commented Aug 21, 2024

The prompt argument can now be either a str or list[str].

This is based on @willccbb's implementation at https://github.com/willccbb/mlx_parallm; I noticed that it aligned with the KVCache upgrades in #911.

The change to generate() is backwards-compatible.

The changes to generate_step(), top_p_sampling(), and min_p_sampling() are backwards-incompatible in order to unify shapes; this could be changed by adding a few if-statements, if preferred.

@llllvvuu llllvvuu changed the title feat: support batch input in generate() feat(mlx_lm): support batch input in generate() Aug 21, 2024
@llllvvuu llllvvuu marked this pull request as draft August 21, 2024 05:22
@llllvvuu llllvvuu marked this pull request as ready for review August 21, 2024 05:25
@llllvvuu
Copy link
Contributor Author

llllvvuu commented Aug 26, 2024

Kind of interesting: for quantized models, the throughput is doesn't go up a lot between small bs (bs=1,2,3,4), but then it starts to go up a lot at higher bs, which is the opposite of what I expected intuitively. For unquantized models the throughput does goes up between small bs. I observe the same on @willccbb's original repo.

The `prompt` argument can now be either a `str` or `list[str]`.

The change to `generate()` is backwards-compatible.

The changes to `generate_step()`, `top_p_sampling()`, and
`min_p_sampling()` are backwards-incompatible in order to unify shapes;
this could be changed by adding a few if-statements, if preferred.
prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer
if is_batch:
tokenizer._tokenizer.padding_side = "left"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that we left pad shorter prompts here which makes sense. But one thing that I'm wondering is how this is handled in the causal models if at all? Shouldn't the causal mask take into account the padding?

Copy link
Contributor Author

@llllvvuu llllvvuu Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't handle it, generation seems OK without it but indeed to be correct I should consume the tokenizer._tokenizer(prompt, padding=True)["attention_mask"]. To do this I would need to update our model APIs to have attention_mask as an input similar to how transformers has model.generate taking attention_mask. Probably this is involves hitting every file in models/. Should mostly be copy/paste though. I can look into it.

@awni
Copy link
Member

awni commented Aug 29, 2024

I think it makes sense to minimize the complexity to the generate function (which is becoming a bit spaghetti) to split out the batched generation into a separate function called batch_generate. I would simplify that function to have fewer arguments (like no formatter, no printing during generation, verbose only prints the timings (e.g. as you have it now).

Also maybe more tricky is the fact that I think for this to be correct, the causal masks need to consider the left padding in the input (please correct me if I'm wrong about that). This has two implications:

  1. Probably we'd need to add a mask parameter to the model __call__ functions and provide an appropriately constructed mask for the batch case.
  2. The Rotating KV cache will be broken in this case (it keeps the initial tokens which would be the padded tokens) and when rotates the mask would need to be updated to consider the padding (which is a bit complicated/tedious). In this case I may suggest disabling this option entirely..

Let me know what you think about the above.

@llllvvuu
Copy link
Contributor Author

I think it makes sense to minimize the complexity to the generate function (which is becoming a bit spaghetti) to split out the batched generation into a separate function called batch_generate. I would simplify that function to have fewer arguments (like no formatter, no printing during generation, verbose only prints the timings (e.g. as you have it now).

Makes sense to me, will implement.

Also maybe more tricky is the fact that I think for this to be correct, the causal masks need to consider the left padding in the input (please correct me if I'm wrong about that). This has two implications:

  1. Probably we'd need to add a mask parameter to the model __call__ functions and provide an appropriately constructed mask for the batch case.

Yes, this sounds straightforward enough.

  1. The Rotating KV cache will be broken in this case (it keeps the initial tokens which would be the padded tokens) and when rotates the mask would need to be updated to consider the padding (which is a bit complicated/tedious). In this case I may suggest disabling this option entirely..

I'll do a bit of thinking if there's an easy way to handle this, otherwise I'll remove that parameter in batch_generate.

Will update when these changes are ready!

@awni
Copy link
Member

awni commented Sep 27, 2024

@llllvvuu are you coming back to this?

@llllvvuu
Copy link
Contributor Author

hey @awni , sorry for the delay, I'd been job hunting this month. I should be able to get back to this in ~a week

@awni
Copy link
Member

awni commented Sep 28, 2024

No worries, just checking. I'll follow up in a week or so.

@nath1295
Copy link

Just realised the attention mask has been mentioned in this PR, which is the reason I raised this issue #1044

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants