-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy path025-VAE的实现.py
152 lines (111 loc) · 4.29 KB
/
025-VAE的实现.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
@file : 025-VAE的实现.py
@author : xiaolu
@time : 2019-06-13
"""
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from keras.layers import Dense, Input
from keras.layers import Conv2D, Flatten, Lambda
from keras.layers import Reshape, Conv2DTranspose
from keras.models import Model
from keras import backend as K
from keras.datasets import mnist
# 加载MNIST数据集
(x_train, y_train_), (x_test, y_test_) = mnist.load_data()
image_size = x_train.shape[1]
x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
x_test = np.reshape(x_test, [-1, image_size, image_size, 1])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
# 网络参数
input_shape = (image_size, image_size, 1)
batch_size = 100
kernel_size = 3
filters = 16
latent_dim = 2 # 隐变量取2维只是为了方便后面画图
epochs = 30
x_in = Input(shape=input_shape)
x = x_in
# 编码过程
for i in range(2):
filters *= 2
x = Conv2D(filters=filters,
kernel_size=kernel_size,
activation='relu',
strides=2,
padding='same')(x)
# 备份当前shape,等下构建decoder的时候要用
shape = K.int_shape(x) # [批量数, 28, 28, 1]
x = Flatten()(x)
x = Dense(16, activation='relu')(x)
# 算p(Z|X)的均值和方差
z_mean = Dense(latent_dim)(x)
z_log_var = Dense(latent_dim)(x)
# 重参数技巧
def sampling(args):
z_mean, z_log_var = args
epsilon = K.random_normal(shape=K.shape(z_mean))
return z_mean + K.exp(z_log_var / 2) * epsilon
# 重参数层,相当于给输入加入噪声
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
# 解码层,也就是生成器部分
# 先搭建为一个独立的模型,然后再调用模型
latent_inputs = Input(shape=(latent_dim,))
x = Dense(shape[1] * shape[2] * shape[3], activation='relu')(latent_inputs) # 通过Dense将latent_dim维扩展到图片的Flatten维
x = Reshape((shape[1], shape[2], shape[3]))(x) # 在进行reshape 成图片的(28, 28, 1)
for i in range(2):
x = Conv2DTranspose(filters=filters,
kernel_size=kernel_size,
activation='relu',
strides=2,
padding='same')(x)
filters //= 2
# 最后的输出
outputs = Conv2DTranspose(filters=1,
kernel_size=kernel_size,
activation='sigmoid',
padding='same')(x)
# 搭建为一个独立的模型 相当于生成模型 从隐层到输出层
decoder = Model(latent_inputs, outputs)
x_out = decoder(z) # z就是隐层随机给的东西
# 建立模型 整理的模型 从编码的输入到解码的输出
vae = Model(x_in, x_out)
# xent_loss是重构loss,kl_loss是KL loss
xent_loss = K.sum(K.binary_crossentropy(x_in, x_out), axis=[1, 2, 3]) # 头到尾的损失
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1) # 头到隐层的损失
vae_loss = K.mean(xent_loss + kl_loss) # 总的损失
# add_loss是新增的方法,用于更灵活地添加各种loss
vae.add_loss(vae_loss)
vae.compile(optimizer='rmsprop')
vae.summary()
vae.fit(x_train,
shuffle=True,
epochs=epochs,
batch_size=batch_size,
validation_data=(x_test, None))
# 构建encoder,然后观察各个数字在隐空间的分布
encoder = Model(x_in, z_mean)
x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
plt.figure(figsize=(6, 6))
plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test_)
plt.colorbar()
plt.show()
# 观察隐变量的两个维度变化是如何影响输出结果的
n = 15 # figure with 15x15 digits
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
# 用正态分布的分位数来构建隐变量对
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))
for i, yi in enumerate(grid_x):
for j, xi in enumerate(grid_y):
z_sample = np.array([[xi, yi]])
x_decoded = decoder.predict(z_sample)
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[i * digit_size: (i + 1) * digit_size,
j * digit_size: (j + 1) * digit_size] = digit
plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
plt.show()