Skip to content

Commit 9f2161b

Browse files
committed
Add function to reuse apriori-trained model
1 parent c5b23b8 commit 9f2161b

File tree

8 files changed

+134
-6
lines changed

8 files changed

+134
-6
lines changed

benchmark.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ end
1111
basedir = haskey(ENV, "DEEPDIP") ? ENV["DEEPDIP"] : @__DIR__
1212
outdir = joinpath(basedir, "output", "kolmogorov")
1313
confdir = joinpath(basedir, "configs/local")
14+
confdir = joinpath(basedir, "configs/snellius")
1415
@warn "Using configuration files from $confdir"
1516
compdir = joinpath(outdir, "comparison")
1617
ispath(compdir) || mkpath(compdir)

cnn_model_workflow.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ end
176176

177177
# Computational time
178178
docomp = conf["docomp"]
179+
docomp = false
179180
docomp && let
180181
comptime, datasize = 0.0, 0.0
181182
for seed in dns_seeds
@@ -241,6 +242,13 @@ closure_INS, θ_INS = NeuralClosure.cnn(;
241242
# Save parameters to disk after each run.
242243
# Plot training progress (for a validation data batch).
243244

245+
# Check if it is asked to re-use the a-priori training from a different model
246+
if haskey(conf["priori"], "reuse")
247+
reuse = conf["priori"]["reuse"]
248+
@info "Reuse a-priori training from closure named: $reuse"
249+
reusepriorfile(reuse, outdir, closure_name)
250+
end
251+
244252
# Train
245253
for i = 1:ntrajectory
246254
if i%numtasks == taskid -1

configs/snellius/cnn_1.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ params:
1010
tsim: 5.0
1111
savefreq: 50
1212
ndns: 4096
13-
nles: [128]
13+
nles: [64]
1414
filters: ["FaceAverage()"]
1515
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
1616
method: "RKMethods.Wray3(; T)"
@@ -24,7 +24,7 @@ seeds:
2424
prior: 345
2525
post: 456
2626
closure:
27-
name: "cnn"
27+
name: "cnn_project"
2828
type: cnn
2929
radii: [2, 2, 2, 2, 2]
3030
channels: [24, 24, 24, 24, 2]
@@ -42,9 +42,10 @@ posteriori:
4242
dotrain: true
4343
projectorders: "(ProjectOrder.Last, )"
4444
nepoch: 300
45-
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
45+
opt: "OptimiserChain(Adam(T(1.0e-4)), ClipGrad(0.01))"
4646
nunroll: 5
4747
nunroll_valid: 10
48+
nsamples: 1
4849
dt: 0.0001
4950
do_plot: false
5051
plot_train: false

configs/snellius/cnn_2.yaml

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
docreatedata: true
2+
docomp: true
3+
ntrajectory: 8
4+
T: "Float32"
5+
params:
6+
D: 2
7+
lims: [0.0, 1.0]
8+
Re: 6000.0
9+
tburn: 0.5
10+
tsim: 5.0
11+
savefreq: 50
12+
ndns: 4096
13+
nles: [64]
14+
filters: ["FaceAverage()"]
15+
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
16+
method: "RKMethods.Wray3(; T)"
17+
bodyforce: "(dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y)"
18+
issteadybodyforce: true
19+
processors: "(; log = timelogger(; nupdate=100))"
20+
Δt: 0.00005
21+
seeds:
22+
dns: 123456
23+
θ_start: 234
24+
prior: 345
25+
post: 456
26+
closure:
27+
name: "cnn_remove"
28+
type: cnn
29+
radii: [2, 2, 2, 2, 2]
30+
channels: [24, 24, 24, 24, 2]
31+
activations: ["tanh", "tanh", "tanh", "tanh", "identity"]
32+
use_bias: [true, true, true, true, false]
33+
rng: "Xoshiro(seeds.θ_start)"
34+
priori:
35+
reuse: "cnn_project"
36+
dotrain: true
37+
nepoch: 10000
38+
batchsize: 64
39+
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
40+
do_plot: false
41+
plot_train: false
42+
posteriori:
43+
dotrain: true
44+
projectorders: "(ProjectOrder.Last, )"
45+
nepoch: 300
46+
opt: "OptimiserChain(Adam(T(1.0e-4)), ClipGrad(0.01))"
47+
nunroll: 5
48+
nunroll_valid: 10
49+
nsamples: 1
50+
dt: 0.0001
51+
do_plot: false
52+
plot_train: false

configs/snellius/cnn_ins.yaml

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
docreatedata: true
2+
docomp: true
3+
ntrajectory: 8
4+
T: "Float32"
5+
params:
6+
D: 2
7+
lims: [0.0, 1.0]
8+
Re: 6000.0
9+
tburn: 0.5
10+
tsim: 5.0
11+
savefreq: 50
12+
ndns: 4096
13+
nles: [64]
14+
filters: ["FaceAverage()"]
15+
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
16+
method: "RKMethods.Wray3(; T)"
17+
bodyforce: "(dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y)"
18+
issteadybodyforce: true
19+
processors: "(; log = timelogger(; nupdate=100))"
20+
Δt: 0.00005
21+
seeds:
22+
dns: 123
23+
θ_start: 234
24+
prior: 345
25+
post: 456
26+
closure:
27+
name: "cnn_noproj"
28+
type: cnn
29+
radii: [2, 2, 2, 2, 2]
30+
channels: [24, 24, 24, 24, 2]
31+
activations: ["tanh", "tanh", "tanh", "tanh", "identity"]
32+
use_bias: [true, true, true, true, false]
33+
rng: "Xoshiro(seeds.θ_start)"
34+
priori:
35+
dotrain: true
36+
nepoch: 10000
37+
batchsize: 64
38+
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
39+
do_plot: false
40+
plot_train: false
41+
posteriori:
42+
dotrain: true
43+
projectorders: "(ProjectOrder.Last, )"
44+
nepoch: 300
45+
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
46+
nunroll: 5
47+
nunroll_valid: 10
48+
dt: 0.0001
49+
do_plot: false
50+
plot_train: false

configs/snellius/conf_INS.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ params:
1212
tsim: 5.0
1313
savefreq: 50
1414
ndns: 4096
15-
nles: [128]
15+
nles: [64]
1616
filters: ["FaceAverage()"]
1717
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
1818
method: "RKMethods.Wray3(; T)"

src/Benchmark.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,5 +111,6 @@ export _convert_to_single_index,
111111
plot_epost_vs_t
112112

113113
export compute_eprior, compute_epost, compute_t_prior_inference
114+
export reusepriorfile
114115

115116
end # module Benchmark

src/train.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function createdata(; params, seed, outdir, backend)
2626
@info "Creating data for" nles Φ seed
2727

2828
data = NS.create_les_data_projected(
29-
nchunks = 500;
29+
nchunks = 8000;
3030
params...,
3131
rng = Xoshiro(seed),
3232
backend = backend,
@@ -44,6 +44,22 @@ function getpriorfile(outdir, closure_name, nles, filter)
4444
)
4545
end
4646

47+
function reusepriorfile(reuse, outdir, closure_name)
48+
reusepath = joinpath(outdir, "priortraining", reuse)
49+
targetpath = joinpath(outdir, "priortraining", closure_name)
50+
# If the reuse path exists, copy it to the target path
51+
if ispath(reusepath)
52+
@info "Reusing prior training from $(reusepath) to $(targetpath)"
53+
ispath(targetpath) || mkpath(targetpath)
54+
for file in readdir(reusepath, join = true)
55+
@info "Copying prior training file $(file) to $(targetpath)"
56+
cp(file, joinpath(targetpath, basename(file)); force = true)
57+
end
58+
else
59+
@warn "Reuse path $(reusepath) does not exist. Not reusing prior training."
60+
end
61+
end
62+
4763
"Load a-priori training results from correct file names."
4864
loadprior(outdir, closure_name, nles, filters) = map(
4965
splat((nles, Φ) -> load_object(getpriorfile(outdir, closure_name, nles, Φ))),
@@ -104,7 +120,6 @@ function trainprior(;
104120
# Read the data in the format expected by the CoupledNODE
105121
data_train = load_data_set(outdir, nles, Φ, dns_seeds_train)
106122
data_valid = load_data_set(outdir, nles, Φ, dns_seeds_valid)
107-
108123
@assert length(nles) == 1 "Only one nles for a-priori training"
109124
io_train = NS.create_io_arrays_priori(data_train, setup[1], device)
110125
io_valid = NS.create_io_arrays_priori(data_valid, setup[1], device)

0 commit comments

Comments
 (0)