Skip to content

Commit ad9af5d

Browse files
Merge pull request #88 from flatironinstitute/jax_key_annotate
updates PRNGKey to remove deprecation warning
2 parents e693128 + 3c4b94c commit ad9af5d

File tree

8 files changed

+17
-19
lines changed

8 files changed

+17
-19
lines changed

docs/examples/plot_04_glm_demo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@
215215
# with the same number of neurons and features (mandatory)
216216
Xnew = np.random.normal(size=(20, ) + X.shape[1:])
217217
# generate a random key given a seed
218-
random_key = jax.random.PRNGKey(123)
218+
random_key = jax.random.key(123)
219219
spikes, rates = model.simulate(random_key, Xnew)
220220

221221
plt.figure()
@@ -343,7 +343,7 @@
343343
# call simulate, with both the recurrent coupling
344344
# and the input
345345
spikes, rates = model.simulate_recurrent(
346-
jax.random.PRNGKey(123),
346+
jax.random.key(123),
347347
feedforward_input=feedforward_input,
348348
coupling_basis_matrix=coupling_basis,
349349
init_y=init_spikes

src/nemos/base_class.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def score(
205205
@abc.abstractmethod
206206
def simulate(
207207
self,
208-
random_key: jax.random.PRNGKeyArray,
208+
random_key: jax.Array,
209209
feed_forward_input: DESIGN_INPUT_TYPE,
210210
):
211211
"""Simulate neural activity in response to a feed-forward input and recurrent activity."""

src/nemos/glm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -393,15 +393,15 @@ def fit(
393393

394394
def simulate(
395395
self,
396-
random_key: jax.random.PRNGKeyArray,
396+
random_key: jax.Array,
397397
feedforward_input: DESIGN_INPUT_TYPE,
398398
) -> Tuple[jnp.ndarray, jnp.ndarray]:
399399
"""Simulate neural activity in response to a feed-forward input.
400400
401401
Parameters
402402
----------
403403
random_key :
404-
PRNGKey for seeding the simulation.
404+
jax.random.key for seeding the simulation.
405405
feedforward_input :
406406
External input matrix to the model, representing factors like convolved currents,
407407
light intensities, etc. When not provided, the simulation is done with coupling-only.
@@ -485,7 +485,7 @@ def __init__(
485485

486486
def simulate_recurrent(
487487
self,
488-
random_key: jax.random.PRNGKeyArray,
488+
random_key: jax.Array,
489489
feedforward_input: Union[NDArray, jnp.ndarray],
490490
coupling_basis_matrix: Union[NDArray, jnp.ndarray],
491491
init_y: Union[NDArray, jnp.ndarray],
@@ -501,7 +501,7 @@ def simulate_recurrent(
501501
Parameters
502502
----------
503503
random_key :
504-
PRNGKey for seeding the simulation.
504+
jax.random.key for seeding the simulation.
505505
feedforward_input :
506506
External input matrix to the model, representing factors like convolved currents,
507507
light intensities, etc. When not provided, the simulation is done with coupling-only.
@@ -586,7 +586,7 @@ def simulate_recurrent(
586586
)
587587

588588
def scan_fn(
589-
data: Tuple[jnp.ndarray, int], key: jax.random.PRNGKeyArray
589+
data: Tuple[jnp.ndarray, int], key: jax.Array
590590
) -> Tuple[Tuple[jnp.ndarray, int], Tuple[jnp.ndarray, jnp.ndarray]]:
591591
"""Scan over time steps and simulate activity and rates.
592592

src/nemos/observation_models.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
from . import utils
1010
from .base_class import Base
1111

12-
KeyArray = Union[jnp.ndarray, jax.random.PRNGKeyArray]
13-
1412
__all__ = ["PoissonObservations"]
1513

1614

@@ -138,7 +136,7 @@ def negative_log_likelihood(self, predicted_rate, y):
138136

139137
@abc.abstractmethod
140138
def sample_generator(
141-
self, key: KeyArray, predicted_rate: jnp.ndarray
139+
self, key: jax.Array, predicted_rate: jnp.ndarray
142140
) -> jnp.ndarray:
143141
"""
144142
Sample from the estimated distribution.
@@ -399,7 +397,7 @@ def negative_log_likelihood(
399397
return jnp.mean(predicted_rate - x)
400398

401399
def sample_generator(
402-
self, key: KeyArray, predicted_rate: jnp.ndarray
400+
self, key: jax.Array, predicted_rate: jnp.ndarray
403401
) -> jnp.ndarray:
404402
"""
405403
Sample from the Poisson distribution.
@@ -538,7 +536,7 @@ def check_observation_model(observation_model):
538536
"test_scalar_func": True,
539537
},
540538
"sample_generator": {
541-
"input": [jax.random.PRNGKey(123), 0.5 * jnp.array([1.0, 1.0, 1.0])],
539+
"input": [jax.random.key(123), 0.5 * jnp.array([1.0, 1.0, 1.0])],
542540
"test_preserve_shape": True,
543541
},
544542
}

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def poissonGLM_coupled_model_config_simulate():
8282
- coupling_basis (jax.numpy.ndarray): Coupling basis values from the config.
8383
- feedforward_input (jax.numpy.ndarray): Feedforward input values from the config.
8484
- init_spikes (jax.numpy.ndarray): Initial spike values from the config.
85-
- jax.random.PRNGKey(123) (jax.random.PRNGKey): A pseudo-random number generator key.
85+
- jax.random.key(123) (jax.Array): A pseudo-random number generator key.
8686
"""
8787
observations = nmo.observation_models.PoissonObservations(jnp.exp)
8888
regularizer = nmo.regularizer.Ridge("BFGS", regularizer_strength=0.1)
@@ -119,7 +119,7 @@ def poissonGLM_coupled_model_config_simulate():
119119
coupling_basis,
120120
feedforward_input,
121121
init_spikes,
122-
jax.random.PRNGKey(123),
122+
jax.random.key(123),
123123
)
124124

125125

tests/test_base_class.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def score(
4242

4343
def simulate(
4444
self,
45-
random_key: jax.random.PRNGKeyArray,
45+
random_key: jax.Array,
4646
feed_forward_input: Union[NDArray, jnp.ndarray],
4747
**kwargs,
4848
):

tests/test_glm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,14 +1038,14 @@ def test_simulate_feedforward_GLM_not_fit(self, poissonGLM_model_instantiation):
10381038
with pytest.raises(
10391039
nmo.exceptions.NotFittedError, match="This GLM instance is not fitted yet"
10401040
):
1041-
model.simulate(jax.random.PRNGKey(123), X)
1041+
model.simulate(jax.random.key(123), X)
10421042

10431043
def test_simulate_feedforward_GLM(self, poissonGLM_model_instantiation):
10441044
"""Test that simulate goes through"""
10451045
X, y, model, params, rate = poissonGLM_model_instantiation
10461046
model.coef_ = params[0]
10471047
model.intercept_ = params[1]
1048-
ysim, ratesim = model.simulate(jax.random.PRNGKey(123), X)
1048+
ysim, ratesim = model.simulate(jax.random.key(123), X)
10491049
# check that the expected dimensionality is returned
10501050
assert ysim.ndim == 2
10511051
assert ratesim.ndim == 2

tests/test_observation_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def test_emission_probability(selfself, poissonGLM_model_instantiation):
160160
Check that the emission probability is set to jax.random.poisson.
161161
"""
162162
_, _, model, _, _ = poissonGLM_model_instantiation
163-
key_array = jax.random.PRNGKey(123)
163+
key_array = jax.random.key(123)
164164
counts = model.observation_model.sample_generator(key_array, np.arange(1, 11))
165165
if not jnp.all(counts == jax.random.poisson(key_array, np.arange(1, 11))):
166166
raise ValueError(

0 commit comments

Comments
 (0)