@@ -12,96 +12,6 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp,
12
12
using ADTypes
13
13
using SparseConnectivityTracer, SparseMatrixColorings
14
14
15
- function generate_sparse_adtype (adtype)
16
- if adtype. sparsity_detector isa ADTypes. NoSparsityDetector &&
17
- adtype. coloring_algorithm isa ADTypes. NoColoringAlgorithm
18
- adtype = AutoSparse (adtype. dense_ad; sparsity_detector = TracerSparsityDetector (),
19
- coloring_algorithm = GreedyColoringAlgorithm ())
20
- if adtype. dense_ad isa ADTypes. AutoFiniteDiff
21
- soadtype = AutoSparse (
22
- DifferentiationInterface. SecondOrder (adtype. dense_ad, adtype. dense_ad),
23
- sparsity_detector = TracerSparsityDetector (),
24
- coloring_algorithm = GreedyColoringAlgorithm ())
25
- elseif ! (adtype. dense_ad isa SciMLBase. NoAD) &&
26
- ADTypes. mode (adtype. dense_ad) isa ADTypes. ForwardMode
27
- soadtype = AutoSparse (
28
- DifferentiationInterface. SecondOrder (adtype. dense_ad, AutoReverseDiff ()),
29
- sparsity_detector = TracerSparsityDetector (),
30
- coloring_algorithm = GreedyColoringAlgorithm ()) # make zygote?
31
- elseif ! (adtype isa SciMLBase. NoAD) &&
32
- ADTypes. mode (adtype. dense_ad) isa ADTypes. ReverseMode
33
- soadtype = AutoSparse (
34
- DifferentiationInterface. SecondOrder (AutoForwardDiff (), adtype. dense_ad),
35
- sparsity_detector = TracerSparsityDetector (),
36
- coloring_algorithm = GreedyColoringAlgorithm ())
37
- end
38
- elseif adtype. sparsity_detector isa ADTypes. NoSparsityDetector &&
39
- ! (adtype. coloring_algorithm isa ADTypes. NoColoringAlgorithm)
40
- adtype = AutoSparse (adtype. dense_ad; sparsity_detector = TracerSparsityDetector (),
41
- coloring_algorithm = adtype. coloring_algorithm)
42
- if adtype. dense_ad isa ADTypes. AutoFiniteDiff
43
- soadtype = AutoSparse (
44
- DifferentiationInterface. SecondOrder (adtype. dense_ad, adtype. dense_ad),
45
- sparsity_detector = TracerSparsityDetector (),
46
- coloring_algorithm = adtype. coloring_algorithm)
47
- elseif ! (adtype. dense_ad isa SciMLBase. NoAD) &&
48
- ADTypes. mode (adtype. dense_ad) isa ADTypes. ForwardMode
49
- soadtype = AutoSparse (
50
- DifferentiationInterface. SecondOrder (adtype. dense_ad, AutoReverseDiff ()),
51
- sparsity_detector = TracerSparsityDetector (),
52
- coloring_algorithm = adtype. coloring_algorithm)
53
- elseif ! (adtype isa SciMLBase. NoAD) &&
54
- ADTypes. mode (adtype. dense_ad) isa ADTypes. ReverseMode
55
- soadtype = AutoSparse (
56
- DifferentiationInterface. SecondOrder (AutoForwardDiff (), adtype. dense_ad),
57
- sparsity_detector = TracerSparsityDetector (),
58
- coloring_algorithm = adtype. coloring_algorithm)
59
- end
60
- elseif ! (adtype. sparsity_detector isa ADTypes. NoSparsityDetector) &&
61
- adtype. coloring_algorithm isa ADTypes. NoColoringAlgorithm
62
- adtype = AutoSparse (adtype. dense_ad; sparsity_detector = adtype. sparsity_detector,
63
- coloring_algorithm = GreedyColoringAlgorithm ())
64
- if adtype. dense_ad isa ADTypes. AutoFiniteDiff
65
- soadtype = AutoSparse (
66
- DifferentiationInterface. SecondOrder (adtype. dense_ad, adtype. dense_ad),
67
- sparsity_detector = adtype. sparsity_detector,
68
- coloring_algorithm = GreedyColoringAlgorithm ())
69
- elseif ! (adtype. dense_ad isa SciMLBase. NoAD) &&
70
- ADTypes. mode (adtype. dense_ad) isa ADTypes. ForwardMode
71
- soadtype = AutoSparse (
72
- DifferentiationInterface. SecondOrder (adtype. dense_ad, AutoReverseDiff ()),
73
- sparsity_detector = adtype. sparsity_detector,
74
- coloring_algorithm = GreedyColoringAlgorithm ())
75
- elseif ! (adtype isa SciMLBase. NoAD) &&
76
- ADTypes. mode (adtype. dense_ad) isa ADTypes. ReverseMode
77
- soadtype = AutoSparse (
78
- DifferentiationInterface. SecondOrder (AutoForwardDiff (), adtype. dense_ad),
79
- sparsity_detector = adtype. sparsity_detector,
80
- coloring_algorithm = GreedyColoringAlgorithm ())
81
- end
82
- else
83
- if adtype. dense_ad isa ADTypes. AutoFiniteDiff
84
- soadtype = AutoSparse (
85
- DifferentiationInterface. SecondOrder (adtype. dense_ad, adtype. dense_ad),
86
- sparsity_detector = adtype. sparsity_detector,
87
- coloring_algorithm = adtype. coloring_algorithm)
88
- elseif ! (adtype. dense_ad isa SciMLBase. NoAD) &&
89
- ADTypes. mode (adtype. dense_ad) isa ADTypes. ForwardMode
90
- soadtype = AutoSparse (
91
- DifferentiationInterface. SecondOrder (adtype. dense_ad, AutoReverseDiff ()),
92
- sparsity_detector = adtype. sparsity_detector,
93
- coloring_algorithm = adtype. coloring_algorithm)
94
- elseif ! (adtype isa SciMLBase. NoAD) &&
95
- ADTypes. mode (adtype. dense_ad) isa ADTypes. ReverseMode
96
- soadtype = AutoSparse (
97
- DifferentiationInterface. SecondOrder (AutoForwardDiff (), adtype. dense_ad),
98
- sparsity_detector = adtype. sparsity_detector,
99
- coloring_algorithm = adtype. coloring_algorithm)
100
- end
101
- end
102
- return adtype, soadtype
103
- end
104
-
105
15
function instantiate_function (
106
16
f:: OptimizationFunction{true} , x, adtype:: ADTypes.AutoSparse{<:AbstractADType} ,
107
17
p = SciMLBase. NullParameters (), num_cons = 0 ;
@@ -205,7 +115,11 @@ function instantiate_function(
205
115
hv! = nothing
206
116
end
207
117
208
- if ! (f. cons === nothing )
118
+ if f. cons === nothing
119
+ cons = nothing
120
+ else
121
+ cons = (res, θ) -> f. cons (res, θ, p)
122
+
209
123
function cons_oop (x)
210
124
_res = zeros (eltype (x), num_cons)
211
125
f. cons (_res, x, p)
@@ -347,7 +261,7 @@ function instantiate_function(
347
261
end
348
262
return OptimizationFunction {true} (f. f, adtype;
349
263
grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!,
350
- cons = (res, x) -> f . cons (res, x, p) , cons_j = cons_j!, cons_h = cons_h!,
264
+ cons = cons, cons_j = cons_j!, cons_h = cons_h!,
351
265
cons_vjp = cons_vjp!, cons_jvp = cons_jvp!,
352
266
hess_prototype = hess_sparsity,
353
267
hess_colorvec = hess_colors,
@@ -475,7 +389,11 @@ function instantiate_function(
475
389
hv! = nothing
476
390
end
477
391
478
- if ! (f. cons === nothing )
392
+ if f. cons === nothing
393
+ cons = nothing
394
+ else
395
+ cons = Base. Fix2 (f. cons, p)
396
+
479
397
function lagrangian (θ, σ, λ, p)
480
398
return σ * f. f (θ, p) + dot (λ, f. cons (θ, p))
481
399
end
@@ -585,7 +503,7 @@ function instantiate_function(
585
503
end
586
504
return OptimizationFunction {false} (f. f, adtype;
587
505
grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!,
588
- cons = Base . Fix2 (f . cons, p) , cons_j = cons_j!, cons_h = cons_h!,
506
+ cons = cons, cons_j = cons_j!, cons_h = cons_h!,
589
507
cons_vjp = cons_vjp!, cons_jvp = cons_jvp!,
590
508
hess_prototype = hess_sparsity,
591
509
hess_colorvec = hess_colors,
0 commit comments