You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: VAE/readme.md
+43-41Lines changed: 43 additions & 41 deletions
Original file line number
Diff line number
Diff line change
@@ -1,30 +1,30 @@
1
1
# Variational Auto Encoder (VAE) in Pytorch
2
-
The repository consists of a variational autoencoder implemented in pytorch and trained on MNIST dataset. It can be used to generate images and also for generating cool T-SNE visulaisations of the latent space.
2
+
The repository consists of a variational autoencoder implemented in PyTorch and trained on MNIST dataset. It can be used to generate images and also for generating cool T-SNE visualization of the latent space.
3
3
4
-
## VAE: Overview
4
+
## VAE: Overview
5
5
6
-
Variational autoencoders at first glance seems like another autoencoder. An autoencoder basically consists of an encoder and a decoder. **The encoder converts the input into another dimension space, generally of a smaller size** and then tries to reconstruct the input from this representation. This kind of forces the network to filter out the not so useful features and only stores useful features. So this is sometimes used to get a lower dimension representation of our data.
6
+
Variational autoencoders at first glance seem like another autoencoder. An autoencoder basically consists of an encoder and a decoder. **The encoder converts the input into another dimension space, generally of a smaller size** and then tries to reconstruct the input from this representation. This kind of forces the network to filter out the not so useful features and only stores useful features. So this is sometimes used to get a lower-dimensional representation of our data.
Now whats so special about Variational autoencoders.
11
+
Now, what's so special about Variational autoencoders.
12
12
13
-
Well this is not a tutorial for VAE so let's just get an overview.
13
+
Well, this is not a tutorial for VAE so let's just get an overview.
14
14
15
-
Now VAE is a generative model ..meaning it can be used to generate new data. Now why can't we use a standard autoencoder to do this. The problem **with the standard auto-encoder is that the latent space repesentation of the data follows some very complex distribution which is not known to us**. So we can't sample new latent variables from that distribution and decode them into something that looks like an image.
15
+
Now VAE is a generative model ..meaning it can be used to generate new data. Now, why can't we use a standard autoencoder to do this. The problem **with the standard auto-encoder is that the latent space representation of the data follows some very complex distribution which is not known to us**. So we can't sample new latent variables from that distribution and decode them into something that looks like an image.
16
16
17
-
So that's were **VAE are different ..they constraint the latent space representation to be of that of a unit gaussian** which we can easily sample from and use to create new samples.
17
+
So that's were **VAE are different ..they constraint the latent space representation to be of that of a unit gaussian** which we can easily sample from and use to create new samples.
18
18
19
-
Now this is done using ..well a lot of complicated maths ..something called variational inference. I guess its easier to explain it using the loss function.
19
+
Now, this is done using ..well a lot of complicated maths ..something called variational inference. I guess its easier to explain it using the loss function.
The first term is basically maximising the likelihood of the input data and is simply said the reconstruction loss. The second term is a KL divergence loss and it measures the similarity of `Q(z|x)` and `P(z)`. `P(z)` is what the distribution of the latent variables should be (ie . unit gaussian) and `Q(z|x)` is our approximator of `P(z)` using the encoder neural network ( Its also a gaussian but with mean and variance output by the encoder)
27
+
The first term is basically maximising the likelihood of the input data and is simply said the reconstruction loss. The second term is a KL divergence loss and it measures the similarity of `Q(z|x)` and `P(z)`. `P(z)` is what the distribution of the latent variables should be (ie . unit gaussian) and `Q(z|x)` is our approximator of `P(z)` using the encoder neural network ( Its also a gaussian but with mean and variance output by the encoder)
28
28
29
29
## Contents
30
30
1.[Setup Instructions and Dependencies](#1-setup-instructions-and-dependencies)
@@ -42,14 +42,14 @@ The first term is basically maximising the likelihood of the input data and is s
42
42
7.[Observations](#7observations)
43
43
8.[Credits](#8-credits)
44
44
45
-
So basically the loss has two opposing functions ..the reconstruction loss which tries to recreate the input as such not caring about the latent variable distribution and the KL divergence term which forces the distribution to be gaussian.
45
+
So basically the loss has two opposing functions ..the reconstruction loss which tries to recreate the input as such not caring about the latent variable distribution and the KL divergence term which forces the distribution to be gaussian.
46
46
47
47
## 1. Setup Instructions and Dependencies
48
48
You can either download the repo or clone it by running the following in cmd prompt
49
49
```
50
50
https://github.com/ayushtues/GenZoo.git
51
51
```
52
-
You can create a virtual environment and run the below command to install all required dependecies
52
+
You can create a virtual environment and run the below command to install all required dependencies
The `config.ini` file should specify all the required parameters of the model.
66
66
**To train on gpu set gpu = 1 else gpu = 0 in config.ini**
67
67
68
-
The program automatically downloads the MNIST dataset and saves it in `MNIST_dataset` (creating the folder itself). This only happens once
68
+
The program automatically downloads the MNIST dataset and saves it in `MNIST_dataset` (creating the folder itself). This only happens once
69
69
70
-
It also creates a `experiments` folder and inside it creates a`exp_name` folder as specified in your config.ini file.
70
+
It also creates a `experiments` folder and inside it creates an`exp_name` folder as specified in your config.ini file.
71
71
72
72
The `exp_name` file has 3 folders
73
73
-`training_checkpoints` - Contains the models saved with frequency as specified in `model_save_frequency` in `config.ini`
@@ -78,7 +78,7 @@ The `exp_name` file has 3 folders
78
78
79
79
To generate new images from z sampled randomly from uniform gaussian and to make a nice digit transit grid run the following command.
80
80
81
-
It can also generate a single digit, for which it basically takes the mean of the means and variance produced for all samples in the test dataset for that digit and uses this averaged mean and variance to sample some latent variables and pass it in the decoder.
81
+
It can also generate a single digit, for which it basically takes the mean of the means and variance produced for all samples in the test dataset for that digit and uses this averaged mean and variance to sample some latent variables and pass it in the decoder.
-`save_path` - the directory where to save the images
90
90
-`z_dims` - the size of the latent space (Useful if training models with different z_dims otherwise) (default 20 )
91
91
-`grid_size2` - the size of the grid of the generated images of a single number sampled randomly
92
-
-`no_datapoints`- the number of test datapoints used for approximating mean and var of the singledigit
92
+
-`no_datapoints`- the number of test data points used for approximating mean and var of the single-digit
93
93
-`digit` - the digit to generate
94
94
95
95
You can use a pre-trained model (with z_dims = 20) by downloading it from the link in `model.txt`
96
96
97
97
## 4. Repository Overview
98
98
99
-
The repository contains of the following files
99
+
The repository contains the following files
100
100
101
-
-`main.py` - Does the major work, Calls functions from various other files to nake , train, save the model and also do the t-sne visualisation.
102
-
-`train.py`- Has the loss function, also the function to display images in a grid, save the reconstructed training images and also the model. Basically handles the training.
103
-
-`model.py` - Contains the VAE model, the encoder, decoder, forward functions.
104
-
-`dataloader.py` - Returns a dataloader from the MNIST dataset with given batch size.
105
-
-`generate-py` - Generates new images and also the transition between two digits grid from a pretrained model.
106
-
-`model.txt` - Contains link to a pretrained model (with z_dims =20)
107
-
-`readme.md` - Readme giving overview of the repo
101
+
-`main.py` - Does the major work, Calls functions from various other files to make, train, save the model and also do the t-sne visualisation.
102
+
-`train.py`- Has the loss function, also the function to display images in a grid, save the reconstructed training images and also the model. Basically handles the training.
103
+
-`model.py` - Contains the VAE model, the encoder, decoder, forward functions.
104
+
-`dataloader.py` - Returns a data loader from the MNIST dataset with given batch size.
105
+
-`generate-py` - Generates new images and also the transition between two digits grid from a pre-trained model.
106
+
-`model.txt` - Contains a link to a pre-trained model (with z_dims =20)
107
+
-`readme.md` - Readme giving an overview of the repo
108
108
-`requirements.txt` - Has all the required dependencies to be installed for the repo
109
109
-`readme_images` - Has various images for the readme
110
110
-`configs` - Has the config.ini file
111
111
-`MNIST_dataset` - Contains the downloaded MNIST Dataset(though `main.py` will download it if needed automatically)
112
112
-`experiments` - Has various results of the training of the model
113
-
-`generated_images` - Contains the digit_transit images and also new generated images
114
-
-`mnist` - Contains the saved models, training images, and t-sne visualisation
115
-
-`runs` Contains the tensorboard logs (automatically created by program)
113
+
-`generated_images` - Contains the digit_transit images and also newly generated images
114
+
-`mnist` - Contains the saved models, training images, and t-sne visualisation
115
+
-`runs` Contains the tensorboard logs (automatically created by the program)
The architecture is basically divided into two parts an encoder and a decoder.
127
-
The encoder first has a bunch of convultional layers with LeakyRelu activation function and maxpooling and batchnorm . The last conv layer also has dropout. Then there are a bunch of Fully connected layers with leakyRelu activation and dropout. Finally a fully connected layer gives us the mean and logvar respectively.
126
+
The architecture is basically divided into two parts an encoder and a decoder.
127
+
The encoder first has a bunch of convolutional layers with LeakyRelu activation function and max-pooling and batch norm . The last conv layer also has dropout. Then there are a bunch of Fully connected layers with leakyRelu activation and dropout. Finally, a fully connected layer gives us the mean and logvar respectively.
128
128
129
129
Then we sample from the distribution using the reparameterisation trick . With z = (std*eps)+mean , where eps = N (0,I) .
130
130
131
-
The decoder consists of a bunch of fully conncected layers followed by Transpose Convolutional layers and finally a sigmoid function which gives the output images.
131
+
The decoder consists of a bunch of fully connected layers followed by Transpose Convolutional layers and finally a sigmoid function which gives the output images.
132
132
## 6. Results
133
133
### 1. Training images
134
134
Image from 0th epoch
@@ -163,16 +163,16 @@ After 100 epochs the t-sne visualisation is
As we can see there are clearly 10 clusters formed and also there is a smooth transition between them, which means the model is working fine ( although there is 1 seperate cluster )
166
+
As we can see there are clearly 10 clusters formed and also there is a smooth transition between them, which means the model is working fine ( although there is 1 separate cluster )
167
167
168
-
### 3. Image generated from random gaussian input
168
+
### 3. Image generated from random Gaussian input
169
169
170
170
The following image was generated after passing a randomly sampled z from the unit gaussian and then passed through the decoder
Although its clear that the images could be better but still they resemble digits well enough for a randomly sampled gaussian input.
175
+
Although it’s clear that the images could be better still they resemble digits well enough for a randomly sampled Gaussian input.
176
176
177
177
### 4. Smooth transition between two digits
178
178
@@ -197,8 +197,8 @@ The following shows images formed when the latent variable z of one image was un
197
197
198
198
**As we can see the total loss and reconstruction loss decrease uniformly as expected**
199
199
200
-
### 6 . Singledigit generated samples
201
-
The following images are genrated by randomly sampling latent variables from a gaussian with mean and var given by the average of all the means and var output by all the test examples of the images belonging to that class
200
+
### 6 . Single-digit generated samples
201
+
The following images are generated by randomly sampling latent variables from a gaussian with mean and var given by the average of all the means and var output by all the test examples of the images belonging to that class
As we can the see the images resemble the corresponding digits.
211
+
As we can see the images resemble the corresponding digits.
212
212
213
213
## 7.Observations
214
-
The model was trained on google colab for 100 epoch, with batch size 50. It took approx 10-15 mins to train.
214
+
The model was trained on google collab for 100 epoch, with batch size 50. It took approx 10-15 mins to train.
215
215
216
-
After training the model was able to reconstruct the input images quite well, and was also able to generate new images although the generated images are not so clear.
217
-
The T-sne visualisation of the latent space shows that the latent space has been divided into 10 clusters, but also has smooth transitions between these spaces, indicating that the model forced the latent space to be a bit similar with the normal distribution.
218
-
The digit-transit images show that **latent space has a linear nature** and linearly changing the latent variables from one digit to another leads to a smooth transition.
216
+
After training the model was able to reconstruct the input images quite well, and was also able to generate new images although the generated images are not so clear.
217
+
The T-sne visualisation of the latent space shows that the latent space has been divided into 10 clusters, but also has smooth transitions between these spaces, indicating that the model forced the latent space to be a bit similar with the normal distribution.
218
+
The digit-transit images show that **latent space has a linear nature** and linearly changing the latent variables from one digit to another leads to a smooth transition.
219
219
220
-
**Also using the estimate of mean and var for a single class , and generating samples from it gave pretty good images , so this might mean that the marginal distribution of a particular class is can be approximated by a gaussian with the mean and var as given by the method and the dataset can be viewed as a sum of 10 different gaussians . Although there are a few 4's and 8's generated in the 9's image meaning these gaussians have some mixing**
220
+
**Also using the estimate of mean and var for a single class , and generating samples from it gave pretty good images , so this might mean that the marginal distribution of a particular class is can be approximated by a gaussian with the mean and var as given by the method and the dataset can be viewed as a sum of 10 different Gaussians . Although there are a few 4's and 8's generated in the 9's image meaning these Gaussians have some mixing**
221
221
222
222
One peculiar thing to notice was that the **KL-Divergence loss actually increased as the model trained** .
223
-
I found a possible explanation at [this reddit post](https://www.reddit.com/r/MachineLearning/comments/6m2tje/d_kl_divergence_decreases_to_a_point_and_then/)
223
+
I found a possible explanation at [this Reddit post](https://www.reddit.com/r/MachineLearning/comments/6m2tje/d_kl_divergence_decreases_to_a_point_and_then/)
224
224
225
225
TLDR : Guess its easier for the model to make the distribution gaussian so it first does that , but a perfect gaussian means that there is no uniqueness introduced by the training images , so after a while the distribution shifts from a perfect gaussian to incorporate the subtle patterns in the data and thus increasing the KLD loss .
226
226
@@ -238,3 +238,5 @@ Diederik P Kingma, Max Welling](https://arxiv.org/abs/1312.6114)
0 commit comments