From 385fcf2fe8dd4dc1c4d17423ad4f044353673ad4 Mon Sep 17 00:00:00 2001 From: ChengcanWang-com Date: Thu, 17 Jun 2021 01:09:51 +0000 Subject: [PATCH] Support ensemble distillation. (#23) close #22 --- README.md | 26 ++- azure-pipelines.yml | 4 +- examples/get_flops_lenet.py | 10 +- examples/get_flops_resnet_50.py | 10 +- examples/resnet_101_imagenet_train.py | 3 +- examples/resnet_50_imagenet_distill.py | 3 +- examples/resnet_50_imagenet_prune_distill.py | 45 +++++ examples/resnet_50_imagenet_train.py | 3 +- setup.py | 9 +- src/model_optimizer/pruner/config_schema.json | 4 + src/model_optimizer/pruner/core/__init__.py | 2 +- src/model_optimizer/pruner/core/pruner.py | 161 +++++++++++++----- .../pruner/dataset/imagenet.py | 2 +- .../pruner/distill/distill_loss.py | 11 +- .../pruner/distill/distiller.py | 16 +- .../pruner/distill/tf_model_loader.py | 74 ++++++++ .../pruner/learner/learner_base.py | 73 +++----- .../pruner/learner/lenet_mnist.py | 3 +- .../pruner/learner/mobilenet_v1_imagenet.py | 3 +- .../pruner/learner/mobilenet_v2_imagenet.py | 3 +- .../pruner/learner/resnet_101_imagenet.py | 9 +- .../pruner/learner/resnet_50_imagenet.py | 12 +- .../pruner/learner/vgg_m_16_cifar10.py | 3 +- src/model_optimizer/pruner/models/__init__.py | 30 ++-- src/model_optimizer/pruner/models/config.py | 63 +++++++ src/model_optimizer/pruner/models/lenet.py | 10 +- .../pruner/models/mobilenet_v1.py | 56 +++--- .../pruner/models/mobilenet_v2.py | 76 ++++++--- src/model_optimizer/pruner/models/resnet.py | 40 +++-- src/model_optimizer/pruner/models/vgg.py | 63 ++++--- .../distill/resnet_50_imagenet_0.3.yaml | 9 +- .../resnet_50_imagenet_0.5_distill.yaml | 44 +++++ src/model_optimizer/stat.py | 19 +-- .../utils/imagenet_preprocessing.py | 28 ++- 34 files changed, 660 insertions(+), 267 deletions(-) create mode 100644 examples/resnet_50_imagenet_prune_distill.py create mode 100644 src/model_optimizer/pruner/distill/tf_model_loader.py create mode 100644 src/model_optimizer/pruner/models/config.py create mode 100644 src/model_optimizer/pruner/scheduler/uniform_auto/resnet_50_imagenet_0.5_distill.yaml diff --git a/README.md b/README.md index 8aaa12c..ca8ffca 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ sparsity pruning depends on special algorithms and hardware to achieve accelerat Adlik pruning focuses on channel pruning and filter pruning, which can really reduce the number of parameters and flops. In terms of quantization, Adlik focuses on 8-bit quantization that is easier to accelerate on specific hardware. After testing, it is found that running a small batch of datasets can obtain a quantitative model with little loss of -accuracy, so Adlik focuses on this method. Knowledge distillation is another way to improve the performance of deep +accuracy, so Adlik focuses on this method. Knowledge distillation is another way to improve the performance of deep learning algorithm. It is possible to compress the knowledge in the big model into a smaller model. The proposed framework mainly consists of two categories of algorithm components, i.e. pruner and quantizer. The @@ -23,7 +23,7 @@ three modules. After filter pruning, model can continue to be quantized, the following table shows the accuracy of the pruned and quantized Lenet-5 and ResNet-50 models. -| model | baseline | pruned | pruned+quantization(TF-Lite) | pruned+quantization(TF-TRT) | +| Model | Baseline | Pruned | Pruned + Quantization(TF-Lite) | Pruned + Quantization(TF-TRT) | | --------- | -------- | -------------------- | ---------------------------- | --------------------------- | | LeNet-5 | 98.85 | 99.11(59% pruned) | 99.05 | 99.11 | | ResNet-50 | 76.174 | 75.456(31.9% pruned) | 75.158 | 75.28 | @@ -31,7 +31,7 @@ quantized Lenet-5 and ResNet-50 models. The Pruner completely removes redundant parameters, which further leads to smaller model size and faster execution. The following table is the size of the above model files: -| model | baseline(H5) | pruned(H5) | quantization(TF-Lite) | quantization(TF-TRT) | +| Model | Baseline(H5) | Pruned(H5) | Quantization(TF-Lite) | Quantization(TF-TRT) | | --------- | ------------ | ------------------ | --------------------- | -------------------- | | LeNet-5 | 1176KB | 499KB(59% pruned) | 120KB | 1154KB (pb) | | ResNet-50 | 99MB | 67MB(31.9% pruned) | 18MB | 138MB(pb) | @@ -47,12 +47,24 @@ which was tested on ImageNet. The original test accuracy is 71.25%, and model si Knowledge distillation is an effective way to imporve the performance of model. -The following table shows the distillation result of ResNet-50 as the student network where ResNet-101 as the teacher network. +The following table shows the distillation result of ResNet-50 as the student network where ResNet-101 as the teacher network. -| student model | ResNet-101 distilled | accuracy change | +| Student Model | ResNet-101 Distilled | Accuracy Change | | ------------- | -------------------- | --------------- | | ResNet-50 | 77.14% | +0.97% | +Ensemble distillation can significantly improve the accuracy of the model. In the case of cutting 72.8% of the +parameters, using senet154 and resnet152b as the teacher network, ensemble distillation can increase the accuracy +by more than 4%. +The details are shown in the table below, and the code can refer to examples\resnet_50_imagenet_prune_distill.py. + +| Model | Accuracy | Params | FLOPs | Model Size | +| --------- | -------- | -------------------- | ---------------------------- | ---------------------------- | +| ResNet-50 | 76.174 | 25610152 | 3899M|99M | +| + pruned | 72.28 | 6954152 ( 72.8% pruned) | 1075M | 27M| +| + pruned + distill | 76.39 | 6954152 ( 72.8% pruned) | 1075M | 27M| +| + pruned + distill + quantization(TF-Lite) | 75.938 | - | - | 7.1M| + ## 1. Pruning and quantization principle ### 1.1 Filter pruning @@ -85,7 +97,7 @@ quantification of ResNet-50 in less than one minute. Knowledge distillation is a compression technique by which the knowledge of a larger model(teacher) is transfered into a smaller one(student). During distillation, a student model learns from a teacher model to generalize well by raise -the temperature of the final softmax of the teacher model as the soft set of targets. +the temperature of the final softmax of the teacher model as the soft set of targets. ![Distillation](imgs/distillation.png) @@ -252,7 +264,7 @@ This step is the same as described above. You can get detailed instructions from Batch size is an important hyper-parameter for Deep Learning model training. If you have more GPU memory available, you can try larger batch size! You have to adjust the learning rate according to different batch size. -| model | card | batch size | learning-rate | +| Model | Card | Batch Size | Learning Rate | | --------- | --------- | ---------- | ------------- | | ResNet-50 | V100 32GB | 256 | 0.1 | | ResNet-50 | P100 16GB | 128 | 0.05 | diff --git a/azure-pipelines.yml b/azure-pipelines.yml index cdfe411..69d8308 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -8,7 +8,7 @@ jobs: strategy: matrix: Linux: - vmImage: ubuntu-latest + vmImage: ubuntu-18.04 pool: vmImage: $(vmImage) steps: @@ -25,7 +25,7 @@ jobs: - job: Markdownlint displayName: Markdownlint pool: - vmImage: ubuntu-latest + vmImage: ubuntu-18.04 steps: - script: sudo npm install -g markdownlint-cli displayName: Install markdownlint-cli diff --git a/examples/get_flops_lenet.py b/examples/get_flops_lenet.py index 93626ac..791bdb4 100644 --- a/examples/get_flops_lenet.py +++ b/examples/get_flops_lenet.py @@ -11,18 +11,18 @@ # sys.path.insert(0, join(abspath(dirname(__file__)), '../src')) # print(sys.path) -from model_optimizer.stat import get_keras_model_flops # noqa: E402 +from model_optimizer.stat import get_keras_model_params_flops # noqa: E402 def _main(): base_dir = os.path.dirname(__file__) model_h5_path = './models_eval_ckpt/lenet_mnist/checkpoint-12.h5' - origin_flops = get_keras_model_flops(os.path.join(base_dir, model_h5_path)) + origin_params, origin_flops = get_keras_model_params_flops(os.path.join(base_dir, model_h5_path)) model_h5_path = './models_eval_ckpt/lenet_mnist_pruned/checkpoint-12.h5' - pruned_flops = get_keras_model_flops(os.path.join(base_dir, model_h5_path)) + pruned_params, pruned_flops = get_keras_model_params_flops(os.path.join(base_dir, model_h5_path)) - print('flops before prune: {}'.format(origin_flops)) - print('flops after pruned: {}'.format(pruned_flops)) + print('Before prune, FLOPs: {}, Params: {}'.format(origin_flops, origin_params)) + print('After pruned, FLOPs: {}, Params: {}'.format(pruned_flops, pruned_params)) if __name__ == "__main__": diff --git a/examples/get_flops_resnet_50.py b/examples/get_flops_resnet_50.py index c2de4e9..692a4f9 100644 --- a/examples/get_flops_resnet_50.py +++ b/examples/get_flops_resnet_50.py @@ -12,18 +12,18 @@ # sys.path.insert(0, join(abspath(dirname(__file__)), '../src')) # print(sys.path) -from model_optimizer.stat import get_keras_model_flops # noqa: E402 +from model_optimizer.stat import get_keras_model_params_flops # noqa: E402 def _main(): base_dir = os.path.dirname(__file__) model_h5_path = './models_eval_ckpt/resnet_50_imagenet/checkpoint-90.h5' - origin_flops = get_keras_model_flops(os.path.join(base_dir, model_h5_path)) + origin_params, origin_flops = get_keras_model_params_flops(os.path.join(base_dir, model_h5_path)) model_h5_path = './models_eval_ckpt/resnet_50_imagenet_pruned/checkpoint-120.h5' - pruned_flops = get_keras_model_flops(os.path.join(base_dir, model_h5_path)) + pruned_params, pruned_flops = get_keras_model_params_flops(os.path.join(base_dir, model_h5_path)) - print('flops before prune: {}'.format(origin_flops)) - print('flops after pruned: {}'.format(pruned_flops)) + print('Before prune, FLOPs: {}, Params: {}'.format(origin_flops, origin_params)) + print('After pruned, FLOPs: {}, Params: {}'.format(pruned_flops, pruned_params)) if __name__ == "__main__": diff --git a/examples/resnet_101_imagenet_train.py b/examples/resnet_101_imagenet_train.py index a74832e..dd323af 100644 --- a/examples/resnet_101_imagenet_train.py +++ b/examples/resnet_101_imagenet_train.py @@ -27,8 +27,7 @@ def _main(): "checkpoint_path": os.path.join(base_dir, "./models_ckpt/resnet_101_imagenet"), "checkpoint_save_period": 5, # save a checkpoint every 5 epoch "checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/resnet_101_imagenet"), - "scheduler": "train", - "classifier_activation": None # None or "softmax", default is softmax + "scheduler": "train" } prune_model(request) diff --git a/examples/resnet_50_imagenet_distill.py b/examples/resnet_50_imagenet_distill.py index 82b5269..f68ea60 100644 --- a/examples/resnet_50_imagenet_distill.py +++ b/examples/resnet_50_imagenet_distill.py @@ -28,8 +28,7 @@ def _main(): "checkpoint_save_period": 5, # save a checkpoint every 5 epoch "checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/resnet_50_imagenet_distill"), "scheduler": "distill", - "scheduler_file_name": "resnet_50_imagenet_0.3.yaml", - "classifier_activation": None # None or "softmax", default is softmax + "scheduler_file_name": "resnet_50_imagenet_0.3.yaml" } prune_model(request) diff --git a/examples/resnet_50_imagenet_prune_distill.py b/examples/resnet_50_imagenet_prune_distill.py new file mode 100644 index 0000000..7e7125c --- /dev/null +++ b/examples/resnet_50_imagenet_prune_distill.py @@ -0,0 +1,45 @@ +# Copyright 2019 ZTE corporation. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +This is an example of pruning and ensemble distillation of the resnet50 model +Please download the two models senet154 and resnet152b to the directory configured in the file +resnet_50_imagenet_0.5_distill.yaml. +wget -c -O resnet152b-0431-b41ec90e.tf2.h5.zip https://github.com/osmr/imgclsmob/releases/ +download/v0.0.517/resnet152b-0431-b41ec90e.tf2.h5.zip +wget -c -O senet154-0466-f1b79a9b_tf2.h5.zip https://github.com/osmr/imgclsmob/releases/ +download/v0.0.422/senet154-0466-f1b79a9b_tf2.h5.zip +""" +import os +# If you did not execute the setup.py, uncomment the following four lines +# import sys +# from os.path import abspath, join, dirname +# sys.path.insert(0, join(abspath(dirname(__file__)), '../src')) +# print(sys.path) + +from model_optimizer import prune_model # noqa: E402 + + +def _main(): + base_dir = os.path.dirname(__file__) + request = { + "dataset": "imagenet", + "model_name": "resnet_50", + "data_dir": os.path.join(base_dir, "/data/imagenet/tfrecord-dataset"), + "batch_size": 256, + "batch_size_val": 100, + "learning_rate": 0.1, + "epochs": 360, + "checkpoint_path": os.path.join(base_dir, "./models_ckpt/resnet_50_imagenet_pruned"), + "checkpoint_save_period": 5, # save a checkpoint every 5 epoch + "checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/resnet_50_imagenet_pruned"), + "scheduler": "uniform_auto", + "is_distill": True, + "scheduler_file_name": "resnet_50_imagenet_0.5_distill.yaml" + } + os.environ['L2_WEIGHT_DECAY'] = "5e-5" + prune_model(request) + + +if __name__ == "__main__": + _main() diff --git a/examples/resnet_50_imagenet_train.py b/examples/resnet_50_imagenet_train.py index ec122aa..616a68b 100644 --- a/examples/resnet_50_imagenet_train.py +++ b/examples/resnet_50_imagenet_train.py @@ -28,8 +28,7 @@ def _main(): "checkpoint_path": os.path.join(base_dir, "./models_ckpt/resnet_50_imagenet"), "checkpoint_save_period": 5, # save a checkpoint every 5 epoch "checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/resnet_50_imagenet"), - "scheduler": "train", - "classifier_activation": None # None or "softmax", default is softmax + "scheduler": "train" } prune_model(request) diff --git a/setup.py b/setup.py index 6d11194..c0ef229 100644 --- a/setup.py +++ b/setup.py @@ -12,12 +12,17 @@ _VERSION = '0.0.0' _REQUIRED_PACKAGES = [ - 'requests', + 'requests==2.25.0', 'tensorflow==2.3.0', 'jsonschema==3.1.1', 'networkx==2.4', 'mpi4py==3.0.3', - 'horovod==0.19.1' + 'horovod==0.19.1', + 'tf2cv==0.0.16', + 'PyYAML==5.3.1', + 'types-PyYAML', + 'types-pkg_resources', + 'types-requests' ] _TEST_REQUIRES = [ diff --git a/src/model_optimizer/pruner/config_schema.json b/src/model_optimizer/pruner/config_schema.json index eda7b9e..e9247d9 100644 --- a/src/model_optimizer/pruner/config_schema.json +++ b/src/model_optimizer/pruner/config_schema.json @@ -30,6 +30,10 @@ "checkpoint_eval_path": { "type": "string", "description": "file path of eval checkpoint" + }, + "is_distill":{ + "type": "boolean", + "description": "if start train model with distilling" } }, "required": [ diff --git a/src/model_optimizer/pruner/core/__init__.py b/src/model_optimizer/pruner/core/__init__.py index 6d536d1..8a83b1d 100644 --- a/src/model_optimizer/pruner/core/__init__.py +++ b/src/model_optimizer/pruner/core/__init__.py @@ -27,5 +27,5 @@ def get_pruner(config, epoch): func_name = scheduler['pruner']['func_name'] pruner_type = scheduler_config['pruners'][func_name]['prune_type'] if pruner_type in pruners: - pruner_list.append(pruners[pruner_type](scheduler_config['pruners'][func_name])) + pruner_list.append(pruners[pruner_type](scheduler_config['pruners'][func_name], config)) return pruner_list diff --git a/src/model_optimizer/pruner/core/pruner.py b/src/model_optimizer/pruner/core/pruner.py index 50a3211..490eca4 100644 --- a/src/model_optimizer/pruner/core/pruner.py +++ b/src/model_optimizer/pruner/core/pruner.py @@ -7,6 +7,12 @@ import networkx as nx import tensorflow as tf import numpy as np +from ..distill.distill_loss import DistillLossLayer + + +_custom_objects = { + 'DistillLossLayer': DistillLossLayer + } def get_network(model): @@ -84,6 +90,7 @@ def _get_dense_mask(model, layer_id, digraph, num_retain_channels, criterion): return _get_layer_mask(model, layer_id, 'dense', digraph, num_retain_channels, criterion) +# pylint: disable=too-many-arguments def _get_layer_mask(model, layer_id, layer_type, digraph, num_retain_channels, criterion): """ Get conv2d layer mask @@ -138,7 +145,13 @@ def get_relate_father_id(layer_id, digraph): 'AveragePooling2D', 'BatchNormalization', 'Flatten', - 'MaxPooling2D'] + 'MaxPooling2D', + 'DepthwiseConv2D', + 'ReLU', + 'Reshape', + 'Dropout', + 'GlobalAveragePooling2D', + 'ZeroPadding2D'] node_list = [] for node in digraph.predecessors(layer_id): node_list.append(node) @@ -202,6 +215,8 @@ def update_weights(model, pruned_model, digraph, mask_dict): """ for i, layer in enumerate(model.layers): layer_type = str(type(layer)) + if 'DepthwiseConv2D' in str(type(layer)): + continue if layer_type.endswith('Conv2D\'>') or layer_type.endswith('Dense\'>') or \ layer_type.endswith('BatchNormalization\'>'): new_model_layer_input_shape = pruned_model.layers[i].input.shape @@ -227,7 +242,8 @@ def update_weights(model, pruned_model, digraph, mask_dict): _layer_set_weights(pruned_model, layer, weights_0, i, mask_dict) -def specified_layers_prune(orig_model, cur_model, layers_name, ratio, criterion='l1_norm'): +# pylint: disable=too-many-arguments,too-many-branches,too-many-statements +def specified_layers_prune(orig_model, cur_model, layers_name, ratio, criterion='l1_norm', basic_config=None): """ Prune with specified layers :param orig_model: original model, never pruned once @@ -235,90 +251,142 @@ def specified_layers_prune(orig_model, cur_model, layers_name, ratio, criterion= :param layers_name: name list of pruned layers :param ratio: ratio of pruned :param criterion: 'l1_norm' or 'bn_gamma' + :param basic_config: config :return: pruned model """ clone_model = tf.keras.models.clone_model(cur_model) - digraph = get_network(cur_model) + if basic_config is None: + is_distill = False + model_name = 'no_name' + else: + is_distill = basic_config.get_attribute('is_distill', False) + model_name = basic_config.get_attribute('model_name') + if is_distill: + _clone_model = clone_model.get_layer(model_name) + _cur_model = cur_model.get_layer(model_name) + _orig_model = orig_model.get_layer(model_name) + else: + _clone_model = clone_model + _cur_model = cur_model + _orig_model = orig_model + digraph = get_network(_cur_model) mask_dict = {} - conv_index, last_reshape, dense_ahead_of_conv = dense_present_before_conv(orig_model) + conv_index, last_reshape, dense_ahead_of_conv = dense_present_before_conv(_orig_model) channel = -1 - for i, layer in enumerate(cur_model.layers): + for i, layer in enumerate(_cur_model.layers): layer_type = str(type(layer)) if not dense_ahead_of_conv and i == conv_index: if layer_type.endswith('Conv2D\'>'): - channel = clone_model.get_layer(layer.name).filters + channel = _clone_model.get_layer(layer.name).filters continue if layer_type.endswith('Reshape\'>'): if i == last_reshape: target_shape = (channel,) - clone_model.layers[i].target_shape = target_shape + _clone_model.layers[i].target_shape = target_shape continue elif channel != -1: target_shape = (1, 1, channel) - clone_model.layers[i].target_shape = target_shape + _clone_model.layers[i].target_shape = target_shape continue if 'Conv2D' in str(type(layer)): if layer.name in layers_name: - clone_model.layers[i].filters = \ - clone_model.layers[i].filters - int(orig_model.layers[i].filters * ratio) - mask_dict[i] = _get_conv_mask(cur_model, i, digraph, int(clone_model.layers[i].filters), criterion) - channel = clone_model.layers[i].filters + _clone_model.layers[i].filters = \ + _clone_model.layers[i].filters - int(_orig_model.layers[i].filters * ratio) + mask_dict[i] = _get_conv_mask(_cur_model, i, digraph, int(_clone_model.layers[i].filters), criterion) + channel = _clone_model.layers[i].filters elif 'Dense' in str(type(layer)): - if i == len(cur_model.layers) - 1: + if i == len(_cur_model.layers) - 1: continue else: if layer.name in layers_name: - clone_model.layers[i].units = \ - clone_model.layers[i].units - int(orig_model.layers[i].units * ratio) - mask_dict[i] = _get_dense_mask(cur_model, i, digraph, int(clone_model.layers[i].units), criterion) - pruned_model = tf.keras.models.model_from_json(clone_model.to_json()) - update_weights(cur_model, pruned_model, digraph, mask_dict) + _clone_model.layers[i].units = \ + _clone_model.layers[i].units - int(_orig_model.layers[i].units * ratio) + mask_dict[i] = _get_dense_mask(_cur_model, i, digraph, int(_clone_model.layers[i].units), criterion) + if is_distill: + custom_objects = _custom_objects + else: + custom_objects = None + pruned_model = tf.keras.models.model_from_json(clone_model.to_json(), custom_objects=custom_objects) + if not is_distill: + update_weights(cur_model, pruned_model, digraph, mask_dict) return pruned_model -def auto_prune(orig_model, cur_model, ratio, criterion='l1_norm'): +# pylint: disable=too-many-arguments,too-many-branches,too-many-statements +def auto_prune(orig_model, cur_model, ratio, criterion='l1_norm', basic_config=None): """ Auto prune layer with fixed ratio :param orig_model: original model, never pruned once :param cur_model: model before this step of pruned :param ratio: ratio of pruned :param criterion: 'l1_norm' or 'bn_gamma' + :param basic_config: config :return: pruned model """ clone_model = tf.keras.models.clone_model(cur_model) - digraph = get_network(cur_model) + if basic_config is None: + is_distill = False + model_name = 'no_name' + else: + is_distill = basic_config.get_attribute('is_distill', False) + model_name = basic_config.get_attribute('model_name') + if is_distill: + _clone_model = clone_model.get_layer(model_name) + _cur_model = cur_model.get_layer(model_name) + _orig_model = orig_model.get_layer(model_name) + else: + _clone_model = clone_model + _cur_model = cur_model + _orig_model = orig_model + digraph = get_network(_cur_model) mask_dict = {} - conv_index, last_reshape, dense_ahead_of_conv = dense_present_before_conv(orig_model) + conv_index, last_reshape, dense_ahead_of_conv = dense_present_before_conv(_orig_model) channel = -1 - for i, layer in enumerate(cur_model.layers): + last_dense_or_conv = True + for i, layer in enumerate(_cur_model.layers): layer_type = str(type(layer)) if not dense_ahead_of_conv and i == conv_index: if layer_type.endswith('Conv2D\'>'): - channel = clone_model.get_layer(layer.name).filters + channel = _clone_model.get_layer(layer.name).filters continue if layer_type.endswith('Reshape\'>'): if i == last_reshape: target_shape = (channel,) - clone_model.layers[i].target_shape = target_shape + _clone_model.layers[i].target_shape = target_shape continue elif channel != -1: target_shape = (1, 1, channel) - clone_model.layers[i].target_shape = target_shape + _clone_model.layers[i].target_shape = target_shape continue - if 'Conv2D' in str(type(layer)): - clone_model.layers[i].filters = \ - clone_model.layers[i].filters - int(orig_model.layers[i].filters * ratio) - mask_dict[i] = _get_conv_mask(cur_model, i, digraph, int(clone_model.layers[i].filters), criterion) - channel = clone_model.layers[i].filters + if 'DepthwiseConv2D' in str(type(layer)): + continue + elif 'Conv2D' in str(type(layer)): + _clone_model.layers[i].filters = \ + _clone_model.layers[i].filters - int(_orig_model.layers[i].filters * ratio) + mask_dict[i] = _get_conv_mask(_cur_model, i, digraph, int(_clone_model.layers[i].filters), criterion) + channel = _clone_model.layers[i].filters elif 'Dense' in str(type(layer)): - if i == len(cur_model.layers) - 1: + if i == len(_cur_model.layers) - 1: continue else: - clone_model.layers[i].units = \ - clone_model.layers[i].units - int(orig_model.layers[i].units * ratio) - mask_dict[i] = _get_dense_mask(cur_model, i, digraph, int(clone_model.layers[i].units), criterion) - pruned_model = tf.keras.models.model_from_json(clone_model.to_json()) - update_weights(cur_model, pruned_model, digraph, mask_dict) + + for index in range(i+1, len(_cur_model.layers)): + if 'Dense' in str(type(_cur_model.layers[index])) or \ + 'Conv2D' in str(type(_cur_model.layers[index])): + last_dense_or_conv = False + break + if not last_dense_or_conv: + _clone_model.layers[i].units = \ + _clone_model.layers[i].units - int(_orig_model.layers[i].units * ratio) + mask_dict[i] = _get_dense_mask(_cur_model, i, digraph, + int(_clone_model.layers[i].units), criterion) + if is_distill: + custom_objects = _custom_objects + else: + custom_objects = None + pruned_model = tf.keras.models.model_from_json(clone_model.to_json(), custom_objects=custom_objects) + if not is_distill: + update_weights(cur_model, pruned_model, digraph, mask_dict) return pruned_model @@ -327,9 +395,10 @@ class AutoPruner: Auto select layers to prune. """ - def __init__(self, config): - self.ratio = config['ratio'] - self.criterion = config['criterion'] + def __init__(self, scheduler_config, basic_config=None): + self.basic_config = basic_config + self.ratio = scheduler_config['ratio'] + self.criterion = scheduler_config['criterion'] def prune(self, orig_model, cur_model): """ @@ -338,7 +407,7 @@ def prune(self, orig_model, cur_model): :param cur_model: model before this step of pruned :return: pruned model """ - return auto_prune(orig_model, cur_model, self.ratio, self.criterion) + return auto_prune(orig_model, cur_model, self.ratio, self.criterion, self.basic_config) class SpecifiedLayersPruner: @@ -346,10 +415,11 @@ class SpecifiedLayersPruner: Specified layers to prune. """ - def __init__(self, config): - self.ratio = config['ratio'] - self.criterion = config['criterion'] - self.layers_name = config['layers_to_be_pruned'] + def __init__(self, scheduler_config, basic_config=None): + self.basic_config = basic_config + self.ratio = scheduler_config['ratio'] + self.criterion = scheduler_config['criterion'] + self.layers_name = scheduler_config['layers_to_be_pruned'] def prune(self, orig_model, cur_model): """ @@ -358,4 +428,5 @@ def prune(self, orig_model, cur_model): :param cur_model: model before this step of pruned :return: pruned model """ - return specified_layers_prune(orig_model, cur_model, self.layers_name, self.ratio, self.criterion) + return specified_layers_prune(orig_model, cur_model, + self.layers_name, self.ratio, self.criterion, self.basic_config) diff --git a/src/model_optimizer/pruner/dataset/imagenet.py b/src/model_optimizer/pruner/dataset/imagenet.py index 295b596..637070d 100644 --- a/src/model_optimizer/pruner/dataset/imagenet.py +++ b/src/model_optimizer/pruner/dataset/imagenet.py @@ -54,7 +54,7 @@ def parse_fn(self, example_serialized): features = tf.io.parse_single_example(serialized=example_serialized, features=feature_description) - label = tf.cast(features['image/class/label'], dtype=tf.int32) + label = tf.cast(features['image/class/label'], dtype=tf.int32) - 1 xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0) ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0) diff --git a/src/model_optimizer/pruner/distill/distill_loss.py b/src/model_optimizer/pruner/distill/distill_loss.py index a51112d..64d721e 100644 --- a/src/model_optimizer/pruner/distill/distill_loss.py +++ b/src/model_optimizer/pruner/distill/distill_loss.py @@ -22,7 +22,8 @@ class DistillLossLayer(tf.keras.layers.Layer): Call arguments: inputs: inputs of the layer. It corresponds to [input, y_true, y_prediction] """ - def __init__(self, teacher_path, alpha=1.0, temperature=10, name="DistillLoss", **kwargs): + def __init__(self, teacher_path, alpha=1.0, temperature=10, name="DistillLoss", + teacher_model_load_func=None, **kwargs): """ :param teacher_path: the model path of teacher. The format of the model is h5. :param alpha: a float between [0.0, 1.0]. It corresponds to the importance between the student loss and the @@ -35,7 +36,12 @@ def __init__(self, teacher_path, alpha=1.0, temperature=10, name="DistillLoss", self.temperature = temperature self.teacher_path = teacher_path self.accuracy_fn = tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy") - self.teacher = tf.keras.models.load_model(self.teacher_path) + self.teacher_model_load_func = teacher_model_load_func + if self.teacher_model_load_func == "load_tf_ensemble_model": + from .tf_model_loader import load_tf_ensemble_model + self.teacher = load_tf_ensemble_model(self.teacher_path) + else: + self.teacher = tf.keras.models.load_model(self.teacher_path) # pylint: disable=unused-argument def call(self, inputs, **kwargs): @@ -72,4 +78,5 @@ def get_config(self): config.update({"teacher_path": self.teacher_path}) config.update({"alpha": self.alpha}) config.update({"temperature": self.temperature}) + config.update({"teacher_model_load_func": self.teacher_model_load_func}) return config diff --git a/src/model_optimizer/pruner/distill/distiller.py b/src/model_optimizer/pruner/distill/distiller.py index 6eb3a27..e969539 100644 --- a/src/model_optimizer/pruner/distill/distiller.py +++ b/src/model_optimizer/pruner/distill/distiller.py @@ -9,19 +9,27 @@ from .distill_loss import DistillLossLayer -def get_distiller(student_model, scheduler_config): +def get_distiller(student_model, scheduler_config, teacher_model_load_func=None): """ Get distiller model :param student_model: student model function :param scheduler_config: scheduler config object + :param teacher_model_load_func: func to load teacher model :return: keras model of distiller """ + + if teacher_model_load_func is None: + if "model_load_func" in scheduler_config['distill']: + teacher_model_load_func = scheduler_config['distill']["model_load_func"] + input_img = tf.keras.layers.Input(shape=(224, 224, 3), name='image') input_lbl = tf.keras.layers.Input((), name="label", dtype='int32') student = student_model - _, logits = student(input_img) - total_loss = DistillLossLayer(scheduler_config['teacher_path'], scheduler_config['alpha'], - scheduler_config['temperature'], )([input_img, input_lbl, logits]) + logits = student(input_img) + total_loss = DistillLossLayer(scheduler_config['distill']['teacher_path'], + scheduler_config['distill']['alpha'], + scheduler_config['distill']['temperature'], + teacher_model_load_func=teacher_model_load_func)([input_img, input_lbl, logits]) distill_model = tf.keras.Model(inputs=[input_img, input_lbl], outputs=[logits, total_loss]) return distill_model diff --git a/src/model_optimizer/pruner/distill/tf_model_loader.py b/src/model_optimizer/pruner/distill/tf_model_loader.py new file mode 100644 index 0000000..c35d3fe --- /dev/null +++ b/src/model_optimizer/pruner/distill/tf_model_loader.py @@ -0,0 +1,74 @@ +# Copyright 2019 ZTE corporation. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +load tf models func +""" +import tensorflow as tf + + +def load_tf_ensemble_model(model_path): + """ + Load multiple models to form an integrated model + :param model_path: Model path, multiple model paths are separated by commas + :return: + """ + return EnsembleModel(model_path) + + +def _load_model(model_path, prefix=None): + if 'tf2' in model_path: + from tf2cv.model_provider import get_model as tf2cv_get_model # type: ignore + _model = tf2cv_get_model(model_path.split('/')[-1].split('-')[0], pretrained=False, data_format="channels_last") + _model.build(input_shape=(1, 224, 224, 3)) + _model.load_weights(model_path) + _model.trainable = False + else: + _model = tf.keras.models.load_model(model_path) + _model.trainable = False + if prefix is not None: + for weight in _model.weights: + weight._handle_name = prefix + '_' + weight.name # pylint: disable=W0212 + return _model + + +class EnsembleModel(tf.keras.Model): # pylint: disable=too-many-ancestors + """ + Ensemble model for distillation + """ + def __init__(self, model_path): + super().__init__() + self.model_path = model_path + self.avg = tf.keras.layers.Average() + self.models = self._get_models(model_path) + + # pylint: disable=R0201 + def _get_models(self, model_path): + models = [] + model_path_list = model_path.split(',') + if len(model_path_list) == 1: + _model = _load_model(model_path) + models.append(_model) + else: + for i, _model_path in enumerate(model_path_list): + _model = _load_model(_model_path, prefix='t'+str(i)) + models.append(_model) + return models + + def call(self, inputs, training=None, mask=None): # pylint: disable=unused-argument + """ + Model call func + :param inputs: Model input + :return: average output + """ + model_outputs = [model(inputs) for model in self.models] + output = self.avg(model_outputs) + return output + + def get_config(self): + """ + Implement get_config to enable serialization. + """ + config = super().get_config() + config.update({"model_path": self.model_path}) + return config diff --git a/src/model_optimizer/pruner/learner/learner_base.py b/src/model_optimizer/pruner/learner/learner_base.py index 81921be..dca5c13 100644 --- a/src/model_optimizer/pruner/learner/learner_base.py +++ b/src/model_optimizer/pruner/learner/learner_base.py @@ -13,7 +13,6 @@ from .utils import get_call_backs from ...stat import print_keras_model_summary, print_keras_model_params_flops from ..distill.distill_loss import DistillLossLayer -from ..core.pruner import dense_present_before_conv class LearnerBase(metaclass=abc.ABCMeta): @@ -104,7 +103,7 @@ def build_dataset(self): eval_dataset = ds_eval.build() train_dataset_distill = None eval_dataset_distill = None - if self.config.get_attribute("scheduler") == "distill": + if self.config.get_attribute("scheduler") == "distill" or self.config.get_attribute('is_distill'): ds_train_distill = get_dataset(self.config, is_training=True, num_shards=hvd.size(), shard_index=hvd.rank()) train_dataset_distill = ds_train_distill.build(True) ds_eval_distill = get_dataset(self.config, is_training=False) @@ -153,7 +152,7 @@ def train(self, initial_epoch=0, epochs=1, lr_schedulers=None): self.callbacks.append(tf.keras.callbacks.ModelCheckpoint(os.path.join(self.checkpoint_path, './checkpoint-{epoch}.h5'), period=self.checkpoint_save_period)) - if self.config.get_attribute('scheduler') == 'distill': + if self.config.get_attribute('scheduler') == 'distill' or self.config.get_attribute('is_distill'): train_dataset = self.train_dataset_distill else: train_dataset = self.train_dataset @@ -165,8 +164,7 @@ def eval(self): """ Model eval process, only evaluate on rank 0 the format of score is like as follows: - {loss: 7.6969 dense1_loss: 5.4490 softmax_1_sparse_categorical_accuracy: 0.0665 - dense1_sparse_categorical_accuracy: 0.0665} + {loss: 7.6969 dense1_sparse_categorical_accuracy: 0.0665} :return: """ if hvd.rank() != 0: @@ -174,10 +172,7 @@ def eval(self): eval_model = self.models_eval[-1] score = eval_model.evaluate(self.eval_dataset, steps=self.eval_steps_per_epoch) loss = score[0] - if self.config.get_attribute("classifier_activation", "softmax") == "softmax": - accuracy = score[2] - else: - accuracy = score[3] + accuracy = score[1] print('Test loss:', loss) print('Test accuracy:', accuracy) @@ -236,7 +231,7 @@ def load_model(self): _custom_objects = { 'DistillLossLayer': DistillLossLayer } - if self.config.get_attribute('scheduler') == 'distill': + if self.config.get_attribute('scheduler') == 'distill' or self.config.get_attribute('is_distill'): custom_objects = _custom_objects else: custom_objects = None @@ -262,49 +257,25 @@ def save_eval_model(self): """ if hvd.rank() != 0: return - train_model = self.models_train[-1] + if self.config.get_attribute('scheduler') == 'distill' or self.config.get_attribute('is_distill'): + train_model = self.models_train[-1].get_layer(self.config.get_attribute('model_name')) + else: + train_model = self.models_train[-1] eval_model = self.models_eval[-1] - channel = -1 - conv_index, last_reshape, dense_ahead_of_conv = dense_present_before_conv(train_model) save_model_path = os.path.join(self.save_model_path, 'checkpoint-') + str(self.cur_epoch) + '.h5' - if self.config.get_attribute('scheduler') == 'distill': - model_name = self.config.get_attribute('model_name') - for layer_eval in eval_model.layers: - for layer in train_model.layers: - if layer.name == model_name and layer_eval.name == model_name: - layer_eval.set_weights(layer.get_weights()) - student_eval = layer_eval - break - student_eval.save(save_model_path) - self.eval_models_update(student_eval) - else: - clone_model = tf.keras.models.clone_model(eval_model) - for i, layer in enumerate(clone_model.layers): - layer_type = str(type(layer)) - # the model's output with convolution, no pruning and getting its channel - if not dense_ahead_of_conv and i == conv_index: - if layer_type.endswith('Conv2D\'>'): - channel = train_model.get_layer(layer.name).filters - continue - # incoperate with the change of channel resulting from pruing filters in convolution - if layer_type.endswith('Reshape\'>'): - if i == last_reshape: - target_shape = (channel,) - clone_model.layers[i].target_shape = target_shape - continue - elif channel != -1: - target_shape = (1, 1, channel) - clone_model.layers[i].target_shape = target_shape - continue - if 'Conv2D' in str(type(layer)): - clone_model.layers[i].filters = train_model.get_layer(layer.name).filters - channel = train_model.get_layer(layer.name).filters - elif 'Dense' in str(type(layer)): - clone_model.layers[i].units = train_model.get_layer(layer.name).units - pruned_eval_model = tf.keras.models.model_from_json(clone_model.to_json()) - pruned_eval_model.set_weights(train_model.get_weights()) - pruned_eval_model.save(save_model_path) - self.eval_models_update(pruned_eval_model) + clone_model = tf.keras.models.clone_model(eval_model) + for i, layer in enumerate(clone_model.layers): + layer_type = str(type(layer)) + if 'Conv2D' in str(type(layer)): + clone_model.layers[i].filters = train_model.get_layer(layer.name).filters + elif 'Dense' in str(type(layer)): + clone_model.layers[i].units = train_model.get_layer(layer.name).units + elif layer_type.endswith('Reshape\'>'): + clone_model.layers[i].target_shape = train_model.get_layer(layer.name).target_shape + pruned_eval_model = tf.keras.models.model_from_json(clone_model.to_json()) + pruned_eval_model.set_weights(train_model.get_weights()) + pruned_eval_model.save(save_model_path) + self.eval_models_update(pruned_eval_model) def print_model_summary(self): """ diff --git a/src/model_optimizer/pruner/learner/lenet_mnist.py b/src/model_optimizer/pruner/learner/lenet_mnist.py index 0ee7cd2..c331be4 100644 --- a/src/model_optimizer/pruner/learner/lenet_mnist.py +++ b/src/model_optimizer/pruner/learner/lenet_mnist.py @@ -56,6 +56,7 @@ def get_metrics(self, is_training=True): :param is_training: is training or not :return: Return model compile metrics """ - if self.config.get_attribute('scheduler') == 'distill' and is_training: + if (self.config.get_attribute('scheduler') == 'distill' or self.config.get_attribute('is_distill', False)) \ + and is_training: return None return ['sparse_categorical_accuracy'] diff --git a/src/model_optimizer/pruner/learner/mobilenet_v1_imagenet.py b/src/model_optimizer/pruner/learner/mobilenet_v1_imagenet.py index 464a9b0..efd3b1a 100644 --- a/src/model_optimizer/pruner/learner/mobilenet_v1_imagenet.py +++ b/src/model_optimizer/pruner/learner/mobilenet_v1_imagenet.py @@ -66,6 +66,7 @@ def get_metrics(self, is_training=True): :param is_training: is training or not :return: Return model compile metrics """ - if self.config.get_attribute('scheduler') == 'distill' and is_training: + if (self.config.get_attribute('scheduler') == 'distill' or self.config.get_attribute('is_distill', False)) \ + and is_training: return None return ['sparse_categorical_accuracy'] diff --git a/src/model_optimizer/pruner/learner/mobilenet_v2_imagenet.py b/src/model_optimizer/pruner/learner/mobilenet_v2_imagenet.py index b29cbfc..7923d0b 100644 --- a/src/model_optimizer/pruner/learner/mobilenet_v2_imagenet.py +++ b/src/model_optimizer/pruner/learner/mobilenet_v2_imagenet.py @@ -65,6 +65,7 @@ def get_metrics(self, is_training=True): :param is_training: is training or not :return: Return model compile metrics """ - if self.config.get_attribute('scheduler') == 'distill' and is_training: + if (self.config.get_attribute('scheduler') == 'distill' or self.config.get_attribute('is_distill', False)) \ + and is_training: return None return ['sparse_categorical_accuracy'] diff --git a/src/model_optimizer/pruner/learner/resnet_101_imagenet.py b/src/model_optimizer/pruner/learner/resnet_101_imagenet.py index 62a3e80..65236b1 100644 --- a/src/model_optimizer/pruner/learner/resnet_101_imagenet.py +++ b/src/model_optimizer/pruner/learner/resnet_101_imagenet.py @@ -58,14 +58,10 @@ def get_losses(self, is_training=True): :return: Return model compile losses """ softmax_loss = tf.keras.losses.SparseCategoricalCrossentropy() - logits_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) if self.config.get_attribute('scheduler') == 'distill' and is_training: return None else: - if self.config.get_attribute("classifier_activation", "softmax") == "softmax": - return [softmax_loss, None] - else: - return [None, logits_loss] + return softmax_loss def get_metrics(self, is_training=True): """ @@ -73,6 +69,7 @@ def get_metrics(self, is_training=True): :param: is_training: is training of not :return: Return model compile metrics """ - if self.config.get_attribute('scheduler') == 'distill' and is_training: + if (self.config.get_attribute('scheduler') == 'distill' or self.config.get_attribute('is_distill', False)) \ + and is_training: return None return ['sparse_categorical_accuracy'] diff --git a/src/model_optimizer/pruner/learner/resnet_50_imagenet.py b/src/model_optimizer/pruner/learner/resnet_50_imagenet.py index b92905c..2080f82 100644 --- a/src/model_optimizer/pruner/learner/resnet_50_imagenet.py +++ b/src/model_optimizer/pruner/learner/resnet_50_imagenet.py @@ -58,14 +58,11 @@ def get_losses(self, is_training=True): :return: Return model compile losses """ softmax_loss = tf.keras.losses.SparseCategoricalCrossentropy() - logits_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) - if self.config.get_attribute('scheduler') == 'distill' and is_training: + if (self.config.get_attribute('scheduler') == 'distill' or self.config.get_attribute('is_distill'))\ + and is_training: return None else: - if self.config.get_attribute("classifier_activation", "softmax") == "softmax": - return [softmax_loss, None] - else: - return [None, logits_loss] + return softmax_loss def get_metrics(self, is_training=True): """ @@ -73,6 +70,7 @@ def get_metrics(self, is_training=True): :param is_training: is training or not :return: Return model compile metrics """ - if self.config.get_attribute('scheduler') == 'distill' and is_training: + if (self.config.get_attribute('scheduler') == 'distill' or self.config.get_attribute('is_distill')) \ + and is_training: return None return ['sparse_categorical_accuracy'] diff --git a/src/model_optimizer/pruner/learner/vgg_m_16_cifar10.py b/src/model_optimizer/pruner/learner/vgg_m_16_cifar10.py index 326e9ee..ad901e9 100644 --- a/src/model_optimizer/pruner/learner/vgg_m_16_cifar10.py +++ b/src/model_optimizer/pruner/learner/vgg_m_16_cifar10.py @@ -64,6 +64,7 @@ def get_metrics(self, is_training=True): :param is_training: is training or not :return: Return model compile metrics """ - if self.config.get_attribute('scheduler') == 'distill' and is_training: + if (self.config.get_attribute('scheduler') == 'distill' or self.config.get_attribute('is_distill')) \ + and is_training: return None return ['sparse_categorical_accuracy'] diff --git a/src/model_optimizer/pruner/models/__init__.py b/src/model_optimizer/pruner/models/__init__.py index 16d806c..ccc2f24 100644 --- a/src/model_optimizer/pruner/models/__init__.py +++ b/src/model_optimizer/pruner/models/__init__.py @@ -21,31 +21,35 @@ def get_model(config, is_training=True): if model_name not in ['lenet', 'resnet_18', 'vgg_m_16', 'resnet_50', 'resnet_101', 'mobilenet_v1', 'mobilenet_v2']: raise Exception('Not support model %s' % model_name) + if (config.get_attribute('scheduler') == 'distill' or config.get_attribute('is_distill')) and is_training: + classifier_activation = None + else: + classifier_activation = 'softmax' if model_name == 'lenet': from .lenet import lenet - return lenet(model_name, is_training) + student_model = lenet(is_training, model_name, classifier_activation=classifier_activation) elif model_name == 'vgg_m_16': from .vgg import vgg_m_16 - return vgg_m_16(is_training, model_name) + student_model = vgg_m_16(is_training, model_name, classifier_activation=classifier_activation) elif model_name == 'resnet_18': from .resnet import resnet_18 - return resnet_18(is_training, model_name) + student_model = resnet_18(is_training, model_name, classifier_activation=classifier_activation) elif model_name == 'resnet_50': from .resnet import resnet_50 - student_model = resnet_50(is_training, model_name) - if config.get_attribute('scheduler') == 'distill': - distill_model = get_distiller(student_model, scheduler_config) - return distill_model - else: - return student_model + student_model = resnet_50(is_training, model_name, classifier_activation=classifier_activation) elif model_name == 'resnet_101': from .resnet import resnet_101 - return resnet_101(is_training, model_name) + student_model = resnet_101(is_training, model_name, classifier_activation=classifier_activation) elif model_name == 'mobilenet_v1': from .mobilenet_v1 import mobilenet_v1_1 - return mobilenet_v1_1(is_training=is_training, name=model_name) + student_model = mobilenet_v1_1(is_training=is_training, name=model_name, + classifier_activation=classifier_activation) elif model_name == 'mobilenet_v2': from .mobilenet_v2 import mobilenet_v2_1 - return mobilenet_v2_1(is_training=is_training, name=model_name) + student_model = mobilenet_v2_1(is_training=is_training, name=model_name, + classifier_activation=classifier_activation) + if (config.get_attribute('scheduler') == 'distill' or config.get_attribute('is_distill')) and is_training: + distill_model = get_distiller(student_model, scheduler_config) else: - raise Exception('Not support model {}'.format(model_name)) + distill_model = student_model + return distill_model diff --git a/src/model_optimizer/pruner/models/config.py b/src/model_optimizer/pruner/models/config.py new file mode 100644 index 0000000..7bce445 --- /dev/null +++ b/src/model_optimizer/pruner/models/config.py @@ -0,0 +1,63 @@ +# Copyright 2019 ZTE corporation. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Get the configuration, if the environment variable exists, get it from the environment variable. +If it does not exist, get it from the initial value. +""" +import os + + +class ModelConfig: + """ + Model train config parameters + """ + def __init__(self, l2_weight_decay=1e-4, batch_norm_decay=0.95, batch_norm_epsilon=0.001, std_dev=0.09): + self._l2_weight_decay = l2_weight_decay + self._std_dev = std_dev + self._batch_norm_decay = batch_norm_decay + self._batch_norm_epsilon = batch_norm_epsilon + + @property + def l2_weight_decay(self): + """ + l2_weight_decay + :return: l2_weight_decay + """ + if 'L2_WEIGHT_DECAY' in os.environ: + return float(os.environ['L2_WEIGHT_DECAY']) + else: + return self._l2_weight_decay + + @property + def batch_norm_decay(self): + """ + batch_norm_decay + :return: batch_norm_decay + """ + if 'BATCH_NORM_DECAY' in os.environ: + return float(os.environ['BATCH_NORM_DECAY']) + else: + return self._batch_norm_decay + + @property + def batch_norm_epsilon(self): + """ + batch_norm_epsilon + :return: batch_norm_epsilon + """ + if 'BATCH_NORM_EPSILON' in os.environ: + return float(os.environ['BATCH_NORM_EPSILON']) + else: + return self._batch_norm_epsilon + + @property + def std_dev(self): + """ + std_dev + :return: std_dev + """ + if 'STD_DEV' in os.environ: + return float(os.environ['STD_DEV']) + else: + return self._std_dev diff --git a/src/model_optimizer/pruner/models/lenet.py b/src/model_optimizer/pruner/models/lenet.py index 44767e3..3b4bdf9 100644 --- a/src/model_optimizer/pruner/models/lenet.py +++ b/src/model_optimizer/pruner/models/lenet.py @@ -7,11 +7,12 @@ import tensorflow as tf -def lenet(name, is_training=True): +def lenet(is_training=True, name='lenet', classifier_activation='softmax'): """ This implements a slightly modified LeNet-5 [LeCun et al., 1998a] + :param is_training: if training or :param name: the model name - :param is_training: if training or not + :param classifier_activation: classifier_activation can only be None or "softmax" :return: LeNet model """ input_ = tf.keras.layers.Input(shape=(28, 28, 1), name='input') @@ -33,6 +34,9 @@ def lenet(name, is_training=True): x = tf.keras.layers.Flatten(name='flatten')(x) x = tf.keras.layers.Dense(120, activation='relu', name='dense_1')(x) x = tf.keras.layers.Dense(84, activation='relu', name='dense_2')(x) - output_ = tf.keras.layers.Dense(10, activation='softmax', name='dense_3')(x) + if classifier_activation == 'softmax': + output_ = tf.keras.layers.Dense(10, activation='softmax', name='dense_3')(x) + else: + output_ = tf.keras.layers.Dense(10, activation=None, name='dense_3')(x) model = tf.keras.Model(input_, output_, name=name) return model diff --git a/src/model_optimizer/pruner/models/mobilenet_v1.py b/src/model_optimizer/pruner/models/mobilenet_v1.py index 6ace82d..4991bcb 100644 --- a/src/model_optimizer/pruner/models/mobilenet_v1.py +++ b/src/model_optimizer/pruner/models/mobilenet_v1.py @@ -7,12 +7,13 @@ Adapted from tf.keras.applications.mobilenet.MobileNetV2(). """ import tensorflow as tf +from .config import ModelConfig - -L2_WEIGHT_DECAY = 0.00004 -STD_DEV = 0.09 -BATCH_NORM_DECAY = 0.95 -BATCH_NORM_EPSILON = 0.001 +_config = ModelConfig(l2_weight_decay=0.00004, batch_norm_decay=0.95, batch_norm_epsilon=0.001, std_dev=0.09) +L2_WEIGHT_DECAY = _config.l2_weight_decay +BATCH_NORM_DECAY = _config.batch_norm_decay +BATCH_NORM_EPSILON = _config.batch_norm_epsilon +STD_DEV = _config.std_dev def _gen_l2_regularizer(use_l2_regularizer=True): @@ -23,55 +24,65 @@ def _gen_initializer(use_initializer=True): return tf.keras.initializers.TruncatedNormal(stddev=STD_DEV) if use_initializer else None -def mobilenet_v1_0_25(num_classes=1001, +def mobilenet_v1_0_25(num_classes=1000, dropout_prob=1e-3, is_training=True, - depth_multiplier=1): + depth_multiplier=1, + classifier_activation='softmax'): """ Build mobilenet_v1_0.25 model :param num_classes: :param dropout_prob: :param is_training: :param depth_multiplier: + :param classifier_activation: classifier_activation can only be None or "softmax" :return: """ - return _mobilenet_v1(num_classes, dropout_prob, is_training, scale=0.25, depth_multiplier=depth_multiplier) + return _mobilenet_v1(num_classes, dropout_prob, is_training, scale=0.25, + depth_multiplier=depth_multiplier, classifier_activation=classifier_activation) -def mobilenet_v1_0_5(num_classes=1001, +def mobilenet_v1_0_5(num_classes=1000, dropout_prob=1e-3, is_training=True, - depth_multiplier=1): + depth_multiplier=1, + classifier_activation='softmax'): """ Build mobilenet_v1_0.5 model :param num_classes: :param dropout_prob: :param is_training: :param depth_multiplier: + :param classifier_activation: classifier_activation can only be None or "softmax" :return: """ - return _mobilenet_v1(num_classes, dropout_prob, is_training, scale=0.5, depth_multiplier=depth_multiplier) + return _mobilenet_v1(num_classes, dropout_prob, is_training, scale=0.5, + depth_multiplier=depth_multiplier, classifier_activation=classifier_activation) -def mobilenet_v1_0_75(num_classes=1001, +def mobilenet_v1_0_75(num_classes=1000, dropout_prob=1e-3, is_training=True, - depth_multiplier=1): + depth_multiplier=1, + classifier_activation='softmax'): """ Build mobilenet_v1_0.75 model :param num_classes: :param dropout_prob: :param is_training: :param depth_multiplier: + :param classifier_activation: classifier_activation can only be None or "softmax" :return: """ - return _mobilenet_v1(num_classes, dropout_prob, is_training, scale=0.75, depth_multiplier=depth_multiplier) + return _mobilenet_v1(num_classes, dropout_prob, is_training, scale=0.75, + depth_multiplier=depth_multiplier, classifier_activation=classifier_activation) -def mobilenet_v1_1(name, num_classes=1001, +def mobilenet_v1_1(name, num_classes=1000, dropout_prob=1e-3, is_training=True, - depth_multiplier=1): + depth_multiplier=1, + classifier_activation='softmax'): """ Build mobilenet_v1_1.0 model :param name: the model name @@ -79,16 +90,19 @@ def mobilenet_v1_1(name, num_classes=1001, :param dropout_prob: :param is_training: :param depth_multiplier: + :param classifier_activation: classifier_activation can only be None or "softmax" :return: """ - return _mobilenet_v1(name, num_classes, dropout_prob, is_training, scale=1.0, depth_multiplier=depth_multiplier) + return _mobilenet_v1(name, num_classes, dropout_prob, is_training, scale=1.0, + depth_multiplier=depth_multiplier, classifier_activation=classifier_activation) def _mobilenet_v1(name, num_classes=1000, dropout_prob=1e-3, is_training=True, scale=1.0, - depth_multiplier=1): + depth_multiplier=1, + classifier_activation='softmax'): """ Build mobilenet_v1 model :param name: the model name @@ -97,6 +111,7 @@ def _mobilenet_v1(name, num_classes=1000, :param is_training: :param scale: :param depth_multiplier: + :param classifier_activation: classifier_activation can only be None or "softmax" :return: """ inputs = tf.keras.layers.Input(shape=(224, 224, 3), name='input') @@ -132,7 +147,10 @@ def _mobilenet_v1(name, num_classes=1000, kernel_regularizer=_gen_l2_regularizer(), name='conv_preds')(x) x = tf.keras.layers.Reshape((num_classes,), name='reshape_2')(x) - outputs = tf.keras.layers.Activation('softmax', name='act_softmax')(x) + if classifier_activation == 'softmax': + outputs = tf.keras.layers.Activation('softmax', name='act_softmax')(x) + else: + outputs = x model = tf.keras.Model(inputs, outputs, name=name) return model diff --git a/src/model_optimizer/pruner/models/mobilenet_v2.py b/src/model_optimizer/pruner/models/mobilenet_v2.py index dab3a9c..b036f19 100644 --- a/src/model_optimizer/pruner/models/mobilenet_v2.py +++ b/src/model_optimizer/pruner/models/mobilenet_v2.py @@ -6,12 +6,13 @@ Adapted from tf.keras.applications.mobilenet.MobileNetV2(). """ import tensorflow as tf +from .config import ModelConfig - -L2_WEIGHT_DECAY = 0.00004 -STD_DEV = 0.09 -BATCH_NORM_DECAY = 0.99 -BATCH_NORM_EPSILON = 0.001 +_config = ModelConfig(l2_weight_decay=0.00004, batch_norm_decay=0.99, batch_norm_epsilon=0.001, std_dev=0.09) +L2_WEIGHT_DECAY = _config.l2_weight_decay +BATCH_NORM_DECAY = _config.batch_norm_decay +BATCH_NORM_EPSILON = _config.batch_norm_epsilon +STD_DEV = _config.std_dev def _gen_l2_regularizer(use_l2_regularizer=True): @@ -22,90 +23,109 @@ def _gen_initializer(use_initializer=True): return tf.keras.initializers.TruncatedNormal(stddev=STD_DEV) if use_initializer else None -def mobilenet_v2_0_35(num_classes=1001, +def mobilenet_v2_0_35(num_classes=1000, dropout_prob=1e-3, - is_training=True): + is_training=True, + classifier_activation='softmax'): """ Build mobilenet_v2_0.35 model :param num_classes: :param dropout_prob: :param is_training: + :param classifier_activation: classifier_activation can only be None or "softmax" :return: """ - return _mobilenet_v2(num_classes, dropout_prob, is_training, scale=0.35) + return _mobilenet_v2(num_classes, dropout_prob, is_training, + scale=0.35, classifier_activation=classifier_activation) -def mobilenet_v2_0_5(num_classes=1001, +def mobilenet_v2_0_5(num_classes=1000, dropout_prob=1e-3, - is_training=True): + is_training=True, + classifier_activation='softmax'): """ Build mobilenet_v2_0.5 model :param num_classes: :param dropout_prob: :param is_training: + :param classifier_activation: classifier_activation can only be None or "softmax" :return: """ - return _mobilenet_v2(num_classes, dropout_prob, is_training, scale=0.5) + return _mobilenet_v2(num_classes, dropout_prob, is_training, + scale=0.5, classifier_activation=classifier_activation) -def mobilenet_v2_0_75(num_classes=1001, +def mobilenet_v2_0_75(num_classes=1000, dropout_prob=1e-3, - is_training=True): + is_training=True, + classifier_activation='softmax'): """ Build mobilenet_v2_0.75 model :param num_classes: :param dropout_prob: :param is_training: + :param classifier_activation: classifier_activation can only be None or "softmax" :return: """ - return _mobilenet_v2(num_classes, dropout_prob, is_training, scale=0.75) + return _mobilenet_v2(num_classes, dropout_prob, is_training, + scale=0.75, classifier_activation=classifier_activation) -def mobilenet_v2_1(name, num_classes=1001, +def mobilenet_v2_1(name, num_classes=1000, dropout_prob=1e-3, - is_training=True): + is_training=True, + classifier_activation='softmax'): """ Build mobilenet_v2_1.0 model :param name: the model name :param num_classes: :param dropout_prob: :param is_training: + :param classifier_activation: classifier_activation can only be None or "softmax" :return: """ - return _mobilenet_v2(name, num_classes, dropout_prob, is_training, scale=1.0) + return _mobilenet_v2(name, num_classes, dropout_prob, is_training, + scale=1.0, classifier_activation=classifier_activation) -def mobilenet_v2_1_3(name, num_classes=1001, +def mobilenet_v2_1_3(name, num_classes=1000, dropout_prob=1e-3, - is_training=True): + is_training=True, + classifier_activation='softmax'): """ Build mobilenet_v2_1.3 model :param name: the model name :param num_classes: :param dropout_prob: :param is_training: + :param classifier_activation: classifier_activation can only be None or "softmax" :return: """ - return _mobilenet_v2(name, num_classes, dropout_prob, is_training, scale=1.3) + return _mobilenet_v2(name, num_classes, dropout_prob, is_training, + scale=1.3, classifier_activation=classifier_activation) -def mobilenet_v2_1_4(num_classes=1001, +def mobilenet_v2_1_4(num_classes=1000, dropout_prob=1e-3, - is_training=True): + is_training=True, + classifier_activation='softmax'): """ Build mobilenet_v2_1.4 model :param num_classes: :param dropout_prob: :param is_training: + :param classifier_activation: classifier_activation can only be None or "softmax" :return: """ - return _mobilenet_v2(num_classes, dropout_prob, is_training, scale=1.4) + return _mobilenet_v2(num_classes, dropout_prob, is_training, + scale=1.4, classifier_activation=classifier_activation) -def _mobilenet_v2(name, num_classes=1001, +def _mobilenet_v2(name, num_classes=1000, dropout_prob=1e-3, is_training=True, - scale=1.0): + scale=1.0, + classifier_activation='softmax'): """ Build mobilenet_v2 model :param name: the model name @@ -113,6 +133,7 @@ def _mobilenet_v2(name, num_classes=1001, :param dropout_prob: :param is_training: :param scale: + :param classifier_activation: classifier_activation can only be None or "softmax" :return: """ first_block_filters = _make_divisible(32 * scale, 8) @@ -195,7 +216,10 @@ def _mobilenet_v2(name, num_classes=1001, kernel_regularizer=_gen_l2_regularizer(), name='conv_preds')(x) x = tf.keras.layers.Reshape((num_classes,), name='reshape_2')(x) - outputs = tf.keras.layers.Activation('softmax', name='act_softmax')(x) + if classifier_activation == 'softmax': + outputs = tf.keras.layers.Activation('softmax', name='act_softmax')(x) + else: + outputs = x model = tf.keras.Model(inputs, outputs, name=name) return model diff --git a/src/model_optimizer/pruner/models/resnet.py b/src/model_optimizer/pruner/models/resnet.py index 93e1590..3a13ede 100644 --- a/src/model_optimizer/pruner/models/resnet.py +++ b/src/model_optimizer/pruner/models/resnet.py @@ -9,18 +9,20 @@ """ import tensorflow as tf +from .config import ModelConfig - -L2_WEIGHT_DECAY = 1e-4 -BATCH_NORM_DECAY = 0.9 -BATCH_NORM_EPSILON = 1e-5 +_config = ModelConfig(l2_weight_decay=1e-4, batch_norm_decay=0.9, batch_norm_epsilon=1e-5) +L2_WEIGHT_DECAY = _config.l2_weight_decay +BATCH_NORM_DECAY = _config.batch_norm_decay +BATCH_NORM_EPSILON = _config.batch_norm_epsilon def _gen_l2_regularizer(use_l2_regularizer=True): return tf.keras.regularizers.l2(L2_WEIGHT_DECAY) if use_l2_regularizer else None -def resnet(layer_num, name, num_classes=1001, use_l2_regularizer=True, is_training=True): +def resnet(layer_num, name, num_classes=1000, use_l2_regularizer=True, + is_training=True, classifier_activation='softmax'): """ Build resnet-18 resnet-34 resnet-50 resnet-101 resnet-152 model :param layer_num: 18, 34, 50, 101, 152 @@ -28,6 +30,7 @@ def resnet(layer_num, name, num_classes=1001, use_l2_regularizer=True, is_traini :param num_classes: classification class :param use_l2_regularizer: if use l2_regularizer :param is_training: if training or not + :param classifier_activation: classifier_activation can only be None or "softmax" :return: keras model """ if layer_num == 18: @@ -62,49 +65,56 @@ def resnet(layer_num, name, num_classes=1001, use_l2_regularizer=True, is_traini logits = tf.keras.layers.Dense(num_classes, kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01), kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer), bias_regularizer=_gen_l2_regularizer(use_l2_regularizer), name='dense1')(x) - outputs = tf.keras.layers.Softmax()(logits) - model = tf.keras.Model(inputs, [outputs, logits], name=name) + if classifier_activation == 'softmax': + outputs = tf.keras.layers.Softmax()(logits) + else: + outputs = logits + model = tf.keras.Model(inputs, outputs, name=name) return model -def resnet_18(is_training, name): +def resnet_18(is_training, name, classifier_activation='softmax'): """ Build resnet-18 model :param is_training: if training or not :param name: the model name + :param classifier_activation: classifier_activation can only be None or "softmax" :return: resnet-18 model """ - return resnet(18, is_training=is_training, name=name) + return resnet(18, is_training=is_training, name=name, classifier_activation=classifier_activation) -def resnet_34(is_training, name): +def resnet_34(is_training, name, classifier_activation='softmax'): """ Build resnet-34 model :param is_training: if training or not :param name: the model name + :param classifier_activation: classifier_activation can only be None or "softmax" :return: resnet-34 model """ - return resnet(34, is_training=is_training, name=name) + return resnet(34, is_training=is_training, name=name, classifier_activation=classifier_activation) -def resnet_50(is_training, name): +def resnet_50(is_training, name, classifier_activation='softmax'): """ Build resnet-50 model :param is_training: if training or not :param name: the model name + :param classifier_activation: classifier_activation can only be None or "softmax" :return: resnet-50 model """ - return resnet(50, is_training=is_training, name=name) + return resnet(50, is_training=is_training, name=name, classifier_activation=classifier_activation) -def resnet_101(is_training, name): +def resnet_101(is_training, name, classifier_activation='softmax'): """ Build resnet-101 model :param is_training: if training or not :param name: the model name + :param classifier_activation: classifier_activation can only be None or "softmax" :return: resnet-101 model """ - return resnet(101, is_training=is_training, name=name) + return resnet(101, is_training=is_training, name=name, classifier_activation=classifier_activation) def residual_block(stage, block_num, input_data, filters, kernel_size, is_training): diff --git a/src/model_optimizer/pruner/models/vgg.py b/src/model_optimizer/pruner/models/vgg.py index 3430f2b..2f71c15 100644 --- a/src/model_optimizer/pruner/models/vgg.py +++ b/src/model_optimizer/pruner/models/vgg.py @@ -5,62 +5,68 @@ VGG models """ import tensorflow as tf +from .config import ModelConfig -L2_WEIGHT_DECAY = 1e-4 -BATCH_NORM_DECAY = 0.9 -BATCH_NORM_EPSILON = 1e-5 +_config = ModelConfig(l2_weight_decay=1e-4, batch_norm_decay=0.9, batch_norm_epsilon=1e-5) +L2_WEIGHT_DECAY = _config.l2_weight_decay +BATCH_NORM_DECAY = _config.batch_norm_decay +BATCH_NORM_EPSILON = _config.batch_norm_epsilon -def vgg_16(is_training, name, num_classes=1001, use_l2_regularizer=True): +def vgg_16(is_training, name, num_classes=1001, use_l2_regularizer=True, classifier_activation='softmax'): """ VGG-16 model :param is_training: if training or not :param name: the model name :param num_classes: classification class :param use_l2_regularizer: if use l2 regularizer or not + :param classifier_activation: classifier_activation can only be None or "softmax" :return: """ return vgg(ver='D', is_training=is_training, name=name, num_classes=num_classes, - use_l2_regularizer=use_l2_regularizer) + use_l2_regularizer=use_l2_regularizer, classifier_activation=classifier_activation) -def vgg_19(is_training, name, num_classes=1001, use_l2_regularizer=True): +def vgg_19(is_training, name, num_classes=1001, use_l2_regularizer=True, classifier_activation='softmax'): """ VGG-19 model :param is_training: if training or not :param name: the model name :param num_classes: classification class :param use_l2_regularizer: if use l2 regularizer or not + :param classifier_activation: classifier_activation can only be None or "softmax" :return: """ return vgg(ver='E', is_training=is_training, name=name, num_classes=num_classes, - use_l2_regularizer=use_l2_regularizer) + use_l2_regularizer=use_l2_regularizer, classifier_activation=classifier_activation) -def vgg_m_16(is_training, name, num_classes=10, use_l2_regularizer=True): +def vgg_m_16(is_training, name, num_classes=10, use_l2_regularizer=True, classifier_activation='softmax'): """ VGG-M-16 model :param is_training: if training or not :param name: the model name :param num_classes: classification class :param use_l2_regularizer: if use l2 regularizer or not + :param classifier_activation: classifier_activation can only be None or "softmax" :return: """ return vgg_m(ver='D', is_training=is_training, name=name, num_classes=num_classes, - use_l2_regularizer=use_l2_regularizer) + use_l2_regularizer=use_l2_regularizer, classifier_activation=classifier_activation) -def vgg_m_19(is_training, name, num_classes=10, use_l2_regularizer=True): +def vgg_m_19(is_training, name, num_classes=10, use_l2_regularizer=True, classifier_activation='softmax'): """ VGG-M-19 model :param is_training: if training or not :param name: the model name :param num_classes: classification class :param use_l2_regularizer: if use l2 regularizer or not + :param classifier_activation: classifier_activation can only be None or "softmax" :return: """ return vgg_m(ver='E', is_training=is_training, name=name, num_classes=num_classes, - use_l2_regularizer=use_l2_regularizer) + use_l2_regularizer=use_l2_regularizer, classifier_activation=classifier_activation) def _gen_l2_regularizer(use_l2_regularizer=True): @@ -81,7 +87,7 @@ def _vgg_blocks(block, conv_num, filters, x, is_training, use_l2_regularizer=Tru return x -def vgg(ver, is_training, name, num_classes=1001, use_l2_regularizer=True): +def vgg(ver, is_training, name, num_classes=1001, use_l2_regularizer=True, classifier_activation='softmax'): """ VGG models :param ver: 'D' or 'E' @@ -89,6 +95,7 @@ def vgg(ver, is_training, name, num_classes=1001, use_l2_regularizer=True): :param name: the model name :param num_classes: classification class :param use_l2_regularizer: if use l2 regularizer or not + :param classifier_activation: classifier_activation can only be None or "softmax" :return: """ if ver == 'D': @@ -117,16 +124,21 @@ def vgg(ver, is_training, name, num_classes=1001, use_l2_regularizer=True): kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer), bias_regularizer=_gen_l2_regularizer(use_l2_regularizer), name='fc2')(x) - outputs = tf.keras.layers.Dense(num_classes, activation='softmax', - kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01), - kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer), - bias_regularizer=_gen_l2_regularizer(use_l2_regularizer), - name='fc3')(x) + + logits = tf.keras.layers.Dense(num_classes, activation=None, + kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01), + kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer), + bias_regularizer=_gen_l2_regularizer(use_l2_regularizer), + name='fc3')(x) + if classifier_activation == 'softmax': + outputs = tf.keras.layers.Softmax()(logits) + else: + outputs = logits model = tf.keras.Model(inputs, outputs, name=name) return model -def vgg_m(ver, is_training, name, num_classes=10, use_l2_regularizer=True): +def vgg_m(ver, is_training, name, num_classes=10, use_l2_regularizer=True, classifier_activation='softmax'): """ VGG-M models :param ver: 'D' or 'E' @@ -134,6 +146,7 @@ def vgg_m(ver, is_training, name, num_classes=10, use_l2_regularizer=True): :param name: the model name :param num_classes: classification class :param use_l2_regularizer: if use l2 regularizer or not + :param classifier_activation: classifier_activation can only be None or "softmax" :return: """ if ver == 'D': @@ -153,10 +166,14 @@ def vgg_m(ver, is_training, name, num_classes=10, use_l2_regularizer=True): x = tf.keras.layers.Flatten(name='flat1')(x) - outputs = tf.keras.layers.Dense(num_classes, activation='softmax', - kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01), - kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer), - bias_regularizer=_gen_l2_regularizer(use_l2_regularizer), - name='fc2')(x) + logits = tf.keras.layers.Dense(num_classes, activation=None, + kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01), + kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer), + bias_regularizer=_gen_l2_regularizer(use_l2_regularizer), + name='fc2')(x) + if classifier_activation == 'softmax': + outputs = tf.keras.layers.Softmax()(logits) + else: + outputs = logits model = tf.keras.Model(inputs, outputs, name=name) return model diff --git a/src/model_optimizer/pruner/scheduler/distill/resnet_50_imagenet_0.3.yaml b/src/model_optimizer/pruner/scheduler/distill/resnet_50_imagenet_0.3.yaml index 442efed..b19ad18 100644 --- a/src/model_optimizer/pruner/scheduler/distill/resnet_50_imagenet_0.3.yaml +++ b/src/model_optimizer/pruner/scheduler/distill/resnet_50_imagenet_0.3.yaml @@ -1,5 +1,6 @@ version: 1 -alpha: 0.3 -temperature: 10 -student_name: "resnet_50" -teacher_path: "/root/work/examples/models_ckpt/resnet_101_imagenet_120e_logits/checkpoint-120.h5" \ No newline at end of file +distill: + alpha: 0.3 + temperature: 10 + student_name: "resnet_50" + teacher_path: "/root/work/examples/models_ckpt/resnet_101_imagenet_120e_logits/checkpoint-120.h5" \ No newline at end of file diff --git a/src/model_optimizer/pruner/scheduler/uniform_auto/resnet_50_imagenet_0.5_distill.yaml b/src/model_optimizer/pruner/scheduler/uniform_auto/resnet_50_imagenet_0.5_distill.yaml new file mode 100644 index 0000000..9c717b9 --- /dev/null +++ b/src/model_optimizer/pruner/scheduler/uniform_auto/resnet_50_imagenet_0.5_distill.yaml @@ -0,0 +1,44 @@ +version: 1 +pruners: + prune_func1: + criterion: l1_norm + prune_type: auto_prune + ratio: 0.5 + +lr_schedulers: + # Learning rate + - name: warmup_lr + class: LearningRateWarmupCallback + warmup_epochs: 5 + verbose: 0 + - name: lr_multiply_1 + class: LearningRateScheduleCallback + start_epoch: 5 + end_epoch: 120 + multiplier: 1.0 + - name: lr_multiply_0.1 + class: LearningRateScheduleCallback + start_epoch: 120 + end_epoch: 240 + multiplier: 1e-1 + - name: lr_multiply_0.01 + class: LearningRateScheduleCallback + start_epoch: 240 + end_epoch: 320 + multiplier: 1e-2 + - name: lr_multiply_0.001 + class: LearningRateScheduleCallback + start_epoch: 320 + multiplier: 1e-3 + +prune_schedulers: + - pruner: + func_name: prune_func1 + epochs: [0] + +distill: + alpha: 0.4 + temperature: 1 + student_name: "resnet_50" + teacher_path: "/models_zoo/senet154-0466-f1b79a9b_tf2.h5,/models_zoo/resnet152b-0431-b41ec90e.tf2.h5" + model_load_func: "load_tf_ensemble_model" \ No newline at end of file diff --git a/src/model_optimizer/stat.py b/src/model_optimizer/stat.py index 8b5afa9..39d04d8 100644 --- a/src/model_optimizer/stat.py +++ b/src/model_optimizer/stat.py @@ -9,24 +9,15 @@ # pylint: disable=not-context-manager -def get_keras_model_flops(model_h5_path): +def get_keras_model_params_flops(model_h5_path): """ Get keras model FLOPs :param model_h5_path: keras model path - :return: FLOPs + :return: Params, FLOPs """ - session = tf.compat.v1.Session() - graph = tf.compat.v1.get_default_graph() - - with graph.as_default(): - with session.as_default(): - _ = tf.keras.models.load_model(model_h5_path) - run_meta = tf.compat.v1.RunMetadata() - opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation() - flops = tf.compat.v1.profiler.profile(graph=graph, - run_meta=run_meta, cmd='op', options=opts) - tf.compat.v1.reset_default_graph() - return flops.total_float_ops + model = tf.keras.models.load_model(model_h5_path) + total_params, total_flops = _count_model_params_flops(model) + return total_params, total_flops def print_keras_model_summary(model, hvd_rank): diff --git a/src/model_optimizer/utils/imagenet_preprocessing.py b/src/model_optimizer/utils/imagenet_preprocessing.py index 9905d6d..9820c39 100644 --- a/src/model_optimizer/utils/imagenet_preprocessing.py +++ b/src/model_optimizer/utils/imagenet_preprocessing.py @@ -10,6 +10,29 @@ _RESIZE_MIN = 256 +def image_normalization(img): + """ + Normalization as in the ImageNet-1K validation procedure. + Parameters: + ---------- + img : np.array + input image. + mean_rgb : tuple of 3 float + Mean of RGB channels in the dataset. + std_rgb : tuple of 3 float + STD of RGB channels in the dataset. + Returns: + ------- + np.array + Output image. + """ + mean_rgb = [123.675, 116.28, 103.53] + std_rgb = [58.395, 57.12, 57.375] + img = tf.subtract(img, tf.broadcast_to(mean_rgb, tf.shape(img))) + img = tf.divide(img, tf.broadcast_to(std_rgb, tf.shape(img))) + return img + + def preprocess_image(image_buffer, bbox, output_height, output_width, num_channels=3, is_training=False): """ @@ -42,7 +65,7 @@ def preprocess_image(image_buffer, bbox, output_height, output_width, image_buffer, crop_window, channels=num_channels) cropped = tf.image.random_flip_left_right(cropped) - return tf.image.resize(cropped, [output_height, output_width], method=tf.image.ResizeMethod.BILINEAR) + img = tf.image.resize(cropped, [output_height, output_width], method=tf.image.ResizeMethod.BILINEAR) else: image = tf.image.decode_jpeg(image_buffer, channels=num_channels) @@ -62,4 +85,5 @@ def preprocess_image(image_buffer, bbox, output_height, output_width, height, width = shape[0], shape[1] crop_top = (height - output_height) // 2 crop_left = (width - output_width) // 2 - return tf.slice(image, [crop_top, crop_left, 0], [output_height, output_width, -1]) + img = tf.slice(image, [crop_top, crop_left, 0], [output_height, output_width, -1]) + return image_normalization(img)