Skip to content

Commit c7a3783

Browse files
committed
Fix extra_model_workflow
1 parent 2e6abdd commit c7a3783

File tree

5 files changed

+22
-16
lines changed

5 files changed

+22
-16
lines changed

benchmark.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ ispath(compdir) || mkpath(compdir)
1818

1919
# List configurations files
2020
using Glob
21-
list_confs = glob("*.yaml", confdir)
21+
exclude_patterns = ["att", "cno"]
22+
exclude_patterns = ["cno"]
23+
@warn "Excluding configurations with patterns: $(exclude_patterns)"
24+
all_confs = glob("*.yaml", confdir)
25+
list_confs = filter(conf -> all(!occursin(pat, conf) for pat in exclude_patterns), all_confs)
2226
if isempty(list_confs)
2327
@error "No configuration files found in $confdir"
2428
end

cnn_model_workflow.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,6 @@ let
357357
postseed = seeds.post,
358358
dns_seeds_train,
359359
dns_seeds_valid,
360-
dns_seeds_test,
361360
nunroll = conf["posteriori"]["nunroll"],
362361
nsamples = conf["posteriori"]["nsamples"],
363362
dt = T(conf["posteriori"]["dt"]),

configs/snellius/att_1.yaml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ params:
1111
tsim: 5.0
1212
savefreq: 50
1313
ndns: 4096
14-
nles: [128]
14+
nles: [64]
1515
filters: ["FaceAverage()"]
1616
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
1717
method: "RKMethods.Wray3(; T)"
@@ -33,8 +33,10 @@ closure:
3333
use_bias: [true, true, true, true, false]
3434
use_attention: [true, false, false, false, false]
3535
emb_sizes: [124, 124, 124, 124, 124]
36-
Ns: [148, 144, 140, 136, 132]
37-
patch_sizes: [37, 36, 35, 34, 33]
36+
# Ns: [148, 144, 140, 136, 132]
37+
Ns: [ 84, 80, 76, 72, 68]
38+
# patch_sizes: [37, 36, 35, 34, 33]
39+
patch_sizes: [21, 20, 19, 18, 17]
3840
n_heads: [4, 4, 4, 4, 4]
3941
sum_attention: [false, false, false, false, false]
4042
rng: "Xoshiro(seeds.θ_start)"

extra_model_workflow.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ dns_seeds_test = dns_seeds[ntrajectory:ntrajectory]
175175
docreatedata = conf["docreatedata"]
176176
for i = 1:ntrajectory
177177
if i%numtasks == taskid - 1
178-
docreatedata && createdata(; params, seed = dns_seeds[i], outdir, backend)
178+
docreatedata && createdata(; params, seed = dns_seeds[i], outdir, backend, dataproj = conf["dataproj"])
179179
end
180180
end
181181
@info "Data generated"
@@ -283,6 +283,7 @@ let
283283
do_plot = conf["priori"]["do_plot"],
284284
plot_train = conf["priori"]["plot_train"],
285285
nepoch,
286+
dataproj = conf["dataproj"],
286287
)
287288
end
288289
end
@@ -380,6 +381,7 @@ let
380381
do_plot = conf["posteriori"]["do_plot"],
381382
plot_train = conf["posteriori"]["plot_train"],
382383
sensealg = haskey(conf["posteriori"],:sensealg) ? eval(Meta.parse(conf["posteriori"]["sensealg"])) : nothing,
384+
dataproj = conf["dataproj"],
383385
)
384386
end
385387
end
@@ -470,11 +472,11 @@ let
470472
eprior.model_post[ig, ifil, iorder] = compute_eprior(closure, device(θ_cnn_post[ig, ifil, iorder]), st, testset...)
471473
end
472474
end
473-
jldsave(joinpath(outdir_model, "eprior.jld2"); eprior...)
475+
jldsave(joinpath(outdir_model, "eprior_nles=$(params.nles[1]).jld2"); eprior...)
474476
end
475477
clean()
476478

477-
eprior = namedtupleload(joinpath(outdir_model, "eprior.jld2"))
479+
eprior = namedtupleload(joinpath(outdir_model, "eprior_nles=$(params.nles[1]).jld2"))
478480

479481
########################################################################## #src
480482

@@ -533,10 +535,10 @@ let
533535
@info "Epost model_post" epost.model_post[I, :]
534536
clean()
535537
end
536-
jldsave(joinpath(outdir_model, "epost.jld2"); epost...)
538+
jldsave(joinpath(outdir_model, "epost_nles=$(params.nles[1]).jld2"); epost...)
537539
end
538540

539-
epost = namedtupleload(joinpath(outdir_model, "epost.jld2"))
541+
epost = namedtupleload(joinpath(outdir_model, "epost_nles=$(params.nles[1]).jld2"))
540542

541543

542544
########################################################################## #src
@@ -724,12 +726,12 @@ let
724726
push!(energyhistory[:model_post][I], Point2f(t, e))
725727
end
726728
end
727-
jldsave(joinpath(outdir_model, "history.jld2"); energyhistory, divergencehistory)
729+
jldsave(joinpath(outdir_model, "history_nles=$(params.nles[1]).jld2"); energyhistory, divergencehistory)
728730
clean()
729731
end
730732
end
731733

732-
(; divergencehistory, energyhistory) = namedtupleload(joinpath(outdir_model, "history.jld2"));
734+
(; divergencehistory, energyhistory) = namedtupleload(joinpath(outdir_model, "history_nles=$(params.nles[1]).jld2"));
733735

734736
########################################################################## #src
735737

@@ -1009,11 +1011,11 @@ let
10091011
end
10101012
clean()
10111013
end
1012-
jldsave("$outdir_model/solutions.jld2"; u = utimes, t = times_exact, itime_max_DIF)
1014+
jldsave("$outdir_model/solutions_nles=$(params.nles[1]).jld2"; u = utimes, t = times_exact, itime_max_DIF)
10131015
end;
10141016

10151017
# Load solution
1016-
solutions = namedtupleload("$outdir_model/solutions.jld2");
1018+
solutions = namedtupleload("$outdir_model/solutions_nles=$(params.nles[1]).jld2");
10171019

10181020
########################################################################## #src
10191021

src/train.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ function trainpost(;
239239
postseed,
240240
dns_seeds_train,
241241
dns_seeds_valid,
242-
dns_seeds_test,
243242
nunroll,
244243
nsamples = 1,
245244
closure,
@@ -335,7 +334,7 @@ function trainpost(;
335334

336335

337336
# For the callback I am going to use the a-posteriori error estimator
338-
sample = namedtupleload(getdatafile(outdir, nles, Φ, dns_seeds_test[1]))
337+
sample = namedtupleload(getdatafile(outdir, nles, Φ, dns_seeds_valid[1]))
339338
it = 1:(nunroll_valid+1)
340339
data_cb = (;
341340
u = selectdim(sample.u, ndims(sample.u), it) |> collect |> device,

0 commit comments

Comments
 (0)