Skip to content

Commit

Permalink
Fix bug in graph partitioner
Browse files Browse the repository at this point in the history
Summary: Title

Test Plan: CI

Differential Revision: D56688411
  • Loading branch information
tugsbayasgalan authored and facebook-github-bot committed Apr 29, 2024
1 parent 3d1dd79 commit c9f67c6
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 8 deletions.
8 changes: 4 additions & 4 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3968,16 +3968,16 @@ def forward(self, b_pred, b_t, x, y):
"""\
def forward(self, b_t, x, y):
submod_3 = self.submod_1
add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_3, b_t, x, y); submod_3 = b_t = x = y = None
add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_3, x, b_t, y); submod_3 = x = b_t = y = None
return (add_1,)""",
)

self.assertExpectedInline(
str(exported_program.graph_module.true_graph_0.submod_1.code.strip()),
"""\
def forward(self, b_t, x, y):
sub = torch.ops.aten.sub.Tensor(b_t, 1); b_t = None
add = torch.ops.aten.add.Tensor(sub, x); sub = x = None
def forward(self, x, b_t, y):
sub = torch.ops.aten.sub.Tensor(x, 1); x = None
add = torch.ops.aten.add.Tensor(sub, b_t); sub = b_t = None
add_1 = torch.ops.aten.add.Tensor(add, y); add = y = None
return add_1""",
)
Expand Down
52 changes: 50 additions & 2 deletions torch/_export/passes/replace_set_grad_with_hop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,55 @@ def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
return gm


def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule):
def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule, graph_signature):
original_output_nodes = ()
for node in gm.graph.nodes:
if node.op == "output":
original_output_nodes = node.args

gm = _replace_set_grad_with_hop_pass(gm)

updated_output_nodes = ()
for node in gm.graph.nodes:
if node.op == "output":
updated_output_nodes = node.args

# It could be updated because replace_set_grad_with_hop_pass adds new
# getitem nodes into the graph.
assert len(original_output_nodes) == len(updated_output_nodes)
old_output_name_to_new_output_name = {}
for k, v in zip(*original_output_nodes, *updated_output_nodes):
if isinstance(k, torch.fx.Node) and isinstance(k, torch.fx.Node):
old_output_name_to_new_output_name[k.name] = v.name
# If there is constant return in the end, it should be true for updated outputs too
else:
constant_types = (int, float, str, type(None))
assert k == v
assert isinstance(k, constant_types)

buffers_to_mutate_copy = graph_signature.buffers_to_mutate.copy()
user_inputs_to_mutate_copy = graph_signature.user_inputs_to_mutate.copy()
for k in old_output_name_to_new_output_name:
if k in graph_signature.buffers_to_mutate:
graph_signature.buffers_to_mutate[old_output_name_to_new_output_name[k]] = buffers_to_mutate_copy[k]
if k not in old_output_name_to_new_output_name.values():
del graph_signature.buffers_to_mutate[k]

if k in graph_signature.user_inputs_to_mutate:
graph_signature.user_inputs_to_mutate[old_output_name_to_new_output_name[k]] = user_inputs_to_mutate_copy[k]
if k not in old_output_name_to_new_output_name.values():
del graph_signature.user_inputs_to_mutate[k]

for i, k in enumerate(graph_signature.user_outputs):
if k in old_output_name_to_new_output_name:
new_k = old_output_name_to_new_output_name[k]
graph_signature.user_outputs[i] = new_k

return gm, graph_signature


def _replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule):

new_gm = _sequential_split_and_maybe_inline_subgraphs(gm)

# recursively call
Expand All @@ -169,7 +217,7 @@ def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule):
subgm = getattr(new_gm, node.target)
if not isinstance(subgm, torch.fx.GraphModule):
continue
new_subgm = replace_set_grad_with_hop_pass(subgm)
new_subgm = _replace_set_grad_with_hop_pass(subgm)
setattr(new_gm, node.target, new_subgm)

new_gm.recompile()
Expand Down
2 changes: 1 addition & 1 deletion torch/export/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def _compiling_state_context():
replace_set_grad_with_hop_pass,
)

gm = replace_set_grad_with_hop_pass(gm)
gm, graph_signature = replace_set_grad_with_hop_pass(gm, graph_signature)

# Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes.
for _mod in gm.modules():
Expand Down
16 changes: 15 additions & 1 deletion torch/fx/passes/split_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,11 @@ def instantiate_node_partition_mapping(node):
)

already_constructed_attr_nodes = set()

# We actually need to insert the placeholder nodes in the original order
# otherwise graph signature will be wrong.
original_order = [node for node in m.graph.nodes if node.op == "placeholder"]

for partition_name in construct_order_partitions:
partition = partitions[partition_name]

Expand All @@ -475,8 +480,17 @@ def instantiate_node_partition_mapping(node):
if keep_original_order:
# first get the attr nodes required by this partition
orig_mod_attr_nodes: List[Node] = [
orig_mod_env[key] for key in partition.inputs
orig_mod_env[key] for key in partition.inputs if key not in original_order
]

for node in original_order:
if node in already_constructed_attr_nodes:
continue # already added this attr to the base graph
base_mod_env, based_mod_attrs = construct_graph(
node, base_mod_env, base_mod_attrs
)
already_constructed_attr_nodes.add(node)

# Construct GraphModule for this partition
for node in orig_mod_attr_nodes: # type: ignore[attr-defined]
if node in already_constructed_attr_nodes:
Expand Down

0 comments on commit c9f67c6

Please sign in to comment.