diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index f85551635..73107d64b 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -30,6 +30,7 @@ from loguru import logger from safetensors import safe_open from torch.nn import Module +from tqdm import tqdm __all__ = [ @@ -95,7 +96,10 @@ def load_pretrained_quantization_parameters( def apply_quantization_config( - model: Module, config: QuantizationConfig | None, run_compressed: bool = False + model: Module, + config: QuantizationConfig | None, + run_compressed: bool = False, + show_progress: bool = True, ): """ Initializes the model for quantization in-place based on the given config. @@ -105,6 +109,7 @@ def apply_quantization_config( :param config: quantization config :param run_compressed: Whether the model will be run in compressed mode or decompressed fully on load + :param show_progress: Whether to show progress bar during quantization """ config = deepcopy(config) if config is None: # see PR #180 @@ -128,8 +133,14 @@ def apply_quantization_config( target_to_scheme[target] = scheme # mark appropriate layers for quantization by setting their quantization schemes - for name, submodule in match_named_modules( - model, target_to_scheme, config.ignore, warn_on_fail=True + matched_modules = list( + match_named_modules(model, target_to_scheme, config.ignore, warn_on_fail=True) + ) + + for name, submodule in tqdm( + matched_modules, + desc="Applying quantization config", + disable=not show_progress, ): # mark modules to be quantized by adding # quant scheme to the matching layers