Skip to content

Test runtime_call inside while loops and optional computations#1695

Open
blasphemetheus wants to merge 4 commits intoelixir-nx:mainfrom
blasphemetheus:fix/runtime-call-while-loop
Open

Test runtime_call inside while loops and optional computations#1695
blasphemetheus wants to merge 4 commits intoelixir-nx:mainfrom
blasphemetheus:fix/runtime-call-while-loop

Conversation

@blasphemetheus
Copy link
Contributor

@blasphemetheus blasphemetheus commented Mar 15, 2026

Thread callback_pid_value through while loop regions and optional computation functions, which are IsolatedFromAbove in StableHLO. Also fix reset_token/merge_outfeed to preserve runtime_callbacks across cache boundaries so callbacks registered inside while bodies are not silently lost.

this is to fix a test hang on runtime_call inside while loops.

also bug has two layers:

  1. MLIR level — callback_pid_value from outer scope used inside IsolatedFromAbove while regions. This is the obvious one.
  2. Cache level — reset_token created a new map dropping runtime_callbacks_key(), and merge_outfeed didn't merge inner callbacks back to outer. So even with correct MLIR threading. callbacks registered inside while bodies were silently lost and the runtime passed zeroed PID bytes. This is the non-obvious one and worth highlighting since it could bite other cache-carried state in the future.

will need rebasing if #1694 merges first for sure (touches some of the same areas)

@polvalente
Copy link
Contributor

Let's wait until my other PR gets merged because the code structure might change a lot.

Thanks for the contribution

@blasphemetheus blasphemetheus marked this pull request as draft March 15, 2026 01:32
end

{typespecs, call_inputs} =
if has_pid do
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PID will always be present

{typespecs, call_inputs} =
if has_pid do
pid_typespec = Value.get_typespec(state.callback_pid_value)
{typespecs ++ [pid_typespec], call_inputs ++ [state.callback_pid_value]}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you want to prepend it, don't you?
Above you do tl(results) which seems to indicate you want to throw out the first item, which I assume is the PID threaded value

{nil, result}
end

result = if has_pid, do: Enum.slice(result, 0..-2//1), else: result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PID will always be present, so you can simplify to tl(result)

Comment on lines +1927 to +1930
ret = List.flatten(res)
ret = if outer_pid, do: ret ++ [state.callback_pid_value], else: ret
ret = if outer_token, do: [get_token(comp_cache) | ret], else: ret
Value.func_return(function, ret)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should prepend callback pid, not append

Copy link
Contributor

@polvalente polvalente left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think :optional is going to be testable, given that it's a private Nx API, maybe we just revert those changes.

However, I think we should do the same with in-line anonymous functions.

Make sure to update your branch with the current state of main!

@blasphemetheus
Copy link
Contributor Author

ope on main after #1694 this bug is fixed.
There are some tests I could include, but that'd just be a test-coverage/regression PR

@polvalente
Copy link
Contributor

@blasphemetheus I was under the impression we would still need to support everything in this PR. Can we at least keep the PR with tests for while and inline anonymous functions?

@polvalente polvalente reopened this Mar 18, 2026
blasphemetheus and others added 3 commits March 19, 2026 01:00
Thread callback_pid_value through while loop regions and optional
computation functions, which are IsolatedFromAbove in StableHLO.
Also fix reset_token/merge_outfeed to preserve runtime_callbacks
across cache boundaries so callbacks registered inside while bodies
are not silently lost.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…nals

Per polvalente's review:
- PID is always present, so remove has_pid conditionals
- Prepend PID (not append) for consistency with while loop ordering
- Use tl/pattern match to extract PID instead of Enum.slice
The runtime_call-in-while-loop bug was fixed by elixir-nx#1694. This PR now
only adds regression tests for:
- runtime_call inside while loops
- runtime_call with inline anonymous functions

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@blasphemetheus blasphemetheus force-pushed the fix/runtime-call-while-loop branch from f88d1b3 to 95b35fc Compare March 19, 2026 06:02
@blasphemetheus blasphemetheus changed the title Fix runtime_call inside while loops and optional computations Test runtime_call inside while loops and optional computations Mar 19, 2026
@blasphemetheus blasphemetheus force-pushed the fix/runtime-call-while-loop branch 3 times, most recently from befa007 to 409286d Compare March 19, 2026 06:23
- runtime_call inside while loop (increment to 10)
- while loop with tuple state (multiply by 2 three times)
- cond branches (positive → double, negative → negate)
- multiple runtime_calls in one while body (add 1 then double)
- type-changing callback (s32 → f32)
- nested while loops with runtime_call
- while with separate accumulator
- tuple input inside while (skipped: shape mismatch on EXLA,
  wrong result on evaluator)

All 17 tests pass (1 skipped). Tested on EXLA host.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@blasphemetheus blasphemetheus force-pushed the fix/runtime-call-while-loop branch from 409286d to 1c4f9d5 Compare March 19, 2026 06:26
@blasphemetheus
Copy link
Contributor Author

ok there's a skipped test: runtime_call with tuple input inside while.

  defn runtime_call_tuple_in_while(x, y) do
    {result, _} =
      while {x, count = Nx.tensor(0)}, Nx.less(count, 3) do
        summed = Nx.runtime_call(x, {x, y}, &sum_tuple_callback/2)
        {summed, count + 1}
      end
    result
  end

Behavior

  • EXLA: Crashes with shape_mismatch in args_spec — the runtime can't match the tuple input shape to the expected template when the callback is inside a while body
  • Evaluator (dev mode): Returns 4.0 instead of 31.0 — the captured y variable isn't being threaded correctly through the while iterations, so the callback gets wrong values

@blasphemetheus blasphemetheus marked this pull request as ready for review March 19, 2026 06:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants