Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from loguru import logger
from safetensors import safe_open
from torch.nn import Module
from tqdm import tqdm


__all__ = [
Expand Down Expand Up @@ -95,7 +96,10 @@ def load_pretrained_quantization_parameters(


def apply_quantization_config(
model: Module, config: QuantizationConfig | None, run_compressed: bool = False
model: Module,
config: QuantizationConfig | None,
run_compressed: bool = False,
show_progress: bool = True,
Comment thread
kylesayrs marked this conversation as resolved.
):
"""
Initializes the model for quantization in-place based on the given config.
Expand All @@ -105,6 +109,7 @@ def apply_quantization_config(
:param config: quantization config
:param run_compressed: Whether the model will be run in compressed mode or
decompressed fully on load
:param show_progress: Whether to show progress bar during quantization
"""
config = deepcopy(config)
if config is None: # see PR #180
Expand All @@ -128,8 +133,14 @@ def apply_quantization_config(
target_to_scheme[target] = scheme

# mark appropriate layers for quantization by setting their quantization schemes
for name, submodule in match_named_modules(
model, target_to_scheme, config.ignore, warn_on_fail=True
matched_modules = list(
match_named_modules(model, target_to_scheme, config.ignore, warn_on_fail=True)
)

for name, submodule in tqdm(
matched_modules,
desc="Applying quantization config",
disable=not show_progress,
):
# mark modules to be quantized by adding
# quant scheme to the matching layers
Expand Down
Loading