Skip to content

Commit

Permalink
resnet50 distillation (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
priscilla-pan committed Jan 21, 2021
1 parent 5738a2e commit 074c81e
Show file tree
Hide file tree
Showing 42 changed files with 619 additions and 107 deletions.
23 changes: 21 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ sparsity pruning depends on special algorithms and hardware to achieve accelerat
Adlik pruning focuses on channel pruning and filter pruning, which can really reduce the number of parameters and
flops. In terms of quantization, Adlik focuses on 8-bit quantization that is easier to accelerate on specific hardware.
After testing, it is found that running a small batch of datasets can obtain a quantitative model with little loss of
accuracy, so Adlik focuses on this method.
accuracy, so Adlik focuses on this method. Knowledge distillation is another way to improve the performance of deep
learning algorithm. It is possible to compress the knowledge in the big model into a smaller model.

The proposed framework mainly consists of two categories of algorithm components, i.e. pruner and quantizer. The
pruner is mainly composed of five modules:core, scheduler, models, dataset and learner. The core module defines
Expand All @@ -35,6 +36,14 @@ The following table is the size of the above model files:
| LeNet-5 | 1176KB | 499KB(59% pruned) | 120KB | 1154KB (pb) |
| ResNet-50 | 99MB | 67MB(31.9% pruned) | 18MB | 138MB(pb) |

Knowledge distillation is an effective way to imporve the performance of model.

The following table shows the distillation result of ResNet-50 as the student network where ResNet-101 as the teacher network.

| student model | ResNet-101 distilled | accuracy change |
| ------------- | -------------------- | --------------- |
| ResNet-50 | 77.14% | +0.97% |

## 1. Pruning and quantization principle

### 1.1 Filter pruning
Expand Down Expand Up @@ -63,6 +72,16 @@ quantization, only need to have inference model and very little calibration data
of quantization is very small, and even some models will rise. Adlik only needs 100 sample images to complete the
quantification of ResNet-50 in less than one minute.

### 1.3 Knowledge Distillation

Knowledge distillation is a compression technique by which the knowledge of a larger model(teacher) is transfered into
a smaller one(student). During distillation, a student model learns from a teacher model to generalize well by raise
the temperature of the final softmax of the teacher model as the soft set of targets.

![Distillation](imgs/distillation.png)

Refer to the paper [Distilling the Knowledge in a Neural Network](https://arxiv.org/pdf/1503.02531.pdf)

## 2. Installation

These instructions will help get Adlik optimizer up and running on your local machine.
Expand Down Expand Up @@ -102,7 +121,7 @@ rm -rf /tmp/openmpi
#### 2.2.2 Install python package

```shell
pip install tensorflow-gpu==2.1.0
pip install tensorflow-gpu==2.3.0
pip install horovod==0.19.1
pip install mpi4py
pip install networkx
Expand Down
46 changes: 46 additions & 0 deletions doc/ResNet-50-Knowledge-Distillation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# ResNet-50 Knowledge Distillation

The following uses ResNet-101 on the ImageNet data set as teacher model to illustrate how to use the model optimizer to
improve the preformance of ResNet-50 by knowledge distillation.

## 1 Prepare data

### 1.1 Generate training and test data sets

You may follow the data preparation guide [here](https://github.com/tensorflow/models/tree/v1.13.0/research/inception)
to download the full data set and convert it into TFRecord files. By default, when the script finishes, you will find
1024 training files and 128 validation files in the DATA_DIR. The files will match the patterns train-?????-of-01024
and validation-?????-of-00128, respectively.

### 2 Train the teacher model

Enter the examples directory and execute

```shell
cd examples
horovodrun -np 8 -H localhost:8 python resnet_101_imagenet_train.py
```

After execution, the default checkpoint file will be generated in ./models_ckpt/resnet_101_imagenet, and the inference
checkpoint file will be generated in ./models_eval_ckpt/resnet_101_imagenet. You can also modify the checkpoint_path
and checkpoint_eval_path of the resnet_101_imagenet_train.py file to change the generated file path.

### 3 Distill

Enter the examples directory and execute

```shell
horovodrun -np 8 -H localhost:8 python resnet_50_imagenet_distill.py
```

After execution, the default checkpoint file will be generated in ./models_ckpt/resnet_50_imagenet_distill, and
the inference checkpoint file will be generated in ./models_eval_ckpt/resnet_50_imagenet_distill. You can also
modify the checkpoint_path and checkpoint_eval_path of the resnet_50_imagenet_distill.py file to change the generated
file path.

> Note
>
> > i. The model in the checkpoint_path is not the pure ResNet-50 model. It's the hybird of ResNet-50(student) and
> > ResNet-101(teacher)
> >
> > ii. The model in the checkpoint_eval_path is the distilled model, i.e. pure ResNet-50 model
5 changes: 3 additions & 2 deletions docker/Dockerfile.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ FROM ubuntu:18.04
RUN apt-get update && \
apt-get install -y software-properties-common && \
apt-get update -y && \
apt-get install -y --no-install-recommends build-essential python3.6 python3.6-dev python3-distutils \
apt-get install -y --no-install-recommends --allow-downgrades --allow-change-held-packages \
build-essential python3.6 python3.6-dev python3-distutils \
curl git openssh-client openssh-server && \
mkdir -p /var/run/sshd && \
mkdir -p /root/work && \
Expand All @@ -26,7 +27,7 @@ RUN mkdir /tmp/openmpi && \
rm -rf /tmp/openmpi

# Install Tensorflow and Horovod
RUN pip install --no-cache-dir tensorflow==2.1.0
RUN pip install --no-cache-dir tensorflow==2.3.0

RUN HOROVOD_WITH_TENSORFLOW=1 pip install --no-cache-dir horovod==0.19.1

Expand Down
5 changes: 3 additions & 2 deletions docker/Dockerfile.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ FROM nvidia/cuda:10.1-devel-ubuntu18.04
RUN apt-get update && \
apt-get install -y software-properties-common && \
apt-get update -y && \
apt-get install -y --no-install-recommends build-essential python3.6 python3.6-dev python3-distutils \
apt-get install -y --no-install-recommends --allow-downgrades --allow-change-held-packages \
build-essential python3.6 python3.6-dev python3-distutils \
curl vim git openssh-client openssh-server \
libcudnn7=7.6.5.32-1+cuda10.1 \
libcudnn7-dev=7.6.5.32-1+cuda10.1 \
Expand Down Expand Up @@ -38,7 +39,7 @@ RUN mkdir /tmp/openmpi && \
rm -rf /tmp/openmpi

# Install Tensorflow and Horovod
RUN pip install --no-cache-dir tensorflow-gpu==2.1.0
RUN pip install --no-cache-dir tensorflow-gpu==2.3.0

RUN HOROVOD_WITH_TENSORFLOW=1 pip install --no-cache-dir horovod==0.19.1

Expand Down
37 changes: 37 additions & 0 deletions examples/resnet_101_imagenet_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2019 ZTE corporation. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Train a ResNet_101 model on the ImageNet dataset
"""
import os
# If you did not execute the setup.py, uncomment the following four lines
# import sys
# from os.path import abspath, join, dirname
# sys.path.insert(0, join(abspath(dirname(__file__)), '../src'))
# print(sys.path)

from model_optimizer import prune_model # noqa: E402


def _main():
base_dir = os.path.dirname(__file__)
request = {
"dataset": "imagenet",
"model_name": "resnet_101",
"data_dir": os.path.join(base_dir, "./data/imagenet"),
"batch_size": 128,
"batch_size_val": 100,
"learning_rate": 0.1,
"epochs": 120,
"checkpoint_path": os.path.join(base_dir, "./models_ckpt/resnet_101_imagenet"),
"checkpoint_save_period": 5, # save a checkpoint every 5 epoch
"checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/resnet_101_imagenet"),
"scheduler": "train",
"classifier_activation": None # None or "softmax", default is softmax
}
prune_model(request)


if __name__ == "__main__":
_main()
38 changes: 38 additions & 0 deletions examples/resnet_50_imagenet_distill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2020 ZTE corporation. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Distill a ResNet_50 model from a trained ResNet_101 model on the ImageNet dataset
"""
import os
# If you did not execute the setup.py, uncomment the following four lines
# import sys
# from os.path import abspath, join, dirname
# sys.path.insert(0, join(abspath(dirname(__file__)), '../src'))
# print(sys.path)

from model_optimizer import prune_model # noqa: E402


def _main():
base_dir = os.path.dirname(__file__)
request = {
"dataset": "imagenet",
"model_name": "resnet_50",
"data_dir": os.path.join(base_dir, "./data/imagenet"),
"batch_size": 256,
"batch_size_val": 100,
"learning_rate": 0.1,
"epochs": 90,
"checkpoint_path": os.path.join(base_dir, "./models_ckpt/resnet_50_imagenet_distill"),
"checkpoint_save_period": 5, # save a checkpoint every 5 epoch
"checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/resnet_50_imagenet_distill"),
"scheduler": "distill",
"scheduler_file_name": "resnet_50_imagenet_0.3.yaml",
"classifier_activation": None # None or "softmax", default is softmax
}
prune_model(request)


if __name__ == "__main__":
_main()
6 changes: 4 additions & 2 deletions examples/resnet_50_imagenet_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
Train a ResNet_50 model on the ImageNet dataset
"""
import os
# If you did not execute the setup.py, uncomment the following four lines

# import sys
# from os.path import abspath, join, dirname
# sys.path.insert(0, join(abspath(dirname(__file__)), '../src'))
# print(sys.path)


from model_optimizer import prune_model # noqa: E402


Expand All @@ -27,7 +28,8 @@ def _main():
"checkpoint_path": os.path.join(base_dir, "./models_ckpt/resnet_50_imagenet"),
"checkpoint_save_period": 5, # save a checkpoint every 5 epoch
"checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/resnet_50_imagenet"),
"scheduler": "train"
"scheduler": "train",
"classifier_activation": None # None or "softmax", default is softmax
}
prune_model(request)

Expand Down
Binary file added imgs/distillation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

_REQUIRED_PACKAGES = [
'requests',
'tensorflow==2.1.0',
'tensorflow==2.3.0',
'jsonschema==3.1.1',
'networkx==2.4',
'mpi4py==3.0.3',
Expand Down Expand Up @@ -43,7 +43,7 @@ def get_dist(pkgname):


if get_dist('tensorflow') is None and get_dist('tensorflow-gpu') is not None:
_REQUIRED_PACKAGES.remove('tensorflow==2.1.0')
_REQUIRED_PACKAGES.remove('tensorflow==2.3.0')

setup(
name="model_optimizer",
Expand All @@ -60,7 +60,8 @@ def get_dist(pkgname):
package_data={
'model_optimizer': ['**/*.json',
'pruner/scheduler/uniform_auto/*.yaml',
'pruner/scheduler/uniform_specified_layer/*.yaml']
'pruner/scheduler/uniform_specified_layer/*.yaml',
'pruner/scheduler/distill/*.yaml']
},

)
3 changes: 2 additions & 1 deletion src/model_optimizer/pruner/dataset/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, config, is_training):
:param is_training: whether to construct the training subset
:return:
"""
super(Cifar10Dataset, self).__init__(config, is_training)
super().__init__(config, is_training)
if is_training:
self.file_pattern = os.path.join(self.data_dir, 'train.tfrecords')
self.batch_size = self.batch_size
Expand All @@ -32,6 +32,7 @@ def __init__(self, config, is_training):
self.num_samples_of_train = 50000
self.num_samples_of_val = 10000

# pylint: disable=no-value-for-parameter,unexpected-keyword-arg
def parse_fn(self, example_serialized):
"""
Parse features from the serialized data
Expand Down
8 changes: 6 additions & 2 deletions src/model_optimizer/pruner/dataset/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ def num_samples(self):
else:
return self.num_samples_of_val

def build(self):
def build(self, is_distill=False):
"""
Build dataset
:param is_distill: is distilling or not
:return: batch of a dataset
"""
dataset = tf.data.Dataset.list_files(self.file_pattern, shuffle=True)
Expand All @@ -73,7 +74,10 @@ def build(self):
dataset = dataset.interleave(self.dataset_fn, cycle_length=10, num_parallel_calls=tf.data.experimental.AUTOTUNE)
if self.is_training:
dataset = dataset.shuffle(buffer_size=self.buffer_size).repeat()
dataset = dataset.map(self.parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
if is_distill:
dataset = dataset.map(self.parse_fn_distill, num_parallel_calls=tf.data.experimental.AUTOTUNE)
else:
dataset = dataset.map(self.parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return self.__build_batch(dataset)

def __build_batch(self, dataset):
Expand Down
14 changes: 13 additions & 1 deletion src/model_optimizer/pruner/dataset/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, config, is_training, num_shards=1, shard_index=0):
:param is_training: whether to construct the training subset
:return:
"""
super(ImagenetDataset, self).__init__(config, is_training, num_shards, shard_index)
super().__init__(config, is_training, num_shards, shard_index)
if is_training:
self.file_pattern = os.path.join(self.data_dir, 'train-*-of-*')
self.batch_size = self.batch_size
Expand All @@ -33,6 +33,7 @@ def __init__(self, config, is_training, num_shards=1, shard_index=0):
self.num_samples_of_train = 1281167
self.num_samples_of_val = 50000

# pylint: disable=no-value-for-parameter,unexpected-keyword-arg
def parse_fn(self, example_serialized):
"""
Parse features from the serialized data
Expand Down Expand Up @@ -77,3 +78,14 @@ def parse_fn(self, example_serialized):
num_channels=3,
is_training=self.is_training)
return image, label

def parse_fn_distill(self, example_serialized):
"""
Parse features from the serialized data for distillation
:param example_serialized: serialized data
:return: {image, label},{}
"""
image, label = self.parse_fn(example_serialized)
inputs = {"image": image, "label": label}
targets = {}
return inputs, targets
3 changes: 2 additions & 1 deletion src/model_optimizer/pruner/dataset/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, config, is_training):
:param is_training: whether to construct the training subset
:return:
"""
super(MnistDataset, self).__init__(config, is_training)
super().__init__(config, is_training)
if is_training:
self.file_pattern = os.path.join(self.data_dir, 'train.tfrecords')
self.batch_size = self.batch_size
Expand All @@ -33,6 +33,7 @@ def __init__(self, config, is_training):
self.num_samples_of_val = 10000

# pylint: disable=R0201
# pylint: disable=no-value-for-parameter,unexpected-keyword-arg
def parse_fn(self, example_serialized):
"""
Parse features from the serialized data
Expand Down
2 changes: 2 additions & 0 deletions src/model_optimizer/pruner/distill/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright 2021 ZTE corporation. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
Loading

0 comments on commit 074c81e

Please sign in to comment.