Skip to content

Commit ca5c011

Browse files
committed
Add dns proj vs noproj plot
1 parent 6c02257 commit ca5c011

File tree

4 files changed

+93
-56
lines changed

4 files changed

+93
-56
lines changed

benchmark.jl

Lines changed: 69 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -111,61 +111,69 @@ plot_labels = Dict(
111111
# xlabel = "Iteration",
112112
# ylabel = "DCF",
113113
#),
114-
:divergence => (
115-
title = "Divergence for different configurations",
114+
:dns_solution => (
115+
title = "DNS solution for different configurations",
116116
xlabel = "t",
117-
ylabel = "Face-average",
118-
),
119-
:energy_evolution => (
120-
title = "Energy evolution for different configurations",
121-
xlabel = "t",
122-
ylabel = "E(t)",
123-
),
124-
:energy_evolution_hist => (
125-
title = "Energy histogram for different configurations",
126-
xlabel = "frequency",
127-
ylabel = "E(t)",
128-
),
129-
:energy_spectra => (
130-
title = "Energy spectra",
131-
),
132-
:training_time => (
133-
title = "Training time for different configurations",
134-
xlabel = "Model",
135-
ylabel = "Training time (s) (per iteration)",
136-
),
137-
:training_comptime => (
138-
title = "Training time for different configurations",
139-
xlabel = "Model",
140-
ylabel = "Full Training time (s)",
141-
),
142-
:inference_time => (
143-
title = "Inference time for different configurations",
144-
xlabel = "Model",
145-
ylabel = "Inference time (s)",
146-
),
147-
:num_parameters => (
148-
title = "Number of parameters for different configurations",
149-
xlabel = "Model",
150-
ylabel = "Number of parameters",
151-
),
152-
:eprior => (
153-
title = "A-prior error for different configurations",
154-
xlabel = "Model",
155-
ylabel = "A-prior error",
156-
),
157-
:epost => (
158-
title = "A-posteriori error for different configurations",
159-
xlabel = "Model",
160-
ylabel = "A-posteriori error",
161-
),
162-
:epost_vs_t => (
163-
title = "A-posteriori error as a function of time",
164-
xlabel = "t",
165-
ylabel = L"e_{M}(t)",
117+
ylabel = "MAE",
166118
),
119+
#:divergence => (
120+
# title = "Divergence for different configurations",
121+
# xlabel = "t",
122+
# ylabel = "Face-average",
123+
#),
124+
#:energy_evolution => (
125+
# title = "Energy evolution for different configurations",
126+
# xlabel = "t",
127+
# ylabel = "E(t)",
128+
#),
129+
#:energy_evolution_hist => (
130+
# title = "Energy histogram for different configurations",
131+
# xlabel = "frequency",
132+
# ylabel = "E(t)",
133+
#),
134+
#:energy_spectra => (
135+
# title = "Energy spectra",
136+
#),
137+
#:training_time => (
138+
# title = "Training time for different configurations",
139+
# xlabel = "Model",
140+
# ylabel = "Training time (s) (per iteration)",
141+
#),
142+
#:training_comptime => (
143+
# title = "Training time for different configurations",
144+
# xlabel = "Model",
145+
# ylabel = "Full Training time (s)",
146+
#),
147+
#:inference_time => (
148+
# title = "Inference time for different configurations",
149+
# xlabel = "Model",
150+
# ylabel = "Inference time (s)",
151+
#),
152+
#:num_parameters => (
153+
# title = "Number of parameters for different configurations",
154+
# xlabel = "Model",
155+
# ylabel = "Number of parameters",
156+
#),
157+
#:eprior => (
158+
# title = "A-prior error for different configurations",
159+
# xlabel = "Model",
160+
# ylabel = "A-prior error",
161+
#),
162+
#:epost => (
163+
# title = "A-posteriori error for different configurations",
164+
# xlabel = "Model",
165+
# ylabel = "A-posteriori error",
166+
#),
167+
#:epost_vs_t => (
168+
# title = "A-posteriori error as a function of time",
169+
# xlabel = "t",
170+
# ylabel = L"e_{M}(t)",
171+
#),
167172
)
168173

174+
dns_seeds = splitseed(123456, 8)
175+
dns_seeds = [0x185efb6b]
176+
169177
for key in keys(plot_labels)
170178
@info "Plotting $key"
171179

@@ -305,12 +313,20 @@ for key in keys(plot_labels)
305313
plot_epost_vs_t(
306314
error_file, closure_name, nles, ax, color, PLOT_STYLES
307315
)
316+
elseif key == :dns_solution
317+
data_ref = load_data_set(outdir, nles, Φ, dns_seeds, false)
318+
data_proj = load_data_set(outdir, nles, Φ, dns_seeds, true)
319+
plot_dns_solution(
320+
data_ref, data_proj, ax#, color, PLOT_STYLES
321+
)
322+
else
323+
@error "Unknown plot type: $key"
308324
end
309325
end
310326
end
311327
end
312328
# Add legend
313-
if key != :energy_spectra
329+
if !(key in (:energy_spectra, :dns_solution))
314330
Legend(fig[:, end+1], ax)
315331
end
316332

@@ -320,7 +336,7 @@ for key in keys(plot_labels)
320336
end
321337

322338
# Set log-log scale
323-
if key == :epost_vs_t
339+
if key in (:epost_vs_t, :dns_solution)
324340
ax.xscale = log10
325341
ax.yscale = log10
326342
end

configs/local/cnn_proj.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ params:
1010
tburn: 0.5
1111
tsim: 2.0
1212
savefreq: 50
13-
ndns: 1024
14-
nles: [64]
13+
ndns: 128
14+
nles: [32]
1515
filters: ["FaceAverage()"]
1616
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
1717
method: "RKMethods.Wray3(; T)"

src/Benchmark.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,11 @@ export _convert_to_single_index,
110110
plot_inference_time,
111111
plot_num_parameters,
112112
plot_error,
113-
plot_epost_vs_t
113+
plot_epost_vs_t,
114+
plot_dns_solution
114115

115116
export compute_eprior, compute_epost, compute_t_prior_inference
116117
export reusepriorfile, reusepostfile
118+
export load_data_set
117119

118120
end # module Benchmark

src/plots.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -886,3 +886,22 @@ function plot_epost_vs_t(error_file, closure_name, nles, ax, color, PLOT_STYLES)
886886

887887
return nothing
888888
end
889+
890+
function plot_dns_solution(data_ref, data_proj, ax)
891+
892+
for (i,dr) in enumerate(data_ref)
893+
x = dr.t[2:end] #skip t=0
894+
y = []
895+
for (ti, tval) in enumerate(x)
896+
#@info "t(ref) = $(tval) t(proj)=$(data_proj[i].t[ti])"
897+
#@assert tval ≈ data_proj[i].t[ti]
898+
push!(y, sum(abs, dr.u[:,:,:,ti]-data_proj[i].u[:,:,:,ti]))
899+
end
900+
scatterlines!(
901+
ax,
902+
x,
903+
vec(y)
904+
)
905+
end
906+
907+
end

0 commit comments

Comments
 (0)