Skip to content

Commit 3a1e847

Browse files
committed
autopush
1 parent 9c3aff6 commit 3a1e847

File tree

8 files changed

+145
-21
lines changed

8 files changed

+145
-21
lines changed

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
4343
AttentionLayer = {rev = "main", url = "https://github.com/DEEPDIP-project/AttentionLayer.jl.git"}
4444
ConvolutionalNeuralOperators = {rev = "main", url = "https://github.com/DEEPDIP-project/ConvolutionalNeuralOperators.jl.git"}
4545
NeuralClosure = {rev = "main", url = "https://github.com/DEEPDIP-project/NeuralClosure.jl.git"}
46-
CoupledNODE = {rev = "main", url = "https://github.com/DEEPDIP-project/CoupledNODE.jl.git"}
4746

4847
[compat]
4948
Accessors = "0.1"

cnn_model_workflow.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ let
357357
postseed = seeds.post,
358358
dns_seeds_train,
359359
dns_seeds_valid,
360+
dns_seeds_test,
360361
nunroll = conf["posteriori"]["nunroll"],
361362
nsamples = conf["posteriori"]["nsamples"],
362363
dt = T(conf["posteriori"]["dt"]),
@@ -497,7 +498,8 @@ let
497498
@info "Computing a-posteriori errors" projectorder Φ nles
498499
I = CartesianIndex(ig, ifil, iorder)
499500
setup = getsetup(; params, nles)
500-
psolver = psolver_spectral(setup)
501+
#psolver = psolver_spectral(setup)
502+
psolver = default_psolver(setup)
501503
sample = namedtupleload(getdatafile(outdir, nles, Φ, dns_seeds_test[1]))
502504
it = 1:length(sample.t)
503505
data = (;
@@ -520,7 +522,7 @@ let
520522
setup, psolver, closure, st)
521523
epost.model_prior[I, :], _ = compute_epost(dudt, device(θ_cnn_prior[ig, ifil]) , tspan, data, tsave, dt)
522524
@info "Epost model_prior" epost.model_prior[I, :]
523-
epost.model_post[I, :], epost.model_t_post_inference[I] = compute_epost(dudt, device(θ_cnn_post[I]) , tspan, data, tsave, dt)
525+
epost.model_post[I, :], epost.model_t_post_inference[I] = compute_epost(dudt, device(θ_cnn_post[ig, ifil, iorder]) , tspan, data, tsave, dt)
524526
@info "Epost model_post" epost.model_post[I, :]
525527

526528
clean()

configs/snellius/cnn_1.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,15 @@ closure:
3434
rng: "Xoshiro(seeds.θ_start)"
3535
priori:
3636
dotrain: true
37-
nepoch: 10000
37+
nepoch: 50000
3838
batchsize: 64
3939
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
4040
do_plot: false
4141
plot_train: false
4242
posteriori:
4343
dotrain: true
4444
projectorders: "(ProjectOrder.Last, )"
45-
nepoch: 300
45+
nepoch: 1500
4646
opt: "OptimiserChain(Adam(T(1.0e-4)), ClipGrad(0.01))"
4747
nunroll: 5
4848
nunroll_valid: 10

configs/snellius/cnn_2.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ docreatedata: true
22
docomp: true
33
ntrajectory: 8
44
T: "Float32"
5+
dataproj: true
56
params:
67
D: 2
78
lims: [0.0, 1.0]
@@ -42,7 +43,7 @@ priori:
4243
posteriori:
4344
dotrain: true
4445
projectorders: "(ProjectOrder.Last, )"
45-
nepoch: 300
46+
nepoch: 100
4647
opt: "OptimiserChain(Adam(T(1.0e-4)), ClipGrad(0.01))"
4748
nunroll: 5
4849
nunroll_valid: 10

configs/snellius/cnn_ins.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,15 @@ closure:
3434
rng: "Xoshiro(seeds.θ_start)"
3535
priori:
3636
dotrain: true
37-
nepoch: 10000
37+
nepoch: 50000
3838
batchsize: 64
3939
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
4040
do_plot: false
4141
plot_train: false
4242
posteriori:
4343
dotrain: true
4444
projectorders: "(ProjectOrder.Last, )"
45-
nepoch: 300
45+
nepoch: 1500
4646
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
4747
nunroll: 5
4848
nunroll_valid: 10

configs/snellius/cnn_nt25.yaml

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
docreatedata: true
2+
docomp: true
3+
ntrajectory: 8
4+
T: "Float32"
5+
dataproj: true
6+
params:
7+
D: 2
8+
lims: [0.0, 1.0]
9+
Re: 6000.0
10+
tburn: 0.5
11+
tsim: 5.0
12+
savefreq: 50
13+
ndns: 4096
14+
nles: [64]
15+
filters: ["FaceAverage()"]
16+
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
17+
method: "RKMethods.Wray3(; T)"
18+
bodyforce: "(dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y)"
19+
issteadybodyforce: true
20+
processors: "(; log = timelogger(; nupdate=100))"
21+
Δt: 0.00005
22+
seeds:
23+
dns: 123456
24+
θ_start: 234
25+
prior: 345
26+
post: 456
27+
closure:
28+
name: "cnn_project_nt25"
29+
type: cnn
30+
radii: [2, 2, 2, 2, 2]
31+
channels: [24, 24, 24, 24, 2]
32+
activations: ["tanh", "tanh", "tanh", "tanh", "identity"]
33+
use_bias: [true, true, true, true, false]
34+
rng: "Xoshiro(seeds.θ_start)"
35+
priori:
36+
reuse: "cnn_project"
37+
dotrain: true
38+
nepoch: 50000
39+
batchsize: 64
40+
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
41+
do_plot: false
42+
plot_train: false
43+
posteriori:
44+
dotrain: true
45+
projectorders: "(ProjectOrder.Last, )"
46+
nepoch: 1500
47+
opt: "OptimiserChain(Adam(T(1.0e-4)), ClipGrad(0.01))"
48+
nunroll: 25
49+
nunroll_valid: 10
50+
nsamples: 1
51+
dt: 0.0001
52+
do_plot: false
53+
plot_train: false

configs/snellius/cnn_test.yaml

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
docreatedata: true
2+
docomp: true
3+
ntrajectory: 8
4+
T: "Float32"
5+
dataproj: true
6+
params:
7+
D: 2
8+
lims: [0.0, 1.0]
9+
Re: 6000.0
10+
tburn: 0.5
11+
tsim: 5.0
12+
savefreq: 50
13+
ndns: 4096
14+
nles: [64]
15+
filters: ["FaceAverage()"]
16+
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
17+
method: "RKMethods.Wray3(; T)"
18+
bodyforce: "(dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y)"
19+
issteadybodyforce: true
20+
processors: "(; log = timelogger(; nupdate=100))"
21+
Δt: 0.00005
22+
seeds:
23+
dns: 123456
24+
θ_start: 234
25+
prior: 345
26+
post: 456
27+
closure:
28+
name: "cnn_test"
29+
type: cnn
30+
radii: [2, 2, 2, 2, 2]
31+
channels: [24, 24, 24, 24, 2]
32+
activations: ["tanh", "tanh", "tanh", "tanh", "identity"]
33+
use_bias: [true, true, true, true, false]
34+
rng: "Xoshiro(seeds.θ_start)"
35+
priori:
36+
reuse: "cnn_project"
37+
dotrain: true
38+
nepoch: 50000
39+
batchsize: 64
40+
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
41+
do_plot: false
42+
plot_train: false
43+
posteriori:
44+
dotrain: true
45+
projectorders: "(ProjectOrder.Last, )"
46+
nepoch: 100
47+
opt: "OptimiserChain(Adam(T(1.0e-4)), ClipGrad(0.1))"
48+
nunroll: 5
49+
nunroll_valid: 24
50+
nsamples: 5
51+
dt: 0.00005
52+
do_plot: false
53+
plot_train: false

src/train.jl

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ function trainpost(;
239239
postseed,
240240
dns_seeds_train,
241241
dns_seeds_valid,
242+
dns_seeds_test,
242243
nunroll,
243244
nsamples = 1,
244245
closure,
@@ -281,21 +282,15 @@ function trainpost(;
281282
checkfile = join(splitext(postfile), "_checkpoint")
282283
setup = getsetup(; params, nles)
283284
psolver = default_psolver(setup)
284-
# Read the data in the format expected by the CoupledNODE
285285
T = eltype(params.Re)
286-
setup = []
287-
for nl in nles
288-
x = ntuple-> LinRange(T(0.0), T(1.0), nl + 1), params.D)
289-
push!(setup, Setup(; x = x, Re = params.Re, params.backend))
290-
end
291286

292287
# Read the data in the format expected by the CoupledNODE
293288
data_train = load_data_set(outdir, nles, Φ, dns_seeds_train, dataproj)
294289
data_valid = load_data_set(outdir, nles, Φ, dns_seeds_valid, dataproj)
295290

296291
NS = Base.get_extension(CoupledNODE, :NavierStokes)
297-
io_train = NS.create_io_arrays_posteriori(data_train, setup[1], device)
298-
io_valid = NS.create_io_arrays_posteriori(data_valid, setup[1], device)
292+
io_train = NS.create_io_arrays_posteriori(data_train, setup, device)
293+
io_valid = NS.create_io_arrays_posteriori(data_valid, setup, device)
299294
θ = device(copy(θ_start[itotal]))
300295
dataloader_post = NS.create_dataloader_posteriori(
301296
io_train;
@@ -305,7 +300,7 @@ function trainpost(;
305300
device = device,
306301
)
307302

308-
dudt_nn = NS.create_right_hand_side_with_closure(setup[1], psolver, closure, st)
303+
dudt_nn = NS.create_right_hand_side_with_closure(setup, psolver, closure, st)
309304
griddims = ((:) for _ = 1:params.D)
310305
inside = ((2:(nles+1)) for _ = 1:params.D)
311306
loss = CoupledNODE.create_loss_post_lux(
@@ -338,11 +333,25 @@ function trainpost(;
338333
nepochs_left = nepoch
339334
end
340335

336+
337+
# For the callback I am going to use the a-posteriori error estimator
338+
sample = namedtupleload(getdatafile(outdir, nles, Φ, dns_seeds_test[1]))
339+
it = 1:(nunroll_valid+1)
340+
data_cb = (;
341+
u = selectdim(sample.u, ndims(sample.u), it) |> collect |> device,
342+
t = sample.t[it],
343+
)
344+
tspan = (data_cb.t[1], data_cb.t[end])
345+
tsave = [nunroll_valid]
346+
dudt_cb = NS.create_right_hand_side_with_closure_inplace(
347+
setup, psolver, closure, st)
348+
loss_cb(_model, pp, _st, _data ) = compute_epost(dudt_cb, pp , tspan, data_cb, tsave, dt)[1][end]
349+
341350
callbackstate, callback = NS.create_callback(
342351
closure,
343352
θ,
344353
io_valid,
345-
loss,
354+
loss_cb,
346355
st;
347356
callbackstate = callbackstate,
348357
nunroll = nunroll_valid,
@@ -420,9 +429,10 @@ function compute_epost(rhs, ps, tspan, (u, t), tsave, dt)
420429
p = ps,
421430
adaptive = true,
422431
saveat = Array(t),
423-
tspan = tspan,
424-
save_start = false,
425-
dt = dt,
432+
#tstops = Array(t),
433+
#tspan = tspan,
434+
#save_start = false,
435+
#dt = dt,
426436
)
427437

428438
e = 0.0
@@ -441,6 +451,12 @@ function compute_epost(rhs, ps, tspan, (u, t), tsave, dt)
441451
push!(es, e / (it - 1))
442452
end
443453
end
454+
#for it in tsave
455+
# yref = y[inside..., :, 1:it]
456+
# ypred = pred[inside..., :, 1:it]
457+
458+
# Lux.MSELoss()(ypred, yref) |> e -> push!(es, e)
459+
#end
444460

445461
return es, time() - t0
446462

0 commit comments

Comments
 (0)