@@ -120,6 +120,9 @@ function plot_posteriori_traininghistory(
120120end
121121
122122function plot_divergence (outdir, closure_name, nles, Φ, data_index, ax, color, PLOT_STYLES)
123+ if closure_name == " INS.jl"
124+ return
125+ end
123126 # Load learned parameters
124127 divergence_dir = joinpath (outdir, closure_name, " history_nles=$(nles) .jld2" )
125128 if ! ispath (divergence_dir)
@@ -228,7 +231,7 @@ function plot_energy_evolution(
228231 end
229232 energyhistory = namedtupleload (energy_dir). energyhistory;
230233
231- if closure_name == " INS.jl"
234+ if closure_name == " INS.jl" && false
232235 label = " No closure "
233236 if _missing_label (ax, label) && haskey (energyhistory, Symbol (" nomodel" ))
234237 lines! (
@@ -337,7 +340,7 @@ function plot_energy_evolution_hist(
337340 end
338341 energyhistory = namedtupleload (energy_dir). energyhistory;
339342
340- if closure_name == " INS.jl"
343+ if closure_name == " INS.jl" && false
341344 label = " No closure "
342345 if _missing_label (ax, label) && haskey (energyhistory, Symbol (" nomodel" ))
343346 _plot_histogram (
@@ -471,7 +474,11 @@ function plot_energy_spectra(
471474 num_of_models,
472475 color,
473476 PLOT_STYLES,
477+ single_legend = false
474478)
479+ if closure_name == " INS.jl"
480+ return
481+ end
475482 # Load learned parameters
476483 energy_dir = joinpath (outdir, closure_name, " solutions_nles=$(nles) .jld2" )
477484 if ! ispath (energy_dir)
@@ -488,9 +495,14 @@ function plot_energy_spectra(
488495
489496 # Create a grid of plots and legends
490497 gtitle = fig[1 , 1 ] # Title of the figure
491- gplot = fig[2 , 1 ]
492- gplot_ax = gplot[1 , 1 ] # Axis for each plot
493- gplot_leg = gplot[1 , 2 ] # The common legend for all plots
498+ if single_legend
499+ gplot = fig[1 , 1 ]
500+ gplot_ax = gplot[1 , 1 ]
501+ else
502+ gplot = fig[2 , 1 ]
503+ gplot_ax = gplot[1 , 1 ] # Axis for each plot
504+ gplot_leg = gplot[1 , 2 ] # The common legend for all plots
505+ end
494506
495507 Label (
496508 gtitle,
@@ -593,15 +605,14 @@ function plot_energy_spectra(
593605 [prior_label, post_label],
594606 labelsize = 8 ,
595607 )
608+ # Add legend that is common for all plots
609+ Legend (
610+ gplot_leg,
611+ [no_closure_plt, reference_plt, inertia_plt],
612+ [no_closure_label, reference_label, inertia_label],
613+ labelsize = 8 ,
614+ )
596615 end
597-
598- # Add legend that is common for all plots
599- Legend (
600- gplot_leg,
601- [no_closure_plt, reference_plt, inertia_plt],
602- [no_closure_label, reference_label, inertia_label],
603- labelsize = 8 ,
604- )
605616 end
606617end
607618
@@ -831,16 +842,24 @@ function _save_error_data_to_csv(error_data, closure_name, data_index; outdir, n
831842
832843 # Get a posteriori (post) error data, handle case where it might be missing
833844 post_error = error_data. model_post[data_index]
845+ if closure_name == " INS.jl"
846+ post_error = error_data. smag[data_index]
847+ end
834848
835849 # Try to read training time data from training files
836850 training_time_post = nothing
837851 if nles != = nothing && Φ != = nothing && projectorders != = nothing
838852 try
839853 if closure_name == " INS.jl"
840- posttraining = namedtupleload (
841- Benchmark. getpostfile (outdir, closure_name, nles, Φ, projectorders[1 ]),
854+ # posttraining = namedtupleload(
855+ # Benchmark.getpostfile(outdir, closure_name, nles, Φ, projectorders[1]),
856+ # )
857+ # training_time_post = posttraining.single_stored_object.time_per_epoch
858+
859+ smagtrain= namedtupleload (
860+ joinpath (outdir, " smagorinski" , " projectorder=Last_filter=FaceAverage()_nles=64.jld2" ),
842861 )
843- training_time_post = posttraining . single_stored_object . time_per_epoch
862+ training_time_post = smagtrain . comptime / 300
844863 else
845864 posttraining = loadpost (outdir, closure_name, [nles], [Φ], projectorders)
846865 training_time_post = posttraining[1 ]. time_per_epoch
@@ -890,6 +909,9 @@ function _save_error_data_to_csv(error_data, closure_name, data_index; outdir, n
890909 # Update or add the current closure_name, preserving existing values when new ones are not provided
891910 current_error = post_error
892911 current_training_time = training_time_post
912+ if closure_name == " INS.jl"
913+ closure_name = " Smagorinsky"
914+ end
893915
894916 if haskey (existing_data, closure_name)
895917 # Keep existing values if new ones are not provided (nothing/missing)
@@ -926,35 +948,37 @@ function plot_epost_vs_t(error_file, closure_name, nles, ax, color, PLOT_STYLES)
926948
927949 x = error_data. nts # use time from error data
928950
929- # Prior
930- scatterlines! (
931- ax,
932- x,
933- vec (error_data. model_prior);
934- label = " $closure_name prior (n = $nles )" ,
935- color = color,
936- linestyle = PLOT_STYLES[:prior ]. linestyle,
937- linewidth = PLOT_STYLES[:prior ]. linewidth,
938- )
939-
940- # Post
941- scatterlines! (
942- ax,
943- x,
944- vec (error_data. model_post);
945- label = " $closure_name post (n = $nles )" ,
946- color = color,
947- linestyle = PLOT_STYLES[:post ]. linestyle,
948- linewidth = PLOT_STYLES[:post ]. linewidth,
949- )
951+ if closure_name != " INS.jl"
952+ # Prior
953+ scatterlines! (
954+ ax,
955+ x,
956+ vec (error_data. model_prior);
957+ label = " $closure_name prior (n = $nles )" ,
958+ color = color,
959+ linestyle = PLOT_STYLES[:prior ]. linestyle,
960+ linewidth = PLOT_STYLES[:prior ]. linewidth,
961+ )
962+ # Post
963+ scatterlines! (
964+ ax,
965+ x,
966+ vec (error_data. model_post);
967+ label = " $closure_name post (n = $nles )" ,
968+ color = color,
969+ linestyle = PLOT_STYLES[:post ]. linestyle,
970+ linewidth = PLOT_STYLES[:post ]. linewidth,
971+ )
972+ end
950973
951974 # Smagorinsky (optional)
952975 if haskey (error_data, Symbol (" smag" ))
953976 scatterlines! (
954977 ax,
955- x,
956- vec (error_data. smag);
957- label = " $closure_name smag (n = $nles )" ,
978+ x[1 : end - 2 ],
979+ vec (error_data. smag[1 : end - 2 ]);
980+ # label = "$closure_name smag (n = $nles)",
981+ label = " Smagorinsky (n = $nles )" ,
958982 color = PLOT_STYLES[:smag ]. color,
959983 linestyle = PLOT_STYLES[:smag ]. linestyle,
960984 linewidth = PLOT_STYLES[:smag ]. linewidth,
0 commit comments