Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

represent random.key_impl of builtin RNGs by canonical string name #24593

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions jax/_src/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,12 @@ def split(key: KeyArrayLike, num: int | tuple[int, ...] = 2) -> KeyArray:
def _key_impl(keys: KeyArray) -> PRNGImpl:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _key_impl(keys: KeyArray) -> PRNGImpl:
def _key_impl(keys: KeyArray) -> str | PRNGSpec:

assert jnp.issubdtype(keys.dtype, dtypes.prng_key)
keys_dtype = typing.cast(prng.KeyTy, keys.dtype)
return keys_dtype._impl
impl = keys_dtype._impl
return impl.name if impl.name in prng.prngs else PRNGSpec(impl)

def key_impl(keys: KeyArrayLike) -> PRNGSpec:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def key_impl(keys: KeyArrayLike) -> PRNGSpec:
def key_impl(keys: KeyArrayLike) -> str | PRNGSpec:

typed_keys, _ = _check_prng_key("key_impl", keys, allow_batched=True)
return PRNGSpec(_key_impl(typed_keys))
return _key_impl(typed_keys)


def _key_data(keys: KeyArray) -> Array:
Expand Down
30 changes: 18 additions & 12 deletions tests/extend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,35 +70,41 @@ def test_symbols(self):

class RandomTest(jtu.JaxTestCase):

def test_key_make_with_custom_impl(self):
shape = (4, 2, 7)

def make_custom_impl(self, shape, seed=False, split=False, fold_in=False,
random_bits=False):
assert not split and not fold_in and not random_bits # not yet implemented
def seed_rule(_):
return jnp.ones(shape, dtype=jnp.dtype('uint32'))

def no_rule(*args, **kwargs):
assert False, 'unreachable'

impl = jex.random.define_prng_impl(
key_shape=shape, seed=seed_rule, split=no_rule, fold_in=no_rule,
random_bits=no_rule)
return jex.random.define_prng_impl(
key_shape=shape, seed=seed_rule if seed else no_rule, split=no_rule,
fold_in=no_rule, random_bits=no_rule)

def test_key_make_with_custom_impl(self):
impl = self.make_custom_impl(shape=(4, 2, 7), seed=True)
k = jax.random.key(42, impl=impl)
self.assertEqual(k.shape, ())
self.assertEqual(impl, jax.random.key_impl(k))

def test_key_wrap_with_custom_impl(self):
def no_rule(*args, **kwargs):
assert False, 'unreachable'

shape = (4, 2, 7)
impl = jex.random.define_prng_impl(
key_shape=shape, seed=no_rule, split=no_rule, fold_in=no_rule,
random_bits=no_rule)
impl = self.make_custom_impl(shape=shape)
data = jnp.ones((3, *shape), dtype=jnp.dtype('uint32'))
k = jax.random.wrap_key_data(data, impl=impl)
self.assertEqual(k.shape, (3,))
self.assertEqual(impl, jax.random.key_impl(k))

def test_key_impl_is_spec(self):
# this is counterpart to random_test.py:
# KeyArrayTest.test_key_impl_builtin_is_string_name
spec_ref = self.make_custom_impl(shape=(4, 2, 7), seed=True)
key = jax.random.key(42, impl=spec_ref)
spec = jax.random.key_impl(key)
self.assertEqual(repr(spec), f"PRNGSpec({spec_ref._impl.name!r})")


class FfiTest(jtu.JaxTestCase):

Expand Down
4 changes: 2 additions & 2 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,10 +1120,10 @@ class A: pass
jax.random.key(42, impl=A())

@jtu.sample_product(name=[name for name, _ in PRNG_IMPLS])
def test_key_spec_repr(self, name):
def test_key_impl_builtin_is_string_name(self, name):
key = jax.random.key(42, impl=name)
spec = jax.random.key_impl(key)
self.assertEqual(repr(spec), f"PRNGSpec({name!r})")
self.assertEqual(spec, name)

def test_keyarray_custom_vjp(self):
# Regression test for https://github.com/jax-ml/jax/issues/18442
Expand Down
Loading