This is a step-by-step tutorial on how to integrate the NNCF package into the existing PyTorch or TensorFlow projects. The use case implies that the user already has a training pipeline that reproduces training of the model in the floating point precision and pretrained model. The task is to prepare this model for accelerated inference by simulating the compression at train time. Please refer to this document for details of the implementation.
Quantize the model using the Post Training Quantization method.
PyTorch
model = TorchModel() # instance of torch.nn.Module
quantized_model = nncf.quantize(model, ...)
TensorFlow
model = TensorFlowModel() # instance of tf.keras.Model
quantized_model = nncf.quantize(model, ...)
At this point, the NNCF is fully integrated into your training pipeline. You can run it as usual and monitor your original model's metrics and/or compression algorithm metrics and balance model metrics quality vs. level of compression.
Important points you should consider when training your networks with compression algorithms:
- Turn off the
Dropout
layers (and similar ones likeDropConnect
) when training a network with quantization
After the compressed model has been fine-tuned to acceptable accuracy and compression stages, you can export it.
PyTorch
Trace the model via inference in framework operations.
# To OpenVINO format
import openvino as ov
ov_quantized_model = ov.convert_model(quantized_model.cpu(), example_input=dummy_input)
TensorFlow
# To OpenVINO format
import openvino as ov
# Removes auxiliary layers and operations added during the quantization process,
# resulting in a clean, fully quantized model ready for deployment.
stripped_model = nncf.strip(quantized_model)
ov_quantized_model = ov.convert_model(stripped_model)
PyTorch
The complete information about compression is defined by a compressed model and a NNCF config.
The model characterizes the weights and topology of the network. The NNCF config - how to restore additional modules intoduced by NNCF.
The NNCF config can be obtained by quantized_model.nncf.get_config()
on saving and passed to the
nncf.torch.load_from_config
helper function to load additional modules from the given NNCF config.
The quantized model saving allows to load quantized modules to the target model in a new python process and
requires only example input for the target module, corresponding NNCF config and the quantized model state dict.
# save part
quantized_model = nncf.quantize(model, calibration_dataset)
checkpoint = {
'state_dict':quantized_model.state_dict(),
'nncf_config': quantized_model.nncf.get_config(),
...
}
torch.save(checkpoint, path)
# load part
resuming_checkpoint = torch.load(path)
nncf_config = resuming_checkpoint['nncf_config']
state_dict = resuming_checkpoint['state_dict']
quantized_model = nncf.torch.load_from_config(model, nncf_config, dummy_input)
quantized_model.load_state_dict(state_dict)
You can save the compressed_model
object torch.save
as usual: via state_dict
and load_state_dict
methods.
TensorFlow
To save a model checkpoint, use the following API:
from nncf.tensorflow import ConfigState
from nncf.tensorflow import get_config
from nncf.tensorflow.callbacks.checkpoint_callback import CheckpointManagerCallback
nncf_config = get_config(quantized_model)
checkpoint = tf.train.Checkpoint(model=quantized_model,
nncf_config_state=ConfigState(nncf_config),
... # the rest of the user-defined objects to save
)
callbacks = []
callbacks.append(CheckpointManagerCallback(checkpoint, path_to_checkpoint))
...
quantized_model.fit(..., callbacks=callbacks)
To restore the model from checkpoint, use the following API:
from nncf.tensorflow import ConfigState
from nncf.tensorflow import load_from_config
checkpoint = tf.train.Checkpoint(nncf_config_state=ConfigState())
checkpoint.restore(path_to_checkpoint)
quantized_model = load_from_config(model, checkpoint.nncf_config_state.config)
checkpoint = tf.train.Checkpoint(model=quantized_model
... # the rest of the user-defined objects to load
)
checkpoint.restore(path_to_checkpoint)
PyTorch
With no target model code modifications, NNCF only supports native PyTorch modules with respect to trainable parameter (weight) compressed, such as torch.nn.Conv2d
.
If your model contains a custom, non-PyTorch standard module with trainable weights that should be compressed, you can register it using the @nncf.register_module
decorator:
import nncf
@nncf.register_module(ignored_algorithms=[...])
class MyModule(torch.nn.Module):
def __init__(self, ...):
self.weight = torch.nn.Parameter(...)
# ...
If registered module should be ignored by specific algorithms use ignored_algorithms
parameter of decorator.
In the example above, the NNCF-compressed models that contain instances of MyModule
will have the corresponding modules extended with functionality that will allow NNCF to quantize the weight
parameter of MyModule
before it takes part in MyModule
's forward
calculation.