@@ -526,15 +526,6 @@ defmodule EXLA.Defn do
526526
527527 { initial , cache } = recur_composite ( initial_arg , state , cache )
528528
529- # Prepend PID before token so the order in the while state is [token, pid | values].
530- # Both must be threaded explicitly because while regions are IsolatedFromAbove.
531- initial =
532- if state . callback_pid_value do
533- [ state . callback_pid_value | initial ]
534- else
535- initial
536- end
537-
538529 initial =
539530 if token = get_token ( cache ) do
540531 [ token | initial ]
@@ -548,19 +539,14 @@ defmodule EXLA.Defn do
548539 results =
549540 Value . while ( function , pred_computation , body_computation , List . flatten ( initial ) )
550541
551- # Extract token and PID from results in the same order they were prepended.
552- { token , results } =
553- if get_token ( cache ) do
554- [ token | rest ] = results
555- { token , rest }
556- else
557- { nil , results }
558- end
559-
560- results = if state . callback_pid_value , do: tl ( results ) , else: results
561- result = wrap_tuple_result ( results , initial_arg )
562- cache = if token , do: update_token ( cache , token ) , else: cache
563- { result , cache }
542+ if get_token ( cache ) do
543+ [ token | results ] = results
544+ result = wrap_tuple_result ( results , initial_arg )
545+ { result , update_token ( cache , token ) }
546+ else
547+ result = wrap_tuple_result ( results , initial_arg )
548+ { result , cache }
549+ end
564550 end
565551
566552 defp cached_recur_operator ( :cond , % T { data: % Expr { args: args } } = t , state , cache ) do
@@ -761,37 +747,15 @@ defmodule EXLA.Defn do
761747 { computation , Map . put ( cache , key , computation ) }
762748 end
763749
764- has_token = get_token ( cache ) != nil
765-
766- typespecs = container_to_typespecs ( expr )
767- call_inputs = call_args
768-
769- { typespecs , call_inputs } =
770- if has_token do
771- { [ Typespec . token ( ) | typespecs ] , [ get_token ( cache ) | call_inputs ] }
772- else
773- { typespecs , call_inputs }
774- end
775-
776- # PID is always present — prepend it (consistent with while loop ordering)
777- pid_typespec = Value . get_typespec ( state . callback_pid_value )
778- typespecs = [ pid_typespec | typespecs ]
779- call_inputs = [ state . callback_pid_value | call_inputs ]
780-
781- result = Value . call ( state . builder , call_inputs , call_body , typespecs )
782-
783- { token , result } =
784- if has_token do
785- [ token | rest ] = result
786- { token , rest }
787- else
788- { nil , result }
789- end
790-
791- # Drop the PID from results (it was prepended)
792- [ _pid | result ] = result
793- cache = if token , do: update_token ( cache , token ) , else: cache
794- { wrap_tuple_result ( result , expr ) , cache }
750+ if token = get_token ( cache ) do
751+ typespecs = [ Typespec . token ( ) | container_to_typespecs ( expr ) ]
752+ [ token | result ] = Value . call ( state . builder , [ token | call_args ] , call_body , typespecs )
753+ { wrap_tuple_result ( result , expr ) , update_token ( cache , token ) }
754+ else
755+ typespecs = container_to_typespecs ( expr )
756+ result = Value . call ( state . builder , call_args , call_body , typespecs )
757+ { wrap_tuple_result ( result , expr ) , cache }
758+ end
795759 end
796760
797761 defp cached_recur_operator (
@@ -1689,21 +1653,11 @@ defmodule EXLA.Defn do
16891653 defp new_cache ( outfeed ) ,
16901654 do: % { __MODULE__ => outfeed }
16911655
1692- defp merge_outfeed ( % { __MODULE__ => outfeed } = cache , new_cache ) do
1693- new_outfeed = new_cache [ __MODULE__ ]
1694- inner_callbacks = Map . get ( new_cache , runtime_callbacks_key ( ) , [ ] )
1656+ defp merge_outfeed ( % { __MODULE__ => outfeed } = cache , % { __MODULE__ => new_outfeed } ) ,
1657+ do: % { cache | __MODULE__ => Outfeed . with_token ( new_outfeed , outfeed . token ) }
16951658
1696- cache
1697- |> Map . put ( __MODULE__ , Outfeed . with_token ( new_outfeed , outfeed . token ) )
1698- |> Map . update ( runtime_callbacks_key ( ) , inner_callbacks , & ( inner_callbacks ++ & 1 ) )
1699- end
1700-
1701- defp reset_token ( cache , token ) do
1702- % {
1703- __MODULE__ => Outfeed . with_token ( cache [ __MODULE__ ] , token ) ,
1704- runtime_callbacks_key ( ) => Map . get ( cache , runtime_callbacks_key ( ) , [ ] )
1705- }
1706- end
1659+ defp reset_token ( % { __MODULE__ => outfeed } , token ) ,
1660+ do: % { __MODULE__ => Outfeed . with_token ( outfeed , token ) }
17071661
17081662 defp update_token ( % { __MODULE__ => outfeed } = cache , token ) ,
17091663 do: % { cache | __MODULE__ => Outfeed . with_token ( outfeed , token ) }
@@ -1801,20 +1755,11 @@ defmodule EXLA.Defn do
18011755 { region , args } = Function . push_region ( state . builder , arg_typespecs )
18021756
18031757 outer_token = get_token ( cache )
1804- outer_pid = state . callback_pid_value
18051758
1806- { inner_token , args } =
1759+ { inner_token , arg_params } =
18071760 if outer_token do
1808- [ arg_token | rest ] = args
1809- { arg_token , rest }
1810- else
1811- { nil , args }
1812- end
1813-
1814- { inner_pid , arg_params } =
1815- if outer_pid do
1816- [ arg_pid | rest ] = args
1817- { arg_pid , rest }
1761+ [ arg_token | arg_params ] = args
1762+ { arg_token , arg_params }
18181763 else
18191764 { nil , args }
18201765 end
@@ -1824,8 +1769,7 @@ defmodule EXLA.Defn do
18241769 state = % {
18251770 state
18261771 | params: Map . new ( params ) ,
1827- scope_ids: Tree . scope_ids ( expr ) ,
1828- callback_pid_value: inner_pid
1772+ scope_ids: Tree . scope_ids ( expr )
18291773 }
18301774
18311775 expr =
@@ -1839,10 +1783,11 @@ defmodule EXLA.Defn do
18391783
18401784 res =
18411785 if type == :with_token do
1842- flat = List . flatten ( res )
1843- flat = if outer_pid , do: [ state . callback_pid_value | flat ] , else: flat
1844- flat = if outer_token , do: [ get_token ( comp_cache ) | flat ] , else: flat
1845- flat
1786+ if outer_token do
1787+ [ get_token ( comp_cache ) | List . flatten ( res ) ]
1788+ else
1789+ List . flatten ( res )
1790+ end
18461791 else
18471792 Enum . map ( res , & to_type ( & 1 , type ) )
18481793 end
@@ -1860,9 +1805,7 @@ defmodule EXLA.Defn do
18601805 out_typespecs = container_to_typespecs ( expr )
18611806
18621807 outer_token = get_token ( cache )
1863- outer_pid = state . callback_pid_value
18641808 token_typespec = Typespec . token ( )
1865- pid_typespec = if outer_pid , do: Value . get_typespec ( outer_pid )
18661809
18671810 { arg_typespecs , out_typespecs } =
18681811 if outer_token do
@@ -1871,10 +1814,6 @@ defmodule EXLA.Defn do
18711814 { arg_typespecs , out_typespecs }
18721815 end
18731816
1874- # PID is always present — prepend (consistent with while loop ordering)
1875- arg_typespecs = [ pid_typespec | arg_typespecs ]
1876- out_typespecs = [ pid_typespec | out_typespecs ]
1877-
18781817 function = EXLA.MLIR.Module . add_function ( module , name , arg_typespecs , out_typespecs )
18791818 args = EXLA.MLIR.Function . get_arguments ( function )
18801819
@@ -1886,25 +1825,22 @@ defmodule EXLA.Defn do
18861825 { nil , args }
18871826 end
18881827
1889- # PID is always the next arg after token
1890- [ inner_pid | args ] = args
1891-
18921828 params = Enum . with_index ( args , fn param , i -> { i , param } end )
18931829
18941830 state = % {
18951831 state
18961832 | builder: function ,
18971833 params: Map . new ( params ) ,
1898- scope_ids: Tree . scope_ids ( expr ) ,
1899- callback_pid_value: inner_pid
1834+ scope_ids: Tree . scope_ids ( expr )
19001835 }
19011836
19021837 { res , comp_cache } = recur_composite ( expr , state , reset_token ( cache , inner_token ) )
19031838
1904- ret = List . flatten ( res )
1905- ret = [ state . callback_pid_value | ret ]
1906- ret = if outer_token , do: [ get_token ( comp_cache ) | ret ] , else: ret
1907- Value . func_return ( function , ret )
1839+ if outer_token do
1840+ Value . func_return ( function , [ get_token ( comp_cache ) | List . flatten ( res ) ] )
1841+ else
1842+ Value . func_return ( function , List . flatten ( res ) )
1843+ end
19081844
19091845 { function , merge_outfeed ( cache , comp_cache ) }
19101846 end
0 commit comments