Skip to content

Commit a9a9cfa

Browse files
committed
autopush
1 parent 78e5300 commit a9a9cfa

File tree

12 files changed

+110
-19
lines changed

12 files changed

+110
-19
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
3636
ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e"
3737
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
3838
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
39+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
3940
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
4041

4142
[weakdeps]
@@ -78,6 +79,7 @@ Optimisers = "0.4"
7879
OptimizationCMAEvolutionStrategy = "0.3.0"
7980
OptimizationOptimJL = "0.4.3"
8081
ParameterSchedulers = "0.4"
82+
SciMLSensitivity = "7.84.0"
8183
Statistics = "1.11.1"
8284
cuDNN = "1"
8385
julia = "1.11"

benchmark.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ ispath(compdir) || mkpath(compdir)
1818

1919
# List configurations files
2020
using Glob
21-
exclude_patterns = ["att", "cno"]
22-
exclude_patterns = ["cno"]
21+
exclude_patterns = ["att", "cno", "cnn_ins", "_1"]
2322
@warn "Excluding configurations with patterns: $(exclude_patterns)"
2423
all_confs = glob("*.yaml", confdir)
2524
list_confs = filter(conf -> all(!occursin(pat, conf) for pat in exclude_patterns), all_confs)

cnn_model_workflow.jl

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ using OptimizationOptimJL
106106
using OptimizationCMAEvolutionStrategy
107107
using ParameterSchedulers
108108
using Random
109+
using SciMLSensitivity
109110

110111

111112
# ## Random number seeds
@@ -350,6 +351,19 @@ projectorders = eval(Meta.parse(conf["posteriori"]["projectorders"]))
350351
nprojectorders = length(projectorders)
351352
@assert nprojectorders == 1 "Only DCF should be done"
352353

354+
sensealg = haskey(conf["posteriori"], "sensealg") ? eval(Meta.parse(conf["posteriori"]["sensealg"])) : nothing
355+
sciml_solver = haskey(conf["posteriori"], "sciml_solver") ? eval(Meta.parse(conf["posteriori"]["sciml_solver"])) : nothing
356+
if sensealg !== nothing
357+
@info "Using sensitivity algorithm: $sensealg"
358+
else
359+
@info "No sensitivity algorithm specified"
360+
end
361+
if sciml_solver !== nothing
362+
@info "Using SciML solver: $sciml_solver"
363+
else
364+
@info "No SciML solver specified"
365+
end
366+
353367
# Train
354368
for i = 1:ntrajectory
355369
if i%numtasks == taskid -1
@@ -377,7 +391,8 @@ let
377391
nepoch,
378392
do_plot = conf["posteriori"]["do_plot"],
379393
plot_train = conf["posteriori"]["plot_train"],
380-
sensealg = haskey(conf["posteriori"],:sensealg) ? eval(Meta.parse(conf["posteriori"]["sensealg"])) : nothing,
394+
sensealg = sensealg,
395+
sciml_solver = sciml_solver,
381396
dataproj = conf["dataproj"],
382397
)
383398
end
@@ -522,14 +537,14 @@ let
522537
dudt_nomod = NS.create_right_hand_side_inplace(
523538
setup, psolver)
524539

525-
epost.nomodel[I,:], _ = compute_epost(dudt_nomod, θ_cnn_post[I].*0 , tspan, data, tsave, dt)
540+
epost.nomodel[I,:], _ = compute_epost(dudt_nomod, sciml_solver, θ_cnn_post[I].*0 , tspan, data, tsave, dt)
526541
@info "Epost nomodel" epost.nomodel[I,:]
527542
# with closure
528543
dudt = NS.create_right_hand_side_with_closure_inplace(
529544
setup, psolver, closure, st)
530-
epost.model_prior[I, :], _ = compute_epost(dudt, device(θ_cnn_prior[ig, ifil]) , tspan, data, tsave, dt)
545+
epost.model_prior[I, :], _ = compute_epost(dudt, sciml_solver, device(θ_cnn_prior[ig, ifil]) , tspan, data, tsave, dt)
531546
@info "Epost model_prior" epost.model_prior[I, :]
532-
epost.model_post[I, :], epost.model_t_post_inference[I] = compute_epost(dudt, device(θ_cnn_post[ig, ifil, iorder]) , tspan, data, tsave, dt)
547+
epost.model_post[I, :], epost.model_t_post_inference[I] = compute_epost(dudt, sciml_solver, device(θ_cnn_post[ig, ifil, iorder]) , tspan, data, tsave, dt)
533548
@info "Epost model_post" epost.model_post[I, :]
534549

535550
clean()

configs/snellius/cnn_1.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,4 @@ posteriori:
5050
dt: 0.0001
5151
do_plot: false
5252
plot_train: false
53+
sciml_solver: "Tsit5()"

configs/snellius/cnn_CMA.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,4 @@ posteriori:
5151
dt: 0.0001
5252
do_plot: false
5353
plot_train: false
54-
sensealg: "InterpolatingAdjoint()"
54+
sciml_solver: "Tsit5()"

configs/snellius/cnn_backsol.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ docreatedata: true
22
docomp: true
33
ntrajectory: 8
44
T: "Float32"
5-
dataproj: true
5+
dataproj: false
66
params:
77
D: 2
88
lims: [0.0, 1.0]
@@ -33,7 +33,7 @@ closure:
3333
use_bias: [true, true, true, true, false]
3434
rng: "Xoshiro(seeds.θ_start)"
3535
priori:
36-
reuse: "cnn_project"
36+
reuse: "cnn_noproj"
3737
dotrain: true
3838
nepoch: 50000
3939
batchsize: 64
@@ -53,3 +53,4 @@ posteriori:
5353
do_plot: false
5454
plot_train: false
5555
sensealg: "BacksolveAdjoint()"
56+
sciml_solver: "Tsit5()"

configs/snellius/cnn_ins.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,4 @@ posteriori:
5050
do_plot: false
5151
plot_train: false
5252
nsamples: 1
53+
sciml_solver: "Tsit5()"

configs/snellius/cnn_interp.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ docreatedata: true
22
docomp: true
33
ntrajectory: 8
44
T: "Float32"
5-
dataproj: true
5+
dataproj: false
66
params:
77
D: 2
88
lims: [0.0, 1.0]
@@ -33,7 +33,7 @@ closure:
3333
use_bias: [true, true, true, true, false]
3434
rng: "Xoshiro(seeds.θ_start)"
3535
priori:
36-
reuse: "cnn_project"
36+
reuse: "cnn_noproj"
3737
dotrain: true
3838
nepoch: 50000
3939
batchsize: 64
@@ -51,4 +51,6 @@ posteriori:
5151
dt: 0.0001
5252
do_plot: false
5353
plot_train: false
54+
#sensealg: "InterpolatingAdjoint(autojacvec=EnzymeVJP())"
5455
sensealg: "InterpolatingAdjoint()"
56+
sciml_solver: "Tsit5()"

configs/snellius/cnn_lbfgs.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,4 @@ posteriori:
5353
do_plot: false
5454
plot_train: false
5555
sensealg: "InterpolatingAdjoint()"
56+
sciml_solver: "Tsit5()"

configs/snellius/cnn_rk4.yaml

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
docreatedata: true
2+
docomp: true
3+
ntrajectory: 8
4+
T: "Float32"
5+
dataproj: false
6+
params:
7+
D: 2
8+
lims: [0.0, 1.0]
9+
Re: 6000.0
10+
tburn: 0.5
11+
tsim: 5.0
12+
savefreq: 50
13+
ndns: 4096
14+
nles: [64]
15+
filters: ["FaceAverage()"]
16+
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
17+
method: "RKMethods.Wray3(; T)"
18+
bodyforce: "(dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y)"
19+
issteadybodyforce: true
20+
processors: "(; log = timelogger(; nupdate=100))"
21+
Δt: 0.00005
22+
seeds:
23+
dns: 123456
24+
θ_start: 234
25+
prior: 345
26+
post: 456
27+
closure:
28+
name: "cnn_noproj"
29+
type: cnn
30+
radii: [2, 2, 2, 2, 2]
31+
channels: [24, 24, 24, 24, 2]
32+
activations: ["tanh", "tanh", "tanh", "tanh", "identity"]
33+
use_bias: [true, true, true, true, false]
34+
rng: "Xoshiro(seeds.θ_start)"
35+
priori:
36+
reuse: "cnn_noproj"
37+
dotrain: true
38+
nepoch: 50000
39+
batchsize: 64
40+
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
41+
do_plot: false
42+
plot_train: false
43+
posteriori:
44+
dotrain: true
45+
projectorders: "(ProjectOrder.Last, )"
46+
nepoch: 1500
47+
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
48+
nunroll: 5
49+
nunroll_valid: 10
50+
dt: 0.0001
51+
do_plot: false
52+
plot_train: false
53+
nsamples: 1
54+
sciml_solver: "RK4()"

0 commit comments

Comments
 (0)