Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,36 @@ NeuralClosure = "099dac27-d7f2-4047-93d5-0baee36b9c25"
Observables = "510215fc-4207-5dde-b226-833fc4488ee2"
OpenSSL_jll = "458c3c95-2e84-50aa-8efc-19380b2a3a95"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
OptimizationCMAEvolutionStrategy = "bd407f91-200f-4536-9381-e4ba712f53f8"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"

[sources]
AttentionLayer = {rev = "main", url = "https://github.com/DEEPDIP-project/AttentionLayer.jl.git"}
ConvolutionalNeuralOperators = {rev = "main", url = "https://github.com/DEEPDIP-project/ConvolutionalNeuralOperators.jl.git"}
NeuralClosure = {rev = "main", url = "https://github.com/DEEPDIP-project/NeuralClosure.jl.git"}
[sources.AttentionLayer]
rev = "main"
url = "https://github.com/DEEPDIP-project/AttentionLayer.jl.git"

[sources.ConvolutionalNeuralOperators]
rev = "main"
url = "https://github.com/DEEPDIP-project/ConvolutionalNeuralOperators.jl.git"

[sources.NeuralClosure]
rev = "main"
url = "https://github.com/DEEPDIP-project/NeuralClosure.jl.git"

[compat]
Accessors = "0.1"
Adapt = "4"
AttentionLayer = "0"
CUDA = "5"
CUDSS = "0.4"
CairoMakie = "0.12"
CairoMakie = "0.12, 0.15"
ComponentArrays = "0.15"
ConvolutionalNeuralOperators = "0"
Dates = "1"
Expand All @@ -73,7 +83,10 @@ NeuralClosure = "1.0.0"
Observables = "0.5"
OpenSSL_jll = "3.0.13"
Optimisers = "0.4"
OptimizationCMAEvolutionStrategy = "0.3.0"
OptimizationOptimJL = "0.4.3"
ParameterSchedulers = "0.4"
SciMLSensitivity = "7.84.0"
Statistics = "1.11.1"
cuDNN = "1"
julia = "1.11"
Expand Down
44 changes: 30 additions & 14 deletions benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@ end
basedir = haskey(ENV, "DEEPDIP") ? ENV["DEEPDIP"] : @__DIR__
outdir = joinpath(basedir, "output", "kolmogorov")
confdir = joinpath(basedir, "configs/local")
#confdir = joinpath(basedir, "configs/snellius")
confdir = joinpath(basedir, "configs/snellius")
@warn "Using configuration files from $confdir"
compdir = joinpath(outdir, "comparison")
ispath(compdir) || mkpath(compdir)

# List configurations files
using Glob
list_confs = glob("*.yaml", confdir)
#exclude_patterns = ["att", "cno", "cnn_ins", "_1", "nopr"]
exclude_patterns = ["att", "cno", "int", "back", "rk4", "cnn_1" ]
@warn "Excluding configurations with patterns: $(exclude_patterns)"
all_confs = glob("*.yaml", confdir)
list_confs = filter(conf -> all(!occursin(pat, conf) for pat in exclude_patterns), all_confs)
if isempty(list_confs)
@error "No configuration files found in $confdir"
end
Expand Down Expand Up @@ -97,16 +101,16 @@ colors_list = [

# Loop over plot types and configurations
plot_labels = Dict(
:prior_hist => (
title = "A-priori training history for different configurations",
xlabel = "Iteration",
ylabel = "A-priori error",
),
:posteriori_hist => (
title = "A-posteriori training history for different configurations",
xlabel = "Iteration",
ylabel = "DCF",
),
#:prior_hist => (
# title = "A-priori training history for different configurations",
# xlabel = "Iteration",
# ylabel = "A-priori error",
#),
#:posteriori_hist => (
# title = "A-posteriori training history for different configurations",
# xlabel = "Iteration",
# ylabel = "DCF",
#),
:divergence => (
title = "Divergence for different configurations",
xlabel = "t",
Expand All @@ -128,7 +132,12 @@ plot_labels = Dict(
:training_time => (
title = "Training time for different configurations",
xlabel = "Model",
ylabel = "Training time (s)",
ylabel = "Training time (s) (per iteration)",
),
:training_comptime => (
title = "Training time for different configurations",
xlabel = "Model",
ylabel = "Full Training time (s)",
),
:inference_time => (
title = "Inference time for different configurations",
Expand Down Expand Up @@ -252,6 +261,13 @@ for key in keys(plot_labels)
)
append!(bar_positions, bar_position)
append!(bar_labels, bar_label)
elseif key == :training_comptime
projectorders = eval(Meta.parse(conf["posteriori"]["projectorders"]))
bar_label, bar_position = plot_training_comptime(
outdir, closure_name, nles, Φ, projectorders, col_index, ax, color
)
append!(bar_positions, bar_position)
append!(bar_labels, bar_label)
elseif key == :inference_time
bar_label, bar_position = plot_inference_time(
outdir, closure_name, nles, data_index, col_index, ax, color
Expand Down Expand Up @@ -299,7 +315,7 @@ for key in keys(plot_labels)
end

# Add xticks in barplot
if key in (:training_time, :inference_time, :num_parameters, :eprior, :epost)
if key in (:training_time, :training_comptime, :inference_time, :num_parameters, :eprior, :epost)
ax.xticks = (bar_positions, bar_labels)
end

Expand Down
32 changes: 27 additions & 5 deletions cnn_model_workflow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,12 @@ using Lux
using LuxCUDA
using NNlib
using Optimisers
using Optimisers: Adam
using OptimizationOptimJL
using OptimizationCMAEvolutionStrategy
using ParameterSchedulers
using Random
using SciMLSensitivity


# ## Random number seeds
Expand Down Expand Up @@ -248,6 +252,11 @@ if haskey(conf["priori"], "reuse")
@info "Reuse a-priori training from closure named: $reuse"
reusepriorfile(reuse, outdir, closure_name)
end
if haskey(conf["posteriori"], "reuse")
reuse = conf["posteriori"]["reuse"]
@info "Reuse a-posteriori training from closure named: $reuse"
reusepostfile(reuse, outdir, closure_name)
end

# Train
for i = 1:ntrajectory
Expand Down Expand Up @@ -342,6 +351,19 @@ projectorders = eval(Meta.parse(conf["posteriori"]["projectorders"]))
nprojectorders = length(projectorders)
@assert nprojectorders == 1 "Only DCF should be done"

sensealg = haskey(conf["posteriori"], "sensealg") ? eval(Meta.parse(conf["posteriori"]["sensealg"])) : nothing
sciml_solver = haskey(conf["posteriori"], "sciml_solver") ? eval(Meta.parse(conf["posteriori"]["sciml_solver"])) : nothing
if sensealg !== nothing
@info "Using sensitivity algorithm: $sensealg"
else
@info "No sensitivity algorithm specified"
end
if sciml_solver !== nothing
@info "Using SciML solver: $sciml_solver"
else
@info "No SciML solver specified"
end

# Train
for i = 1:ntrajectory
if i%numtasks == taskid -1
Expand All @@ -357,7 +379,6 @@ let
postseed = seeds.post,
dns_seeds_train,
dns_seeds_valid,
dns_seeds_test,
nunroll = conf["posteriori"]["nunroll"],
nsamples = conf["posteriori"]["nsamples"],
dt = T(conf["posteriori"]["dt"]),
Expand All @@ -370,7 +391,8 @@ let
nepoch,
do_plot = conf["posteriori"]["do_plot"],
plot_train = conf["posteriori"]["plot_train"],
sensealg = haskey(conf["posteriori"],:sensealg) ? eval(Meta.parse(conf["posteriori"]["sensealg"])) : nothing,
sensealg = sensealg,
sciml_solver = sciml_solver,
dataproj = conf["dataproj"],
)
end
Expand Down Expand Up @@ -515,14 +537,14 @@ let
dudt_nomod = NS.create_right_hand_side_inplace(
setup, psolver)

epost.nomodel[I,:], _ = compute_epost(dudt_nomod, θ_cnn_post[I].*0 , tspan, data, tsave, dt)
epost.nomodel[I,:], _ = compute_epost(dudt_nomod, sciml_solver, θ_cnn_post[I].*0 , tspan, data, tsave, dt)
@info "Epost nomodel" epost.nomodel[I,:]
# with closure
dudt = NS.create_right_hand_side_with_closure_inplace(
setup, psolver, closure, st)
epost.model_prior[I, :], _ = compute_epost(dudt, device(θ_cnn_prior[ig, ifil]) , tspan, data, tsave, dt)
epost.model_prior[I, :], _ = compute_epost(dudt, sciml_solver, device(θ_cnn_prior[ig, ifil]) , tspan, data, tsave, dt)
@info "Epost model_prior" epost.model_prior[I, :]
epost.model_post[I, :], epost.model_t_post_inference[I] = compute_epost(dudt, device(θ_cnn_post[ig, ifil, iorder]) , tspan, data, tsave, dt)
epost.model_post[I, :], epost.model_t_post_inference[I] = compute_epost(dudt, sciml_solver, device(θ_cnn_post[ig, ifil, iorder]) , tspan, data, tsave, dt)
@info "Epost model_post" epost.model_post[I, :]

clean()
Expand Down
18 changes: 11 additions & 7 deletions configs/snellius/att_1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ docreatedata: true
docomp: true
ntrajectory: 8
T: "Float32"
dataproj: true
params:
D: 2
lims: [0.0, 1.0]
Expand All @@ -10,7 +11,7 @@ params:
tsim: 5.0
savefreq: 50
ndns: 4096
nles: [128]
nles: [64]
filters: ["FaceAverage()"]
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
method: "RKMethods.Wray3(; T)"
Expand All @@ -32,25 +33,28 @@ closure:
use_bias: [true, true, true, true, false]
use_attention: [true, false, false, false, false]
emb_sizes: [124, 124, 124, 124, 124]
Ns: [148, 144, 140, 136, 132]
patch_sizes: [37, 36, 35, 34, 33]
# Ns: [148, 144, 140, 136, 132]
Ns: [ 84, 80, 76, 72, 68]
# patch_sizes: [37, 36, 35, 34, 33]
patch_sizes: [21, 20, 19, 18, 17]
n_heads: [4, 4, 4, 4, 4]
sum_attention: [false, false, false, false, false]
rng: "Xoshiro(seeds.θ_start)"
priori:
dotrain: true
nepoch: 10000
nepoch: 50000
batchsize: 64
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
do_plot: false
plot_train: false
posteriori:
dotrain: true
projectorders: "(ProjectOrder.Last, )"
nepoch: 300
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
nepoch: 1500
opt: "OptimizationCMAEvolutionStrategy.CMAEvolutionStrategyOpt()"
nunroll: 5
nunroll_valid: 10
dt: T(5e-5)
dt: 0.0001
nsamples: 1
do_plot: false
plot_train: false
1 change: 1 addition & 0 deletions configs/snellius/cnn_1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@ posteriori:
dt: 0.0001
do_plot: false
plot_train: false
sciml_solver: "Tsit5()"
9 changes: 5 additions & 4 deletions configs/snellius/cnn_2.yaml → configs/snellius/cnn_CMA.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ seeds:
prior: 345
post: 456
closure:
name: "cnn_remove"
name: "cnn_cma"
type: cnn
radii: [2, 2, 2, 2, 2]
channels: [24, 24, 24, 24, 2]
Expand All @@ -35,19 +35,20 @@ closure:
priori:
reuse: "cnn_project"
dotrain: true
nepoch: 10000
nepoch: 50000
batchsize: 64
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
do_plot: false
plot_train: false
posteriori:
dotrain: true
projectorders: "(ProjectOrder.Last, )"
nepoch: 100
opt: "OptimiserChain(Adam(T(1.0e-4)), ClipGrad(0.01))"
nepoch: 1500
opt: "OptimizationCMAEvolutionStrategy.CMAEvolutionStrategyOpt()"
nunroll: 5
nunroll_valid: 10
nsamples: 1
dt: 0.0001
do_plot: false
plot_train: false
sciml_solver: "Tsit5()"
56 changes: 56 additions & 0 deletions configs/snellius/cnn_backsol.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
docreatedata: true
docomp: true
ntrajectory: 8
T: "Float32"
dataproj: false
params:
D: 2
lims: [0.0, 1.0]
Re: 6000.0
tburn: 0.5
tsim: 5.0
savefreq: 50
ndns: 4096
nles: [64]
filters: ["FaceAverage()"]
icfunc: "(setup, psolver, rng) -> random_field(setup, T(0); kp=20, psolver, rng)"
method: "RKMethods.Wray3(; T)"
bodyforce: "(dim, x, y, t) -> (dim == 1) * 5 * sinpi(8 * y)"
issteadybodyforce: true
processors: "(; log = timelogger(; nupdate=100))"
Δt: 0.00005
seeds:
dns: 123456
θ_start: 234
prior: 345
post: 456
closure:
name: "cnn_backsol"
type: cnn
radii: [2, 2, 2, 2, 2]
channels: [24, 24, 24, 24, 2]
activations: ["tanh", "tanh", "tanh", "tanh", "identity"]
use_bias: [true, true, true, true, false]
rng: "Xoshiro(seeds.θ_start)"
priori:
reuse: "cnn_noproj"
dotrain: true
nepoch: 50000
batchsize: 64
opt: "OptimiserChain(Adam(T(1.0e-3)), ClipGrad(0.1))"
do_plot: false
plot_train: false
posteriori:
#reuse: "cnn_project"
dotrain: true
projectorders: "(ProjectOrder.Last, )"
nepoch: 1500
opt: "OptimiserChain(Adam(T(1.0e-4)), ClipGrad(0.01))"
nunroll: 5
nunroll_valid: 10
nsamples: 1
dt: 0.0001
do_plot: false
plot_train: false
sensealg: "BacksolveAdjoint()"
sciml_solver: "Tsit5()"
Loading