Skip to content

Commit

Permalink
Do not use .__algo_info__ when .algo_info is available.
Browse files Browse the repository at this point in the history
  • Loading branch information
janosg committed Nov 12, 2024
1 parent d16c36f commit 1dcb642
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 31 deletions.
33 changes: 15 additions & 18 deletions .tools/create_algo_selection_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _get_algorithms_in_module(module: ModuleType) -> dict[str, Type[Algorithm]]:
}
algos = {}
for candidate in candidate_dict.values():
name = candidate.__algo_info__.name
name = candidate.algo_info.name
if issubclass(candidate, Algorithm) and candidate is not Algorithm:
algos[name] = candidate
return algos
Expand All @@ -119,47 +119,47 @@ def _get_algorithms_in_module(module: ModuleType) -> dict[str, Type[Algorithm]]:
# Functions to filter algorithms by selectors
# ======================================================================================
def _is_gradient_based(algo: Type[Algorithm]) -> bool:
return algo.__algo_info__.needs_jac # type: ignore
return algo.algo_info.needs_jac


def _is_gradient_free(algo: Type[Algorithm]) -> bool:
return not _is_gradient_based(algo)


def _is_global(algo: Type[Algorithm]) -> bool:
return algo.__algo_info__.is_global # type: ignore
return algo.algo_info.is_global


def _is_local(algo: Type[Algorithm]) -> bool:
return not _is_global(algo)


def _is_bounded(algo: Type[Algorithm]) -> bool:
return algo.__algo_info__.supports_bounds # type: ignore
return algo.algo_info.supports_bounds


def _is_linear_constrained(algo: Type[Algorithm]) -> bool:
return algo.__algo_info__.supports_linear_constraints # type: ignore
return algo.algo_info.supports_linear_constraints


def _is_nonlinear_constrained(algo: Type[Algorithm]) -> bool:
return algo.__algo_info__.supports_nonlinear_constraints # type: ignore
return algo.algo_info.supports_nonlinear_constraints


def _is_scalar(algo: Type[Algorithm]) -> bool:
return algo.__algo_info__.solver_type == AggregationLevel.SCALAR # type: ignore
return algo.algo_info.solver_type == AggregationLevel.SCALAR


def _is_least_squares(algo: Type[Algorithm]) -> bool:
return algo.__algo_info__.solver_type == AggregationLevel.LEAST_SQUARES # type: ignore
return algo.algo_info.solver_type == AggregationLevel.LEAST_SQUARES


def _is_likelihood(algo: Type[Algorithm]) -> bool:
return algo.__algo_info__.solver_type == AggregationLevel.LIKELIHOOD # type: ignore
return algo.algo_info.solver_type == AggregationLevel.LIKELIHOOD


def _is_parallel(algo: Type[Algorithm]) -> bool:
return algo.__algo_info__.supports_parallelism # type: ignore
return algo.algo_info.supports_parallelism


def _get_filters() -> dict[str, Callable[[Type[Algorithm]], bool]]:
Expand Down Expand Up @@ -385,7 +385,7 @@ def _all(self) -> list[Type[Algorithm]]:
def _available(self) -> list[Type[Algorithm]]:
_all = self._all()
return [
a for a in _all if a.__algo_info__.is_available # type: ignore
a for a in _all if a.algo_info.is_available # type: ignore
]
@property
Expand All @@ -398,22 +398,19 @@ def Available(self) -> list[Type[Algorithm]]:
@property
def AllNames(self) -> list[str]:
return [a.__algo_info__.name for a in self._all()] # type: ignore
return [str(a.name) for a in self._all()]
@property
def AvailableNames(self) -> list[str]:
return [a.__algo_info__.name for a in self._available()] # type: ignore
return [str(a.name) for a in self._available()]
@property
def _all_algorithms_dict(self) -> dict[str, Type[Algorithm]]:
return {a.__algo_info__.name: a for a in self._all()} # type: ignore
return {str(a.name): a for a in self._all()}
@property
def _available_algorithms_dict(self) -> dict[str, Type[Algorithm]]:
return {
a.__algo_info__.name: a # type: ignore
for a in self._available()
}
return {str(a.name): a for a in self._available()}
""")
return out
Expand Down
13 changes: 5 additions & 8 deletions src/optimagic/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _available(self) -> list[Type[Algorithm]]:
return [
a
for a in _all
if a.__algo_info__.is_available # type: ignore
if a.algo_info.is_available # type: ignore
]

@property
Expand All @@ -103,22 +103,19 @@ def Available(self) -> list[Type[Algorithm]]:

@property
def AllNames(self) -> list[str]:
return [a.__algo_info__.name for a in self._all()] # type: ignore
return [str(a.name) for a in self._all()]

Check warning on line 106 in src/optimagic/algorithms.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/algorithms.py#L106

Added line #L106 was not covered by tests

@property
def AvailableNames(self) -> list[str]:
return [a.__algo_info__.name for a in self._available()] # type: ignore
return [str(a.name) for a in self._available()]

Check warning on line 110 in src/optimagic/algorithms.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/algorithms.py#L110

Added line #L110 was not covered by tests

@property
def _all_algorithms_dict(self) -> dict[str, Type[Algorithm]]:
return {a.__algo_info__.name: a for a in self._all()} # type: ignore
return {str(a.name): a for a in self._all()}

@property
def _available_algorithms_dict(self) -> dict[str, Type[Algorithm]]:
return {
a.__algo_info__.name: a # type: ignore
for a in self._available()
}
return {str(a.name): a for a in self._available()}


@dataclass(frozen=True)
Expand Down
4 changes: 2 additions & 2 deletions src/optimagic/visualization/history_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,9 @@ def _harmonize_inputs_to_dict(results, names):

def _convert_key_to_str(key: Any) -> str:
if inspect.isclass(key) and issubclass(key, Algorithm):
out = key.__algo_info__.name # type: ignore
out = str(key.name)

Check warning on line 202 in src/optimagic/visualization/history_plots.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/visualization/history_plots.py#L202

Added line #L202 was not covered by tests
elif isinstance(key, Algorithm):
out = key.__algo_info__.name # type: ignore
out = str(key.name)

Check warning on line 204 in src/optimagic/visualization/history_plots.py

View check run for this annotation

Codecov / codecov/patch

src/optimagic/visualization/history_plots.py#L204

Added line #L204 was not covered by tests
else:
out = str(key)
return out
Expand Down
2 changes: 1 addition & 1 deletion tests/optimagic/optimization/test_history_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
OPTIMIZERS = []
BOUNDED = []
for name, algo in AVAILABLE_ALGORITHMS.items():
info = algo.__algo_info__
info = algo.algo_info
if not info.disable_history:
if info.supports_parallelism:
OPTIMIZERS.append(name)
Expand Down
2 changes: 1 addition & 1 deletion tests/optimagic/optimization/test_many_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

BOUNDED_ALGORITHMS = []
for name, algo in LOCAL_ALGORITHMS.items():
if algo.__algo_info__.supports_bounds:
if algo.algo_info.supports_bounds:
BOUNDED_ALGORITHMS.append(name)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_nonlinear_optimization(nlc_2d_example, algorithm, constr_type):
warnings.simplefilter("ignore")
result = maximize(algorithm=algorithm, **kwargs[constr_type])

if NLC_ALGORITHMS[algorithm].__algo_info__.is_global:
if NLC_ALGORITHMS[algorithm].algo_info.is_global:
decimal = 0
else:
decimal = 4
Expand Down

0 comments on commit 1dcb642

Please sign in to comment.