Skip to content

Commit a9d5962

Browse files
tugsbayasgalanfacebook-github-bot
authored andcommitted
Fix bug in graph partitioner
Summary: Title Test Plan: CI Reviewed By: PaulZhang12 Differential Revision: D56688411
1 parent faee0e5 commit a9d5962

File tree

4 files changed

+73
-8
lines changed

4 files changed

+73
-8
lines changed

test/export/test_export.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3968,16 +3968,16 @@ def forward(self, b_pred, b_t, x, y):
39683968
"""\
39693969
def forward(self, b_t, x, y):
39703970
submod_3 = self.submod_1
3971-
add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_3, b_t, x, y); submod_3 = b_t = x = y = None
3971+
add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_3, x, b_t, y); submod_3 = x = b_t = y = None
39723972
return (add_1,)""",
39733973
)
39743974

39753975
self.assertExpectedInline(
39763976
str(exported_program.graph_module.true_graph_0.submod_1.code.strip()),
39773977
"""\
3978-
def forward(self, b_t, x, y):
3979-
sub = torch.ops.aten.sub.Tensor(b_t, 1); b_t = None
3980-
add = torch.ops.aten.add.Tensor(sub, x); sub = x = None
3978+
def forward(self, x, b_t, y):
3979+
sub = torch.ops.aten.sub.Tensor(x, 1); x = None
3980+
add = torch.ops.aten.add.Tensor(sub, b_t); sub = b_t = None
39813981
add_1 = torch.ops.aten.add.Tensor(add, y); add = y = None
39823982
return add_1""",
39833983
)

torch/_export/passes/replace_set_grad_with_hop_pass.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,58 @@ def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
160160
return gm
161161

162162

163-
def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule):
163+
def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule, graph_signature):
164+
original_output_nodes = ()
165+
for node in gm.graph.nodes:
166+
if node.op == "output":
167+
original_output_nodes = node.args
168+
169+
gm = _replace_set_grad_with_hop_pass(gm)
170+
171+
updated_output_nodes = ()
172+
for node in gm.graph.nodes:
173+
if node.op == "output":
174+
updated_output_nodes = node.args
175+
176+
# It could be updated because replace_set_grad_with_hop_pass adds new
177+
# getitem nodes into the graph.
178+
assert len(original_output_nodes) == len(updated_output_nodes)
179+
old_output_name_to_new_output_name = {}
180+
for k, v in zip(*original_output_nodes, *updated_output_nodes):
181+
if isinstance(k, torch.fx.Node) and isinstance(k, torch.fx.Node):
182+
old_output_name_to_new_output_name[k.name] = v.name
183+
# If there is constant return in the end, it should be true for updated outputs too
184+
else:
185+
constant_types = (int, float, str, type(None))
186+
assert k == v
187+
assert isinstance(k, constant_types)
188+
189+
buffers_to_mutate_copy = graph_signature.buffers_to_mutate.copy()
190+
user_inputs_to_mutate_copy = graph_signature.user_inputs_to_mutate.copy()
191+
for k in old_output_name_to_new_output_name:
192+
if k in graph_signature.buffers_to_mutate:
193+
graph_signature.buffers_to_mutate[
194+
old_output_name_to_new_output_name[k]
195+
] = buffers_to_mutate_copy[k]
196+
if k not in old_output_name_to_new_output_name.values():
197+
del graph_signature.buffers_to_mutate[k]
198+
199+
if k in graph_signature.user_inputs_to_mutate:
200+
graph_signature.user_inputs_to_mutate[
201+
old_output_name_to_new_output_name[k]
202+
] = user_inputs_to_mutate_copy[k]
203+
if k not in old_output_name_to_new_output_name.values():
204+
del graph_signature.user_inputs_to_mutate[k]
205+
206+
for i, k in enumerate(graph_signature.user_outputs):
207+
if k in old_output_name_to_new_output_name:
208+
new_k = old_output_name_to_new_output_name[k]
209+
graph_signature.user_outputs[i] = new_k
210+
211+
return gm, graph_signature
212+
213+
214+
def _replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule):
164215
new_gm = _sequential_split_and_maybe_inline_subgraphs(gm)
165216

166217
# recursively call
@@ -169,7 +220,7 @@ def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule):
169220
subgm = getattr(new_gm, node.target)
170221
if not isinstance(subgm, torch.fx.GraphModule):
171222
continue
172-
new_subgm = replace_set_grad_with_hop_pass(subgm)
223+
new_subgm = _replace_set_grad_with_hop_pass(subgm)
173224
setattr(new_gm, node.target, new_subgm)
174225

175226
new_gm.recompile()

torch/export/_trace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def _compiling_state_context():
539539
replace_set_grad_with_hop_pass,
540540
)
541541

542-
gm = replace_set_grad_with_hop_pass(gm)
542+
gm, graph_signature = replace_set_grad_with_hop_pass(gm, graph_signature)
543543

544544
# Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes.
545545
for _mod in gm.modules():

torch/fx/passes/split_module.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,11 @@ def instantiate_node_partition_mapping(node):
457457
)
458458

459459
already_constructed_attr_nodes = set()
460+
461+
# We actually need to insert the placeholder nodes in the original order
462+
# otherwise graph signature will be wrong.
463+
original_order = [node for node in m.graph.nodes if node.op == "placeholder"]
464+
460465
for partition_name in construct_order_partitions:
461466
partition = partitions[partition_name]
462467

@@ -475,8 +480,17 @@ def instantiate_node_partition_mapping(node):
475480
if keep_original_order:
476481
# first get the attr nodes required by this partition
477482
orig_mod_attr_nodes: List[Node] = [
478-
orig_mod_env[key] for key in partition.inputs
483+
orig_mod_env[key] for key in partition.inputs if key not in original_order
479484
]
485+
486+
for node in original_order:
487+
if node in already_constructed_attr_nodes:
488+
continue # already added this attr to the base graph
489+
base_mod_env, based_mod_attrs = construct_graph(
490+
node, base_mod_env, base_mod_attrs
491+
)
492+
already_constructed_attr_nodes.add(node)
493+
480494
# Construct GraphModule for this partition
481495
for node in orig_mod_attr_nodes: # type: ignore[attr-defined]
482496
if node in already_constructed_attr_nodes:

0 commit comments

Comments
 (0)