Skip to content

Commit 1ebca6f

Browse files
Merge pull request #110 from SciML/DIv6
DifferentiationInterface v6: Constant and arguments orders changes
2 parents e30f933 + 3038abf commit 1ebca6f

File tree

8 files changed

+472
-524
lines changed

8 files changed

+472
-524
lines changed

Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OptimizationBase"
22
uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
33
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
4-
version = "2.0.4"
4+
version = "2.1.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -38,11 +38,11 @@ OptimizationReverseDiffExt = "ReverseDiff"
3838
OptimizationZygoteExt = "Zygote"
3939

4040
[compat]
41-
ADTypes = "1.5"
41+
ADTypes = "1.9"
4242
ArrayInterface = "7.6"
43-
DifferentiationInterface = "0.5"
43+
DifferentiationInterface = "0.6.1"
4444
DocStringExtensions = "0.9"
45-
Enzyme = "0.12.12"
45+
Enzyme = "0.13.2"
4646
FastClosures = "0.3"
4747
FiniteDiff = "2.12"
4848
ForwardDiff = "0.10.26"

ext/OptimizationEnzymeExt.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ end
4141
function hv_f2_alloc(x, f, p)
4242
dx = Enzyme.make_zero(x)
4343
Enzyme.autodiff_deferred(Enzyme.Reverse,
44-
firstapply,
44+
Const(firstapply),
4545
Active,
4646
Const(f),
4747
Enzyme.Duplicated(x, dx),
@@ -58,7 +58,8 @@ function inner_cons(x, fcons::Function, p::Union{SciMLBase.NullParameters, Nothi
5858
end
5959

6060
function cons_f2(x, dx, fcons, p, num_cons, i)
61-
Enzyme.autodiff_deferred(Enzyme.Reverse, inner_cons, Active, Enzyme.Duplicated(x, dx),
61+
Enzyme.autodiff_deferred(
62+
Enzyme.Reverse, Const(inner_cons), Active, Enzyme.Duplicated(x, dx),
6263
Const(fcons), Const(p), Const(num_cons), Const(i))
6364
return nothing
6465
end
@@ -71,7 +72,7 @@ end
7172

7273
function cons_f2_oop(x, dx, fcons, p, i)
7374
Enzyme.autodiff_deferred(
74-
Enzyme.Reverse, inner_cons_oop, Active, Enzyme.Duplicated(x, dx),
75+
Enzyme.Reverse, Const(inner_cons_oop), Active, Enzyme.Duplicated(x, dx),
7576
Const(fcons), Const(p), Const(i))
7677
return nothing
7778
end
@@ -83,7 +84,8 @@ function lagrangian(x, _f::Function, cons::Function, p, λ, σ = one(eltype(x)))
8384
end
8485

8586
function lag_grad(x, dx, lagrangian::Function, _f::Function, cons::Function, p, σ, λ)
86-
Enzyme.autodiff_deferred(Enzyme.Reverse, lagrangian, Active, Enzyme.Duplicated(x, dx),
87+
Enzyme.autodiff_deferred(
88+
Enzyme.Reverse, Const(lagrangian), Active, Enzyme.Duplicated(x, dx),
8789
Const(_f), Const(cons), Const(p), Const(λ), Const(σ))
8890
return nothing
8991
end
@@ -187,7 +189,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
187189
if hv == true && f.hv === nothing
188190
function hv!(H, θ, v, p = p)
189191
H .= Enzyme.autodiff(
190-
Enzyme.Forward, hv_f2_alloc, DuplicatedNoNeed, Duplicated(θ, v),
192+
Enzyme.Forward, hv_f2_alloc, Duplicated(θ, v),
191193
Const(f.f), Const(p)
192194
)[1]
193195
end
@@ -531,7 +533,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
531533
for i in eachindex(Jaccache)
532534
Enzyme.make_zero!(Jaccache[i])
533535
end
534-
y, Jaccache = Enzyme.autodiff(Enzyme.Forward, f.cons, Duplicated,
536+
Jaccache, y = Enzyme.autodiff(Enzyme.ForwardWithPrimal, f.cons, Duplicated,
535537
BatchDuplicated(θ, seeds), Const(p))
536538
if size(y, 1) == 1
537539
return reduce(vcat, Jaccache)

0 commit comments

Comments
 (0)