Skip to content

Commit 033154e

Browse files
committed
autopush
1 parent b868de6 commit 033154e

File tree

2 files changed

+181
-29
lines changed

2 files changed

+181
-29
lines changed

benchmark.jl

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ using Glob
2121
exclude_patterns = ["att", "cno", "int", "back", "rk4", "cnn_1" ]
2222
exclude_patterns = ["att", "cno" ]
2323
include_patterns = ["base"]
24+
include_patterns = []
25+
exclude_patterns = []
2426

2527
if !isempty(include_patterns)
2628
@warn "Including only configurations with patterns: $(include_patterns)"
@@ -129,6 +131,11 @@ colors_list = [
129131

130132
# Loop over plot types and configurations
131133
plot_labels = Dict(
134+
:dns_solution => (
135+
title = "DNS solution for different configurations",
136+
xlabel = "t",
137+
ylabel = L"\frac{|u(t)-u_{proj}(t)|}{|u(t)|}",
138+
),
132139
:prior_hist => (
133140
title = "A-priori training history for different configurations",
134141
xlabel = "Iteration",
@@ -139,11 +146,6 @@ plot_labels = Dict(
139146
xlabel = "Iteration",
140147
ylabel = "DCF",
141148
),
142-
:dns_solution => (
143-
title = "DNS solution for different configurations",
144-
xlabel = "t",
145-
ylabel = L"\frac{|u(t)-u_{proj}(t)|}{|u(t)|}",
146-
),
147149
:divergence => (
148150
title = "Divergence for different configurations",
149151
xlabel = "t",
@@ -162,21 +164,6 @@ plot_labels = Dict(
162164
:energy_spectra => (
163165
title = "Energy spectra",
164166
),
165-
:training_time => (
166-
title = "Training time for different configurations",
167-
xlabel = "Model",
168-
ylabel = "Training time (s) (per iteration)",
169-
),
170-
:training_comptime => (
171-
title = "Training time for different configurations",
172-
xlabel = "Model",
173-
ylabel = "Full Training time (s)",
174-
),
175-
:inference_time => (
176-
title = "Inference time for different configurations",
177-
xlabel = "Model",
178-
ylabel = "Inference time (s)",
179-
),
180167
:num_parameters => (
181168
title = "Number of parameters for different configurations",
182169
xlabel = "Model",
@@ -197,6 +184,21 @@ plot_labels = Dict(
197184
xlabel = "t",
198185
ylabel = L"e_{M}(t)",
199186
),
187+
:training_time => (
188+
title = "Training time for different configurations",
189+
xlabel = "Model",
190+
ylabel = "Training time (s) (per iteration)",
191+
),
192+
:training_comptime => (
193+
title = "Training time for different configurations",
194+
xlabel = "Model",
195+
ylabel = "Full Training time (s)",
196+
),
197+
:inference_time => (
198+
title = "Inference time for different configurations",
199+
xlabel = "Model",
200+
ylabel = "Inference time (s)",
201+
),
200202
)
201203

202204
dns_seeds = splitseed(123456, 8)
@@ -312,7 +314,7 @@ for key in keys(plot_labels)
312314
outdir, closure_name, "eprior_nles=$(nles).jld2"
313315
)
314316
bar_label, bar_position = plot_error(
315-
error_file, closure_name, nles, data_index, col_index, ax, color, PLOT_STYLES
317+
error_file, closure_name, nles, data_index, col_index, ax, color, PLOT_STYLES; outdir=outdir
316318
)
317319
append!(bar_positions, bar_position)
318320
append!(bar_labels, bar_label)

src/plots.jl

Lines changed: 158 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,20 @@ function plot_training_time(
649649
label = "$closure_name (n = $nles)",
650650
color = color, # dont change this color
651651
)
652+
653+
# Save training time data to CSV (create a minimal error data structure if needed)
654+
# Try to find the error file to get real error data, otherwise create dummy data
655+
error_file_path = joinpath(dirname(outdir), closure_name, "epost_nles=$(nles).jld2")
656+
if isfile(error_file_path)
657+
error_data = namedtupleload(error_file_path)
658+
_save_error_data_to_csv(error_file_path, error_data, closure_name, nles, 1, model_index, training_time_post; outdir=outdir)
659+
else
660+
# Create a dummy error data structure and save training time only
661+
dummy_error_data = (model_post = [NaN],) # Use NaN as placeholder for missing error
662+
dummy_error_file = joinpath(dirname(outdir), closure_name, "dummy_epost.jld2")
663+
_save_error_data_to_csv(dummy_error_file, dummy_error_data, closure_name, nles, 1, model_index, training_time_post; outdir=outdir)
664+
end
665+
652666
return labels, labels_positions
653667

654668
end
@@ -760,7 +774,8 @@ function plot_error(
760774
model_index,
761775
ax,
762776
color,
763-
PLOT_STYLES,
777+
PLOT_STYLES;
778+
outdir=nothing,
764779
)
765780
error_data = namedtupleload(error_file)
766781

@@ -811,9 +826,116 @@ function plot_error(
811826
gap = bar_gap,
812827
)
813828

829+
# Store data in CSV format
830+
_save_error_data_to_csv(error_file, error_data, closure_name, nles, data_index, model_index; outdir=outdir)
831+
814832
return labels, labels_positions
815833
end
816834

835+
function _save_error_data_to_csv(error_file, error_data, closure_name, nles, data_index, model_index, training_time_post=nothing; outdir=nothing)
836+
# Create single CSV file in the comparison directory
837+
if outdir === nothing
838+
# If outdir not provided, try to derive it from error_file path
839+
# error_file is typically: basedir/output/kolmogorov/ClosureName/epost_nles=X.jld2
840+
# We want: basedir/output/kolmogorov/comparison/a_posteriori_errors.csv
841+
error_dir = dirname(error_file) # .../output/kolmogorov/ClosureName
842+
outdir = dirname(error_dir) # .../output/kolmogorov
843+
end
844+
845+
comparison_dir = joinpath(outdir, "comparison")
846+
# Ensure comparison directory exists
847+
ispath(comparison_dir) || mkpath(comparison_dir)
848+
csv_file = joinpath(comparison_dir, "a_posteriori_errors.csv")
849+
850+
# Get a posteriori (post) error data, handle case where it might be missing
851+
post_error = nothing
852+
if haskey(error_data, :model_post)
853+
try
854+
# Handle both scalar indices and CartesianIndex
855+
potential_error = error_data.model_post[data_index]
856+
if !isnan(potential_error)
857+
post_error = potential_error
858+
end
859+
catch ex
860+
if isa(ex, BoundsError)
861+
# Index is out of bounds, leave post_error as nothing
862+
@warn "data_index $data_index is out of bounds for model_post array"
863+
else
864+
# Other indexing errors, leave post_error as nothing
865+
@warn "Error accessing model_post with index $data_index: $ex"
866+
end
867+
end
868+
end
869+
870+
# Read existing data if file exists
871+
existing_data = Dict{String, Tuple{Union{Float64, Nothing}, Union{Float64, Nothing}}}()
872+
file_exists = isfile(csv_file)
873+
874+
if file_exists
875+
# Read existing data into a dictionary
876+
lines = readlines(csv_file)
877+
if length(lines) > 1
878+
header = split(lines[1], ",")
879+
has_training_time = length(header) >= 3 && strip(header[3]) == "training_time_post"
880+
881+
for line in lines[2:end] # Skip header
882+
if !isempty(strip(line))
883+
parts = split(line, ",")
884+
if length(parts) >= 2
885+
existing_closure_name = strip(parts[1])
886+
887+
# Parse error (might be empty)
888+
existing_error = nothing
889+
if !isempty(strip(parts[2]))
890+
existing_error = parse(Float64, strip(parts[2]))
891+
end
892+
893+
# Parse training time (might be empty)
894+
existing_training_time = nothing
895+
if has_training_time && length(parts) >= 3 && !isempty(strip(parts[3]))
896+
existing_training_time = parse(Float64, strip(parts[3]))
897+
end
898+
899+
existing_data[existing_closure_name] = (existing_error, existing_training_time)
900+
end
901+
end
902+
end
903+
end
904+
end
905+
906+
# Update or add the current closure_name, preserving existing values when new ones are not provided
907+
current_error = post_error
908+
current_training_time = training_time_post
909+
910+
if haskey(existing_data, closure_name)
911+
# Keep existing values if new ones are not provided (nothing/missing)
912+
if post_error === nothing
913+
current_error = existing_data[closure_name][1]
914+
end
915+
if training_time_post === nothing
916+
current_training_time = existing_data[closure_name][2]
917+
end
918+
end
919+
920+
existing_data[closure_name] = (current_error, current_training_time)
921+
922+
# Write the complete data back to file
923+
open(csv_file, "w") do io
924+
# Write header
925+
println(io, "closure_name,a_posteriori_error,training_time_post")
926+
927+
# Write all data rows (sorted by closure_name for consistency)
928+
for closure in sort(collect(keys(existing_data)))
929+
error_val, training_time_val = existing_data[closure]
930+
error_str = error_val === nothing ? "" : string(error_val)
931+
training_time_str = training_time_val === nothing ? "" : string(training_time_val)
932+
println(io, join([closure, error_str, training_time_str], ","))
933+
end
934+
end
935+
936+
@info "Data saved to CSV: $csv_file (closure_name: $closure_name, error: $(current_error === nothing ? "missing" : current_error), training_time: $(current_training_time === nothing ? "missing" : current_training_time))"
937+
end
938+
817939

818940
function plot_epost_vs_t(error_file, closure_name, nles, ax, color, PLOT_STYLES)
819941
error_data = namedtupleload(error_file)
@@ -892,15 +1014,14 @@ function plot_epost_vs_t(error_file, closure_name, nles, ax, color, PLOT_STYLES)
8921014
return nothing
8931015
end
8941016

895-
function plot_dns_solution(ax, frameskip, infile, savepath)
1017+
function plot_dns_solution(ax, frameskip, infile, savepath; zidx=1, frame_to_save=300)
8961018
data = load(infile)
897-
# use inside
8981019
ref = data["uref"]
8991020
proj = data["u"]
9001021
t = data["t"]
9011022
tref = data["tref"]
9021023
y = []
903-
tref = tref[:length(tref)-10]
1024+
tref = tref[1:(length(tref)-1)]
9041025

9051026
for i in 1:length(tref)
9061027
# @assert t[i] ≈ tref[i] "Time values do not match: t[i] = $(t[i]), tref[i] = $(tref[i])"
@@ -915,17 +1036,46 @@ function plot_dns_solution(ax, frameskip, infile, savepath)
9151036

9161037
# Assume we use the first dataset to define dimensions
9171038
nframes = length(tref)
918-
# Choose z-slice
919-
zidx = 1
9201039

921-
fig = Figure(resolution = (900, 300))
1040+
# Save specified frame as a separate PNG first
1041+
if nframes >= frame_to_save
1042+
t_save = t[frame_to_save]
1043+
frame_fig = Figure(resolution = (1200, 300))
1044+
frame_ax1 = Makie.Axis(frame_fig[1, 1], title="Reference (t=$t_save s)")
1045+
frame_ax2 = Makie.Axis(frame_fig[1, 2], title="Projected (t=$t_save s)")
1046+
frame_ax3 = Makie.Axis(frame_fig[1, 3], title="Diff (t=$t_save s)")
1047+
1048+
u_ref_frame = ref[:, :, zidx, frame_to_save]
1049+
u_proj_frame = proj[:, :, zidx, frame_to_save]
1050+
diff_frame = u_ref_frame - u_proj_frame
1051+
1052+
heatmap!(frame_ax1, u_ref_frame)
1053+
heatmap!(frame_ax2, u_proj_frame)
1054+
hm_diff = heatmap!(frame_ax3, diff_frame, colormap=:reds)
1055+
1056+
# Add colorbar for the difference plot
1057+
Colorbar(frame_fig[1, 4], hm_diff, label="Difference")
1058+
1059+
# Create PNG filename from GIF savepath
1060+
png_savepath = replace(savepath, r"\.(gif|GIF)$" => "_frame$(frame_to_save).png")
1061+
save(png_savepath, frame_fig)
1062+
println("Frame $frame_to_save PNG saved to $png_savepath")
1063+
else
1064+
println("Warning: Only $nframes frames available, cannot save frame $frame_to_save")
1065+
end
1066+
1067+
# Now create the GIF animation
1068+
fig = Figure(resolution = (1200, 300))
9221069
ax1 = Makie.Axis(fig[1, 1], title="Reference")
9231070
ax2 = Makie.Axis(fig[1, 2], title="Projected")
9241071
ax3 = Makie.Axis(fig[1, 3], title="Diff")
9251072

9261073
hm1 = heatmap!(ax1, zeros(size(ref, 1), size(ref, 2)))
9271074
hm2 = heatmap!(ax2, zeros(size(ref, 1), size(ref, 2)))
928-
hm3 = heatmap!(ax3, zeros(size(ref, 1), size(ref, 2)))
1075+
hm3 = heatmap!(ax3, zeros(size(ref, 1), size(ref, 2)), colormap=:reds)
1076+
1077+
# Add colorbar for the difference plot
1078+
Colorbar(fig[1, 4], hm3, label="Difference")
9291079

9301080
Makie.record(fig, savepath, 1:frameskip:nframes) do t
9311081
u_ref_t = ref[:, :, zidx, t]

0 commit comments

Comments
 (0)