Skip to content

Commit

Permalink
improved decorator logic
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Feb 10, 2025
1 parent 34c3b06 commit e554a8d
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
7 changes: 5 additions & 2 deletions src/nemos/basis/_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ def wrapper(self, **params):
current_params = self.__sklearn_get_params__()
for key, val in new_params.items():
current = current_params[key]
if isinstance(val, Basis) and val._has_default_label and not current._has_default_label:
if (
isinstance(val, Basis)
and val._has_default_label
and not current._has_default_label
):
try:
val.label = current.label
except ValueError:
Expand Down Expand Up @@ -367,7 +371,6 @@ def get_params(self, deep=True) -> dict:
def set_params(self, **params: Any):
return super().set_params(**params)


@property
def n_output_features(self) -> int | None:
"""
Expand Down
4 changes: 1 addition & 3 deletions src/nemos/basis/_basis_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import copy
import inspect
import re
import sys
from functools import wraps
from itertools import chain
from typing import TYPE_CHECKING, Generator, List, Literal, Optional, Tuple, Union
Expand Down Expand Up @@ -39,7 +38,7 @@
"OrthExponentialEval",
"OrthExponentialConv",
"AdditiveBasis",
"MultiplicativeBasis"
"MultiplicativeBasis",
]


Expand Down Expand Up @@ -165,7 +164,6 @@ def label(self, label: str | None) -> None:
def _has_default_label(self):
return re.match(rf"^{self.__class__.__name__}(_\d+)?$", self._label) is not None


def _recompute_all_labels(self):
"""
Recompute all labels matching default for self.
Expand Down
7 changes: 6 additions & 1 deletion src/nemos/basis/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@

from ..typing import FeatureMatrix
from ._basis import add_docstring
from ._basis_mixin import AtomicBasisMixin, ConvBasisMixin, EvalBasisMixin, __PUBLIC_BASES__
from ._basis_mixin import (
__PUBLIC_BASES__,
AtomicBasisMixin,
ConvBasisMixin,
EvalBasisMixin,
)
from ._decaying_exponential import OrthExponentialBasis
from ._identity import HistoryBasis, IdentityBasis
from ._raised_cosine_basis import RaisedCosineBasisLinear, RaisedCosineBasisLog
Expand Down

0 comments on commit e554a8d

Please sign in to comment.