Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_best_device() can lead to CUDA OOM #611

Open
johannaSommer opened this issue Sep 13, 2024 · 0 comments
Open

get_best_device() can lead to CUDA OOM #611

johannaSommer opened this issue Sep 13, 2024 · 0 comments

Comments

@johannaSommer
Copy link

johannaSommer commented Sep 13, 2024

Dear AutoAWQ team,

thanks a lot for maintaining this amazing repository.

As far as I understand, AutoAWQ supports quantization of models that are distributed on multiple GPUs, i.e. loaded with device_map="auto". However, during the initialization of the quantization, I think the device casting might be improvable.

In particular, in aws/quantize/quantizer.py in lines 542 - 544 we cast both the embeddings and the first decoder layer to best_device:

best_device = get_best_device()
modules[0] = modules[0].to(best_device)
self.awq_model.move_embed(self.model, best_device)

However, the get_best_device() function, in the case of utilizing CUDA, forces a common casting on GPU:0. This can fail due to CUDA OOM in line 542, even though if multiple GPUs are available this would not have to be an issue:

def get_best_device():
    if torch.backends.mps.is_available():
        return "mps"
    elif torch.cuda.is_available():
        return "cuda:0"
    else:
        return "cpu"

This failure can occur if there is something else held on GPU:0 that blocks memory or possibly a model loaded with device_map=sequential with unfortunate distribution of the decoder and embedding layers (e.g. model components are loaded until GPU:0 is full and embeddings end up on GPU:1).

Do you think it would be possible and useful to adjust this device selection by e.g. selecting the least utilized GPU?
I would be grateful to hear your thoughts on this and also help out on this issue if I can. Thank you!


Here would be a simple example to reproduce the issue (autoawq==0.2.6):

import torch
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_path = 'GalrionSoftworks/Pleiades-12B-v1'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# allocate some tensor on GPU:0 for exemplary purposes
large_tensor = torch.randn(100000, 55000, device="cuda:0")

model = AutoAWQForCausalLM.from_pretrained(
    model_path, low_cpu_mem_usage=True, use_cache=False, device_map="sequential"
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model.quantize(tokenizer, quant_config=quant_config)

On a system with 4 NVIDIA A10Gs, where sufficient memory is left on the other GPUs, it fails with OOM:

---------------------------------------------------------------------------
OutOfMemoryError                          Traceback (most recent call last)
[location]                                line 5
      1 model = AutoAWQForCausalLM.from_pretrained(
      2     model_path, low_cpu_mem_usage=True, use_cache=False, device_map="sequential"
      3 )
      4 tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
----> 5 model.quantize(tokenizer, quant_config=quant_config)

[location]                                in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

[location]                                in BaseAWQForCausalLM.quantize(self, tokenizer, quant_config, calib_data, split, text_column, duo_scaling, export_compatible, apply_clip, n_parallel_calib_samples, max_calib_samples, max_calib_seq_len, max_chunk_memory)
    208 if hasattr(self, "modules_to_not_convert"):
    209     self.quant_config.modules_to_not_convert = self.modules_to_not_convert
--> 211 self.quantizer = AwqQuantizer(
    212     self,
    213     self.model,
    214     tokenizer,
    215     self.quant_config.w_bit,
    216     self.quant_config.q_group_size,
    217     self.quant_config.zero_point,
    218     self.quant_config.version,
    219     calib_data,
    220     split,
    221     text_column,
    222     duo_scaling,
    223     modules_to_not_convert=self.quant_config.modules_to_not_convert,
    224     export_compatible=export_compatible,
    225     apply_clip=apply_clip,
    226     n_parallel_calib_samples=n_parallel_calib_samples,
    227     max_calib_samples=max_calib_samples,
    228     max_calib_seq_len=max_calib_seq_len,
    229     max_chunk_memory=max_chunk_memory,
    230 )
    231 self.quantizer.quantize()
    233 self.is_quantized = True

[location]                                in AwqQuantizer.__init__(self, awq_model, model, tokenizer, w_bit, group_size, zero_point, version, calib_data, split, text_column, duo_scaling, modules_to_not_convert, export_compatible, apply_clip, n_parallel_calib_samples, max_calib_samples, max_calib_seq_len, max_chunk_memory)
     65 self.max_chunk_memory = max_chunk_memory
     66 self.modules_to_not_convert = (
     67     modules_to_not_convert if modules_to_not_convert is not None else []
     68 )
---> 69 self.modules, self.module_kwargs, self.inps = self.init_quant(
     70     n_samples=self.max_calib_samples, max_seq_len=self.max_calib_seq_len
     71 )

[location]                                in AwqQuantizer.init_quant(self, n_samples, max_seq_len)
    542 best_device = get_best_device()
    543 modules[0] = modules[0].to(best_device)
--> 544 self.awq_model.move_embed(self.model, best_device)

[location]                                in MistralAWQForCausalLM.move_embed(model, device)
     31 @staticmethod
     32 def move_embed(model: OldMistralForCausalLM, device: str):
---> 33     model.model.embed_tokens = model.model.embed_tokens.to(device)

[location]                                in Module.to(self, *args, **kwargs)
   1170         else:
   1171             raise
--> 1173 return self._apply(convert)

[location]                                in Module._apply(self, fn, recurse)
    800 # Tensors stored in modules are graph leaves, and we don't want to
    801 # track autograd history of `param_applied`, so we have to use
    802 # `with torch.no_grad():`
    803 with torch.no_grad():
--> 804     param_applied = fn(param)

[location]                                in Module.to.<locals>.convert(t)
   1152     if convert_to_format is not None and t.dim() in (4, 5):
   1153         return t.to(
   1154             device,
   1155             dtype if t.is_floating_point() or t.is_complex() else None,
   1156             non_blocking,
   1157             memory_format=convert_to_format,
   1158         )
--> 1159     return t.to(
   1160         device,
   1161         dtype if t.is_floating_point() or t.is_complex() else None,
   1162         non_blocking,
   1163     )

OutOfMemoryError: CUDA out of memory. Tried to allocate 1.25 GiB. GPU 0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant