You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Consider a function $f(a) = a b$ with $b$ some complex parameter and $a$ real. I want to compute the vector-Jacobian product $t^* \partial f(a)$. When $a b$ represents matrix multiplication, a warning is raised about an imaginary part being discarded.
For $a$, $b$ scalar it works as expected and without warning (imported jax, jax.numpy as jnp, and warnings):
However, if we promote $a$ and $b$ to matrices, this raises a ComplexWarning:
a=jnp.array([[1.0]])
b=jnp.array([[1.0j]])
t=jnp.array([[1.0j]])
withwarnings.catch_warnings():
warnings.simplefilter("error", jnp.ComplexWarning)
_, vjp_fn=jax.vjp(lambdax: x @ b, a)
vjp_fn(t)
ComplexWarning: Casting complex values to real discards the imaginary part
I would not expect a complex warning in the latter case. Although the output is the same in both cases (up to shapes), the warning suggests something could go wrong. Unless I'm missing something, the two behaviors are at least inconsistent.
The warning is not raised when $a$ is a matrix but $b$ is not (matrix-scalar multiplication), but it also occurs when using einsum.
System info (python version, jaxlib version, accelerator, etc.)
Yes, that would be a good resolution. I'm also unaware of anything going wrong. But the warning made me waste some time searching for what I might be doing wrong, so I thought fixing this might be useful.
Description
Consider a function$f(a) = a b$ with $b$ some complex parameter and $a$ real. I want to compute the vector-Jacobian product $t^* \partial f(a)$ . When $a b$ represents matrix multiplication, a warning is raised about an imaginary part being discarded.
For$a$ , $b$ scalar it works as expected and without warning (imported
jax
,jax.numpy as jnp
, andwarnings
):However, if we promote$a$ and $b$ to matrices, this raises a ComplexWarning:
I would not expect a complex warning in the latter case. Although the output is the same in both cases (up to shapes), the warning suggests something could go wrong. Unless I'm missing something, the two behaviors are at least inconsistent.
The warning is not raised when$a$ is a matrix but $b$ is not (matrix-scalar multiplication), but it also occurs when using
einsum
.System info (python version, jaxlib version, accelerator, etc.)
Tested on CPU:
as well as on Colab (jax
0.4.26
both on CPU and GPU).The text was updated successfully, but these errors were encountered: