Skip to content

Commit

Permalink
Merge pull request #287 from flatironinstitute/bugfix_simulate_pynapple
Browse files Browse the repository at this point in the history
added decorator to simulate
  • Loading branch information
BalzaniEdoardo authored Jan 10, 2025
2 parents 94dae0c + f77c29b commit 316977f
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/nemos/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,7 @@ def _set_coef_and_intercept(self, params):
self.coef_: DESIGN_INPUT_TYPE = params[0]
self.intercept_: jnp.ndarray = params[1]

@support_pynapple(conv_type="jax")
def simulate(
self,
random_key: jax.Array,
Expand Down
56 changes: 56 additions & 0 deletions tests/test_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import pytest
import statsmodels.api as sm
from pynapple import Tsd, TsdFrame
from sklearn.linear_model import GammaRegressor, PoissonRegressor
from sklearn.model_selection import GridSearchCV

Expand Down Expand Up @@ -1332,6 +1333,33 @@ def test_simulate_input_dimensionality(
feedforward_input=X,
)

@pytest.mark.parametrize(
"input_type, expected_out_type",
[
(TsdFrame, Tsd),
(np.ndarray, jnp.ndarray),
(jnp.ndarray, jnp.ndarray),
],
)
def test_simulate_pynapple(
self, input_type, expected_out_type, poissonGLM_model_instantiation
):
"""
Test that the `simulate` method retturns the expected data type for different allowed inputs.
"""
X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
model.coef_ = true_params[0]
model.intercept_ = true_params[1]

if input_type == TsdFrame:
X = TsdFrame(t=np.arange(X.shape[0]), d=X)
count, rate = model.simulate(
random_key=jax.random.key(123),
feedforward_input=X,
)
assert isinstance(count, expected_out_type)
assert isinstance(rate, expected_out_type)

@pytest.mark.parametrize(
"is_fit, expectation",
[
Expand Down Expand Up @@ -3387,6 +3415,34 @@ def test_simulate_input_dimensionality(
feedforward_input=X,
)

@pytest.mark.parametrize(
"input_type, expected_out_type",
[
(TsdFrame, TsdFrame),
(np.ndarray, jnp.ndarray),
(jnp.ndarray, jnp.ndarray),
],
)
def test_simulate_pynapple(
self, input_type, expected_out_type, poisson_population_GLM_model
):
"""
Test that the `simulate` method retturns the expected data type for different allowed inputs.
"""
X, y, model, true_params, firing_rate = poisson_population_GLM_model
model.coef_ = true_params[0]
model.intercept_ = true_params[1]
model._initialize_feature_mask(X, y)
if input_type == TsdFrame:
X = TsdFrame(t=np.arange(X.shape[0]), d=X)

count, rate = model.simulate(
random_key=jax.random.key(123),
feedforward_input=X,
)
assert isinstance(count, expected_out_type)
assert isinstance(rate, expected_out_type)

@pytest.mark.parametrize(
"is_fit, expectation",
[
Expand Down

0 comments on commit 316977f

Please sign in to comment.