Skip to content

Commit cce012c

Browse files
committed
finish wrapping fides
1 parent 0f34d32 commit cce012c

File tree

2 files changed

+177
-51
lines changed

2 files changed

+177
-51
lines changed

src/optimagic/optimizers/fides.py

Lines changed: 140 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
"""Implement the fides optimizer."""
22

33
import logging
4+
from dataclasses import dataclass
5+
from typing import Callable, Literal, cast
46

57
import numpy as np
8+
from numpy.typing import NDArray
69

10+
from optimagic import mark
711
from optimagic.config import IS_FIDES_INSTALLED
8-
from optimagic.decorators import mark_minimizer
912
from optimagic.exceptions import NotInstalledError
1013
from optimagic.optimization.algo_options import (
1114
CONVERGENCE_FTOL_ABS,
@@ -15,40 +18,138 @@
1518
CONVERGENCE_XTOL_ABS,
1619
STOPPING_MAXITER,
1720
)
21+
from optimagic.optimization.algorithm import Algorithm, InternalOptimizeResult
22+
from optimagic.optimization.internal_optimization_problem import (
23+
InternalOptimizationProblem,
24+
)
25+
from optimagic.typing import (
26+
AggregationLevel,
27+
NonNegativeFloat,
28+
PositiveFloat,
29+
PositiveInt,
30+
PyTree,
31+
)
1832

1933
if IS_FIDES_INSTALLED:
2034
from fides import Optimizer, hessian_approximation
2135

2236

23-
@mark_minimizer(
37+
@mark.minimizer(
2438
name="fides",
25-
primary_criterion_entry="value",
26-
needs_scaling=False,
39+
solver_type=AggregationLevel.SCALAR,
2740
is_available=IS_FIDES_INSTALLED,
41+
is_global=False,
42+
needs_jac=True,
43+
needs_hess=False,
44+
supports_parallelism=False,
45+
supports_bounds=True,
46+
supports_linear_constraints=False,
47+
supports_nonlinear_constraints=False,
48+
disable_history=False,
2849
)
29-
def fides(
30-
criterion_and_derivative,
31-
x,
32-
lower_bounds,
33-
upper_bounds,
34-
*,
35-
hessian_update_strategy="bfgs",
36-
convergence_ftol_abs=CONVERGENCE_FTOL_ABS,
37-
convergence_ftol_rel=CONVERGENCE_FTOL_REL,
38-
convergence_xtol_abs=CONVERGENCE_XTOL_ABS,
39-
convergence_gtol_abs=CONVERGENCE_GTOL_ABS,
40-
convergence_gtol_rel=CONVERGENCE_GTOL_REL,
41-
stopping_maxiter=STOPPING_MAXITER,
42-
stopping_max_seconds=np.inf,
43-
trustregion_initial_radius=1.0,
44-
trustregion_stepback_strategy="truncate",
45-
trustregion_subspace_dimension="full",
46-
trustregion_max_stepback_fraction=0.95,
47-
trustregion_decrease_threshold=0.25,
48-
trustregion_increase_threshold=0.75,
49-
trustregion_decrease_factor=0.25,
50-
trustregion_increase_factor=2.0,
51-
):
50+
@dataclass(frozen=True)
51+
class Fides(Algorithm):
52+
hessian_update_strategy: Literal[
53+
"bfgs",
54+
"bb",
55+
"bg",
56+
"dfp",
57+
"sr1",
58+
] = "bfgs"
59+
convergence_ftol_abs: NonNegativeFloat = CONVERGENCE_FTOL_ABS
60+
convergence_ftol_rel: NonNegativeFloat = CONVERGENCE_FTOL_REL
61+
convergence_xtol_abs: NonNegativeFloat = CONVERGENCE_XTOL_ABS
62+
convergence_gtol_abs: NonNegativeFloat = CONVERGENCE_GTOL_ABS
63+
convergence_gtol_rel: NonNegativeFloat = CONVERGENCE_GTOL_REL
64+
stopping_maxiter: PositiveInt = STOPPING_MAXITER
65+
stopping_max_seconds: float = np.inf
66+
trustregion_initial_radius: PositiveFloat = 1.0
67+
trustregion_stepback_strategy: Literal[
68+
"truncate",
69+
"reflect",
70+
"reflect_single",
71+
"mixed",
72+
] = "truncate"
73+
trustregion_subspace_dimension: Literal[
74+
"full",
75+
"2D",
76+
"scg",
77+
] = "full"
78+
trustregion_max_stepback_fraction: float = 0.95
79+
trustregion_decrease_threshold: float = 0.25
80+
trustregion_increase_threshold: float = 0.75
81+
trustregion_decrease_factor: float = 0.25
82+
trustregion_increase_factor: float = 2.0
83+
84+
def _solve_internal_problem(
85+
self, problem: InternalOptimizationProblem, x0: NDArray[np.float64]
86+
) -> InternalOptimizeResult:
87+
res = fides_internal(
88+
fun_and_jac=cast(
89+
Callable[[NDArray[np.float64]], NDArray[np.float64]],
90+
problem.fun_and_jac,
91+
),
92+
x=x0,
93+
lower_bounds=problem.bounds.lower,
94+
upper_bounds=problem.bounds.upper,
95+
hessian_update_strategy=self.hessian_update_strategy,
96+
convergence_ftol_abs=self.convergence_ftol_abs,
97+
convergence_ftol_rel=self.convergence_ftol_rel,
98+
convergence_xtol_abs=self.convergence_xtol_abs,
99+
convergence_gtol_abs=self.convergence_gtol_abs,
100+
convergence_gtol_rel=self.convergence_gtol_rel,
101+
stopping_maxiter=self.stopping_maxiter,
102+
stopping_max_seconds=self.stopping_max_seconds,
103+
trustregion_initial_radius=self.trustregion_initial_radius,
104+
trustregion_stepback_strategy=self.trustregion_stepback_strategy,
105+
trustregion_subspace_dimension=self.trustregion_subspace_dimension,
106+
trustregion_max_stepback_fraction=self.trustregion_max_stepback_fraction,
107+
trustregion_decrease_threshold=self.trustregion_decrease_threshold,
108+
trustregion_increase_threshold=self.trustregion_increase_threshold,
109+
trustregion_decrease_factor=self.trustregion_decrease_factor,
110+
trustregion_increase_factor=self.trustregion_increase_factor,
111+
)
112+
113+
return res
114+
115+
116+
def fides_internal(
117+
fun_and_jac: Callable[[NDArray[np.float64]], NDArray[np.float64]],
118+
x: NDArray[np.float64],
119+
lower_bounds: PyTree | None,
120+
upper_bounds: PyTree | None,
121+
hessian_update_strategy: Literal[
122+
"bfgs",
123+
"bb",
124+
"bg",
125+
"dfp",
126+
"sr1",
127+
],
128+
convergence_ftol_abs: NonNegativeFloat,
129+
convergence_ftol_rel: NonNegativeFloat,
130+
convergence_xtol_abs: NonNegativeFloat,
131+
convergence_gtol_abs: NonNegativeFloat,
132+
convergence_gtol_rel: NonNegativeFloat,
133+
stopping_maxiter: PositiveInt,
134+
stopping_max_seconds: float,
135+
trustregion_initial_radius: PositiveFloat,
136+
trustregion_stepback_strategy: Literal[
137+
"truncate",
138+
"reflect",
139+
"reflect_single",
140+
"mixed",
141+
],
142+
trustregion_subspace_dimension: Literal[
143+
"full",
144+
"2D",
145+
"scg",
146+
],
147+
trustregion_max_stepback_fraction: float,
148+
trustregion_decrease_threshold: float,
149+
trustregion_increase_threshold: float,
150+
trustregion_decrease_factor: float,
151+
trustregion_increase_factor: float,
152+
) -> InternalOptimizeResult:
52153
"""Minimize a scalar function using the Fides Optimizer.
53154
54155
For details see
@@ -82,7 +183,7 @@ def fides(
82183
hessian_instance = _create_hessian_updater_from_user_input(hessian_update_strategy)
83184

84185
opt = Optimizer(
85-
fun=criterion_and_derivative,
186+
fun=fun_and_jac,
86187
lb=lower_bounds,
87188
ub=upper_bounds,
88189
verbose=logging.ERROR,
@@ -93,7 +194,17 @@ def fides(
93194
)
94195
raw_res = opt.minimize(x)
95196
res = _process_fides_res(raw_res, opt)
96-
return res
197+
out = InternalOptimizeResult(
198+
x=res["solution_x"],
199+
fun=res["solution_criterion"],
200+
jac=res["solution_derivative"],
201+
hess=res["solution_hessian"],
202+
success=res["success"],
203+
message=res["message"],
204+
n_iterations=res["n_iterations"],
205+
)
206+
207+
return out
97208

98209

99210
def _process_fides_res(raw_res, opt):

tests/optimagic/optimizers/test_fides_options.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
import pytest
55
from numpy.testing import assert_array_almost_equal as aaae
66
from optimagic.config import IS_FIDES_INSTALLED
7+
from optimagic.optimization.optimize import minimize
8+
from optimagic.parameters.bounds import Bounds
79

810
if IS_FIDES_INSTALLED:
911
from fides.hessian_approximation import FX, SR1, Broyden
10-
from optimagic.optimizers.fides import fides
12+
from optimagic.optimizers.fides import Fides
1113
else:
1214
FX = lambda: None
1315
SR1 = lambda: None
@@ -40,17 +42,24 @@ def criterion_and_derivative(x):
4042
return (x**2).sum(), 2 * x
4143

4244

45+
def criterion(x):
46+
return (x**2).sum()
47+
48+
4349
@pytest.mark.skipif(not IS_FIDES_INSTALLED, reason="fides not installed.")
4450
@pytest.mark.parametrize("algo_options", test_cases_no_contribs_needed)
4551
def test_fides_correct_algo_options(algo_options):
46-
res = fides(
47-
criterion_and_derivative=criterion_and_derivative,
48-
x=np.array([1, -5, 3]),
49-
lower_bounds=np.array([-10, -10, -10]),
50-
upper_bounds=np.array([10, 10, 10]),
51-
**algo_options,
52+
res = minimize(
53+
fun_and_jac=criterion_and_derivative,
54+
fun=criterion,
55+
x0=np.array([1, -5, 3]),
56+
bounds=Bounds(
57+
lower=np.array([-10, -10, -10]),
58+
upper=np.array([10, 10, 10]),
59+
),
60+
algorithm=Fides(**algo_options),
5261
)
53-
aaae(res["solution_x"], np.zeros(3), decimal=4)
62+
aaae(res.params, np.zeros(3), decimal=4)
5463

5564

5665
test_cases_needing_contribs = [
@@ -65,23 +74,29 @@ def test_fides_correct_algo_options(algo_options):
6574
@pytest.mark.parametrize("algo_options", test_cases_needing_contribs)
6675
def test_fides_unimplemented_algo_options(algo_options):
6776
with pytest.raises(NotImplementedError):
68-
fides(
69-
criterion_and_derivative=criterion_and_derivative,
70-
x=np.array([1, -5, 3]),
71-
lower_bounds=np.array([-10, -10, -10]),
72-
upper_bounds=np.array([10, 10, 10]),
73-
**algo_options,
77+
minimize(
78+
fun_and_jac=criterion_and_derivative,
79+
fun=criterion,
80+
x0=np.array([1, -5, 3]),
81+
bounds=Bounds(
82+
lower=np.array([-10, -10, -10]),
83+
upper=np.array([10, 10, 10]),
84+
),
85+
algorithm=Fides(**algo_options),
7486
)
7587

7688

7789
@pytest.mark.skipif(not IS_FIDES_INSTALLED, reason="fides not installed.")
7890
def test_fides_stop_after_one_iteration():
79-
res = fides(
80-
criterion_and_derivative=criterion_and_derivative,
81-
x=np.array([1, -5, 3]),
82-
lower_bounds=np.array([-10, -10, -10]),
83-
upper_bounds=np.array([10, 10, 10]),
84-
stopping_maxiter=1,
91+
res = minimize(
92+
fun_and_jac=criterion_and_derivative,
93+
fun=criterion,
94+
x0=np.array([1, -5, 3]),
95+
bounds=Bounds(
96+
lower=np.array([-10, -10, -10]),
97+
upper=np.array([10, 10, 10]),
98+
),
99+
algorithm=Fides(stopping_maxiter=1),
85100
)
86-
assert not res["success"]
87-
assert res["n_iterations"] == 1
101+
assert not res.success
102+
assert res.n_iterations == 1

0 commit comments

Comments
 (0)