Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VQ-VAE training example(v2) returned NAN loss #198

Open
EBGU opened this issue Feb 13, 2021 · 4 comments
Open

VQ-VAE training example(v2) returned NAN loss #198

EBGU opened this issue Feb 13, 2021 · 4 comments

Comments

@EBGU
Copy link

EBGU commented Feb 13, 2021

Dear Team Deepmind,

I am really grateful that you shared a vqvae_example with sonnet2. However, when running it, I currently encounter a problem of NAN vqvae loss from the beginning. The outcome is:
100 train loss: nan recon_error: 1.010 perplexity: 1.031 vqvae loss: nan
and so on.
The plot of the training set is fine, but the reconstruction is pure grey. I tried vq_use_ema = False of True and got the same results.
I have slightly modified your code by replacing downloading and data loading with the previous version(https://github.com/deepmind/sonnet/blob/master/sonnet/examples/vqvae_example.ipynb) using a local directory. Also, I'm using TensorFlow version 2.2.0 Sonnet version 2.0.0. My code didn't return any error, just NAN loss.
I wonder if you could kindly help me with this problem.
Thanks a lot!

Sincerely,
Harold

My code:
import os
import subprocess
import tempfile

import matplotlib.pyplot as plt
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import tree

try:
import sonnet.v2 as snt
tf.enable_v2_behavior()
except ImportError:
import sonnet as snt

from six.moves import cPickle
from six.moves import urllib
from six.moves import xrange

#for plt dispaly
os.system('export DISPLAY=:0')

print("TensorFlow version {}".format(tf.version))
print("Sonnet version {}".format(snt.version))

local_data_dir='/home/harold/Documents/VQ-VAE'
'''
#Downloading cifar10
cifar10 = tfds.as_numpy(tfds.load("cifar10:3.0.2", split="train+test", batch_size=-1))
cifar10.pop("id", None)
cifar10.pop("label")
tree.map_structure(lambda x: f'{x.dtype.name}{list(x.shape)}', cifar10)
'''

#Data loading
'''
train_data_dict = tree.map_structure(lambda x: x[:40000], cifar10)
valid_data_dict = tree.map_structure(lambda x: x[40000:50000], cifar10)
test_data_dict = tree.map_structure(lambda x: x[50000:], cifar10)

def cast_and_normalise_images(data_dict):
"""Convert images to floating point with the range [-0.5, 0.5]"""
images = data_dict['image']
data_dict['image'] = (tf.cast(images, tf.float32) / 255.0) - 0.5
return data_dict

train_data_variance = np.var(train_data_dict['image'] / 255.0)
print('train data variance: %s' % train_data_variance)
'''

def unpickle(filename):
with open(filename, 'rb') as fo:
return cPickle.load(fo, encoding='latin1')

def reshape_flattened_image_batch(flat_image_batch):
return flat_image_batch.reshape(-1, 3, 32, 32).transpose([0, 2, 3, 1]) # convert from NCHW to NHWC

def combine_batches(batch_list):
images = np.vstack([reshape_flattened_image_batch(batch['data'])
for batch in batch_list])
labels = np.vstack([np.array(batch['labels']) for batch in batch_list]).reshape(-1, 1)
return {'images': images, 'labels': labels}

train_data_dict = combine_batches([
unpickle(os.path.join(local_data_dir,
'cifar-10-batches-py/data_batch_%d' % i))
for i in range(1,5)
])

valid_data_dict = combine_batches([
unpickle(os.path.join(local_data_dir,
'cifar-10-batches-py/data_batch_5'))])

test_data_dict = combine_batches([
unpickle(os.path.join(local_data_dir, 'cifar-10-batches-py/test_batch'))])

def cast_and_normalise_images(data_dict):
"""Convert images to floating point with the range [-0.5, 0.5]"""
images = data_dict['images']
data_dict['images'] = (tf.cast(images, tf.float32) / 255.0) - 0.5
return data_dict

train_data_variance = np.var(train_data_dict['images'] / 255.0)
print('train data variance: %s' % train_data_variance)

#Encoder & Decoder Architecture
class ResidualStack(snt.Module):
def init(self, num_hiddens, num_residual_layers, num_residual_hiddens,
name=None):
super(ResidualStack, self).init(name=name)
self._num_hiddens = num_hiddens
self._num_residual_layers = num_residual_layers
self._num_residual_hiddens = num_residual_hiddens

self._layers = []
for i in range(num_residual_layers):
  conv3 = snt.Conv2D(
      output_channels=num_residual_hiddens,
      kernel_shape=(3, 3),
      stride=(1, 1),
      name="res3x3_%d" % i)
  conv1 = snt.Conv2D(
      output_channels=num_hiddens,
      kernel_shape=(1, 1),
      stride=(1, 1),
      name="res1x1_%d" % i)
  self._layers.append((conv3, conv1))

def call(self, inputs):
h = inputs
for conv3, conv1 in self._layers:
conv3_out = conv3(tf.nn.relu(h))
conv1_out = conv1(tf.nn.relu(conv3_out))
h += conv1_out
return tf.nn.relu(h) # Resnet V1 style

class Encoder(snt.Module):
def init(self, num_hiddens, num_residual_layers, num_residual_hiddens,
name=None):
super(Encoder, self).init(name=name)
self._num_hiddens = num_hiddens
self._num_residual_layers = num_residual_layers
self._num_residual_hiddens = num_residual_hiddens

self._enc_1 = snt.Conv2D(
    output_channels=self._num_hiddens // 2,
    kernel_shape=(4, 4),
    stride=(2, 2),
    name="enc_1")
self._enc_2 = snt.Conv2D(
    output_channels=self._num_hiddens,
    kernel_shape=(4, 4),
    stride=(2, 2),
    name="enc_2")
self._enc_3 = snt.Conv2D(
    output_channels=self._num_hiddens,
    kernel_shape=(3, 3),
    stride=(1, 1),
    name="enc_3")
self._residual_stack = ResidualStack(
    self._num_hiddens,
    self._num_residual_layers,
    self._num_residual_hiddens)

def call(self, x):
h = tf.nn.relu(self._enc_1(x))
h = tf.nn.relu(self._enc_2(h))
h = tf.nn.relu(self._enc_3(h))
return self._residual_stack(h)

class Decoder(snt.Module):
def init(self, num_hiddens, num_residual_layers, num_residual_hiddens,
name=None):
super(Decoder, self).init(name=name)
self._num_hiddens = num_hiddens
self._num_residual_layers = num_residual_layers
self._num_residual_hiddens = num_residual_hiddens

self._dec_1 = snt.Conv2D(
    output_channels=self._num_hiddens,
    kernel_shape=(3, 3),
    stride=(1, 1),
    name="dec_1")
self._residual_stack = ResidualStack(
    self._num_hiddens,
    self._num_residual_layers,
    self._num_residual_hiddens)
self._dec_2 = snt.Conv2DTranspose(
    output_channels=self._num_hiddens // 2,
    output_shape=None,
    kernel_shape=(4, 4),
    stride=(2, 2),
    name="dec_2")
self._dec_3 = snt.Conv2DTranspose(
    output_channels=3,
    output_shape=None,
    kernel_shape=(4, 4),
    stride=(2, 2),
    name="dec_3")

def call(self, x):
h = self._dec_1(x)
h = self._residual_stack(h)
h = tf.nn.relu(self._dec_2(h))
x_recon = self._dec_3(h)
return x_recon

class VQVAEModel(snt.Module):
def init(self, encoder, decoder, vqvae, pre_vq_conv1,
data_variance, name=None):
super(VQVAEModel, self).init(name=name)
self._encoder = encoder
self._decoder = decoder
self._vqvae = vqvae
self._pre_vq_conv1 = pre_vq_conv1
self._data_variance = data_variance

def call(self, inputs, is_training):
z = self._pre_vq_conv1(self._encoder(inputs))
vq_output = self._vqvae(z, is_training=is_training)
x_recon = self._decoder(vq_output['quantize'])
recon_error = tf.reduce_mean((x_recon - inputs) ** 2) / self._data_variance
loss = recon_error + vq_output['loss']
return {
'z': z,
'x_recon': x_recon,
'loss': loss,
'recon_error': recon_error,
'vq_output': vq_output,
}

#Build Model and train
#%%time

Set hyper-parameters.

batch_size = 32
image_size = 32

100k steps should take < 30 minutes on a modern (>= 2017) GPU.

10k steps gives reasonable accuracy with VQVAE on Cifar10.

num_training_updates = 10000

num_hiddens = 128
num_residual_hiddens = 32
num_residual_layers = 2

These hyper-parameters define the size of the model (number of parameters and layers).

The hyper-parameters in the paper were (For ImageNet):

batch_size = 128

image_size = 128

num_hiddens = 128

num_residual_hiddens = 32

num_residual_layers = 2

This value is not that important, usually 64 works.

This will not change the capacity in the information-bottleneck.

embedding_dim = 64

The higher this value, the higher the capacity in the information bottleneck.

num_embeddings = 512

commitment_cost should be set appropriately. It's often useful to try a couple

of values. It mostly depends on the scale of the reconstruction cost

(log p(x|z)). So if the reconstruction cost is 100x higher, the

commitment_cost should also be multiplied with the same amount.

commitment_cost = 0.25

Use EMA updates for the codebook (instead of the Adam optimizer).

This typically converges faster, and makes the model less dependent on choice

of the optimizer. In the VQ-VAE paper EMA updates were not used (but was

developed afterwards). See Appendix of the paper for more details.

vq_use_ema = False

This is only used for EMA updates.

decay = 0.99

learning_rate = 3e-4

# Data Loading.

train_dataset = (
tf.data.Dataset.from_tensor_slices(train_data_dict)
.map(cast_and_normalise_images)
.shuffle(10000)
.repeat(-1) # repeat indefinitely
.batch(batch_size, drop_remainder=True)
.prefetch(-1))

valid_dataset = (
tf.data.Dataset.from_tensor_slices(valid_data_dict)
.map(cast_and_normalise_images)
.repeat(1) # 1 epoch
.batch(batch_size)
.prefetch(-1))

'''

train_batch = next(iter(train_dataset))

def convert_batch_to_image_grid(image_batch):
reshaped = (image_batch.reshape(4, 8, 32, 32, 3)
.transpose(0, 2, 1, 3, 4)
.reshape(4 * 32, 8 * 32, 3))
return reshaped + 0.5

f = plt.figure(figsize=(16,8))
ax = f.add_subplot(2,2,1)
ax.imshow(convert_batch_to_image_grid(train_batch['images'].numpy()),
interpolation='nearest')
ax.set_title('training data originals')
plt.axis('off')
plt.show()

'''

# Build modules.

encoder = Encoder(num_hiddens, num_residual_layers, num_residual_hiddens)
decoder = Decoder(num_hiddens, num_residual_layers, num_residual_hiddens)
pre_vq_conv1 = snt.Conv2D(output_channels=embedding_dim,
kernel_shape=(1, 1),
stride=(1, 1),
name="to_vq")

if vq_use_ema:
vq_vae = snt.nets.VectorQuantizerEMA(
embedding_dim=embedding_dim,
num_embeddings=num_embeddings,
commitment_cost=commitment_cost,
decay=decay)
else:
vq_vae = snt.nets.VectorQuantizer(
embedding_dim=embedding_dim,
num_embeddings=num_embeddings,
commitment_cost=commitment_cost)

model = VQVAEModel(encoder, decoder, vq_vae, pre_vq_conv1,
data_variance=train_data_variance)

optimizer = snt.optimizers.Adam(learning_rate=learning_rate)

@tf.function
def train_step(data):
with tf.GradientTape() as tape:
model_output = model(data['images'], is_training=True)
trainable_variables = model.trainable_variables
grads = tape.gradient(model_output['loss'], trainable_variables)
optimizer.apply(grads, trainable_variables)

return model_output

train_losses = []
train_recon_errors = []
train_perplexities = []
train_vqvae_loss = []

for step_index, data in enumerate(train_dataset):
train_results = train_step(data)
train_losses.append(train_results['loss'])
train_recon_errors.append(train_results['recon_error'])
train_perplexities.append(train_results['vq_output']['perplexity'])
train_vqvae_loss.append(train_results['vq_output']['loss'])

if (step_index + 1) % 100 == 0:
print('%d train loss: %f ' % (step_index + 1,
np.mean(train_losses[-100:])) +
('recon_error: %.3f ' % np.mean(train_recon_errors[-100:])) +
('perplexity: %.3f ' % np.mean(train_perplexities[-100:])) +
('vqvae loss: %.3f' % np.mean(train_vqvae_loss[-100:])))
if step_index == num_training_updates:
break

#Plot loss
f = plt.figure(figsize=(16,8))
ax = f.add_subplot(1,2,1)
ax.plot(train_recon_errors)
ax.set_yscale('log')
ax.set_title('NMSE.')

ax = f.add_subplot(1,2,2)
ax.plot(train_perplexities)
ax.set_title('Average codebook usage (perplexity).')
plt.show()
#Visualization

Reconstructions

train_batch = next(iter(train_dataset))
valid_batch = next(iter(valid_dataset))

Put data through the model with is_training=False, so that in the case of

using EMA the codebook is not updated.

train_reconstructions = model(train_batch['images'],
is_training=False)['x_recon'].numpy()
valid_reconstructions = model(valid_batch['images'],
is_training=False)['x_recon'].numpy()

def convert_batch_to_image_grid(image_batch):
reshaped = (image_batch.reshape(4, 8, 32, 32, 3)
.transpose(0, 2, 1, 3, 4)
.reshape(4 * 32, 8 * 32, 3))
return reshaped + 0.5

f = plt.figure(figsize=(16,8))
ax = f.add_subplot(2,2,1)
ax.imshow(convert_batch_to_image_grid(train_batch['images'].numpy()),
interpolation='nearest')
ax.set_title('training data originals')
plt.axis('off')

ax = f.add_subplot(2,2,2)
ax.imshow(convert_batch_to_image_grid(train_reconstructions),
interpolation='nearest')
ax.set_title('training data reconstructions')
plt.axis('off')

ax = f.add_subplot(2,2,3)
ax.imshow(convert_batch_to_image_grid(valid_batch['images'].numpy()),
interpolation='nearest')
ax.set_title('validation data originals')
plt.axis('off')

ax = f.add_subplot(2,2,4)
ax.imshow(convert_batch_to_image_grid(valid_reconstructions),
interpolation='nearest')
ax.set_title('validation data reconstructions')
plt.axis('off')
plt.show()

@tomhennigan
Copy link
Collaborator

Hi @EBGU , there's quite a lot of code there! I recognize at least some of this from our vqvae example notebook? Rather than printing the whole file it might be more useful if you could highlight what you have changed?

I've just ran our vqvae notebook using a free GPU instance on Google Colab, with TF 2.4.1, you can see the results in the gist below:

https://colab.research.google.com/gist/tomhennigan/62edee62a4638e0d0ab9738a757043ed/tf2_vq_vae_training_example.ipynb

As far as I can tell things are working correctly?

@EBGU
Copy link
Author

EBGU commented Feb 16, 2021

Hi @tomhennigan! I also tried your original code without any changes. The result was still NaN. I thought it could be an environmental problem, but there was no error coming up.

@EBGU
Copy link
Author

EBGU commented Feb 16, 2021

I upgrade my tf to 2.4.1 and it worked! I guess tf 2.2.0 is somehow incompatible with the code. Thanks a lot!

@abhilash1910
Copy link

abhilash1910 commented Sep 6, 2021

Hi ,
I tried to run the Notebook with TF 2.2.0 and it works. Please find the notebook: https://colab.research.google.com/drive/18GT4HVkjDwHB4e2AEU2G8XYU__A2t-F6?usp=sharing
Hope this helps

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants