Skip to content

Commit 9cfdfe0

Browse files
committed
add cuda
1 parent 9f29bcd commit 9cfdfe0

File tree

5 files changed

+322
-128
lines changed

5 files changed

+322
-128
lines changed

GAN.ipynb

Lines changed: 89 additions & 90 deletions
Large diffs are not rendered by default.

WGAN.ipynb

Lines changed: 49 additions & 38 deletions
Large diffs are not rendered by default.

WGAN.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
2+
# coding: utf-8
3+
4+
# In[1]:
5+
6+
7+
import torch
8+
import torch.nn as nn
9+
from torch.autograd import Variable
10+
import torchvision.datasets as dset
11+
import torchvision.transforms as transforms
12+
import torch.nn.functional as F
13+
import torch.optim as optim
14+
import sys
15+
16+
17+
# In[2]:
18+
19+
20+
from matplotlib import pyplot as plt
21+
from torchvision import utils
22+
show_image=True
23+
def imshow(inp, file_name, save=False, title=None):
24+
"""Imshow for Tensor."""
25+
fig = plt.figure(figsize=(5, 5))
26+
inp = inp.numpy().transpose((1, 2, 0))
27+
plt.imshow(inp, cmap='gray')
28+
if show_image:
29+
plt.show()
30+
31+
32+
# In[4]:
33+
34+
35+
root = './data'
36+
download = True
37+
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5,), (0.5,0.5,0.5))])
38+
trans = transforms.Compose([transforms.ToTensor()])
39+
train_set = dset.MNIST(root=root, train=True, transform=trans, download=download)
40+
test_set = dset.MNIST(root=root, train=False, transform=trans)
41+
batch_size = 128
42+
kwargs = {}
43+
train_loader = torch.utils.data.DataLoader(
44+
dataset=train_set,
45+
batch_size=batch_size,
46+
shuffle=True)
47+
test_loader = torch.utils.data.DataLoader(
48+
dataset=test_set,
49+
batch_size=batch_size,
50+
shuffle=False)
51+
52+
53+
# In[70]:
54+
55+
56+
z_size=128
57+
hidden_size=128
58+
img_size=28
59+
60+
61+
# In[71]:
62+
63+
64+
class Generator(nn.Module):
65+
def __init__(self):
66+
super().__init__()
67+
self.model = nn.Sequential(
68+
nn.Linear(z_size, hidden_size*2),
69+
nn.LeakyReLU(0.2, inplace=True),
70+
nn.Linear(hidden_size*2, hidden_size*4),
71+
nn.LeakyReLU(0.2, inplace=True),
72+
# nn.Linear(hidden_size*4, hidden_size*8),
73+
# nn.LeakyReLU(0.2, inplace=True),
74+
nn.Linear(hidden_size*4, img_size**2),
75+
nn.Tanh()
76+
)
77+
def forward(self, x):
78+
x = x.view(x.size()[0], z_size)
79+
out = self.model(x)
80+
out = out.view(x.size()[0], 1,img_size,img_size)
81+
return out
82+
83+
84+
# In[72]:
85+
86+
87+
class Discriminator(nn.Module):
88+
def __init__(self):
89+
super().__init__()
90+
self.model = nn.Sequential(
91+
nn.Linear(img_size**2, hidden_size*4),
92+
nn.LeakyReLU(0.2, inplace=True),
93+
nn.Dropout(0.3),
94+
# nn.Linear(hidden_size*8, hidden_size*4),
95+
# nn.LeakyReLU(0.2, inplace=True),
96+
# nn.Dropout(0.3),
97+
nn.Linear(hidden_size*4, hidden_size*2),
98+
nn.LeakyReLU(0.2, inplace=True),
99+
nn.Dropout(0.3),
100+
nn.Linear(hidden_size*2, 1),
101+
)
102+
def forward(self, x):
103+
out = self.model(x.view(x.size(0), img_size**2))
104+
out = out.view(out.size(0), -1)
105+
return out
106+
107+
108+
# In[73]:
109+
110+
111+
from tqdm import tqdm
112+
G = Generator()
113+
D = Discriminator()
114+
if torch.cuda.is_available():
115+
G.cuda()
116+
D.cuda()
117+
G_lr = D_lr = 5e-5
118+
optimizers = {
119+
'D': torch.optim.RMSprop(D.parameters(), lr=D_lr),
120+
'G': torch.optim.RMSprop(G.parameters(), lr=G_lr)
121+
}
122+
criterion = nn.BCELoss()
123+
for epoch in tqdm(range(10000)):
124+
for _ in range(5):
125+
optimizers['D'].zero_grad()
126+
data=next(iter(train_loader))[0]
127+
data = Variable(data)
128+
output_real = D(data)
129+
noisev = torch.randn(data.size()[0], z_size, 1, 1)
130+
noisev = Variable(noisev)
131+
fake_data = G(noisev)
132+
output_fake = D(fake_data)
133+
D_loss = -(torch.mean(output_real) - torch.mean(output_fake))
134+
135+
D_loss.backward()
136+
optimizers['D'].step()
137+
for p in D.parameters():
138+
p.data.clamp_(-0.01, 0.01)
139+
140+
optimizers['G'].zero_grad()
141+
noisev = torch.randn(data.size()[0], z_size, 1, 1)
142+
noisev = Variable(noisev)
143+
fake_data = G(noisev)
144+
output_fake1 = D(fake_data)
145+
G_loss = -torch.mean(output_fake1)
146+
147+
G_loss.backward()
148+
optimizers['G'].step()
149+
150+
if epoch % 1000 == 0:
151+
dd = utils.make_grid(fake_data.data[:16])
152+
imshow(dd,'./results/WGAN_%d.png'%(epoch))
153+
154+
155+
# In[28]:
156+
157+
158+
class Generator(nn.Module):
159+
def __init__(self):
160+
super().__init__()
161+
d=128
162+
self.model = nn.Sequential(
163+
nn.ConvTranspose2d(z_size, d*8, 4, 1, 0),
164+
nn.BatchNorm2d(d*8),
165+
nn.ReLU(),
166+
nn.ConvTranspose2d(d*8, d*4, 4, 2, 1),
167+
nn.BatchNorm2d(d*4),
168+
nn.ReLU(),
169+
nn.ConvTranspose2d(d*4, d*2, 4, 2, 1),
170+
nn.BatchNorm2d(d*2),
171+
nn.ReLU(),
172+
nn.ConvTranspose2d(d*2, d, 4, 2, 1),
173+
nn.BatchNorm2d(d),
174+
nn.ReLU(),
175+
nn.ConvTranspose2d(d, 1, 4, 2, 1),
176+
nn.Tanh(),
177+
)
178+
def forward(self, x):
179+
x = x.view(x.size()[0], z_size, 1,1)
180+
out = self.model(x)
181+
print(out.size())
182+
out = out.view(x.size()[0], 1,img_size,img_size)
183+
return out
184+
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)