Skip to content

Commit 444409d

Browse files
Nush395Torax team
authored andcommitted
Add reverse-mode differentiable jax.lax.while_loop.
PiperOrigin-RevId: 794095925
1 parent 68d7481 commit 444409d

File tree

2 files changed

+101
-2
lines changed

2 files changed

+101
-2
lines changed

torax/_src/jax_utils.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import functools
1919
import inspect
2020
import os
21-
from typing import Any, Callable, Literal, TypeVar
21+
from typing import Any, Callable, Literal, ParamSpec, TypeAlias, TypeVar
2222

2323
import chex
2424
import equinox as eqx
@@ -27,7 +27,8 @@
2727
import numpy as np
2828

2929
T = TypeVar('T')
30-
BooleanNumeric = Any # A bool, or a Boolean array.
30+
BooleanNumeric: TypeAlias = Any # A bool, or a Boolean array.
31+
_State = ParamSpec('_State')
3132

3233

3334
@functools.cache
@@ -296,3 +297,49 @@ def init_array(x):
296297
return x
297298

298299
return jax.tree_util.tree_map(init_array, t)
300+
301+
302+
@functools.partial(jit, static_argnames=['cond_fun', 'body_fun', 'max_steps'])
303+
def max_steps_while_loop(
304+
cond_fun: Callable[[_State], BooleanNumeric],
305+
body_fun: Callable[[_State], _State],
306+
init_val: _State,
307+
max_steps: int,
308+
) -> _State:
309+
"""A reverse-mode differentiable while_loop using jax.lax.scan.
310+
311+
Args:
312+
cond_fun: A function `cond_fun(state)` that returns a boolean, indicating
313+
whether to continue the loop.
314+
body_fun: A function `body_fun(state)` that returns the new state.
315+
init_val: The initial state for the loop.
316+
max_steps: An integer, the maximum number of iterations the loop can
317+
perform. This is crucial for defining a fixed computational graph for
318+
scan.
319+
320+
Returns:
321+
The final state after the loop terminates or `max_steps` are reached.
322+
"""
323+
324+
# Initial carry for the scan: (current_state, condition_met)
325+
initial_scan_carry = (init_val, jnp.array(True, dtype=jnp.bool_))
326+
327+
def scan_body(carry, _):
328+
current_state, cond_met_prev = carry
329+
330+
cond_eval = cond_fun(current_state)
331+
should_execute_body = jnp.logical_and(cond_met_prev, cond_eval)
332+
333+
next_state = jax.lax.cond(
334+
should_execute_body, body_fun, lambda s: s, current_state
335+
)
336+
next_cond_met = should_execute_body
337+
338+
return (next_state, next_cond_met), None
339+
340+
dummy_xs = jnp.arange(max_steps)
341+
342+
# Perform the scan.
343+
(final_state, _), _ = jax.lax.scan(scan_body, initial_scan_carry, dummy_xs)
344+
345+
return final_state

torax/_src/tests/jax_utils_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,58 @@ def f(x, z, y=2.0):
150150
x = {'temp1': jnp.array(1.3), 'temp2': jnp.array(2.6)}
151151
chex.assert_trees_all_close(f_non_inlined(x, z='left'), f(x, z='left'))
152152

153+
def test_max_steps_while_loop_grad(self):
154+
terminating_step = 4
155+
156+
def cond_fun(state):
157+
i, _ = state
158+
return i < terminating_step
159+
160+
def body_fun(state):
161+
i, value = state
162+
next_i = i + 1
163+
next_value = jnp.sin(value)
164+
return next_i, next_value
165+
166+
init_state = (0, 0.5)
167+
max_steps = 10
168+
169+
with self.subTest('forward_agrees_with_while_loop'):
170+
output_state = jax_utils.max_steps_while_loop(
171+
cond_fun, body_fun, init_state, max_steps
172+
)
173+
chex.assert_trees_all_close(
174+
output_state, jax.lax.while_loop(cond_fun, body_fun, init_state)
175+
)
176+
177+
def f_while(x, max_steps=4):
178+
init_state = (0, x)
179+
return jax_utils.max_steps_while_loop(
180+
cond_fun, body_fun, init_state, max_steps=max_steps
181+
)[1]
182+
183+
def f(x):
184+
"""Apply sin recursively to x {terminating_step} times."""
185+
return jnp.sin(jnp.sin(jnp.sin(jnp.sin(x))))
186+
187+
with self.subTest('forward_agrees_with_explicit'):
188+
chex.assert_trees_all_close(f_while(0.5), f(0.5))
189+
with self.subTest('grad_agrees_with_explicit'):
190+
chex.assert_trees_all_close(jax.grad(f_while)(0.5), jax.grad(f)(0.5))
191+
192+
with self.subTest('max_steps_is_respected'):
193+
194+
def double_sin(x):
195+
return jnp.sin(jnp.sin(x))
196+
197+
chex.assert_trees_all_close(f_while(0.5, max_steps=2), double_sin(0.5))
198+
chex.assert_trees_all_close(
199+
jax.grad(f_while)(0.5, max_steps=2), jax.grad(double_sin)(0.5)
200+
)
201+
chex.assert_trees_all_close(
202+
f_while(0.5, max_steps=3), jnp.sin(double_sin(0.5))
203+
)
204+
153205

154206
if __name__ == '__main__':
155207
absltest.main()

0 commit comments

Comments
 (0)