|
18 | 18 | import functools |
19 | 19 | import inspect |
20 | 20 | import os |
21 | | -from typing import Any, Callable, Literal, TypeVar |
| 21 | +from typing import Any, Callable, Literal, ParamSpec, TypeAlias, TypeVar |
22 | 22 |
|
23 | 23 | import chex |
24 | 24 | import equinox as eqx |
|
27 | 27 | import numpy as np |
28 | 28 |
|
29 | 29 | T = TypeVar('T') |
30 | | -BooleanNumeric = Any # A bool, or a Boolean array. |
| 30 | +BooleanNumeric: TypeAlias = Any # A bool, or a Boolean array. |
| 31 | +_State = ParamSpec('_State') |
31 | 32 |
|
32 | 33 |
|
33 | 34 | @functools.cache |
@@ -296,3 +297,49 @@ def init_array(x): |
296 | 297 | return x |
297 | 298 |
|
298 | 299 | return jax.tree_util.tree_map(init_array, t) |
| 300 | + |
| 301 | + |
| 302 | +@functools.partial(jit, static_argnames=['cond_fun', 'body_fun', 'max_steps']) |
| 303 | +def max_steps_while_loop( |
| 304 | + cond_fun: Callable[[_State], BooleanNumeric], |
| 305 | + body_fun: Callable[[_State], _State], |
| 306 | + init_val: _State, |
| 307 | + max_steps: int, |
| 308 | +) -> _State: |
| 309 | + """A reverse-mode differentiable while_loop using jax.lax.scan. |
| 310 | +
|
| 311 | + Args: |
| 312 | + cond_fun: A function `cond_fun(state)` that returns a boolean, indicating |
| 313 | + whether to continue the loop. |
| 314 | + body_fun: A function `body_fun(state)` that returns the new state. |
| 315 | + init_val: The initial state for the loop. |
| 316 | + max_steps: An integer, the maximum number of iterations the loop can |
| 317 | + perform. This is crucial for defining a fixed computational graph for |
| 318 | + scan. |
| 319 | +
|
| 320 | + Returns: |
| 321 | + The final state after the loop terminates or `max_steps` are reached. |
| 322 | + """ |
| 323 | + |
| 324 | + # Initial carry for the scan: (current_state, condition_met) |
| 325 | + initial_scan_carry = (init_val, jnp.array(True, dtype=jnp.bool_)) |
| 326 | + |
| 327 | + def scan_body(carry, _): |
| 328 | + current_state, cond_met_prev = carry |
| 329 | + |
| 330 | + cond_eval = cond_fun(current_state) |
| 331 | + should_execute_body = jnp.logical_and(cond_met_prev, cond_eval) |
| 332 | + |
| 333 | + next_state = jax.lax.cond( |
| 334 | + should_execute_body, body_fun, lambda s: s, current_state |
| 335 | + ) |
| 336 | + next_cond_met = should_execute_body |
| 337 | + |
| 338 | + return (next_state, next_cond_met), None |
| 339 | + |
| 340 | + dummy_xs = jnp.arange(max_steps) |
| 341 | + |
| 342 | + # Perform the scan. |
| 343 | + (final_state, _), _ = jax.lax.scan(scan_body, initial_scan_carry, dummy_xs) |
| 344 | + |
| 345 | + return final_state |
0 commit comments