@@ -239,6 +239,7 @@ function trainpost(;
239239 postseed,
240240 dns_seeds_train,
241241 dns_seeds_valid,
242+ dns_seeds_test,
242243 nunroll,
243244 nsamples = 1 ,
244245 closure,
@@ -281,21 +282,15 @@ function trainpost(;
281282 checkfile = join (splitext (postfile), " _checkpoint" )
282283 setup = getsetup (; params, nles)
283284 psolver = default_psolver (setup)
284- # Read the data in the format expected by the CoupledNODE
285285 T = eltype (params. Re)
286- setup = []
287- for nl in nles
288- x = ntuple (α -> LinRange (T (0.0 ), T (1.0 ), nl + 1 ), params. D)
289- push! (setup, Setup (; x = x, Re = params. Re, params. backend))
290- end
291286
292287 # Read the data in the format expected by the CoupledNODE
293288 data_train = load_data_set (outdir, nles, Φ, dns_seeds_train, dataproj)
294289 data_valid = load_data_set (outdir, nles, Φ, dns_seeds_valid, dataproj)
295290
296291 NS = Base. get_extension (CoupledNODE, :NavierStokes )
297- io_train = NS. create_io_arrays_posteriori (data_train, setup[ 1 ] , device)
298- io_valid = NS. create_io_arrays_posteriori (data_valid, setup[ 1 ] , device)
292+ io_train = NS. create_io_arrays_posteriori (data_train, setup, device)
293+ io_valid = NS. create_io_arrays_posteriori (data_valid, setup, device)
299294 θ = device (copy (θ_start[itotal]))
300295 dataloader_post = NS. create_dataloader_posteriori (
301296 io_train;
@@ -305,7 +300,7 @@ function trainpost(;
305300 device = device,
306301 )
307302
308- dudt_nn = NS. create_right_hand_side_with_closure (setup[ 1 ] , psolver, closure, st)
303+ dudt_nn = NS. create_right_hand_side_with_closure (setup, psolver, closure, st)
309304 griddims = ((:) for _ = 1 : params. D)
310305 inside = ((2 : (nles+ 1 )) for _ = 1 : params. D)
311306 loss = CoupledNODE. create_loss_post_lux (
@@ -338,11 +333,25 @@ function trainpost(;
338333 nepochs_left = nepoch
339334 end
340335
336+
337+ # For the callback I am going to use the a-posteriori error estimator
338+ sample = namedtupleload (getdatafile (outdir, nles, Φ, dns_seeds_test[1 ]))
339+ it = 1 : (nunroll_valid+ 1 )
340+ data_cb = (;
341+ u = selectdim (sample. u, ndims (sample. u), it) |> collect |> device,
342+ t = sample. t[it],
343+ )
344+ tspan = (data_cb. t[1 ], data_cb. t[end ])
345+ tsave = [nunroll_valid]
346+ dudt_cb = NS. create_right_hand_side_with_closure_inplace (
347+ setup, psolver, closure, st)
348+ loss_cb (_model, pp, _st, _data ) = compute_epost (dudt_cb, pp , tspan, data_cb, tsave, dt)[1 ][end ]
349+
341350 callbackstate, callback = NS. create_callback (
342351 closure,
343352 θ,
344353 io_valid,
345- loss ,
354+ loss_cb ,
346355 st;
347356 callbackstate = callbackstate,
348357 nunroll = nunroll_valid,
@@ -420,9 +429,10 @@ function compute_epost(rhs, ps, tspan, (u, t), tsave, dt)
420429 p = ps,
421430 adaptive = true ,
422431 saveat = Array (t),
423- tspan = tspan,
424- save_start = false ,
425- dt = dt,
432+ # tstops = Array(t),
433+ # tspan = tspan,
434+ # save_start = false,
435+ # dt = dt,
426436 )
427437
428438 e = 0.0
@@ -441,6 +451,12 @@ function compute_epost(rhs, ps, tspan, (u, t), tsave, dt)
441451 push! (es, e / (it - 1 ))
442452 end
443453 end
454+ # for it in tsave
455+ # yref = y[inside..., :, 1:it]
456+ # ypred = pred[inside..., :, 1:it]
457+
458+ # Lux.MSELoss()(ypred, yref) |> e -> push!(es, e)
459+ # end
444460
445461 return es, time () - t0
446462
0 commit comments