Feature request
Removing the line logits = logits.float() in most ModelForCausalLM. This would allow to save a lot of memory for models with large vocabulary size. This allows to divide the memory peak by more than 2 on Llama3.
Motivation
This is in relation to my work in #30536.
I noticed that almost all ModelForCausalLM contain the following line in the forward:
Now, since most models are now used in (b)float16, or even quantized, that line will almost always double the memory footprint of the logits. As the vocabulary size can be quite big (e.g. Llama3), this result in a lot of memory being used.
I suspect that it was originally introduced so that later manipulations of the logits (processors, warpers...) can be applied without losing too much precision. However, in generate() we only ever use the last token logits, not the whole logit matrix. So this is a huge waste of memory.
Your contribution
If the casting of the logits to float is indeed only used for not losing precision in their manipulations, I propose to only cast the last token to float in each decoding strategy function.
So, instead of:
in forward(), do
next_token_logits = outputs.logits[:, -1, :].clone().float()
in each decoding strategy function. It would only cast the last token vector to float which is negligible in term of memory overhead.
As an example of the potential memory gains, running this very simple code snippet on Llama3 8B (vocabulary size 128256):
import torch
from transformers import AutoModelForCausalLM
model_name = 'meta-llama/Meta-Llama-3-8B'
dtype = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='flash_attention_2',
torch_dtype=dtype, low_cpu_mem_usage=True).cuda(1)
memory = []
sizes = [100] + [500*i for i in range(1, 14)]
for size in sizes:
input = torch.randint(1, 120000, (1, size), device=1)
torch.cuda.reset_peak_memory_stats(1)
actual_peak = torch.cuda.max_memory_allocated(1) / 1024**3
# Single forward pass (first iteration of `generate()`)
with torch.no_grad():
out = model(input)
memory_used = (torch.cuda.max_memory_allocated(1) / 1024**3) - actual_peak
memory.append(memory_used)
del out
gives:


That is, more than dividing by 2 the memory footprint. This is because the vocabulary size is so large that computing the logits from the hidden states is actually more costly than computing the hidden states themselves. Thus when casting to float(), we more than double the memory requirements (double for the new logits + the overhead when actually copying).
Of course, other models usually have smaller vocabulary size so will not benefit as much, but still the memory peak will decrease by a non-negligible portion for all applicable models (see below for Mistral, ~30% memory gain). And Llama3, which is I believe the hottest open-source model at the moment will be much more efficient.
mistral_ratio_example.pdf
Of course, if this casting to float is made for something else that I overlooked, this may not be applicable. Otherwise, I would be happy to make the change.
@ArthurZucker @gante
Cheers,
Cyril
Feature request
Removing the line
logits = logits.float()in mostModelForCausalLM. This would allow to save a lot of memory for models with large vocabulary size. This allows to divide the memory peak by more than 2 on Llama3.Motivation
This is in relation to my work in #30536.
I noticed that almost all
ModelForCausalLMcontain the following line in theforward:Now, since most models are now used in (b)float16, or even quantized, that line will almost always double the memory footprint of the logits. As the vocabulary size can be quite big (e.g. Llama3), this result in a lot of memory being used.
I suspect that it was originally introduced so that later manipulations of the logits (processors, warpers...) can be applied without losing too much precision. However, in
generate()we only ever use the last token logits, not the whole logit matrix. So this is a huge waste of memory.Your contribution
If the casting of the logits to float is indeed only used for not losing precision in their manipulations, I propose to only cast the last token to
floatin each decoding strategy function.So, instead of:
in
forward(), doin each decoding strategy function. It would only cast the last token vector to float which is negligible in term of memory overhead.
As an example of the potential memory gains, running this very simple code snippet on Llama3 8B (vocabulary size 128256):
gives:


That is, more than dividing by 2 the memory footprint. This is because the vocabulary size is so large that computing the logits from the hidden states is actually more costly than computing the hidden states themselves. Thus when casting to
float(), we more than double the memory requirements (double for the new logits + the overhead when actually copying).Of course, other models usually have smaller vocabulary size so will not benefit as much, but still the memory peak will decrease by a non-negligible portion for all applicable models (see below for Mistral, ~30% memory gain). And Llama3, which is I believe the hottest open-source model at the moment will be much more efficient.
mistral_ratio_example.pdf
Of course, if this casting to float is made for something else that I overlooked, this may not be applicable. Otherwise, I would be happy to make the change.
@ArthurZucker @gante
Cheers,
Cyril