Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected speedup from wrapping function call in trivial jax.lax.cond statement #21065

Open
TonyZhou729 opened this issue May 3, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@TonyZhou729
Copy link

Description

Hi,

We noticed a strange speed-up when a trivial lax.cond statement is used to call a function rather than directly calling a function itself.

In the reproduction of the issue below, we use JIT on a main() function which contains a lax.scan() loop. In each loop call, if we insert a lax.cond() around the function we call with the condition that the loop index i (runs from 0 to Ny for Ny steps) is greater than -1, which is always true. This seemingly unnecessary choice somehow causes a speed up.

import jax.numpy as jnp
from jax import jit, lax
from jax.scipy.ndimage import map_coordinates
import time

Nx = 300
Ny = 100000
x_axis = jnp.linspace(5., 12.75, Nx)

@jit
def main():
    y_axis = jnp.linspace(0, 1, Ny)

    # Initial value of B is just (Nx, Ny) size arrays of zeros.
    B = jnp.zeros((Nx, Ny), dtype="float32")

    def loop_in_main(carry, i):
        B = carry
        y = y_axis[i]

        """ Obtain an array A using interp_A_from_B(), picking one of three ways """
        # Case 1: We simply run interp_A_from_B() every step
        A = interp_A_from_B((y, y_axis, B))

        # Case 2: We use a seemingly trivial lax.cond wrapper, but will still always run
        # interp_A_from_B since index i is always greater than -1.
        # For some reason we observe a speed up over case 1.
        #A = lax.cond(i>-1, interp_A_from_B, false_func, (y, y_axis, B))

        # Update B array with values of A from this loop.
        B = set_B_to_A(i, B, A)

        return B, None

    # Use lax.scan to run loop and update B Ny times.
    # Index i will run through jnp.arange(Ny) = (0, 1, 2, ..., Ny-1)
    B, _ = lax.scan(loop_in_main, B, jnp.arange(Ny))

    return B

def interp_A_from_B(params):
    # B is a (Nx, Ny) array.
    # A is a (Nx,) array.

    y, y_axis, B    = params
    # Precise value of y to interpolate at.
    y_prime         = y - jnp.log(x_axis[1:Nx] / x_axis[:Nx-1])
    # Convert to index position within y_axis, to use with ndimage.map_coordinates.
    y_prime_indices = jnp.interp(y_prime, y_axis, jnp.arange(Ny))
    # Interpolated version of A from B via 2D map_coordinates.
    interp          = map_coordinates(B, [jnp.arange(1, Nx), y_prime_indices], order=1)
    # Here, only use the interpolated result for values of y_prime larger than the smallest y in     y_axis.
    condition       = y_prime < y_axis[0]

    # Put A array together, with some fill in values for where we don't want the interpolated value.
    A               = condition * jnp.exp(-x_axis[:Nx-1]) \
                    + (1-condition) * interp
    A               = jnp.append(A, jnp.exp(-x_axis[-1]))

    return A

def set_B_to_A(i, B, A):
    # Update a column of B with the current value of A.
    B = B.at[:, i].set(A)
    return B

def false_func(params):
    # Trivial false function, sets all entries of A to some fill values if called.
    A = jnp.exp(-x_axis)
    return A

""" Running main() a couple times to see the speed """
for i in range(5):
    s = time.time()
    B = main()
    print(time.time() - s)

When using case 1 in loop_in_main() and calling main() 5 times we observe runtimes of (in seconds)

2.593395233154297
2.245166540145874
2.2276456356048584
2.242725372314453
2.2277612686157227

But switching to case 2 we see

2.095738649368286
1.769906997680664
1.7625515460968018
1.7623748779296875
1.7783701419830322

In both cases the first run time is longer due to JIT compilation. We checked that this speed up scales with Ny, the number of steps in lax.scan. In our code with more computations in each step the speed up is even more significant.

Thank you in advance for your help and comments!
Tony

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.16
jaxlib: 0.4.16
numpy:  1.24.3
python: 3.10.10 (main, Mar 21 2023, 18:45:11) [GCC 11.2.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
@TonyZhou729 TonyZhou729 added the bug Something isn't working label May 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant