Skip to content

Commit 4d8c8b7

Browse files
authored
chore: Refactor code
Refactored the goddamn code for presentation
1 parent 352f9a5 commit 4d8c8b7

File tree

3 files changed

+191
-0
lines changed

3 files changed

+191
-0
lines changed

models.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import numpy as np
2+
import pandas as pd
3+
import PIL
4+
from PIL import Image
5+
import cv2
6+
import os, os.path
7+
import matplotlib.pyplot as plt
8+
import torch
9+
import torch.nn as nn
10+
import torch.optim as optim
11+
import torch.nn.functional as F
12+
from torch.autograd import Variable
13+
14+
15+
class Generator(nn.Module):
16+
def __init__(self):
17+
super(Generator, self).__init__()
18+
self.seq = nn.Sequential(
19+
nn.ConvTranspose2d(100, 64*8, 4, 1, 0, bias=False),
20+
nn.BatchNorm2d(64*8),
21+
nn.ReLU(True),
22+
nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),
23+
nn.BatchNorm2d(64 * 4),
24+
nn.ReLU(True),
25+
nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),
26+
nn.BatchNorm2d(64*2),
27+
nn.ReLU(True),
28+
nn.ConvTranspose2d(64*2, 64, 4, 2, 1, bias=False),
29+
nn.BatchNorm2d(64),
30+
nn.ReLU(True),
31+
nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
32+
nn.Tanh()
33+
)
34+
35+
def feed_forward(self, inp):
36+
x = self.seq(inp)
37+
return x
38+
39+
class Discriminator(nn.Module):
40+
def __init__(self):
41+
super(Discriminator, self).__init__()
42+
43+
self.seq = nn.Sequential(
44+
# input is (nc) x 64 x 64
45+
nn.Conv2d(3, 64, 4, 2, 1, bias=False),
46+
nn.LeakyReLU(0.2, inplace=True),
47+
# state size. (ndf) x 32 x 32
48+
nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),
49+
nn.BatchNorm2d(64 * 2),
50+
nn.LeakyReLU(0.2, inplace=True),
51+
# state size. (ndf*2) x 16 x 16
52+
nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
53+
nn.BatchNorm2d(64 * 4),
54+
nn.LeakyReLU(0.2, inplace=True),
55+
# state size. (64*4) x 8 x 8
56+
nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),
57+
nn.BatchNorm2d(64 * 8),
58+
nn.LeakyReLU(0.2, inplace=True),
59+
# state size. (64*8) x 4 x 4
60+
nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),
61+
nn.Sigmoid()
62+
)
63+
64+
def feed_forward(self, x):
65+
x = self.seq(x)
66+
return x
67+
68+
class GAN(nn.Module):
69+
def __init__(self):
70+
super(GAN, self).__init__()
71+
self.generator = Generator().float().cuda()
72+
self.discriminator = Discriminator().float().cuda()
73+
self.generator_optim = optim.Adam(self.generator.parameters())
74+
self.discriminator_optim = optim.Adam(self.discriminator.parameters())
75+
76+
def generate_images(self, batch_size):
77+
inp = Variable(torch.randn(batch_size,100,1,1))
78+
out = self.generator.feed_forward(inp)
79+
return out
80+
81+
def train(self, images, epochs = 20):
82+
print('============ Starting training of GAN ============')
83+
batch_size = 10
84+
loss = nn.BCELoss()
85+
for epoch in range(epochs):
86+
discriminator_error = 0
87+
generator_error = 0
88+
for i in range(int(len(images)/batch_size) + 1):
89+
self.generator_optim.zero_grad()
90+
self.discriminator_optim.zero_grad()
91+
try:
92+
orig_images = torch.from_numpy(images[i*batch_size:(i+1)*batch_size, :, :, :]).permute(0, 3, 1, 2).float().cuda()
93+
if(orig_images.shape[0]==0):
94+
break
95+
except:
96+
orig_images = torch.from_numpy(images[i*batch_size:]).permute(0, 3, 1, 2).float().cuda()
97+
98+
#Training the discriminator
99+
fake_images_truth = Variable(torch.zeros(batch_size)).float().cuda()
100+
noise = Variable(torch.randn(batch_size, 100, 1, 1)).float().cuda()
101+
gen_images = self.generator.feed_forward(noise).cuda()
102+
fake_images_prediction = self.discriminator.feed_forward(gen_images)
103+
dis_error_fake = loss(fake_images_prediction, fake_images_truth)
104+
dis_error_fake.backward()
105+
106+
true_images_truth = Variable(torch.ones(len(orig_images))).float().cuda()
107+
true_images_prediction = self.discriminator.feed_forward(orig_images)
108+
dis_error_true = loss(true_images_prediction, true_images_truth)
109+
dis_error_true.backward()
110+
self.discriminator_optim.step()
111+
112+
discriminator_error = discriminator_error + dis_error_fake + dis_error_true
113+
114+
#Training the generator
115+
116+
noise = Variable(torch.randn(batch_size, 100, 1, 1)).float().cuda()
117+
gen_images = self.generator.feed_forward(noise)
118+
fake_images_prediction = self.discriminator.feed_forward(gen_images.cuda())
119+
fake_images_truth = Variable(torch.ones(batch_size)).float().cuda()
120+
dis_error_fake = loss(fake_images_prediction, fake_images_truth)
121+
dis_error_fake.backward()
122+
self.generator_optim.step()
123+
gen_error = dis_error_fake
124+
125+
generator_error = generator_error + gen_error
126+
127+
print("========== Epoch : {} | Generator Loss : {} | Discriminator Loss : {} =========="\
128+
.format(epoch+1, generator_error.detach(), discriminator_error.detach()))
129+
return self.generator, self.discriminator

preprocess.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import numpy as np
2+
import pandas as pd
3+
import PIL
4+
from PIL import Image
5+
import cv2
6+
7+
8+
def Normalize(image,mean = [0.485,0.456,0.406],std = [0.229,0.224,0.225]):
9+
for channel in range(3):
10+
image[:,:,channel]=(image[:,:,channel]-mean[channel])/std[channel]
11+
return image
12+
13+
def DeNormalize(image, mean = [0.485,0.456,0.406], std = [0.229,0.224,0.225]):
14+
for channel in range(3):
15+
image[:,:,channel] = image[:,:,channel]*std[channel]+mean[channel]
16+
return image

train.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import numpy as np
2+
import pandas as pd
3+
import PIL
4+
from PIL import Image
5+
import cv2
6+
import os, os.path
7+
import matplotlib.pyplot as plt
8+
import torch
9+
import torch.nn as nn
10+
import torch.optim as optim
11+
import torch.nn.functional as F
12+
from torch.autograd import Variable
13+
from models import Generator, Discriminator, GAN
14+
from preprocess import Normalize, DeNormalize
15+
16+
if __name__ == '__main__':
17+
imageDir = '...celeb_images/Part 1'
18+
19+
image_path_list = []
20+
for file in os.listdir(imageDir):
21+
image_path_list.append(os.path.join(imageDir, file))
22+
23+
image = np.empty([len(image_path_list), 64, 64, 3])
24+
for indx, imagePath in enumerate(image_path_list):
25+
if(indx>5000):
26+
break
27+
im = Image.open(imagePath).convert('RGB')
28+
im = im.resize((64, 64))
29+
im=np.array(im,dtype=np.float32)
30+
im=im/255
31+
im=Normalize(im, [0.5,0.5,0.5], [0.5,0.5,0.5])
32+
image[indx,:,:,:] = im
33+
34+
image = image[:5000,:,:,:]
35+
36+
37+
gan = GAN()
38+
gen, dis = gan.train(image,200)
39+
40+
noise = Variable(torch.randn(1, 100, 1, 1)).float().cuda()
41+
im = gen.feed_forward(noise)
42+
im = im.permute(0, 2, 3, 1)
43+
im = im.squeeze(0).cpu().detach().numpy()*255
44+
im = DeNormalize(im,[0.5,0.5,0.5],[0.5,0.5,0.5])
45+
46+
plt.imshow(im)

0 commit comments

Comments
 (0)