Skip to content

Commit

Permalink
add MobileNet-v1 pruning (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
lyg95 committed Jun 7, 2021
1 parent 074c81e commit 4ba7077
Show file tree
Hide file tree
Showing 9 changed files with 272 additions and 14 deletions.
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ disable = fixme,
no-member,
unnecessary-pass,
import-outside-toplevel,
no-else-continue
no-else-continue,
similarities

[FORMAT]
max-line-length = 120
Expand Down
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ The following table is the size of the above model files:
| LeNet-5 | 1176KB | 499KB(59% pruned) | 120KB | 1154KB (pb) |
| ResNet-50 | 99MB | 67MB(31.9% pruned) | 18MB | 138MB(pb) |

Currently, the MobileNet-v1 model was only pruned but not quantized. We show the results of different pruning ratio,
which was tested on ImageNet. The original test accuracy is 71.25%, and model size is 17MB.

| Pruning ratio(%) | FLOPs(%) | Params(%) | Test Accuracy(%) | Size(MB) |
| ---------------- | -------- | --------- | ---------------- | ------------ |
| 25 | -33.12 | -38.37 | 69.658 | 11(-35.29%) |
| 35 | -51.32 | -51.41 | 68.66 | 8.2(-51.76%) |
| 50 | -57.21 | -67.69 | 66.87 | 5.5(-67.65%) |

Knowledge distillation is an effective way to imporve the performance of model.

The following table shows the distillation result of ResNet-50 as the student network where ResNet-101 as the teacher network.
Expand Down
57 changes: 57 additions & 0 deletions doc/MobileNet-v1-Training-Pruning.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# MobileNet v1 training and pruning

The following uses MobileNet v1 on the ImageNet dataset to illustrate how to use the model optimizer to achieve
model traing, and pruning

## 1 Prepare data

### 1.1 Generate training and test data sets

You may follow the data preparation guide [here](https://github.com/tensorflow/models/tree/v1.13.0/research/inception)
to download the full dataset and convert it into TFRecord files. By default, when the script finishes, you will find
1024 training files and 128 validation files in the DATA_DIR. The file will match the patterns
train-?????-of-01024 and validation-?????-of-0128, respectively.

## 2 Train

Enter the examples directory and execute

```shell
cd examples
horovodrun -np 8 -H localhost:8 python movilenet_v1_imagenet_train.py
```

After execution, the default checkpoint file will be generated in ./models_ckpt/mobilenet_v1_imagenet, and the
inference checkpoint file will be generated in ./models_eval_ckpt/mobilenet_v1_imagenet. You can also modify the
checkpoint_path and checkpoint_eval_path of the mobilenet_v1_imagenet_train.py file to change the generated file path.

## 3 Prune

Here, you can use a full trained model or the model in training process as a initial model to prune. The following
uses specified pruning strategy as an example.

If you have a well trained model, for example, named checkpoint-120.h5 in directory ./models_ckpt/mobilenet_v1_imagenet.
You can copy it to the ./models_ckpt/mobilenet_v1_imagenet_specified_pruned directory, and then perform pruning. Enter
the examples diretory and execute

```shell
cd examples
cp ./models_ckpt/mobilenet_v1_imagenet/checkpoint-120.h5 ./models_ckpt/mobilenet_v1_imagenet_specified_pruned/
horovodrun -np 8 -H localhost:8 python mobilenet_v1_imagenet_prune.py
```

Or you can start a training and pruning process from scratch

```shell
cd examples
horovodrun -np 8 -H localhost:8 python mobilenet_v1_imagenet_prune.py
```

After execution, the default checkpoint file weill be generated in ./models_ckpt/mobilenet_v1_imagenet_specified_pruned,
and the inference checkpoint file will be generated in ./models_eval_ckpt/mobilenet_v1_imagenet_specified_pruned. You
can also modify the checkpoint_path and checkpoint_eval_path of the mobilenet_v1_imagenet_prune.py file to change the
generated file path.

## 4 Quantize

To be continue.
78 changes: 77 additions & 1 deletion src/model_optimizer/pruner/core/pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,45 @@ def get_network(model):
return digraph


def dense_present_before_conv(orig_model):
"""
Determin whether to use a fully connected layer or a reshape layer for model classification.
:orig_model: keras model
:return:
- layer_index: convolution or dense layer index
- last_reshape: the last reshape layer index
- Boolean: the model whether ending with dense layer
"""
layer_name = []
dense_present = False
conv_present = False
last_reshape = -1
conv_dense_present = False
layer_index = -1
for _, layer in enumerate(orig_model.layers):
layer_name.append(str(type(layer)))
length = len(layer_name)
for index in range(length - 1, -1, -1):
if not conv_dense_present and 'Reshape' in layer_name[index]:
last_reshape = index
if 'Conv2D' in layer_name[index] and not conv_dense_present:
conv_present = True
conv_dense_present = True
layer_index = index
break
if 'Dense' in layer_name[index] and not conv_dense_present:
dense_present = True
conv_dense_present = True
layer_index = index
break
if dense_present and not conv_present:
return layer_index, last_reshape, True
elif conv_present and not dense_present:
return layer_index, last_reshape, False
else:
return -1, -1, False


def _get_sorted_mask(arr, num_retain_channels):
arg_sort = np.argsort(-arr)
mask_arr = np.zeros(arr.shape[-1], dtype=bool)
Expand Down Expand Up @@ -134,7 +173,10 @@ def _layer_set_weights(pruned_model, layer, weights_0, idx, mask_dict):
pruned_model.layers[idx].set_weights([weights_0[:, :, :, mask_dict[idx]],
layer.weights[1].numpy()[mask_dict[idx]]])
else:
pruned_model.layers[idx].set_weights([weights_0[:, :, :, mask_dict[idx]]])
if layer_type.endswith('DepthwiseConv2D\'>'):
pruned_model.layers[idx].set_weights(weights_0[:, :, mask_dict[idx], :])
else:
pruned_model.layers[idx].set_weights([weights_0[:, :, :, mask_dict[idx]]])
else:
if layer.use_bias:
pruned_model.layers[idx].set_weights([weights_0[:, mask_dict[idx]],
Expand Down Expand Up @@ -198,12 +240,29 @@ def specified_layers_prune(orig_model, cur_model, layers_name, ratio, criterion=
clone_model = tf.keras.models.clone_model(cur_model)
digraph = get_network(cur_model)
mask_dict = {}
conv_index, last_reshape, dense_ahead_of_conv = dense_present_before_conv(orig_model)
channel = -1
for i, layer in enumerate(cur_model.layers):
layer_type = str(type(layer))
if not dense_ahead_of_conv and i == conv_index:
if layer_type.endswith('Conv2D\'>'):
channel = clone_model.get_layer(layer.name).filters
continue
if layer_type.endswith('Reshape\'>'):
if i == last_reshape:
target_shape = (channel,)
clone_model.layers[i].target_shape = target_shape
continue
elif channel != -1:
target_shape = (1, 1, channel)
clone_model.layers[i].target_shape = target_shape
continue
if 'Conv2D' in str(type(layer)):
if layer.name in layers_name:
clone_model.layers[i].filters = \
clone_model.layers[i].filters - int(orig_model.layers[i].filters * ratio)
mask_dict[i] = _get_conv_mask(cur_model, i, digraph, int(clone_model.layers[i].filters), criterion)
channel = clone_model.layers[i].filters
elif 'Dense' in str(type(layer)):
if i == len(cur_model.layers) - 1:
continue
Expand All @@ -229,11 +288,28 @@ def auto_prune(orig_model, cur_model, ratio, criterion='l1_norm'):
clone_model = tf.keras.models.clone_model(cur_model)
digraph = get_network(cur_model)
mask_dict = {}
conv_index, last_reshape, dense_ahead_of_conv = dense_present_before_conv(orig_model)
channel = -1
for i, layer in enumerate(cur_model.layers):
layer_type = str(type(layer))
if not dense_ahead_of_conv and i == conv_index:
if layer_type.endswith('Conv2D\'>'):
channel = clone_model.get_layer(layer.name).filters
continue
if layer_type.endswith('Reshape\'>'):
if i == last_reshape:
target_shape = (channel,)
clone_model.layers[i].target_shape = target_shape
continue
elif channel != -1:
target_shape = (1, 1, channel)
clone_model.layers[i].target_shape = target_shape
continue
if 'Conv2D' in str(type(layer)):
clone_model.layers[i].filters = \
clone_model.layers[i].filters - int(orig_model.layers[i].filters * ratio)
mask_dict[i] = _get_conv_mask(cur_model, i, digraph, int(clone_model.layers[i].filters), criterion)
channel = clone_model.layers[i].filters
elif 'Dense' in str(type(layer)):
if i == len(cur_model.layers) - 1:
continue
Expand Down
21 changes: 21 additions & 0 deletions src/model_optimizer/pruner/learner/learner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .utils import get_call_backs
from ...stat import print_keras_model_summary, print_keras_model_params_flops
from ..distill.distill_loss import DistillLossLayer
from ..core.pruner import dense_present_before_conv


class LearnerBase(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -253,6 +254,7 @@ def load_model(self):
custom_objects=custom_objects)
self.train_models_update(model)

# pylint: disable=too-many-branches
def save_eval_model(self):
"""
Save evaluate model
Expand All @@ -262,6 +264,8 @@ def save_eval_model(self):
return
train_model = self.models_train[-1]
eval_model = self.models_eval[-1]
channel = -1
conv_index, last_reshape, dense_ahead_of_conv = dense_present_before_conv(train_model)
save_model_path = os.path.join(self.save_model_path, 'checkpoint-') + str(self.cur_epoch) + '.h5'
if self.config.get_attribute('scheduler') == 'distill':
model_name = self.config.get_attribute('model_name')
Expand All @@ -276,8 +280,25 @@ def save_eval_model(self):
else:
clone_model = tf.keras.models.clone_model(eval_model)
for i, layer in enumerate(clone_model.layers):
layer_type = str(type(layer))
# the model's output with convolution, no pruning and getting its channel
if not dense_ahead_of_conv and i == conv_index:
if layer_type.endswith('Conv2D\'>'):
channel = train_model.get_layer(layer.name).filters
continue
# incoperate with the change of channel resulting from pruing filters in convolution
if layer_type.endswith('Reshape\'>'):
if i == last_reshape:
target_shape = (channel,)
clone_model.layers[i].target_shape = target_shape
continue
elif channel != -1:
target_shape = (1, 1, channel)
clone_model.layers[i].target_shape = target_shape
continue
if 'Conv2D' in str(type(layer)):
clone_model.layers[i].filters = train_model.get_layer(layer.name).filters
channel = train_model.get_layer(layer.name).filters
elif 'Dense' in str(type(layer)):
clone_model.layers[i].units = train_model.get_layer(layer.name).units
pruned_eval_model = tf.keras.models.model_from_json(clone_model.to_json())
Expand Down
3 changes: 1 addition & 2 deletions src/model_optimizer/pruner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ def run_scheduler(config):
learner.train(initial_epoch=initial_epoch, epochs=cur_epoch+epoch_span, lr_schedulers=lr_schedulers)
cur_epoch += epoch_span
_prune(config, cur_epoch, learner)
if cur_epoch < ln_cur_epoch:
cur_epoch = ln_cur_epoch
cur_epoch = max(cur_epoch, ln_cur_epoch)
target_epoch = config.get_attribute('epochs')
if cur_epoch < target_epoch:
learner.build_train()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
version: 1
pruners:
prune_func1:
criterion: l1_norm
prune_type: auto_prune
ratio: 0.30

lr_schedulers:
# Learning rate
- name: warmup_lr
class: LearningRateWarmupCallback
warmup_epochs: 5
verbose: 0
- name: lr_multiply_1
class: LearningRateScheduleCallback
start_epoch: 5
end_epoch: 30
multiplier: 1.0
- name: lr_multiply_0.1
class: LearningRateScheduleCallback
start_epoch: 30
end_epoch: 80
multiplier: 1e-1
- name: lr_multiply_0.01
class: LearningRateScheduleCallback
start_epoch: 80
end_epoch: 120
multiplier: 1e-2
- name: lr_multiply_0.001
class: LearningRateScheduleCallback
start_epoch: 120
end_epoch: 140
multiplier: 1e-3
- name: lr_multiply_0.0001
class: LearningRateScheduleCallback
start_epoch: 140
end_epoch: 200
multiplier: 1e-4

prune_schedulers:
- pruner:
func_name: prune_func1
epochs: [50]

Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
version: 1
pruners:
prune_func1:
criterion: mean_l1_norm
prune_type: specified_layer_prune
ratio: 0.35
layers_to_be_pruned: [
conv_pw_1,
conv_pw_2,
conv_pw_4,
conv_pw_5,
conv_pw_6,
conv_pw_7,
conv_pw_8,
conv_pw_9,
conv_pw_10,
conv_pw_11,
conv_pw_12,
conv_pw_13
]

lr_schedulers:
# Learning rate
- name: warmup_lr
class: LearningRateWarmupCallback
warmup_epochs: 5
verbose: 0
- name: lr_multiply_1
class: LearningRateScheduleCallback
start_epoch: 5
end_epoch: 30
multiplier: 1.0
- name: lr_multiply_0.1
class: LearningRateScheduleCallback
start_epoch: 30
end_epoch: 50
multiplier: 1e-1
- name: lr_multiply_0.01
class: LearningRateScheduleCallback
start_epoch: 50
end_epoch: 70
multiplier: 1e-2
- name: lr_multiply_0.001
class: LearningRateScheduleCallback
start_epoch: 70
end_epoch : 90
multiplier: 1e-3

prune_schedulers:
- pruner:
func_name: prune_func1
epochs: [5]
19 changes: 9 additions & 10 deletions src/model_optimizer/quantizer/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@ def compress_dir(source_list, zip_file_path):
:param zip_file_path: path of zip file
:return:
"""
zip_file = zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED)
for source in source_list:
basename = os.path.basename(source)
zip_file.write(source, basename)
if os.path.isdir(source):
for path, _, filenames in os.walk(source):
fpath = path.replace(source, basename)
for filename in filenames:
zip_file.write(os.path.join(path, filename), os.path.join(fpath, filename))
zip_file.close()
with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED) as zip_file:
for source in source_list:
basename = os.path.basename(source)
zip_file.write(source, basename)
if os.path.isdir(source):
for path, _, filenames in os.walk(source):
fpath = path.replace(source, basename)
for filename in filenames:
zip_file.write(os.path.join(path, filename), os.path.join(fpath, filename))
return zip_file_path

0 comments on commit 4ba7077

Please sign in to comment.