diff --git a/.github/linters/.python-lint b/.github/linters/.python-lint index c4f4a6f6c..c194cc279 100644 --- a/.github/linters/.python-lint +++ b/.github/linters/.python-lint @@ -7,11 +7,11 @@ ignored-classes = ModelProto max-line-length = 99 [DESIGN] max-locals=100 -max-statements=350 +max-statements=360 min-public-methods=1 max-branches=130 max-module-lines=5000 max-args=20 max-returns=10 -max-attributes=25 +max-attributes=30 max-nested-blocks=10 diff --git a/README.md b/README.md index 4df6c395d..f2ede89ed 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # ADI MAX78000/MAX78002 Model Training and Synthesis -July 22, 2024 +August 27, 2024 **Note: This branch requires PyTorch 2. Please see the archive-1.8 branch for PyTorch 1.8 support. [KNOWN_ISSUES](KNOWN_ISSUES.txt) contains a list of known issues.** @@ -1620,13 +1620,15 @@ When using the `-8` command line switch, all module outputs are quantized to 8-b The last layer can optionally use 32-bit output for increased precision. This is simulated by adding the parameter `wide=True` to the module function call. -##### Weights: Quantization-Aware Training (QAT) +##### Weights and Activations: Quantization-Aware Training (QAT) Quantization-aware training (QAT) is enabled by default. QAT is controlled by a policy file, specified by `--qat-policy`. -* After `start_epoch` epochs, training will learn an additional parameter that corresponds to a shift of the final sum of products. +* After `start_epoch` epochs, an intermediate epoch with no backpropagation will be realized to collect activation statistics. Each layer's activation ranges will be determined based on the range & resolution trade-off from the collected activations. Then, QAT will start and an additional parameter (`output_shift`) will be learned to shift activations for compensating weights & biases scaling down. * `weight_bits` describes the number of bits available for weights. * `overrides` allows specifying the `weight_bits` on a per-layer basis. +* `outlier_removal_z_score` defines the z-score threshold for outlier removal during activation range calculation. (default: 8.0) +* `shift_quantile` defines the quantile of the parameters distribution to be used for the `output_shift` parameter. (default: 1.0) By default, weights are quantized to 8-bits after 30 epochs as specified in `policies/qat_policy.yaml`. A more refined example that specifies weight sizes for individual layers can be seen in `policies/qat_policy_cifar100.yaml`. @@ -1745,7 +1747,7 @@ For both approaches, the `quantize.py` software quantizes an existing PyTorch ch #### Quantization-Aware Training (QAT) -Quantization-aware training is the better performing approach. It is enabled by default. QAT learns additional parameters during training that help with quantization (see [Weights: Quantization-Aware Training (QAT)](#weights-quantization-aware-training-qat). No additional arguments (other than input, output, and device) are needed for `quantize.py`. +Quantization-aware training is the better performing approach. It is enabled by default. QAT learns additional parameters during training that help with quantization (see [Weights and Activations: Quantization-Aware Training (QAT)](#weights-and-activations-quantization-aware-training-qat). No additional arguments (other than input, output, and device) are needed for `quantize.py`. The input checkpoint to `quantize.py` is either `qat_best.pth.tar`, the best QAT epoch’s checkpoint, or `qat_checkpoint.pth.tar`, the final QAT epoch’s checkpoint. @@ -2004,7 +2006,7 @@ The behavior of a training session might change when Quantization Aware Training While there can be multiple reasons for this, check two important settings that can influence the training behavior: * The initial learning rate may be set too high. Reduce LR by a factor of 10 or 100 by specifying a smaller initial `--lr` on the command line, and possibly by reducing the epoch `milestones` for further reduction of the learning rate in the scheduler file specified by `--compress`. Note that the the selected optimizer and the batch size both affect the learning rate. -* The epoch when QAT is engaged may be set too low. Increase `start_epoch` in the QAT scheduler file specified by `--qat-policy`, and increase the total number of training epochs by increasing the value specified by the `--epochs` command line argument and by editing the `ending_epoch` in the scheduler file specified by `--compress`. *See also the rule of thumb discussed in the section [Weights: Quantization-Aware Training (QAT)](#weights:-auantization-aware-training \(qat\)).* +* The epoch when QAT is engaged may be set too low. Increase `start_epoch` in the QAT scheduler file specified by `--qat-policy`, and increase the total number of training epochs by increasing the value specified by the `--epochs` command line argument and by editing the `ending_epoch` in the scheduler file specified by `--compress`. *See also the rule of thumb discussed in the section [Weights and Activations: Quantization-Aware Training (QAT)](#weights-and-activations-quantization-aware-training-qat).* @@ -2209,6 +2211,7 @@ The following table describes the most important command line arguments for `ai8 | `--no-unload` | Do not create the `cnn_unload()` function | | | `--no-kat` | Do not generate the `check_output()` function (disable known-answer test) | | | `--no-deduplicate-weights` | Do not deduplicate weights and and bias values | | +| `--no-scale-output` | Do not use scales from the checkpoint to recover output range while generating `cnn_unload()` function | | ### YAML Network Description @@ -2330,6 +2333,12 @@ The following keywords are required for each `unload` list item: `width`: Data width (optional, defaults to 8) — either 8 or 32 `write_gap`: Gap between data words (optional, defaults to 0) +When `--no-scale-output` is not specified, scales from the checkpoint file are used to recover the output range. If there is a non-zero scale for the 8 bits output, the output will be scaled and kept in 16 bits. If the scale is zero, the output will be 8 bits. For 32 bits output, the output will be kept in 32 bits always. + +Example: + +![Unload Array](docs/unload_example.png) + ##### `layers` (Mandatory) `layers` is a list that defines the per-layer description, as shown below: @@ -2654,7 +2663,7 @@ Example: By default, the final layer is used as the output layer. Output layers are checked using the known-answer test, and they are copied from hardware memory when `cnn_unload()` is called. The tool also checks that output layer data isn’t overwritten by any later layers. When specifying `output: true`, any layer (or a combination of layers) can be used as an output layer. -*Note:* When `unload:` is used, output layers are not used for generating `cnn_unload()`. +*Note:* When `--no-unload` is used, output layers are not used for generating `cnn_unload()`. Example: `output: true` diff --git a/README.pdf b/README.pdf index 3eca99513..b426a1f14 100644 Binary files a/README.pdf and b/README.pdf differ diff --git a/ai8x.py b/ai8x.py index 277cd6ff3..ee66f9d4b 100644 --- a/ai8x.py +++ b/ai8x.py @@ -1,6 +1,6 @@ ################################################################################################### # -# Copyright (C) 2020-2023 Maxim Integrated Products, Inc. All Rights Reserved. +# Copyright (C) 2020-2024 Maxim Integrated Products, Inc. All Rights Reserved. # # Maxim Integrated Products, Inc. Default Copyright Notice: # https://www.maximintegrated.com/en/aboutus/legal/copyrights.html @@ -13,9 +13,11 @@ the limits into account. """ +import numpy as np import torch from torch import nn from torch.autograd import Function +from torch.fx import symbolic_trace import devices @@ -327,7 +329,7 @@ def forward(self, x): # pylint: disable=arguments-differ return x.mul(factor).floor().div(factor) -def quantize_clamp(wide, quantize_activation=False, weight_bits=8): +def quantize_clamp(wide, quantize_activation=False, clamp_activation=False, weight_bits=8): """ Return new Quantization and Clamp objects. """ @@ -352,21 +354,25 @@ def quantize_clamp(wide, quantize_activation=False, weight_bits=8): quantize = Quantize(num_bits=dev.WIDE_LAYER_RESOLUTION_BITS) else: quantize = Empty() - if not wide: - clamp = Clamp( # Do not combine with ReLU - min_val=-1., - max_val=(2.**(dev.ACTIVATION_BITS-1)-1)/(2.**(dev.ACTIVATION_BITS-1)), - ) + + if clamp_activation: + if not wide: + clamp = Clamp( # Do not combine with ReLU + min_val=-1., + max_val=(2.**(dev.ACTIVATION_BITS-1)-1)/(2.**(dev.ACTIVATION_BITS-1)), + ) + else: + clamp = Clamp( + min_val=-(2.**((dev.FULL_ACC_BITS-2*(dev.DATA_BITS-1))-1)), + max_val=2.**((dev.FULL_ACC_BITS-2*(dev.DATA_BITS-1))-1), + ) else: - clamp = Clamp( - min_val=-(2.**((dev.FULL_ACC_BITS-2*(dev.DATA_BITS-1))-1)), - max_val=2.**((dev.FULL_ACC_BITS-2*(dev.DATA_BITS-1))-1), - ) + clamp = Empty() return quantize, clamp -def quantize_clamp_pool(pooling, quantize_activation=False): +def quantize_clamp_pool(pooling, quantize_activation=False, clamp_activation=False): """ Return new Quantization and Clamp objects for pooling. """ @@ -385,7 +391,10 @@ def quantize_clamp_pool(pooling, quantize_activation=False): if pooling == 'Avg': if quantize_activation: quantize = RoundQat() if dev.round_avg else FloorQat() - clamp = Clamp(min_val=-1., max_val=127./128.) + if clamp_activation: + clamp = Clamp(min_val=-1., max_val=127./128.) + else: + clamp = Empty() else: # Max, None clamp = Empty() @@ -494,7 +503,7 @@ class One(nn.Module): """ def forward(self, x): # pylint: disable=arguments-differ """Forward prop""" - return torch.ones(1, device=x.device) + return torch.ones(1).to(x.device) class WeightScale(nn.Module): @@ -563,6 +572,152 @@ def get_activation(activation=None): return Empty() +def histogram(inp, bins): + """ + CUDA compatible histogram calculation + """ + minimum, maximum = inp.min(), inp.max() + counts = torch.histc(inp, bins, min=minimum, max=maximum).cpu() + boundaries = torch.linspace(minimum, maximum, bins + 1) + return counts, boundaries + + +def calc_q_error(module, threshold, bits, eps=1e-9): + """ + Activation quantization error calculation + """ + quantized_hist = module.hist[1].clone() + quantized_hist = torch.round((quantized_hist / (threshold + eps)) * 2**(bits-1)) + quantized_hist = torch.clamp(quantized_hist, -2**(bits-1), 2**(bits-1)-1) + quantized_hist = (quantized_hist * (threshold + eps) / 2**(bits-1)) + err = torch.sum(((quantized_hist - module.hist[1])**2)*module.hist[0]) \ + / torch.sum(module.hist[0]) + + return err + + +def _merge_hist(module): + """ + Merge histograms of activations + """ + bins_to_stack = [] + for hist in module.hist: + bins_to_stack.append(hist[1]) + stacked_bins = torch.stack(bins_to_stack) + min_edge = stacked_bins.min() + max_edge = stacked_bins.max() + # 2048 is the number of bins and 2049 is the number of edges + merged_bins = torch.linspace(min_edge.item(), max_edge.item(), 2049) + merged_counts = None + + for hist in module.hist: + if merged_counts is None: + merged_counts = _interpolate_hist(hist[0], hist[1], merged_bins) + else: + merged_counts += _interpolate_hist(hist[0], hist[1], merged_bins) + + module.hist = (merged_counts, merged_bins) + + +def _interpolate_hist(counts, bins, new_bins): + """ + Helper function for interpolating histograms to new bins + """ + cumulative_hist = torch.cumsum(counts, dim=0).to(device=bins.device) + cumulative_hist = torch.cat((torch.tensor([0]), cumulative_hist)) + cumulative_interp_hist = torch.from_numpy(np.interp(new_bins.numpy(), bins.numpy(), + cumulative_hist.numpy())) + interp_counts = torch.diff(cumulative_interp_hist, prepend=torch.tensor([0])) + + return interp_counts + + +# pylint: disable=unused-argument +def _hist_hook(module, inp, output): + """ + Hook to collect histogram of activations + """ + if not hasattr(module, 'hist'): + module.hist = [] + # dynamic histogram collection + hist = histogram(output.clone().detach().flatten(), bins=2048) + module.hist.append(hist) + + +def register_hist_hooks(module): + """ + Register hooks for histogram collection + """ + module.handle = module.register_forward_hook(_hist_hook, always_call=True) + + +def release_hist_hooks(module): + """ + Release hooks after histogram collection + """ + module.handle.remove() + + +def _remove_outliers(module, outlier_removal_z_score=8.0): + """ + Remove outliers from histogram + """ + # Get mean and std of histogram + hist_count = module.hist[0] + hist_bins = module.hist[1] + hist_bins_middle = [] + for i in range(len(hist_bins) - 1): + hist_bins_middle.append((hist_bins[i] + hist_bins[i+1])/2) + hist_bins_middle = torch.tensor(hist_bins_middle) + mean = torch.sum(hist_count[1:] * hist_bins_middle) / torch.sum(hist_count[1:]) + std = torch.sqrt(torch.sum(hist_count[1:] * (hist_bins_middle - mean)**2) + / torch.sum(hist_count[1:])) + + # When activations are very small, std ends up being 0 due to rounding. + # In this case, we set std to a very small value to prevent zero element histogram. + if std == 0: + std = 1e-9 + # Calculate bounds according to z-score + upper_bound = mean + outlier_removal_z_score * std + lower_bound = mean - outlier_removal_z_score * std + hist_bins_middle = torch.cat((torch.tensor([0]), hist_bins_middle)) + # Remove outliers according to bounds + hist_count[hist_bins_middle > upper_bound] = 0 + hist_count[hist_bins_middle < lower_bound] = 0 + non_zero_bins = hist_count != 0 + hist_count = hist_count[non_zero_bins] + hist_bins = hist_bins[non_zero_bins] + module.hist = (hist_count, hist_bins) + + +def init_threshold_module(module, outlier_removal_z_score): + """ + Initialize activation threshold + """ + _merge_hist(module) + _remove_outliers(module, outlier_removal_z_score) + module.activation_threshold = nn.Parameter(module.hist[1].abs().max().log2().ceil().exp2(), + requires_grad=False) + + +def calc_threshold(module, iterations=5, bits=8): + """ + Iteratively calculate threshold for activation quantization + """ + e_min = torch.inf + t_nc = module.activation_threshold + t = None + + for i in range(iterations): + t_i = t_nc / (2**i) + e_i = calc_q_error(module, t_i, bits) + if e_i < e_min: + e_min = e_i + t = t_i + + module.activation_threshold = nn.Parameter(torch.log2(t), requires_grad=False) + + class QuantizationAwareModule(nn.Module): """ Common code for Quantization-Aware Training @@ -579,6 +734,7 @@ def __init__( op=None, bn=None, shift_quantile=1.0, + clamp_activation=False, ): super().__init__() @@ -609,13 +765,20 @@ def __init__( self.pooling = pooling self.output_shift = nn.Parameter(torch.tensor([0.]), requires_grad=False) - self.init_module(weight_bits, bias_bits, quantize_activation, shift_quantile) + # Activation threshold determined during QAT, used in quantization + # It determines the range of quantization + self.activation_threshold = nn.Parameter(torch.tensor(0.), requires_grad=False) + self.final_scale = nn.Parameter(torch.tensor(0.), requires_grad=False) + + self.init_module(weight_bits, bias_bits, quantize_activation, + clamp_activation, shift_quantile) def init_module( self, weight_bits, bias_bits, quantize_activation, + clamp_activation, shift_quantile, export=False, ): @@ -625,12 +788,15 @@ def init_module( self.weight_bits = nn.Parameter(torch.tensor([0]), requires_grad=False) self.bias_bits = nn.Parameter(torch.tensor([0]), requires_grad=False) self.quantize_activation = nn.Parameter(torch.tensor([False]), requires_grad=False) + self.clamp_activation = nn.Parameter(torch.tensor([clamp_activation]), + requires_grad=False) self.adjust_output_shift = nn.Parameter(torch.tensor([False]), requires_grad=False) elif weight_bits in [1, 2, 4, 8] and bias_bits in [1, 2, 4, 8] and quantize_activation: self.weight_bits = nn.Parameter(torch.tensor([weight_bits]), requires_grad=False) if not export: self.bias_bits = nn.Parameter(torch.tensor([bias_bits]), requires_grad=False) self.quantize_activation = nn.Parameter(torch.tensor([True]), requires_grad=False) + self.clamp_activation = nn.Parameter(torch.tensor([True]), requires_grad=False) self.adjust_output_shift = nn.Parameter(torch.tensor([not dev.simulate]), requires_grad=False) else: @@ -659,9 +825,11 @@ def set_functions(self): self.bias_bits.detach().item()) self.quantize, self.clamp = \ quantize_clamp(self.wide, bool(self.quantize_activation.detach().item()), + bool(self.clamp_activation.detach().item()), int(self.weight_bits.detach().item())) self.quantize_pool, self.clamp_pool = \ - quantize_clamp_pool(self.pooling, bool(self.quantize_activation.detach().item())) + quantize_clamp_pool(self.pooling, bool(self.quantize_activation.detach().item()), + bool(self.clamp_activation.detach().item())) def forward(self, x): # pylint: disable=arguments-differ """Forward prop""" @@ -676,8 +844,13 @@ def forward(self, x): # pylint: disable=arguments-differ params_r = torch.flatten(self.op.weight.detach()) out_shift = self.calc_out_shift(params_r, self.output_shift.detach()) weight_scale = self.calc_weight_scale(out_shift) - out_scale = self.calc_out_scale(out_shift) + # Quantized checkpoint will have subtracted threshold from output shift + # Therefore, it shouldn't be done again in simulate mode + if not dev.simulate: + out_shift = (out_shift - self.activation_threshold).clamp(min=-15., max=15.) + + out_scale = self.calc_out_scale(out_shift) x = self._conv_forward( # pylint: disable=protected-access x, self.clamp_weight(self.quantize_weight(self.op.weight.mul(weight_scale))), @@ -686,11 +859,14 @@ def forward(self, x): # pylint: disable=arguments-differ ) if self.bn is not None: - x = self.bn(x).div(4.) + x = self.bn(x) if not self.wide: # The device does not apply output shift in wide mode x = self.scale(x, out_scale) x = self.clamp(self.quantize(self.activate(x))) + + # This is the final scale for the output, in the device it will be realized in SW + x = x.mul(2.**(self.final_scale)) return x @@ -1607,14 +1783,24 @@ class Eltwise(nn.Module): """ Base Class for Elementwise Operation """ - def __init__(self, f): + def __init__(self, f, clamp_activation=False): super().__init__() self.f = f + self.activation_threshold = nn.Parameter(torch.tensor(0.), requires_grad=False) + self.set_clamp(clamp_activation) + + def set_clamp(self, clamp_activation): + """ + Set Clamping Function + """ if dev.simulate: bits = dev.ACTIVATION_BITS self.clamp = Clamp(min_val=-(2**(bits-1)), max_val=2**(bits-1)-1) else: - self.clamp = Clamp(min_val=-1., max_val=127./128.) + if clamp_activation: + self.clamp = Clamp(min_val=-1., max_val=127./128.) + else: + self.clamp = Empty() def forward(self, *x): """Forward prop""" @@ -1822,19 +2008,28 @@ def initiate_qat(m, qat_policy, export=False): if 'shift_quantile' in qat_policy: module.init_module(qat_policy['weight_bits'], qat_policy['weight_bits'], - True, qat_policy['shift_quantile'], export) + True, True, qat_policy['shift_quantile'], export) else: module.init_module(qat_policy['weight_bits'], - qat_policy['weight_bits'], True, 1.0, export) + qat_policy['weight_bits'], True, True, 1.0, export) if 'overrides' in qat_policy: if name in qat_policy['overrides']: - weight_field = qat_policy['overrides'][name]['weight_bits'] - if 'shift_quantile' in qat_policy: - module.init_module(weight_field, weight_field, + if 'weight_bits' in qat_policy['overrides'][name]: + weight_field = qat_policy['overrides'][name]['weight_bits'] + else: + weight_field = qat_policy['weight_bits'] + if 'shift_quantile' in qat_policy['overrides'][name]: + module.init_module(weight_field, weight_field, True, + True, qat_policy['overrides'][name]['shift_quantile'], + export) + elif 'shift_quantile' in qat_policy: + module.init_module(weight_field, weight_field, True, True, qat_policy['shift_quantile'], export) else: module.init_module(weight_field, - weight_field, True, 1.0, export) + weight_field, True, True, 1.0, export) + elif isinstance(module, Eltwise): + module.set_clamp(True) def update_model(m): @@ -1842,14 +2037,9 @@ def update_model(m): Update model `m` with the current parameters. It is used to update model functions after loading a checkpoint file. """ - def _update_model(m): - for attr_str in dir(m): - target_attr = getattr(m, attr_str) - if isinstance(target_attr, QuantizationAwareModule): - target_attr.set_functions() - setattr(m, attr_str, target_attr) - - m.apply(_update_model) + for _, module in m.named_modules(): + if isinstance(module, QuantizationAwareModule): + module.set_functions() def update_optimizer(m, optimizer): @@ -1898,38 +2088,211 @@ def fuse_bn_layers(m): """ Fuse the bn layers before the quantization aware training starts. """ - def _fuse_bn_layers(m): - for attr_str in dir(m): - target_attr = getattr(m, attr_str) - if isinstance(target_attr, QuantizationAwareModule) \ - and target_attr.bn is not None: - w = target_attr.op.weight.data - b = target_attr.op.bias.data - device = w.device - - r_mean = target_attr.bn.running_mean - r_var = target_attr.bn.running_var - r_inv_std = torch.rsqrt(r_var + target_attr.bn.eps) - beta = target_attr.bn.weight - gamma = target_attr.bn.bias - - if beta is None: - beta = torch.ones(w.shape[0], device=device) - if gamma is None: - gamma = torch.zeros(w.shape[0], device=device) - - beta = 0.25 * beta - gamma = 0.25 * gamma - - w_new = w * (beta * r_inv_std).reshape((w.shape[0],) + (1,) * (len(w.shape) - 1)) - b_new = (b - r_mean) * r_inv_std * beta + gamma - - target_attr.op.weight.data = w_new - target_attr.op.bias.data = b_new - target_attr.bn = None - setattr(m, attr_str, target_attr) - - m.apply(_fuse_bn_layers) + for _, module in m.named_modules(): + if isinstance(module, QuantizationAwareModule) and module.bn is not None: + w = module.op.weight.data + b = module.op.bias.data + device = w.device + + r_mean = module.bn.running_mean + r_var = module.bn.running_var + r_inv_std = torch.rsqrt(r_var + module.bn.eps) + beta = module.bn.weight + gamma = module.bn.bias + + if beta is None: + beta = torch.ones(w.shape[0], device=device) + if gamma is None: + gamma = torch.zeros(w.shape[0], device=device) + + w_new = w * (beta * r_inv_std).reshape((w.shape[0],) + (1,) * (len(w.shape) - 1)) + b_new = (b - r_mean) * r_inv_std * beta + gamma + + module.op.weight.data = w_new + module.op.bias.data = b_new + module.bn = None + + +def apply_scales(model): + """ + Readjust the scales and apply according to the model graph. + """ + net_graph = symbolic_trace(model) + adds = {} + concats = {} + prevs = {} + op_names = ["torch.conv2d", "torch.conv1d", "torch.linear", + "torch._C._nn.linear", "torch.conv_transpose2d"] + nodes_to_search = [] + name_prev = None + + # Model graph traversal for finding the adds, concats and previous layers + for node in net_graph.graph.nodes: + name = node.format_node() + if ("torch.add" in name) or ("torch.cat" in name): + nodes_to_search.clear() + if "target=view" in name: + if len(node.all_input_nodes) > 0: + input_node = (node.all_input_nodes)[0] + nodes_to_search.append(input_node) + else: + nodes_to_search.extend(node.all_input_nodes) + for node_prev in reversed(net_graph.graph.nodes): + name_prev = node_prev.format_node() + if any(op_name in name_prev for op_name in op_names): + if node_prev in nodes_to_search: + node_prev_name = next(reversed(node_prev.__dict__['meta'] + ['nn_module_stack'])) + if "torch.add" in name: + node_name = next(reversed(node.__dict__['meta']['nn_module_stack'])) + adds[node_prev_name] = node_name + elif "torch.cat" in name: + concats[node_prev_name] = str(node) + nodes_to_search.pop(nodes_to_search.index(node_prev)) + else: + if node_prev in nodes_to_search: + nodes_to_search.pop(nodes_to_search.index(node_prev)) + if "target=view" in name_prev: + if len(node_prev.all_input_nodes) > 0: + input_node = (node_prev.all_input_nodes)[0] + nodes_to_search.append(input_node) + else: + nodes_to_search.extend(node_prev.all_input_nodes) + + elif any(op_name in name for op_name in op_names): + nodes_to_search.clear() + if len(node.all_input_nodes) > 0: + input_node = (node.all_input_nodes)[0] + nodes_to_search.append(input_node) + for node_prev in reversed(net_graph.graph.nodes): + name_prev = node_prev.format_node() + if any(op_name in name_prev for op_name in op_names): + if node_prev in nodes_to_search: + node_prev_name = next(reversed(node_prev.__dict__['meta'] + ['nn_module_stack'])) + node_name = next(reversed(node.__dict__['meta']['nn_module_stack'])) + if prevs.get(str(node_name)) is None: + node_prevs = [] + node_prevs.append(str(node_prev_name)) + prevs[str(node_name)] = node_prevs + else: + prevs[str(node_name)].append(str(node_prev_name)) + nodes_to_search.pop(nodes_to_search.index(node_prev)) + + else: + for name_node in nodes_to_search: + if node_prev == name_node: + nodes_to_search.extend(node_prev.all_input_nodes) + nodes_to_search.pop(nodes_to_search.index(name_node)) + + # Override the thresholds of layers that are connected to adds + for name, module in model.named_modules(): + if isinstance(module, QuantizationAwareModule): + if name in adds: + for name1, module1 in model.named_modules(): + if isinstance(module1, Eltwise): + if adds[name] == name1: + module.activation_threshold = module1.activation_threshold + break + + # Find the maximum threshold from the layers that are concatenated together + concat_thresholds = {} + for name, module in model.named_modules(): + if isinstance(module, QuantizationAwareModule): + if name in concats: + if concat_thresholds.get(concats[name]) is None: + concat_thresholds[concats[name]] = module.activation_threshold + elif module.activation_threshold > concat_thresholds[concats[name]]: + concat_thresholds[concats[name]] = module.activation_threshold + + # Apply the maximum threshold to the layers that are concatenated together + for name, module in model.named_modules(): + if isinstance(module, QuantizationAwareModule): + if name in concats: + module.activation_threshold = nn.Parameter(concat_thresholds[concats[name]], + requires_grad=False) + + # Find weight sharing layers and apply the maximum threshold from the multiple passes + shared_threshold = {} + for name, module in model.named_modules(): + if isinstance(module, QuantizationAwareModule): + if prevs.get(name) is not None: + for prev in prevs[name]: + for name1, module1 in model.named_modules(): + if isinstance(module1, QuantizationAwareModule): + if prev == name1: + if shared_threshold.get(name) is None: + shared_threshold[name] = module1.activation_threshold + elif module1.activation_threshold > shared_threshold[name]: + shared_threshold[name] = module1.activation_threshold + for prev in prevs[name]: + for name1, module1 in model.named_modules(): + if isinstance(module1, QuantizationAwareModule): + if prev == name1: + module1.activation_threshold = shared_threshold[name] + + # Get the thresholds after overrides + thresholds = {} + for name, module in model.named_modules(): + if isinstance(module, QuantizationAwareModule): + thresholds[name] = module.activation_threshold + + # Adjust bias and threshold values according to the previous layers, + # and set the final scale value for output layers + for name, module in model.named_modules(): + if isinstance(module, QuantizationAwareModule): + if name in prevs: + prev_threshold_set = False + for name1, module1 in model.named_modules(): + if isinstance(module1, QuantizationAwareModule): + if name1 in prevs[name]: + if not prev_threshold_set: + if module.op is not None and module.op.bias is not None: + module.op.bias = nn.Parameter(module.op.bias / + torch.exp2(thresholds[name1])) + module.activation_threshold = \ + nn.Parameter((module.activation_threshold - thresholds[name1]), + requires_grad=False) + if module.wide: + module.final_scale = nn.Parameter(thresholds[name] - + module.activation_threshold, + requires_grad=False) + else: + module.final_scale = nn.Parameter(thresholds[name], + requires_grad=False) + prev_threshold_set = True + module1.final_scale = nn.Parameter(torch.tensor(0.), + requires_grad=False) + + +def init_hist(model): + """ + Place forward hooks to collect histograms of activations + """ + for _, module in model.named_modules(): + if isinstance(module, (Eltwise, QuantizationAwareModule)): + register_hist_hooks(module) + + +def release_hist(model): + """ + Remove forward hooks after histogram collection + """ + for _, module in model.named_modules(): + if isinstance(module, (Eltwise, QuantizationAwareModule)): + release_hist_hooks(module) + + +def init_threshold(model, outlier_removal_z_score=8.0): + """ + Calculate thresholds based on the collected histograms + """ + for _, module in model.named_modules(): + if isinstance(module, (Eltwise, QuantizationAwareModule)): + # If module defined but not called on forward, it won't have hist + if hasattr(module, 'hist'): + init_threshold_module(module, outlier_removal_z_score) + calc_threshold(module) def onnx_export_prep(m, simplify=False, remove_clamp=False): diff --git a/ai8x_blocks.py b/ai8x_blocks.py index fc2b314c7..42fffc693 100644 --- a/ai8x_blocks.py +++ b/ai8x_blocks.py @@ -1,6 +1,6 @@ ################################################################################################### # -# Copyright (C) 2020-2023 Maxim Integrated Products, Inc. All Rights Reserved. +# Copyright (C) 2020-2024 Maxim Integrated Products, Inc. All Rights Reserved. # # Maxim Integrated Products, Inc. Default Copyright Notice: # https://www.maximintegrated.com/en/aboutus/legal/copyrights.html @@ -240,12 +240,11 @@ def __init__(self, eps=1e-03, momentum=0.01, **kwargs) # Depthwise Convolution phase if fused is not True: - self.depthwise_conv = ai8x.FusedConv2dBNReLU(in_channels=out, out_channels=out, - groups=out, # groups makes it depthwise - padding=1, kernel_size=kernel_size, - stride=stride, batchnorm='Affine', - bias=bias, eps=1e-03, momentum=0.01, - **kwargs) + self.depthwise_conv = ai8x.FusedDepthwiseConv2dBNReLU(out, out, kernel_size, + padding=1, stride=stride, + batchnorm='Affine', bias=bias, + eps=1e-03, momentum=0.01, + **kwargs) # Squeeze and Excitation phase if self.has_se: num_squeezed_channels = max(1, int(in_channels * se_ratio)) @@ -260,7 +259,9 @@ def __init__(self, kernel_size=1, batchnorm='Affine', bias=bias, eps=1e-03, momentum=0.01, **kwargs) # Skip connection - self.resid = ai8x.Add() + input_filters, output_filters = self.in_channels, self.out_channels + if self.stride == 1 and input_filters == output_filters: + self.resid = ai8x.Add() def forward(self, inputs): """MBConvBlock's forward function. diff --git a/datasets/imagenet.py b/datasets/imagenet.py index 4e77633e5..59e0972ba 100644 --- a/datasets/imagenet.py +++ b/datasets/imagenet.py @@ -1,6 +1,6 @@ # # Copyright (c) 2018 Intel Corporation -# Portions Copyright (C) 2019-2023 Maxim Integrated Products, Inc. +# Portions Copyright (C) 2019-2024 Maxim Integrated Products, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -48,7 +48,6 @@ def imagenet_get_datasets(data, load_train=True, load_test=True, transforms.RandomResizedCrop(input_size, antialias=True), transforms.RandomHorizontalFlip(), transforms.ToTensor(), - transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ai8x.normalize(args=args), ]) @@ -71,7 +70,6 @@ def imagenet_get_datasets(data, load_train=True, load_test=True, transforms.Resize(int(input_size / 0.875), antialias=True), # type: ignore transforms.CenterCrop(input_size), transforms.ToTensor(), - transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ai8x.normalize(args=args), ]) diff --git a/docs/unload_example.png b/docs/unload_example.png new file mode 100644 index 000000000..7ecacf2bc Binary files /dev/null and b/docs/unload_example.png differ diff --git a/models/ai85net-actiontcn.py b/models/ai85net-actiontcn.py index 23310856b..f8fece919 100644 --- a/models/ai85net-actiontcn.py +++ b/models/ai85net-actiontcn.py @@ -1,6 +1,6 @@ ################################################################################################### # -# Copyright (C) 2022-2023 Maxim Integrated Products, Inc. All Rights Reserved. +# Copyright (C) 2022-2024 Maxim Integrated Products, Inc. All Rights Reserved. # # Maxim Integrated Products, Inc. Default Copyright Notice: # https://www.maximintegrated.com/en/aboutus/legal/copyrights.html @@ -18,6 +18,9 @@ class AI85ActionTCN(nn.Module): """ Conv2D backbone + TCN layers for Action Recognition + Model was designed to be used with the Kinetics dataset. + Number of frames was set to 15, as the model optimally performs with this number + within the constraints of the AI85 hardware. """ def __init__( self, @@ -33,6 +36,7 @@ def __init__( self.num_classes = num_classes self.cnn_out_shape = (1, 1) self.cnn_out_channel = 32 + self.num_frames = 15 num_filters = 64 len_frame_vector = self.cnn_out_shape[0]*self.cnn_out_shape[1]*self.cnn_out_channel @@ -113,15 +117,14 @@ def create_cnn(self, x): def forward(self, x): """Forward prop""" batch_size = x.shape[0] - num_frames = x.shape[1] - cnnoutputs = torch.zeros(batch_size, num_frames, self.cnn_out_channel, - self.cnn_out_shape[0], self.cnn_out_shape[1], - device=x.get_device()) - for i in range(num_frames): - prep_out = self.create_prep(x[:, i]) - cnnoutputs[:, i] = self.create_cnn(prep_out) + cnnoutputs = torch.zeros_like(x) + cnnoutputs = cnnoutputs[:, :, :self.cnn_out_channel, :self.cnn_out_shape[0], + :self.cnn_out_shape[1]] - tcn_input = cnnoutputs.permute(0, 1, 3, 4, 2).reshape(batch_size, num_frames, -1) \ + for i in range(self.num_frames): + prep_out = self.create_prep(x[:, i]) + cnnoutputs = assign_cnnoutputs(cnnoutputs, i, self.create_cnn(prep_out)) + tcn_input = cnnoutputs.permute(0, 1, 3, 4, 2).reshape(batch_size, self.num_frames, -1) \ .permute(0, 2, 1) tcn_output = self.tcn0(tcn_input) tcn_output = self.tcn1(tcn_output) @@ -130,6 +133,15 @@ def forward(self, x): return tcn_output.reshape(batch_size, self.num_classes) +@torch.fx.wrap +def assign_cnnoutputs(cnnoutputs, index, value): + """ + Assigns a value to a slice of a tensor, required for symbolic tracing + """ + cnnoutputs[:, index] = value + return cnnoutputs + + def ai85actiontcn(pretrained=False, **kwargs): """ Constructs an AI85ActionTCN model. diff --git a/models/ai85net-autoencoder.py b/models/ai85net-autoencoder.py index caa316d8c..0855c371b 100644 --- a/models/ai85net-autoencoder.py +++ b/models/ai85net-autoencoder.py @@ -144,7 +144,7 @@ def __init__(self, self.initWeights(weight_init) - def forward(self, x, return_bottleneck=False): + def forward(self, x): """Forward prop""" x = self.en_conv1(x) x = self.en_conv2(x) @@ -152,9 +152,6 @@ def forward(self, x, return_bottleneck=False): x = self.en_lin1(x) x = self.en_lin2(x) - if return_bottleneck: - return x - x = self.de_lin1(x) x = self.de_lin2(x) x = self.out_lin(x) diff --git a/models/ai85net-faceid_112.py b/models/ai85net-faceid_112.py index eeb33e2bd..841e98fdd 100644 --- a/models/ai85net-faceid_112.py +++ b/models/ai85net-faceid_112.py @@ -1,6 +1,6 @@ ################################################################################################### # -# Copyright (C) 2019-2023 Maxim Integrated Products, Inc. All Rights Reserved. +# Copyright (C) 2019-2024 Maxim Integrated Products, Inc. All Rights Reserved. # # Maxim Integrated Products, Inc. Default Copyright Notice: # https://www.maximintegrated.com/en/aboutus/legal/copyrights.html @@ -97,8 +97,7 @@ def _create_bottleneck_stage(self, setting, bias, depthwise_bias, def forward(self, x): # pylint: disable=arguments-differ """Forward prop""" - if x.shape[1] == 6: - x = x[:, 0:3, :, :] + x = x[:, 0:3, :, :] x = self.pre_stage(x) x = self.pre_stage_2(x) for stage in self.feature_stage: diff --git a/models/ai87net-imagenet-effnetv2.py b/models/ai87net-imagenet-effnetv2.py index ffa9dbc47..54b98feee 100644 --- a/models/ai87net-imagenet-effnetv2.py +++ b/models/ai87net-imagenet-effnetv2.py @@ -9,6 +9,8 @@ """ ImageNet EfficientNet v.2 network implementation for MAX78002. """ +import math + from torch import nn import ai8x @@ -75,6 +77,8 @@ def __init__( # Final linear layer self.fc = ai8x.Linear(1024, num_classes, bias=bias, wide=True, **kwargs) + self._initialize_weights() + def forward(self, x): # pylint: disable=arguments-differ """ Forward prop """ x = self.conv_stem(x) @@ -97,6 +101,17 @@ def forward(self, x): # pylint: disable=arguments-differ x = self.fc(x) return x + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / (n))) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + def ai87imageneteffnetv2(pretrained=False, **kwargs): """ diff --git a/models/ai87net-mobilefacenet_112.py b/models/ai87net-mobilefacenet_112.py index c039d1e3e..6707fdfd8 100644 --- a/models/ai87net-mobilefacenet_112.py +++ b/models/ai87net-mobilefacenet_112.py @@ -1,6 +1,6 @@ ################################################################################################### # -# Copyright (C) 2023 Maxim Integrated Products, Inc. All Rights Reserved. +# Copyright (C) 2023-2024 Maxim Integrated Products, Inc. All Rights Reserved. # # Maxim Integrated Products, Inc. Default Copyright Notice: # https://www.maximintegrated.com/en/aboutus/legal/copyrights.html @@ -98,8 +98,7 @@ def _create_bottleneck_stage(self, setting, bias, depthwise_bias, def forward(self, x): # pylint: disable=arguments-differ """Forward prop""" - if x.shape[1] == 6: - x = x[:, 0:3, :, :] + x = x[:, 0:3, :, :] x = self.pre_stage(x) x = self.dwise(x) for stage in self.feature_stage: diff --git a/parse_qat_yaml.py b/parse_qat_yaml.py index 7d9ed513f..f431c639f 100644 --- a/parse_qat_yaml.py +++ b/parse_qat_yaml.py @@ -1,6 +1,6 @@ ################################################################################################### # -# Copyright (C) 2020-2023 Maxim Integrated Products, Inc. All Rights Reserved. +# Copyright (C) 2020-2024 Maxim Integrated Products, Inc. All Rights Reserved. # # Maxim Integrated Products, Inc. Default Copyright Notice: # https://www.maximintegrated.com/en/aboutus/legal/copyrights.html @@ -31,5 +31,10 @@ def parse(yaml_file, msglogger=None): assert False, '`start_epoch` must be defined in QAT policy' if policy and 'weight_bits' not in policy: assert False, '`weight_bits` must be defined in QAT policy' + if policy and 'outlier_removal_z_score' not in policy: + policy['outlier_removal_z_score'] = 8.0 + if msglogger is not None: + msglogger.info('`outlier_removal_z_score` not defined in QAT policy.' + 'Using default value of 8.0') return policy diff --git a/policies/qat_policy_cifar100.yaml b/policies/qat_policy_cifar100.yaml index 3e1ae75fb..41c1830b4 100644 --- a/policies/qat_policy_cifar100.yaml +++ b/policies/qat_policy_cifar100.yaml @@ -2,6 +2,7 @@ start_epoch: 240 shift_quantile: 0.985 weight_bits: 2 +outlier_removal_z_score: 2 overrides: conv1: weight_bits: 8 diff --git a/policies/qat_policy_imagenet.yaml b/policies/qat_policy_imagenet.yaml index 036c3db68..61a418a4b 100644 --- a/policies/qat_policy_imagenet.yaml +++ b/policies/qat_policy_imagenet.yaml @@ -1,3 +1,3 @@ --- -start_epoch: 1 +start_epoch: 150 weight_bits: 8 diff --git a/policies/qat_policy_late_cifar.yaml b/policies/qat_policy_late_cifar.yaml index b921a3e98..ede6ed4fc 100644 --- a/policies/qat_policy_late_cifar.yaml +++ b/policies/qat_policy_late_cifar.yaml @@ -2,3 +2,4 @@ start_epoch: 210 weight_bits: 8 shift_quantile: 0.995 +outlier_removal_z_score: 2 diff --git a/policies/qat_policy_mnist.yaml b/policies/qat_policy_mnist.yaml index 45c0a2db7..370a73360 100644 --- a/policies/qat_policy_mnist.yaml +++ b/policies/qat_policy_mnist.yaml @@ -1,6 +1,7 @@ --- start_epoch: 10 weight_bits: 2 +outlier_removal_z_score: 2 overrides: conv1: weight_bits: 8 diff --git a/policies/qat_policy_pascalvoc.yaml b/policies/qat_policy_pascalvoc.yaml index 35b4539ad..486c26b48 100644 --- a/policies/qat_policy_pascalvoc.yaml +++ b/policies/qat_policy_pascalvoc.yaml @@ -1,3 +1,4 @@ --- start_epoch: 250 +outlier_removal_z_score: 6 weight_bits: 8 diff --git a/policies/schedule-imagenet-effnet2.yaml b/policies/schedule-imagenet-effnet2.yaml index e68983f93..090575e98 100644 --- a/policies/schedule-imagenet-effnet2.yaml +++ b/policies/schedule-imagenet-effnet2.yaml @@ -1,13 +1,13 @@ --- lr_schedulers: training_lr: - class: MultiStepLR - milestones: [50, 100, 150, 200, 250] - gamma: 0.5 + class: CosineAnnealingLR + T_max: 150 + eta_min: 0.0000005 policies: - lr_scheduler: instance_name: training_lr starting_epoch: 0 - ending_epoch: 300 + ending_epoch: 150 frequency: 1 diff --git a/requirements.txt b/requirements.txt index a972c6c72..b1a718dc9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,8 @@ torchaudio==2.3.1 GitPython==3.1.43 Pillow==10.4.0 PyYAML==6.0.1 -albumentations==1.4.12 +albumentations==1.4.15 +albucore==0.0.16 faiss-cpu==1.8.0.post1 batch-face==1.4.0 h5py==3.11.0 diff --git a/scripts/train_imagenet_effnet2.sh b/scripts/train_imagenet_effnet2.sh index 64e381e05..8bded0394 100755 --- a/scripts/train_imagenet_effnet2.sh +++ b/scripts/train_imagenet_effnet2.sh @@ -1,2 +1,2 @@ #!/bin/sh -python train.py --deterministic --epochs 200 --optimizer Adam --lr 0.001 --wd 0 --compress policies/schedule-imagenet-effnet2.yaml --model ai87imageneteffnetv2 --dataset ImageNet --device MAX78002 --batch-size 256 --print-freq 100 --validation-split 0 --use-bias --qat-policy policies/qat_policy_imagenet.yaml "$@" +python train.py --deterministic --epochs 200 --optimizer Adam --lr 0.001 --wd 1e-5 --momentum 0.9 --compress policies/schedule-imagenet-effnet2.yaml --model ai87imageneteffnetv2 --dataset ImageNet --device MAX78002 --batch-size 256 --print-freq 100 --validation-split 0 --use-bias --qat-policy policies/qat_policy_imagenet.yaml "$@" diff --git a/scripts/train_mnist_qat.sh b/scripts/train_mnist_qat.sh index aa01fe3e8..84a61f0cb 100755 --- a/scripts/train_mnist_qat.sh +++ b/scripts/train_mnist_qat.sh @@ -1,2 +1,2 @@ #!/bin/sh -python train.py --lr 0.1 --optimizer SGD --epochs 200 --deterministic --seed 1 --compress policies/schedule.yaml --model ai85net5 --dataset MNIST --confusion --param-hist --pr-curves --embedding --device MAX78000 --qat-policy policies/qat_policy_mnist.yaml "$@" +python train.py --lr 0.01 --optimizer SGD --epochs 200 --deterministic --seed 1 --compress policies/schedule.yaml --model ai85net5 --dataset MNIST --confusion --param-hist --pr-curves --embedding --device MAX78000 --qat-policy policies/qat_policy_mnist.yaml "$@" diff --git a/train.py b/train.py index bd5e49139..04fb6435b 100755 --- a/train.py +++ b/train.py @@ -101,6 +101,7 @@ from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator from pytorch_metric_learning.utils.inference import CustomKNN from torchmetrics.detection import MeanAveragePrecision +from tqdm import tqdm import ai8x import ai8x_nas @@ -384,13 +385,22 @@ def flush(self): args.name = f'{args.name}_qat' else: args.name = 'qat' - model, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint( - model, args.resumed_checkpoint_path, model_device=args.device) + try: + model, compression_scheduler, optimizer, start_epoch = apputils.load_checkpoint( + model, args.resumed_checkpoint_path, model_device=args.device) + except ValueError as exc: + raise ValueError('\n ERROR: Unable to resume from the checkpoint. ' + 'The reason might be the size mismatch between checkpoint and' + ' optimizer. Instead of "--resume-from", "--exp-load-weights-from" ' + 'argument can be used to load the lean model. ') from exc elif args.load_model_path: + init_qat = False + update_old_model_params(args.load_model_path, model) if qat_policy is not None: checkpoint = torch.load(args.load_model_path, map_location=lambda storage, loc: storage) if checkpoint.get('epoch', None) >= qat_policy['start_epoch']: + init_qat = True ai8x.fuse_bn_layers(model) if args.name: args.name = f'{args.name}_qat' @@ -399,6 +409,10 @@ def flush(self): model = apputils.load_lean_checkpoint(model, args.load_model_path, model_device=args.device) + # If model is in QAT mode, guarantee that the model is initialized for QATv2 + if init_qat: + ai8x.initiate_qat(model, qat_policy) + ai8x.update_model(model) if args.reset_optimizer: @@ -576,7 +590,7 @@ def flush(self): if args.evaluate: msglogger.info('Dataset sizes:\n\ttest=%d', len(test_loader.sampler)) - return evaluate_model(model, criterion, test_loader, pylogger, args, compression_scheduler) + return test(test_loader, model, criterion, pylogger, args=args) assert train_loader and val_loader msglogger.info('Dataset sizes:\n\ttraining=%d\n\tvalidation=%d\n\ttest=%d', @@ -594,6 +608,15 @@ def flush(self): # Fuse the BN parameters into conv layers before Quantization Aware Training (QAT) ai8x.fuse_bn_layers(model) + ai8x.init_hist(model) + + msglogger.info('Collecting statistics for quantization aware training (QAT)...') + stat_collect(train_loader, model, args) + + ai8x.init_threshold(model, qat_policy["outlier_removal_z_score"]) + ai8x.release_hist(model) + + ai8x.apply_scales(model) # Update the optimizer to reflect fused batchnorm layers optimizer = ai8x.update_optimizer(model, optimizer) @@ -715,7 +738,9 @@ def flush(self): # Finally run results on the test set if not args.dr: - test(test_loader, model, criterion, [pylogger], args=args) + test(test_loader, model, criterion, [pylogger], args=args, mode="ckpt") + test(test_loader, model, criterion, [pylogger], args=args, mode="best", + ckpt_name=checkpoint_name) if args.copy_output_folder and local_rank <= 0: msglogger.info('Copying output folder to: %s', args.copy_output_folder) @@ -825,6 +850,15 @@ def create_nas_kd_policy(model, compression_scheduler, epoch, next_state_start_e ' | '.join([f'{val:.2f}' for val in dlw])) +@torch.no_grad() +def stat_collect(train_loader, model, args): + """Collect statistics for quantization aware training""" + model.eval() + for inputs, _ in tqdm(train_loader): + inputs = inputs.to(args.device) + model(inputs) + + def train(train_loader, model, criterion, optimizer, epoch, compression_scheduler, loggers, args, loss_optimizer=None): """Training loop for one epoch.""" @@ -1048,11 +1082,20 @@ def validate(val_loader, model, criterion, loggers, args, epoch=-1, tflogger=Non return _validate(val_loader, model, criterion, loggers, args, epoch, tflogger) -def test(test_loader, model, criterion, loggers, args): +def test(test_loader, model, criterion, loggers, args, mode='ckpt', ckpt_name=None): """Model Test""" assert msglogger is not None - msglogger.info('--- test ---------------------') - top1, top5, vloss, mAP = _validate(test_loader, model, criterion, loggers, args) + if mode == 'ckpt': + msglogger.info('--- test (ckpt) ---------------------') + top1, top5, vloss, mAP = _validate(test_loader, model, criterion, loggers, args) + else: + msglogger.info('--- test (best) ---------------------') + if ckpt_name is None: + best_ckpt_path = os.path.join(msglogger.logdir, 'best.pth.tar') + else: + best_ckpt_path = os.path.join(msglogger.logdir, ckpt_name + "_best.pth.tar") + model = apputils.load_lean_checkpoint(model, best_ckpt_path) + top1, top5, vloss, mAP = _validate(test_loader, model, criterion, loggers, args) return top1, top5, vloss, mAP @@ -1122,16 +1165,15 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1, tflogger=N # compute output from model output_boxes, output_conf = model(inputs) - # correct output for accurate loss calculation if args.act_mode_8bit: output_boxes /= 128. output_conf /= 128. - if (hasattr(m, 'are_locations_wide') and m.are_locations_wide): + if (hasattr(m, 'are_locations_wide') and m.are_locations_wide()): output_boxes /= 128. - if (hasattr(m, 'are_scores_wide') and m.are_scores_wide): + if (hasattr(m, 'are_scores_wide') and m.are_scores_wide()): output_conf /= 128. output = (output_boxes, output_conf) @@ -1191,10 +1233,9 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1, tflogger=N # correct output for accurate loss calculation if args.act_mode_8bit: output /= 128. - for key in model.__dict__['_modules'].keys(): - if (hasattr(model.__dict__['_modules'][key], 'wide') - and model.__dict__['_modules'][key].wide): - output /= 256. + for _, module in model.named_modules(): + if hasattr(module, 'wide') and module.wide: + output /= 128. if args.regression: target /= 128. @@ -1395,37 +1436,6 @@ def update_training_scores_history(perf_scores_history, model, top1, top5, mAP, score.epoch) -def evaluate_model(model, criterion, test_loader, loggers, args, scheduler=None, local_rank=-1): - """ - This sample application can be invoked to evaluate the accuracy of your model on - the test dataset. - You can optionally quantize the model to 8-bit integer before evaluation. - For example: - python3 compress_classifier.py --arch resnet20_cifar \ - ../data.cifar10 -p=50 --resume-from=checkpoint.pth.tar --evaluate - """ - - if not isinstance(loggers, list): - loggers = [loggers] - - top1, _, _, mAP = test(test_loader, model, criterion, loggers, args=args) - - if args.quantize_eval and local_rank <= 0: # not DistributedDataParallel or rank 0 - checkpoint_name = 'quantized' - - if args.obj_detection: - extras = {'quantized_mAP': mAP} - else: - extras = {'quantized_top1': top1} - assert msglogger is not None - - m, _, _ = model_wrapper.unwrap(model) - apputils.save_checkpoint(0, args.cnn, m, optimizer=None, scheduler=scheduler, - name='_'.join([args.name, checkpoint_name]) - if args.name else checkpoint_name, - dir=msglogger.logdir, extras=extras) - - def summarize_model(model, dataset, which_summary, filename='model'): """summarize_model""" if which_summary.startswith('png'): @@ -1516,6 +1526,37 @@ def __missing__(self, key): return None # note, does *not* set self[key] - we don't want defaultdict's behavior +def update_old_model_params(model_path, model_new): + """Adds missing model parameters added with default values. + This is mainly due to the saved checkpoints from previous versions of the repo. + New model is saved to `model_path` and the old model copied into the same file_path with + `__obsolete__` prefix.""" + is_model_old = False + model_old = torch.load(model_path, + map_location=lambda storage, loc: storage) + # Fix up any instances of DataParallel + old_dict = model_old['state_dict'].copy() + for k in old_dict: + if k.startswith('module.'): + model_old['state_dict'][k[7:]] = old_dict[k] + for new_key, new_val in model_new.state_dict().items(): + if new_key not in model_old['state_dict'] and '.bn.' not in new_key: + is_model_old = True + model_old['state_dict'][new_key] = new_val + if 'compression_sched' in model_old: + if 'masks_dict' in model_old['compression_sched']: + model_old['compression_sched']['masks_dict'][new_key] = None + + if is_model_old: + dir_path, file_name = os.path.split(model_path) + new_file_name = '__obsolete__' + file_name + old_model_path = os.path.join(dir_path, new_file_name) + os.rename(model_path, old_model_path) + torch.save(model_old, model_path) + msglogger.info('Model `%s` is old. Missing parameters added with default values!', + model_path) + + if __name__ == '__main__': try: np.set_printoptions(threshold=sys.maxsize, linewidth=190)