Skip to content

Commit

Permalink
Allow for AlgorithmType in estimation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Sep 27, 2024
1 parent 216b41e commit a7160d0
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 7 deletions.
9 changes: 6 additions & 3 deletions src/estimagic/estimate_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def estimate_ml(
optimize_options to False. Pytrees can be a numpy array, a pandas Series, a
DataFrame with "value" column, a float and any kind of (nested) dictionary
or list containing these elements. See :ref:`params` for examples.
optimize_options (dict, str or False): Keyword arguments that govern the
numerical optimization. Valid entries are all arguments of
optimize_options (dict, Algorithm, str or False): Keyword arguments that govern
the numerical optimization. Valid entries are all arguments of
:func:`~estimagic.optimization.optimize.minimize` except for those that are
passed explicilty to ``estimate_ml``. If you pass False as optimize_options
you signal that ``params`` are already the optimal parameters and no
Expand Down Expand Up @@ -199,7 +199,10 @@ def estimate_ml(
is_optimized = optimize_options is False

if not is_optimized:
if isinstance(optimize_options, str):
# If optimize_options is not a dictionary and not False, we assume it represents
# an algorithm. The actual testing of whether it is a valid algorithm is done
# when `maximize` is called.
if not isinstance(optimize_options, dict):
optimize_options = {"algorithm": optimize_options}

check_optimization_options(
Expand Down
9 changes: 6 additions & 3 deletions src/estimagic/estimate_msm.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def estimate_msm(
optimize_options to False. Pytrees can be a numpy array, a pandas Series, a
DataFrame with "value" column, a float and any kind of (nested) dictionary
or list containing these elements. See :ref:`params` for examples.
optimize_options (dict, str or False): Keyword arguments that govern the
numerical optimization. Valid entries are all arguments of
optimize_options (dict, Algorithm, str or False): Keyword arguments that govern
the numerical optimization. Valid entries are all arguments of
:func:`~estimagic.optimization.optimize.minimize` except for those that can
be passed explicitly to ``estimate_msm``. If you pass False as
``optimize_options`` you signal that ``params`` are already
Expand Down Expand Up @@ -199,7 +199,10 @@ def estimate_msm(
is_optimized = optimize_options is False

if not is_optimized:
if isinstance(optimize_options, str):
# If optimize_options is not a dictionary and not False, we assume it represents
# an algorithm. The actual testing of whether it is a valid algorithm is done
# when `maximize` is called.
if not isinstance(optimize_options, dict):
optimize_options = {"algorithm": optimize_options}

check_optimization_options(
Expand Down
2 changes: 1 addition & 1 deletion src/optimagic/shared/check_option_dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ def check_optimization_options(options, usage, algorithm_mandatory=True):
msg = (
"The following are not valid entries of optimize_options because they are "
"not only relevant for minimization but also for inference: "
"{invalid_general}"
f"{invalid_general}"
)
raise ValueError(msg)
29 changes: 29 additions & 0 deletions tests/estimagic/test_estimate_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
scalar_logit_fun_and_jac,
)
from optimagic import mark
from optimagic.optimizers import scipy_optimizers
from optimagic.parameters.bounds import Bounds


Expand Down Expand Up @@ -349,6 +350,34 @@ def test_estimate_ml_optimize_options_false(fitted_logit_model, logit_np_inputs)
aaae(got.cov(method="jacobian"), fitted_logit_model.covjac, decimal=4)


def test_estimate_ml_algorithm_type(logit_np_inputs):
"""Test that estimate_ml computes correct covariances given correct params."""
kwargs = {"y": logit_np_inputs["y"], "x": logit_np_inputs["x"]}

params = pd.DataFrame({"value": logit_np_inputs["params"]})

estimate_ml(
loglike=logit_loglike,
params=params,
loglike_kwargs=kwargs,
optimize_options=scipy_optimizers.ScipyLBFGSB,
)


def test_estimate_ml_algorithm(logit_np_inputs):
"""Test that estimate_ml computes correct covariances given correct params."""
kwargs = {"y": logit_np_inputs["y"], "x": logit_np_inputs["x"]}

params = pd.DataFrame({"value": logit_np_inputs["params"]})

estimate_ml(
loglike=logit_loglike,
params=params,
loglike_kwargs=kwargs,
optimize_options=scipy_optimizers.ScipyLBFGSB(stopping_maxfun=10),
)


# ======================================================================================
# Univariate normal case using dict params
# ======================================================================================
Expand Down
35 changes: 35 additions & 0 deletions tests/estimagic/test_estimate_msm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from estimagic.estimate_msm import estimate_msm
from optimagic.optimization.optimize_result import OptimizeResult
from optimagic.optimizers import scipy_optimizers
from optimagic.shared.check_option_dicts import (
check_optimization_options,
)
Expand Down Expand Up @@ -161,6 +162,40 @@ def test_estimate_msm_with_jacobian():
aaae(calculated.cov(), cov_np)


def test_estimate_msm_with_algorithm_type():
start_params = np.array([3, 2, 1])
expected_params = np.zeros(3)
empirical_moments = _sim_np(expected_params)
if isinstance(empirical_moments, dict):
empirical_moments = empirical_moments["simulated_moments"]

estimate_msm(
simulate_moments=_sim_np,
empirical_moments=empirical_moments,
moments_cov=cov_np,
params=start_params,
optimize_options=scipy_optimizers.ScipyLBFGSB,
jacobian=lambda x: np.eye(len(x)),
)


def test_estimate_msm_with_algorithm():
start_params = np.array([3, 2, 1])
expected_params = np.zeros(3)
empirical_moments = _sim_np(expected_params)
if isinstance(empirical_moments, dict):
empirical_moments = empirical_moments["simulated_moments"]

estimate_msm(
simulate_moments=_sim_np,
empirical_moments=empirical_moments,
moments_cov=cov_np,
params=start_params,
optimize_options=scipy_optimizers.ScipyLBFGSB(stopping_maxfun=10),
jacobian=lambda x: np.eye(len(x)),
)


def test_to_pickle(tmp_path):
start_params = np.array([3, 2, 1])

Expand Down

0 comments on commit a7160d0

Please sign in to comment.