Skip to content

Commit 3038abf

Browse files
cons again
1 parent 75be96e commit 3038abf

File tree

4 files changed

+128
-116
lines changed

4 files changed

+128
-116
lines changed

ext/OptimizationZygoteExt.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,7 @@ function OptimizationBase.instantiate_function(
118118
if f.cons === nothing
119119
cons = nothing
120120
else
121-
function cons(res, θ)
122-
return f.cons(res, θ, p)
123-
end
121+
cons = (res, θ) -> f.cons(res, θ, p)
124122

125123
function cons_oop(x)
126124
_res = Zygote.Buffer(x, num_cons)
@@ -369,7 +367,8 @@ function OptimizationBase.instantiate_function(
369367
end
370368

371369
if hv == true && f.hv === nothing
372-
prep_hvp = prepare_hvp(_f, soadtype.dense_ad, x, zeros(eltype(x), size(x)))
370+
prep_hvp = prepare_hvp(
371+
f.f, soadtype.dense_ad, x, (zeros(eltype(x), size(x)),), Constant(p))
373372
function hv!(H, θ, v)
374373
hvp!(f.f, (H,), prep_hvp, soadtype.dense_ad, θ, (v,), Constant(p))
375374
end
@@ -387,9 +386,7 @@ function OptimizationBase.instantiate_function(
387386
if f.cons === nothing
388387
cons = nothing
389388
else
390-
function cons(res, θ)
391-
f.cons(res, θ, p)
392-
end
389+
cons = (res, θ) -> f.cons(res, θ, p)
393390

394391
function cons_oop(x)
395392
_res = Zygote.Buffer(x, num_cons)

src/OptimizationDIExt.jl

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,6 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp,
1414
hvp, jacobian, Constant
1515
using ADTypes, SciMLBase
1616

17-
function generate_adtype(adtype)
18-
if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode
19-
soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote?
20-
elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode
21-
soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype)
22-
else
23-
soadtype = adtype
24-
end
25-
return adtype, soadtype
26-
end
27-
2817
function instantiate_function(
2918
f::OptimizationFunction{true}, x, adtype::ADTypes.AbstractADType,
3019
p = SciMLBase.NullParameters(), num_cons = 0;
@@ -122,7 +111,10 @@ function instantiate_function(
122111
hv! = nothing
123112
end
124113

125-
if !(f.cons === nothing)
114+
if f.cons === nothing
115+
cons = nothing
116+
else
117+
cons = (res, x) -> f.cons(res, x, p)
126118
function cons_oop(x)
127119
_res = zeros(eltype(x), num_cons)
128120
f.cons(_res, x, p)
@@ -257,7 +249,7 @@ function instantiate_function(
257249

258250
return OptimizationFunction{true}(f.f, adtype;
259251
grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!,
260-
cons = (res, x) -> f.cons(res, x, p), cons_j = cons_j!, cons_h = cons_h!,
252+
cons = cons, cons_j = cons_j!, cons_h = cons_h!,
261253
cons_vjp = cons_vjp!, cons_jvp = cons_jvp!,
262254
hess_prototype = hess_sparsity,
263255
hess_colorvec = hess_colors,
@@ -379,7 +371,11 @@ function instantiate_function(
379371
hv! = nothing
380372
end
381373

382-
if !(f.cons === nothing)
374+
if f.cons === nothing
375+
cons = nothing
376+
else
377+
cons = Base.Fix2(f.cons, p)
378+
383379
function lagrangian(θ, σ, λ, p)
384380
return σ * f.f(θ, p) + dot(λ, f.cons(θ, p))
385381
end
@@ -482,7 +478,7 @@ function instantiate_function(
482478

483479
return OptimizationFunction{false}(f.f, adtype;
484480
grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!,
485-
cons = Base.Fix2(f.cons, p), cons_j = cons_j!, cons_h = cons_h!,
481+
cons = cons, cons_j = cons_j!, cons_h = cons_h!,
486482
cons_vjp = cons_vjp!, cons_jvp = cons_jvp!,
487483
hess_prototype = hess_sparsity,
488484
hess_colorvec = hess_colors,

src/OptimizationDISparseExt.jl

Lines changed: 12 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -12,96 +12,6 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp,
1212
using ADTypes
1313
using SparseConnectivityTracer, SparseMatrixColorings
1414

15-
function generate_sparse_adtype(adtype)
16-
if adtype.sparsity_detector isa ADTypes.NoSparsityDetector &&
17-
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
18-
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(),
19-
coloring_algorithm = GreedyColoringAlgorithm())
20-
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
21-
soadtype = AutoSparse(
22-
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
23-
sparsity_detector = TracerSparsityDetector(),
24-
coloring_algorithm = GreedyColoringAlgorithm())
25-
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
26-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
27-
soadtype = AutoSparse(
28-
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
29-
sparsity_detector = TracerSparsityDetector(),
30-
coloring_algorithm = GreedyColoringAlgorithm()) #make zygote?
31-
elseif !(adtype isa SciMLBase.NoAD) &&
32-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
33-
soadtype = AutoSparse(
34-
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
35-
sparsity_detector = TracerSparsityDetector(),
36-
coloring_algorithm = GreedyColoringAlgorithm())
37-
end
38-
elseif adtype.sparsity_detector isa ADTypes.NoSparsityDetector &&
39-
!(adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm)
40-
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(),
41-
coloring_algorithm = adtype.coloring_algorithm)
42-
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
43-
soadtype = AutoSparse(
44-
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
45-
sparsity_detector = TracerSparsityDetector(),
46-
coloring_algorithm = adtype.coloring_algorithm)
47-
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
48-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
49-
soadtype = AutoSparse(
50-
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
51-
sparsity_detector = TracerSparsityDetector(),
52-
coloring_algorithm = adtype.coloring_algorithm)
53-
elseif !(adtype isa SciMLBase.NoAD) &&
54-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
55-
soadtype = AutoSparse(
56-
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
57-
sparsity_detector = TracerSparsityDetector(),
58-
coloring_algorithm = adtype.coloring_algorithm)
59-
end
60-
elseif !(adtype.sparsity_detector isa ADTypes.NoSparsityDetector) &&
61-
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
62-
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = adtype.sparsity_detector,
63-
coloring_algorithm = GreedyColoringAlgorithm())
64-
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
65-
soadtype = AutoSparse(
66-
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
67-
sparsity_detector = adtype.sparsity_detector,
68-
coloring_algorithm = GreedyColoringAlgorithm())
69-
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
70-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
71-
soadtype = AutoSparse(
72-
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
73-
sparsity_detector = adtype.sparsity_detector,
74-
coloring_algorithm = GreedyColoringAlgorithm())
75-
elseif !(adtype isa SciMLBase.NoAD) &&
76-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
77-
soadtype = AutoSparse(
78-
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
79-
sparsity_detector = adtype.sparsity_detector,
80-
coloring_algorithm = GreedyColoringAlgorithm())
81-
end
82-
else
83-
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
84-
soadtype = AutoSparse(
85-
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
86-
sparsity_detector = adtype.sparsity_detector,
87-
coloring_algorithm = adtype.coloring_algorithm)
88-
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
89-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
90-
soadtype = AutoSparse(
91-
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
92-
sparsity_detector = adtype.sparsity_detector,
93-
coloring_algorithm = adtype.coloring_algorithm)
94-
elseif !(adtype isa SciMLBase.NoAD) &&
95-
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
96-
soadtype = AutoSparse(
97-
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
98-
sparsity_detector = adtype.sparsity_detector,
99-
coloring_algorithm = adtype.coloring_algorithm)
100-
end
101-
end
102-
return adtype, soadtype
103-
end
104-
10515
function instantiate_function(
10616
f::OptimizationFunction{true}, x, adtype::ADTypes.AutoSparse{<:AbstractADType},
10717
p = SciMLBase.NullParameters(), num_cons = 0;
@@ -205,7 +115,11 @@ function instantiate_function(
205115
hv! = nothing
206116
end
207117

208-
if !(f.cons === nothing)
118+
if f.cons === nothing
119+
cons = nothing
120+
else
121+
cons = (res, θ) -> f.cons(res, θ, p)
122+
209123
function cons_oop(x)
210124
_res = zeros(eltype(x), num_cons)
211125
f.cons(_res, x, p)
@@ -347,7 +261,7 @@ function instantiate_function(
347261
end
348262
return OptimizationFunction{true}(f.f, adtype;
349263
grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!,
350-
cons = (res, x) -> f.cons(res, x, p), cons_j = cons_j!, cons_h = cons_h!,
264+
cons = cons, cons_j = cons_j!, cons_h = cons_h!,
351265
cons_vjp = cons_vjp!, cons_jvp = cons_jvp!,
352266
hess_prototype = hess_sparsity,
353267
hess_colorvec = hess_colors,
@@ -475,7 +389,11 @@ function instantiate_function(
475389
hv! = nothing
476390
end
477391

478-
if !(f.cons === nothing)
392+
if f.cons === nothing
393+
cons = nothing
394+
else
395+
cons = Base.Fix2(f.cons, p)
396+
479397
function lagrangian(θ, σ, λ, p)
480398
return σ * f.f(θ, p) + dot(λ, f.cons(θ, p))
481399
end
@@ -585,7 +503,7 @@ function instantiate_function(
585503
end
586504
return OptimizationFunction{false}(f.f, adtype;
587505
grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!,
588-
cons = Base.Fix2(f.cons, p), cons_j = cons_j!, cons_h = cons_h!,
506+
cons = cons, cons_j = cons_j!, cons_h = cons_h!,
589507
cons_vjp = cons_vjp!, cons_jvp = cons_jvp!,
590508
hess_prototype = hess_sparsity,
591509
hess_colorvec = hess_colors,

src/adtypes.jl

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,104 @@ if a `hess` function is supplied to the `OptimizationFunction`, then the
218218
Hessian is not defined via Zygote.
219219
"""
220220
AutoZygote
221+
222+
function generate_adtype(adtype)
223+
if !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ForwardMode
224+
soadtype = DifferentiationInterface.SecondOrder(adtype, AutoReverseDiff()) #make zygote?
225+
elseif !(adtype isa SciMLBase.NoAD) && ADTypes.mode(adtype) isa ADTypes.ReverseMode
226+
soadtype = DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype)
227+
else
228+
soadtype = adtype
229+
end
230+
return adtype, soadtype
231+
end
232+
233+
function generate_sparse_adtype(adtype)
234+
if adtype.sparsity_detector isa ADTypes.NoSparsityDetector &&
235+
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
236+
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(),
237+
coloring_algorithm = GreedyColoringAlgorithm())
238+
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
239+
soadtype = AutoSparse(
240+
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
241+
sparsity_detector = TracerSparsityDetector(),
242+
coloring_algorithm = GreedyColoringAlgorithm())
243+
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
244+
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
245+
soadtype = AutoSparse(
246+
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
247+
sparsity_detector = TracerSparsityDetector(),
248+
coloring_algorithm = GreedyColoringAlgorithm()) #make zygote?
249+
elseif !(adtype isa SciMLBase.NoAD) &&
250+
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
251+
soadtype = AutoSparse(
252+
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
253+
sparsity_detector = TracerSparsityDetector(),
254+
coloring_algorithm = GreedyColoringAlgorithm())
255+
end
256+
elseif adtype.sparsity_detector isa ADTypes.NoSparsityDetector &&
257+
!(adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm)
258+
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = TracerSparsityDetector(),
259+
coloring_algorithm = adtype.coloring_algorithm)
260+
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
261+
soadtype = AutoSparse(
262+
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
263+
sparsity_detector = TracerSparsityDetector(),
264+
coloring_algorithm = adtype.coloring_algorithm)
265+
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
266+
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
267+
soadtype = AutoSparse(
268+
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
269+
sparsity_detector = TracerSparsityDetector(),
270+
coloring_algorithm = adtype.coloring_algorithm)
271+
elseif !(adtype isa SciMLBase.NoAD) &&
272+
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
273+
soadtype = AutoSparse(
274+
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
275+
sparsity_detector = TracerSparsityDetector(),
276+
coloring_algorithm = adtype.coloring_algorithm)
277+
end
278+
elseif !(adtype.sparsity_detector isa ADTypes.NoSparsityDetector) &&
279+
adtype.coloring_algorithm isa ADTypes.NoColoringAlgorithm
280+
adtype = AutoSparse(adtype.dense_ad; sparsity_detector = adtype.sparsity_detector,
281+
coloring_algorithm = GreedyColoringAlgorithm())
282+
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
283+
soadtype = AutoSparse(
284+
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
285+
sparsity_detector = adtype.sparsity_detector,
286+
coloring_algorithm = GreedyColoringAlgorithm())
287+
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
288+
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
289+
soadtype = AutoSparse(
290+
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
291+
sparsity_detector = adtype.sparsity_detector,
292+
coloring_algorithm = GreedyColoringAlgorithm())
293+
elseif !(adtype isa SciMLBase.NoAD) &&
294+
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
295+
soadtype = AutoSparse(
296+
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
297+
sparsity_detector = adtype.sparsity_detector,
298+
coloring_algorithm = GreedyColoringAlgorithm())
299+
end
300+
else
301+
if adtype.dense_ad isa ADTypes.AutoFiniteDiff
302+
soadtype = AutoSparse(
303+
DifferentiationInterface.SecondOrder(adtype.dense_ad, adtype.dense_ad),
304+
sparsity_detector = adtype.sparsity_detector,
305+
coloring_algorithm = adtype.coloring_algorithm)
306+
elseif !(adtype.dense_ad isa SciMLBase.NoAD) &&
307+
ADTypes.mode(adtype.dense_ad) isa ADTypes.ForwardMode
308+
soadtype = AutoSparse(
309+
DifferentiationInterface.SecondOrder(adtype.dense_ad, AutoReverseDiff()),
310+
sparsity_detector = adtype.sparsity_detector,
311+
coloring_algorithm = adtype.coloring_algorithm)
312+
elseif !(adtype isa SciMLBase.NoAD) &&
313+
ADTypes.mode(adtype.dense_ad) isa ADTypes.ReverseMode
314+
soadtype = AutoSparse(
315+
DifferentiationInterface.SecondOrder(AutoForwardDiff(), adtype.dense_ad),
316+
sparsity_detector = adtype.sparsity_detector,
317+
coloring_algorithm = adtype.coloring_algorithm)
318+
end
319+
end
320+
return adtype, soadtype
321+
end

0 commit comments

Comments
 (0)