Skip to content

Commit

Permalink
Fix bug in graph partitioner
Browse files Browse the repository at this point in the history
Summary: Title

Test Plan: CI

Reviewed By: PaulZhang12

Differential Revision: D56688411
  • Loading branch information
tugsbayasgalan authored and facebook-github-bot committed May 1, 2024
1 parent 39eb5d4 commit e83a1e8
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 59 deletions.
8 changes: 4 additions & 4 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4061,16 +4061,16 @@ def forward(self, b_pred, b_t, x, y):
"""\
def forward(self, b_t, x, y):
submod_3 = self.submod_1
add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_3, b_t, x, y); submod_3 = b_t = x = y = None
add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_3, x, b_t, y); submod_3 = x = b_t = y = None
return (add_1,)""",
)

self.assertExpectedInline(
str(exported_program.graph_module.true_graph_0.submod_1.code.strip()),
"""\
def forward(self, b_t, x, y):
sub = torch.ops.aten.sub.Tensor(b_t, 1); b_t = None
add = torch.ops.aten.add.Tensor(sub, x); sub = x = None
def forward(self, x, b_t, y):
sub = torch.ops.aten.sub.Tensor(x, 1); x = None
add = torch.ops.aten.add.Tensor(sub, b_t); sub = b_t = None
add_1 = torch.ops.aten.add.Tensor(add, y); add = y = None
return add_1""",
)
Expand Down
46 changes: 27 additions & 19 deletions torch/_export/passes/replace_set_grad_with_hop_pass.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import copy

import torch
Expand Down Expand Up @@ -125,7 +126,9 @@ def _remove_set_grad_and_inline(node: torch.fx.Node):
node_inline_(node)


def _sequential_split_and_maybe_inline_subgraphs(gm: torch.fx.GraphModule):
def _sequential_split_and_maybe_inline_subgraphs(
gm: torch.fx.GraphModule, graph_signature
):
"""
Helper function for replace_set_grad_with_hop_pass().
Split the graph module into multiple subgraphs based on the set_grad_enabled nodes.
Expand All @@ -141,35 +144,40 @@ def _sequential_split_and_maybe_inline_subgraphs(gm: torch.fx.GraphModule):
if need_replacing:
new_gm = sequential_split(gm, _is_set_grad_enabled_node)

def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True):
_replace_with_hop(node)
else:
_remove_set_grad_and_inline(node)

nodes_map(
list(new_gm.graph.nodes),
lambda node: (
_maybe_inline_or_replace_with_hop(node)
if node.op == "call_module"
else node
),
)
replace_ctx = contextlib.nullcontext()
if graph_signature is not None:
replace_ctx = new_gm._set_replace_hook(graph_signature.get_replace_hook()) # type: ignore[assignment]

with replace_ctx:

def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True):
_replace_with_hop(node)
else:
_remove_set_grad_and_inline(node)

nodes_map(
list(new_gm.graph.nodes),
lambda node: (
_maybe_inline_or_replace_with_hop(node)
if node.op == "call_module"
else node
),
)
return new_gm

return gm


def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule):
new_gm = _sequential_split_and_maybe_inline_subgraphs(gm)

def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule, graph_signature):
new_gm = _sequential_split_and_maybe_inline_subgraphs(gm, graph_signature)
# recursively call
for node in new_gm.graph.nodes:
if node.op == "get_attr":
subgm = getattr(new_gm, node.target)
if not isinstance(subgm, torch.fx.GraphModule):
continue
new_subgm = replace_set_grad_with_hop_pass(subgm)
new_subgm = replace_set_grad_with_hop_pass(subgm, None)
setattr(new_gm, node.target, new_subgm)

new_gm.recompile()
Expand Down
70 changes: 35 additions & 35 deletions torch/export/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,41 +509,6 @@ def _compiling_state_context():
if isinstance(mod, torch.fx.GraphModule) and hasattr(mod, "meta"):
gm.meta.update(mod.meta)

if pre_dispatch:
from torch._export.passes.replace_set_grad_with_hop_pass import (
replace_set_grad_with_hop_pass,
)

gm = replace_set_grad_with_hop_pass(gm)

# Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes.
for _mod in gm.modules():
if not isinstance(_mod, torch.fx.GraphModule):
continue
for node in _mod.graph.nodes:
if node.op in ["placeholder", "output"]:
node.meta.pop("nn_module_stack", None)
node.meta.pop("stack_trace", None)

# NOTE: aot_export adds symint metadata for placeholders with int values;
# since these become specialized, we replace such metadata with the original values
flat_args = pytree.tree_leaves((fake_args, fake_kwargs))
index = 0
total_non_user_inputs = (
len(graph_signature.parameters)
+ len(graph_signature.buffers)
+ len(graph_signature.input_tokens)
)
for node in gm.graph.nodes:
if node.op == "placeholder":
if index >= total_non_user_inputs:
user_arg = flat_args[index - total_non_user_inputs]
if not isinstance(user_arg, torch.Tensor):
node.meta["val"] = user_arg
index += 1

is_joint = graph_signature.backward_signature is not None

def make_argument_spec(i, node) -> ArgumentSpec:
if isinstance(node, (int, bool, float, type(None))):
# For const outputs we just directly return this
Expand All @@ -570,6 +535,25 @@ def make_argument_spec(i, node) -> ArgumentSpec:
f"while writing the metadata for exported program"
)

is_joint = graph_signature.backward_signature is not None

# NOTE: aot_export adds symint metadata for placeholders with int values;
# since these become specialized, we replace such metadata with the original values
flat_args = pytree.tree_leaves((fake_args, fake_kwargs))
index = 0
total_non_user_inputs = (
len(graph_signature.parameters)
+ len(graph_signature.buffers)
+ len(graph_signature.input_tokens)
)
for node in gm.graph.nodes:
if node.op == "placeholder":
if index >= total_non_user_inputs:
user_arg = flat_args[index - total_non_user_inputs]
if not isinstance(user_arg, torch.Tensor):
node.meta["val"] = user_arg
index += 1

input_specs, output_specs = _sig_to_specs(
user_inputs=set(graph_signature.user_inputs),
inputs_to_parameters=graph_signature.inputs_to_parameters, # type: ignore[arg-type]
Expand Down Expand Up @@ -598,6 +582,22 @@ def make_argument_spec(i, node) -> ArgumentSpec:
input_specs=input_specs, output_specs=output_specs
)

if pre_dispatch:
from torch._export.passes.replace_set_grad_with_hop_pass import (
replace_set_grad_with_hop_pass,
)

gm = replace_set_grad_with_hop_pass(gm, export_graph_signature)

# Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes.
for _mod in gm.modules():
if not isinstance(_mod, torch.fx.GraphModule):
continue
for node in _mod.graph.nodes:
if node.op in ["placeholder", "output"]:
node.meta.pop("nn_module_stack", None)
node.meta.pop("stack_trace", None)

constants = rewrite_script_object_meta(gm)
attr_constants = lift_constants_pass(gm, export_graph_signature, constant_attrs)
assert not any(
Expand Down
16 changes: 15 additions & 1 deletion torch/fx/passes/split_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,11 @@ def instantiate_node_partition_mapping(node):
)

already_constructed_attr_nodes = set()

# We actually need to insert the placeholder nodes in the original order
# otherwise graph signature will be wrong.
original_order = [node for node in m.graph.nodes if node.op == "placeholder"]

for partition_name in construct_order_partitions:
partition = partitions[partition_name]

Expand All @@ -475,8 +480,17 @@ def instantiate_node_partition_mapping(node):
if keep_original_order:
# first get the attr nodes required by this partition
orig_mod_attr_nodes: List[Node] = [
orig_mod_env[key] for key in partition.inputs
orig_mod_env[key] for key in partition.inputs if key not in original_order
]

for node in original_order:
if node in already_constructed_attr_nodes:
continue # already added this attr to the base graph
base_mod_env, based_mod_attrs = construct_graph(
node, base_mod_env, base_mod_attrs
)
already_constructed_attr_nodes.add(node)

# Construct GraphModule for this partition
for node in orig_mod_attr_nodes: # type: ignore[attr-defined]
if node in already_constructed_attr_nodes:
Expand Down

0 comments on commit e83a1e8

Please sign in to comment.