Skip to content

Commit 0a37829

Browse files
committed
Add extra confs
1 parent c7a3783 commit 0a37829

File tree

9 files changed

+261
-1
lines changed

9 files changed

+261
-1
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ NeuralClosure = "099dac27-d7f2-4047-93d5-0baee36b9c25"
3131
Observables = "510215fc-4207-5dde-b226-833fc4488ee2"
3232
OpenSSL_jll = "458c3c95-2e84-50aa-8efc-19380b2a3a95"
3333
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
34+
OptimizationCMAEvolutionStrategy = "bd407f91-200f-4536-9381-e4ba712f53f8"
35+
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
3436
ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e"
3537
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
3638
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -73,6 +75,8 @@ NeuralClosure = "1.0.0"
7375
Observables = "0.5"
7476
OpenSSL_jll = "3.0.13"
7577
Optimisers = "0.4"
78+
OptimizationCMAEvolutionStrategy = "0.3.0"
79+
OptimizationOptimJL = "0.4.3"
7680
ParameterSchedulers = "0.4"
7781
Statistics = "1.11.1"
7882
cuDNN = "1"

cnn_model_workflow.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ using Lux
101101
using LuxCUDA
102102
using NNlib
103103
using Optimisers
104+
using Optimisers: Adam
105+
using OptimizationOptimJL
106+
using OptimizationCMAEvolutionStrategy
104107
using ParameterSchedulers
105108
using Random
106109

@@ -248,6 +251,11 @@ if haskey(conf["priori"], "reuse")
248251
@info "Reuse a-priori training from closure named: $reuse"
249252
reusepriorfile(reuse, outdir, closure_name)
250253
end
254+
if haskey(conf["posteriori"], "reuse")
255+
reuse = conf["posteriori"]["reuse"]
256+
@info "Reuse a-posteriori training from closure named: $reuse"
257+
reusepostfile(reuse, outdir, closure_name)
258+
end
251259

252260
# Train
253261
for i = 1:ntrajectory

configs/snellius/cnn_CMA.yaml

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
docreatedata: true
2+
docomp: true
3+
ntrajectory: 8
4+
T: "Float32"
5+
dataproj: true
6+
params:
7+
D: 2
8+
lims: [0.0, 1.0]
9+
Re: 6000.0
10+
tburn: 0.5
11+
tsim: 5.0
12+
savefreq: 50
13+
ndns: 4096
14+
nles: [64]
15+
filters: ["FaceAverage()"]
16+
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
17+
method: "RKMethods.Wray3(; T)"
18+
bodyforce: "(dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y)"
19+
issteadybodyforce: true
20+
processors: "(; log = timelogger(; nupdate=100))"
21+
Δt: 0.00005
22+
seeds:
23+
dns: 123456
24+
θ_start: 234
25+
prior: 345
26+
post: 456
27+
closure:
28+
name: "cnn_cma"
29+
type: cnn
30+
radii: [2, 2, 2, 2, 2]
31+
channels: [24, 24, 24, 24, 2]
32+
activations: ["tanh", "tanh", "tanh", "tanh", "identity"]
33+
use_bias: [true, true, true, true, false]
34+
rng: "Xoshiro(seeds.θ_start)"
35+
priori:
36+
reuse: "cnn_project"
37+
dotrain: true
38+
nepoch: 50000
39+
batchsize: 64
40+
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
41+
do_plot: false
42+
plot_train: false
43+
posteriori:
44+
reuse: "cnn_project"
45+
dotrain: true
46+
projectorders: "(ProjectOrder.Last, )"
47+
nepoch: 2000
48+
opt: "OptimizationCMAEvolutionStrategy.CMAEvolutionStrategyOpt()"
49+
nunroll: 5
50+
nunroll_valid: 10
51+
nsamples: 1
52+
dt: 0.0001
53+
do_plot: false
54+
plot_train: false
55+
sensealg: "InterpolatingAdjoint()"

configs/snellius/cnn_backsol.yaml

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
docreatedata: true
2+
docomp: true
3+
ntrajectory: 8
4+
T: "Float32"
5+
dataproj: true
6+
params:
7+
D: 2
8+
lims: [0.0, 1.0]
9+
Re: 6000.0
10+
tburn: 0.5
11+
tsim: 5.0
12+
savefreq: 50
13+
ndns: 4096
14+
nles: [64]
15+
filters: ["FaceAverage()"]
16+
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
17+
method: "RKMethods.Wray3(; T)"
18+
bodyforce: "(dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y)"
19+
issteadybodyforce: true
20+
processors: "(; log = timelogger(; nupdate=100))"
21+
Δt: 0.00005
22+
seeds:
23+
dns: 123456
24+
θ_start: 234
25+
prior: 345
26+
post: 456
27+
closure:
28+
name: "cnn_backsol"
29+
type: cnn
30+
radii: [2, 2, 2, 2, 2]
31+
channels: [24, 24, 24, 24, 2]
32+
activations: ["tanh", "tanh", "tanh", "tanh", "identity"]
33+
use_bias: [true, true, true, true, false]
34+
rng: "Xoshiro(seeds.θ_start)"
35+
priori:
36+
reuse: "cnn_project"
37+
dotrain: true
38+
nepoch: 50000
39+
batchsize: 64
40+
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
41+
do_plot: false
42+
plot_train: false
43+
posteriori:
44+
#reuse: "cnn_project"
45+
dotrain: true
46+
projectorders: "(ProjectOrder.Last, )"
47+
nepoch: 1500
48+
opt: "OptimiserChain(Adam(T(1.0e-4)), ClipGrad(0.01))"
49+
nunroll: 5
50+
nunroll_valid: 10
51+
nsamples: 1
52+
dt: 0.0001
53+
do_plot: false
54+
plot_train: false
55+
sensealg: "BacksolveAdjoint()"

configs/snellius/cnn_interp.yaml

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
docreatedata: true
2+
docomp: true
3+
ntrajectory: 8
4+
T: "Float32"
5+
dataproj: true
6+
params:
7+
D: 2
8+
lims: [0.0, 1.0]
9+
Re: 6000.0
10+
tburn: 0.5
11+
tsim: 5.0
12+
savefreq: 50
13+
ndns: 4096
14+
nles: [64]
15+
filters: ["FaceAverage()"]
16+
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
17+
method: "RKMethods.Wray3(; T)"
18+
bodyforce: "(dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y)"
19+
issteadybodyforce: true
20+
processors: "(; log = timelogger(; nupdate=100))"
21+
Δt: 0.00005
22+
seeds:
23+
dns: 123456
24+
θ_start: 234
25+
prior: 345
26+
post: 456
27+
closure:
28+
name: "cnn_interp"
29+
type: cnn
30+
radii: [2, 2, 2, 2, 2]
31+
channels: [24, 24, 24, 24, 2]
32+
activations: ["tanh", "tanh", "tanh", "tanh", "identity"]
33+
use_bias: [true, true, true, true, false]
34+
rng: "Xoshiro(seeds.θ_start)"
35+
priori:
36+
reuse: "cnn_project"
37+
dotrain: true
38+
nepoch: 50000
39+
batchsize: 64
40+
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
41+
do_plot: false
42+
plot_train: false
43+
posteriori:
44+
dotrain: true
45+
projectorders: "(ProjectOrder.Last, )"
46+
nepoch: 1500
47+
opt: "OptimiserChain(Adam(T(1.0e-4)), ClipGrad(0.01))"
48+
nunroll: 5
49+
nunroll_valid: 10
50+
nsamples: 1
51+
dt: 0.0001
52+
do_plot: false
53+
plot_train: false
54+
sensealg: "InterpolatingAdjoint()"

configs/snellius/cnn_lbfgs.yaml

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
docreatedata: true
2+
docomp: true
3+
ntrajectory: 8
4+
T: "Float32"
5+
dataproj: true
6+
params:
7+
D: 2
8+
lims: [0.0, 1.0]
9+
Re: 6000.0
10+
tburn: 0.5
11+
tsim: 5.0
12+
savefreq: 50
13+
ndns: 4096
14+
nles: [64]
15+
filters: ["FaceAverage()"]
16+
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
17+
method: "RKMethods.Wray3(; T)"
18+
bodyforce: "(dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y)"
19+
issteadybodyforce: true
20+
processors: "(; log = timelogger(; nupdate=100))"
21+
Δt: 0.00005
22+
seeds:
23+
dns: 123456
24+
θ_start: 234
25+
prior: 345
26+
post: 456
27+
closure:
28+
name: "cnn_lbfgs"
29+
type: cnn
30+
radii: [2, 2, 2, 2, 2]
31+
channels: [24, 24, 24, 24, 2]
32+
activations: ["tanh", "tanh", "tanh", "tanh", "identity"]
33+
use_bias: [true, true, true, true, false]
34+
rng: "Xoshiro(seeds.θ_start)"
35+
priori:
36+
reuse: "cnn_project"
37+
dotrain: true
38+
nepoch: 50000
39+
batchsize: 64
40+
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
41+
do_plot: false
42+
plot_train: false
43+
posteriori:
44+
reuse: "cnn_project"
45+
dotrain: true
46+
projectorders: "(ProjectOrder.Last, )"
47+
nepoch: 2000
48+
opt: "Optim.LBFGS()"
49+
nunroll: 5
50+
nunroll_valid: 10
51+
nsamples: 1
52+
dt: 0.0001
53+
do_plot: false
54+
plot_train: false
55+
sensealg: "InterpolatingAdjoint()"

extra_model_workflow.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,18 @@ end
259259
# Save parameters to disk after each run.
260260
# Plot training progress (for a validation data batch).
261261

262+
# Check if it is asked to re-use the a-priori training from a different model
263+
if haskey(conf["priori"], "reuse")
264+
reuse = conf["priori"]["reuse"]
265+
@info "Reuse a-priori training from closure named: $reuse"
266+
reusepriorfile(reuse, outdir, closure_name)
267+
end
268+
if haskey(conf["posteriori"], "reuse")
269+
reuse = conf["posteriori"]["reuse"]
270+
@info "Reuse a-posteriori training from closure named: $reuse"
271+
reusepostfile(reuse, outdir, closure_name)
272+
end
273+
262274
# Train
263275
@info "A priori training"
264276
for i = 1:ntrajectory

src/Benchmark.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,6 @@ export _convert_to_single_index,
112112
plot_epost_vs_t
113113

114114
export compute_eprior, compute_epost, compute_t_prior_inference
115-
export reusepriorfile
115+
export reusepriorfile, reusepostfile
116116

117117
end # module Benchmark

src/train.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,23 @@ function reusepriorfile(reuse, outdir, closure_name)
6868
end
6969
end
7070

71+
function reusepostfile(reuse, outdir, closure_name)
72+
reusepath = joinpath(outdir, "posttraining", reuse)
73+
targetpath = joinpath(outdir, "posttraining", closure_name)
74+
# If the reuse path exists, copy it to the target path
75+
if ispath(reusepath)
76+
@info "Reusing post training from $(reusepath) to $(targetpath)"
77+
ispath(targetpath) || mkpath(targetpath)
78+
for file in readdir(reusepath, join = true)
79+
@info "Copying post training file $(file) to $(targetpath)"
80+
cp(file, joinpath(targetpath, basename(file)); force = true)
81+
end
82+
else
83+
@warn "Reuse path $(reusepath) does not exist. Not reusing post training."
84+
end
85+
end
86+
87+
7188
"Load a-priori training results from correct file names."
7289
loadprior(outdir, closure_name, nles, filters) = map(
7390
splat((nles, Φ) -> load_object(getpriorfile(outdir, closure_name, nles, Φ))),

0 commit comments

Comments
 (0)