You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We noticed a strange speed-up when a trivial lax.cond statement is used to call a function rather than directly calling a function itself.
In the reproduction of the issue below, we use JIT on a main() function which contains a lax.scan() loop. In each loop call, if we insert a lax.cond() around the function we call with the condition that the loop index i (runs from 0 to Ny for Ny steps) is greater than -1, which is always true. This seemingly unnecessary choice somehow causes a speed up.
importjax.numpyasjnpfromjaximportjit, laxfromjax.scipy.ndimageimportmap_coordinatesimporttimeNx=300Ny=100000x_axis=jnp.linspace(5., 12.75, Nx)
@jitdefmain():
y_axis=jnp.linspace(0, 1, Ny)
# Initial value of B is just (Nx, Ny) size arrays of zeros.B=jnp.zeros((Nx, Ny), dtype="float32")
defloop_in_main(carry, i):
B=carryy=y_axis[i]
""" Obtain an array A using interp_A_from_B(), picking one of three ways """# Case 1: We simply run interp_A_from_B() every stepA=interp_A_from_B((y, y_axis, B))
# Case 2: We use a seemingly trivial lax.cond wrapper, but will still always run# interp_A_from_B since index i is always greater than -1.# For some reason we observe a speed up over case 1.#A = lax.cond(i>-1, interp_A_from_B, false_func, (y, y_axis, B))# Update B array with values of A from this loop.B=set_B_to_A(i, B, A)
returnB, None# Use lax.scan to run loop and update B Ny times.# Index i will run through jnp.arange(Ny) = (0, 1, 2, ..., Ny-1)B, _=lax.scan(loop_in_main, B, jnp.arange(Ny))
returnBdefinterp_A_from_B(params):
# B is a (Nx, Ny) array.# A is a (Nx,) array.y, y_axis, B=params# Precise value of y to interpolate at.y_prime=y-jnp.log(x_axis[1:Nx] /x_axis[:Nx-1])
# Convert to index position within y_axis, to use with ndimage.map_coordinates.y_prime_indices=jnp.interp(y_prime, y_axis, jnp.arange(Ny))
# Interpolated version of A from B via 2D map_coordinates.interp=map_coordinates(B, [jnp.arange(1, Nx), y_prime_indices], order=1)
# Here, only use the interpolated result for values of y_prime larger than the smallest y in y_axis.condition=y_prime<y_axis[0]
# Put A array together, with some fill in values for where we don't want the interpolated value.A=condition*jnp.exp(-x_axis[:Nx-1]) \
+ (1-condition) *interpA=jnp.append(A, jnp.exp(-x_axis[-1]))
returnAdefset_B_to_A(i, B, A):
# Update a column of B with the current value of A.B=B.at[:, i].set(A)
returnBdeffalse_func(params):
# Trivial false function, sets all entries of A to some fill values if called.A=jnp.exp(-x_axis)
returnA""" Running main() a couple times to see the speed """foriinrange(5):
s=time.time()
B=main()
print(time.time() -s)
When using case 1 in loop_in_main() and calling main() 5 times we observe runtimes of (in seconds)
In both cases the first run time is longer due to JIT compilation. We checked that this speed up scales with Ny, the number of steps in lax.scan. In our code with more computations in each step the speed up is even more significant.
Thank you in advance for your help and comments!
Tony
System info (python version, jaxlib version, accelerator, etc.)
Description
Hi,
We noticed a strange speed-up when a trivial lax.cond statement is used to call a function rather than directly calling a function itself.
In the reproduction of the issue below, we use JIT on a main() function which contains a lax.scan() loop. In each loop call, if we insert a lax.cond() around the function we call with the condition that the loop index i (runs from 0 to Ny for Ny steps) is greater than -1, which is always true. This seemingly unnecessary choice somehow causes a speed up.
When using case 1 in loop_in_main() and calling main() 5 times we observe runtimes of (in seconds)
But switching to case 2 we see
In both cases the first run time is longer due to JIT compilation. We checked that this speed up scales with Ny, the number of steps in lax.scan. In our code with more computations in each step the speed up is even more significant.
Thank you in advance for your help and comments!
Tony
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: