@@ -3,19 +3,26 @@ function getdatafile(outdir, nles, filter, seed)
33 joinpath (outdir, " data" , splatfileparts (; seed = repr (seed), filter, nles) * " .jld2" )
44end
55
6- function load_data_set (outdir, nles, Φ, seeds)
6+ function load_data_set (outdir, nles, Φ, seeds, dataproj )
77 data = []
88 for s in seeds
9- data_i = namedtupleload (getdatafile (outdir, nles, Φ, s))
9+ filename = getdatafile (outdir, nles, Φ, s)
10+ if dataproj
11+ filename = replace (filename, " .jld2" => " _projected.jld2" )
12+ end
13+ data_i = namedtupleload (filename)
1014 push! (data, data_i)
1115 end
1216 return data
1317end
1418
15- function createdata (; params, seed, outdir, backend)
19+ function createdata (; params, seed, outdir, backend, dataproj )
1620 for (nles, Φ) in Iterators. product (params. nles, params. filters)
1721
1822 filename = getdatafile (outdir, nles, Φ, seed)
23+ if dataproj
24+ filename = replace (filename, " .jld2" => " _projected.jld2" )
25+ end
1926 datadir = dirname (filename)
2027 ispath (datadir) || mkpath (datadir)
2128
@@ -85,6 +92,7 @@ function trainprior(;
8592 do_plot = false ,
8693 plot_train = false ,
8794 nepoch,
95+ dataproj
8896)
8997 device (x) = adapt (params. backend, x)
9098 itotal = 0
@@ -118,8 +126,8 @@ function trainprior(;
118126 NS = Base. get_extension (CoupledNODE, :NavierStokes )
119127
120128 # Read the data in the format expected by the CoupledNODE
121- data_train = load_data_set (outdir, nles, Φ, dns_seeds_train)
122- data_valid = load_data_set (outdir, nles, Φ, dns_seeds_valid)
129+ data_train = load_data_set (outdir, nles, Φ, dns_seeds_train, dataproj )
130+ data_valid = load_data_set (outdir, nles, Φ, dns_seeds_valid, dataproj )
123131 @assert length (nles) == 1 " Only one nles for a-priori training"
124132 io_train = NS. create_io_arrays_priori (data_train, setup[1 ], device)
125133 io_valid = NS. create_io_arrays_priori (data_valid, setup[1 ], device)
@@ -244,6 +252,7 @@ function trainpost(;
244252 do_plot = false ,
245253 plot_train = false ,
246254 sensealg = nothing ,
255+ dataproj,
247256)
248257 device (x) = adapt (params. backend, x)
249258 itotal = 0
@@ -280,8 +289,8 @@ function trainpost(;
280289 end
281290
282291 # Read the data in the format expected by the CoupledNODE
283- data_train = load_data_set (outdir, nles, Φ, dns_seeds_train)
284- data_valid = load_data_set (outdir, nles, Φ, dns_seeds_valid)
292+ data_train = load_data_set (outdir, nles, Φ, dns_seeds_train, dataproj )
293+ data_valid = load_data_set (outdir, nles, Φ, dns_seeds_valid, dataproj )
285294
286295 NS = Base. get_extension (CoupledNODE, :NavierStokes )
287296 io_train = NS. create_io_arrays_posteriori (data_train, setup[1 ], device)
0 commit comments