Skip to content

Commit 24e9fde

Browse files
Merge pull request #1067 from SciML/makezerobang
Safety make_zero! on repeated Enzyme calls with caches
2 parents f524e9f + 9f14a4f commit 24e9fde

File tree

5 files changed

+24
-8
lines changed

5 files changed

+24
-8
lines changed

src/concrete_solve.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ function DiffEqBase._concrete_solve_adjoint(
578578
(Δu[i] isa NoTangent || eltype(Δu) <: NoTangent) && return
579579
if Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray ||
580580
Δ isa Tangent
581-
x = (Δ isa AbstractVectorOfArray || Δ isa Tangent) ? Δu[i] : Δ[i]
581+
x = Δ isa AbstractVectorOfArray ? Δu.u[i] : (Δ isa Tangent ? Δu[i] : Δ[i])
582582
if _save_idxs isa Number
583583
_out[_save_idxs] = x[_save_idxs]
584584
elseif _save_idxs isa Colon
@@ -1681,8 +1681,8 @@ function DiffEqBase._concrete_solve_adjoint(prob::SciMLBase.AbstractODEProblem,
16811681

16821682
function adjoint_sensitivity_backpass(Δ)
16831683
function df(_out, u, p, t, i)
1684-
if Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray
1685-
x = Δ isa AbstractVectorOfArray ? Δ.u[i] : Δ[i]
1684+
if Δ isa AbstractArray{<:AbstractArray} Δ isa AbstractVectorOfArray || Δ isa Tangent
1685+
x = (Δ isa AbstractVectorOfArray || Δ isa Tangent) ? unthunk(Δ.u[i]) : Δ[i]
16861686
if _save_idxs isa Number
16871687
_out[_save_idxs] = x[_save_idxs]
16881688
elseif _save_idxs isa Colon

src/derivative_wrappers.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,7 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad,
689689
ytmp = _tmp5
690690
end
691691

692-
tmp1 .= 0 # should be removed for dλ
692+
Enzyme.make_zero!(tmp1) # should be removed for dλ
693693
vec(ytmp) .= vec(y)
694694

695695
#if dgrad !== nothing
@@ -707,13 +707,18 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad,
707707
#if dy !== nothing
708708
# tmp3 = dy
709709
#else
710-
tmp3 .= 0
710+
Enzyme.make_zero!(tmp3)
711711
#end
712712

713713
vec(tmp4) .= vec(λ)
714714

715715
isautojacvec = get_jacvec(sensealg)
716716

717+
# Correctness over speed
718+
# TODO: Get a fix for `make_zero!` to allow reusing zero'd memory
719+
# https://github.com/EnzymeAD/Enzyme.jl/issues/2400
720+
_tmp6 = Enzyme.make_zero(_tmp6)
721+
717722
if inplace_sensitivity(S)
718723
if W === nothing
719724
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(S.diffcache.pf, _tmp6),

src/gauss_adjoint.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,13 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand)
500500
tmp3, tmp4, tmp6 = paramjac_config
501501
vtmp4 = vec(tmp4)
502502
vtmp4 .= λ
503-
out .= 0
503+
Enzyme.make_zero!(out)
504+
505+
# Correctness over speed
506+
# TODO: Get a fix for `make_zero!` to allow reusing zero'd memory
507+
# https://github.com/EnzymeAD/Enzyme.jl/issues/2400
508+
tmp6 = Enzyme.make_zero(tmp6)
509+
504510
Enzyme.autodiff(
505511
Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const,
506512
Enzyme.Duplicated(tmp3, tmp4),

src/quadrature_adjoint.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,12 @@ function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand)
300300
tmp3, tmp4, tmp6 = paramjac_config
301301
vtmp4 = vec(tmp4)
302302
vtmp4 .= λ
303-
out .= 0
303+
Enzyme.make_zero!(out)
304+
305+
# Correctness over speed
306+
# TODO: Get a fix for `make_zero!` to allow reusing zero'd memory
307+
# https://github.com/EnzymeAD/Enzyme.jl/issues/2400
308+
tmp6 = Enzyme.make_zero(tmp6)
304309
Enzyme.autodiff(
305310
Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const,
306311
Enzyme.Duplicated(tmp3, tmp4),

test/hybrid_de.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,4 @@ end
5151
res = solve(OptimizationProblem(OptimizationFunction(loss_n_ode, AutoZygote()),
5252
ps),
5353
Adam(0.05); callback = cba, maxiters = 200)
54-
@test loss_n_ode(res.u, nothing) < 0.4
54+
@test loss_n_ode(res.u, nothing) < 0.5

0 commit comments

Comments
 (0)