Skip to content

Fix GaussAdjoint with callbacks #1060

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,10 @@ struct ReverseLossCallback{λType, timeType, yType, RefType, FMType, AlgType, dg
f::fType
sol::solType
Δλas::ΔλasType
no_start::Bool
end

function ReverseLossCallback(sensefun, λ, t, dgdu, dgdp, cur_time)
function ReverseLossCallback(sensefun, λ, t, dgdu, dgdp, cur_time, no_start)
(; sensealg, y) = sensefun
isq = (sensealg isa QuadratureAdjoint)

Expand All @@ -585,18 +586,20 @@ function ReverseLossCallback(sensefun, λ, t, dgdu, dgdp, cur_time)
if ArrayInterface.ismutable(y)
return ReverseLossCallback(isq, λ, t, y, cur_time, idx, factorized_mass_matrix,
sensealg, dgdu, dgdp, sensefun.diffcache, sensefun.f,
nothing, Δλas)
nothing, Δλas, no_start)
else
return ReverseLossCallback(isq, λ, t, y, cur_time, idx, factorized_mass_matrix,
sensealg, dgdu, dgdp, sensefun.diffcache, sensefun.f,
sensefun.sol, Δλas)
sensefun.sol, Δλas, no_start)
end
end

function (f::ReverseLossCallback)(integrator)
(; isq, λ, t, y, cur_time, idx, F, sensealg, dgdu, dgdp, sol) = f
(; isq, λ, t, y, cur_time, idx, F, sensealg, dgdu, dgdp, sol, no_start) = f
(; diffvar_idxs, algevar_idxs, issemiexplicitdae, J, uf, f_cache, jac_config) = f.diffcache

no_start && !(sensealg isa BacksolveAdjoint) && cur_time[] == 1 && return nothing

p, u = integrator.p, integrator.u

if sensealg isa BacksolveAdjoint
Expand Down Expand Up @@ -657,7 +660,7 @@ end

# handle discrete loss contributions
function generate_callbacks(sensefun, dgdu, dgdp, λ, t, t0, callback, init_cb,
terminated = false)
terminated = false, no_start = false)
if sensefun isa NILSASSensitivityFunction
(; sensealg) = sensefun.S
else
Expand All @@ -678,7 +681,7 @@ function generate_callbacks(sensefun, dgdu, dgdp, λ, t, t0, callback, init_cb,
# callbacks can lead to non-unique time points
_t, duplicate_iterator_times = separate_nonunique(t)

rlcb = ReverseLossCallback(sensefun, λ, t, dgdu, dgdp, cur_time)
rlcb = ReverseLossCallback(sensefun, λ, t, dgdu, dgdp, cur_time, no_start)

if eltype(_t) !== typeof(t0)
_t = convert.(typeof(t0), _t)
Expand All @@ -688,7 +691,7 @@ function generate_callbacks(sensefun, dgdu, dgdp, λ, t, t0, callback, init_cb,
# handle duplicates (currently only for double occurrences)
if duplicate_iterator_times !== nothing
# use same ref for cur_time to cope with concrete_solve
cbrev_dupl_affect = ReverseLossCallback(sensefun, λ, t, dgdu, dgdp, cur_time)
cbrev_dupl_affect = ReverseLossCallback(sensefun, λ, t, dgdu, dgdp, cur_time, no_start)
cb_dupl = PresetTimeCallback(duplicate_iterator_times[1], cbrev_dupl_affect)
return CallbackSet(cb, reverse_cbs, cb_dupl), rlcb, duplicate_iterator_times
else
Expand Down
27 changes: 23 additions & 4 deletions src/callback_tracking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ function _setup_reverse_callbacks(
du = first(get_tmp_cache(integrator))
λ, grad, y, dλ, dgrad, dy = split_states(du, integrator.u, integrator.t, S)

if sensealg isa GaussAdjoint
dgrad = integrator.f.f.integrating_cb.affect!.accumulation_cache
recursive_copyto!(dgrad, 0)
end

# if save_positions[2] = false, then the right limit is not saved. Thus, for
# the QuadratureAdjoint we would need to lift y from the left to the right limit.
# However, one also needs to update dgrad later on.
Expand Down Expand Up @@ -330,16 +335,25 @@ function _setup_reverse_callbacks(
fakeSp = CallbackSensitivityFunction(wp, sensealg, diffcaches[2],
integrator.sol.prob)
#vjp with Jacobin given by dw/dp before event and vector given by grad
vecjacobian!(dgrad, integrator.p, grad, y, integrator.t, fakeSp;

if sensealg isa GaussAdjoint
vecjacobian!(dgrad, integrator.p,
integrator.f.f.integrating_cb.affect!.integrand_values.integrand,
y, integrator.t, fakeSp; dgrad = nothing, dy = nothing)
integrator.f.f.integrating_cb.affect!.integrand_values.integrand .= dgrad
else
vecjacobian!(dgrad, integrator.p, grad, y, integrator.t, fakeSp;
dgrad = nothing, dy = nothing)
grad .= dgrad
grad .= dgrad
end
end
end

vecjacobian!(dλ, y, λ, integrator.p, integrator.t, fakeS;
dgrad = dgrad, dy = dy)

dgrad !== nothing && (dgrad .*= -1)
dgrad !== nothing && !(sensealg isa QuadratureAdjoint) && (dgrad .*= -1)

if cb isa Union{ContinuousCallback, VectorContinuousCallback}
# second correction to correct for left limit
(; Lu_left) = correction
Expand All @@ -358,7 +372,12 @@ function _setup_reverse_callbacks(

λ .= dλ

if !(sensealg isa QuadratureAdjoint)
if sensealg isa GaussAdjoint
@assert integrator.f.f isa ODEGaussAdjointSensitivityFunction
integrator.f.f.integrating_cb.affect!.integrand_values.integrand .-= dgrad

#recursive_add!(integrator.f.f.integrating_cb.affect!.integrand_values.integrand,dgrad)
elseif !(sensealg isa QuadratureAdjoint)
grad .-= dgrad
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -678,14 +678,14 @@ function DiffEqBase._concrete_solve_adjoint(
du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts,
dgdu_discrete = ArrayInterface.ismutable(eltype(state_values(sol))) ? df_iip : df_oop,
sensealg = sensealg,
callback = cb2,
callback = cb2, no_start = !save_start && _prob.tspan[1] ∈ ts,
initializealg = BrownFullBasicInit(),
kwargs_init...)
else
du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts,
dgdu_discrete = ArrayInterface.ismutable(eltype(state_values(sol))) ? df_iip : df_oop,
sensealg = sensealg,
callback = cb2,
callback = cb2, no_start = !save_start && _prob.tspan[2] ∈ ts,
kwargs_init...)
end

Expand Down
30 changes: 17 additions & 13 deletions src/gauss_adjoint.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mutable struct GaussIntegrand{pType, uType, lType, rateType, S, PF, PJC, PJT, DGP,
G}
G, SAlg <: GaussAdjoint}
sol::S
p::pType
y::uType
Expand All @@ -8,15 +8,17 @@ mutable struct GaussIntegrand{pType, uType, lType, rateType, S, PF, PJC, PJT, DG
f_cache::rateType
pJ::PJT
paramjac_config::PJC
sensealg::GaussAdjoint
sensealg::SAlg
dgdp_cache::DGP
dgdp::G
end

struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache,
Alg <: GaussAdjoint,
uType, SType, CPS, pType,
fType <: AbstractDiffEqFunction} <: SensitivityFunction
fType <: DiffEqBase.AbstractDiffEqFunction,
GI <: GaussIntegrand,
ICB} <: SensitivityFunction
diffcache::C
sensealg::Alg
discrete::Bool
Expand All @@ -25,7 +27,8 @@ struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache,
checkpoint_sol::CPS
prob::pType
f::fType
GaussInt::GaussIntegrand
GaussInt::GI
integrating_cb::ICB
end

mutable struct GaussCheckpointSolution{S, I, T, T2}
Expand All @@ -39,7 +42,7 @@ end
function ODEGaussAdjointSensitivityFunction(
g, sensealg, gaussint, discrete, sol, dgdu, dgdp,
f, alg,
checkpoints, tols, tstops = nothing;
checkpoints, integrating_cb, tols, tstops = nothing;
tspan = reverse(sol.prob.tspan))
checkpointing = ischeckpointing(sensealg, sol)
(checkpointing && checkpoints === nothing) &&
Expand Down Expand Up @@ -82,7 +85,7 @@ function ODEGaussAdjointSensitivityFunction(
g, sensealg, discrete, sol, dgdu, dgdp, sol.prob.f, alg;
quad = true)
return ODEGaussAdjointSensitivityFunction(diffcache, sensealg, discrete,
y, sol, checkpoint_sol, sol.prob, f, gaussint)
y, sol, checkpoint_sol, sol.prob, f, gaussint, integrating_cb)
end

function Gaussfindcursor(intervals, t)
Expand Down Expand Up @@ -200,7 +203,7 @@ function split_states(u, t, S::ODEGaussAdjointSensitivityFunction; update = true
end

@noinline function ODEAdjointProblem(sol, sensealg::GaussAdjoint, alg,
GaussInt::GaussIntegrand,
GaussInt::GaussIntegrand, integrating_cb,
t = nothing,
dgdu_discrete::DG1 = nothing,
dgdp_discrete::DG2 = nothing,
Expand All @@ -209,7 +212,7 @@ end
g::G = nothing,
::Val{RetCB} = Val(false);
checkpoints = current_time(sol),
callback = CallbackSet(),
callback = CallbackSet(), no_start = false,
reltol = nothing, abstol = nothing, kwargs...) where {DG1, DG2, DG3, DG4, G,
RetCB}
dgdu_discrete === nothing && dgdu_continuous === nothing && g === nothing &&
Expand Down Expand Up @@ -273,14 +276,14 @@ end
λ = zero(u0)
end
sense = ODEGaussAdjointSensitivityFunction(g, sensealg, GaussInt, discrete, sol,
dgdu_continuous, dgdp_continuous, f, alg, checkpoints,
dgdu_continuous, dgdp_continuous, f, alg, checkpoints, integrating_cb,
(reltol = reltol, abstol = abstol), tstops, tspan = tspan)

init_cb = (discrete || dgdu_discrete !== nothing) # && tspan[1] == t[end]
z0 = vec(zero(λ))
cb, rcb, _ = generate_callbacks(sense, dgdu_discrete, dgdp_discrete,
λ, t, tspan[2],
callback, init_cb, terminated)
callback, init_cb, terminated, no_start)

jac_prototype = sol.prob.f.jac_prototype
adjoint_jac_prototype = !sense.discrete || jac_prototype === nothing ? nothing :
Expand Down Expand Up @@ -551,7 +554,7 @@ function _adjoint_sensitivities(sol, sensealg::GaussAdjoint, alg; t = nothing,
abstol = 1e-6, reltol = 1e-3,
checkpoints = current_time(sol),
corfunc_analytical = false,
callback = CallbackSet(),
callback = CallbackSet(), no_start = false,
kwargs...)
p = SymbolicIndexingInterface.parameter_values(sol)
if !isscimlstructure(p) && !isfunctor(p)
Expand All @@ -577,11 +580,12 @@ function _adjoint_sensitivities(sol, sensealg::GaussAdjoint, alg; t = nothing,

if sol.prob isa ODEProblem
adj_prob, cb2, rcb = ODEAdjointProblem(
sol, sensealg, alg, integrand, t, dgdu_discrete,
sol, sensealg, alg, integrand, cb,
t, dgdu_discrete,
dgdp_discrete,
dgdu_continuous, dgdp_continuous, g, Val(true);
checkpoints = checkpoints,
callback = callback,
callback = callback, no_start = no_start,
abstol = abstol, reltol = reltol, kwargs...)
else
error("Continuous adjoint sensitivities are only supported for ODE problems.")
Expand Down
4 changes: 2 additions & 2 deletions src/interpolating_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ end
g::G = nothing,
::Val{RetCB} = Val(false);
checkpoints = current_time(sol),
callback = CallbackSet(),
callback = CallbackSet(), no_start = false,
reltol = nothing, abstol = nothing,
kwargs...) where {DG1, DG2, DG3, DG4, G, RetCB}
dgdu_discrete === nothing && dgdu_continuous === nothing && g === nothing &&
Expand Down Expand Up @@ -355,7 +355,7 @@ end
cb, rcb, duplicate_iterator_times = generate_callbacks(sense, dgdu_discrete,
dgdp_discrete,
λ, t, tspan[2],
callback, init_cb, terminated)
callback, init_cb, terminated, no_start)
z0 = vec(zero(λ))
original_mm = sol.prob.f.mass_matrix
if original_mm === I || original_mm === (I, I)
Expand Down
8 changes: 4 additions & 4 deletions src/quadrature_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ end
dgdu_continuous::DG3 = nothing,
dgdp_continuous::DG4 = nothing,
g::G = nothing,
::Val{RetCB} = Val(false);
::Val{RetCB} = Val(false); no_start = false,
callback = CallbackSet()) where {DG1, DG2, DG3, DG4, G,
RetCB}
dgdu_discrete === nothing && dgdu_continuous === nothing && g === nothing &&
Expand Down Expand Up @@ -133,7 +133,7 @@ end
z0 = vec(zero(λ))
cb, rcb, _ = generate_callbacks(sense, dgdu_discrete, dgdp_discrete,
λ, t, tspan[2],
callback, init_cb, terminated)
callback, init_cb, terminated, no_start)

jac_prototype = sol.prob.f.jac_prototype
adjoint_jac_prototype = !sense.discrete || jac_prototype === nothing ? nothing :
Expand Down Expand Up @@ -345,13 +345,13 @@ function _adjoint_sensitivities(sol, sensealg::QuadratureAdjoint, alg; t = nothi
dgdp_discrete = nothing,
dgdu_continuous = nothing,
dgdp_continuous = nothing,
g = nothing,
g = nothing, no_start = false,
abstol = sensealg.abstol, reltol = sensealg.reltol,
callback = CallbackSet(),
kwargs...)
adj_prob, rcb = ODEAdjointProblem(sol, sensealg, alg, t, dgdu_discrete, dgdp_discrete,
dgdu_continuous, dgdp_continuous, g, Val(true);
callback)
callback, no_start)
adj_sol = solve(adj_prob, alg; abstol = abstol, reltol = reltol,
save_everystep = true, save_start = true, kwargs...)

Expand Down
12 changes: 9 additions & 3 deletions src/sensitivity_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ adjoint_sensitivities(sol,alg;t=nothing,
checkpoints=sol.t,
corfunc_analytical=nothing,
callback = nothing,
no_start = false,
sensealg=InterpolatingAdjoint(),
kwargs...)
```
Expand Down Expand Up @@ -127,6 +128,10 @@ For more information, see [Sensitivity Math Details](@ref sensitivity_math).
then this term is not required and will be approximated by numerical or (forward-mode) automatic
differentiation (via the `autojacvec` keyword argument in the `sensealg`)
if this term is not given by the user.
- `no_start`: Says whether the starting time contains data for the sensitivity analysis.
This is required because you must have a solution object which contains the starting time
for the adjoint method, but if your forward solve had `save_start=false` then `no_start=true`
should be set.
- `abstol`: the absolute tolerance of the adjoint solve. Defaults to `1e-3`
- `reltol`: the relative tolerance of the adjoint solve. Defaults to `1e-3`
- `checkpoints`: the values to use for the checkpoints of the reverse solve, if the
Expand Down Expand Up @@ -173,7 +178,8 @@ For continuous functionals, the form is:
du0,dp = adjoint_sensitivities(sol,alg;dgdu_continuous=dgdu,g=g,
dgdp_continuous = dgdp,
sensealg=InterpolatingAdjoint(),
checkpoints=sol.t,kwargs...)
checkpoints=sol.t,
no_start = false, kwargs...)
```

for the cost functional
Expand Down Expand Up @@ -406,7 +412,7 @@ function _adjoint_sensitivities(sol, sensealg, alg;
t = nothing,
dgdu_discrete = nothing, dgdp_discrete = nothing,
dgdu_continuous = nothing, dgdp_continuous = nothing,
g = nothing,
g = nothing, no_start = false,
abstol = 1e-6, reltol = 1e-3,
checkpoints = current_time(sol),
corfunc_analytical = nothing,
Expand All @@ -423,7 +429,7 @@ function _adjoint_sensitivities(sol, sensealg, alg;
dgdp_discrete,
dgdu_continuous, dgdp_continuous, g, Val(true);
checkpoints = checkpoints,
callback = callback,
callback = callback, no_start = no_start,
abstol = abstol, reltol = reltol, kwargs...)

elseif sol.prob isa SDEProblem
Expand Down
6 changes: 3 additions & 3 deletions test/HybridNODE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,16 +183,16 @@ function test_hybridNODE3(sensealg)
end

@testset "PresetTimeCallback: $(sensealg)" for sensealg in [ForwardDiffSensitivity(),
BacksolveAdjoint(), InterpolatingAdjoint(), QuadratureAdjoint()]
BacksolveAdjoint(), InterpolatingAdjoint(), QuadratureAdjoint(), GaussAdjoint()]
test_hybridNODE(sensealg)
end

@testset "PeriodicCallback: $(sensealg)" for sensealg in [ReverseDiffAdjoint(),
BacksolveAdjoint(), InterpolatingAdjoint(), QuadratureAdjoint()]
BacksolveAdjoint(), InterpolatingAdjoint(), QuadratureAdjoint(), GaussAdjoint()]
test_hybridNODE2(sensealg)
end

@testset "tprevCallback: $(sensealg)" for sensealg in [ReverseDiffAdjoint(),
BacksolveAdjoint(), InterpolatingAdjoint(), QuadratureAdjoint()]
BacksolveAdjoint(), InterpolatingAdjoint(), QuadratureAdjoint(), GaussAdjoint()]
test_hybridNODE3(sensealg)
end
4 changes: 4 additions & 0 deletions test/callbacks/continuous_callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,5 +291,9 @@ println("Continuous Callbacks")
sensealg = InterpolatingAdjoint(autojacvec = EnzymeVJP())
gZy = Zygote.gradient(p -> loss(p, cb, sensealg), p)[1]
@test gFD≈gZy rtol=1e-10

sensealg = GaussAdjoint(autojacvec = EnzymeVJP())
gZy = Zygote.gradient(p -> loss(p, cb, sensealg), p)[1]
@test gFD≈gZy rtol=1e-10
end
end
Loading
Loading