Skip to content

Improving memory efficiency further 🚀 #30860

@Cyrilvallez

Description

@Cyrilvallez

Feature request

Removing the line logits = logits.float() in most ModelForCausalLM. This would allow to save a lot of memory for models with large vocabulary size. This allows to divide the memory peak by more than 2 on Llama3.

Motivation

This is in relation to my work in #30536.
I noticed that almost all ModelForCausalLM contain the following line in the forward:

logits = logits.float()

Now, since most models are now used in (b)float16, or even quantized, that line will almost always double the memory footprint of the logits. As the vocabulary size can be quite big (e.g. Llama3), this result in a lot of memory being used.
I suspect that it was originally introduced so that later manipulations of the logits (processors, warpers...) can be applied without losing too much precision. However, in generate() we only ever use the last token logits, not the whole logit matrix. So this is a huge waste of memory.

Your contribution

If the casting of the logits to float is indeed only used for not losing precision in their manipulations, I propose to only cast the last token to float in each decoding strategy function.

So, instead of:

logits = logits.float()

in forward(), do

next_token_logits = outputs.logits[:, -1, :].clone().float()

in each decoding strategy function. It would only cast the last token vector to float which is negligible in term of memory overhead.

As an example of the potential memory gains, running this very simple code snippet on Llama3 8B (vocabulary size 128256):

import torch
from transformers import AutoModelForCausalLM

model_name = 'meta-llama/Meta-Llama-3-8B'
dtype = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='flash_attention_2',
                                            torch_dtype=dtype, low_cpu_mem_usage=True).cuda(1)

memory = []
sizes = [100] + [500*i for i in range(1, 14)]
for size in sizes:

    input = torch.randint(1, 120000, (1, size), device=1)

    torch.cuda.reset_peak_memory_stats(1)
    actual_peak = torch.cuda.max_memory_allocated(1) / 1024**3

    # Single forward pass (first iteration of `generate()`)
    with torch.no_grad():
        out = model(input)

    memory_used = (torch.cuda.max_memory_allocated(1) / 1024**3) - actual_peak
    memory.append(memory_used)

    del out

gives:
llama3_example.pdf
llama3_ratio_example.pdf

That is, more than dividing by 2 the memory footprint. This is because the vocabulary size is so large that computing the logits from the hidden states is actually more costly than computing the hidden states themselves. Thus when casting to float(), we more than double the memory requirements (double for the new logits + the overhead when actually copying).

Of course, other models usually have smaller vocabulary size so will not benefit as much, but still the memory peak will decrease by a non-negligible portion for all applicable models (see below for Mistral, ~30% memory gain). And Llama3, which is I believe the hottest open-source model at the moment will be much more efficient.
mistral_ratio_example.pdf

Of course, if this casting to float is made for something else that I overlooked, this may not be applicable. Otherwise, I would be happy to make the change.

@ArthurZucker @gante

Cheers,
Cyril

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions