Skip to content

Commit df43edf

Browse files
committed
split energy evolution and histogram
1 parent 6be0f3d commit df43edf

File tree

3 files changed

+131
-97
lines changed

3 files changed

+131
-97
lines changed

benchmark.jl

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,13 @@ plot_labels = Dict(
113113
ylabel = "Face-average",
114114
),
115115
:energy_evolution => (
116-
title = "Energy evolution for different configurations",
117-
xlabel1 = "t",
118-
xlabel2 = "frequency",
116+
title = "Energy evolution for different configurations",
117+
xlabel = "t",
118+
ylabel = "E(t)",
119+
),
120+
:energy_evolution_hist => (
121+
title = "Energy histogram for different configurations",
122+
xlabel = "frequency",
119123
ylabel = "E(t)",
120124
),
121125
:energy_spectra => (
@@ -158,36 +162,14 @@ for key in keys(plot_labels)
158162

159163
# Create the figure
160164
fig = Figure(; size = (950, 600))
161-
if key != :energy_spectra && key != :energy_evolution
165+
if key != :energy_spectra
162166
ax = Axis(
163167
fig[1, 1];
164168
title = plot_labels[key].title,
165169
xlabel = plot_labels[key].xlabel,
166170
ylabel = plot_labels[key].ylabel,
167171
)
168172
end
169-
if key == :energy_evolution
170-
Label(
171-
fig[1, 1],
172-
plot_labels[key].title;
173-
font = :bold,
174-
tellwidth = false,
175-
)
176-
gplot = fig[2, 1]
177-
ax1 = Axis(
178-
gplot[1, 1];
179-
xlabel = plot_labels[key].xlabel1,
180-
ylabel = plot_labels[key].ylabel,
181-
)
182-
ax = ax1
183-
ax2 = Axis(
184-
gplot[1, 2];
185-
xlabel = plot_labels[key].xlabel2,
186-
ylabel = plot_labels[key].ylabel,
187-
yaxisposition = :right,
188-
xreversed=true,
189-
)
190-
end
191173

192174
# empty list for barplots
193175
bar_positions = Float64[]
@@ -248,7 +230,12 @@ for key in keys(plot_labels)
248230

249231
elseif key == :energy_evolution
250232
plot_energy_evolution(
251-
outdir, closure_name, nles, Φ, data_index, ax1, ax2, color, PLOT_STYLES
233+
outdir, closure_name, nles, Φ, data_index, ax, color, PLOT_STYLES
234+
)
235+
236+
elseif key == :energy_evolution_hist
237+
plot_energy_evolution_hist(
238+
outdir, closure_name, nles, Φ, data_index, ax, color, PLOT_STYLES
252239
)
253240

254241
elseif key== :energy_spectra

src/Benchmark.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ export _convert_to_single_index,
103103
plot_posteriori_traininghistory,
104104
plot_divergence,
105105
plot_energy_evolution,
106+
plot_energy_evolution_hist,
106107
plot_energy_spectra,
107108
plot_training_time,
108109
plot_inference_time,

src/plots.jl

Lines changed: 116 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,14 @@ function plot_divergence(outdir, closure_name, nles, Φ, data_index, ax, color,
191191
ax.yscale = log10
192192
end
193193

194-
function _plot_histogram(ax, data, color, linestyle, linewidth)
194+
function _plot_histogram(ax, data, label, color, linestyle, linewidth)
195195
hist_data = hist!(ax, [p[2] for p in data]; bins = 10, color = (:transparent, 0.0))
196196
centers = hist_data.plots[1][1][]
197197
lines!(
198198
ax,
199199
[p[2] for p in centers], # x-values are frequencies
200200
[p[1] for p in centers], # y-values are bin centers
201+
label = label,
201202
linestyle = linestyle,
202203
linewidth = linewidth,
203204
color = color,
@@ -210,8 +211,7 @@ function plot_energy_evolution(
210211
nles,
211212
Φ,
212213
data_index,
213-
ax1,
214-
ax2,
214+
ax,
215215
color,
216216
PLOT_STYLES,
217217
)
@@ -223,150 +223,196 @@ function plot_energy_evolution(
223223
end
224224
energyhistory = namedtupleload(energy_dir).energyhistory;
225225

226-
num_bins = 10
227-
228-
229226
if closure_name == "INS_ref"
230227
label = "No closure "
231-
if _missing_label(ax1, label) && haskey(energyhistory, Symbol("nomodel"))
228+
if _missing_label(ax, label) && haskey(energyhistory, Symbol("nomodel"))
232229
lines!(
233-
ax1,
230+
ax,
234231
energyhistory.nomodel[data_index];
235232
label = label,
236233
linestyle = PLOT_STYLES[:no_closure].linestyle,
237234
linewidth = PLOT_STYLES[:no_closure].linewidth,
238235
color = PLOT_STYLES[:no_closure].color,
239236
)
240-
_plot_histogram(
241-
ax2,
242-
energyhistory.nomodel[data_index],
243-
PLOT_STYLES[:no_closure].color,
244-
PLOT_STYLES[:no_closure].linestyle,
245-
PLOT_STYLES[:no_closure].linewidth,
246-
)
247237
end
248238
# add reference only once
249239
label = "Reference"
250-
if _missing_label(ax1, label) && haskey(energyhistory, Symbol("ref"))
240+
if _missing_label(ax, label) && haskey(energyhistory, Symbol("ref"))
251241
lines!(
252-
ax1,
242+
ax,
253243
energyhistory.ref[data_index];
254244
color = PLOT_STYLES[:reference].color,
255245
linestyle = PLOT_STYLES[:reference].linestyle,
256246
linewidth = PLOT_STYLES[:reference].linewidth,
257247
label = label,
258248
)
259-
_plot_histogram(
260-
ax2,
261-
energyhistory.ref[data_index],
262-
PLOT_STYLES[:reference].color,
263-
PLOT_STYLES[:reference].linestyle,
264-
PLOT_STYLES[:reference].linewidth,
265-
)
266249
end
267250
end
268251

269252
if closure_name == "cnn_1"
270253
label = "No closure (projected dyn)"
271-
if _missing_label(ax1, label) && haskey(energyhistory, Symbol("nomodel"))
254+
if _missing_label(ax, label) && haskey(energyhistory, Symbol("nomodel"))
272255
lines!(
273-
ax1,
256+
ax,
274257
energyhistory.nomodel[data_index];
275258
label = label,
276259
linestyle = PLOT_STYLES[:no_closure_proj].linestyle,
277260
linewidth = PLOT_STYLES[:no_closure_proj].linewidth,
278261
color = PLOT_STYLES[:no_closure_proj].color,
279262
)
280-
_plot_histogram(
281-
ax2,
282-
energyhistory.nomodel[data_index],
283-
PLOT_STYLES[:no_closure_proj].color,
284-
PLOT_STYLES[:no_closure_proj].linestyle,
285-
PLOT_STYLES[:no_closure_proj].linewidth,
286-
)
287263
end
288264
label = "Reference (projected dyn)"
289-
if _missing_label(ax1, label) && haskey(energyhistory, Symbol("ref"))
265+
if _missing_label(ax, label) && haskey(energyhistory, Symbol("ref"))
290266
lines!(
291-
ax1,
267+
ax,
292268
energyhistory.ref[data_index];
293269
color = PLOT_STYLES[:reference_proj].color,
294270
linestyle = PLOT_STYLES[:reference_proj].linestyle,
295271
linewidth = PLOT_STYLES[:reference_proj].linewidth,
296272
label = label,
297273
)
298-
_plot_histogram(
299-
ax2,
300-
energyhistory.ref[data_index],
301-
PLOT_STYLES[:reference_proj].color,
302-
PLOT_STYLES[:reference_proj].linestyle,
303-
PLOT_STYLES[:reference_proj].linewidth,
304-
)
305274
end
306275
end
307276

308277
if haskey(energyhistory, Symbol("smag"))
309278
lines!(
310-
ax1,
279+
ax,
311280
energyhistory.smag[data_index];
312281
color = PLOT_STYLES[:smag].color,
313282
linestyle = PLOT_STYLES[:smag].linestyle,
314283
linewidth = PLOT_STYLES[:smag].linewidth,
315284
label = "$closure_name (smag) (n = $nles)",
316285
)
317-
_plot_histogram(
318-
ax2,
319-
energyhistory.smag[data_index],
320-
PLOT_STYLES[:smag].color,
321-
PLOT_STYLES[:smag].linestyle,
322-
PLOT_STYLES[:smag].linewidth,
323-
)
324286
end
325287

326288
label = Φ isa FaceAverage ? "FA" : "VA"
327289
# prior
328290
lines!(
329-
ax1,
291+
ax,
330292
energyhistory.model_prior[data_index];
331293
label = "$closure_name (prior) (n = $nles, $label)",
332294
linestyle = PLOT_STYLES[:prior].linestyle,
333295
linewidth = PLOT_STYLES[:prior].linewidth,
334296
color = color, # dont change this color
335297
)
336-
_plot_histogram(
337-
ax2,
338-
energyhistory.model_prior[data_index],
339-
color,
340-
PLOT_STYLES[:prior].linestyle,
341-
PLOT_STYLES[:prior].linewidth,
342-
)
343-
344298
# post
345299
lines!(
346-
ax1,
300+
ax,
347301
energyhistory.model_post[data_index];
348302
label = "$closure_name (post) (n = $nles, $label)",
349303
linestyle = PLOT_STYLES[:post].linestyle,
350304
linewidth = PLOT_STYLES[:post].linewidth,
351305
color = color, # dont change this color
352306
)
307+
308+
# update axis limits
309+
x_values = [point[1] for v in values(energyhistory) for point in v[data_index]]
310+
y_values = [point[2] for v in values(energyhistory) for point in v[data_index]]
311+
ax = _update_ax_limits(ax, x_values, y_values)
312+
313+
end
314+
315+
function plot_energy_evolution_hist(
316+
outdir,
317+
closure_name,
318+
nles,
319+
Φ,
320+
data_index,
321+
ax,
322+
color,
323+
PLOT_STYLES,
324+
)
325+
# Load learned parameters
326+
energy_dir = joinpath(outdir, closure_name, "history.jld2")
327+
if !ispath(energy_dir)
328+
@warn "Energy history not found in $energy_dir"
329+
return
330+
end
331+
energyhistory = namedtupleload(energy_dir).energyhistory;
332+
333+
if closure_name == "INS_ref"
334+
label = "No closure "
335+
if _missing_label(ax, label) && haskey(energyhistory, Symbol("nomodel"))
336+
_plot_histogram(
337+
ax,
338+
energyhistory.nomodel[data_index],
339+
label,
340+
PLOT_STYLES[:no_closure].color,
341+
PLOT_STYLES[:no_closure].linestyle,
342+
PLOT_STYLES[:no_closure].linewidth,
343+
)
344+
end
345+
# add reference only once
346+
label = "Reference"
347+
if _missing_label(ax, label) && haskey(energyhistory, Symbol("ref"))
348+
_plot_histogram(
349+
ax,
350+
energyhistory.ref[data_index],
351+
label,
352+
PLOT_STYLES[:reference].color,
353+
PLOT_STYLES[:reference].linestyle,
354+
PLOT_STYLES[:reference].linewidth,
355+
)
356+
end
357+
end
358+
359+
if closure_name == "cnn_1"
360+
label = "No closure (projected dyn)"
361+
if _missing_label(ax, label) && haskey(energyhistory, Symbol("nomodel"))
362+
_plot_histogram(
363+
ax,
364+
energyhistory.nomodel[data_index],
365+
label,
366+
PLOT_STYLES[:no_closure_proj].color,
367+
PLOT_STYLES[:no_closure_proj].linestyle,
368+
PLOT_STYLES[:no_closure_proj].linewidth,
369+
)
370+
end
371+
label = "Reference (projected dyn)"
372+
if _missing_label(ax, label) && haskey(energyhistory, Symbol("ref"))
373+
_plot_histogram(
374+
ax,
375+
energyhistory.ref[data_index],
376+
label,
377+
PLOT_STYLES[:reference_proj].color,
378+
PLOT_STYLES[:reference_proj].linestyle,
379+
PLOT_STYLES[:reference_proj].linewidth,
380+
)
381+
end
382+
end
383+
384+
if haskey(energyhistory, Symbol("smag"))
385+
_plot_histogram(
386+
ax,
387+
energyhistory.smag[data_index],
388+
"$closure_name (smag) (n = $nles)",
389+
PLOT_STYLES[:smag].color,
390+
PLOT_STYLES[:smag].linestyle,
391+
PLOT_STYLES[:smag].linewidth,
392+
)
393+
end
394+
395+
label = Φ isa FaceAverage ? "FA" : "VA"
396+
# prior
397+
_plot_histogram(
398+
ax,
399+
energyhistory.model_prior[data_index],
400+
"$closure_name (prior) (n = $nles, $label)",
401+
color,
402+
PLOT_STYLES[:prior].linestyle,
403+
PLOT_STYLES[:prior].linewidth,
404+
)
405+
406+
# post
353407
_plot_histogram(
354-
ax2,
408+
ax,
355409
energyhistory.model_post[data_index],
410+
"$closure_name (post) (n = $nles, $label)",
356411
color,
357412
PLOT_STYLES[:post].linestyle,
358413
PLOT_STYLES[:post].linewidth,
359414
)
360415

361-
# update axis limits
362-
x_values = [point[1] for v in values(energyhistory) for point in v[data_index]]
363-
y_values = [point[2] for v in values(energyhistory) for point in v[data_index]]
364-
ax1 = _update_ax_limits(ax1, x_values, y_values)
365-
366-
# set the y-axis limits of ax2 to be the same as ax1
367-
ymin, ymax = ax1.limits[][2]
368-
ylims!(ax2, ymin, ymax)
369-
370416
end
371417

372418
function _get_spectra(setup, u)

0 commit comments

Comments
 (0)