Skip to content

Commit f799bd0

Browse files
committed
autopush
1 parent 31c6151 commit f799bd0

File tree

4 files changed

+134
-57
lines changed

4 files changed

+134
-57
lines changed

benchmark.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ ispath(compdir) || mkpath(compdir)
2020
using Glob
2121
exclude_patterns = ["att", "cno", "int", "back", "rk4", "cnn_1" ]
2222
exclude_patterns = ["att", "cno" ]
23-
include_patterns = ["base"]
24-
include_patterns = []
25-
exclude_patterns = []
23+
exclude_patterns = ["rod", "ken", "ow" ]
24+
include_patterns = ["cnn_base", "ins"]
2625

2726
if !isempty(include_patterns)
2827
@warn "Including only configurations with patterns: $(include_patterns)"
@@ -354,9 +353,9 @@ for key in keys(plot_labels)
354353
end
355354

356355
# Add xticks in barplot
357-
if key in (:training_time, :training_comptime, :inference_time, :num_parameters, :eprior, :epost)
358-
ax.xticks = (bar_positions, bar_labels)
359-
end
356+
#if key in (:training_time, :training_comptime, :inference_time, :num_parameters, :eprior, :epost)
357+
# ax.xticks = (bar_positions, bar_labels)
358+
#end
360359

361360
# Set log-log scale
362361
if key == :epost_vs_t

configs/snellius64/ins.yaml

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
docreatedata: true
2+
docomp: true
3+
ntrajectory: 8
4+
T: "Float32"
5+
dataproj: false
6+
params:
7+
D: 2
8+
lims: [0.0, 1.0]
9+
Re: 6000.0
10+
tburn: 0.5
11+
tsim: 5.0
12+
savefreq: 50
13+
ndns: 4096
14+
nles: [64]
15+
filters: ["FaceAverage()"]
16+
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
17+
method: "RKMethods.Wray3(; T)"
18+
bodyforce: "(dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y)"
19+
issteadybodyforce: true
20+
processors: "(; log = timelogger(; nupdate=100))"
21+
Δt: 0.00005
22+
seeds:
23+
dns: 123456
24+
θ_start: 234
25+
prior: 345
26+
post: 456
27+
closure:
28+
name: "INS.jl"
29+
type: cnn
30+
radii: [2, 2, 2, 2, 2]
31+
channels: [24, 24, 24, 24, 2]
32+
activations: ["tanh", "tanh", "tanh", "tanh", "identity"]
33+
use_bias: [true, true, true, true, false]
34+
rng: "Xoshiro(seeds.θ_start)"
35+
priori:
36+
dotrain: true
37+
nepoch: 50000
38+
batchsize: 64
39+
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
40+
do_plot: false
41+
plot_train: false
42+
posteriori:
43+
dotrain: true
44+
projectorders: "(ProjectOrder.Last, )"
45+
nepoch: 3000
46+
opt: "Adam(T(1.0e-4))"
47+
nunroll: 5
48+
nunroll_valid: 10
49+
dt: 0.0001
50+
do_plot: false
51+
plot_train: false
52+
nsamples: 5
53+
sciml_solver: "Tsit5()"
54+

multisub.sh

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,19 @@
1010
#sbatch -J CNN_nopr job_a100.sh configs/snellius/cnn_noproj.yaml
1111

1212
#sbatch -J base job_a100.sh configs/snellius64/cnn_base.yaml
13-
#sbatch -J backsolve job_a100.sh configs/snellius64/cnn_backsolve.yaml
14-
#sbatch -J gauss job_a100.sh configs/snellius64/cnn_gauss.yaml
15-
#sbatch -J interp job_a100.sh configs/snellius64/cnn_interp.yaml
16-
#sbatch -J multishooting job_a100.sh configs/snellius64/cnn_multishooting.yaml
17-
#sbatch -J tsit5 job_a100.sh configs/snellius64/cnn_tsit5.yaml
18-
#sbatch -J rodas job_a100.sh configs/snellius64/cnn_rodas.yaml
19-
#sbatch -J rodaskryl job_a100.sh configs/snellius64/cnn_rodaskryl.yaml
20-
#sbatch -J rosenb job_a100.sh configs/snellius64/cnn_rosenb.yaml
21-
#sbatch -J kencarp job_a100.sh configs/snellius64/cnn_kencarp.yaml
22-
#sbatch -J vern job_a100.sh configs/snellius64/cnn_vern.yaml
13+
sbatch -J backsolve job_a100.sh configs/snellius64/cnn_backsolve.yaml
14+
sbatch -J gauss job_a100.sh configs/snellius64/cnn_gauss.yaml
15+
sbatch -J interp job_a100.sh configs/snellius64/cnn_interp.yaml
16+
sbatch -J multishooting job_a100.sh configs/snellius64/cnn_multishooting.yaml
17+
sbatch -J tsit5 job_a100.sh configs/snellius64/cnn_tsit5.yaml
18+
sbatch -J rodas job_a100.sh configs/snellius64/cnn_rodas.yaml
19+
sbatch -J rodaskryl job_a100.sh configs/snellius64/cnn_rodaskryl.yaml
20+
sbatch -J rosenb job_a100.sh configs/snellius64/cnn_rosenb.yaml
21+
sbatch -J kencarp job_a100.sh configs/snellius64/cnn_kencarp.yaml
22+
sbatch -J vern job_a100.sh configs/snellius64/cnn_vern.yaml
2323
#sbatch -J owr job_a100.sh configs/snellius64/cnn_owr.yaml
2424
#sbatch -J bs3 job_a100.sh configs/snellius64/cnn_bs3.yaml
25-
#sbatch -J composite job_a100.sh configs/snellius64/cnn_composite.yaml
25+
sbatch -J composite job_a100.sh configs/snellius64/cnn_composite.yaml
2626

2727
#sbatch -J att job_a100_extra.sh configs/snellius64/att_base.yaml
2828
sbatch -J cno job_a100_extra.sh configs/snellius64/cno_base.yaml

src/plots.jl

Lines changed: 64 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ function plot_posteriori_traininghistory(
120120
end
121121

122122
function plot_divergence(outdir, closure_name, nles, Φ, data_index, ax, color, PLOT_STYLES)
123+
if closure_name == "INS.jl"
124+
return
125+
end
123126
# Load learned parameters
124127
divergence_dir = joinpath(outdir, closure_name, "history_nles=$(nles).jld2")
125128
if !ispath(divergence_dir)
@@ -228,7 +231,7 @@ function plot_energy_evolution(
228231
end
229232
energyhistory = namedtupleload(energy_dir).energyhistory;
230233

231-
if closure_name == "INS.jl"
234+
if closure_name == "INS.jl" && false
232235
label = "No closure "
233236
if _missing_label(ax, label) && haskey(energyhistory, Symbol("nomodel"))
234237
lines!(
@@ -337,7 +340,7 @@ function plot_energy_evolution_hist(
337340
end
338341
energyhistory = namedtupleload(energy_dir).energyhistory;
339342

340-
if closure_name == "INS.jl"
343+
if closure_name == "INS.jl" && false
341344
label = "No closure "
342345
if _missing_label(ax, label) && haskey(energyhistory, Symbol("nomodel"))
343346
_plot_histogram(
@@ -471,7 +474,11 @@ function plot_energy_spectra(
471474
num_of_models,
472475
color,
473476
PLOT_STYLES,
477+
single_legend = false
474478
)
479+
if closure_name == "INS.jl"
480+
return
481+
end
475482
# Load learned parameters
476483
energy_dir = joinpath(outdir, closure_name, "solutions_nles=$(nles).jld2")
477484
if !ispath(energy_dir)
@@ -488,9 +495,14 @@ function plot_energy_spectra(
488495

489496
# Create a grid of plots and legends
490497
gtitle = fig[1, 1] # Title of the figure
491-
gplot = fig[2, 1]
492-
gplot_ax = gplot[1, 1] # Axis for each plot
493-
gplot_leg = gplot[1, 2] # The common legend for all plots
498+
if single_legend
499+
gplot = fig[1, 1]
500+
gplot_ax = gplot[1, 1]
501+
else
502+
gplot = fig[2, 1]
503+
gplot_ax = gplot[1, 1] # Axis for each plot
504+
gplot_leg = gplot[1, 2] # The common legend for all plots
505+
end
494506

495507
Label(
496508
gtitle,
@@ -593,15 +605,14 @@ function plot_energy_spectra(
593605
[prior_label, post_label],
594606
labelsize = 8,
595607
)
608+
# Add legend that is common for all plots
609+
Legend(
610+
gplot_leg,
611+
[no_closure_plt, reference_plt, inertia_plt],
612+
[no_closure_label, reference_label, inertia_label],
613+
labelsize = 8,
614+
)
596615
end
597-
598-
# Add legend that is common for all plots
599-
Legend(
600-
gplot_leg,
601-
[no_closure_plt, reference_plt, inertia_plt],
602-
[no_closure_label, reference_label, inertia_label],
603-
labelsize = 8,
604-
)
605616
end
606617
end
607618

@@ -831,16 +842,24 @@ function _save_error_data_to_csv(error_data, closure_name, data_index; outdir, n
831842

832843
# Get a posteriori (post) error data, handle case where it might be missing
833844
post_error = error_data.model_post[data_index]
845+
if closure_name == "INS.jl"
846+
post_error = error_data.smag[data_index]
847+
end
834848

835849
# Try to read training time data from training files
836850
training_time_post = nothing
837851
if nles !== nothing && Φ !== nothing && projectorders !== nothing
838852
try
839853
if closure_name == "INS.jl"
840-
posttraining = namedtupleload(
841-
Benchmark.getpostfile(outdir, closure_name, nles, Φ, projectorders[1]),
854+
#posttraining = namedtupleload(
855+
# Benchmark.getpostfile(outdir, closure_name, nles, Φ, projectorders[1]),
856+
#)
857+
#training_time_post = posttraining.single_stored_object.time_per_epoch
858+
859+
smagtrain= namedtupleload(
860+
joinpath(outdir, "smagorinski", "projectorder=Last_filter=FaceAverage()_nles=64.jld2"),
842861
)
843-
training_time_post = posttraining.single_stored_object.time_per_epoch
862+
training_time_post = smagtrain.comptime/300
844863
else
845864
posttraining = loadpost(outdir, closure_name, [nles], [Φ], projectorders)
846865
training_time_post = posttraining[1].time_per_epoch
@@ -890,6 +909,9 @@ function _save_error_data_to_csv(error_data, closure_name, data_index; outdir, n
890909
# Update or add the current closure_name, preserving existing values when new ones are not provided
891910
current_error = post_error
892911
current_training_time = training_time_post
912+
if closure_name == "INS.jl"
913+
closure_name = "Smagorinsky"
914+
end
893915

894916
if haskey(existing_data, closure_name)
895917
# Keep existing values if new ones are not provided (nothing/missing)
@@ -926,35 +948,37 @@ function plot_epost_vs_t(error_file, closure_name, nles, ax, color, PLOT_STYLES)
926948

927949
x = error_data.nts # use time from error data
928950

929-
# Prior
930-
scatterlines!(
931-
ax,
932-
x,
933-
vec(error_data.model_prior);
934-
label = "$closure_name prior (n = $nles)",
935-
color = color,
936-
linestyle = PLOT_STYLES[:prior].linestyle,
937-
linewidth = PLOT_STYLES[:prior].linewidth,
938-
)
939-
940-
# Post
941-
scatterlines!(
942-
ax,
943-
x,
944-
vec(error_data.model_post);
945-
label = "$closure_name post (n = $nles)",
946-
color = color,
947-
linestyle = PLOT_STYLES[:post].linestyle,
948-
linewidth = PLOT_STYLES[:post].linewidth,
949-
)
951+
if closure_name != "INS.jl"
952+
# Prior
953+
scatterlines!(
954+
ax,
955+
x,
956+
vec(error_data.model_prior);
957+
label = "$closure_name prior (n = $nles)",
958+
color = color,
959+
linestyle = PLOT_STYLES[:prior].linestyle,
960+
linewidth = PLOT_STYLES[:prior].linewidth,
961+
)
962+
# Post
963+
scatterlines!(
964+
ax,
965+
x,
966+
vec(error_data.model_post);
967+
label = "$closure_name post (n = $nles)",
968+
color = color,
969+
linestyle = PLOT_STYLES[:post].linestyle,
970+
linewidth = PLOT_STYLES[:post].linewidth,
971+
)
972+
end
950973

951974
# Smagorinsky (optional)
952975
if haskey(error_data, Symbol("smag"))
953976
scatterlines!(
954977
ax,
955-
x,
956-
vec(error_data.smag);
957-
label = "$closure_name smag (n = $nles)",
978+
x[1:end-2],
979+
vec(error_data.smag[1:end-2]);
980+
#label = "$closure_name smag (n = $nles)",
981+
label = "Smagorinsky (n = $nles)",
958982
color = PLOT_STYLES[:smag].color,
959983
linestyle = PLOT_STYLES[:smag].linestyle,
960984
linewidth = PLOT_STYLES[:smag].linewidth,

0 commit comments

Comments
 (0)