Skip to content

Commit 7e45a94

Browse files
committed
break out function
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 5eac577 commit 7e45a94

File tree

1 file changed

+27
-13
lines changed
  • src/compressed_tensors/quantization/lifecycle

1 file changed

+27
-13
lines changed

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
initialize_module_for_quantization,
3030
is_attention_module,
3131
)
32+
from compressed_tensors.quantization.quant_args import QuantizationArgs
3233
from compressed_tensors.quantization.quant_config import (
3334
QuantizationConfig,
3435
QuantizationStatus,
@@ -133,21 +134,11 @@ def apply_quantization_config(
133134
# force zero points during initialization
134135
force_zero_point = config.quantization_status != QuantizationStatus.COMPRESSED
135136

136-
# apply kv cache quantization before any attention quantization
137-
# because attention quantization is a superset of kv cache quantization
137+
# apply and initialize kv cache quantization
138138
if config.kv_cache_scheme is not None:
139-
scheme = QuantizationScheme(
140-
targets=[".*self_attn$"], input_activations=config.kv_cache_scheme
139+
_apply_kv_cache_scheme(
140+
model, config.kv_cache_scheme, config.quantization_status, force_zero_point
141141
)
142-
for submodule in model.modules():
143-
if is_attention_module(submodule):
144-
submodule.quantization_scheme = scheme
145-
initialize_hooked_kv_cache(model, submodule)
146-
initialize_module_for_quantization(
147-
submodule,
148-
force_zero_point=force_zero_point,
149-
)
150-
submodule.quantization_status = config.quantization_status
151142

152143
# build mapping of targets to schemes for easier matching
153144
# use ordered dict to preserve target ordering in config
@@ -196,6 +187,29 @@ def apply_quantization_config(
196187
submodule.quantization_status = config.quantization_status
197188

198189

190+
def _apply_kv_cache_scheme(
191+
model: torch.nn.Module,
192+
kv_cache_scheme: QuantizationArgs,
193+
status: QuantizationStatus,
194+
force_zero_point: bool,
195+
):
196+
# applies and initializes kv cache quantization
197+
# this step cannot come after attention apply/initialize
198+
# otherwise it will override the attention qparams
199+
scheme = QuantizationScheme(
200+
targets=[".*self_attn$"], input_activations=kv_cache_scheme
201+
)
202+
for submodule in model.modules():
203+
if is_attention_module(submodule):
204+
submodule.quantization_scheme = scheme
205+
initialize_hooked_kv_cache(model, submodule)
206+
initialize_module_for_quantization(
207+
submodule,
208+
force_zero_point=force_zero_point,
209+
)
210+
submodule.quantization_status = status
211+
212+
199213
@deprecated(
200214
message="This function is deprecated and will be removed in a future release."
201215
"Please use `match_targets` from `compressed_tensors.utils.match` instead."

0 commit comments

Comments
 (0)