@@ -209,14 +209,135 @@ defmodule EXLA.Defn.RuntimeCallTest do
209209 assert_receive { :container_fun , ^ ref }
210210 end
211211
212+ def add_one_callback ( t , _opts ) , do: Nx . add ( t , 1 )
213+ def double_callback ( t , _opts ) , do: Nx . multiply ( t , 2 )
214+ def negate_callback ( t , _opts ) , do: Nx . negate ( t )
215+
212216 defn runtime_call_in_while ( x ) do
213217 while x , Nx . less ( x , 10 ) do
214- Nx . runtime_call ( x , x , fn t -> Nx . add ( t , 1 ) end )
218+ Nx . runtime_call ( x , x , & add_one_callback / 2 )
215219 end
216220 end
217221
218222 test "runtime_call inside while loop" do
219223 result = runtime_call_in_while ( Nx . tensor ( 0 ) )
220224 assert_equal ( result , Nx . tensor ( 10 ) )
221225 end
226+
227+ defn runtime_call_in_while_with_tuple ( x ) do
228+ { result , _count } =
229+ while { x , count = Nx . tensor ( 0 ) } , Nx . less ( count , 3 ) do
230+ doubled = Nx . runtime_call ( x , x , & double_callback / 2 )
231+ { doubled , count + 1 }
232+ end
233+
234+ result
235+ end
236+
237+ test "runtime_call inside while loop with tuple state" do
238+ result = runtime_call_in_while_with_tuple ( Nx . tensor ( 1.0 ) )
239+ # 1.0 * 2 * 2 * 2 = 8.0
240+ assert_equal ( result , Nx . tensor ( 8.0 ) )
241+ end
242+
243+ defn runtime_call_in_cond ( x ) do
244+ if Nx . greater ( x , 0 ) do
245+ Nx . runtime_call ( x , x , & double_callback / 2 )
246+ else
247+ Nx . runtime_call ( x , x , & negate_callback / 2 )
248+ end
249+ end
250+
251+ test "runtime_call inside cond branches" do
252+ assert_equal ( runtime_call_in_cond ( Nx . tensor ( 5.0 ) ) , Nx . tensor ( 10.0 ) )
253+ assert_equal ( runtime_call_in_cond ( Nx . tensor ( - 3.0 ) ) , Nx . tensor ( 3.0 ) )
254+ end
255+
256+ defn multiple_runtime_calls_in_while ( x ) do
257+ while x , Nx . less ( x , 100 ) do
258+ step1 = Nx . runtime_call ( x , x , & add_one_callback / 2 )
259+ Nx . runtime_call ( step1 , step1 , & double_callback / 2 )
260+ end
261+ end
262+
263+ test "multiple runtime_calls in one while body" do
264+ # (0+1)*2=2, (2+1)*2=6, (6+1)*2=14, (14+1)*2=30, (30+1)*2=62, (62+1)*2=126
265+ result = multiple_runtime_calls_in_while ( Nx . tensor ( 0.0 ) )
266+ assert_equal ( result , Nx . tensor ( 126.0 ) )
267+ end
268+
269+ def cast_to_float_callback ( t , _opts ) , do: Nx . as_type ( t , :f32 )
270+
271+ defn runtime_call_type_change ( x ) do
272+ out = % { x | type: { :f , 32 } }
273+ Nx . runtime_call ( out , x , & cast_to_float_callback / 2 )
274+ end
275+
276+ test "runtime_call where callback changes type" do
277+ result = runtime_call_type_change ( Nx . tensor ( [ 1 , 2 , 3 ] , type: :s32 ) )
278+ assert Nx . type ( result ) == { :f , 32 }
279+ assert_equal ( result , Nx . tensor ( [ 1.0 , 2.0 , 3.0 ] ) )
280+ end
281+
282+ def add_ten_callback ( t , _opts ) , do: Nx . add ( t , 10 )
283+
284+ defn nested_while_with_runtime_call ( x ) do
285+ { result , _ } =
286+ while { x , outer = Nx . tensor ( 0 ) } , Nx . less ( outer , 2 ) do
287+ { inner_result , _ } =
288+ while { x , inner = Nx . tensor ( 0 ) } , Nx . less ( inner , 3 ) do
289+ { Nx . runtime_call ( x , x , & add_one_callback / 2 ) , inner + 1 }
290+ end
291+
292+ { inner_result , outer + 1 }
293+ end
294+
295+ result
296+ end
297+
298+ test "runtime_call inside nested while loops" do
299+ result = nested_while_with_runtime_call ( Nx . tensor ( 0.0 ) )
300+ # Inner while runs 3 times each outer iteration, outer runs 2 times
301+ # 0 → +3 = 3 → +3 = 6
302+ assert_equal ( result , Nx . tensor ( 6.0 ) )
303+ end
304+
305+
306+ def sum_tuple_callback ( { a , b } , _opts ) , do: Nx . add ( a , b )
307+
308+ defn runtime_call_tuple_in_while ( x , y ) do
309+ { result , _ } =
310+ while { x , count = Nx . tensor ( 0 ) } , Nx . less ( count , 3 ) do
311+ summed = Nx . runtime_call ( x , { x , y } , & sum_tuple_callback / 2 )
312+ { summed , count + 1 }
313+ end
314+
315+ result
316+ end
317+
318+ @ tag :skip
319+ test "runtime_call with tuple input inside while" do
320+ # Crashes on EXLA with shape_mismatch in args_spec.
321+ # Returns wrong value (4.0 instead of 31.0) on Evaluator.
322+ # Tuple inputs to runtime_call inside while with captured
323+ # variables don't work correctly.
324+ result = runtime_call_tuple_in_while ( Nx . tensor ( 1.0 ) , Nx . tensor ( 10.0 ) )
325+ assert_equal ( result , Nx . tensor ( 31.0 ) )
326+ end
327+
328+ defn runtime_call_in_while_accumulating ( x ) do
329+ { _x , acc } =
330+ while { x , acc = Nx . tensor ( 0.0 ) } , Nx . less ( acc , 100 ) do
331+ val = Nx . runtime_call ( x , x , & double_callback / 2 )
332+ { val , acc + val }
333+ end
334+
335+ acc
336+ end
337+
338+ test "runtime_call in while with separate accumulator" do
339+ # x=5, iter 1: val=10, acc=10; iter 2: val=20, acc=30; iter 3: val=40, acc=70; iter 4: val=80, acc=150
340+ result = runtime_call_in_while_accumulating ( Nx . tensor ( 5.0 ) )
341+ assert_equal ( result , Nx . tensor ( 150.0 ) )
342+ end
222343end
0 commit comments