7878 logfile = " log_$(Dates. now ()) .out"
7979end
8080logfile = joinpath (logdir, logfile)
81- setsnelliuslogger (logfile)
81+ # check if I am planning to use Enzyme, in which case I can not touch the logger
82+ if (haskey (conf[" priori" ], " ad_type" ) && occursin (" Enzyme" , conf[" priori" ][" ad_type" ])) ||
83+ (haskey (conf[" posteriori" ], " ad_type" ) && occursin (" Enzyme" , conf[" posteriori" ][" ad_type" ]))
84+ @warn " Enzyme is used, so logger will not be set to ConsoleLogger"
85+ else
86+ setsnelliuslogger (logfile)
87+ end
8288
8389@info " # A-posteriori analysis: Forced turbulence (2D)"
8490
@@ -98,6 +104,7 @@ using CairoMakie
98104using CoupledNODE: loss_priori_lux, create_loss_post_lux
99105using CUDA
100106using DifferentialEquations
107+ using Enzyme
101108using IncompressibleNavierStokes. RKMethods
102109using JLD2
103110using LaTeXStrings
@@ -111,6 +118,7 @@ using OptimizationOptimJL
111118using OptimizationCMAEvolutionStrategy
112119using ParameterSchedulers
113120using Random
121+ using SciMLSensitivity
114122
115123
116124# ## Random number seeds
@@ -174,6 +182,22 @@ dns_seeds_train = dns_seeds[1:ntrajectory-2]
174182dns_seeds_valid = dns_seeds[ntrajectory- 1 : ntrajectory- 1 ]
175183dns_seeds_test = dns_seeds[ntrajectory: ntrajectory]
176184
185+ doprojtest = conf[" projtest" ]
186+ if doprojtest && taskid == 1
187+ testprojfile = joinpath (outdir, " test_dns_proj.jld2" )
188+ if isfile (testprojfile)
189+ @info " Test DNS projection file already exists."
190+ else
191+ create_test_dns_proj (
192+ nchunks = 8000 ;
193+ params... ,
194+ rng = Xoshiro (2406 ),
195+ backend = backend,
196+ filename = testprojfile,
197+ )
198+ end
199+ end
200+
177201# Create data
178202docreatedata = conf[" docreatedata" ]
179203for i = 1 : ntrajectory
247271 u = randn (T, params. nles[1 ], params. nles[1 ], 2 , 10 ) |> device
248272 θ = θ_start |> device
249273 closure (u, θ, st)
250- gradient (θ -> sum (closure (u, θ, st)[1 ]), θ)
274+ Zygote . gradient (θ -> sum (closure (u, θ, st)[1 ]), θ)
251275 clean ()
252276end
253277
299323 plot_train = conf[" priori" ][" plot_train" ],
300324 nepoch,
301325 dataproj = conf[" dataproj" ],
302- λ = conf[" priori" ][" lambda" ],
326+ λ = haskey (conf[" priori" ], " λ" ) ? eval (Meta. parse (conf[" priori" ][" λ" ])) : nothing ,
327+ ad_type = haskey (conf[" priori" ], " ad_type" ) ? eval (Meta. parse (conf[" priori" ][" ad_type" ])) : Optimization. AutoZygote (),
303328 )
304329end
305330end
412437 sensealg = sensealg,
413438 sciml_solver = sciml_solver,
414439 dataproj = conf[" dataproj" ],
415- λ = conf[" posteriori" ][" lambda" ],
440+ λ = haskey (conf[" posteriori" ], " λ" ) ? eval (Meta. parse (conf[" posteriori" ][" λ" ])) : nothing ,
441+ multishoot_nt = haskey (conf[" posteriori" ], " multishoot_nt" ) ? conf[" posteriori" ][" multishoot_nt" ] : 0 ,
442+ ad_type = haskey (conf[" posteriori" ], " ad_type" ) ? eval (Meta. parse (conf[" posteriori" ][" ad_type" ])) : Optimization. AutoZygote (),
416443 )
417444end
418445end
0 commit comments