15
15
16
16
batch_size = 32
17
17
num_epochs = 100 # Just for the sake of demonstration
18
- total_timesteps = 1000
18
+ total_timesteps = 300
19
19
norm_groups = 8 # Number of groups used in GroupNormalization layer
20
20
learning_rate = 2e-4
21
21
@@ -90,6 +90,25 @@ def train_preprocessing(x):
90
90
)
91
91
92
92
93
+
94
+
95
+ # define various schedules for the TT timesteps
96
+ def cosine_beta_schedule (timesteps , s = 0.008 ):
97
+ """
98
+ cosine schedule as proposed in https://arxiv.org/abs/2102.09672
99
+ """
100
+ steps = timesteps + 1
101
+ x = np .linspace (0 , timesteps , steps )
102
+ alphas_cumprod = np .cos (((x / timesteps ) + s ) / (1 + s ) * np .pi * 0.5 ) ** 2
103
+ alphas_cumprod = alphas_cumprod / alphas_cumprod [0 ]
104
+ betas = 1 - (alphas_cumprod [1 :] / alphas_cumprod [:- 1 ])
105
+ return np .clip (betas , 0.0001 , 0.9999 )
106
+
107
+ def linear_beta_schedule (timesteps ):
108
+ beta_start = 1e-4 ,
109
+ beta_end = 0.02 ,
110
+ return np .linspace (beta_start , beta_end , timesteps )
111
+
93
112
"""
94
113
## Gaussian diffusion utilities
95
114
We define the forward process and the reverse process
@@ -110,7 +129,7 @@ def __init__(
110
129
self ,
111
130
beta_start = 1e-4 ,
112
131
beta_end = 0.02 ,
113
- timesteps = 1000 ,
132
+ timesteps = 300 ,
114
133
clip_min = - 1.0 ,
115
134
clip_max = 1.0 ,
116
135
):
@@ -121,12 +140,8 @@ def __init__(
121
140
self .clip_max = clip_max
122
141
123
142
# Define the linear variance schedule
124
- self .betas = betas = np .linspace (
125
- beta_start ,
126
- beta_end ,
127
- timesteps ,
128
- dtype = np .float64 , # Using float64 for better precision
129
- )
143
+ self .betas = betas = cosine_beta_schedule (timesteps )
144
+
130
145
self .num_timesteps = int (timesteps )
131
146
132
147
alphas = 1.0 - betas
@@ -137,6 +152,8 @@ def __init__(
137
152
self .alphas_cumprod = tf .constant (alphas_cumprod , dtype = tf .float32 )
138
153
self .alphas_cumprod_prev = tf .constant (alphas_cumprod_prev , dtype = tf .float32 )
139
154
155
+ tf .print (self .alphas_cumprod_prev )
156
+
140
157
# Calculations for diffusion q(x_t | x_{t-1}) and others
141
158
self .sqrt_alphas_cumprod = tf .constant (
142
159
np .sqrt (alphas_cumprod ), dtype = tf .float32
@@ -180,8 +197,7 @@ def __init__(
180
197
(1.0 - alphas_cumprod_prev ) * np .sqrt (alphas ) / (1.0 - alphas_cumprod ),
181
198
dtype = tf .float32 ,
182
199
)
183
- tf .print (self .sqrt_one_minus_alphas_cumprod )
184
-
200
+
185
201
186
202
def _extract (self , a , t , x_shape ):
187
203
"""Extract some coefficients at specified timesteps,
@@ -670,7 +686,87 @@ def plot_images(
670
686
671
687
#print("tensorflow version:", tf.__version__)
672
688
#print("keras version:", tf.keras.__version__)
673
- tf .saved_model .save (model .ema_network , "saved_model/" )
689
+ # tf.saved_model.save(model.ema_network, "saved_model/")
674
690
675
691
# Generate and plot some samples
676
- #model.plot_images(num_rows=1, num_cols=2)
692
+ #model.plot_images(num_rows=2, num_cols=4)
693
+
694
+ """
695
+ import torch
696
+ import torch.nn as nn
697
+ from dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
698
+ ## 1. Define the noise schedule.
699
+ noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.from_numpy(gdf_util.betas.numpy()))
700
+
701
+ tf.compat.v1.enable_eager_execution()
702
+ class XXModel(nn.Module):
703
+ def __init__(self, ema_network):
704
+ super().__init__()
705
+ self.ema_network = ema_network
706
+
707
+ def forward(self, x, t):
708
+ print(t)
709
+ samples = tf.constant(np.transpose(x.numpy(), [0, 2, 3, 1] ))
710
+ tt = tf.constant(t.numpy())
711
+ pred_noise = self.ema_network.predict(
712
+ [samples, tt], verbose=0, batch_size=1
713
+ )
714
+ return torch.from_numpy(np.transpose(pred_noise, [0, 3, 1, 2]))
715
+
716
+ xx_model = XXModel(model.ema_network)
717
+ ## 2. Convert your discrete-time `model` to the continuous-time
718
+ ## noise prediction model. Here is an example for a diffusion model
719
+ ## `model` with the noise prediction type ("noise") .
720
+ model_fn = model_wrapper(
721
+ xx_model,
722
+ noise_schedule,
723
+ model_type="noise", # or "x_start" or "v" or "score"
724
+ model_kwargs={},
725
+ )
726
+
727
+
728
+ ## 3. Define dpm-solver and sample by singlestep DPM-Solver.
729
+ ## (We recommend singlestep DPM-Solver for unconditional sampling)
730
+ ## You can adjust the `steps` to balance the computation
731
+ ## costs and the sample quality.
732
+ dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
733
+ ## Can also try
734
+ # dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
735
+
736
+ x_T = torch.randn(8, img_channels, img_size, img_size, dtype=torch.float32)
737
+
738
+ ## You can use steps = 10, 12, 15, 20, 25, 50, 100.
739
+ ## Empirically, we find that steps in [10, 20] can generate quite good samples.
740
+ ## And steps = 20 can almost converge.
741
+ x_sample = dpm_solver.sample(
742
+ x_T,
743
+ steps=20,
744
+ order=3,
745
+ skip_type="time_quadratic",
746
+ method="singlestep",
747
+ denoise_to_zero=True
748
+ )
749
+
750
+ x_sample = np.transpose(x_sample.numpy(), [0, 2, 3, 1] )
751
+
752
+ generated_samples = (
753
+ tf.clip_by_value(x_sample * 127.5 + 127.5, 0.0, 255.0)
754
+ .numpy()
755
+ .astype(np.uint8)
756
+ )
757
+
758
+ num_rows = 2
759
+ num_cols = 4
760
+ figsize=(12, 5)
761
+ _, ax = plt.subplots(num_rows, num_cols, figsize=figsize)
762
+ for i, image in enumerate(generated_samples):
763
+ if num_rows == 1:
764
+ ax[i].imshow(image)
765
+ ax[i].axis("off")
766
+ else:
767
+ ax[i // num_cols, i % num_cols].imshow(image)
768
+ ax[i // num_cols, i % num_cols].axis("off")
769
+
770
+ plt.tight_layout()
771
+ plt.show()
772
+ """
0 commit comments