Skip to content

Commit

Permalink
Fix bug in graph partitioner (pytorch#125133)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/torchrec#1934


Title

Test Plan: CI

Reviewed By: PaulZhang12

Differential Revision: D56688411
  • Loading branch information
tugsbayasgalan authored and facebook-github-bot committed May 7, 2024
1 parent 7864d28 commit def6fb1
Show file tree
Hide file tree
Showing 6 changed files with 321 additions and 245 deletions.
73 changes: 69 additions & 4 deletions test/export/test_export.py
Expand Up @@ -92,6 +92,11 @@
"(Tensor x) -> (Tensor)",
tags=torch.Tag.pt2_compliant_tag,
)
torch.library.define(
"testlib::foo_unbacked",
"(Scalar x) -> (Tensor)",
tags=torch.Tag.pt2_compliant_tag,
)


@torch.library.impl("testlib::returns_tensor_symint", "cpu")
Expand Down Expand Up @@ -125,6 +130,15 @@ def foo_functional(x):
return a.cos()


@torch.library.impl("testlib::foo_unbacked", "CompositeImplicitAutograd")
def foo_unbacked(x):
if x > 2:
return torch.ones(4, 4)
if x < 6:
return torch.ones(4, 4)
return torch.ones(4, 4)


@dataclass
class Inp:
x: Tensor
Expand Down Expand Up @@ -4052,16 +4066,16 @@ def forward(self, b_pred, b_t, x, y):
"""\
def forward(self, b_t, x, y):
submod_3 = self.submod_1
add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_3, b_t, x, y); submod_3 = b_t = x = y = None
add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_3, x, b_t, y); submod_3 = x = b_t = y = None
return (add_1,)""",
)

self.assertExpectedInline(
str(exported_program.graph_module.true_graph_0.submod_1.code.strip()),
"""\
def forward(self, b_t, x, y):
sub = torch.ops.aten.sub.Tensor(b_t, 1); b_t = None
add = torch.ops.aten.add.Tensor(sub, x); sub = x = None
def forward(self, x, b_t, y):
sub = torch.ops.aten.sub.Tensor(x, 1); x = None
add = torch.ops.aten.add.Tensor(sub, b_t); sub = b_t = None
add_1 = torch.ops.aten.add.Tensor(add, y); add = y = None
return add_1""",
)
Expand Down Expand Up @@ -4561,6 +4575,57 @@ def forward(self, x, y, div="floor"):
self.assertEqual(div_spec.arg.name, "div")
self.assertEqual(div_spec.arg.value, "floor")

def test_unbacked_deferred_runtime_retrace(self):
class Foo(torch.nn.Module):
def forward(self, x, y):
y_sum = y.sin().sum()
with torch.no_grad():
a = x.item()
torch._check_is_size(a)
torch._check(a > 2)
torch._check(a < 6)
unbacked_shape = torch.ops.testlib.foo_unbacked(a)
return y + y_sum + unbacked_shape.sum()

inps = (torch.tensor(4), torch.randn(5, 5))
from torch.export import _trace
ep_pre = _trace._export(Foo(), inps, pre_dispatch=True, strict=False)
self.assertExpectedInline(str(ep_pre.graph_module.submod_1.code).strip(), """\
def forward(self, x):
item = torch.ops.aten.item.default(x); x = None
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item)
mul = -1 * item
le = mul <= 0; mul = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression -u1 <= 0 on node 'le'\\nMore context: %mul : [num_users=1] = call_function[target=operator.mul](args = (-1, %item), kwargs = {})\\n%le : [num_users=0] = call_function[target=operator.le](args = (%mul, 0), kwargs = {})"); le = None
mul_1 = -1 * item
lt = mul_1 < -2; mul_1 = None
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(lt, "Runtime assertion failed for expression -u1 < -2 on node 'lt'\\nMore context: %mul_1 : [num_users=1] = call_function[target=operator.mul](args = (-1, %item), kwargs = {})\\n%lt : [num_users=0] = call_function[target=operator.lt](args = (%mul_1, -2), kwargs = {})"); lt = None
lt_1 = item < 6
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(lt_1, "Runtime assertion failed for expression u1 < 6 on node 'lt_1'\\nMore context: %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%lt, Runtime assertion failed for expression -u1 < -2 on node 'lt'\\nMore context: %mul_1 : [num_users=1] = call_function[target=operator.mul](args = (-1, %item), kwargs = {})\\n%lt : [num_users=0] = call_function[target=operator.lt](args = (%mul_1, -2), kwargs = {})), kwargs = {})\\n%lt_1 : [num_users=0] = call_function[target=operator.lt](args = (%item, 6), kwargs = {})"); lt_1 = None
foo_unbacked = torch.ops.testlib.foo_unbacked.default(item); item = None
return foo_unbacked""")
ep_aot = ep_pre.run_decompositions()
self.assertExpectedInline(str(ep_aot.graph_module.code).strip(), """\
def forward(self, x, y):
sin = torch.ops.aten.sin.default(y)
sum_1 = torch.ops.aten.sum.dim_IntList(sin, []); sin = None
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x); x = None
sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense)
mul = -1 * _local_scalar_dense
le = mul <= 0; mul = None
_assert_scalar = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression -u1 <= 0 on node 'le'\\nMore context: %mul : [num_users=1] = call_function[target=operator.mul](args = (-1, %item), kwargs = {})\\n%le : [num_users=0] = call_function[target=operator.le](args = (%mul, 0), kwargs = {})"); le = None
mul_1 = -1 * _local_scalar_dense
lt = mul_1 < -2; mul_1 = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(lt, "Runtime assertion failed for expression -u1 < -2 on node 'lt'\\nMore context: %mul_1 : [num_users=1] = call_function[target=operator.mul](args = (-1, %item), kwargs = {})\\n%lt : [num_users=0] = call_function[target=operator.lt](args = (%mul_1, -2), kwargs = {})"); lt = None
lt_1 = _local_scalar_dense < 6; _local_scalar_dense = None
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(lt_1, "Runtime assertion failed for expression u1 < 6 on node 'lt_1'\\nMore context: %_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%lt, Runtime assertion failed for expression -u1 < -2 on node 'lt'\\nMore context: %mul_1 : [num_users=1] = call_function[target=operator.mul](args = (-1, %item), kwargs = {})\\n%lt : [num_users=0] = call_function[target=operator.lt](args = (%mul_1, -2), kwargs = {})), kwargs = {})\\n%lt_1 : [num_users=0] = call_function[target=operator.lt](args = (%item, 6), kwargs = {})"); lt_1 = None
full = torch.ops.aten.full.default([4, 4], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
add = torch.ops.aten.add.Tensor(y, sum_1); y = sum_1 = None
sum_2 = torch.ops.aten.sum.dim_IntList(full, []); full = None
add_1 = torch.ops.aten.add.Tensor(add, sum_2); add = sum_2 = None
return (add_1,)""")


def test_nested_dynamic_shapes_spec(self):
class Foo(torch.nn.Module):
def forward(self, x):
Expand Down
46 changes: 27 additions & 19 deletions torch/_export/passes/replace_set_grad_with_hop_pass.py
@@ -1,3 +1,4 @@
import contextlib
import copy

import torch
Expand Down Expand Up @@ -125,7 +126,9 @@ def _remove_set_grad_and_inline(node: torch.fx.Node):
node_inline_(node)


def _sequential_split_and_maybe_inline_subgraphs(gm: torch.fx.GraphModule):
def _sequential_split_and_maybe_inline_subgraphs(
gm: torch.fx.GraphModule, graph_signature
):
"""
Helper function for replace_set_grad_with_hop_pass().
Split the graph module into multiple subgraphs based on the set_grad_enabled nodes.
Expand All @@ -141,35 +144,40 @@ def _sequential_split_and_maybe_inline_subgraphs(gm: torch.fx.GraphModule):
if need_replacing:
new_gm = sequential_split(gm, _is_set_grad_enabled_node)

def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True):
_replace_with_hop(node)
else:
_remove_set_grad_and_inline(node)

nodes_map(
list(new_gm.graph.nodes),
lambda node: (
_maybe_inline_or_replace_with_hop(node)
if node.op == "call_module"
else node
),
)
replace_ctx = contextlib.nullcontext()
if graph_signature is not None:
replace_ctx = new_gm._set_replace_hook(graph_signature.get_replace_hook()) # type: ignore[assignment]

with replace_ctx:

def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True):
_replace_with_hop(node)
else:
_remove_set_grad_and_inline(node)

nodes_map(
list(new_gm.graph.nodes),
lambda node: (
_maybe_inline_or_replace_with_hop(node)
if node.op == "call_module"
else node
),
)
return new_gm

return gm


def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule):
new_gm = _sequential_split_and_maybe_inline_subgraphs(gm)

def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule, graph_signature):
new_gm = _sequential_split_and_maybe_inline_subgraphs(gm, graph_signature)
# recursively call
for node in new_gm.graph.nodes:
if node.op == "get_attr":
subgm = getattr(new_gm, node.target)
if not isinstance(subgm, torch.fx.GraphModule):
continue
new_subgm = replace_set_grad_with_hop_pass(subgm)
new_subgm = replace_set_grad_with_hop_pass(subgm, None)
setattr(new_gm, node.target, new_subgm)

new_gm.recompile()
Expand Down
70 changes: 35 additions & 35 deletions torch/export/_trace.py
Expand Up @@ -507,41 +507,6 @@ def _compiling_state_context():
if isinstance(mod, torch.fx.GraphModule) and hasattr(mod, "meta"):
gm.meta.update(mod.meta)

if pre_dispatch:
from torch._export.passes.replace_set_grad_with_hop_pass import (
replace_set_grad_with_hop_pass,
)

gm = replace_set_grad_with_hop_pass(gm)

# Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes.
for _mod in gm.modules():
if not isinstance(_mod, torch.fx.GraphModule):
continue
for node in _mod.graph.nodes:
if node.op in ["placeholder", "output"]:
node.meta.pop("nn_module_stack", None)
node.meta.pop("stack_trace", None)

# NOTE: aot_export adds symint metadata for placeholders with int values;
# since these become specialized, we replace such metadata with the original values
flat_args = pytree.tree_leaves((fake_args, fake_kwargs))
index = 0
total_non_user_inputs = (
len(graph_signature.parameters)
+ len(graph_signature.buffers)
+ len(graph_signature.input_tokens)
)
for node in gm.graph.nodes:
if node.op == "placeholder":
if index >= total_non_user_inputs:
user_arg = flat_args[index - total_non_user_inputs]
if not isinstance(user_arg, torch.Tensor):
node.meta["val"] = user_arg
index += 1

is_joint = graph_signature.backward_signature is not None

def make_argument_spec(i, node) -> ArgumentSpec:
if isinstance(node, (int, bool, float, type(None))):
# For const outputs we just directly return this
Expand Down Expand Up @@ -570,6 +535,25 @@ def make_argument_spec(i, node) -> ArgumentSpec:
f"while writing the metadata for exported program"
)

is_joint = graph_signature.backward_signature is not None

# NOTE: aot_export adds symint metadata for placeholders with int values;
# since these become specialized, we replace such metadata with the original values
flat_args = pytree.tree_leaves((fake_args, fake_kwargs))
index = 0
total_non_user_inputs = (
len(graph_signature.parameters)
+ len(graph_signature.buffers)
+ len(graph_signature.input_tokens)
)
for node in gm.graph.nodes:
if node.op == "placeholder":
if index >= total_non_user_inputs:
user_arg = flat_args[index - total_non_user_inputs]
if not isinstance(user_arg, torch.Tensor):
node.meta["val"] = user_arg
index += 1

input_specs, output_specs = _sig_to_specs(
user_inputs=set(graph_signature.user_inputs),
inputs_to_parameters=graph_signature.inputs_to_parameters, # type: ignore[arg-type]
Expand Down Expand Up @@ -598,6 +582,22 @@ def make_argument_spec(i, node) -> ArgumentSpec:
input_specs=input_specs, output_specs=output_specs
)

if pre_dispatch:
from torch._export.passes.replace_set_grad_with_hop_pass import (
replace_set_grad_with_hop_pass,
)

gm = replace_set_grad_with_hop_pass(gm, export_graph_signature)

# Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes.
for _mod in gm.modules():
if not isinstance(_mod, torch.fx.GraphModule):
continue
for node in _mod.graph.nodes:
if node.op in ["placeholder", "output"]:
node.meta.pop("nn_module_stack", None)
node.meta.pop("stack_trace", None)

constants = rewrite_script_object_meta(gm)
constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs))

Expand Down
7 changes: 0 additions & 7 deletions torch/export/exported_program.py
Expand Up @@ -662,13 +662,6 @@ def update_arg(old_arg, new_ph):

_replace_sym_size_ops_pass(gm)

if len(new_range_constraints) > 0:
res = _AddRuntimeAssertionsForInlineConstraintsPass(new_range_constraints)(
gm
)
assert res is not None
gm = res.graph_module

exported_program = ExportedProgram(
root=gm,
graph=gm.graph,
Expand Down

0 comments on commit def6fb1

Please sign in to comment.