Skip to content

Commit 95b35fc

Browse files
Revert code changes, keep only tests
The runtime_call-in-while-loop bug was fixed by #1694. This PR now only adds regression tests for: - runtime_call inside while loops - runtime_call with inline anonymous functions Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 11a411f commit 95b35fc

File tree

1 file changed

+36
-100
lines changed

1 file changed

+36
-100
lines changed

exla/lib/exla/defn.ex

Lines changed: 36 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -526,15 +526,6 @@ defmodule EXLA.Defn do
526526

527527
{initial, cache} = recur_composite(initial_arg, state, cache)
528528

529-
# Prepend PID before token so the order in the while state is [token, pid | values].
530-
# Both must be threaded explicitly because while regions are IsolatedFromAbove.
531-
initial =
532-
if state.callback_pid_value do
533-
[state.callback_pid_value | initial]
534-
else
535-
initial
536-
end
537-
538529
initial =
539530
if token = get_token(cache) do
540531
[token | initial]
@@ -548,19 +539,14 @@ defmodule EXLA.Defn do
548539
results =
549540
Value.while(function, pred_computation, body_computation, List.flatten(initial))
550541

551-
# Extract token and PID from results in the same order they were prepended.
552-
{token, results} =
553-
if get_token(cache) do
554-
[token | rest] = results
555-
{token, rest}
556-
else
557-
{nil, results}
558-
end
559-
560-
results = if state.callback_pid_value, do: tl(results), else: results
561-
result = wrap_tuple_result(results, initial_arg)
562-
cache = if token, do: update_token(cache, token), else: cache
563-
{result, cache}
542+
if get_token(cache) do
543+
[token | results] = results
544+
result = wrap_tuple_result(results, initial_arg)
545+
{result, update_token(cache, token)}
546+
else
547+
result = wrap_tuple_result(results, initial_arg)
548+
{result, cache}
549+
end
564550
end
565551

566552
defp cached_recur_operator(:cond, %T{data: %Expr{args: args}} = t, state, cache) do
@@ -761,37 +747,15 @@ defmodule EXLA.Defn do
761747
{computation, Map.put(cache, key, computation)}
762748
end
763749

764-
has_token = get_token(cache) != nil
765-
766-
typespecs = container_to_typespecs(expr)
767-
call_inputs = call_args
768-
769-
{typespecs, call_inputs} =
770-
if has_token do
771-
{[Typespec.token() | typespecs], [get_token(cache) | call_inputs]}
772-
else
773-
{typespecs, call_inputs}
774-
end
775-
776-
# PID is always present — prepend it (consistent with while loop ordering)
777-
pid_typespec = Value.get_typespec(state.callback_pid_value)
778-
typespecs = [pid_typespec | typespecs]
779-
call_inputs = [state.callback_pid_value | call_inputs]
780-
781-
result = Value.call(state.builder, call_inputs, call_body, typespecs)
782-
783-
{token, result} =
784-
if has_token do
785-
[token | rest] = result
786-
{token, rest}
787-
else
788-
{nil, result}
789-
end
790-
791-
# Drop the PID from results (it was prepended)
792-
[_pid | result] = result
793-
cache = if token, do: update_token(cache, token), else: cache
794-
{wrap_tuple_result(result, expr), cache}
750+
if token = get_token(cache) do
751+
typespecs = [Typespec.token() | container_to_typespecs(expr)]
752+
[token | result] = Value.call(state.builder, [token | call_args], call_body, typespecs)
753+
{wrap_tuple_result(result, expr), update_token(cache, token)}
754+
else
755+
typespecs = container_to_typespecs(expr)
756+
result = Value.call(state.builder, call_args, call_body, typespecs)
757+
{wrap_tuple_result(result, expr), cache}
758+
end
795759
end
796760

797761
defp cached_recur_operator(
@@ -1689,21 +1653,11 @@ defmodule EXLA.Defn do
16891653
defp new_cache(outfeed),
16901654
do: %{__MODULE__ => outfeed}
16911655

1692-
defp merge_outfeed(%{__MODULE__ => outfeed} = cache, new_cache) do
1693-
new_outfeed = new_cache[__MODULE__]
1694-
inner_callbacks = Map.get(new_cache, runtime_callbacks_key(), [])
1656+
defp merge_outfeed(%{__MODULE__ => outfeed} = cache, %{__MODULE__ => new_outfeed}),
1657+
do: %{cache | __MODULE__ => Outfeed.with_token(new_outfeed, outfeed.token)}
16951658

1696-
cache
1697-
|> Map.put(__MODULE__, Outfeed.with_token(new_outfeed, outfeed.token))
1698-
|> Map.update(runtime_callbacks_key(), inner_callbacks, &(inner_callbacks ++ &1))
1699-
end
1700-
1701-
defp reset_token(cache, token) do
1702-
%{
1703-
__MODULE__ => Outfeed.with_token(cache[__MODULE__], token),
1704-
runtime_callbacks_key() => Map.get(cache, runtime_callbacks_key(), [])
1705-
}
1706-
end
1659+
defp reset_token(%{__MODULE__ => outfeed}, token),
1660+
do: %{__MODULE__ => Outfeed.with_token(outfeed, token)}
17071661

17081662
defp update_token(%{__MODULE__ => outfeed} = cache, token),
17091663
do: %{cache | __MODULE__ => Outfeed.with_token(outfeed, token)}
@@ -1801,20 +1755,11 @@ defmodule EXLA.Defn do
18011755
{region, args} = Function.push_region(state.builder, arg_typespecs)
18021756

18031757
outer_token = get_token(cache)
1804-
outer_pid = state.callback_pid_value
18051758

1806-
{inner_token, args} =
1759+
{inner_token, arg_params} =
18071760
if outer_token do
1808-
[arg_token | rest] = args
1809-
{arg_token, rest}
1810-
else
1811-
{nil, args}
1812-
end
1813-
1814-
{inner_pid, arg_params} =
1815-
if outer_pid do
1816-
[arg_pid | rest] = args
1817-
{arg_pid, rest}
1761+
[arg_token | arg_params] = args
1762+
{arg_token, arg_params}
18181763
else
18191764
{nil, args}
18201765
end
@@ -1824,8 +1769,7 @@ defmodule EXLA.Defn do
18241769
state = %{
18251770
state
18261771
| params: Map.new(params),
1827-
scope_ids: Tree.scope_ids(expr),
1828-
callback_pid_value: inner_pid
1772+
scope_ids: Tree.scope_ids(expr)
18291773
}
18301774

18311775
expr =
@@ -1839,10 +1783,11 @@ defmodule EXLA.Defn do
18391783

18401784
res =
18411785
if type == :with_token do
1842-
flat = List.flatten(res)
1843-
flat = if outer_pid, do: [state.callback_pid_value | flat], else: flat
1844-
flat = if outer_token, do: [get_token(comp_cache) | flat], else: flat
1845-
flat
1786+
if outer_token do
1787+
[get_token(comp_cache) | List.flatten(res)]
1788+
else
1789+
List.flatten(res)
1790+
end
18461791
else
18471792
Enum.map(res, &to_type(&1, type))
18481793
end
@@ -1860,9 +1805,7 @@ defmodule EXLA.Defn do
18601805
out_typespecs = container_to_typespecs(expr)
18611806

18621807
outer_token = get_token(cache)
1863-
outer_pid = state.callback_pid_value
18641808
token_typespec = Typespec.token()
1865-
pid_typespec = if outer_pid, do: Value.get_typespec(outer_pid)
18661809

18671810
{arg_typespecs, out_typespecs} =
18681811
if outer_token do
@@ -1871,10 +1814,6 @@ defmodule EXLA.Defn do
18711814
{arg_typespecs, out_typespecs}
18721815
end
18731816

1874-
# PID is always present — prepend (consistent with while loop ordering)
1875-
arg_typespecs = [pid_typespec | arg_typespecs]
1876-
out_typespecs = [pid_typespec | out_typespecs]
1877-
18781817
function = EXLA.MLIR.Module.add_function(module, name, arg_typespecs, out_typespecs)
18791818
args = EXLA.MLIR.Function.get_arguments(function)
18801819

@@ -1886,25 +1825,22 @@ defmodule EXLA.Defn do
18861825
{nil, args}
18871826
end
18881827

1889-
# PID is always the next arg after token
1890-
[inner_pid | args] = args
1891-
18921828
params = Enum.with_index(args, fn param, i -> {i, param} end)
18931829

18941830
state = %{
18951831
state
18961832
| builder: function,
18971833
params: Map.new(params),
1898-
scope_ids: Tree.scope_ids(expr),
1899-
callback_pid_value: inner_pid
1834+
scope_ids: Tree.scope_ids(expr)
19001835
}
19011836

19021837
{res, comp_cache} = recur_composite(expr, state, reset_token(cache, inner_token))
19031838

1904-
ret = List.flatten(res)
1905-
ret = [state.callback_pid_value | ret]
1906-
ret = if outer_token, do: [get_token(comp_cache) | ret], else: ret
1907-
Value.func_return(function, ret)
1839+
if outer_token do
1840+
Value.func_return(function, [get_token(comp_cache) | List.flatten(res)])
1841+
else
1842+
Value.func_return(function, List.flatten(res))
1843+
end
19081844

19091845
{function, merge_outfeed(cache, comp_cache)}
19101846
end

0 commit comments

Comments
 (0)