Skip to content

Commit 8a7c8e4

Browse files
committed
Use adapt=false and RK4
1 parent 627893e commit 8a7c8e4

File tree

5 files changed

+81
-24
lines changed

5 files changed

+81
-24
lines changed

cnn_model_workflow.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ let
349349
dns_seeds_train,
350350
dns_seeds_valid,
351351
nunroll = conf["posteriori"]["nunroll"],
352+
dt = conf["posteriori"]["dt"],
352353
closure,
353354
closure_name,
354355
θ_start = θ_cnn_prior,
@@ -492,19 +493,20 @@ let
492493
t = sample.t[it],
493494
)
494495
tspan = (data.t[1], data.t[end])
496+
dt = conf["posteriori"]["dt"]
495497

496498
## No model
497499
dudt_nomod = NS.create_right_hand_side_inplace(
498500
setup, psolver)
499501

500-
epost.nomodel[I,:], _ = compute_epost(dudt_nomod, θ_cnn_post[I].*0 , tspan, data, tsave)
502+
epost.nomodel[I,:], _ = compute_epost(dudt_nomod, θ_cnn_post[I].*0 , tspan, data, tsave, dt)
501503
@info "Epost nomodel" epost.nomodel[I,:]
502504
# with closure
503505
dudt = NS.create_right_hand_side_with_closure_inplace(
504506
setup, psolver, closure, st)
505-
epost.model_prior[I, :], _ = compute_epost(dudt, device(θ_cnn_prior[ig, ifil]) , tspan, data, tsave)
507+
epost.model_prior[I, :], _ = compute_epost(dudt, device(θ_cnn_prior[ig, ifil]) , tspan, data, tsave, dt)
506508
@info "Epost model_prior" epost.model_prior[I, :]
507-
epost.model_post[I, :], epost.model_t_post_inference[I] = compute_epost(dudt, device(θ_cnn_post[I]) , tspan, data, tsave)
509+
epost.model_post[I, :], epost.model_t_post_inference[I] = compute_epost(dudt, device(θ_cnn_post[I]) , tspan, data, tsave, dt)
508510
@info "Epost model_post" epost.model_post[I, :]
509511

510512
clean()

configs/local/cnn_1 copy.yaml

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
docreatedata: true
2+
docomp: false
3+
ntrajectory: 8
4+
T: "Float32"
5+
params:
6+
D: 2
7+
lims: [0.0, 1.0]
8+
Re: 6000.0
9+
tburn: 0.5
10+
tsim: 2.0
11+
savefreq: 100
12+
ndns: 1024
13+
nles: [32]
14+
filters: ["FaceAverage()"]
15+
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
16+
method: "RKMethods.Wray3(; T)"
17+
bodyforce: "(dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y)"
18+
issteadybodyforce: true
19+
processors: "(; log = timelogger(; nupdate=100))"
20+
Δt: 0.0001
21+
seeds:
22+
dns: 1111
23+
θ_start: 234
24+
prior: 345
25+
post: 456
26+
closure:
27+
name: "cnn_new"
28+
type: cnn
29+
radii: [2, 2, 2, 2, 2]
30+
channels: [24, 24, 24, 24, 2]
31+
activations: ["tanh", "tanh", "tanh", "tanh", "identity"]
32+
use_bias: [true, true, true, true, false]
33+
rng: "Xoshiro(seeds.θ_start)"
34+
priori:
35+
dotrain: true
36+
nepoch: 5000
37+
batchsize: 64
38+
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
39+
do_plot: false
40+
plot_train: false
41+
posteriori:
42+
dotrain: true
43+
projectorders: "(ProjectOrder.Last, )"
44+
nepoch: 300
45+
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
46+
nunroll: 5
47+
nunroll_valid: 5
48+
dt: T(1e-3)
49+
do_plot: false
50+
plot_train: false

configs/local/cnn_2.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ params:
77
lims: [0.0, 1.0]
88
Re: 6000.0
99
tburn: 0.5
10-
tsim: 2.0
10+
tsim: 1.0
1111
savefreq: 100
12-
ndns: 1024
12+
ndns: 256
1313
nles: [32]
1414
filters: ["FaceAverage()"]
1515
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
@@ -19,7 +19,7 @@ params:
1919
processors: "(; log = timelogger(; nupdate=100))"
2020
Δt: 0.0001
2121
seeds:
22-
dns: 12345
22+
dns: 1234567
2323
θ_start: 234
2424
prior: 345
2525
post: 456

extra_model_workflow.jl

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ let
368368
dns_seeds_train,
369369
dns_seeds_valid,
370370
nunroll = conf["posteriori"]["nunroll"],
371+
dt = conf["posteriori"]["dt"],
371372
closure,
372373
closure_name,
373374
θ_start = θ_cnn_prior,
@@ -511,20 +512,20 @@ let
511512
u = selectdim(sample.u, ndims(sample.u), it) |> collect |> device,
512513
t = sample.t[it],
513514
)
514-
dt = T(data.t[2] - data.t[1])
515515
tspan = (data.t[1], data.t[end])
516+
dt = conf["posteriori"]["dt"]
516517

517518
## No model
518519
dudt_nomod = NS.create_right_hand_side_inplace(
519520
setup, psolver)
520-
epost.nomodel[I, :], epost.nomodel_t_post_inference[I] = compute_epost(dudt_nomod, θ_cnn_post[I].*0 , tspan, data, tsave)
521+
epost.nomodel[I, :], epost.nomodel_t_post_inference[I] = compute_epost(dudt_nomod, θ_cnn_post[I].*0 , tspan, data, tsave, dt)
521522
@info "Epost nomodel" epost.nomodel[I,:]
522523
# with closure
523524
dudt = NS.create_right_hand_side_with_closure_inplace(
524525
setup, psolver, closure, st)
525-
epost.model_prior[I, :], _ = compute_epost(dudt, device(θ_cnn_prior[ig, ifil]) , tspan, data, tsave)
526+
epost.model_prior[I, :], _ = compute_epost(dudt, device(θ_cnn_prior[ig, ifil]) , tspan, data, tsave, dt)
526527
@info "Epost model_prior" epost.model_prior[I, :]
527-
epost.model_post[I, :], epost.model_t_post_inference[I] = compute_epost(dudt, device(θ_cnn_post[I]) , tspan, data, tsave)
528+
epost.model_post[I, :], epost.model_t_post_inference[I] = compute_epost(dudt, device(θ_cnn_post[I]) , tspan, data, tsave, dt)
528529
@info "Epost model_post" epost.model_post[I, :]
529530
clean()
530531
end
@@ -657,7 +658,7 @@ let
657658
θ_prior = device(θ_cnn_prior[ig, ifil])
658659
θ_post = device(θ_cnn_post[I])
659660

660-
dt = T(sample.t[2] - sample.t[1])
661+
dt = conf["posteriori"]["dt"]
661662
tspan = (sample.t[1], sample.t[end])
662663
dt_sample = T(0.05) # Sample every 0.05 seconds for the history (same as INS)
663664
tsave = (x*dt_sample for x in 1:(floor(Int, length(sample.t) / 0.05)+1))
@@ -671,10 +672,10 @@ let
671672
pred_prior =
672673
solve(
673674
prob_prior,
674-
Tsit5();
675+
RK4();
675676
u0 = x,
676677
p = θ_prior,
677-
adaptive = true,
678+
adaptive = false,
678679
saveat = tsave,
679680
dt = dt,
680681
tspan = tspan,
@@ -685,10 +686,10 @@ let
685686
pred_post =
686687
solve(
687688
prob_post,
688-
Tsit5();
689+
RK4();
689690
u0 = x,
690691
p = θ_post,
691-
adaptive = true,
692+
adaptive = false,
692693
saveat = tsave,
693694
dt = dt,
694695
tspan = tspan,
@@ -964,7 +965,7 @@ let
964965
θ_prior = device(θ_cnn_prior[I])
965966
θ_post = device(θ_cnn_post[I])
966967

967-
dt = T(1e-4)
968+
dt = conf["posteriori"]["dt"]
968969
tspan = (T(0), times[end]+T(1e-4))
969970

970971
dudt = NS.create_right_hand_side_with_closure_inplace(
@@ -976,23 +977,25 @@ let
976977
pred_prior =
977978
solve(
978979
prob_prior,
979-
Tsit5(),
980+
RK4(),
980981
u0 = x,
981982
p = θ_prior,
982-
adaptive = true,
983+
adaptive = false,
983984
saveat = times,
984985
tspan = tspan,
986+
dt = dt,
985987
)
986988
prob_post = ODEProblem(dudt, x, tspan, θ_post)
987989
pred_post =
988990
solve(
989991
prob_post,
990-
Tsit5(),
992+
RK4(),
991993
u0 = x,
992994
p = θ_post,
993-
adaptive = true,
995+
adaptive = false,
994996
saveat = times,
995997
tspan = tspan,
998+
dt = dt,
996999
)
9971000

9981001
for it in 1:length(times)

src/train.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,9 @@ function trainpost(;
293293
dudt_nn,
294294
griddims,
295295
inside;
296+
dt = dt,
296297
ensemble = nsamples > 1,
297-
sciml_solver = Tsit5(),
298+
sciml_solver = RK4(),
298299
sensealg = sensealg,
299300
)
300301

@@ -386,7 +387,7 @@ function compute_t_prior_inference(closure, θ, st, x, y, nreps = 1000)
386387
end
387388

388389

389-
function compute_epost(rhs, ps, tspan, (u, t), tsave)
390+
function compute_epost(rhs, ps, tspan, (u, t), tsave, dt)
390391
griddims = ((:) for _ = 1:(ndims(u)-2))
391392
inside = ((2:(size(u, 1)-1)) for _ = 1:(ndims(u)-2))
392393
x = u[griddims..., :, 1]
@@ -395,13 +396,14 @@ function compute_epost(rhs, ps, tspan, (u, t), tsave)
395396
t0 = time()
396397
pred = solve(
397398
prob,
398-
Tsit5();
399+
RK4();
399400
u0 = x,
400401
p = ps,
401-
adaptive = true,
402+
adaptive = false,
402403
saveat = Array(t),
403404
tspan = tspan,
404405
save_start = false,
406+
dt = dt,
405407
)
406408

407409
e = 0.0

0 commit comments

Comments
 (0)