diff --git a/README.md b/README.md
index f2ede89ed..7aba1007c 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
# ADI MAX78000/MAX78002 Model Training and Synthesis
-August 27, 2024
+November 7, 2024
**Note: This branch requires PyTorch 2. Please see the archive-1.8 branch for PyTorch 1.8 support. [KNOWN_ISSUES](KNOWN_ISSUES.txt) contains a list of known issues.**
@@ -1636,7 +1636,7 @@ Quantization-aware training can be disabled by specifying `--qat-policy N
The proper choice of `start_epoch` is important for achieving good results, and the default policy’s `start_epoch` may be much too small. As a rule of thumb, set `start_epoch` to a very high value (e.g., 1000) to begin, and then observe where in the training process the model stops learning. This epoch can be used as `start_epoch`, and the final network metrics (after an additional number of epochs) should be close to the non-QAT metrics. *Additionally, ensure that the learning rate after the `start_epoch` epoch is relatively small.*
-For more information, please also see [Quantization](#quantization).
+For more information, please also see [Quantization](#quantization) and [QATv2](https://github.com/analogdevicesinc/ai8x-training/blob/develop/docs/QATv2.md).
#### Batch Normalization
diff --git a/README.pdf b/README.pdf
index b426a1f14..c954703df 100644
Binary files a/README.pdf and b/README.pdf differ
diff --git a/ai8x.py b/ai8x.py
index ee66f9d4b..665484aaa 100644
--- a/ai8x.py
+++ b/ai8x.py
@@ -19,6 +19,8 @@
from torch.autograd import Function
from torch.fx import symbolic_trace
+from tqdm import tqdm
+
import devices
dev = None
@@ -435,45 +437,6 @@ def forward(self, _, x): # pylint: disable=arguments-differ
return x
-def interp(x, xp, fp, method='linear'):
- """
- Simple PyTorch implementation of `np.interp`.
- 1D data only, length must be 2 or greater.
- `method` must be "linear" or "lower".
- """
- # Find the index
- n = len(xp) - 1
- if n == 0:
- return fp[0]
- if x == 1.:
- return fp[-1]
- i = torch.clip(torch.searchsorted(xp, x, side='right').unsqueeze(0), 1, n) - 1
- # Calculate fractional index
- if method == 'linear':
- g = x * n - i
- else:
- assert method == 'lower'
- g = .0
- # Interpolate result
- return fp[i] + g * (fp[i + 1] - fp[i])
-
-
-def quantile(x, q, method='linear'):
- """
- Ersatz quantile function in PyTorch that works with torch.compile().
- 1D data only, len(x) must be 2 or greater.
- `method` must be "linear" or "lower".
- """
- x = x.flatten()
- n = len(x)
- return interp(
- q,
- torch.linspace(1 / (2 * n), (2 * n - 1) / (2 * n), n, device=x.device),
- torch.sort(x)[0],
- method,
- ).squeeze(0)
-
-
class OutputShiftLimit(nn.Module):
"""
Calculate the clamped output shift when adjusting during quantization-aware training.
@@ -484,7 +447,7 @@ def __init__(self, shift_quantile=1.0):
def forward(self, x, _): # pylint: disable=arguments-differ
"""Forward prop"""
- limit = quantile(x.abs(), self.shift_quantile)
+ limit = torch.quantile(x.abs(), self.shift_quantile)
return -(1./limit).log2().floor().clamp(min=-15., max=15.)
@@ -2265,6 +2228,26 @@ def apply_scales(model):
requires_grad=False)
+@torch.no_grad()
+def stat_collect(train_loader, model, args):
+ """Collect statistics for quantization aware training"""
+ model.eval()
+ for inputs, _ in tqdm(train_loader):
+ inputs = inputs.to(args.device)
+ model(inputs)
+
+
+def pre_qat(model, train_loader, args, qat_policy):
+ """
+ Prepare the model for quantization aware training
+ """
+ init_hist(model)
+ stat_collect(train_loader, model, args)
+ init_threshold(model, qat_policy["outlier_removal_z_score"])
+ release_hist(model)
+ apply_scales(model)
+
+
def init_hist(model):
"""
Place forward hooks to collect histograms of activations
diff --git a/docs/QATv2-Adds.png b/docs/QATv2-Adds.png
new file mode 100644
index 000000000..862bdcb9a
Binary files /dev/null and b/docs/QATv2-Adds.png differ
diff --git a/docs/QATv2-Apply Scales.png b/docs/QATv2-Apply Scales.png
new file mode 100644
index 000000000..200c780ad
Binary files /dev/null and b/docs/QATv2-Apply Scales.png differ
diff --git a/docs/QATv2-Concats.png b/docs/QATv2-Concats.png
new file mode 100644
index 000000000..6a7ef1e3a
Binary files /dev/null and b/docs/QATv2-Concats.png differ
diff --git a/docs/QATv2-Layer Sharing.png b/docs/QATv2-Layer Sharing.png
new file mode 100644
index 000000000..93c1bda72
Binary files /dev/null and b/docs/QATv2-Layer Sharing.png differ
diff --git a/docs/QATv2.md b/docs/QATv2.md
new file mode 100644
index 000000000..195553927
--- /dev/null
+++ b/docs/QATv2.md
@@ -0,0 +1,56 @@
+# Quantization Aware Training (QAT)
+
+This document aims to explain Quantization Aware Training framework for MAX7800x series microcontrollers. QAT for MAX7800x consists of four main stages:
+Activation Statistics Collection, Activation Threshold Determination, Scale Adjustments, and Weights Quantization.
+
+## Activation Statistics Collection
+
+To train a quantization-aware model, the first step is to collect activation statistics. The activation statistics are collected by running the model on the training dataset. The training script includes the activation statistics collection step(stat_collect() function). The activation statistics are the histogram of the activations for each layer.
+
+## Activation Threshold Determination
+
+The collected statistics are use to determine the activation thresholds. To do this, first, an outlier removal step based on z-score is applied to the activation statistics. The default z-score is 8.0, and it can be changed by defining a z-score on the qat policy file. Then, an iterative algorithm [1] that minimizes the quantization error by adjusting the threshold to determine the full activation range. This algorithm finds a balance point in the tradeoff between range and resolution. Scales are calculated as powers of two, making the scaling-down operation more computationally efficient by defining them as bit shift operations at the edge hardware.
+
+## Scale Adjustments
+
+To implement the threshold-based quantization, the scales of the layers are adjusted. The scales are adjusted based on the type of operation that is performed on the layers. The scale adjustments are made for residual additions, concatenations, and layer sharing. Figure 1. shows the scale adjustments for residual additions. In the figure, Layer1 and Layer2 are layers that are added together. The scale of the residual addition is selected as the scale of the layers that are connected to the residual addition.
+
+
+
+Figure 1. Scale Adjustments for Residual Additions
+
+Figure 2. shows the scale adjustments for concatenations. In the figure, Layer1 and Layer2 are layers that are concatenated. The maximum scale of the layers is selected as the scale for the concatenated layer.
+
+
+
+Figure 2. Scale Adjustments for Concatenations
+
+Figure 3. shows the scale adjustments for layer sharing. In the figure, Layer1, Layer2 and Layer3 are layers that share weights. The maximum scale of the layers is selected as the scale for the shared layer.
+
+
+
+Figure 3. Scale Adjustments for Layer Sharing
+
+Figure 4. provides a simplified diagram showing how the scaling-down and scale carry-over operations are implemented. In the diagram, Layer1 and Layer2 represent linear layers with weights w1 and w2, and biases b1 and b2. S1 and S2 represent the activation scales, which are calculated as previously described. As shown, the output of Layer1 is scaled down using the S1 threshold, and the scale carry-over operation is achieved by adjusting
+Layer2’s scale and dividing its biases accordingly.
+
+
+
+Figure 4. Scaling-down and Scale Carry Over Diagram
+
+## Weights Quantization
+
+After determining the activation thresholds and scales, the next step is to quantize the weights. The weights are quantized using the QAT framework, which is based on the method proposed by Jacob et al. [2]. While training the model, weights and biasses are fake quantized to integers. The fake quantization is done by quantizing the weights and biases to integers and then dequantizing them back to floating-point numbers.
+
+## Deploying the Quantized Model
+
+The output shifts from the weights quantization are merged with the scale shifts from the activation quantization to form the final shifts of the quantized model. When the model is deployed, the final layer's scale should be restored to the original scale by multiplying the outputs with the final layer's scale. In the auto-generated C code, the cnn_unload() function is responsible for restoring the final layer's scale. If the cnn_unload() function is not used, the final layer's scale should be restored manually by multiplying the outputs with the final layer's scale. The final layer's scale values can be found at the cnn.c file in the comments section.
+
+
+
+
+## References
+
+[1] [Habi, Hai Victor, et al. "Hptq: Hardware-friendly post training quantization." arXiv preprint arXiv:2109.09113 (2021).](https://arxiv.org/abs/2109.09113)
+
+[2] [Jacob, B., Kligys, S., Chen, B., Zhu, M., Tang, M., Howard, A., ... & Kalenichenko, D. (2018). Quantization and training of neural networks for efficient integer-arithmetic-only inference. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 2704-2713).](https://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf)
diff --git a/train.py b/train.py
index 04fb6435b..7d3a00b63 100755
--- a/train.py
+++ b/train.py
@@ -101,7 +101,6 @@
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
from pytorch_metric_learning.utils.inference import CustomKNN
from torchmetrics.detection import MeanAveragePrecision
-from tqdm import tqdm
import ai8x
import ai8x_nas
@@ -608,15 +607,10 @@ def flush(self):
# Fuse the BN parameters into conv layers before Quantization Aware Training (QAT)
ai8x.fuse_bn_layers(model)
- ai8x.init_hist(model)
msglogger.info('Collecting statistics for quantization aware training (QAT)...')
- stat_collect(train_loader, model, args)
- ai8x.init_threshold(model, qat_policy["outlier_removal_z_score"])
- ai8x.release_hist(model)
-
- ai8x.apply_scales(model)
+ ai8x.pre_qat(model, train_loader, args, qat_policy)
# Update the optimizer to reflect fused batchnorm layers
optimizer = ai8x.update_optimizer(model, optimizer)
@@ -646,6 +640,12 @@ def flush(self):
torch._dynamo.reset() # pylint: disable=protected-access
model = torch.compile(model, mode=args.compiler_mode,
backend=args.compiler_backend)
+
+ # TODO: Optimize DDP is currently not supported with QAT.
+ # Once pytorch supports DDP with higher order ops,
+ # we can enable optimize DDP with QAT.
+ # https://github.com/pytorch/pytorch/issues/104674.
+ torch._dynamo.config.optimize_ddp = False # pylint: disable=protected-access
msglogger.info(
'torch.compile() successful, mode=%s, cache limit=%d',
args.compiler_mode,
@@ -740,7 +740,7 @@ def flush(self):
if not args.dr:
test(test_loader, model, criterion, [pylogger], args=args, mode="ckpt")
test(test_loader, model, criterion, [pylogger], args=args, mode="best",
- ckpt_name=checkpoint_name)
+ ckpt_name=checkpoint_name, local_rank=local_rank)
if args.copy_output_folder and local_rank <= 0:
msglogger.info('Copying output folder to: %s', args.copy_output_folder)
@@ -850,15 +850,6 @@ def create_nas_kd_policy(model, compression_scheduler, epoch, next_state_start_e
' | '.join([f'{val:.2f}' for val in dlw]))
-@torch.no_grad()
-def stat_collect(train_loader, model, args):
- """Collect statistics for quantization aware training"""
- model.eval()
- for inputs, _ in tqdm(train_loader):
- inputs = inputs.to(args.device)
- model(inputs)
-
-
def train(train_loader, model, criterion, optimizer, epoch,
compression_scheduler, loggers, args, loss_optimizer=None):
"""Training loop for one epoch."""
@@ -1082,7 +1073,7 @@ def validate(val_loader, model, criterion, loggers, args, epoch=-1, tflogger=Non
return _validate(val_loader, model, criterion, loggers, args, epoch, tflogger)
-def test(test_loader, model, criterion, loggers, args, mode='ckpt', ckpt_name=None):
+def test(test_loader, model, criterion, loggers, args, mode='ckpt', ckpt_name=None, local_rank=0):
"""Model Test"""
assert msglogger is not None
if mode == 'ckpt':
@@ -1090,11 +1081,32 @@ def test(test_loader, model, criterion, loggers, args, mode='ckpt', ckpt_name=No
top1, top5, vloss, mAP = _validate(test_loader, model, criterion, loggers, args)
else:
msglogger.info('--- test (best) ---------------------')
- if ckpt_name is None:
- best_ckpt_path = os.path.join(msglogger.logdir, 'best.pth.tar')
- else:
- best_ckpt_path = os.path.join(msglogger.logdir, ckpt_name + "_best.pth.tar")
- model = apputils.load_lean_checkpoint(model, best_ckpt_path)
+ model, dynamo, ddp = model_wrapper.unwrap(model)
+ if local_rank <= 0:
+ if ckpt_name is None:
+ best_ckpt_path = os.path.join(msglogger.logdir, 'best.pth.tar')
+ else:
+ best_ckpt_path = os.path.join(msglogger.logdir, ckpt_name + "_best.pth.tar")
+ model = apputils.load_lean_checkpoint(model, best_ckpt_path)
+
+ if ddp:
+ model = DistributedDataParallel(
+ model,
+ device_ids=[local_rank] if args.device == 'cuda' else None,
+ output_device=local_rank if args.device == 'cuda' else None,
+ )
+
+ if dynamo:
+ torch._dynamo.reset() # pylint: disable=protected-access
+ model = torch.compile(model, mode=args.compiler_mode,
+ backend=args.compiler_backend)
+ torch._dynamo.config.optimize_ddp = False # pylint: disable=protected-access
+ msglogger.info(
+ 'torch.compile() successful, mode=%s, cache limit=%d',
+ args.compiler_mode,
+ torch._dynamo.config.cache_size_limit, # pylint: disable=protected-access
+ )
+
top1, top5, vloss, mAP = _validate(test_loader, model, criterion, loggers, args)
return top1, top5, vloss, mAP