Test runtime_call inside while loops and optional computations#1695
Test runtime_call inside while loops and optional computations#1695blasphemetheus wants to merge 4 commits intoelixir-nx:mainfrom
Conversation
|
Let's wait until my other PR gets merged because the code structure might change a lot. Thanks for the contribution |
exla/lib/exla/defn.ex
Outdated
| end | ||
|
|
||
| {typespecs, call_inputs} = | ||
| if has_pid do |
There was a problem hiding this comment.
PID will always be present
exla/lib/exla/defn.ex
Outdated
| {typespecs, call_inputs} = | ||
| if has_pid do | ||
| pid_typespec = Value.get_typespec(state.callback_pid_value) | ||
| {typespecs ++ [pid_typespec], call_inputs ++ [state.callback_pid_value]} |
There was a problem hiding this comment.
I think you want to prepend it, don't you?
Above you do tl(results) which seems to indicate you want to throw out the first item, which I assume is the PID threaded value
exla/lib/exla/defn.ex
Outdated
| {nil, result} | ||
| end | ||
|
|
||
| result = if has_pid, do: Enum.slice(result, 0..-2//1), else: result |
There was a problem hiding this comment.
PID will always be present, so you can simplify to tl(result)
exla/lib/exla/defn.ex
Outdated
| ret = List.flatten(res) | ||
| ret = if outer_pid, do: ret ++ [state.callback_pid_value], else: ret | ||
| ret = if outer_token, do: [get_token(comp_cache) | ret], else: ret | ||
| Value.func_return(function, ret) |
There was a problem hiding this comment.
we should prepend callback pid, not append
polvalente
left a comment
There was a problem hiding this comment.
I don't think :optional is going to be testable, given that it's a private Nx API, maybe we just revert those changes.
However, I think we should do the same with in-line anonymous functions.
Make sure to update your branch with the current state of main!
|
ope on main after #1694 this bug is fixed. |
|
@blasphemetheus I was under the impression we would still need to support everything in this PR. Can we at least keep the PR with tests for while and inline anonymous functions? |
Thread callback_pid_value through while loop regions and optional computation functions, which are IsolatedFromAbove in StableHLO. Also fix reset_token/merge_outfeed to preserve runtime_callbacks across cache boundaries so callbacks registered inside while bodies are not silently lost. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…nals Per polvalente's review: - PID is always present, so remove has_pid conditionals - Prepend PID (not append) for consistency with while loop ordering - Use tl/pattern match to extract PID instead of Enum.slice
The runtime_call-in-while-loop bug was fixed by elixir-nx#1694. This PR now only adds regression tests for: - runtime_call inside while loops - runtime_call with inline anonymous functions Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
f88d1b3 to
95b35fc
Compare
befa007 to
409286d
Compare
- runtime_call inside while loop (increment to 10) - while loop with tuple state (multiply by 2 three times) - cond branches (positive → double, negative → negate) - multiple runtime_calls in one while body (add 1 then double) - type-changing callback (s32 → f32) - nested while loops with runtime_call - while with separate accumulator - tuple input inside while (skipped: shape mismatch on EXLA, wrong result on evaluator) All 17 tests pass (1 skipped). Tested on EXLA host. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
409286d to
1c4f9d5
Compare
|
ok there's a skipped test: Behavior
|
Thread callback_pid_value through while loop regions and optional computation functions, which are IsolatedFromAbove in StableHLO. Also fix reset_token/merge_outfeed to preserve runtime_callbacks across cache boundaries so callbacks registered inside while bodies are not silently lost.
this is to fix a test hang on runtime_call inside while loops.
also bug has two layers:
will need rebasing if #1694 merges first for sure (touches some of the same areas)