|
29 | 29 | initialize_module_for_quantization, |
30 | 30 | is_attention_module, |
31 | 31 | ) |
| 32 | +from compressed_tensors.quantization.quant_args import QuantizationArgs |
32 | 33 | from compressed_tensors.quantization.quant_config import ( |
33 | 34 | QuantizationConfig, |
34 | 35 | QuantizationStatus, |
@@ -133,21 +134,11 @@ def apply_quantization_config( |
133 | 134 | # force zero points during initialization |
134 | 135 | force_zero_point = config.quantization_status != QuantizationStatus.COMPRESSED |
135 | 136 |
|
136 | | - # apply kv cache quantization before any attention quantization |
137 | | - # because attention quantization is a superset of kv cache quantization |
| 137 | + # apply and initialize kv cache quantization |
138 | 138 | if config.kv_cache_scheme is not None: |
139 | | - scheme = QuantizationScheme( |
140 | | - targets=[".*self_attn$"], input_activations=config.kv_cache_scheme |
| 139 | + _apply_kv_cache_scheme( |
| 140 | + model, config.kv_cache_scheme, config.quantization_status, force_zero_point |
141 | 141 | ) |
142 | | - for submodule in model.modules(): |
143 | | - if is_attention_module(submodule): |
144 | | - submodule.quantization_scheme = scheme |
145 | | - initialize_hooked_kv_cache(model, submodule) |
146 | | - initialize_module_for_quantization( |
147 | | - submodule, |
148 | | - force_zero_point=force_zero_point, |
149 | | - ) |
150 | | - submodule.quantization_status = config.quantization_status |
151 | 142 |
|
152 | 143 | # build mapping of targets to schemes for easier matching |
153 | 144 | # use ordered dict to preserve target ordering in config |
@@ -196,6 +187,29 @@ def apply_quantization_config( |
196 | 187 | submodule.quantization_status = config.quantization_status |
197 | 188 |
|
198 | 189 |
|
| 190 | +def _apply_kv_cache_scheme( |
| 191 | + model: torch.nn.Module, |
| 192 | + kv_cache_scheme: QuantizationArgs, |
| 193 | + status: QuantizationStatus, |
| 194 | + force_zero_point: bool, |
| 195 | +): |
| 196 | + # applies and initializes kv cache quantization |
| 197 | + # this step cannot come after attention apply/initialize |
| 198 | + # otherwise it will override the attention qparams |
| 199 | + scheme = QuantizationScheme( |
| 200 | + targets=[".*self_attn$"], input_activations=kv_cache_scheme |
| 201 | + ) |
| 202 | + for submodule in model.modules(): |
| 203 | + if is_attention_module(submodule): |
| 204 | + submodule.quantization_scheme = scheme |
| 205 | + initialize_hooked_kv_cache(model, submodule) |
| 206 | + initialize_module_for_quantization( |
| 207 | + submodule, |
| 208 | + force_zero_point=force_zero_point, |
| 209 | + ) |
| 210 | + submodule.quantization_status = status |
| 211 | + |
| 212 | + |
199 | 213 | @deprecated( |
200 | 214 | message="This function is deprecated and will be removed in a future release." |
201 | 215 | "Please use `match_targets` from `compressed_tensors.utils.match` instead." |
|
0 commit comments