Skip to content

Commit b8114b7

Browse files
committed
Add dataproj key to use same ref seed
1 parent eaeff3a commit b8114b7

File tree

5 files changed

+25
-11
lines changed

5 files changed

+25
-11
lines changed

cnn_model_workflow.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ dns_seeds_test = dns_seeds[ntrajectory:ntrajectory]
169169
docreatedata = conf["docreatedata"]
170170
for i = 1:ntrajectory
171171
if i%numtasks == taskid - 1
172-
docreatedata && createdata(; params, seed = dns_seeds[i], outdir, backend)
172+
docreatedata && createdata(; params, seed = dns_seeds[i], outdir, backend, dataproj = conf["dataproj"])
173173
end
174174
end
175175
@info "Data generated"
@@ -271,7 +271,8 @@ let
271271
batchsize = conf["priori"]["batchsize"],
272272
do_plot = conf["priori"]["do_plot"],
273273
plot_train = conf["priori"]["plot_train"],
274-
nepoch,
274+
nepoch = nepoch,
275+
dataproj = conf["dataproj"],
275276
)
276277
end
277278
end
@@ -369,6 +370,7 @@ let
369370
do_plot = conf["posteriori"]["do_plot"],
370371
plot_train = conf["posteriori"]["plot_train"],
371372
sensealg = haskey(conf["posteriori"],:sensealg) ? eval(Meta.parse(conf["posteriori"]["sensealg"])) : nothing,
373+
dataproj = conf["dataproj"],
372374
)
373375
end
374376
end

configs/snellius/cnn_1.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ docreatedata: true
22
docomp: true
33
ntrajectory: 8
44
T: "Float32"
5+
dataproj: true
56
params:
67
D: 2
78
lims: [0.0, 1.0]

configs/snellius/cnn_ins.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ docreatedata: true
22
docomp: true
33
ntrajectory: 8
44
T: "Float32"
5+
dataproj: false
56
params:
67
D: 2
78
lims: [0.0, 1.0]
@@ -48,3 +49,4 @@ posteriori:
4849
dt: 0.0001
4950
do_plot: false
5051
plot_train: false
52+
nsamples: 1

src/plots.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ function plot_energy_evolution(
236236
end
237237
end
238238

239-
if closure_name == "cnn_1"
239+
if closure_name == "cnn_proj"
240240
label = "No closure (projected dyn)"
241241
if _missing_label(ax, label) && haskey(energyhistory, Symbol("nomodel"))
242242
lines!(
@@ -702,7 +702,7 @@ function plot_epost_vs_t(error_file, closure_name, nles, ax, color, PLOT_STYLES)
702702
end
703703
end
704704

705-
if closure_name == "cnn_INS"
705+
if closure_name == "cnn_noproj"
706706
label = "No model (projected dyn)"
707707
if _missing_label(ax, label) # add No closure only once
708708
scatterlines!(

src/train.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,26 @@ function getdatafile(outdir, nles, filter, seed)
33
joinpath(outdir, "data", splatfileparts(; seed = repr(seed), filter, nles) * ".jld2")
44
end
55

6-
function load_data_set(outdir, nles, Φ, seeds)
6+
function load_data_set(outdir, nles, Φ, seeds, dataproj)
77
data = []
88
for s in seeds
9-
data_i = namedtupleload(getdatafile(outdir, nles, Φ, s))
9+
filename = getdatafile(outdir, nles, Φ, s)
10+
if dataproj
11+
filename = replace(filename, ".jld2" => "_projected.jld2")
12+
end
13+
data_i = namedtupleload(filename)
1014
push!(data, data_i)
1115
end
1216
return data
1317
end
1418

15-
function createdata(; params, seed, outdir, backend)
19+
function createdata(; params, seed, outdir, backend, dataproj)
1620
for (nles, Φ) in Iterators.product(params.nles, params.filters)
1721

1822
filename = getdatafile(outdir, nles, Φ, seed)
23+
if dataproj
24+
filename = replace(filename, ".jld2" => "_projected.jld2")
25+
end
1926
datadir = dirname(filename)
2027
ispath(datadir) || mkpath(datadir)
2128

@@ -85,6 +92,7 @@ function trainprior(;
8592
do_plot = false,
8693
plot_train = false,
8794
nepoch,
95+
dataproj
8896
)
8997
device(x) = adapt(params.backend, x)
9098
itotal = 0
@@ -118,8 +126,8 @@ function trainprior(;
118126
NS = Base.get_extension(CoupledNODE, :NavierStokes)
119127

120128
# Read the data in the format expected by the CoupledNODE
121-
data_train = load_data_set(outdir, nles, Φ, dns_seeds_train)
122-
data_valid = load_data_set(outdir, nles, Φ, dns_seeds_valid)
129+
data_train = load_data_set(outdir, nles, Φ, dns_seeds_train, dataproj)
130+
data_valid = load_data_set(outdir, nles, Φ, dns_seeds_valid, dataproj)
123131
@assert length(nles) == 1 "Only one nles for a-priori training"
124132
io_train = NS.create_io_arrays_priori(data_train, setup[1], device)
125133
io_valid = NS.create_io_arrays_priori(data_valid, setup[1], device)
@@ -244,6 +252,7 @@ function trainpost(;
244252
do_plot = false,
245253
plot_train = false,
246254
sensealg = nothing,
255+
dataproj,
247256
)
248257
device(x) = adapt(params.backend, x)
249258
itotal = 0
@@ -280,8 +289,8 @@ function trainpost(;
280289
end
281290

282291
# Read the data in the format expected by the CoupledNODE
283-
data_train = load_data_set(outdir, nles, Φ, dns_seeds_train)
284-
data_valid = load_data_set(outdir, nles, Φ, dns_seeds_valid)
292+
data_train = load_data_set(outdir, nles, Φ, dns_seeds_train, dataproj)
293+
data_valid = load_data_set(outdir, nles, Φ, dns_seeds_valid, dataproj)
285294

286295
NS = Base.get_extension(CoupledNODE, :NavierStokes)
287296
io_train = NS.create_io_arrays_posteriori(data_train, setup[1], device)

0 commit comments

Comments
 (0)