Skip to content

Commit 5a1a65e

Browse files
committed
+ prev
1 parent d51c765 commit 5a1a65e

File tree

3 files changed

+175
-188
lines changed

3 files changed

+175
-188
lines changed

docs/64x64_cosin_300/worker.js

Lines changed: 170 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,9 @@ const url = self.location.toString();
44
let index = url.lastIndexOf('/');
55
index = url.lastIndexOf('/', index-1);
66
console.log(url, index);
7-
const TF_JS_URL = url.substring( 0, index) + "/@tensorflow/[email protected]/dist/tf.min.js";
8-
9-
console.log(TF_JS_URL);
10-
7+
const TF_JS_URL = url.substring( 0, index) + "/@tensorflow/tfjs/dist/tf.min.js";
8+
const TF_JS_CDN_URL = "https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js";
119
async function load_model() {
12-
try {
13-
//await import('https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf-backend-webgl.min.js');
14-
//await tf.setBackend('webgl');
15-
16-
//tf.wasm.setWasmPaths('https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/')
17-
//const backend = await import('https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf-backend-wasm.min.js');
18-
//tf.wasm.setWasmPaths('https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/');
19-
20-
21-
22-
23-
24-
self.postMessage({ type: 'progress', progress: 0.3, message: 'Loading model' });
25-
} catch (e) {
26-
self.postMessage({ type: 'error', message: 'Failed to load tensorflow WASM backend' });
27-
console.log('Failed to initialize tensorflow backend', e);
28-
return null;
29-
}
3010

3111
const model = await (async () => {
3212
try {
@@ -49,196 +29,203 @@ async function load_model() {
4929
return;
5030
}
5131

52-
self.postMessage({ type: 'ready' });
32+
5333
return model;
5434
}
5535

36+
async function main() {
37+
const model = await load_model();
5638

39+
const stable_sqrt = (number) => number > 0 && number < 0.001 ? Math.exp( 0.5 * Math.log(Math.max(number, 1e-20)) ) : Math.sqrt(number);
5740

58-
self.postMessage({ type: 'progress', progress: 0, message: 'Loading ' + TF_JS_URL });
59-
import(TF_JS_URL)
60-
.then(async () => {
61-
62-
const model = await load_model();
63-
64-
const stable_sqrt = (number) => Math.sqrt(number); // Math.exp( 0.5 * Math.log(Math.max(number, 1e-20)) );
65-
66-
// cosine schedule as proposed in https://arxiv.org/abs/2102.09672
67-
const cosine_beta_schedule = (timesteps, s) => {
68-
const steps = timesteps + 1.0;
69-
let alphas_cumprod = tf.linspace(0, timesteps, steps).arraySync().map( x => {
70-
return Math.pow( Math.cos(((x / timesteps) + s) / (1 + s) * Math.PI * 0.5), 2);
71-
});
72-
const base = alphas_cumprod[0];
73-
alphas_cumprod = alphas_cumprod.map( x => x / base);
74-
const betas = new Array(timesteps);
75-
for( let index = 0; index < betas.length; index++){
76-
const b = 1 - (alphas_cumprod[index+1] / alphas_cumprod[index]);
77-
betas[index] = Math.min(0.9999, Math.max(0.0001, b));
78-
}
79-
return betas;
80-
};
81-
82-
83-
const image_size = 64;
84-
const timesteps = 300;
85-
//const betas = tf.linspace( beta_start, beta_end, timesteps).arraySync();
86-
const betas = cosine_beta_schedule(timesteps, 0.008);
87-
const alphas = new Array(timesteps);
88-
const alphas_cumprod = new Array(timesteps);
89-
const alphas_cumprod_prev = new Array(timesteps);
90-
const sqrt_one_minus_alphas_cumprod = new Array(timesteps);
91-
const sqrt_recip_alphas_cumprod = new Array(timesteps);
92-
const sqrt_recipm1_alphas_cumprod = new Array(timesteps);
93-
const stddevs = new Array(timesteps);
94-
95-
alphas_cumprod_prev[0] = 1.0;
96-
97-
// prepare variables
98-
betas.forEach( (beta, index) => {
99-
alphas[index] = 1.0 - beta;
100-
101-
alphas_cumprod[index] = (index > 0) ? alphas_cumprod[index-1] * alphas[index] : alphas[index];
102-
103-
if( index < timesteps - 1 ){
104-
alphas_cumprod_prev[index+1] = alphas_cumprod[index];
105-
}
106-
107-
sqrt_recip_alphas_cumprod[index] = stable_sqrt(1.0 / alphas_cumprod[index]);
108-
sqrt_recipm1_alphas_cumprod[index] = stable_sqrt(1.0 / alphas_cumprod[index] - 1);
109-
110-
sqrt_one_minus_alphas_cumprod[index] = Math.sqrt(1.0 - alphas_cumprod[index]);
111-
112-
const variance = beta * (1.0 - alphas_cumprod_prev[index]) / (1.0 - alphas_cumprod[index]);
113-
stddevs[index] = Math.exp( 0.5 * Math.log(Math.max(variance, 1e-20)) ); // Log calculation clipped because the posterior variance is 0 at the beginning
41+
// cosine schedule as proposed in https://arxiv.org/abs/2102.09672
42+
const cosine_beta_schedule = (timesteps, s) => {
43+
const steps = timesteps + 1.0;
44+
let alphas_cumprod = tf.linspace(0, timesteps, steps).arraySync().map( x => {
45+
return Math.pow( Math.cos(((x / timesteps) + s) / (1 + s) * Math.PI * 0.5), 2);
11446
});
115-
116-
if( typeof(console.log) === 'function' ){
117-
console.log("betas=", betas);
118-
console.log("alphas=", alphas);
119-
console.log("alphas_cumprod=", alphas_cumprod);
120-
console.log("alphas_cumprod_prev=", alphas_cumprod_prev);
121-
console.log("sqrt_one_minus_alphas_cumprod=", sqrt_one_minus_alphas_cumprod);
122-
console.log("stddevs=", stddevs);
47+
const base = alphas_cumprod[0];
48+
alphas_cumprod = alphas_cumprod.map( x => x / base);
49+
const betas = new Array(timesteps);
50+
for( let index = 0; index < betas.length; index++){
51+
const b = 1 - (alphas_cumprod[index+1] / alphas_cumprod[index]);
52+
betas[index] = Math.min(0.9999, Math.max(0.0001, b));
12353
}
124-
125-
126-
const ddpm_p_sample = (image_input, timeStep) => {
127-
// When using WebGL backend, tf.Tensor memory must be managed explicitly (it is not sufficient to let a tf.Tensor go out of scope for its memory to be released).
128-
// Here we use an array to collect all tensors to be disposed when this method exits
129-
const collection = new Array();
130-
131-
const time_input = tf.tensor(timeStep, [1]/*shape*/, 'int32' /* model.signature.inputs.time_input.dtype */);
132-
collection.push(time_input);
133-
134-
const inputs = {
135-
time_input : time_input,
136-
image_input : image_input
137-
};
138-
139-
const epsilon = model.predict(inputs);
140-
collection.push(epsilon);
54+
return betas;
55+
};
14156

142-
const epsilon2 = epsilon.mul(Math.sqrt(1 - alphas_cumprod[timeStep]));
143-
collection.push(epsilon2);
57+
58+
const image_size = 64;
59+
const timesteps = 300;
60+
//const betas = tf.linspace( beta_start, beta_end, timesteps).arraySync();
61+
const betas = cosine_beta_schedule(timesteps, 0.008);
62+
const alphas = new Array(timesteps);
63+
const alphas_cumprod = new Array(timesteps);
64+
const alphas_cumprod_prev = new Array(timesteps);
65+
const sqrt_one_minus_alphas_cumprod = new Array(timesteps);
66+
const sqrt_recip_alphas_cumprod = new Array(timesteps);
67+
const sqrt_recipm1_alphas_cumprod = new Array(timesteps);
68+
const stddevs = new Array(timesteps);
69+
70+
alphas_cumprod_prev[0] = 1.0;
71+
72+
// prepare variables
73+
betas.forEach( (beta, index) => {
74+
alphas[index] = 1.0 - beta;
75+
76+
alphas_cumprod[index] = (index > 0) ? alphas_cumprod[index-1] * alphas[index] : alphas[index];
77+
78+
if( index < timesteps - 1 ){
79+
alphas_cumprod_prev[index+1] = alphas_cumprod[index];
80+
}
14481

145-
const xt_sub_epsilon2 = image_input.sub(epsilon2);
146-
collection.push(xt_sub_epsilon2);
82+
sqrt_recip_alphas_cumprod[index] = stable_sqrt(1.0 / alphas_cumprod[index]);
83+
sqrt_recipm1_alphas_cumprod[index] = stable_sqrt(1.0 / alphas_cumprod[index] - 1);
14784

148-
const x0 = xt_sub_epsilon2.div(Math.sqrt(alphas_cumprod[timeStep]));
149-
collection.push(x0);
85+
sqrt_one_minus_alphas_cumprod[index] = stable_sqrt(1.0 - alphas_cumprod[index]);
15086

151-
const clipped_x0 = tf.clipByValue(x0, -1.0, 1.0);
152-
collection.push(clipped_x0);
87+
const variance = beta * (1.0 - alphas_cumprod_prev[index]) / (1.0 - alphas_cumprod[index]);
88+
stddevs[index] = Math.exp( 0.5 * Math.log(Math.max(variance, 1e-20)) ); // Log calculation clipped because the posterior variance is 0 at the beginning
89+
});
15390

154-
const x0_coefficient = Math.sqrt(alphas_cumprod_prev[timeStep]) * betas[timeStep] / (1 - alphas_cumprod[timeStep]);
155-
const x0_multipled_by_coef = clipped_x0.mul(x0_coefficient);
156-
collection.push(x0_multipled_by_coef);
91+
if( typeof(console.log) === 'function' ){
92+
console.log("betas=", betas);
93+
console.log("alphas=", alphas);
94+
console.log("alphas_cumprod=", alphas_cumprod);
95+
console.log("alphas_cumprod_prev=", alphas_cumprod_prev);
96+
console.log("sqrt_one_minus_alphas_cumprod=", sqrt_one_minus_alphas_cumprod);
97+
console.log("stddevs=", stddevs);
98+
}
15799

158-
const xt_coefficient = Math.sqrt(alphas[timeStep]) * (1 - alphas_cumprod_prev[timeStep]) / (1 - alphas_cumprod[timeStep]);
159-
const xt_multipled_by_coef = image_input.mul(xt_coefficient);
160-
collection.push(xt_multipled_by_coef);
161-
162-
const mean = x0_multipled_by_coef.add(xt_multipled_by_coef);
163-
collection.push(mean);
164-
165-
let normal_noise = tf.randomNormal(image_input.shape, 0/*mean*/, 1/*stddev*/, 'float32', Math.random()*10000/*seed*/);
166-
collection.push(normal_noise);
100+
self.postMessage({ type: 'ready' });
167101

168-
normal_noise = normal_noise.mul(stddevs[timeStep]);
169-
collection.push(normal_noise);
102+
const ddpm_p_sample = (image_input, timeStep) => {
103+
// When using WebGL backend, tf.Tensor memory must be managed explicitly (it is not sufficient to let a tf.Tensor go out of scope for its memory to be released).
104+
// Here we use an array to collect all tensors to be disposed when this method exits
105+
const collection = new Array();
106+
107+
const time_input = tf.tensor(timeStep, [1]/*shape*/, 'int32' /* model.signature.inputs.time_input.dtype */);
108+
collection.push(time_input);
170109

171-
const xt_minus_one = mean.add(normal_noise);
172-
173-
collection.forEach( (t) => t.dispose() );
174-
return xt_minus_one;
110+
const inputs = {
111+
time_input : time_input,
112+
image_input : image_input
175113
};
176114

177-
const map = {};
178-
let keySource = 0;
179-
self.onmessage = (evt) => {
180-
const request = evt.data;
181-
182-
switch (request.type) {
183-
case 'ddpmStart': {
184-
const timestep = timesteps - 1;
185-
const shape = [1, image_size, image_size, 3];
186-
const initial_noises = tf.randomNormal(shape, 0/*mean*/, 1/*stddev*/, 'float32', Math.random()*10000/*seed*/);
187-
const image = ddpm_p_sample(initial_noises, timestep);
188-
initial_noises.dispose();
189-
const key = ++keySource;
190-
map[key] = {
191-
step : timestep,
192-
image : image
193-
};
194-
115+
const epsilon = model.predict(inputs);
116+
collection.push(epsilon);
117+
118+
const epsilon2 = epsilon.mul(stable_sqrt(1 - alphas_cumprod[timeStep]));
119+
collection.push(epsilon2);
120+
121+
const xt_sub_epsilon2 = image_input.sub(epsilon2);
122+
collection.push(xt_sub_epsilon2);
123+
124+
const x0 = xt_sub_epsilon2.div(stable_sqrt(alphas_cumprod[timeStep]));
125+
collection.push(x0);
126+
127+
const clipped_x0 = tf.clipByValue(x0, -1.0, 1.0);
128+
collection.push(clipped_x0);
129+
130+
const x0_coefficient = stable_sqrt(alphas_cumprod_prev[timeStep]) * betas[timeStep] / (1 - alphas_cumprod[timeStep]);
131+
const x0_multipled_by_coef = clipped_x0.mul(x0_coefficient);
132+
collection.push(x0_multipled_by_coef);
133+
134+
const xt_coefficient = stable_sqrt(alphas[timeStep]) * (1 - alphas_cumprod_prev[timeStep]) / (1 - alphas_cumprod[timeStep]);
135+
const xt_multipled_by_coef = image_input.mul(xt_coefficient);
136+
collection.push(xt_multipled_by_coef);
137+
138+
const mean = x0_multipled_by_coef.add(xt_multipled_by_coef);
139+
collection.push(mean);
140+
141+
let normal_noise = tf.randomNormal(image_input.shape, 0/*mean*/, 1/*stddev*/, 'float32', Math.random()*10000/*seed*/);
142+
collection.push(normal_noise);
143+
144+
normal_noise = normal_noise.mul(stddevs[timeStep]);
145+
collection.push(normal_noise);
146+
147+
const xt_minus_one = mean.add(normal_noise);
148+
149+
collection.forEach( (t) => t.dispose() );
150+
return xt_minus_one;
151+
};
152+
153+
const map = {};
154+
let keySource = 0;
155+
self.onmessage = (evt) => {
156+
const request = evt.data;
157+
158+
switch (request.type) {
159+
case 'ddpmStart': {
160+
const timestep = timesteps - 1;
161+
const shape = [1, image_size, image_size, 3];
162+
const initial_noises = tf.randomNormal(shape, 0/*mean*/, 1/*stddev*/, 'float32', Math.random()*10000/*seed*/);
163+
const image = ddpm_p_sample(initial_noises, timestep);
164+
initial_noises.dispose();
165+
const key = ++keySource;
166+
map[key] = {
167+
step : timestep,
168+
image : image
169+
};
170+
171+
self.postMessage({ type: 'reply', id: request.id, data: {
172+
step : timestep,
173+
image : image.arraySync(),
174+
key : key,
175+
percent : (timesteps - timestep) / (1.0 * timesteps),
176+
} });
177+
break;
178+
}
179+
180+
case 'ddpmNext': {
181+
const prev = map[request.kwargs && request.kwargs.key];
182+
if(!prev) {
183+
console.log('invalid request');
184+
} else {
185+
delete map[request.kwargs.key];
186+
const timestep = prev.step - 1;
187+
const image = ddpm_p_sample(prev.image, timestep);
188+
prev.image.dispose();
189+
if(timestep > 0) {
190+
map[request.kwargs.key] = {
191+
step : timestep,
192+
image : image
193+
};
194+
}
195195
self.postMessage({ type: 'reply', id: request.id, data: {
196196
step : timestep,
197197
image : image.arraySync(),
198-
key : key,
199-
percent : (timesteps - timestep) / (1.0 * timesteps),
198+
key : request.kwargs.key,
199+
percent : (timesteps - timestep) / (1.0 * timesteps)
200200
} });
201-
break;
202201
}
203-
204-
case 'ddpmNext': {
205-
const prev = map[request.kwargs && request.kwargs.key];
206-
if(!prev) {
207-
console.log('invalid request');
208-
} else {
209-
delete map[request.kwargs.key];
210-
const timestep = prev.step - 1;
211-
const image = ddpm_p_sample(prev.image, timestep);
212-
prev.image.dispose();
213-
if(timestep > 0) {
214-
map[request.kwargs.key] = {
215-
step : timestep,
216-
image : image
217-
};
218-
}
219-
self.postMessage({ type: 'reply', id: request.id, data: {
220-
step : timestep,
221-
image : image.arraySync(),
222-
key : request.kwargs.key,
223-
percent : (timesteps - timestep) / (1.0 * timesteps)
224-
} });
225-
}
226-
break;
227-
}
228-
202+
break;
229203
}
230-
};
231-
232-
204+
205+
}
206+
};
233207

234208

209+
}
235210

236211

212+
self.postMessage({ type: 'progress', progress: 0, message: 'Loading ' + TF_JS_URL });
213+
import(TF_JS_URL)
214+
.then(async () => {
215+
216+
await main();
237217

238218
})
239-
.catch((err) => {
240-
self.postMessage({ type: 'error', message: 'Unable to load tensorflow.js' });
241-
console.log('Unable to load tensorflow.js', err);
219+
.catch(async (err) => {
220+
221+
try{
222+
await import(TF_JS_CDN_URL);
223+
await main();
224+
}
225+
catch{
226+
self.postMessage({ type: 'error', message: 'Unable to load ' + TF_JS_URL });
227+
console.log('Unable to load ' + TF_JS_URL, err);
228+
}
242229
});
243230

244231

0 commit comments

Comments
 (0)