Skip to content

Commit

Permalink
Support ensemble distillation. (#23)
Browse files Browse the repository at this point in the history
close  #22
  • Loading branch information
ChengcanWang-zte committed Jun 17, 2021
1 parent 4ba7077 commit 385fcf2
Show file tree
Hide file tree
Showing 34 changed files with 660 additions and 267 deletions.
26 changes: 19 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ sparsity pruning depends on special algorithms and hardware to achieve accelerat
Adlik pruning focuses on channel pruning and filter pruning, which can really reduce the number of parameters and
flops. In terms of quantization, Adlik focuses on 8-bit quantization that is easier to accelerate on specific hardware.
After testing, it is found that running a small batch of datasets can obtain a quantitative model with little loss of
accuracy, so Adlik focuses on this method. Knowledge distillation is another way to improve the performance of deep
accuracy, so Adlik focuses on this method. Knowledge distillation is another way to improve the performance of deep
learning algorithm. It is possible to compress the knowledge in the big model into a smaller model.

The proposed framework mainly consists of two categories of algorithm components, i.e. pruner and quantizer. The
Expand All @@ -23,15 +23,15 @@ three modules.
After filter pruning, model can continue to be quantized, the following table shows the accuracy of the pruned and
quantized Lenet-5 and ResNet-50 models.

| model | baseline | pruned | pruned+quantization(TF-Lite) | pruned+quantization(TF-TRT) |
| Model | Baseline | Pruned | Pruned + Quantization(TF-Lite) | Pruned + Quantization(TF-TRT) |
| --------- | -------- | -------------------- | ---------------------------- | --------------------------- |
| LeNet-5 | 98.85 | 99.11(59% pruned) | 99.05 | 99.11 |
| ResNet-50 | 76.174 | 75.456(31.9% pruned) | 75.158 | 75.28 |

The Pruner completely removes redundant parameters, which further leads to smaller model size and faster execution.
The following table is the size of the above model files:

| model | baseline(H5) | pruned(H5) | quantization(TF-Lite) | quantization(TF-TRT) |
| Model | Baseline(H5) | Pruned(H5) | Quantization(TF-Lite) | Quantization(TF-TRT) |
| --------- | ------------ | ------------------ | --------------------- | -------------------- |
| LeNet-5 | 1176KB | 499KB(59% pruned) | 120KB | 1154KB (pb) |
| ResNet-50 | 99MB | 67MB(31.9% pruned) | 18MB | 138MB(pb) |
Expand All @@ -47,12 +47,24 @@ which was tested on ImageNet. The original test accuracy is 71.25%, and model si

Knowledge distillation is an effective way to imporve the performance of model.

The following table shows the distillation result of ResNet-50 as the student network where ResNet-101 as the teacher network.
The following table shows the distillation result of ResNet-50 as the student network where ResNet-101 as the teacher network.

| student model | ResNet-101 distilled | accuracy change |
| Student Model | ResNet-101 Distilled | Accuracy Change |
| ------------- | -------------------- | --------------- |
| ResNet-50 | 77.14% | +0.97% |

Ensemble distillation can significantly improve the accuracy of the model. In the case of cutting 72.8% of the
parameters, using senet154 and resnet152b as the teacher network, ensemble distillation can increase the accuracy
by more than 4%.
The details are shown in the table below, and the code can refer to examples\resnet_50_imagenet_prune_distill.py.

| Model | Accuracy | Params | FLOPs | Model Size |
| --------- | -------- | -------------------- | ---------------------------- | ---------------------------- |
| ResNet-50 | 76.174 | 25610152 | 3899M|99M |
| + pruned | 72.28 | 6954152 ( 72.8% pruned) | 1075M | 27M|
| + pruned + distill | 76.39 | 6954152 ( 72.8% pruned) | 1075M | 27M|
| + pruned + distill + quantization(TF-Lite) | 75.938 | - | - | 7.1M|

## 1. Pruning and quantization principle

### 1.1 Filter pruning
Expand Down Expand Up @@ -85,7 +97,7 @@ quantification of ResNet-50 in less than one minute.

Knowledge distillation is a compression technique by which the knowledge of a larger model(teacher) is transfered into
a smaller one(student). During distillation, a student model learns from a teacher model to generalize well by raise
the temperature of the final softmax of the teacher model as the soft set of targets.
the temperature of the final softmax of the teacher model as the soft set of targets.

![Distillation](imgs/distillation.png)

Expand Down Expand Up @@ -252,7 +264,7 @@ This step is the same as described above. You can get detailed instructions from
Batch size is an important hyper-parameter for Deep Learning model training. If you have more GPU memory available,
you can try larger batch size! You have to adjust the learning rate according to different batch size.

| model | card | batch size | learning-rate |
| Model | Card | Batch Size | Learning Rate |
| --------- | --------- | ---------- | ------------- |
| ResNet-50 | V100 32GB | 256 | 0.1 |
| ResNet-50 | P100 16GB | 128 | 0.05 |
4 changes: 2 additions & 2 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
strategy:
matrix:
Linux:
vmImage: ubuntu-latest
vmImage: ubuntu-18.04
pool:
vmImage: $(vmImage)
steps:
Expand All @@ -25,7 +25,7 @@ jobs:
- job: Markdownlint
displayName: Markdownlint
pool:
vmImage: ubuntu-latest
vmImage: ubuntu-18.04
steps:
- script: sudo npm install -g markdownlint-cli
displayName: Install markdownlint-cli
Expand Down
10 changes: 5 additions & 5 deletions examples/get_flops_lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@
# sys.path.insert(0, join(abspath(dirname(__file__)), '../src'))
# print(sys.path)

from model_optimizer.stat import get_keras_model_flops # noqa: E402
from model_optimizer.stat import get_keras_model_params_flops # noqa: E402


def _main():
base_dir = os.path.dirname(__file__)
model_h5_path = './models_eval_ckpt/lenet_mnist/checkpoint-12.h5'
origin_flops = get_keras_model_flops(os.path.join(base_dir, model_h5_path))
origin_params, origin_flops = get_keras_model_params_flops(os.path.join(base_dir, model_h5_path))
model_h5_path = './models_eval_ckpt/lenet_mnist_pruned/checkpoint-12.h5'
pruned_flops = get_keras_model_flops(os.path.join(base_dir, model_h5_path))
pruned_params, pruned_flops = get_keras_model_params_flops(os.path.join(base_dir, model_h5_path))

print('flops before prune: {}'.format(origin_flops))
print('flops after pruned: {}'.format(pruned_flops))
print('Before prune, FLOPs: {}, Params: {}'.format(origin_flops, origin_params))
print('After pruned, FLOPs: {}, Params: {}'.format(pruned_flops, pruned_params))


if __name__ == "__main__":
Expand Down
10 changes: 5 additions & 5 deletions examples/get_flops_resnet_50.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@
# sys.path.insert(0, join(abspath(dirname(__file__)), '../src'))
# print(sys.path)

from model_optimizer.stat import get_keras_model_flops # noqa: E402
from model_optimizer.stat import get_keras_model_params_flops # noqa: E402


def _main():
base_dir = os.path.dirname(__file__)
model_h5_path = './models_eval_ckpt/resnet_50_imagenet/checkpoint-90.h5'
origin_flops = get_keras_model_flops(os.path.join(base_dir, model_h5_path))
origin_params, origin_flops = get_keras_model_params_flops(os.path.join(base_dir, model_h5_path))
model_h5_path = './models_eval_ckpt/resnet_50_imagenet_pruned/checkpoint-120.h5'
pruned_flops = get_keras_model_flops(os.path.join(base_dir, model_h5_path))
pruned_params, pruned_flops = get_keras_model_params_flops(os.path.join(base_dir, model_h5_path))

print('flops before prune: {}'.format(origin_flops))
print('flops after pruned: {}'.format(pruned_flops))
print('Before prune, FLOPs: {}, Params: {}'.format(origin_flops, origin_params))
print('After pruned, FLOPs: {}, Params: {}'.format(pruned_flops, pruned_params))


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions examples/resnet_101_imagenet_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ def _main():
"checkpoint_path": os.path.join(base_dir, "./models_ckpt/resnet_101_imagenet"),
"checkpoint_save_period": 5, # save a checkpoint every 5 epoch
"checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/resnet_101_imagenet"),
"scheduler": "train",
"classifier_activation": None # None or "softmax", default is softmax
"scheduler": "train"
}
prune_model(request)

Expand Down
3 changes: 1 addition & 2 deletions examples/resnet_50_imagenet_distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ def _main():
"checkpoint_save_period": 5, # save a checkpoint every 5 epoch
"checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/resnet_50_imagenet_distill"),
"scheduler": "distill",
"scheduler_file_name": "resnet_50_imagenet_0.3.yaml",
"classifier_activation": None # None or "softmax", default is softmax
"scheduler_file_name": "resnet_50_imagenet_0.3.yaml"
}
prune_model(request)

Expand Down
45 changes: 45 additions & 0 deletions examples/resnet_50_imagenet_prune_distill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2019 ZTE corporation. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""
This is an example of pruning and ensemble distillation of the resnet50 model
Please download the two models senet154 and resnet152b to the directory configured in the file
resnet_50_imagenet_0.5_distill.yaml.
wget -c -O resnet152b-0431-b41ec90e.tf2.h5.zip https://github.com/osmr/imgclsmob/releases/
download/v0.0.517/resnet152b-0431-b41ec90e.tf2.h5.zip
wget -c -O senet154-0466-f1b79a9b_tf2.h5.zip https://github.com/osmr/imgclsmob/releases/
download/v0.0.422/senet154-0466-f1b79a9b_tf2.h5.zip
"""
import os
# If you did not execute the setup.py, uncomment the following four lines
# import sys
# from os.path import abspath, join, dirname
# sys.path.insert(0, join(abspath(dirname(__file__)), '../src'))
# print(sys.path)

from model_optimizer import prune_model # noqa: E402


def _main():
base_dir = os.path.dirname(__file__)
request = {
"dataset": "imagenet",
"model_name": "resnet_50",
"data_dir": os.path.join(base_dir, "/data/imagenet/tfrecord-dataset"),
"batch_size": 256,
"batch_size_val": 100,
"learning_rate": 0.1,
"epochs": 360,
"checkpoint_path": os.path.join(base_dir, "./models_ckpt/resnet_50_imagenet_pruned"),
"checkpoint_save_period": 5, # save a checkpoint every 5 epoch
"checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/resnet_50_imagenet_pruned"),
"scheduler": "uniform_auto",
"is_distill": True,
"scheduler_file_name": "resnet_50_imagenet_0.5_distill.yaml"
}
os.environ['L2_WEIGHT_DECAY'] = "5e-5"
prune_model(request)


if __name__ == "__main__":
_main()
3 changes: 1 addition & 2 deletions examples/resnet_50_imagenet_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ def _main():
"checkpoint_path": os.path.join(base_dir, "./models_ckpt/resnet_50_imagenet"),
"checkpoint_save_period": 5, # save a checkpoint every 5 epoch
"checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/resnet_50_imagenet"),
"scheduler": "train",
"classifier_activation": None # None or "softmax", default is softmax
"scheduler": "train"
}
prune_model(request)

Expand Down
9 changes: 7 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,17 @@
_VERSION = '0.0.0'

_REQUIRED_PACKAGES = [
'requests',
'requests==2.25.0',
'tensorflow==2.3.0',
'jsonschema==3.1.1',
'networkx==2.4',
'mpi4py==3.0.3',
'horovod==0.19.1'
'horovod==0.19.1',
'tf2cv==0.0.16',
'PyYAML==5.3.1',
'types-PyYAML',
'types-pkg_resources',
'types-requests'
]

_TEST_REQUIRES = [
Expand Down
4 changes: 4 additions & 0 deletions src/model_optimizer/pruner/config_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
"checkpoint_eval_path": {
"type": "string",
"description": "file path of eval checkpoint"
},
"is_distill":{
"type": "boolean",
"description": "if start train model with distilling"
}
},
"required": [
Expand Down
2 changes: 1 addition & 1 deletion src/model_optimizer/pruner/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ def get_pruner(config, epoch):
func_name = scheduler['pruner']['func_name']
pruner_type = scheduler_config['pruners'][func_name]['prune_type']
if pruner_type in pruners:
pruner_list.append(pruners[pruner_type](scheduler_config['pruners'][func_name]))
pruner_list.append(pruners[pruner_type](scheduler_config['pruners'][func_name], config))
return pruner_list
Loading

0 comments on commit 385fcf2

Please sign in to comment.