Skip to content

Commit

Permalink
[numpy] Fix test failures under NumPy 2.0.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 662483478
  • Loading branch information
hawkinsp authored and tensorflower-gardener committed Aug 13, 2024
1 parent 95767a1 commit 4724420
Showing 1 changed file with 17 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,18 @@ def _bcast_shape(base_shape, args):
return bcast_shape


def _rng_from_seed(seed):
if seed is None:
return np.random
elif isinstance(seed, int):
return np.random.RandomState(seed & 0xFFFFFFFF)
else:
return np.random.RandomState(np.array(seed, dtype=np.uint32))


def _binomial(shape, seed, counts, probs, output_dtype=np.int32, name=None): # pylint: disable=unused-argument
"""Massaging dtype and nan handling of np.random.binomial."""
rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff)
rng = _rng_from_seed(seed)
invalid_count = (np.int64(counts) < 0) != (counts < 0)
if np.any(invalid_count):
raise ValueError('int64 overflow: {} -> {}'.format(
Expand All @@ -82,7 +91,7 @@ def _binomial(shape, seed, counts, probs, output_dtype=np.int32, name=None): #


def _categorical(logits, num_samples, dtype=None, seed=None, name=None): # pylint: disable=unused-argument
rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff)
rng = _rng_from_seed(seed)
dtype = utils.numpy_dtype(dtype or np.int64)
if not hasattr(logits, 'shape'):
logits = np.array(logits, np.float32)
Expand All @@ -107,7 +116,7 @@ def _categorical_jax(logits, num_samples, dtype=None, seed=None, name=None): #

def _gamma(shape, alpha, beta=None, dtype=np.float32, seed=None,
name=None): # pylint: disable=unused-argument
rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff)
rng = _rng_from_seed(seed)
scale = 1. if beta is None else (1. / beta)
shape = _ensure_shape_tuple(shape)
return rng.gamma(shape=alpha, scale=scale, size=shape).astype(dtype)
Expand All @@ -133,7 +142,7 @@ def _gamma_jax(shape, alpha, beta=None, dtype=np.float32, seed=None, name=None):

def _normal(shape, mean=0.0, stddev=1.0, dtype=np.float32, seed=None,
name=None): # pylint: disable=unused-argument
rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff)
rng = _rng_from_seed(seed)
dtype = utils.common_dtype([mean, stddev], dtype_hint=dtype)
shape = _bcast_shape(shape, [mean, stddev])
return rng.normal(loc=mean, scale=stddev, size=shape).astype(dtype)
Expand All @@ -151,7 +160,7 @@ def _normal_jax(shape, mean=0.0, stddev=1.0, dtype=np.float32, seed=None,

def _poisson(shape, lam, dtype=np.float32, seed=None,
name=None): # pylint: disable=unused-argument
rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff)
rng = _rng_from_seed(seed)
dtype = utils.common_dtype([lam], dtype_hint=dtype)
shape = _ensure_shape_tuple(shape)
return rng.poisson(lam=lam, size=shape).astype(dtype)
Expand Down Expand Up @@ -209,7 +218,7 @@ def _poisson_jax(shape, lam, dtype=np.float32, seed=None,


def _shuffle(value, seed=None, name=None): # pylint: disable=unused-argument
rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff)
rng = _rng_from_seed(seed)
ret = np.array(value)
rng.shuffle(ret)
return ret
Expand All @@ -225,7 +234,7 @@ def _shuffle_jax(value, seed=None, name=None): # pylint: disable=unused-argumen
def _truncated_normal(
shape, seed, means=0.0, stddevs=1.0, minvals=-2.0, maxvals=2.0, name=None): # pylint: disable=unused-argument
from scipy import stats # pylint: disable=g-import-not-at-top
rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff)
rng = _rng_from_seed(seed)
std_low = (minvals - means) / stddevs
std_high = (maxvals - means) / stddevs
std_samps = stats.truncnorm.rvs(
Expand All @@ -248,7 +257,7 @@ def _truncated_normal_jax(
def _uniform(shape, minval=0, maxval=None, dtype=np.float32, seed=None,
name=None): # pylint: disable=unused-argument
"""Numpy uniform random sampler."""
rng = np.random if seed is None else np.random.RandomState(seed & 0xffffffff)
rng = _rng_from_seed(seed)
if minval is not None:
minval = ops.convert_to_tensor(minval, dtype=dtype)
if maxval is not None:
Expand Down

0 comments on commit 4724420

Please sign in to comment.