Skip to content

Commit 4f2b300

Browse files
committed
Reformat
1 parent 8a7c8e4 commit 4f2b300

File tree

13 files changed

+142
-261
lines changed

13 files changed

+142
-261
lines changed

.github/workflows/CI.yml

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,17 @@ jobs:
6767
using Pkg
6868
Pkg.develop(PackageSpec(path=pwd()))
6969
Pkg.instantiate()'
70-
- name: Run doctests
71-
run: |
72-
julia --project=docs -e '
73-
using Documenter: DocMeta, doctest
74-
using CoupledNODE
75-
DocMeta.setdocmeta!(CoupledNODE, :DocTestSetup, :(using CoupledNODE); recursive=true)
76-
doctest(CoupledNODE)'
77-
- name: Generate documentation
78-
run: julia --project=docs docs/make.jl
79-
env:
80-
JULIA_PKG_SERVER: ""
81-
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
82-
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
83-
GKSwstype: "100"
70+
#- name: Run doctests
71+
# run: |
72+
# julia --project=docs -e '
73+
# using Documenter: DocMeta, doctest
74+
# using CoupledNODE
75+
# DocMeta.setdocmeta!(CoupledNODE, :DocTestSetup, :(using CoupledNODE); recursive=true)
76+
# doctest(CoupledNODE)'
77+
#- name: Generate documentation
78+
# run: julia --project=docs docs/make.jl
79+
# env:
80+
# JULIA_PKG_SERVER: ""
81+
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
82+
# DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
83+
# GKSwstype: "100"

benchmark.jl

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,36 @@ end
6161
# Global variables for setting linestyle and colors in all plots
6262
PLOT_STYLES = Dict(
6363
:no_closure => (color="black", linestyle=:dash, linewidth=2),
64+
:no_closure_proj => (color="red", linestyle=:dash, linewidth=2),
6465
:reference => (color="black", linestyle=:dot, linewidth=2),
6566
:prior => (color="black", linestyle=:solid, linewidth=1),
6667
:post => (color="black", linestyle=:dashdot, linewidth=1),
6768
:inertia => (color="cyan", linestyle=:dot, linewidth=1),
6869
:smag => (color="darkgreen", linestyle=:dot, linewidth=1),
6970
)
7071

71-
# Color list: if there are more models, add more colors here
72-
# colors black, cyan and lightgreen are reserved, see above!
73-
# Bright Red-Orange, Sky Blue, Deep Purple, Hot Pink,
74-
# Bright Green, Dark Blue, Violet, Teal
72+
# Color list: high-contrast, colorblind-friendly palette
7573
colors_list = [
76-
"#ff3300", "#3399ff", "#9933cc", "#ff33cc",
77-
"#33cc33", "#00008B", "#6600cc", "#00cc99"
74+
"#E41A1C", # Red
75+
"#377EB8", # Blue
76+
"#4DAF4A", # Green
77+
"#984EA3", # Purple
78+
"#FF7F00", # Orange
79+
"#A65628", # Brown
80+
"#F781BF", # Pink
81+
"#999999", # Grey
82+
"#FFD700", # Gold
83+
"#00CED1", # Dark Turquoise
84+
"#1E90FF", # Dodger Blue
85+
"#228B22", # Forest Green
86+
"#D2691E", # Chocolate
87+
"#DC143C", # Crimson
88+
"#8B008B", # Dark Magenta
89+
"#FF1493", # Deep Pink
90+
"#00FF7F", # Spring Green
91+
"#4682B4", # Steel Blue
92+
"#B22222", # Firebrick
93+
"#20B2AA", # Light Sea Green
7894
]
7995

8096
# Loop over plot types and configurations
@@ -128,9 +144,9 @@ plot_labels = Dict(
128144
ylabel = "A-posteriori error",
129145
),
130146
:epost_vs_t => (
131-
title = "A-posteriori error as a function of time" * ", " * L"\frac{e_{M}(Nt)}{e_\text{no model}(Nt)}",
132-
xlabel = "Nt",
133-
ylabel = L"\frac{e_{M}(Nt)}{e_\text{no model}(Nt)}",
147+
title = "A-posteriori error as a function of time",
148+
xlabel = "t",
149+
ylabel = L"e_{M}(t)",
134150
),
135151
)
136152

cnn_model_workflow.jl

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -220,15 +220,15 @@ closure_INS, θ_INS = NeuralClosure.cnn(;
220220
# Give the CNN a test run
221221
# Note: Data and parameters are stored on the CPU, and
222222
# must be moved to the GPU before use (with `device`)
223-
let
224-
@info "CNN warm up run"
225-
using NeuralClosure.Zygote
226-
u = randn(T, 32, 32, 2, 10) |> device
227-
θ = θ_start |> device
228-
closure(u, θ, st)
229-
gradient-> sum(closure(u, θ, st)[1]), θ)
230-
clean()
231-
end
223+
#let
224+
# @info "CNN warm up run"
225+
# using NeuralClosure.Zygote
226+
# u = randn(T, 32, 32, 2, 10) |> device
227+
# θ = θ_start |> device
228+
# closure(u, θ, st)
229+
# gradient(θ -> sum(closure(u, θ, st)[1]), θ)
230+
# clean()
231+
#end
232232

233233
########################################################################## #src
234234

@@ -349,7 +349,8 @@ let
349349
dns_seeds_train,
350350
dns_seeds_valid,
351351
nunroll = conf["posteriori"]["nunroll"],
352-
dt = conf["posteriori"]["dt"],
352+
nsamples = conf["posteriori"]["nsamples"],
353+
dt = T(conf["posteriori"]["dt"]),
353354
closure,
354355
closure_name,
355356
θ_start = θ_cnn_prior,
@@ -468,15 +469,16 @@ let
468469
end
469470

470471
let
471-
tsave = [5,10,50,100,199]
472+
tsave = [5, 10, 25, 50, 100, 200, 500, 750, 1000]
473+
tsave .-=1
472474
s = (length(params.nles), length(params.filters), length(projectorders))
473475
swt = (length(params.nles), length(params.filters), length(projectorders), length(tsave))
474476
epost = (;
475477
nomodel = zeros(T, swt),
476478
model_prior = zeros(T, swt),
477479
model_post = zeros(T, swt),
478480
model_t_post_inference = zeros(T, s),
479-
nts = tsave,
481+
nts = zeros(T, length(tsave)),
480482
)
481483
for (iorder, projectorder) in enumerate(projectorders),
482484
(ifil, Φ) in enumerate(params.filters),
@@ -492,8 +494,10 @@ let
492494
u = selectdim(sample.u, ndims(sample.u), it) |> collect |> device,
493495
t = sample.t[it],
494496
)
497+
epost.nts[:] = [data.t[i] for i in tsave]
498+
@info epost.nts
495499
tspan = (data.t[1], data.t[end])
496-
dt = conf["posteriori"]["dt"]
500+
dt = T(conf["posteriori"]["dt"])
497501

498502
## No model
499503
dudt_nomod = NS.create_right_hand_side_inplace(

configs/local/att_1.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ closure:
3939
rng: "Xoshiro(seeds.θ_start)"
4040
priori:
4141
dotrain: true
42-
nepoch: 5000
42+
nepoch: 1000
4343
batchsize: 64
4444
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
4545
do_plot: false
@@ -48,9 +48,10 @@ posteriori:
4848
dotrain: true
4949
projectorders: "(ProjectOrder.Last, )"
5050
nepoch: 300
51-
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
51+
opt: "OptimiserChain(Adam(T(1.0e-4)), ClipGrad(0.01))"
5252
nunroll: 5
5353
nunroll_valid: 5
54-
dt: T(1e-4)
54+
dt: 0.0001
55+
nsamples: 1
5556
do_plot: false
5657
plot_train: false

configs/local/cnn_1.yaml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ params:
77
lims: [0.0, 1.0]
88
Re: 6000.0
99
tburn: 0.5
10-
tsim: 2.0
11-
savefreq: 100
10+
tsim: 5.0
11+
savefreq: 50
1212
ndns: 1024
1313
nles: [32]
1414
filters: ["FaceAverage()"]
@@ -33,18 +33,19 @@ closure:
3333
rng: "Xoshiro(seeds.θ_start)"
3434
priori:
3535
dotrain: true
36-
nepoch: 5000
36+
nepoch: 3000
3737
batchsize: 64
3838
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
3939
do_plot: false
4040
plot_train: false
4141
posteriori:
4242
dotrain: true
4343
projectorders: "(ProjectOrder.Last, )"
44-
nepoch: 300
45-
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
44+
nepoch: 100
45+
opt: "OptimiserChain(Adam(T(1.0e-4)), ClipGrad(0.01))"
4646
nunroll: 5
4747
nunroll_valid: 5
48-
dt: T(1e-4)
48+
dt: 0.0001
49+
nsamples: 1
4950
do_plot: false
5051
plot_train: false

configs/local/cnn_2.yaml

Lines changed: 0 additions & 50 deletions
This file was deleted.
Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
docreatedata: true
1+
docreatedata: false
22
docomp: false
33
ntrajectory: 8
44
T: "Float32"
@@ -7,8 +7,8 @@ params:
77
lims: [0.0, 1.0]
88
Re: 6000.0
99
tburn: 0.5
10-
tsim: 2.0
11-
savefreq: 100
10+
tsim: 5.0
11+
savefreq: 50
1212
ndns: 1024
1313
nles: [32]
1414
filters: ["FaceAverage()"]
@@ -19,12 +19,12 @@ params:
1919
processors: "(; log = timelogger(; nupdate=100))"
2020
Δt: 0.0001
2121
seeds:
22-
dns: 1111
22+
dns: 123
2323
θ_start: 234
2424
prior: 345
2525
post: 456
2626
closure:
27-
name: "cnn_new"
27+
name: "cnn_INS"
2828
type: cnn
2929
radii: [2, 2, 2, 2, 2]
3030
channels: [24, 24, 24, 24, 2]
@@ -33,18 +33,19 @@ closure:
3333
rng: "Xoshiro(seeds.θ_start)"
3434
priori:
3535
dotrain: true
36-
nepoch: 5000
36+
nepoch: 3000
3737
batchsize: 64
3838
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
3939
do_plot: false
4040
plot_train: false
4141
posteriori:
4242
dotrain: true
4343
projectorders: "(ProjectOrder.Last, )"
44-
nepoch: 300
45-
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
44+
nepoch: 100
45+
opt: "OptimiserChain(Adam(T(1.0e-4)), ClipGrad(0.01))"
4646
nunroll: 5
4747
nunroll_valid: 5
48-
dt: T(1e-3)
48+
dt: 0.0001
49+
nsamples: 1
4950
do_plot: false
5051
plot_train: false

configs/local/cnn_backsolve.yaml

Lines changed: 0 additions & 51 deletions
This file was deleted.

0 commit comments

Comments
 (0)