-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
46 lines (38 loc) · 1.24 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import numpy as np
import pandas as pd
import PIL
from PIL import Image
import cv2
import os, os.path
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from models import Generator, Discriminator, GAN
from preprocess import Normalize, DeNormalize
if __name__ == '__main__':
imageDir = '...celeb_images/Part 1'
image_path_list = []
for file in os.listdir(imageDir):
image_path_list.append(os.path.join(imageDir, file))
image = np.empty([len(image_path_list), 64, 64, 3])
for indx, imagePath in enumerate(image_path_list):
if(indx>5000):
break
im = Image.open(imagePath).convert('RGB')
im = im.resize((64, 64))
im=np.array(im,dtype=np.float32)
im=im/255
im=Normalize(im, [0.5,0.5,0.5], [0.5,0.5,0.5])
image[indx,:,:,:] = im
image = image[:5000,:,:,:]
gan = GAN()
gen, dis = gan.train(image,200)
noise = Variable(torch.randn(1, 100, 1, 1)).float().cuda()
im = gen.feed_forward(noise)
im = im.permute(0, 2, 3, 1)
im = im.squeeze(0).cpu().detach().numpy()*255
im = DeNormalize(im,[0.5,0.5,0.5],[0.5,0.5,0.5])
plt.imshow(im)