Skip to content

Commit d51c765

Browse files
committed
+ prev
1 parent 49975ce commit d51c765

File tree

3,473 files changed

+1134865
-328
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

3,473 files changed

+1134865
-328
lines changed

ddpm.py

Lines changed: 108 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
batch_size = 32
1717
num_epochs = 100 # Just for the sake of demonstration
18-
total_timesteps = 1000
18+
total_timesteps = 300
1919
norm_groups = 8 # Number of groups used in GroupNormalization layer
2020
learning_rate = 2e-4
2121

@@ -90,6 +90,25 @@ def train_preprocessing(x):
9090
)
9191

9292

93+
94+
95+
# define various schedules for the TT timesteps
96+
def cosine_beta_schedule(timesteps, s=0.008):
97+
"""
98+
cosine schedule as proposed in https://arxiv.org/abs/2102.09672
99+
"""
100+
steps = timesteps + 1
101+
x = np.linspace(0, timesteps, steps)
102+
alphas_cumprod = np.cos(((x / timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2
103+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
104+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
105+
return np.clip(betas, 0.0001, 0.9999)
106+
107+
def linear_beta_schedule(timesteps):
108+
beta_start=1e-4,
109+
beta_end=0.02,
110+
return np.linspace(beta_start, beta_end, timesteps)
111+
93112
"""
94113
## Gaussian diffusion utilities
95114
We define the forward process and the reverse process
@@ -110,7 +129,7 @@ def __init__(
110129
self,
111130
beta_start=1e-4,
112131
beta_end=0.02,
113-
timesteps=1000,
132+
timesteps=300,
114133
clip_min=-1.0,
115134
clip_max=1.0,
116135
):
@@ -121,12 +140,8 @@ def __init__(
121140
self.clip_max = clip_max
122141

123142
# Define the linear variance schedule
124-
self.betas = betas = np.linspace(
125-
beta_start,
126-
beta_end,
127-
timesteps,
128-
dtype=np.float64, # Using float64 for better precision
129-
)
143+
self.betas = betas = cosine_beta_schedule(timesteps)
144+
130145
self.num_timesteps = int(timesteps)
131146

132147
alphas = 1.0 - betas
@@ -137,6 +152,8 @@ def __init__(
137152
self.alphas_cumprod = tf.constant(alphas_cumprod, dtype=tf.float32)
138153
self.alphas_cumprod_prev = tf.constant(alphas_cumprod_prev, dtype=tf.float32)
139154

155+
tf.print(self.alphas_cumprod_prev)
156+
140157
# Calculations for diffusion q(x_t | x_{t-1}) and others
141158
self.sqrt_alphas_cumprod = tf.constant(
142159
np.sqrt(alphas_cumprod), dtype=tf.float32
@@ -180,8 +197,7 @@ def __init__(
180197
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod),
181198
dtype=tf.float32,
182199
)
183-
tf.print(self.sqrt_one_minus_alphas_cumprod)
184-
200+
185201

186202
def _extract(self, a, t, x_shape):
187203
"""Extract some coefficients at specified timesteps,
@@ -670,7 +686,87 @@ def plot_images(
670686

671687
#print("tensorflow version:", tf.__version__)
672688
#print("keras version:", tf.keras.__version__)
673-
tf.saved_model.save(model.ema_network, "saved_model/")
689+
#tf.saved_model.save(model.ema_network, "saved_model/")
674690

675691
# Generate and plot some samples
676-
#model.plot_images(num_rows=1, num_cols=2)
692+
#model.plot_images(num_rows=2, num_cols=4)
693+
694+
"""
695+
import torch
696+
import torch.nn as nn
697+
from dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
698+
## 1. Define the noise schedule.
699+
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.from_numpy(gdf_util.betas.numpy()))
700+
701+
tf.compat.v1.enable_eager_execution()
702+
class XXModel(nn.Module):
703+
def __init__(self, ema_network):
704+
super().__init__()
705+
self.ema_network = ema_network
706+
707+
def forward(self, x, t):
708+
print(t)
709+
samples = tf.constant(np.transpose(x.numpy(), [0, 2, 3, 1] ))
710+
tt = tf.constant(t.numpy())
711+
pred_noise = self.ema_network.predict(
712+
[samples, tt], verbose=0, batch_size=1
713+
)
714+
return torch.from_numpy(np.transpose(pred_noise, [0, 3, 1, 2]))
715+
716+
xx_model = XXModel(model.ema_network)
717+
## 2. Convert your discrete-time `model` to the continuous-time
718+
## noise prediction model. Here is an example for a diffusion model
719+
## `model` with the noise prediction type ("noise") .
720+
model_fn = model_wrapper(
721+
xx_model,
722+
noise_schedule,
723+
model_type="noise", # or "x_start" or "v" or "score"
724+
model_kwargs={},
725+
)
726+
727+
728+
## 3. Define dpm-solver and sample by singlestep DPM-Solver.
729+
## (We recommend singlestep DPM-Solver for unconditional sampling)
730+
## You can adjust the `steps` to balance the computation
731+
## costs and the sample quality.
732+
dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
733+
## Can also try
734+
# dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
735+
736+
x_T = torch.randn(8, img_channels, img_size, img_size, dtype=torch.float32)
737+
738+
## You can use steps = 10, 12, 15, 20, 25, 50, 100.
739+
## Empirically, we find that steps in [10, 20] can generate quite good samples.
740+
## And steps = 20 can almost converge.
741+
x_sample = dpm_solver.sample(
742+
x_T,
743+
steps=20,
744+
order=3,
745+
skip_type="time_quadratic",
746+
method="singlestep",
747+
denoise_to_zero=True
748+
)
749+
750+
x_sample = np.transpose(x_sample.numpy(), [0, 2, 3, 1] )
751+
752+
generated_samples = (
753+
tf.clip_by_value(x_sample * 127.5 + 127.5, 0.0, 255.0)
754+
.numpy()
755+
.astype(np.uint8)
756+
)
757+
758+
num_rows = 2
759+
num_cols = 4
760+
figsize=(12, 5)
761+
_, ax = plt.subplots(num_rows, num_cols, figsize=figsize)
762+
for i, image in enumerate(generated_samples):
763+
if num_rows == 1:
764+
ax[i].imshow(image)
765+
ax[i].axis("off")
766+
else:
767+
ax[i // num_cols, i % num_cols].imshow(image)
768+
ax[i // num_cols, i % num_cols].axis("off")
769+
770+
plt.tight_layout()
771+
plt.show()
772+
"""

docs/64x64_cosin_300/group1-shard10of61.bin

Lines changed: 4 additions & 0 deletions
Large diffs are not rendered by default.

docs/64x64_cosin_300/group1-shard11of61.bin

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

docs/64x64_cosin_300/group1-shard12of61.bin

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

docs/64x64_cosin_300/group1-shard13of61.bin

Lines changed: 2 additions & 0 deletions
Large diffs are not rendered by default.

docs/64x64_cosin_300/group1-shard14of61.bin

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

docs/64x64_cosin_300/group1-shard15of61.bin

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

docs/64x64_cosin_300/group1-shard16of61.bin

Lines changed: 17 additions & 0 deletions
Large diffs are not rendered by default.

docs/64x64_cosin_300/group1-shard17of61.bin

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

docs/64x64_cosin_300/group1-shard18of61.bin

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)