Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions exir/backend/test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,6 +1446,85 @@ def inputs(self):
torch.allclose(model_outputs[0], ref_output, atol=1e-03, rtol=1e-03)
)

def test_arrange_graph_outputs_reorders_mutations_before_user_outputs(self):
"""
Directly test that arrange_graph_outputs correctly reorders a
submodule's output tuple so that BUFFER_MUTATION outputs come before
USER_OUTPUT outputs, and that getitem indices in the parent graph are
remapped accordingly.
"""
from executorch.exir.lowered_backend_module import arrange_graph_outputs
from torch.export.exported_program import OutputKind, OutputSpec, TensorArgument

# Build a submodule graph with 3 outputs in order:
# [user_out_0, buffer_mut_1, user_out_2]
# The expected reordering is:
# [buffer_mut_1, user_out_0, user_out_2]
sub_graph = torch.fx.Graph()
x = sub_graph.placeholder("x")
buf = sub_graph.placeholder("buf")
add_node = sub_graph.call_function(torch.ops.aten.add.Tensor, (x, x))
mul_node = sub_graph.call_function(torch.ops.aten.mul.Tensor, (buf, x))
sub_node = sub_graph.call_function(torch.ops.aten.sub.Tensor, (x, x))
# Output order: user, mutation, user
sub_graph.output((add_node, mul_node, sub_node))
sub_gm = torch.fx.GraphModule({}, sub_graph)

output_specs = [
OutputSpec(
kind=OutputKind.USER_OUTPUT,
arg=TensorArgument(name="add"),
target=None,
),
OutputSpec(
kind=OutputKind.BUFFER_MUTATION,
arg=TensorArgument(name="mul"),
target="buf",
),
OutputSpec(
kind=OutputKind.USER_OUTPUT,
arg=TensorArgument(name="sub"),
target=None,
),
]

# Build a parent graph with a call_module node and getitem users
parent_graph = torch.fx.Graph()
px = parent_graph.placeholder("x")
call_mod = parent_graph.call_module("sub_mod", (px,))
gi0 = parent_graph.call_function(operator.getitem, (call_mod, 0))
gi1 = parent_graph.call_function(operator.getitem, (call_mod, 1))
gi2 = parent_graph.call_function(operator.getitem, (call_mod, 2))
parent_graph.output((gi0, gi1, gi2))

# Run arrange_graph_outputs
arrange_graph_outputs(sub_gm, output_specs, call_mod)

# Verify output_specs are reordered: mutation first
self.assertEqual(output_specs[0].kind, OutputKind.BUFFER_MUTATION)
self.assertEqual(output_specs[1].kind, OutputKind.USER_OUTPUT)
self.assertEqual(output_specs[2].kind, OutputKind.USER_OUTPUT)
self.assertEqual(output_specs[0].target, "buf")

# Verify the submodule graph output tuple is reordered
output_node = None
for node in sub_gm.graph.nodes:
if node.op == "output":
output_node = node
break
reordered = list(output_node.args[0])
self.assertIs(reordered[0], mul_node) # buffer mutation first
self.assertIs(reordered[1], add_node) # then user outputs
self.assertIs(reordered[2], sub_node)

# Verify getitem indices were remapped:
# old 0 (user) -> new 1
# old 1 (mutation) -> new 0
# old 2 (user) -> new 2 (unchanged)
self.assertEqual(gi0.args[1], 1)
self.assertEqual(gi1.args[1], 0)
self.assertEqual(gi2.args[1], 2)

def test_prohibited_nested_backends(self):
class MyBackend(BackendDetails):
@staticmethod
Expand Down
98 changes: 96 additions & 2 deletions exir/lowered_backend_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,97 @@ def arrange_graph_placeholders(
return gm


def arrange_graph_outputs(
gm: torch.fx.GraphModule,
output_specs: List[OutputSpec],
call_module_node: torch.fx.Node,
) -> torch.fx.GraphModule:
"""
Reorders the output tuple of the graph so that buffer mutation outputs come
before user outputs, matching the ordering that ExportedProgram's verifier
expects: [buffer_mutations..., user_outputs...].

The partitioner may produce a submodule whose output tuple has buffer
mutations and user outputs interleaved in arbitrary order. The verifier
determines which outputs are mutations by position (first N outputs where
N = number of mutation specs), so a misordered tuple causes a
SpecViolationError.

This function builds a permutation from the output_specs (which
_get_new_signature already classified correctly) and rewrites the graph's
output node to match. It also remaps getitem indices on the parent
graph's call_module_node so the parent continues to extract the correct
outputs.

Args:
gm: The graph module whose output ordering may need adjustment.
output_specs: The output specs built by _get_new_signature, with
correct kind annotations but potentially mismatched ordering
relative to the graph's output tuple.
call_module_node: The call_module node in the parent graph whose
getitem users need index remapping.

Returns:
The graph module with reordered outputs (modified in-place).
"""
# Find the output node
output_node = None
for node in gm.graph.nodes:
if node.op == "output":
output_node = node
break

if output_node is None or not output_node.args[0]:
return gm

old_outputs = list(output_node.args[0])

if len(old_outputs) != len(output_specs):
raise RuntimeError(
f"Mismatch between graph outputs ({len(old_outputs)}) and "
f"output_specs ({len(output_specs)}). This indicates a bug in "
"_get_new_signature."
)

# Separate indices by kind: mutations first, then user outputs
mutation_indices = []
user_output_indices = []
for i, spec in enumerate(output_specs):
if spec.kind in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION):
mutation_indices.append(i)
else:
user_output_indices.append(i)

new_order = mutation_indices + user_output_indices

# Check if already in correct order
if new_order == list(range(len(old_outputs))):
return gm

# Build reverse mapping: old_index -> new_index
old_to_new = {old_idx: new_idx for new_idx, old_idx in enumerate(new_order)}

# Reorder the output tuple in the submodule graph
new_outputs = [old_outputs[i] for i in new_order]
output_node.args = (tuple(new_outputs),)

# Reorder the output_specs to match (in-place)
reordered_specs = [output_specs[i] for i in new_order]
output_specs.clear()
output_specs.extend(reordered_specs)

# Remap getitem indices in the parent graph
for user in list(call_module_node.users.keys()):
if user.op == "call_function" and user.target == operator.getitem:
old_idx = user.args[1]
if isinstance(old_idx, int) and old_idx in old_to_new:
user.args = (user.args[0], old_to_new[old_idx])

gm.graph.lint()

return gm


# TODO Don't regenerate new signature manually.
def _get_new_signature( # noqa: C901
original_program: ExportedProgram,
Expand Down Expand Up @@ -706,8 +797,6 @@ def create_exported_program_from_submodule(
# Arrange the submodule's placeholders in order
submodule = arrange_graph_placeholders(submodule, owning_program, tag)

# TODO: we probably need to arrange the outputs wrt buffer mutations.

# Get updated graph signature
(
subgraph_signature,
Expand All @@ -719,6 +808,11 @@ def create_exported_program_from_submodule(
owning_program, submodule, call_module_node, tag, is_submodule
)

# Reorder outputs: buffer mutations first, then user outputs.
# The verifier expects this ordering but _get_new_signature produces
# output_specs in graph order which may interleave the two kinds.
arrange_graph_outputs(submodule, subgraph_signature.output_specs, call_module_node)

in_spec = pytree.tree_flatten((tuple(subgraph_signature.user_inputs), {}))[1]
out_spec = pytree.tree_flatten(subgraph_signature.user_outputs)[1]

Expand Down
Loading