@@ -4,29 +4,9 @@ const url = self.location.toString();
4
4
let index = url . lastIndexOf ( '/' ) ;
5
5
index = url . lastIndexOf ( '/' , index - 1 ) ;
6
6
console . log ( url , index ) ;
7
- const TF_JS_URL = url . substring ( 0 , index ) + "/@tensorflow/[email protected] /dist/tf.min.js" ;
8
-
9
- console . log ( TF_JS_URL ) ;
10
-
7
+ const TF_JS_URL = url . substring ( 0 , index ) + "/@tensorflow/tfjs/dist/tf.min.js" ;
8
+ const TF_JS_CDN_URL = "https://cdn.jsdelivr.net/npm/@tensorflow/[email protected] /dist/tf.min.js" ;
11
9
async function load_model ( ) {
12
- try {
13
- //await import('https://cdn.jsdelivr.net/npm/@tensorflow/[email protected] /dist/tf-backend-webgl.min.js');
14
- //await tf.setBackend('webgl');
15
-
16
- //tf.wasm.setWasmPaths('https://cdn.jsdelivr.net/npm/@tensorflow/[email protected] /dist/')
17
- //const backend = await import('https://cdn.jsdelivr.net/npm/@tensorflow/[email protected] /dist/tf-backend-wasm.min.js');
18
- //tf.wasm.setWasmPaths('https://cdn.jsdelivr.net/npm/@tensorflow/[email protected] /dist/');
19
-
20
-
21
-
22
-
23
-
24
- self . postMessage ( { type : 'progress' , progress : 0.3 , message : 'Loading model' } ) ;
25
- } catch ( e ) {
26
- self . postMessage ( { type : 'error' , message : 'Failed to load tensorflow WASM backend' } ) ;
27
- console . log ( 'Failed to initialize tensorflow backend' , e ) ;
28
- return null ;
29
- }
30
10
31
11
const model = await ( async ( ) => {
32
12
try {
@@ -49,196 +29,203 @@ async function load_model() {
49
29
return ;
50
30
}
51
31
52
- self . postMessage ( { type : 'ready' } ) ;
32
+
53
33
return model ;
54
34
}
55
35
36
+ async function main ( ) {
37
+ const model = await load_model ( ) ;
56
38
39
+ const stable_sqrt = ( number ) => number > 0 && number < 0.001 ? Math . exp ( 0.5 * Math . log ( Math . max ( number , 1e-20 ) ) ) : Math . sqrt ( number ) ;
57
40
58
- self . postMessage ( { type : 'progress' , progress : 0 , message : 'Loading ' + TF_JS_URL } ) ;
59
- import ( TF_JS_URL )
60
- . then ( async ( ) => {
61
-
62
- const model = await load_model ( ) ;
63
-
64
- const stable_sqrt = ( number ) => Math . sqrt ( number ) ; // Math.exp( 0.5 * Math.log(Math.max(number, 1e-20)) );
65
-
66
- // cosine schedule as proposed in https://arxiv.org/abs/2102.09672
67
- const cosine_beta_schedule = ( timesteps , s ) => {
68
- const steps = timesteps + 1.0 ;
69
- let alphas_cumprod = tf . linspace ( 0 , timesteps , steps ) . arraySync ( ) . map ( x => {
70
- return Math . pow ( Math . cos ( ( ( x / timesteps ) + s ) / ( 1 + s ) * Math . PI * 0.5 ) , 2 ) ;
71
- } ) ;
72
- const base = alphas_cumprod [ 0 ] ;
73
- alphas_cumprod = alphas_cumprod . map ( x => x / base ) ;
74
- const betas = new Array ( timesteps ) ;
75
- for ( let index = 0 ; index < betas . length ; index ++ ) {
76
- const b = 1 - ( alphas_cumprod [ index + 1 ] / alphas_cumprod [ index ] ) ;
77
- betas [ index ] = Math . min ( 0.9999 , Math . max ( 0.0001 , b ) ) ;
78
- }
79
- return betas ;
80
- } ;
81
-
82
-
83
- const image_size = 64 ;
84
- const timesteps = 300 ;
85
- //const betas = tf.linspace( beta_start, beta_end, timesteps).arraySync();
86
- const betas = cosine_beta_schedule ( timesteps , 0.008 ) ;
87
- const alphas = new Array ( timesteps ) ;
88
- const alphas_cumprod = new Array ( timesteps ) ;
89
- const alphas_cumprod_prev = new Array ( timesteps ) ;
90
- const sqrt_one_minus_alphas_cumprod = new Array ( timesteps ) ;
91
- const sqrt_recip_alphas_cumprod = new Array ( timesteps ) ;
92
- const sqrt_recipm1_alphas_cumprod = new Array ( timesteps ) ;
93
- const stddevs = new Array ( timesteps ) ;
94
-
95
- alphas_cumprod_prev [ 0 ] = 1.0 ;
96
-
97
- // prepare variables
98
- betas . forEach ( ( beta , index ) => {
99
- alphas [ index ] = 1.0 - beta ;
100
-
101
- alphas_cumprod [ index ] = ( index > 0 ) ? alphas_cumprod [ index - 1 ] * alphas [ index ] : alphas [ index ] ;
102
-
103
- if ( index < timesteps - 1 ) {
104
- alphas_cumprod_prev [ index + 1 ] = alphas_cumprod [ index ] ;
105
- }
106
-
107
- sqrt_recip_alphas_cumprod [ index ] = stable_sqrt ( 1.0 / alphas_cumprod [ index ] ) ;
108
- sqrt_recipm1_alphas_cumprod [ index ] = stable_sqrt ( 1.0 / alphas_cumprod [ index ] - 1 ) ;
109
-
110
- sqrt_one_minus_alphas_cumprod [ index ] = Math . sqrt ( 1.0 - alphas_cumprod [ index ] ) ;
111
-
112
- const variance = beta * ( 1.0 - alphas_cumprod_prev [ index ] ) / ( 1.0 - alphas_cumprod [ index ] ) ;
113
- stddevs [ index ] = Math . exp ( 0.5 * Math . log ( Math . max ( variance , 1e-20 ) ) ) ; // Log calculation clipped because the posterior variance is 0 at the beginning
41
+ // cosine schedule as proposed in https://arxiv.org/abs/2102.09672
42
+ const cosine_beta_schedule = ( timesteps , s ) => {
43
+ const steps = timesteps + 1.0 ;
44
+ let alphas_cumprod = tf . linspace ( 0 , timesteps , steps ) . arraySync ( ) . map ( x => {
45
+ return Math . pow ( Math . cos ( ( ( x / timesteps ) + s ) / ( 1 + s ) * Math . PI * 0.5 ) , 2 ) ;
114
46
} ) ;
115
-
116
- if ( typeof ( console . log ) === 'function' ) {
117
- console . log ( "betas=" , betas ) ;
118
- console . log ( "alphas=" , alphas ) ;
119
- console . log ( "alphas_cumprod=" , alphas_cumprod ) ;
120
- console . log ( "alphas_cumprod_prev=" , alphas_cumprod_prev ) ;
121
- console . log ( "sqrt_one_minus_alphas_cumprod=" , sqrt_one_minus_alphas_cumprod ) ;
122
- console . log ( "stddevs=" , stddevs ) ;
47
+ const base = alphas_cumprod [ 0 ] ;
48
+ alphas_cumprod = alphas_cumprod . map ( x => x / base ) ;
49
+ const betas = new Array ( timesteps ) ;
50
+ for ( let index = 0 ; index < betas . length ; index ++ ) {
51
+ const b = 1 - ( alphas_cumprod [ index + 1 ] / alphas_cumprod [ index ] ) ;
52
+ betas [ index ] = Math . min ( 0.9999 , Math . max ( 0.0001 , b ) ) ;
123
53
}
124
-
125
-
126
- const ddpm_p_sample = ( image_input , timeStep ) => {
127
- // When using WebGL backend, tf.Tensor memory must be managed explicitly (it is not sufficient to let a tf.Tensor go out of scope for its memory to be released).
128
- // Here we use an array to collect all tensors to be disposed when this method exits
129
- const collection = new Array ( ) ;
130
-
131
- const time_input = tf . tensor ( timeStep , [ 1 ] /*shape*/ , 'int32' /* model.signature.inputs.time_input.dtype */ ) ;
132
- collection . push ( time_input ) ;
133
-
134
- const inputs = {
135
- time_input : time_input ,
136
- image_input : image_input
137
- } ;
138
-
139
- const epsilon = model . predict ( inputs ) ;
140
- collection . push ( epsilon ) ;
54
+ return betas ;
55
+ } ;
141
56
142
- const epsilon2 = epsilon . mul ( Math . sqrt ( 1 - alphas_cumprod [ timeStep ] ) ) ;
143
- collection . push ( epsilon2 ) ;
57
+
58
+ const image_size = 64 ;
59
+ const timesteps = 300 ;
60
+ //const betas = tf.linspace( beta_start, beta_end, timesteps).arraySync();
61
+ const betas = cosine_beta_schedule ( timesteps , 0.008 ) ;
62
+ const alphas = new Array ( timesteps ) ;
63
+ const alphas_cumprod = new Array ( timesteps ) ;
64
+ const alphas_cumprod_prev = new Array ( timesteps ) ;
65
+ const sqrt_one_minus_alphas_cumprod = new Array ( timesteps ) ;
66
+ const sqrt_recip_alphas_cumprod = new Array ( timesteps ) ;
67
+ const sqrt_recipm1_alphas_cumprod = new Array ( timesteps ) ;
68
+ const stddevs = new Array ( timesteps ) ;
69
+
70
+ alphas_cumprod_prev [ 0 ] = 1.0 ;
71
+
72
+ // prepare variables
73
+ betas . forEach ( ( beta , index ) => {
74
+ alphas [ index ] = 1.0 - beta ;
75
+
76
+ alphas_cumprod [ index ] = ( index > 0 ) ? alphas_cumprod [ index - 1 ] * alphas [ index ] : alphas [ index ] ;
77
+
78
+ if ( index < timesteps - 1 ) {
79
+ alphas_cumprod_prev [ index + 1 ] = alphas_cumprod [ index ] ;
80
+ }
144
81
145
- const xt_sub_epsilon2 = image_input . sub ( epsilon2 ) ;
146
- collection . push ( xt_sub_epsilon2 ) ;
82
+ sqrt_recip_alphas_cumprod [ index ] = stable_sqrt ( 1.0 / alphas_cumprod [ index ] ) ;
83
+ sqrt_recipm1_alphas_cumprod [ index ] = stable_sqrt ( 1.0 / alphas_cumprod [ index ] - 1 ) ;
147
84
148
- const x0 = xt_sub_epsilon2 . div ( Math . sqrt ( alphas_cumprod [ timeStep ] ) ) ;
149
- collection . push ( x0 ) ;
85
+ sqrt_one_minus_alphas_cumprod [ index ] = stable_sqrt ( 1.0 - alphas_cumprod [ index ] ) ;
150
86
151
- const clipped_x0 = tf . clipByValue ( x0 , - 1.0 , 1.0 ) ;
152
- collection . push ( clipped_x0 ) ;
87
+ const variance = beta * ( 1.0 - alphas_cumprod_prev [ index ] ) / ( 1.0 - alphas_cumprod [ index ] ) ;
88
+ stddevs [ index ] = Math . exp ( 0.5 * Math . log ( Math . max ( variance , 1e-20 ) ) ) ; // Log calculation clipped because the posterior variance is 0 at the beginning
89
+ } ) ;
153
90
154
- const x0_coefficient = Math . sqrt ( alphas_cumprod_prev [ timeStep ] ) * betas [ timeStep ] / ( 1 - alphas_cumprod [ timeStep ] ) ;
155
- const x0_multipled_by_coef = clipped_x0 . mul ( x0_coefficient ) ;
156
- collection . push ( x0_multipled_by_coef ) ;
91
+ if ( typeof ( console . log ) === 'function' ) {
92
+ console . log ( "betas=" , betas ) ;
93
+ console . log ( "alphas=" , alphas ) ;
94
+ console . log ( "alphas_cumprod=" , alphas_cumprod ) ;
95
+ console . log ( "alphas_cumprod_prev=" , alphas_cumprod_prev ) ;
96
+ console . log ( "sqrt_one_minus_alphas_cumprod=" , sqrt_one_minus_alphas_cumprod ) ;
97
+ console . log ( "stddevs=" , stddevs ) ;
98
+ }
157
99
158
- const xt_coefficient = Math . sqrt ( alphas [ timeStep ] ) * ( 1 - alphas_cumprod_prev [ timeStep ] ) / ( 1 - alphas_cumprod [ timeStep ] ) ;
159
- const xt_multipled_by_coef = image_input . mul ( xt_coefficient ) ;
160
- collection . push ( xt_multipled_by_coef ) ;
161
-
162
- const mean = x0_multipled_by_coef . add ( xt_multipled_by_coef ) ;
163
- collection . push ( mean ) ;
164
-
165
- let normal_noise = tf . randomNormal ( image_input . shape , 0 /*mean*/ , 1 /*stddev*/ , 'float32' , Math . random ( ) * 10000 /*seed*/ ) ;
166
- collection . push ( normal_noise ) ;
100
+ self . postMessage ( { type : 'ready' } ) ;
167
101
168
- normal_noise = normal_noise . mul ( stddevs [ timeStep ] ) ;
169
- collection . push ( normal_noise ) ;
102
+ const ddpm_p_sample = ( image_input , timeStep ) => {
103
+ // When using WebGL backend, tf.Tensor memory must be managed explicitly (it is not sufficient to let a tf.Tensor go out of scope for its memory to be released).
104
+ // Here we use an array to collect all tensors to be disposed when this method exits
105
+ const collection = new Array ( ) ;
106
+
107
+ const time_input = tf . tensor ( timeStep , [ 1 ] /*shape*/ , 'int32' /* model.signature.inputs.time_input.dtype */ ) ;
108
+ collection . push ( time_input ) ;
170
109
171
- const xt_minus_one = mean . add ( normal_noise ) ;
172
-
173
- collection . forEach ( ( t ) => t . dispose ( ) ) ;
174
- return xt_minus_one ;
110
+ const inputs = {
111
+ time_input : time_input ,
112
+ image_input : image_input
175
113
} ;
176
114
177
- const map = { } ;
178
- let keySource = 0 ;
179
- self . onmessage = ( evt ) => {
180
- const request = evt . data ;
181
-
182
- switch ( request . type ) {
183
- case 'ddpmStart' : {
184
- const timestep = timesteps - 1 ;
185
- const shape = [ 1 , image_size , image_size , 3 ] ;
186
- const initial_noises = tf . randomNormal ( shape , 0 /*mean*/ , 1 /*stddev*/ , 'float32' , Math . random ( ) * 10000 /*seed*/ ) ;
187
- const image = ddpm_p_sample ( initial_noises , timestep ) ;
188
- initial_noises . dispose ( ) ;
189
- const key = ++ keySource ;
190
- map [ key ] = {
191
- step : timestep ,
192
- image : image
193
- } ;
194
-
115
+ const epsilon = model . predict ( inputs ) ;
116
+ collection . push ( epsilon ) ;
117
+
118
+ const epsilon2 = epsilon . mul ( stable_sqrt ( 1 - alphas_cumprod [ timeStep ] ) ) ;
119
+ collection . push ( epsilon2 ) ;
120
+
121
+ const xt_sub_epsilon2 = image_input . sub ( epsilon2 ) ;
122
+ collection . push ( xt_sub_epsilon2 ) ;
123
+
124
+ const x0 = xt_sub_epsilon2 . div ( stable_sqrt ( alphas_cumprod [ timeStep ] ) ) ;
125
+ collection . push ( x0 ) ;
126
+
127
+ const clipped_x0 = tf . clipByValue ( x0 , - 1.0 , 1.0 ) ;
128
+ collection . push ( clipped_x0 ) ;
129
+
130
+ const x0_coefficient = stable_sqrt ( alphas_cumprod_prev [ timeStep ] ) * betas [ timeStep ] / ( 1 - alphas_cumprod [ timeStep ] ) ;
131
+ const x0_multipled_by_coef = clipped_x0 . mul ( x0_coefficient ) ;
132
+ collection . push ( x0_multipled_by_coef ) ;
133
+
134
+ const xt_coefficient = stable_sqrt ( alphas [ timeStep ] ) * ( 1 - alphas_cumprod_prev [ timeStep ] ) / ( 1 - alphas_cumprod [ timeStep ] ) ;
135
+ const xt_multipled_by_coef = image_input . mul ( xt_coefficient ) ;
136
+ collection . push ( xt_multipled_by_coef ) ;
137
+
138
+ const mean = x0_multipled_by_coef . add ( xt_multipled_by_coef ) ;
139
+ collection . push ( mean ) ;
140
+
141
+ let normal_noise = tf . randomNormal ( image_input . shape , 0 /*mean*/ , 1 /*stddev*/ , 'float32' , Math . random ( ) * 10000 /*seed*/ ) ;
142
+ collection . push ( normal_noise ) ;
143
+
144
+ normal_noise = normal_noise . mul ( stddevs [ timeStep ] ) ;
145
+ collection . push ( normal_noise ) ;
146
+
147
+ const xt_minus_one = mean . add ( normal_noise ) ;
148
+
149
+ collection . forEach ( ( t ) => t . dispose ( ) ) ;
150
+ return xt_minus_one ;
151
+ } ;
152
+
153
+ const map = { } ;
154
+ let keySource = 0 ;
155
+ self . onmessage = ( evt ) => {
156
+ const request = evt . data ;
157
+
158
+ switch ( request . type ) {
159
+ case 'ddpmStart' : {
160
+ const timestep = timesteps - 1 ;
161
+ const shape = [ 1 , image_size , image_size , 3 ] ;
162
+ const initial_noises = tf . randomNormal ( shape , 0 /*mean*/ , 1 /*stddev*/ , 'float32' , Math . random ( ) * 10000 /*seed*/ ) ;
163
+ const image = ddpm_p_sample ( initial_noises , timestep ) ;
164
+ initial_noises . dispose ( ) ;
165
+ const key = ++ keySource ;
166
+ map [ key ] = {
167
+ step : timestep ,
168
+ image : image
169
+ } ;
170
+
171
+ self . postMessage ( { type : 'reply' , id : request . id , data : {
172
+ step : timestep ,
173
+ image : image . arraySync ( ) ,
174
+ key : key ,
175
+ percent : ( timesteps - timestep ) / ( 1.0 * timesteps ) ,
176
+ } } ) ;
177
+ break ;
178
+ }
179
+
180
+ case 'ddpmNext' : {
181
+ const prev = map [ request . kwargs && request . kwargs . key ] ;
182
+ if ( ! prev ) {
183
+ console . log ( 'invalid request' ) ;
184
+ } else {
185
+ delete map [ request . kwargs . key ] ;
186
+ const timestep = prev . step - 1 ;
187
+ const image = ddpm_p_sample ( prev . image , timestep ) ;
188
+ prev . image . dispose ( ) ;
189
+ if ( timestep > 0 ) {
190
+ map [ request . kwargs . key ] = {
191
+ step : timestep ,
192
+ image : image
193
+ } ;
194
+ }
195
195
self . postMessage ( { type : 'reply' , id : request . id , data : {
196
196
step : timestep ,
197
197
image : image . arraySync ( ) ,
198
- key : key ,
199
- percent : ( timesteps - timestep ) / ( 1.0 * timesteps ) ,
198
+ key : request . kwargs . key ,
199
+ percent : ( timesteps - timestep ) / ( 1.0 * timesteps )
200
200
} } ) ;
201
- break ;
202
201
}
203
-
204
- case 'ddpmNext' : {
205
- const prev = map [ request . kwargs && request . kwargs . key ] ;
206
- if ( ! prev ) {
207
- console . log ( 'invalid request' ) ;
208
- } else {
209
- delete map [ request . kwargs . key ] ;
210
- const timestep = prev . step - 1 ;
211
- const image = ddpm_p_sample ( prev . image , timestep ) ;
212
- prev . image . dispose ( ) ;
213
- if ( timestep > 0 ) {
214
- map [ request . kwargs . key ] = {
215
- step : timestep ,
216
- image : image
217
- } ;
218
- }
219
- self . postMessage ( { type : 'reply' , id : request . id , data : {
220
- step : timestep ,
221
- image : image . arraySync ( ) ,
222
- key : request . kwargs . key ,
223
- percent : ( timesteps - timestep ) / ( 1.0 * timesteps )
224
- } } ) ;
225
- }
226
- break ;
227
- }
228
-
202
+ break ;
229
203
}
230
- } ;
231
-
232
-
204
+
205
+ }
206
+ } ;
233
207
234
208
209
+ }
235
210
236
211
212
+ self . postMessage ( { type : 'progress' , progress : 0 , message : 'Loading ' + TF_JS_URL } ) ;
213
+ import ( TF_JS_URL )
214
+ . then ( async ( ) => {
215
+
216
+ await main ( ) ;
237
217
238
218
} )
239
- . catch ( ( err ) => {
240
- self . postMessage ( { type : 'error' , message : 'Unable to load tensorflow.js' } ) ;
241
- console . log ( 'Unable to load tensorflow.js' , err ) ;
219
+ . catch ( async ( err ) => {
220
+
221
+ try {
222
+ await import ( TF_JS_CDN_URL ) ;
223
+ await main ( ) ;
224
+ }
225
+ catch {
226
+ self . postMessage ( { type : 'error' , message : 'Unable to load ' + TF_JS_URL } ) ;
227
+ console . log ( 'Unable to load ' + TF_JS_URL , err ) ;
228
+ }
242
229
} ) ;
243
230
244
231
0 commit comments