Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ComplexWarning in VJP when using complex matrix multiplication #21188

Open
mathisgerdes opened this issue May 11, 2024 · 2 comments
Open

ComplexWarning in VJP when using complex matrix multiplication #21188

mathisgerdes opened this issue May 11, 2024 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@mathisgerdes
Copy link
Contributor

Description

Consider a function $f(a) = a b$ with $b$ some complex parameter and $a$ real. I want to compute the vector-Jacobian product $t^* \partial f(a)$. When $a b$ represents matrix multiplication, a warning is raised about an imaginary part being discarded.

For $a$, $b$ scalar it works as expected and without warning (imported jax, jax.numpy as jnp, and warnings):

a = jnp.array(1.0)
b = jnp.array(1.0j)
t = jnp.array(1.0j)

with warnings.catch_warnings():
  warnings.simplefilter("error", jnp.ComplexWarning)
  
  _, vjp_fn = jax.vjp(lambda x: x * b, a)
  vjp_fn(t)

However, if we promote $a$ and $b$ to matrices, this raises a ComplexWarning:

a = jnp.array([[1.0]])
b = jnp.array([[1.0j]])
t = jnp.array([[1.0j]])

with warnings.catch_warnings():
  warnings.simplefilter("error", jnp.ComplexWarning)
  
  _, vjp_fn = jax.vjp(lambda x: x @ b, a)
  vjp_fn(t)
ComplexWarning: Casting complex values to real discards the imaginary part

I would not expect a complex warning in the latter case. Although the output is the same in both cases (up to shapes), the warning suggests something could go wrong. Unless I'm missing something, the two behaviors are at least inconsistent.

The warning is not raised when $a$ is a matrix but $b$ is not (matrix-scalar multiplication), but it also occurs when using einsum.

System info (python version, jaxlib version, accelerator, etc.)

Tested on CPU:

jax:    0.4.23
jaxlib: 0.4.23
numpy:  1.26.3
python: 3.9.9 (v3.9.9:ccb0e6a345, Nov 15 2021, 13:06:05)  [Clang 13.0.0 (clang-1300.0.29.3)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1

as well as on Colab (jax 0.4.26 both on CPU and GPU).

@mathisgerdes mathisgerdes added the bug Something isn't working label May 11, 2024
@mattjj
Copy link
Member

mattjj commented May 11, 2024

Thanks for the clear report! I'll take a look, though my initial guess is that this is a spurious warning we should suppress.

@mattjj mattjj self-assigned this May 11, 2024
@mathisgerdes
Copy link
Contributor Author

Yes, that would be a good resolution. I'm also unaware of anything going wrong. But the warning made me waste some time searching for what I might be doing wrong, so I thought fixing this might be useful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants