368368 dns_seeds_train,
369369 dns_seeds_valid,
370370 nunroll = conf[" posteriori" ][" nunroll" ],
371+ dt = conf[" posteriori" ][" dt" ],
371372 closure,
372373 closure_name,
373374 θ_start = θ_cnn_prior,
@@ -511,20 +512,20 @@ let
511512 u = selectdim (sample. u, ndims (sample. u), it) |> collect |> device,
512513 t = sample. t[it],
513514 )
514- dt = T (data. t[2 ] - data. t[1 ])
515515 tspan = (data. t[1 ], data. t[end ])
516+ dt = conf[" posteriori" ][" dt" ]
516517
517518 # # No model
518519 dudt_nomod = NS. create_right_hand_side_inplace (
519520 setup, psolver)
520- epost. nomodel[I, :], epost. nomodel_t_post_inference[I] = compute_epost (dudt_nomod, θ_cnn_post[I]. * 0 , tspan, data, tsave)
521+ epost. nomodel[I, :], epost. nomodel_t_post_inference[I] = compute_epost (dudt_nomod, θ_cnn_post[I]. * 0 , tspan, data, tsave, dt )
521522 @info " Epost nomodel" epost. nomodel[I,:]
522523 # with closure
523524 dudt = NS. create_right_hand_side_with_closure_inplace (
524525 setup, psolver, closure, st)
525- epost. model_prior[I, :], _ = compute_epost (dudt, device (θ_cnn_prior[ig, ifil]) , tspan, data, tsave)
526+ epost. model_prior[I, :], _ = compute_epost (dudt, device (θ_cnn_prior[ig, ifil]) , tspan, data, tsave, dt )
526527 @info " Epost model_prior" epost. model_prior[I, :]
527- epost. model_post[I, :], epost. model_t_post_inference[I] = compute_epost (dudt, device (θ_cnn_post[I]) , tspan, data, tsave)
528+ epost. model_post[I, :], epost. model_t_post_inference[I] = compute_epost (dudt, device (θ_cnn_post[I]) , tspan, data, tsave, dt )
528529 @info " Epost model_post" epost. model_post[I, :]
529530 clean ()
530531 end
657658 θ_prior = device (θ_cnn_prior[ig, ifil])
658659 θ_post = device (θ_cnn_post[I])
659660
660- dt = T (sample . t[ 2 ] - sample . t[ 1 ])
661+ dt = conf[ " posteriori " ][ " dt " ]
661662 tspan = (sample. t[1 ], sample. t[end ])
662663 dt_sample = T (0.05 ) # Sample every 0.05 seconds for the history (same as INS)
663664 tsave = (x* dt_sample for x in 1 : (floor (Int, length (sample. t) / 0.05 )+ 1 ))
@@ -671,10 +672,10 @@ let
671672 pred_prior =
672673 solve (
673674 prob_prior,
674- Tsit5 ();
675+ RK4 ();
675676 u0 = x,
676677 p = θ_prior,
677- adaptive = true ,
678+ adaptive = false ,
678679 saveat = tsave,
679680 dt = dt,
680681 tspan = tspan,
@@ -685,10 +686,10 @@ let
685686 pred_post =
686687 solve (
687688 prob_post,
688- Tsit5 ();
689+ RK4 ();
689690 u0 = x,
690691 p = θ_post,
691- adaptive = true ,
692+ adaptive = false ,
692693 saveat = tsave,
693694 dt = dt,
694695 tspan = tspan,
964965 θ_prior = device (θ_cnn_prior[I])
965966 θ_post = device (θ_cnn_post[I])
966967
967- dt = T ( 1e-4 )
968+ dt = conf[ " posteriori " ][ " dt " ]
968969 tspan = (T (0 ), times[end ]+ T (1e-4 ))
969970
970971 dudt = NS. create_right_hand_side_with_closure_inplace (
@@ -976,23 +977,25 @@ let
976977 pred_prior =
977978 solve (
978979 prob_prior,
979- Tsit5 (),
980+ RK4 (),
980981 u0 = x,
981982 p = θ_prior,
982- adaptive = true ,
983+ adaptive = false ,
983984 saveat = times,
984985 tspan = tspan,
986+ dt = dt,
985987 )
986988 prob_post = ODEProblem (dudt, x, tspan, θ_post)
987989 pred_post =
988990 solve (
989991 prob_post,
990- Tsit5 (),
992+ RK4 (),
991993 u0 = x,
992994 p = θ_post,
993- adaptive = true ,
995+ adaptive = false ,
994996 saveat = times,
995997 tspan = tspan,
998+ dt = dt,
996999 )
9971000
9981001 for it in 1 : length (times)
0 commit comments