Skip to content

Commit

Permalink
Fix bug in graph partitioner
Browse files Browse the repository at this point in the history
Summary: Title

Test Plan: CI

Reviewed By: PaulZhang12

Differential Revision: D56688411
  • Loading branch information
tugsbayasgalan authored and facebook-github-bot committed Apr 30, 2024
1 parent ea347fa commit d334c5f
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 31 deletions.
8 changes: 4 additions & 4 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4061,16 +4061,16 @@ def forward(self, b_pred, b_t, x, y):
"""\
def forward(self, b_t, x, y):
submod_3 = self.submod_1
add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_3, b_t, x, y); submod_3 = b_t = x = y = None
add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_3, x, b_t, y); submod_3 = x = b_t = y = None
return (add_1,)""",
)

self.assertExpectedInline(
str(exported_program.graph_module.true_graph_0.submod_1.code.strip()),
"""\
def forward(self, b_t, x, y):
sub = torch.ops.aten.sub.Tensor(b_t, 1); b_t = None
add = torch.ops.aten.add.Tensor(sub, x); sub = x = None
def forward(self, x, b_t, y):
sub = torch.ops.aten.sub.Tensor(x, 1); x = None
add = torch.ops.aten.add.Tensor(sub, b_t); sub = b_t = None
add_1 = torch.ops.aten.add.Tensor(add, y); add = y = None
return add_1""",
)
Expand Down
43 changes: 24 additions & 19 deletions torch/_export/passes/replace_set_grad_with_hop_pass.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import copy

import torch
Expand Down Expand Up @@ -125,7 +126,7 @@ def _remove_set_grad_and_inline(node: torch.fx.Node):
node_inline_(node)


def _sequential_split_and_maybe_inline_subgraphs(gm: torch.fx.GraphModule):
def _sequential_split_and_maybe_inline_subgraphs(gm: torch.fx.GraphModule, graph_signature):
"""
Helper function for replace_set_grad_with_hop_pass().
Split the graph module into multiple subgraphs based on the set_grad_enabled nodes.
Expand All @@ -141,35 +142,39 @@ def _sequential_split_and_maybe_inline_subgraphs(gm: torch.fx.GraphModule):
if need_replacing:
new_gm = sequential_split(gm, _is_set_grad_enabled_node)

def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True):
_replace_with_hop(node)
else:
_remove_set_grad_and_inline(node)

nodes_map(
list(new_gm.graph.nodes),
lambda node: (
_maybe_inline_or_replace_with_hop(node)
if node.op == "call_module"
else node
),
)
replace_ctx = contextlib.nullcontext()
if graph_signature is not None:
replace_ctx = new_gm._set_replace_hook(graph_signature.get_replace_hook()) # type: ignore[assignment]

with replace_ctx:
def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True):
_replace_with_hop(node)
else:
_remove_set_grad_and_inline(node)

nodes_map(
list(new_gm.graph.nodes),
lambda node: (
_maybe_inline_or_replace_with_hop(node)
if node.op == "call_module"
else node
),
)
return new_gm

return gm


def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule):
new_gm = _sequential_split_and_maybe_inline_subgraphs(gm)

def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule, graph_signature):
new_gm = _sequential_split_and_maybe_inline_subgraphs(gm, graph_signature)
# recursively call
for node in new_gm.graph.nodes:
if node.op == "get_attr":
subgm = getattr(new_gm, node.target)
if not isinstance(subgm, torch.fx.GraphModule):
continue
new_subgm = replace_set_grad_with_hop_pass(subgm)
new_subgm = replace_set_grad_with_hop_pass(subgm, None)
setattr(new_gm, node.target, new_subgm)

new_gm.recompile()
Expand Down
14 changes: 7 additions & 7 deletions torch/export/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,13 +509,6 @@ def _compiling_state_context():
if isinstance(mod, torch.fx.GraphModule) and hasattr(mod, "meta"):
gm.meta.update(mod.meta)

if pre_dispatch:
from torch._export.passes.replace_set_grad_with_hop_pass import (
replace_set_grad_with_hop_pass,
)

gm = replace_set_grad_with_hop_pass(gm)

# Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes.
for _mod in gm.modules():
if not isinstance(_mod, torch.fx.GraphModule):
Expand Down Expand Up @@ -619,6 +612,13 @@ def make_argument_spec(i, node) -> ArgumentSpec:
constants,
)

if pre_dispatch:
from torch._export.passes.replace_set_grad_with_hop_pass import (
replace_set_grad_with_hop_pass,
)

gm = replace_set_grad_with_hop_pass(gm, export_graph_signature)

@dataclasses.dataclass
class _ExportedProgramNonStrict:
gm: torch.fx.GraphModule
Expand Down
16 changes: 15 additions & 1 deletion torch/fx/passes/split_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,11 @@ def instantiate_node_partition_mapping(node):
)

already_constructed_attr_nodes = set()

# We actually need to insert the placeholder nodes in the original order
# otherwise graph signature will be wrong.
original_order = [node for node in m.graph.nodes if node.op == "placeholder"]

for partition_name in construct_order_partitions:
partition = partitions[partition_name]

Expand All @@ -475,8 +480,17 @@ def instantiate_node_partition_mapping(node):
if keep_original_order:
# first get the attr nodes required by this partition
orig_mod_attr_nodes: List[Node] = [
orig_mod_env[key] for key in partition.inputs
orig_mod_env[key] for key in partition.inputs if key not in original_order
]

for node in original_order:
if node in already_constructed_attr_nodes:
continue # already added this attr to the base graph
base_mod_env, based_mod_attrs = construct_graph(
node, base_mod_env, base_mod_attrs
)
already_constructed_attr_nodes.add(node)

# Construct GraphModule for this partition
for node in orig_mod_attr_nodes: # type: ignore[attr-defined]
if node in already_constructed_attr_nodes:
Expand Down

0 comments on commit d334c5f

Please sign in to comment.