From f77c29b8755d1f74bd1ef4705ae2e5322f39fa19 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 10 Jan 2025 10:49:28 -0500 Subject: [PATCH] added decorator to simulate --- src/nemos/glm.py | 1 + tests/test_glm.py | 56 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/src/nemos/glm.py b/src/nemos/glm.py index 509d8beb..34c94d1f 100644 --- a/src/nemos/glm.py +++ b/src/nemos/glm.py @@ -760,6 +760,7 @@ def _set_coef_and_intercept(self, params): self.coef_: DESIGN_INPUT_TYPE = params[0] self.intercept_: jnp.ndarray = params[1] + @support_pynapple(conv_type="jax") def simulate( self, random_key: jax.Array, diff --git a/tests/test_glm.py b/tests/test_glm.py index 9cdb38eb..8c16371f 100644 --- a/tests/test_glm.py +++ b/tests/test_glm.py @@ -8,6 +8,7 @@ import numpy as np import pytest import statsmodels.api as sm +from pynapple import Tsd, TsdFrame from sklearn.linear_model import GammaRegressor, PoissonRegressor from sklearn.model_selection import GridSearchCV @@ -1332,6 +1333,33 @@ def test_simulate_input_dimensionality( feedforward_input=X, ) + @pytest.mark.parametrize( + "input_type, expected_out_type", + [ + (TsdFrame, Tsd), + (np.ndarray, jnp.ndarray), + (jnp.ndarray, jnp.ndarray), + ], + ) + def test_simulate_pynapple( + self, input_type, expected_out_type, poissonGLM_model_instantiation + ): + """ + Test that the `simulate` method retturns the expected data type for different allowed inputs. + """ + X, y, model, true_params, firing_rate = poissonGLM_model_instantiation + model.coef_ = true_params[0] + model.intercept_ = true_params[1] + + if input_type == TsdFrame: + X = TsdFrame(t=np.arange(X.shape[0]), d=X) + count, rate = model.simulate( + random_key=jax.random.key(123), + feedforward_input=X, + ) + assert isinstance(count, expected_out_type) + assert isinstance(rate, expected_out_type) + @pytest.mark.parametrize( "is_fit, expectation", [ @@ -3387,6 +3415,34 @@ def test_simulate_input_dimensionality( feedforward_input=X, ) + @pytest.mark.parametrize( + "input_type, expected_out_type", + [ + (TsdFrame, TsdFrame), + (np.ndarray, jnp.ndarray), + (jnp.ndarray, jnp.ndarray), + ], + ) + def test_simulate_pynapple( + self, input_type, expected_out_type, poisson_population_GLM_model + ): + """ + Test that the `simulate` method retturns the expected data type for different allowed inputs. + """ + X, y, model, true_params, firing_rate = poisson_population_GLM_model + model.coef_ = true_params[0] + model.intercept_ = true_params[1] + model._initialize_feature_mask(X, y) + if input_type == TsdFrame: + X = TsdFrame(t=np.arange(X.shape[0]), d=X) + + count, rate = model.simulate( + random_key=jax.random.key(123), + feedforward_input=X, + ) + assert isinstance(count, expected_out_type) + assert isinstance(rate, expected_out_type) + @pytest.mark.parametrize( "is_fit, expectation", [