diff --git a/pyproject.toml b/pyproject.toml index 725c7c2c..28716766 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "nemos" -version = "0.2.0" +version = "0.2.1" authors = [{name = "nemos authors"}] description = "NEural MOdelS, a statistical modeling framework for neuroscience." readme = "README.md" diff --git a/src/nemos/__init__.py b/src/nemos/__init__.py index ecde7b4f..ee8e8841 100644 --- a/src/nemos/__init__.py +++ b/src/nemos/__init__.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -__version__ = "0.2.0" +__version__ = "0.2.1" from . import ( basis, diff --git a/src/nemos/glm.py b/src/nemos/glm.py index 5f7ab0d4..010aefef 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -1126,6 +1126,12 @@ def _get_optimal_solver_params_config(self): def __repr__(self): return format_repr(self, multiline=True) + def __sklearn_clone__(self) -> GLM: + """Clone the PopulationGLM, dropping feature_mask""" + params = self.get_params(deep=False) + klass = self.__class__(**params) + return klass + class PopulationGLM(GLM): """ @@ -1633,7 +1639,7 @@ def _predict( + bs ) - def __sklearn_clone__(self) -> GLM: + def __sklearn_clone__(self) -> PopulationGLM: """Clone the PopulationGLM, dropping feature_mask""" params = self.get_params(deep=False) params.pop("feature_mask") diff --git a/src/nemos/utils.py b/src/nemos/utils.py index b25644d0..68215c2c 100644 --- a/src/nemos/utils.py +++ b/src/nemos/utils.py @@ -534,7 +534,7 @@ def format_repr( ) if repr_param: if k in use_name_keys: - v = v.__name__ + v = getattr(v, "__name__", repr(v)) elif isinstance(v, str): v = repr(v) disp_params.append(f"{k}={v}") diff --git a/tests/test_glm.py b/tests/test_glm.py index f6d53ad7..a54b55b2 100644 --- a/tests/test_glm.py +++ b/tests/test_glm.py @@ -1478,12 +1478,16 @@ def test_deviance_against_statsmodels(self, poissonGLM_model_instantiation): def test_compatibility_with_sklearn_cv(self, poissonGLM_model_instantiation): X, y, model, true_params, firing_rate = poissonGLM_model_instantiation param_grid = {"solver_name": ["BFGS", "GradientDescent"]} - GridSearchCV(model, param_grid).fit(X, y) + cls = GridSearchCV(model, param_grid).fit(X, y) + # check that the repr works after cloning + repr(cls) def test_compatibility_with_sklearn_cv_gamma(self, gammaGLM_model_instantiation): X, y, model, true_params, firing_rate = gammaGLM_model_instantiation param_grid = {"solver_name": ["BFGS", "GradientDescent"]} - GridSearchCV(model, param_grid).fit(X, y) + cls = GridSearchCV(model, param_grid).fit(X, y) + # check that the repr works after cloning + repr(cls) @pytest.mark.parametrize( "regr_setup, glm_class", @@ -3572,12 +3576,16 @@ def test_deviance_against_statsmodels(self, poisson_population_GLM_model): def test_compatibility_with_sklearn_cv(self, poisson_population_GLM_model): X, y, model, true_params, firing_rate = poisson_population_GLM_model param_grid = {"solver_name": ["BFGS", "GradientDescent"]} - GridSearchCV(model, param_grid).fit(X, y) + cls = GridSearchCV(model, param_grid).fit(X, y) + # check that the repr works after cloning + repr(cls) def test_compatibility_with_sklearn_cv_gamma(self, gamma_population_GLM_model): X, y, model, true_params, firing_rate = gamma_population_GLM_model param_grid = {"solver_name": ["BFGS", "GradientDescent"]} - GridSearchCV(model, param_grid).fit(X, y) + cls = GridSearchCV(model, param_grid).fit(X, y) + # check that the repr works after cloning + repr(cls) def test_sklearn_clone(self, poisson_population_GLM_model): X, y, model, true_params, firing_rate = poisson_population_GLM_model diff --git a/tests/test_utils.py b/tests/test_utils.py index e267c3f6..5e3331da 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,6 @@ import warnings from contextlib import nullcontext as does_not_raise +from copy import deepcopy import jax import jax.numpy as jnp @@ -595,6 +596,8 @@ def __repr__(self): (Example(a=0, b=False, c=None), None, [], "Example(a=0, b=False, d=1)"), # Falsey values excluded2 (Example(a=0, b=[], c={}), None, [], "Example(a=0, d=1)"), + # function without the __name__ + (nmo.observation_models.PoissonObservations(deepcopy(jax.numpy.exp)),None, [], "PoissonObservations(inverse_link_function=)") ], ) def test_format_repr(obj, exclude_keys, use_name_keys, expected):