Skip to content

Commit b504f76

Browse files
committed
.
1 parent 3bc1f9a commit b504f76

File tree

4 files changed

+70
-14
lines changed

4 files changed

+70
-14
lines changed

benchmark.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ ispath(compdir) || mkpath(compdir)
1919
# List configurations files
2020
using Glob
2121
#exclude_patterns = ["att", "cno", "cnn_ins", "_1", "nopr"]
22-
exclude_patterns = ["att", "cno", "int", "back", "rk4"]
22+
exclude_patterns = ["att", "cno", "int", "back", "rk4", "cnn_1"]
2323
@warn "Excluding configurations with patterns: $(exclude_patterns)"
2424
all_confs = glob("*.yaml", confdir)
2525
list_confs = filter(conf -> all(!occursin(pat, conf) for pat in exclude_patterns), all_confs)
@@ -101,16 +101,16 @@ colors_list = [
101101

102102
# Loop over plot types and configurations
103103
plot_labels = Dict(
104-
:prior_hist => (
105-
title = "A-priori training history for different configurations",
106-
xlabel = "Iteration",
107-
ylabel = "A-priori error",
108-
),
109-
:posteriori_hist => (
110-
title = "A-posteriori training history for different configurations",
111-
xlabel = "Iteration",
112-
ylabel = "DCF",
113-
),
104+
#:prior_hist => (
105+
# title = "A-priori training history for different configurations",
106+
# xlabel = "Iteration",
107+
# ylabel = "A-priori error",
108+
#),
109+
#:posteriori_hist => (
110+
# title = "A-posteriori training history for different configurations",
111+
# xlabel = "Iteration",
112+
# ylabel = "DCF",
113+
#),
114114
:divergence => (
115115
title = "Divergence for different configurations",
116116
xlabel = "t",

configs/snellius/cnn_noproj.yaml

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
docreatedata: true
2+
docomp: true
3+
ntrajectory: 8
4+
T: "Float32"
5+
dataproj: false
6+
params:
7+
D: 2
8+
lims: [0.0, 1.0]
9+
Re: 6000.0
10+
tburn: 0.5
11+
tsim: 5.0
12+
savefreq: 50
13+
ndns: 4096
14+
nles: [64]
15+
filters: ["FaceAverage()"]
16+
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
17+
method: "RKMethods.Wray3(; T)"
18+
bodyforce: "(dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y)"
19+
issteadybodyforce: true
20+
processors: "(; log = timelogger(; nupdate=100))"
21+
Δt: 0.00005
22+
seeds:
23+
dns: 123456
24+
θ_start: 234
25+
prior: 345
26+
post: 456
27+
closure:
28+
name: "NoProjection"
29+
type: cnn
30+
radii: [2, 2, 2, 2, 2]
31+
channels: [24, 24, 24, 24, 2]
32+
activations: ["tanh", "tanh", "tanh", "tanh", "identity"]
33+
use_bias: [true, true, true, true, false]
34+
rng: "Xoshiro(seeds.θ_start)"
35+
priori:
36+
dotrain: true
37+
nepoch: 50000
38+
batchsize: 64
39+
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
40+
do_plot: false
41+
plot_train: false
42+
posteriori:
43+
dotrain: true
44+
projectorders: "(ProjectOrder.Last, )"
45+
nepoch: 3000
46+
opt: "Adam(T(1.0e-4))"
47+
nunroll: 5
48+
nunroll_valid: 10
49+
dt: 0.0001
50+
do_plot: false
51+
plot_train: false
52+
nsamples: 5
53+
sciml_solver: "Tsit5()"

multisub.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,6 @@ sbatch -J CNN job_a100.sh configs/snellius/cnn_1.yaml
55
sbatch -J backsol job_a100.sh configs/snellius/cnn_backsol.yaml
66
sbatch -J interp job_a100.sh configs/snellius/cnn_interp.yaml
77
sbatch -J rk4 job_a100.sh configs/snellius/cnn_rk4.yaml
8+
9+
sbatch -J CNN_proj job_a100.sh configs/snellius/cnn_proj.yaml
10+
sbatch -J CNN_nopr job_a100.sh configs/snellius/cnn_noproj.yaml

src/plots.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ function plot_energy_evolution(
254254
end
255255
end
256256

257-
if closure_name == "cnn_proj"
257+
if closure_name == "Project"
258258
label = "No closure (projected dyn)"
259259
if _missing_label(ax, label) && haskey(energyhistory, Symbol("nomodel"))
260260
lines!(
@@ -361,7 +361,7 @@ function plot_energy_evolution_hist(
361361
end
362362
end
363363

364-
if closure_name == "cnn_1"
364+
if closure_name == "Project"
365365
label = "No closure (projected dyn)"
366366
if _missing_label(ax, label) && haskey(energyhistory, Symbol("nomodel"))
367367
_plot_histogram(
@@ -868,7 +868,7 @@ function plot_epost_vs_t(error_file, closure_name, nles, ax, color, PLOT_STYLES)
868868
end
869869
end
870870

871-
if closure_name == "cnn_noproj"
871+
if closure_name == "NoProjection"
872872
label = "No model (projected dyn)"
873873
if _missing_label(ax, label) # add No closure only once
874874
scatterlines!(

0 commit comments

Comments
 (0)