Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
zaidalyafeai committed Dec 26, 2020
1 parent ea39062 commit 9f16266
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 1 deletion.
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,12 @@
# gan-mosaics
# gan-mosaics

The models were trained using a Stylegan2-Ada model for two days.

## Colab Notebook
You can use this notebook to traversing and interpolation.

## Size-256
### Generated Images
![alt text](mosaic-256.png)
### Time Elapse
![alt text](time-elapse-256.gif)
84 changes: 84 additions & 0 deletions gan-mosaics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Download the model of choice
import argparse
import numpy as np
import PIL.Image
import dnnlib
import dnnlib.tflib as tflib
import re
import sys
from io import BytesIO
import IPython.display
from math import ceil
from PIL import Image, ImageDraw
import os
import pickle
from utils import log_progress, imshow

class Mosaic:
def __init__(self, path):
dnnlib.tflib.init_tf()

print('Loading networks from "%s"...' % path)
with dnnlib.util.open_url(path) as fp:
self._G, self._D, self.Gs = pickle.load(fp)
self.noise_vars = [var for name, var in self.Gs.components.synthesis.vars.items() if name.startswith('noise')]
# Generates a list of images, based on a list of latent vectors (Z), and a list (or a single constant) of truncation_psi's.
def generate_images_in_w_space(self, dlatents, truncation_psi):
Gs_kwargs = dnnlib.EasyDict()
Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
Gs_kwargs.randomize_noise = False
Gs_kwargs.truncation_psi = truncation_psi
# dlatent_avg = self.Gs.get_var('dlatent_avg') # [component]

imgs = []
for _, dlatent in log_progress(enumerate(dlatents), name = "Generating images"):
#row_dlatents = (dlatent[np.newaxis] - dlatent_avg) * np.reshape(truncation_psi, [-1, 1, 1]) + dlatent_avg
# dl = (dlatent-dlatent_avg)*truncation_psi + dlatent_avg
row_images = self.Gs.components.synthesis.run(dlatent, **Gs_kwargs)
imgs.append(PIL.Image.fromarray(row_images[0], 'RGB'))
return imgs

def generate_images(self, zs, truncation_psi):
Gs_kwargs = dnnlib.EasyDict()
Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
Gs_kwargs.randomize_noise = False
if not isinstance(truncation_psi, list):
truncation_psi = [truncation_psi] * len(zs)

imgs = []
for z_idx, z in log_progress(enumerate(zs), size = len(zs), name = "Generating images"):
Gs_kwargs.truncation_psi = truncation_psi[z_idx]
noise_rnd = np.random.RandomState(1) # fix noise
tflib.set_vars({var: noise_rnd.randn(*var.shape.as_list()) for var in self.noise_vars}) # [height, width]
images = self.Gs.run(z, None, **Gs_kwargs) # [minibatch, height, width, channel]
imgs.append(PIL.Image.fromarray(images[0], 'RGB'))
return imgs

def generate_zs_from_seeds(self, seeds):
zs = []
for _, seed in enumerate(seeds):
rnd = np.random.RandomState(seed)
z = rnd.randn(1, *self.Gs.input_shape[1:]) # [minibatch, component]
zs.append(z)
return zs

# Generates a list of images, based on a list of seed for latent vectors (Z), and a list (or a single constant) of truncation_psi's.
def generate_images_from_seeds(self, seeds, truncation_psi):
return imshow(self.generate_images(self.generate_zs_from_seeds(seeds), truncation_psi))


def convertZtoW(self, latent, truncation_psi=0.7, truncation_cutoff=9):
dlatent = self.Gs.components.mapping.run(latent, None) # [seed, layer, component]
dlatent_avg = self.Gs.get_var('dlatent_avg') # [component]
for i in range(truncation_cutoff):
dlatent[0][i] = (dlatent[0][i]-dlatent_avg)*truncation_psi + dlatent_avg

return dlatent

def interpolate(self, zs, steps):
out = []
for i in range(len(zs)-1):
for index in range(steps):
fraction = index/float(steps)
out.append(zs[i+1]*fraction + zs[i]*(1-fraction))
return out
Binary file added mosaic-256.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added time-elapse-256.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
106 changes: 106 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@

import numpy as np
import PIL.Image
import sys
from io import BytesIO
import IPython.display
import numpy as np
from math import ceil
from PIL import Image, ImageDraw
import os

def imshow(a, format='png', jpeg_fallback=True):
a = np.asarray(a, dtype=np.uint8)
str_file = BytesIO()
PIL.Image.fromarray(a).save(str_file, format)
im_data = str_file.getvalue()
try:
disp = IPython.display.display(IPython.display.Image(im_data))
except IOError:
if jpeg_fallback and format != 'jpeg':
print ('Warning: image was too large to display in format "{}"; '
'trying jpeg instead.').format(format)
return imshow(a, format='jpeg')
else:
raise
return disp

def show_array(self, a, fmt='png'):
a = np.uint8(a)
f = StringIO()
PIL.Image.fromarray(a).save(f, fmt)
IPython.display.display(IPython.display.Image(data=f.getvalue()))


def clamp(x, minimum, maximum):
return max(minimum, min(x, maximum))

def create_image_grid(images, scale=0.25, rows=1):
w,h = images[0].size
w = int(w*scale)
h = int(h*scale)
height = rows*h
cols = ceil(len(images) / rows)
width = cols*w
canvas = PIL.Image.new('RGBA', (width,height), 'white')
for i,img in enumerate(images):
img = img.resize((w,h), PIL.Image.ANTIALIAS)
canvas.paste(img, (w*(i % cols), h*(i // cols)))
return canvas

# Taken from https://github.com/alexanderkuk/log-progress
def log_progress(sequence, every=1, size=None, name='Items'):
from ipywidgets import IntProgress, HTML, VBox
from IPython.display import display

is_iterator = False
if size is None:
try:
size = len(sequence)
except TypeError:
is_iterator = True
if size is not None:
if every is None:
if size <= 200:
every = 1
else:
every = int(size / 200) # every 0.5%
else:
assert every is not None, 'sequence is iterator, set every'

if is_iterator:
progress = IntProgress(min=0, max=1, value=1)
progress.bar_style = 'info'
else:
progress = IntProgress(min=0, max=size, value=0)
label = HTML()
box = VBox(children=[label, progress])
display(box)

index = 0
try:
for index, record in enumerate(sequence, 1):
if index == 1 or index % every == 0:
if is_iterator:
label.value = '{name}: {index} / ?'.format(
name=name,
index=index
)
else:
progress.value = index
label.value = u'{name}: {index} / {size}'.format(
name=name,
index=index,
size=size
)
yield record
except:
progress.bar_style = 'danger'
raise
else:
progress.bar_style = 'success'
progress.value = index
label.value = "{name}: {index}".format(
name=name,
index=str(index or '?')
)

0 comments on commit 9f16266

Please sign in to comment.