@@ -106,6 +106,7 @@ using OptimizationOptimJL
106106using OptimizationCMAEvolutionStrategy
107107using ParameterSchedulers
108108using Random
109+ using SciMLSensitivity
109110
110111
111112# ## Random number seeds
@@ -350,6 +351,19 @@ projectorders = eval(Meta.parse(conf["posteriori"]["projectorders"]))
350351nprojectorders = length (projectorders)
351352@assert nprojectorders == 1 " Only DCF should be done"
352353
354+ sensealg = haskey (conf[" posteriori" ], " sensealg" ) ? eval (Meta. parse (conf[" posteriori" ][" sensealg" ])) : nothing
355+ sciml_solver = haskey (conf[" posteriori" ], " sciml_solver" ) ? eval (Meta. parse (conf[" posteriori" ][" sciml_solver" ])) : nothing
356+ if sensealg != = nothing
357+ @info " Using sensitivity algorithm: $sensealg "
358+ else
359+ @info " No sensitivity algorithm specified"
360+ end
361+ if sciml_solver != = nothing
362+ @info " Using SciML solver: $sciml_solver "
363+ else
364+ @info " No SciML solver specified"
365+ end
366+
353367# Train
354368for i = 1 : ntrajectory
355369 if i% numtasks == taskid - 1
377391 nepoch,
378392 do_plot = conf[" posteriori" ][" do_plot" ],
379393 plot_train = conf[" posteriori" ][" plot_train" ],
380- sensealg = haskey (conf[" posteriori" ],:sensealg ) ? eval (Meta. parse (conf[" posteriori" ][" sensealg" ])) : nothing ,
394+ sensealg = sensealg,
395+ sciml_solver = sciml_solver,
381396 dataproj = conf[" dataproj" ],
382397 )
383398end
@@ -522,14 +537,14 @@ let
522537 dudt_nomod = NS. create_right_hand_side_inplace (
523538 setup, psolver)
524539
525- epost. nomodel[I,:], _ = compute_epost (dudt_nomod, θ_cnn_post[I]. * 0 , tspan, data, tsave, dt)
540+ epost. nomodel[I,:], _ = compute_epost (dudt_nomod, sciml_solver, θ_cnn_post[I]. * 0 , tspan, data, tsave, dt)
526541 @info " Epost nomodel" epost. nomodel[I,:]
527542 # with closure
528543 dudt = NS. create_right_hand_side_with_closure_inplace (
529544 setup, psolver, closure, st)
530- epost. model_prior[I, :], _ = compute_epost (dudt, device (θ_cnn_prior[ig, ifil]) , tspan, data, tsave, dt)
545+ epost. model_prior[I, :], _ = compute_epost (dudt, sciml_solver, device (θ_cnn_prior[ig, ifil]) , tspan, data, tsave, dt)
531546 @info " Epost model_prior" epost. model_prior[I, :]
532- epost. model_post[I, :], epost. model_t_post_inference[I] = compute_epost (dudt, device (θ_cnn_post[ig, ifil, iorder]) , tspan, data, tsave, dt)
547+ epost. model_post[I, :], epost. model_t_post_inference[I] = compute_epost (dudt, sciml_solver, device (θ_cnn_post[ig, ifil, iorder]) , tspan, data, tsave, dt)
533548 @info " Epost model_post" epost. model_post[I, :]
534549
535550 clean ()
0 commit comments