Skip to content

Commit d3b3435

Browse files
Jake VanderPlascopybara-github
authored andcommitted
Future-proof calls to jnp.solve on batched 1D inputs.
This has been deprecated since JAX v0.4.25, and is no longer supported in JAX v0.5.0. PiperOrigin-RevId: 713077633
1 parent 8ca8408 commit d3b3435

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

trax/layers/core.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -808,8 +808,7 @@ def log_gaussian_pdf(x, mu, sigma): # pylint: disable=invalid-name
808808
"""
809809
a = mu.shape[-1] * jnp.log(2 * jnp.pi)
810810
_, b = jnp.linalg.slogdet(sigma)
811-
y = jnp.linalg.solve(sigma, x - mu)
812-
y = jnp.expand_dims(y, axis=-1)
811+
y = jnp.linalg.solve(sigma, (x - mu)[..., None])
813812
xm = jnp.expand_dims(x - mu, axis=-2)
814813
c = jnp.matmul(xm, y)
815814
c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1)

0 commit comments

Comments
 (0)