Skip to content

Commit 311edad

Browse files
committed
Improve plots
1 parent 13606c7 commit 311edad

File tree

4 files changed

+108
-87
lines changed

4 files changed

+108
-87
lines changed

benchmark.jl

Lines changed: 96 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ end
1111
basedir = haskey(ENV, "DEEPDIP") ? ENV["DEEPDIP"] : @__DIR__
1212
outdir = joinpath(basedir, "output", "kolmogorov")
1313
confdir = joinpath(basedir, "configs/local")
14-
#confdir = joinpath(basedir, "configs/snellius64")
14+
confdir = joinpath(basedir, "configs/snellius64")
1515
@warn "Using configuration files from $confdir"
1616
compdir = joinpath(outdir, "comparison")
1717
ispath(compdir) || mkpath(compdir)
@@ -20,12 +20,10 @@ ispath(compdir) || mkpath(compdir)
2020
using Glob
2121
#exclude_patterns = ["att", "cno", "cnn_ins", "_1", "nopr"]
2222
exclude_patterns = ["att", "cno", "int", "back", "rk4", "cnn_1" ]
23+
exclude_patterns = ["att", "cno" ]
2324
@warn "Excluding configurations with patterns: $(exclude_patterns)"
2425
all_confs = glob("*.yaml", confdir)
2526
list_confs = filter(conf -> all(!occursin(pat, conf) for pat in exclude_patterns), all_confs)
26-
if isempty(list_confs)
27-
@error "No configuration files found in $confdir"
28-
end
2927

3028

3129
using Pkg
@@ -63,6 +61,29 @@ else
6361
backend = IncompressibleNavierStokes.CPU()
6462
end
6563

64+
# Create a new list to store valid configs
65+
valid_confs = String[]
66+
for conf_file in list_confs
67+
closure_name, params, conf = read_config(outdir, conf_file, backend)
68+
if !check_necessary_files(
69+
outdir,
70+
closure_name,
71+
params.nles[1],
72+
params.filters[1],
73+
eval(Meta.parse(conf["posteriori"]["projectorders"]))[1]
74+
)
75+
@error "Some files are missing for configuration $conf_file. Skipping"
76+
continue
77+
end
78+
push!(valid_confs, conf_file)
79+
end
80+
# Replace list_confs with only the valid ones
81+
list_confs = valid_confs
82+
83+
if isempty(list_confs)
84+
@error "No configuration files found in $confdir"
85+
end
86+
6687
# Global variables for setting linestyle and colors in all plots
6788
PLOT_STYLES = Dict(
6889
:no_closure => (color="black", linestyle=:dash, linewidth=2),
@@ -101,82 +122,82 @@ colors_list = [
101122

102123
# Loop over plot types and configurations
103124
plot_labels = Dict(
104-
#:prior_hist => (
105-
# title = "A-priori training history for different configurations",
106-
# xlabel = "Iteration",
107-
# ylabel = "A-priori error",
108-
#),
109-
#:posteriori_hist => (
110-
# title = "A-posteriori training history for different configurations",
111-
# xlabel = "Iteration",
112-
# ylabel = "DCF",
113-
#),
125+
:prior_hist => (
126+
title = "A-priori training history for different configurations",
127+
xlabel = "Iteration",
128+
ylabel = "A-priori error",
129+
),
130+
:posteriori_hist => (
131+
title = "A-posteriori training history for different configurations",
132+
xlabel = "Iteration",
133+
ylabel = "DCF",
134+
),
114135
:dns_solution => (
115136
title = "DNS solution for different configurations",
116137
xlabel = "t",
117138
ylabel = L"\frac{|u(t)-u_{proj}(t)|}{|u(t)|}",
118139
),
119-
#:divergence => (
120-
# title = "Divergence for different configurations",
121-
# xlabel = "t",
122-
# ylabel = "Face-average",
123-
#),
124-
#:energy_evolution => (
125-
# title = "Energy evolution for different configurations",
126-
# xlabel = "t",
127-
# ylabel = "E(t)",
128-
#),
129-
#:energy_evolution_hist => (
130-
# title = "Energy histogram for different configurations",
131-
# xlabel = "frequency",
132-
# ylabel = "E(t)",
133-
#),
134-
#:energy_spectra => (
135-
# title = "Energy spectra",
136-
#),
137-
#:training_time => (
138-
# title = "Training time for different configurations",
139-
# xlabel = "Model",
140-
# ylabel = "Training time (s) (per iteration)",
141-
#),
142-
#:training_comptime => (
143-
# title = "Training time for different configurations",
144-
# xlabel = "Model",
145-
# ylabel = "Full Training time (s)",
146-
#),
147-
#:inference_time => (
148-
# title = "Inference time for different configurations",
149-
# xlabel = "Model",
150-
# ylabel = "Inference time (s)",
151-
#),
152-
#:num_parameters => (
153-
# title = "Number of parameters for different configurations",
154-
# xlabel = "Model",
155-
# ylabel = "Number of parameters",
156-
#),
157-
#:eprior => (
158-
# title = "A-prior error for different configurations",
159-
# xlabel = "Model",
160-
# ylabel = "A-prior error",
161-
#),
162-
#:epost => (
163-
# title = "A-posteriori error for different configurations",
164-
# xlabel = "Model",
165-
# ylabel = "A-posteriori error",
166-
#),
167-
#:epost_vs_t => (
168-
# title = "A-posteriori error as a function of time",
169-
# xlabel = "t",
170-
# ylabel = L"e_{M}(t)",
171-
#),
140+
:divergence => (
141+
title = "Divergence for different configurations",
142+
xlabel = "t",
143+
ylabel = "Face-average",
144+
),
145+
:energy_evolution => (
146+
title = "Energy evolution for different configurations",
147+
xlabel = "t",
148+
ylabel = "E(t)",
149+
),
150+
:energy_evolution_hist => (
151+
title = "Energy histogram for different configurations",
152+
xlabel = "frequency",
153+
ylabel = "E(t)",
154+
),
155+
:energy_spectra => (
156+
title = "Energy spectra",
157+
),
158+
:training_time => (
159+
title = "Training time for different configurations",
160+
xlabel = "Model",
161+
ylabel = "Training time (s) (per iteration)",
162+
),
163+
:training_comptime => (
164+
title = "Training time for different configurations",
165+
xlabel = "Model",
166+
ylabel = "Full Training time (s)",
167+
),
168+
:inference_time => (
169+
title = "Inference time for different configurations",
170+
xlabel = "Model",
171+
ylabel = "Inference time (s)",
172+
),
173+
:num_parameters => (
174+
title = "Number of parameters for different configurations",
175+
xlabel = "Model",
176+
ylabel = "Number of parameters",
177+
),
178+
:eprior => (
179+
title = "A-prior error for different configurations",
180+
xlabel = "Model",
181+
ylabel = "A-prior error",
182+
),
183+
:epost => (
184+
title = "A-posteriori error for different configurations",
185+
xlabel = "Model",
186+
ylabel = "A-posteriori error",
187+
),
188+
:epost_vs_t => (
189+
title = "A-posteriori error as a function of time",
190+
xlabel = "t",
191+
ylabel = L"e_{M}(t)",
192+
),
172193
)
173194

174195
dns_seeds = splitseed(123456, 8)
175196
dns_seeds = splitseed(16, 8)
176197
#dns_seeds = [0x185efb6b]
177198

178199
for key in keys(plot_labels)
179-
@info "Plotting $key"
200+
@info "******************** Plotting $key"
180201

181202
# Create the figure
182203
fig = Figure(; size = (950, 600))
@@ -209,16 +230,6 @@ for key in keys(plot_labels)
209230
continue
210231
end
211232

212-
#if !check_necessary_files(
213-
# outdir,
214-
# closure_name,
215-
# nles,
216-
# Φ,
217-
# projectorders[1],
218-
#)
219-
# @error "Some files are missing for configuration $conf_file. Skipping"
220-
# continue
221-
#end
222233

223234
# make sure each combination has a consistent color
224235
#TODO this function should be tested
@@ -315,9 +326,13 @@ for key in keys(plot_labels)
315326
error_file, closure_name, nles, ax, color, PLOT_STYLES
316327
)
317328
elseif key == :dns_solution
318-
data = load("output/kolmogorov/test_dns_proj.jld2")
319-
plot_dns_solution(
320-
data, ax, 5, joinpath(compdir, "projection_dns_test_nles=$(nles).gif")
329+
infile = "output/kolmogorov/test_dns_proj.jld2"
330+
outfile = joinpath(
331+
compdir, "projection_dns_test_nles=$(nles).gif"
332+
)
333+
334+
isfile(outfile) || plot_dns_solution(
335+
ax, 5, infile, outfile
321336
)
322337
else
323338
@error "Unknown plot type: $key"

configs/snellius64/cnn_test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ seeds:
2626
prior: 345
2727
post: 456
2828
closure:
29-
name: "Base"
29+
name: "Test"
3030
type: cnn
3131
radii: [2, 2, 2, 2, 2]
3232
channels: [24, 24, 24, 24, 2]

configs/snellius64/cnn_vern.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ seeds:
2626
prior: 345
2727
post: 456
2828
closure:
29-
name: "Rodas"
29+
name: "Vern"
3030
type: cnn
3131
radii: [2, 2, 2, 2, 2]
3232
channels: [24, 24, 24, 24, 2]

src/plots.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,9 @@ function plot_energy_evolution(
254254
end
255255
end
256256

257-
if closure_name == "Project"
257+
if closure_name == "Base"
258258
label = "No closure (projected dyn)"
259+
label = "No closure"
259260
if _missing_label(ax, label) && haskey(energyhistory, Symbol("nomodel"))
260261
lines!(
261262
ax,
@@ -267,6 +268,7 @@ function plot_energy_evolution(
267268
)
268269
end
269270
label = "Reference (projected dyn)"
271+
label = "Reference"
270272
if _missing_label(ax, label) && haskey(energyhistory, Symbol("ref"))
271273
lines!(
272274
ax,
@@ -361,8 +363,9 @@ function plot_energy_evolution_hist(
361363
end
362364
end
363365

364-
if closure_name == "Project"
366+
if closure_name == "Base"
365367
label = "No closure (projected dyn)"
368+
label = "No closure"
366369
if _missing_label(ax, label) && haskey(energyhistory, Symbol("nomodel"))
367370
_plot_histogram(
368371
ax,
@@ -374,6 +377,7 @@ function plot_energy_evolution_hist(
374377
)
375378
end
376379
label = "Reference (projected dyn)"
380+
label = "Reference"
377381
if _missing_label(ax, label) && haskey(energyhistory, Symbol("ref"))
378382
_plot_histogram(
379383
ax,
@@ -868,8 +872,9 @@ function plot_epost_vs_t(error_file, closure_name, nles, ax, color, PLOT_STYLES)
868872
end
869873
end
870874

871-
if closure_name == "NoProjection"
875+
if closure_name == "Base"
872876
label = "No model (projected dyn)"
877+
label = "No model"
873878
if _missing_label(ax, label) # add No closure only once
874879
scatterlines!(
875880
ax,
@@ -887,7 +892,8 @@ function plot_epost_vs_t(error_file, closure_name, nles, ax, color, PLOT_STYLES)
887892
return nothing
888893
end
889894

890-
function plot_dns_solution(data, ax, frameskip=5, savepath = "dns_solution.gif")
895+
function plot_dns_solution(ax, frameskip, infile, savepath)
896+
data = load(infile)
891897
# use inside
892898
ref = data["uref"]
893899
proj = data["u"]

0 commit comments

Comments
 (0)