Skip to content

Commit

Permalink
Merge pull request #88 from flatironinstitute/jax_key_annotate
Browse files Browse the repository at this point in the history
updates PRNGKey to remove deprecation warning
  • Loading branch information
BalzaniEdoardo authored Jan 31, 2024
2 parents e693128 + 3c4b94c commit ad9af5d
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 19 deletions.
4 changes: 2 additions & 2 deletions docs/examples/plot_04_glm_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@
# with the same number of neurons and features (mandatory)
Xnew = np.random.normal(size=(20, ) + X.shape[1:])
# generate a random key given a seed
random_key = jax.random.PRNGKey(123)
random_key = jax.random.key(123)
spikes, rates = model.simulate(random_key, Xnew)

plt.figure()
Expand Down Expand Up @@ -343,7 +343,7 @@
# call simulate, with both the recurrent coupling
# and the input
spikes, rates = model.simulate_recurrent(
jax.random.PRNGKey(123),
jax.random.key(123),
feedforward_input=feedforward_input,
coupling_basis_matrix=coupling_basis,
init_y=init_spikes
Expand Down
2 changes: 1 addition & 1 deletion src/nemos/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def score(
@abc.abstractmethod
def simulate(
self,
random_key: jax.random.PRNGKeyArray,
random_key: jax.Array,
feed_forward_input: DESIGN_INPUT_TYPE,
):
"""Simulate neural activity in response to a feed-forward input and recurrent activity."""
Expand Down
10 changes: 5 additions & 5 deletions src/nemos/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,15 +393,15 @@ def fit(

def simulate(
self,
random_key: jax.random.PRNGKeyArray,
random_key: jax.Array,
feedforward_input: DESIGN_INPUT_TYPE,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Simulate neural activity in response to a feed-forward input.
Parameters
----------
random_key :
PRNGKey for seeding the simulation.
jax.random.key for seeding the simulation.
feedforward_input :
External input matrix to the model, representing factors like convolved currents,
light intensities, etc. When not provided, the simulation is done with coupling-only.
Expand Down Expand Up @@ -485,7 +485,7 @@ def __init__(

def simulate_recurrent(
self,
random_key: jax.random.PRNGKeyArray,
random_key: jax.Array,
feedforward_input: Union[NDArray, jnp.ndarray],
coupling_basis_matrix: Union[NDArray, jnp.ndarray],
init_y: Union[NDArray, jnp.ndarray],
Expand All @@ -501,7 +501,7 @@ def simulate_recurrent(
Parameters
----------
random_key :
PRNGKey for seeding the simulation.
jax.random.key for seeding the simulation.
feedforward_input :
External input matrix to the model, representing factors like convolved currents,
light intensities, etc. When not provided, the simulation is done with coupling-only.
Expand Down Expand Up @@ -586,7 +586,7 @@ def simulate_recurrent(
)

def scan_fn(
data: Tuple[jnp.ndarray, int], key: jax.random.PRNGKeyArray
data: Tuple[jnp.ndarray, int], key: jax.Array
) -> Tuple[Tuple[jnp.ndarray, int], Tuple[jnp.ndarray, jnp.ndarray]]:
"""Scan over time steps and simulate activity and rates.
Expand Down
8 changes: 3 additions & 5 deletions src/nemos/observation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from . import utils
from .base_class import Base

KeyArray = Union[jnp.ndarray, jax.random.PRNGKeyArray]

__all__ = ["PoissonObservations"]


Expand Down Expand Up @@ -138,7 +136,7 @@ def negative_log_likelihood(self, predicted_rate, y):

@abc.abstractmethod
def sample_generator(
self, key: KeyArray, predicted_rate: jnp.ndarray
self, key: jax.Array, predicted_rate: jnp.ndarray
) -> jnp.ndarray:
"""
Sample from the estimated distribution.
Expand Down Expand Up @@ -399,7 +397,7 @@ def negative_log_likelihood(
return jnp.mean(predicted_rate - x)

def sample_generator(
self, key: KeyArray, predicted_rate: jnp.ndarray
self, key: jax.Array, predicted_rate: jnp.ndarray
) -> jnp.ndarray:
"""
Sample from the Poisson distribution.
Expand Down Expand Up @@ -538,7 +536,7 @@ def check_observation_model(observation_model):
"test_scalar_func": True,
},
"sample_generator": {
"input": [jax.random.PRNGKey(123), 0.5 * jnp.array([1.0, 1.0, 1.0])],
"input": [jax.random.key(123), 0.5 * jnp.array([1.0, 1.0, 1.0])],
"test_preserve_shape": True,
},
}
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def poissonGLM_coupled_model_config_simulate():
- coupling_basis (jax.numpy.ndarray): Coupling basis values from the config.
- feedforward_input (jax.numpy.ndarray): Feedforward input values from the config.
- init_spikes (jax.numpy.ndarray): Initial spike values from the config.
- jax.random.PRNGKey(123) (jax.random.PRNGKey): A pseudo-random number generator key.
- jax.random.key(123) (jax.Array): A pseudo-random number generator key.
"""
observations = nmo.observation_models.PoissonObservations(jnp.exp)
regularizer = nmo.regularizer.Ridge("BFGS", regularizer_strength=0.1)
Expand Down Expand Up @@ -119,7 +119,7 @@ def poissonGLM_coupled_model_config_simulate():
coupling_basis,
feedforward_input,
init_spikes,
jax.random.PRNGKey(123),
jax.random.key(123),
)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def score(

def simulate(
self,
random_key: jax.random.PRNGKeyArray,
random_key: jax.Array,
feed_forward_input: Union[NDArray, jnp.ndarray],
**kwargs,
):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,14 +1038,14 @@ def test_simulate_feedforward_GLM_not_fit(self, poissonGLM_model_instantiation):
with pytest.raises(
nmo.exceptions.NotFittedError, match="This GLM instance is not fitted yet"
):
model.simulate(jax.random.PRNGKey(123), X)
model.simulate(jax.random.key(123), X)

def test_simulate_feedforward_GLM(self, poissonGLM_model_instantiation):
"""Test that simulate goes through"""
X, y, model, params, rate = poissonGLM_model_instantiation
model.coef_ = params[0]
model.intercept_ = params[1]
ysim, ratesim = model.simulate(jax.random.PRNGKey(123), X)
ysim, ratesim = model.simulate(jax.random.key(123), X)
# check that the expected dimensionality is returned
assert ysim.ndim == 2
assert ratesim.ndim == 2
Expand Down
2 changes: 1 addition & 1 deletion tests/test_observation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def test_emission_probability(selfself, poissonGLM_model_instantiation):
Check that the emission probability is set to jax.random.poisson.
"""
_, _, model, _, _ = poissonGLM_model_instantiation
key_array = jax.random.PRNGKey(123)
key_array = jax.random.key(123)
counts = model.observation_model.sample_generator(key_array, np.arange(1, 11))
if not jnp.all(counts == jax.random.poisson(key_array, np.arange(1, 11))):
raise ValueError(
Expand Down

0 comments on commit ad9af5d

Please sign in to comment.