Skip to content

Commit ff13176

Browse files
Add fields for manifold and curvature results in cache
1 parent 2a1f84c commit ff13176

File tree

4 files changed

+48
-19
lines changed

4 files changed

+48
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OptimizationBase"
22
uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
33
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
4-
version = "1.3.1"
4+
version = "1.3.2"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/OptimizationBase.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ end
1010

1111
using ArrayInterface, Base.Iterators, SparseArrays, LinearAlgebra
1212
using SymbolicIndexingInterface
13-
using SymbolicAnalysis: propagate_sign, propagate_curvature, propagate_gcurvature,
14-
getcurvature, getgcurvature, getsign
13+
using SymbolicAnalysis
14+
using SymbolicAnalysis: AnalysisResult
1515
import Symbolics
1616
import Manifolds
1717
import Symbolics: variable, Equation, Inequality, unwrap, @variables

src/cache.jl

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
import Symbolics: , ~
2-
struct OptimizationCache{F, RC, LB, UB, LC, UC, S, O, D, P, C} <:
2+
3+
struct AnalysisResults
4+
objective::Union{Nothing, AnalysisResult}
5+
constraints::Union{Nothing, Vector{AnalysisResult}}
6+
end
7+
8+
struct OptimizationCache{F, RC, LB, UB, LC, UC, S, O, D, P, C, M} <:
39
SciMLBase.AbstractOptimizationCache
410
f::F
511
reinit_cache::RC
@@ -12,6 +18,8 @@ struct OptimizationCache{F, RC, LB, UB, LC, UC, S, O, D, P, C} <:
1218
data::D
1319
progress::P
1420
callback::C
21+
manifold::M
22+
analysis_results::AnalysisResults
1523
solver_args::NamedTuple
1624
end
1725

@@ -101,32 +109,44 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt, data = DEFA
101109
end
102110

103111
if obj_expr !== nothing
104-
try
105-
obj_expr = obj_expr |> Symbolics.unwrap
106-
obj_expr = propagate_curvature(propagate_sign(obj_expr))
107-
@info "Objective Euclidean curvature: $(getcurvature(obj_expr))"
108-
catch
109-
@info "No euclidean atom available"
112+
obj_expr = obj_expr |> Symbolics.unwrap
113+
if manifold === nothing
114+
obj_res = analyze(obj_expr)
115+
else
116+
obj_res = analyze(obj_expr, manifold)
110117
end
111118

112-
try
113-
obj_expr = propagate_gcurvature(propagate_sign(obj_expr), manifold)
114-
@info "Objective Geodesic curvature: $(getgcurvature(obj_expr))"
115-
catch
116-
@info "No geodesic atom available"
119+
@info "Objective Euclidean curvature: $(obj_res.curvature)"
120+
121+
if obj_res.gcurvature !== nothing
122+
@info "Objective Geodesic curvature: $(obj_res.gcurvature)"
117123
end
124+
else
125+
obj_res = nothing
118126
end
119127

120128
if cons_expr !== nothing
121129
cons_expr = cons_expr .|> Symbolics.unwrap
122-
cons_expr = propagate_curvature.(propagate_sign.(cons_expr))
123-
@info "Constraints Euclidean curvature: $(getcurvature.(cons_expr))"
130+
if manifold === nothing
131+
cons_res = analyze.(cons_expr)
132+
else
133+
cons_res = analyze.(cons_expr, Ref(manifold))
134+
end
135+
for i in 1:num_cons
136+
@info "Constraints Euclidean curvature: $(cons_res[i].curvature)"
137+
138+
if cons_res[i].gcurvature !== nothing
139+
@info "Constraints Geodesic curvature: $(cons_res[i].gcurvature)"
140+
end
141+
end
142+
else
143+
cons_res = nothing
124144
end
125145

126146
return OptimizationCache(f, reinit_cache, prob.lb, prob.ub, prob.lcons,
127147
prob.ucons, prob.sense,
128-
opt, data, progress, callback,
129-
merge((; maxiters, maxtime, abstol, reltol, manifold),
148+
opt, data, progress, callback, manifold, AnalysisResults(obj_res, cons_res),
149+
merge((; maxiters, maxtime, abstol, reltol),
130150
NamedTuple(kwargs)))
131151
end
132152

test/cvxtest.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ optf = OptimizationFunction(f, Optimization.AutoForwardDiff())
99
prob = OptimizationProblem(optf, [0.4], structural_analysis = true)
1010

1111
@time sol = solve(prob, Optimization.LBFGS(), maxiters = 1000)
12+
@test sol.cache.analysis_results.objective.curvature == SymbolicAnalysis.Convex
13+
@test sol.cache.analysis_results.constraints === nothing
1214

1315
x0 = zeros(2)
1416
rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
@@ -17,6 +19,7 @@ l1 = rosenbrock(x0)
1719
optf = OptimizationFunction(rosenbrock, AutoEnzyme())
1820
prob = OptimizationProblem(optf, x0, structural_analysis = true)
1921
@time res = solve(prob, Optimization.LBFGS(), maxiters = 100)
22+
@test res.cache.analysis_results.objective.curvature == SymbolicAnalysis.UnknownCurvature
2023

2124
function con2_c(res, x, p)
2225
res .= [x[1]^2 + x[2]^2, (x[2] * sin(x[1]) + x[1]) - 5]
@@ -26,6 +29,10 @@ optf = OptimizationFunction(rosenbrock, AutoZygote(), cons = con2_c)
2629
prob = OptimizationProblem(optf, x0, lcons = [1.0, -Inf], ucons = [1.0, 0.0],
2730
lb = [-1.0, -1.0], ub = [1.0, 1.0], structural_analysis = true)
2831
@time res = solve(prob, Optimization.LBFGS(), maxiters = 100)
32+
@test res.cache.analysis_results.objective.curvature == SymbolicAnalysis.Convex
33+
@test res.cache.analysis_results.constraints[1].curvature == SymbolicAnalysis.Convex
34+
@test res.cache.analysis_results.constraints[2].curvature ==
35+
SymbolicAnalysis.UnknownCurvature
2936

3037
m = 100
3138
σ = 0.005
@@ -41,3 +48,5 @@ prob = OptimizationProblem(optf, data2[1]; manifold = M, structural_analysis = t
4148
opt = OptimizationManopt.GradientDescentOptimizer()
4249
@time sol = solve(prob, Optimization.LBFGS(), maxiters = 100)
4350
@test sol.minimizer < 1e-1
51+
@test sol.cache.analysis_results.objective.curvature == SymbolicAnalysis.UnknownCurvature
52+
@test sol.cache.analysis_results.objective.gcurvature == SymbolicAnalysis.GConvex

0 commit comments

Comments
 (0)