Detailes of the CNN strucdure in the demo, as well as the mathematical derivation of backpropagation can be found in "Derivation of Backpropagation in Convolutional Neural Network (CNN)", which is specifically written for this demo.
The implementation of CNN uses the trimmed version of DeepLearnToolbox by R. B. Palm.
- Pre-requisite
- [MNIST dataset](#MNIST dataset)
- [CNN structure](#CNN structure)
- Run the demo
- Results
- Wider or deeper CNN
>> demo_CNN_MNIST
Note that only 1 epoch will be performs. If you want to run more epochs, please modify the variable num_epochs
in the file demo_CNN_MNIST.m
(line 62).
num_epochs | Training accuracy | Testing accuracy |
---|---|---|
200 | 99.34% | 99.02% |
<img src="https://github.com/ZZUTK/An-Example-of-CNN-on-MNIST-dataset-/blob/master/figs/loss_func.png", width="200">
where y and y_hat denote the true label and prediction, respectively.
<img src="https://github.com/ZZUTK/An-Example-of-CNN-on-MNIST-dataset-/blob/master/figs/train_MSE.png", width="400">
### The learned kernels of the first and second convolutional layers The first convolutional layer has 6 kernels, and the second has 6x12 kernels. All kernels are in the size of 5x5. ### An example of feedforward on the trained CNN ## Use a wider or deeper CNN The classification accuracy will increase if using a wider or deeper CNN. The "wider" means increasing the number of channels in each convolutional (pooling) layer, and the "deeper" refers to increasing the number of convolutional (pooling) layers.We consider the CNN structure used in the demo as the prototype, which has 6 and 12 channels in the first and second convolutional (pooling) layer, respectively. The wider CNN will take 12 and 24 channels in the first and second convolutional (pooling) layer, respectively. The deeper CNN has three convolutional (pooling) layers, and there are 6, 12, and 24 channels for each. The traing and testing accuracy are shown as follow.
num_epochs=200 | Training accuracy | Testing accuracy | Training MSE |
---|---|---|---|
Wider | 99.60% | 99.16% | 0.0052 |
Deeper | 99.69% | 99.28% | 0.0042 |
num_epochs=500 | Training accuracy | Testing accuracy | Training MSE |
---|---|---|---|
Wider | 99.83% | 99.20% | 0.0022 |
Deeper | 99.84% | 99.37% | 0.0016 |