Skip to content

Commit 833af5c

Browse files
authored
Fixed Grammar
1 parent 77ddda0 commit 833af5c

File tree

1 file changed

+43
-41
lines changed

1 file changed

+43
-41
lines changed

VAE/readme.md

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,30 @@
11
# Variational Auto Encoder (VAE) in Pytorch
2-
The repository consists of a variational autoencoder implemented in pytorch and trained on MNIST dataset . It can be used to generate images and also for generating cool T-SNE visulaisations of the latent space .
2+
The repository consists of a variational autoencoder implemented in PyTorch and trained on MNIST dataset. It can be used to generate images and also for generating cool T-SNE visualization of the latent space.
33

4-
## VAE : Overview
4+
## VAE: Overview
55

6-
Variational autoencoders at first glance seems like another autoencoder . An autoencoder basically consists of an encoder and a decoder . **The encoder converts the input into another dimension space , generally of a smaller size** and then tries to reconstruct the input from this representation . This kind of forces the network to filter out the not so useful features and only stores useful features . So this is sometimes used to get a lower dimension representation of our data .
6+
Variational autoencoders at first glance seem like another autoencoder. An autoencoder basically consists of an encoder and a decoder. **The encoder converts the input into another dimension space, generally of a smaller size** and then tries to reconstruct the input from this representation. This kind of forces the network to filter out the not so useful features and only stores useful features. So this is sometimes used to get a lower-dimensional representation of our data.
77

88
<img src='readme_images/autoencoder.png' style="max-width:100%">
99

1010

11-
Now whats so special about Variational autoencoders .
11+
Now, what's so special about Variational autoencoders.
1212

13-
Well this is not a tutorial for VAE so let's just get an overview .
13+
Well, this is not a tutorial for VAE so let's just get an overview.
1414

15-
Now VAE is a generative model ..meaning it can be used to generate new data . Now why can't we use a standard autoencoder to do this . The problem **with the standard auto-encoder is that the latent space repesentation of the data follows some very complex distribution which is not known to us** . So we can't sample new latent variables from that distribution and decode them into something that looks like an image .
15+
Now VAE is a generative model ..meaning it can be used to generate new data. Now, why can't we use a standard autoencoder to do this. The problem **with the standard auto-encoder is that the latent space representation of the data follows some very complex distribution which is not known to us **. So we can't sample new latent variables from that distribution and decode them into something that looks like an image.
1616

17-
So that's were **VAE are different ..they constraint the latent space representation to be of that of a unit gaussian** which we can easily sample from and use to create new samples .
17+
So that's were **VAE are different ..they constraint the latent space representation to be of that of a unit gaussian** which we can easily sample from and use to create new samples.
1818

19-
Now this is done using ..well a lot of complicated maths ..something called variational inference . I guess its easier to explain it using the loss function .
19+
Now, this is done using ..well a lot of complicated maths ..something called variational inference. I guess its easier to explain it using the loss function.
2020

2121

2222
<img src='readme_images/vae.png' style="max-width:100%">
2323

2424
`Loss = Ez∼Q(z|x)[logP(x|z)]−KL[Q(z|x)||P(z)]`
2525

2626

27-
The first term is basically maximising the likelihood of the input data and is simply said the reconstruction loss . The second term is a KL divergence loss and it measures the similarity of `Q(z|x)` and `P(z)` . `P(z)` is what the distribution of the latent variables should be (ie . unit gaussian) and `Q(z|x)` is our approximator of `P(z)` using the encoder neural network ( Its also a gaussian but with mean and variance output by the encoder)
27+
The first term is basically maximising the likelihood of the input data and is simply said the reconstruction loss. The second term is a KL divergence loss and it measures the similarity of `Q(z|x)` and `P(z)`. `P(z)` is what the distribution of the latent variables should be (ie . unit gaussian) and `Q(z|x)` is our approximator of `P(z)` using the encoder neural network ( Its also a gaussian but with mean and variance output by the encoder)
2828

2929
## Contents
3030
1. [Setup Instructions and Dependencies](#1-setup-instructions-and-dependencies)
@@ -42,14 +42,14 @@ The first term is basically maximising the likelihood of the input data and is s
4242
7. [Observations](#7observations)
4343
8. [Credits](#8-credits)
4444

45-
So basically the loss has two opposing functions ..the reconstruction loss which tries to recreate the input as such not caring about the latent variable distribution and the KL divergence term which forces the distribution to be gaussian .
45+
So basically the loss has two opposing functions ..the reconstruction loss which tries to recreate the input as such not caring about the latent variable distribution and the KL divergence term which forces the distribution to be gaussian.
4646

4747
## 1. Setup Instructions and Dependencies
4848
You can either download the repo or clone it by running the following in cmd prompt
4949
```
5050
https://github.com/ayushtues/GenZoo.git
5151
```
52-
You can create a virtual environment and run the below command to install all required dependecies
52+
You can create a virtual environment and run the below command to install all required dependencies
5353

5454
```
5555
pip3 install -r requirements.txt
@@ -65,9 +65,9 @@ python main.py --config /path/to/config.ini
6565
The `config.ini` file should specify all the required parameters of the model.
6666
**To train on gpu set gpu = 1 else gpu = 0 in config.ini**
6767

68-
The program automatically downloads the MNIST dataset and saves it in `MNIST_dataset` (creating the folder itself) . This only happens once
68+
The program automatically downloads the MNIST dataset and saves it in `MNIST_dataset` (creating the folder itself). This only happens once
6969

70-
It also creates a `experiments` folder and inside it creates a `exp_name` folder as specified in your config.ini file .
70+
It also creates a `experiments` folder and inside it creates an `exp_name` folder as specified in your config.ini file.
7171

7272
The `exp_name` file has 3 folders
7373
- `training_checkpoints` - Contains the models saved with frequency as specified in `model_save_frequency` in `config.ini`
@@ -78,7 +78,7 @@ The `exp_name` file has 3 folders
7878

7979
To generate new images from z sampled randomly from uniform gaussian and to make a nice digit transit grid run the following command.
8080

81-
It can also generate a single digit , for which it basically takes the mean of the means and variance produced for all samples in the test dataset for that digit and uses this averaged mean and variance to sample some latent variables and pass it in the decoder .
81+
It can also generate a single digit, for which it basically takes the mean of the means and variance produced for all samples in the test dataset for that digit and uses this averaged mean and variance to sample some latent variables and pass it in the decoder.
8282
```
8383
python generate.py --dataset [DATASET] --model_path [PATH_TO_MODEL] --grid_size [GRID_SIZE] --save_path [SAVE_DIRECTORY] --z_dims [Z_DIMENSIONS] --grid_size2 [GRID_SIZE2] --no_datapoints [NO_DATAPOINTS] --digit [DIGIT]
8484
```
@@ -89,30 +89,30 @@ python generate.py --dataset [DATASET] --model_path [PATH_TO_MODEL] --grid_size
8989
- `save_path` - the directory where to save the images
9090
- `z_dims` - the size of the latent space (Useful if training models with different z_dims otherwise) (default 20 )
9191
- `grid_size2` - the size of the grid of the generated images of a single number sampled randomly
92-
- `no_datapoints`- the number of test datapoints used for approximating mean and var of the single digit
92+
- `no_datapoints`- the number of test data points used for approximating mean and var of the single-digit
9393
- `digit` - the digit to generate
9494

9595
You can use a pre-trained model (with z_dims = 20) by downloading it from the link in `model.txt`
9696

9797
## 4. Repository Overview
9898

99-
The repository contains of the following files
99+
The repository contains the following files
100100

101-
- `main.py` - Does the major work , Calls functions from various other files to nake , train , save the model and also do the t-sne visualisation .
102-
- `train.py`- Has the loss function , also the function to display images in a grid , save the reconstructed training images and also the model. Basically handles the training.
103-
- `model.py` - Contains the VAE model , the encoder , decoder , forward functions.
104-
- `dataloader.py` - Returns a dataloader from the MNIST dataset with given batch size .
105-
- `generate-py` - Generates new images and also the transition between two digits grid from a pretrained model.
106-
- `model.txt` - Contains link to a pretrained model (with z_dims =20)
107-
- `readme.md` - Readme giving overview of the repo
101+
- `main.py` - Does the major work, Calls functions from various other files to make, train, save the model and also do the t-sne visualisation.
102+
- `train.py`- Has the loss function, also the function to display images in a grid, save the reconstructed training images and also the model. Basically handles the training.
103+
- `model.py` - Contains the VAE model, the encoder, decoder, forward functions.
104+
- `dataloader.py` - Returns a data loader from the MNIST dataset with given batch size.
105+
- `generate-py` - Generates new images and also the transition between two digits grid from a pre-trained model.
106+
- `model.txt` - Contains a link to a pre-trained model (with z_dims =20)
107+
- `readme.md` - Readme giving an overview of the repo
108108
- `requirements.txt` - Has all the required dependencies to be installed for the repo
109109
- `readme_images` - Has various images for the readme
110110
- `configs` - Has the config.ini file
111111
- `MNIST_dataset` - Contains the downloaded MNIST Dataset(though `main.py` will download it if needed automatically)
112112
- `experiments` - Has various results of the training of the model
113-
- `generated_images` - Contains the digit_transit images and also new generated images
114-
- `mnist` - Contains the saved models , training images , and t-sne visualisation
115-
- `runs` Contains the tensorboard logs (automatically created by program)
113+
- `generated_images` - Contains the digit_transit images and also newly generated images
114+
- `mnist` - Contains the saved models, training images, and t-sne visualisation
115+
- `runs` Contains the tensorboard logs (automatically created by the program)
116116

117117
Do the following to run tensorboard
118118

@@ -123,12 +123,12 @@ tensordboard --logdir path/to/directory/runs
123123

124124
<img src='readme_images/VAE_architecture.png' style="max-width:100%">
125125

126-
The architecture is basically divided into two parts an encoder and a decoder .
127-
The encoder first has a bunch of convultional layers with LeakyRelu activation function and max pooling and batchnorm . The last conv layer also has dropout. Then there are a bunch of Fully connected layers with leakyRelu activation and dropout . Finally a fully connected layer gives us the mean and logvar respectively.
126+
The architecture is basically divided into two parts an encoder and a decoder.
127+
The encoder first has a bunch of convolutional layers with LeakyRelu activation function and max-pooling and batch norm . The last conv layer also has dropout. Then there are a bunch of Fully connected layers with leakyRelu activation and dropout. Finally, a fully connected layer gives us the mean and logvar respectively.
128128

129129
Then we sample from the distribution using the reparameterisation trick . With z = (std*eps)+mean , where eps = N (0,I) .
130130

131-
The decoder consists of a bunch of fully conncected layers followed by Transpose Convolutional layers and finally a sigmoid function which gives the output images .
131+
The decoder consists of a bunch of fully connected layers followed by Transpose Convolutional layers and finally a sigmoid function which gives the output images.
132132
## 6. Results
133133
### 1. Training images
134134
Image from 0th epoch
@@ -163,16 +163,16 @@ After 100 epochs the t-sne visualisation is
163163

164164
<img src='readme_images/t_sne_visualization.png' style="max-width:100%">
165165

166-
As we can see there are clearly 10 clusters formed and also there is a smooth transition between them , which means the model is working fine ( although there is 1 seperate cluster )
166+
As we can see there are clearly 10 clusters formed and also there is a smooth transition between them, which means the model is working fine ( although there is 1 separate cluster )
167167

168-
### 3. Image generated from random gaussian input
168+
### 3. Image generated from random Gaussian input
169169

170170
The following image was generated after passing a randomly sampled z from the unit gaussian and then passed through the decoder
171171

172172
<img src='readme_images/user_generated_image.png' style="max-width:100%">
173173

174174

175-
Although its clear that the images could be better but still they resemble digits well enough for a randomly sampled gaussian input .
175+
Although it’s clear that the images could be better still they resemble digits well enough for a randomly sampled Gaussian input.
176176

177177
### 4. Smooth transition between two digits
178178

@@ -197,8 +197,8 @@ The following shows images formed when the latent variable z of one image was un
197197

198198
**As we can see the total loss and reconstruction loss decrease uniformly as expected**
199199

200-
### 6 . Single digit generated samples
201-
The following images are genrated by randomly sampling latent variables from a gaussian with mean and var given by the average of all the means and var output by all the test examples of the images belonging to that class
200+
### 6 . Single-digit generated samples
201+
The following images are generated by randomly sampling latent variables from a gaussian with mean and var given by the average of all the means and var output by all the test examples of the images belonging to that class
202202

203203
Generated images of digit 1
204204

@@ -208,19 +208,19 @@ Generated images of digit 9
208208

209209
<img src='readme_images/digit_9.png' style="max-width:100%">
210210

211-
As we can the see the images resemble the corresponding digits .
211+
As we can see the images resemble the corresponding digits.
212212

213213
## 7.Observations
214-
The model was trained on google colab for 100 epoch , with batch size 50 . It took approx 10-15 mins to train .
214+
The model was trained on google collab for 100 epoch, with batch size 50. It took approx 10-15 mins to train.
215215

216-
After training the model was able to reconstruct the input images quite well , and was also able to generate new images although the generated images are not so clear .
217-
The T-sne visualisation of the latent space shows that the latent space has been divided into 10 clusters , but also has smooth transitions between these spaces , indicating that the model forced the latent space to be a bit similar with the normal distribution .
218-
The digit-transit images show that **latent space has a linear nature** and linearly changing the latent variables from one digit to another leads to a smooth transition .
216+
After training the model was able to reconstruct the input images quite well, and was also able to generate new images although the generated images are not so clear.
217+
The T-sne visualisation of the latent space shows that the latent space has been divided into 10 clusters, but also has smooth transitions between these spaces, indicating that the model forced the latent space to be a bit similar with the normal distribution.
218+
The digit-transit images show that **latent space has a linear nature** and linearly changing the latent variables from one digit to another leads to a smooth transition.
219219

220-
**Also using the estimate of mean and var for a single class , and generating samples from it gave pretty good images , so this might mean that the marginal distribution of a particular class is can be approximated by a gaussian with the mean and var as given by the method and the dataset can be viewed as a sum of 10 different gaussians . Although there are a few 4's and 8's generated in the 9's image meaning these gaussians have some mixing**
220+
**Also using the estimate of mean and var for a single class , and generating samples from it gave pretty good images , so this might mean that the marginal distribution of a particular class is can be approximated by a gaussian with the mean and var as given by the method and the dataset can be viewed as a sum of 10 different Gaussians . Although there are a few 4's and 8's generated in the 9's image meaning these Gaussians have some mixing**
221221

222222
One peculiar thing to notice was that the **KL-Divergence loss actually increased as the model trained** .
223-
I found a possible explanation at [this reddit post](https://www.reddit.com/r/MachineLearning/comments/6m2tje/d_kl_divergence_decreases_to_a_point_and_then/)
223+
I found a possible explanation at [this Reddit post](https://www.reddit.com/r/MachineLearning/comments/6m2tje/d_kl_divergence_decreases_to_a_point_and_then/)
224224

225225
TLDR : Guess its easier for the model to make the distribution gaussian so it first does that , but a perfect gaussian means that there is no uniqueness introduced by the training images , so after a while the distribution shifts from a perfect gaussian to incorporate the subtle patterns in the data and thus increasing the KLD loss .
226226

@@ -238,3 +238,5 @@ Diederik P Kingma, Max Welling](https://arxiv.org/abs/1312.6114)
238238

239239

240240

241+
242+

0 commit comments

Comments
 (0)