@@ -21,6 +21,7 @@ def __init__(self, config) -> None:
2121 self .quant_config .output_dir = Path (self .quant_config .output_dir ) / self .quant_config .model_name
2222 for k , v in self .quant_config .layer_config .to_dict ().items ():
2323 setattr (self .quant_config , k , v )
24+ self .quant_cache_dir = Path (f"{ self .quant_config .output_dir } /quant_cache" )
2425
2526 def set_tokenizer (self , tokenizer ):
2627 self .tokenizer = tokenizer
@@ -89,7 +90,7 @@ def collect_hessian_pre(self, model, model_prefix, dev):
8990 from .qllm_hessian import process_collect_hessian
9091 sample_args = self .quant_config .hessian_config
9192 sample_args .base_model = self .quant_config .model_name
92- sample_args .save_path = f". /hessian_path/{ sample_args .base_model } _{ sample_args .devset_size } _{ sample_args .ctx_size } "
93+ sample_args .save_path = f"{ self . quant_config . output_dir } /hessian_path/{ sample_args .base_model } _{ sample_args .devset_size } _{ sample_args .ctx_size } "
9394
9495 self .quant_config .hessian_path = sample_args .save_path
9596 self .quant_config .inv_hessian_path = sample_args .save_path + "_inv"
@@ -123,21 +124,20 @@ def parallel_quantize(self, quantize_layer, attention_layers, num_gpus, dev):
123124
124125 pbar = tqdm .tqdm (total = len (attention_layers ), desc = f"running VPTQ on { num_gpus } GPUs" )
125126 output_queue = theading_queue .Queue ()
126- quant_tmp = Path ("quant_tmp" )
127127 for i in range (num_gpus ):
128128 output_queue .put (i ) # poison pill
129129 def fetch_next_task (future ):
130130 comm_utils .clear_memory ()
131131 pbar .update (1 )
132132 pbar .set_postfix_str (f'gpu memory: { torch .cuda .memory_allocated (future .gpu_idx )/ 1024 ** 3 :.2f} GB' )
133133 output_queue .put (future .gpu_idx )
134- torch .save (future .result (), quant_tmp / f"layer_{ future .layer_idx } .pt" )
134+ torch .save (future .result (), self . quant_cache_dir / f"layer_{ future .layer_idx } .pt" )
135135
136136 for layer_idx ,layer in enumerate (attention_layers ):
137- if (quant_tmp / f"layer_{ layer_idx } .pt" ).exists ():
137+ if (self . quant_cache_dir / f"layer_{ layer_idx } .pt" ).exists ():
138138 import warnings
139139 warnings .simplefilter (action = 'ignore' , category = FutureWarning )
140- attention_layers [layer_idx ] = torch .load (quant_tmp / f"layer_{ layer_idx } .pt" , weights_only = False )
140+ attention_layers [layer_idx ] = torch .load (self . quant_cache_dir / f"layer_{ layer_idx } .pt" , weights_only = False )
141141 pbar .update (1 )
142142 continue
143143 free_gpu_id = output_queue .get ()
@@ -178,15 +178,14 @@ def do_quantize(self, model, dataloader, model_prefix, dev):
178178 vptq_quantizer = InternalVPTQQuantizer ()
179179 quantize_layer = vptq_quantizer .quantize_layer
180180 quantizers = {}
181- quant_tmp = Path ("quant_tmp" )
182- quant_tmp .mkdir (exist_ok = True )
181+ self .quant_cache_dir .mkdir (exist_ok = True )
183182
184183 if num_gpus > 1 :
185184 self .parallel_quantize (quantize_layer , attention_layers , num_gpus , dev )
186185 else :
187186 for layer_idx in tqdm .trange ((len (attention_layers )), desc = "running VPTQ" ):
188- if (quant_tmp / f"layer_{ layer_idx } .pt" ).exists ():
189- attention_layers [layer_idx ] = torch .load (quant_tmp / f"layer_{ layer_idx } .pt" , weights_only = False )
187+ if (self . quant_cache_dir / f"layer_{ layer_idx } .pt" ).exists ():
188+ attention_layers [layer_idx ] = torch .load (self . quant_cache_dir / f"layer_{ layer_idx } .pt" , weights_only = False )
190189 continue
191190 attention_layers [layer_idx ] = quantize_layer (
192191 (attention_layers [layer_idx ], layer_idx ), self .quant_config , self .quant_config ,
0 commit comments