Skip to content

Commit 409286d

Browse files
Add runtime_call regression tests for while, cond, and inline fns
- runtime_call inside while loop (increment to 10) - while loop with tuple state (multiply by 2 three times) - cond branches (positive → double, negative → negate) - multiple runtime_calls in one while body (add 1 then double) - type-changing callback (s32 → f32) - nested while loops with runtime_call - while with separate accumulator - tuple input inside while (skipped: shape mismatch on EXLA, wrong result on evaluator) All 17 tests pass (1 skipped). Tested on EXLA host. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 95b35fc commit 409286d

File tree

1 file changed

+122
-1
lines changed

1 file changed

+122
-1
lines changed

exla/test/exla/defn/runtime_call_test.exs

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,14 +209,135 @@ defmodule EXLA.Defn.RuntimeCallTest do
209209
assert_receive {:container_fun, ^ref}
210210
end
211211

212+
def add_one_callback(t, _opts), do: Nx.add(t, 1)
213+
def double_callback(t, _opts), do: Nx.multiply(t, 2)
214+
def negate_callback(t, _opts), do: Nx.negate(t)
215+
212216
defn runtime_call_in_while(x) do
213217
while x, Nx.less(x, 10) do
214-
Nx.runtime_call(x, x, fn t -> Nx.add(t, 1) end)
218+
Nx.runtime_call(x, x, &add_one_callback/2)
215219
end
216220
end
217221

218222
test "runtime_call inside while loop" do
219223
result = runtime_call_in_while(Nx.tensor(0))
220224
assert_equal(result, Nx.tensor(10))
221225
end
226+
227+
defn runtime_call_in_while_with_tuple(x) do
228+
{result, _count} =
229+
while {x, count = Nx.tensor(0)}, Nx.less(count, 3) do
230+
doubled = Nx.runtime_call(x, x, &double_callback/2)
231+
{doubled, count + 1}
232+
end
233+
234+
result
235+
end
236+
237+
test "runtime_call inside while loop with tuple state" do
238+
result = runtime_call_in_while_with_tuple(Nx.tensor(1.0))
239+
# 1.0 * 2 * 2 * 2 = 8.0
240+
assert_equal(result, Nx.tensor(8.0))
241+
end
242+
243+
defn runtime_call_in_cond(x) do
244+
if Nx.greater(x, 0) do
245+
Nx.runtime_call(x, x, &double_callback/2)
246+
else
247+
Nx.runtime_call(x, x, &negate_callback/2)
248+
end
249+
end
250+
251+
test "runtime_call inside cond branches" do
252+
assert_equal(runtime_call_in_cond(Nx.tensor(5.0)), Nx.tensor(10.0))
253+
assert_equal(runtime_call_in_cond(Nx.tensor(-3.0)), Nx.tensor(3.0))
254+
end
255+
256+
defn multiple_runtime_calls_in_while(x) do
257+
while x, Nx.less(x, 100) do
258+
step1 = Nx.runtime_call(x, x, &add_one_callback/2)
259+
Nx.runtime_call(step1, step1, &double_callback/2)
260+
end
261+
end
262+
263+
test "multiple runtime_calls in one while body" do
264+
# (0+1)*2=2, (2+1)*2=6, (6+1)*2=14, (14+1)*2=30, (30+1)*2=62, (62+1)*2=126
265+
result = multiple_runtime_calls_in_while(Nx.tensor(0.0))
266+
assert_equal(result, Nx.tensor(126.0))
267+
end
268+
269+
def cast_to_float_callback(t, _opts), do: Nx.as_type(t, :f32)
270+
271+
defn runtime_call_type_change(x) do
272+
out = %{x | type: {:f, 32}}
273+
Nx.runtime_call(out, x, &cast_to_float_callback/2)
274+
end
275+
276+
test "runtime_call where callback changes type" do
277+
result = runtime_call_type_change(Nx.tensor([1, 2, 3], type: :s32))
278+
assert Nx.type(result) == {:f, 32}
279+
assert_equal(result, Nx.tensor([1.0, 2.0, 3.0]))
280+
end
281+
282+
def add_ten_callback(t, _opts), do: Nx.add(t, 10)
283+
284+
defn nested_while_with_runtime_call(x) do
285+
{result, _} =
286+
while {x, outer = Nx.tensor(0)}, Nx.less(outer, 2) do
287+
{inner_result, _} =
288+
while {x, inner = Nx.tensor(0)}, Nx.less(inner, 3) do
289+
{Nx.runtime_call(x, x, &add_one_callback/2), inner + 1}
290+
end
291+
292+
{inner_result, outer + 1}
293+
end
294+
295+
result
296+
end
297+
298+
test "runtime_call inside nested while loops" do
299+
result = nested_while_with_runtime_call(Nx.tensor(0.0))
300+
# Inner while runs 3 times each outer iteration, outer runs 2 times
301+
# 0 → +3 = 3 → +3 = 6
302+
assert_equal(result, Nx.tensor(6.0))
303+
end
304+
305+
306+
def sum_tuple_callback({a, b}, _opts), do: Nx.add(a, b)
307+
308+
defn runtime_call_tuple_in_while(x, y) do
309+
{result, _} =
310+
while {x, count = Nx.tensor(0)}, Nx.less(count, 3) do
311+
summed = Nx.runtime_call(x, {x, y}, &sum_tuple_callback/2)
312+
{summed, count + 1}
313+
end
314+
315+
result
316+
end
317+
318+
@tag :skip
319+
test "runtime_call with tuple input inside while" do
320+
# Crashes on EXLA with shape_mismatch in args_spec.
321+
# Returns wrong value (4.0 instead of 31.0) on Evaluator.
322+
# Tuple inputs to runtime_call inside while with captured
323+
# variables don't work correctly.
324+
result = runtime_call_tuple_in_while(Nx.tensor(1.0), Nx.tensor(10.0))
325+
assert_equal(result, Nx.tensor(31.0))
326+
end
327+
328+
defn runtime_call_in_while_accumulating(x) do
329+
{_x, acc} =
330+
while {x, acc = Nx.tensor(0.0)}, Nx.less(acc, 100) do
331+
val = Nx.runtime_call(x, x, &double_callback/2)
332+
{val, acc + val}
333+
end
334+
335+
acc
336+
end
337+
338+
test "runtime_call in while with separate accumulator" do
339+
# x=5, iter 1: val=10, acc=10; iter 2: val=20, acc=30; iter 3: val=40, acc=70; iter 4: val=80, acc=150
340+
result = runtime_call_in_while_accumulating(Nx.tensor(5.0))
341+
assert_equal(result, Nx.tensor(150.0))
342+
end
222343
end

0 commit comments

Comments
 (0)