New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix bug in graph partitioner and update graph signature after partitioning. #125133
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125133
Note: Links to docs will display an error until the docs builds have been completed. ❌ 27 New Failures, 1 Unrelated FailureAs of commit cddf9d6 with merge base faf0015 (): NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
d1c2102
to
8a2978d
Compare
8a2978d
to
c78101d
Compare
torch/export/_trace.py
Outdated
if node.op == "output": | ||
updated_output_names = node.args | ||
|
||
# It could be updated because replace_set_grad_with_hop_pass adds new |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addd a test case for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The internal diff has a torchrec unittest.
torch/fx/passes/split_module.py
Outdated
@@ -477,6 +477,16 @@ def instantiate_node_partition_mapping(node): | |||
orig_mod_attr_nodes: List[Node] = [ | |||
orig_mod_env[key] for key in partition.inputs | |||
] | |||
|
|||
# We actually need to insert the placeholder nodes in the original order |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The internal diff has a torchrec unittest.
torch/export/_trace.py
Outdated
updated_output_names = node.args | ||
|
||
# It could be updated because replace_set_grad_with_hop_pass adds new | ||
# getitem nodes into the graph. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we move this to replace_set_grad_with_hop_pass or there are other cases that also require this pass? Or we can organize the change into a pass or a function?
We should probably have a simplified version of the internal unittest in OSS so that we could get an early signal when we add new features in OSS in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now it is only replace_set_grad_with_hop_pass, but in theory, it could be useful for any pass that adds submodules in the graph. I can make it a free function.
c78101d
to
c9f67c6
Compare
0bd001d
to
a9d5962
Compare
# It could be updated because replace_set_grad_with_hop_pass adds new | ||
# getitem nodes into the graph. | ||
assert len(original_output_nodes) == len(updated_output_nodes) | ||
old_output_name_to_new_output_name = {} | ||
for k, v in zip(*original_output_nodes, *updated_output_nodes): | ||
if isinstance(k, torch.fx.Node) and isinstance(k, torch.fx.Node): | ||
old_output_name_to_new_output_name[k.name] = v.name | ||
# If there is constant return in the end, it should be true for updated outputs too | ||
else: | ||
constant_types = (int, float, str, type(None)) | ||
assert k == v | ||
assert isinstance(k, constant_types) | ||
|
||
buffers_to_mutate_copy = graph_signature.buffers_to_mutate.copy() | ||
user_inputs_to_mutate_copy = graph_signature.user_inputs_to_mutate.copy() | ||
for k in old_output_name_to_new_output_name: | ||
if k in graph_signature.buffers_to_mutate: | ||
graph_signature.buffers_to_mutate[ | ||
old_output_name_to_new_output_name[k] | ||
] = buffers_to_mutate_copy[k] | ||
if k not in old_output_name_to_new_output_name.values(): | ||
del graph_signature.buffers_to_mutate[k] | ||
|
||
if k in graph_signature.user_inputs_to_mutate: | ||
graph_signature.user_inputs_to_mutate[ | ||
old_output_name_to_new_output_name[k] | ||
] = user_inputs_to_mutate_copy[k] | ||
if k not in old_output_name_to_new_output_name.values(): | ||
del graph_signature.user_inputs_to_mutate[k] | ||
|
||
for i, k in enumerate(graph_signature.user_outputs): | ||
if k in old_output_name_to_new_output_name: | ||
new_k = old_output_name_to_new_output_name[k] | ||
graph_signature.user_outputs[i] = new_k |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can get away these regeneration code with:
with graph_module._set_replace_hook(ep.graph_signature.get_replace_hook():
_replace_set_grad_with_hop_pass(gm)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm we actually cannot use ep here because this is before we create ep. We need it before ep because ep verifier won't pass as we have non-functional ops in the graph at this point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq: is ep.graph_signature the same as the type as the graph signature here? If so, we could still use the method?
A bunch of loop and mapping is difficult to parse what it does at first sight lol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
discussed offline with @tugsbayasgalan, we should still use something more structural here. will follow up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry wrong button clicked lol
@@ -160,7 +160,58 @@ def _maybe_inline_or_replace_with_hop(node: torch.fx.Node): | |||
return gm | |||
|
|||
|
|||
def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule): | |||
def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule, graph_signature): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe a type annotation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah adding a type annotation here is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good to me. Left some minor comments on code structure.
a9d5962
to
d334c5f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some nits.
@@ -160,7 +160,58 @@ def _maybe_inline_or_replace_with_hop(node: torch.fx.Node): | |||
return gm | |||
|
|||
|
|||
def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule): | |||
def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule, graph_signature): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah adding a type annotation here is better.
def _maybe_inline_or_replace_with_hop(node: torch.fx.Node): | ||
if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True): | ||
_replace_with_hop(node) | ||
else: | ||
_remove_set_grad_and_inline(node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can define this function out of the context manager.
d334c5f
to
5b6e078
Compare
This pull request was exported from Phabricator. Differential Revision: D56688411 |
5b6e078
to
e83a1e8
Compare
This pull request was exported from Phabricator. Differential Revision: D56688411 |
e83a1e8
to
7842ce4
Compare
Summary: Pull Request resolved: pytorch#125133 Title Test Plan: CI Reviewed By: PaulZhang12 Differential Revision: D56688411
Summary: X-link: pytorch/pytorch#125133 Title Reviewed By: PaulZhang12 Differential Revision: D56688411
7842ce4
to
0630e97
Compare
Summary: Title Test Plan: CI Reviewed By: PaulZhang12 Differential Revision: D56688411
This pull request was exported from Phabricator. Differential Revision: D56688411 |
Summary: X-link: pytorch/pytorch#125133 Title Reviewed By: PaulZhang12 Differential Revision: D56688411
Summary: X-link: pytorch/torchrec#1934 Title Test Plan: CI Reviewed By: PaulZhang12 Differential Revision: D56688411
0630e97
to
def6fb1
Compare
This pull request was exported from Phabricator. Differential Revision: D56688411 |
def6fb1
to
24f8cb9
Compare
Summary: X-link: pytorch/torchrec#1934 Title Test Plan: CI Reviewed By: PaulZhang12 Differential Revision: D56688411
This pull request was exported from Phabricator. Differential Revision: D56688411 |
Summary: X-link: pytorch/torchrec#1934 Title Test Plan: CI Reviewed By: PaulZhang12 Differential Revision: D56688411
24f8cb9
to
98753f6
Compare
This pull request was exported from Phabricator. Differential Revision: D56688411 |
98753f6
to
321e839
Compare
Summary: X-link: pytorch/torchrec#1934 Title Test Plan: CI Reviewed By: PaulZhang12 Differential Revision: D56688411
This pull request was exported from Phabricator. Differential Revision: D56688411 |
Summary: X-link: pytorch/torchrec#1934 Title Test Plan: CI Reviewed By: PaulZhang12 Differential Revision: D56688411
321e839
to
60b1ec3
Compare
This pull request was exported from Phabricator. Differential Revision: D56688411 |
Summary: X-link: pytorch/torchrec#1934 Title Test Plan: CI Reviewed By: PaulZhang12 Differential Revision: D56688411
60b1ec3
to
cddf9d6
Compare
This pull request was exported from Phabricator. Differential Revision: D56688411 |
Summary: This fix does two things:
Test Plan: CI
Differential Revision: D56688411