Skip to content

Commit 77d4368

Browse files
committed
Fix attention model
1 parent 2d93472 commit 77d4368

File tree

3 files changed

+97
-4
lines changed

3 files changed

+97
-4
lines changed

configs/snellius64/att_base.yaml

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
docreatedata: true
2+
docomp: true
3+
ntrajectory: 8
4+
T: "Float64"
5+
dataproj: true
6+
projtest: false
7+
params:
8+
D: 2
9+
lims: [0.0, 1.0]
10+
Re: 6000.0
11+
tburn: 0.5
12+
tsim: 5.0
13+
savefreq: 50
14+
ndns: 4096
15+
nles: [64]
16+
filters: ["FaceAverage()"]
17+
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
18+
method: "RKMethods.RK44(; T)"
19+
bodyforce: "(dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y)"
20+
issteadybodyforce: true
21+
processors: "(; log = timelogger(; nupdate=100))"
22+
Δt: 0.00005
23+
seeds:
24+
dns: 123456
25+
θ_start: 234
26+
prior: 345
27+
post: 456
28+
closure:
29+
name: "Attention"
30+
type: attentioncnn
31+
radii: [2, 2, 2, 2, 2]
32+
channels: [24, 24, 24, 24, 2]
33+
activations: ["tanh", "tanh", "tanh", "tanh", "identity"]
34+
use_bias: [true, true, true, true, false]
35+
use_attention: [true, false, false, false, false]
36+
emb_sizes: [124, 124, 124, 124, 124]
37+
# Ns: [148, 144, 140, 136, 132]
38+
Ns: [ 84, 80, 76, 72, 68]
39+
# patch_sizes: [37, 36, 35, 34, 33]
40+
patch_sizes: [21, 20, 19, 18, 17]
41+
n_heads: [4, 4, 4, 4, 4]
42+
sum_attention: [false, false, false, false, false]
43+
rng: "Xoshiro(seeds.θ_start)"
44+
priori:
45+
dotrain: true
46+
nepoch: 1000
47+
batchsize: 64
48+
opt: "Adam(T(1.0e-3))"
49+
do_plot: false
50+
plot_train: false
51+
lambda: 0.00005
52+
posteriori:
53+
dotrain: true
54+
projectorders: "(ProjectOrder.Last, )"
55+
nepoch: 3000
56+
opt: "Adam(T(1.0e-4))"
57+
nunroll: 5
58+
nunroll_valid: 10
59+
nsamples: 5
60+
dt: 0.0001
61+
do_plot: false
62+
plot_train: false
63+
lambda: 0.0000005
64+
sciml_solver: "Tsit5()"

extra_model_workflow.jl

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,13 @@ else
7878
logfile = "log_$(Dates.now()).out"
7979
end
8080
logfile = joinpath(logdir, logfile)
81-
setsnelliuslogger(logfile)
81+
# check if I am planning to use Enzyme, in which case I can not touch the logger
82+
if (haskey(conf["priori"], "ad_type") && occursin("Enzyme", conf["priori"]["ad_type"])) ||
83+
(haskey(conf["posteriori"], "ad_type") && occursin("Enzyme", conf["posteriori"]["ad_type"]))
84+
@warn "Enzyme is used, so logger will not be set to ConsoleLogger"
85+
else
86+
setsnelliuslogger(logfile)
87+
end
8288

8389
@info "# A-posteriori analysis: Forced turbulence (2D)"
8490

@@ -98,6 +104,7 @@ using CairoMakie
98104
using CoupledNODE: loss_priori_lux, create_loss_post_lux
99105
using CUDA
100106
using DifferentialEquations
107+
using Enzyme
101108
using IncompressibleNavierStokes.RKMethods
102109
using JLD2
103110
using LaTeXStrings
@@ -111,6 +118,7 @@ using OptimizationOptimJL
111118
using OptimizationCMAEvolutionStrategy
112119
using ParameterSchedulers
113120
using Random
121+
using SciMLSensitivity
114122

115123

116124
# ## Random number seeds
@@ -174,6 +182,22 @@ dns_seeds_train = dns_seeds[1:ntrajectory-2]
174182
dns_seeds_valid = dns_seeds[ntrajectory-1:ntrajectory-1]
175183
dns_seeds_test = dns_seeds[ntrajectory:ntrajectory]
176184

185+
doprojtest = conf["projtest"]
186+
if doprojtest && taskid == 1
187+
testprojfile = joinpath(outdir, "test_dns_proj.jld2")
188+
if isfile(testprojfile)
189+
@info "Test DNS projection file already exists."
190+
else
191+
create_test_dns_proj(
192+
nchunks = 8000;
193+
params...,
194+
rng = Xoshiro(2406),
195+
backend = backend,
196+
filename = testprojfile,
197+
)
198+
end
199+
end
200+
177201
# Create data
178202
docreatedata = conf["docreatedata"]
179203
for i = 1:ntrajectory
@@ -247,7 +271,7 @@ let
247271
u = randn(T, params.nles[1], params.nles[1], 2, 10) |> device
248272
θ = θ_start |> device
249273
closure(u, θ, st)
250-
gradient-> sum(closure(u, θ, st)[1]), θ)
274+
Zygote.gradient-> sum(closure(u, θ, st)[1]), θ)
251275
clean()
252276
end
253277

@@ -299,7 +323,8 @@ let
299323
plot_train = conf["priori"]["plot_train"],
300324
nepoch,
301325
dataproj = conf["dataproj"],
302-
λ = conf["priori"]["lambda"],
326+
λ = haskey(conf["priori"], "λ") ? eval(Meta.parse(conf["priori"]["λ"])) : nothing,
327+
ad_type = haskey(conf["priori"], "ad_type") ? eval(Meta.parse(conf["priori"]["ad_type"])) : Optimization.AutoZygote(),
303328
)
304329
end
305330
end
@@ -412,7 +437,9 @@ let
412437
sensealg = sensealg,
413438
sciml_solver = sciml_solver,
414439
dataproj = conf["dataproj"],
415-
λ = conf["posteriori"]["lambda"],
440+
λ = haskey(conf["posteriori"], "λ") ? eval(Meta.parse(conf["posteriori"]["λ"])) : nothing,
441+
multishoot_nt = haskey(conf["posteriori"], "multishoot_nt") ? conf["posteriori"]["multishoot_nt"] : 0,
442+
ad_type = haskey(conf["posteriori"], "ad_type") ? eval(Meta.parse(conf["posteriori"]["ad_type"])) : Optimization.AutoZygote(),
416443
)
417444
end
418445
end

multisub.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,5 @@
2222
sbatch -J owr job_a100.sh configs/snellius64/cnn_owr.yaml
2323
sbatch -J bs3 job_a100.sh configs/snellius64/cnn_bs3.yaml
2424
#sbatch -J composite job_a100.sh configs/snellius64/cnn_composite.yaml
25+
26+
sbatch -J att job_a100_extra.sh configs/snellius64/att_base.yaml

0 commit comments

Comments
 (0)