Skip to content

Commit 6f45feb

Browse files
committed
autopush
1 parent d28470a commit 6f45feb

File tree

7 files changed

+177
-8
lines changed

7 files changed

+177
-8
lines changed

cnn_model_workflow.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ let
282282
plot_train = conf["priori"]["plot_train"],
283283
nepoch = nepoch,
284284
dataproj = conf["dataproj"],
285+
λ = haskey(conf["priori"], "λ") ? eval(Meta.parse(conf["priori"]["λ"])) : nothing
285286
)
286287
end
287288
end
@@ -394,6 +395,8 @@ let
394395
sensealg = sensealg,
395396
sciml_solver = sciml_solver,
396397
dataproj = conf["dataproj"],
398+
λ = haskey(conf["posteriori"], "λ") ? eval(Meta.parse(conf["posteriori"]["λ"])) : nothing,
399+
multi_shooting = haskey(conf["posteriori"], "multi_shooting") ? conf["posteriori"]["multi_shooting"] : 0,
397400
)
398401
end
399402
end

configs/snellius32/cnn_backsolve25.yaml renamed to configs/snellius32/cnn_backsolve15.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ seeds:
2525
prior: 345
2626
post: 456
2727
closure:
28-
name: "Backsolve_25unroll"
28+
name: "Backsolve_15unroll"
2929
type: cnn
3030
radii: [2, 2, 2, 2, 2]
3131
channels: [24, 24, 24, 24, 2]
@@ -43,10 +43,10 @@ priori:
4343
posteriori:
4444
dotrain: true
4545
projectorders: "(ProjectOrder.Last, )"
46-
nepoch: 3000
46+
nepoch: 2000
4747
opt: "Adam(T(1.0e-4))"
48-
nunroll: 25
49-
nunroll_valid: 25
48+
nunroll: 15
49+
nunroll_valid: 15
5050
dt: 0.0001
5151
do_plot: false
5252
plot_train: false
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
docreatedata: true
2+
docomp: true
3+
ntrajectory: 8
4+
T: "Float32"
5+
dataproj: false
6+
params:
7+
D: 2
8+
lims: [0.0, 1.0]
9+
Re: 6000.0
10+
tburn: 0.5
11+
tsim: 5.0
12+
savefreq: 50
13+
ndns: 4096
14+
nles: [32]
15+
filters: ["FaceAverage()"]
16+
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
17+
method: "RKMethods.Wray3(; T)"
18+
bodyforce: "(dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y)"
19+
issteadybodyforce: true
20+
processors: "(; log = timelogger(; nupdate=100))"
21+
Δt: 0.00005
22+
seeds:
23+
dns: 123456
24+
θ_start: 234
25+
prior: 345
26+
post: 456
27+
closure:
28+
name: "Multishoot"
29+
type: cnn
30+
radii: [2, 2, 2, 2, 2]
31+
channels: [24, 24, 24, 24, 2]
32+
activations: ["tanh", "tanh", "tanh", "tanh", "identity"]
33+
use_bias: [true, true, true, true, false]
34+
rng: "Xoshiro(seeds.θ_start)"
35+
priori:
36+
reuse: "NoProjection"
37+
dotrain: true
38+
nepoch: 50000
39+
batchsize: 64
40+
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
41+
do_plot: false
42+
plot_train: false
43+
posteriori:
44+
dotrain: true
45+
projectorders: "(ProjectOrder.Last, )"
46+
nepoch: 2000
47+
opt: "Adam(T(1.0e-4))"
48+
nunroll: 15
49+
multi_shooting: 3
50+
nunroll_valid: 15
51+
dt: 0.0001
52+
do_plot: false
53+
plot_train: false
54+
nsamples: 5
55+
sciml_solver: "Tsit5()"

configs/snellius32/cnn_noproj25.yaml renamed to configs/snellius32/cnn_noproj15.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ seeds:
2525
prior: 345
2626
post: 456
2727
closure:
28-
name: "NoProjection_25unroll"
28+
name: "NoProjection_15unroll"
2929
type: cnn
3030
radii: [2, 2, 2, 2, 2]
3131
channels: [24, 24, 24, 24, 2]
@@ -43,10 +43,10 @@ priori:
4343
posteriori:
4444
dotrain: true
4545
projectorders: "(ProjectOrder.Last, )"
46-
nepoch: 3000
46+
nepoch: 2000
4747
opt: "Adam(T(1.0e-4))"
48-
nunroll: 25
49-
nunroll_valid: 25
48+
nunroll: 15
49+
nunroll_valid: 15
5050
dt: 0.0001
5151
do_plot: false
5252
plot_train: false
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
docreatedata: true
2+
docomp: true
3+
ntrajectory: 8
4+
T: "Float32"
5+
dataproj: false
6+
params:
7+
D: 2
8+
lims: [0.0, 1.0]
9+
Re: 6000.0
10+
tburn: 0.5
11+
tsim: 5.0
12+
savefreq: 50
13+
ndns: 4096
14+
nles: [64]
15+
filters: ["FaceAverage()"]
16+
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
17+
method: "RKMethods.Wray3(; T)"
18+
bodyforce: "(dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y)"
19+
issteadybodyforce: true
20+
processors: "(; log = timelogger(; nupdate=100))"
21+
Δt: 0.00005
22+
seeds:
23+
dns: 123456
24+
θ_start: 234
25+
prior: 345
26+
post: 456
27+
closure:
28+
name: "Multishooting"
29+
type: cnn
30+
radii: [2, 2, 2, 2, 2]
31+
channels: [24, 24, 24, 24, 2]
32+
activations: ["tanh", "tanh", "tanh", "tanh", "identity"]
33+
use_bias: [true, true, true, true, false]
34+
rng: "Xoshiro(seeds.θ_start)"
35+
priori:
36+
reuse : "NoProjection"
37+
dotrain: true
38+
nepoch: 50000
39+
batchsize: 64
40+
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
41+
do_plot: false
42+
plot_train: false
43+
posteriori:
44+
dotrain: true
45+
projectorders: "(ProjectOrder.Last, )"
46+
nepoch: 2000
47+
opt: "Adam(T(1.0e-4))"
48+
nunroll: 15
49+
multi_shooting: 3
50+
nunroll_valid: 15
51+
dt: 0.0001
52+
do_plot: false
53+
plot_train: false
54+
nsamples: 1
55+
sciml_solver: "Tsit5()"
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
docreatedata: true
2+
docomp: true
3+
ntrajectory: 8
4+
T: "Float32"
5+
dataproj: false
6+
params:
7+
D: 2
8+
lims: [0.0, 1.0]
9+
Re: 6000.0
10+
tburn: 0.5
11+
tsim: 5.0
12+
savefreq: 50
13+
ndns: 4096
14+
nles: [64]
15+
filters: ["FaceAverage()"]
16+
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
17+
method: "RKMethods.Wray3(; T)"
18+
bodyforce: "(dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y)"
19+
issteadybodyforce: true
20+
processors: "(; log = timelogger(; nupdate=100))"
21+
Δt: 0.00005
22+
seeds:
23+
dns: 123456
24+
θ_start: 234
25+
prior: 345
26+
post: 456
27+
closure:
28+
name: "NoProjection_15unroll"
29+
type: cnn
30+
radii: [2, 2, 2, 2, 2]
31+
channels: [24, 24, 24, 24, 2]
32+
activations: ["tanh", "tanh", "tanh", "tanh", "identity"]
33+
use_bias: [true, true, true, true, false]
34+
rng: "Xoshiro(seeds.θ_start)"
35+
priori:
36+
reuse : "NoProjection"
37+
dotrain: true
38+
nepoch: 50000
39+
batchsize: 64
40+
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
41+
do_plot: false
42+
plot_train: false
43+
posteriori:
44+
dotrain: true
45+
projectorders: "(ProjectOrder.Last, )"
46+
nepoch: 2000
47+
opt: "Adam(T(1.0e-4))"
48+
nunroll: 15
49+
nunroll_valid: 15
50+
dt: 0.0001
51+
do_plot: false
52+
plot_train: false
53+
nsamples: 5
54+
sciml_solver: "Tsit5()"

src/train.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ function trainpost(;
276276
sciml_solver = nothing,
277277
dataproj,
278278
λ = nothing,
279+
multi_shooting = 0,
279280
)
280281
device(x) = adapt(params.backend, x)
281282
itotal = 0
@@ -332,6 +333,7 @@ function trainpost(;
332333
ensemble = nsamples > 1,
333334
sciml_solver = sciml_solver,
334335
sensealg = sensealg,
336+
multiple_shooting = multi_shooting,
335337
)
336338

337339
if loadcheckpoint && isfile(checkfile)

0 commit comments

Comments
 (0)