Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reductions along leading axes can be incredibly slow in C and Numba backends #935

Open
ricardoV94 opened this issue Jul 15, 2024 · 2 comments

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 15, 2024

Description

Reported by @aseyboldt

from timeit import timeit
from functools import partial

import pytensor
import pytensor.tensor as pt
import numpy as np

from jax import block_until_ready

N = 256
r = 10

x_test = np.random.uniform(size=(N, N, N))
x = pytensor.shared(x_test, name="x", shape=x_test.shape)

for axis in [0, 1, 2]:    
    y = x.sum(axis)    
    c_fn = pytensor.function([], y, mode="FAST_RUN")
    numba_fn = pytensor.function([], y, mode="NUMBA")
    jax_fn_ = pytensor.function([], y, mode="JAX")
    jax_fn = lambda : np.asarray(jax_fn_())
    numpy_fn = partial(np.sum, x_test, axis=axis)
    
    np.testing.assert_allclose(c_fn(), numpy_fn())
    np.testing.assert_allclose(numba_fn(), numpy_fn())
    np.testing.assert_allclose(jax_fn(), numpy_fn())
    print(f"\n{axis=}")
    for name, fn in [("C", c_fn), ("numba", numba_fn), ("jax", jax_fn), ("numpy", numpy_fn)]:
        print(f"  | {name}: {timeit(fn, number=r) / r: .4f}s")       

I'm running JAX on a CPU

axis=0
  | C:  1.8741s
  | numba:  1.7838s
  | jax:  0.7674s
  | numpy:  0.0075s
axis=1
  | C:  0.0286s
  | numba:  0.0280s
  | jax:  0.0133s
  | numpy:  0.0083s
axis=2
  | C:  0.0142s
  | numba:  0.0153s
  | jax:  0.0046s
  | numpy:  0.0069s

#931 makes numba slightly better in axis=0 at the expense of doing worse on axis=2

axis=0
  | C:  1.6623s
  | numba:  0.0203s
  | jax:  0.6826s
  | numpy:  0.0098s
axis=1
  | C:  0.0294s
  | numba:  0.0288s
  | jax:  0.0126s
  | numpy:  0.0078s
axis=2
  | C:  0.0141s
  | numba:  0.1160s
  | jax:  0.0044s
  | numpy:  0.0064s

In any case numpy is wiping our ass :)

Surprisingly JAX is also doing bad on the first case, although not as bad as C/Numba. Performance is probably due to bad iteration order / cache access

@ricardoV94 ricardoV94 changed the title Reduction along leading axis is incredibly slow in C and Numba backends Reductions can be incredibly slow in C and Numba backends Jul 15, 2024
@ricardoV94 ricardoV94 changed the title Reductions can be incredibly slow in C and Numba backends Reductions along leading axes can be incredibly slow in C and Numba backends Jul 15, 2024
@ricardoV94
Copy link
Member Author

Opened issue with numba: numba/numba#9679

@ricardoV94
Copy link
Member Author

C case is now better after #971

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant