From 12581fbd682482f535e66e5dbe2ddc260d24780d Mon Sep 17 00:00:00 2001 From: laoda513 <128342390+laoda513@users.noreply.github.com> Date: Fri, 3 May 2024 02:21:12 +0800 Subject: [PATCH] support max_memory to specify mem usage for each GPU (#460) --- awq/models/auto.py | 2 ++ awq/models/base.py | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/awq/models/auto.py b/awq/models/auto.py index 0a236979..e114fb35 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -83,6 +83,7 @@ def from_quantized( batch_size=1, safetensors=True, device_map="balanced", + max_memory=None, offload_folder=None, download_kwargs=None, **config_kwargs, @@ -108,6 +109,7 @@ def from_quantized( use_exllama_v2=use_exllama_v2, safetensors=safetensors, device_map=device_map, + max_memory=max_memory, offload_folder=offload_folder, download_kwargs=download_kwargs, **config_kwargs, diff --git a/awq/models/base.py b/awq/models/base.py index ebd45ccc..0ef40a31 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -393,6 +393,12 @@ def from_quantized( "A device map that will be passed onto the model loading method from transformers." ), ] = "balanced", + max_memory: Annotated[ + Dict[Union[int, str], Union[int, str]], + Doc( + 'A dictionary device identifier to maximum memory which will be passed onto the model loading method from transformers. For example:{0: "4GB",1: "10GB"' + ), + ] = None, offload_folder: Annotated[ str, Doc("The folder ot offload the model to."), @@ -449,6 +455,7 @@ def from_quantized( model, checkpoint=model_weights_path, device_map=device_map, + max_memory=max_memory, no_split_module_classes=[self.layer_type], offload_folder=offload_folder, dtype=torch_dtype,