Skip to content

Commit

Permalink
[PT2][Optimus] Read the patterns from the config instead of hard-code…
Browse files Browse the repository at this point in the history
… passes (pytorch#125136)

Summary: Due to the compatitbility issue, we hard coded the passes to do the pattern optimization. Here, we revisit the method since it has been a while for the changes into production packages. We instead read from the config to decide whether we do the specific pattern optimization, which makes followup pattern add easier.

Differential Revision: D56659934

Pull Request resolved: pytorch#125136
Approved by: https://github.com/jackiexu1992
  • Loading branch information
mengluy0125 authored and andoorve committed May 1, 2024
1 parent 92ccece commit 874d7aa
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 92 deletions.
13 changes: 13 additions & 0 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,19 @@ def forward(self, x, y):
with config.patch({"freezing": True}):
self.check_model(Model(self.device), example_inputs)

@torch._inductor.config.patch(
pre_grad_fusion_options={
"normalization_pass": {},
"remove_split_with_size_one_pass": {},
"merge_getitem_cat_pass": {},
"merge_stack_tahn_unbind_pass": {},
"merge_splits_pass": {},
"mutate_cat_pass": {},
"split_cat_pass": {},
"unbind_stack_pass": {},
},
post_grad_fusion_options={},
)
def test_simple_split(self):
class Model(torch.nn.Module):
def __init__(self):
Expand Down
4 changes: 3 additions & 1 deletion test/inductor/test_decompose_mem_bound_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def forward(self, input1, input2):

@requires_cuda
@torch._inductor.config.patch(
decompose_mem_bound_mm=True,
post_grad_fusion_options={
"decompose_mm_pass": {},
}
)
@instantiate_parametrized_tests
class TestDecomposeMemMM(TestCase):
Expand Down
17 changes: 13 additions & 4 deletions test/inductor/test_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def fn(a, b, c):
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
]
self.common(fn, args, 2, 5)
self.common(fn, args, 1, 4)

def test_cat_addmm(self):
def fn(a, b, c):
Expand All @@ -538,7 +538,7 @@ def fn(a, b, c):
torch.randn(16, 16, device="cuda"),
torch.randn(16, 16, device="cuda"),
]
self.common(fn, args, 2, 5)
self.common(fn, args, 1, 4)

def test_cat_slice_cat_cuda(self):
def fn(a, b):
Expand Down Expand Up @@ -839,7 +839,9 @@ def foo(x, y):

def test_match_with_mutation(self):
counter = 0
test_pass = PatternMatcherPass(prevent_match_across_mutations=True)
test_pass = PatternMatcherPass(
prevent_match_across_mutations=True, pass_name="test"
)

@register_graph_pattern(
CallFunction(
Expand Down Expand Up @@ -892,7 +894,14 @@ def fn5(x, y):
]

with unittest.mock.patch(
"torch._inductor.fx_passes.pre_grad.pattern_matcher_passes", [test_pass]
"torch._inductor.fx_passes.pre_grad.config.pre_grad_fusion_options",
{"test": {}},
), unittest.mock.patch(
"torch._inductor.fx_passes.pre_grad.PRE_GRAD_FUSIONS",
[],
), unittest.mock.patch(
"torch._inductor.fx_passes.pre_grad.PRE_GRAD_PATTERNS",
{"test": test_pass},
):
for fn in (fn0, fn1, fn2, fn3, fn4, fn5):
counter = 0
Expand Down
19 changes: 17 additions & 2 deletions test/inductor/test_split_cat_fx_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,19 @@


def patch(f):
f = torch._inductor.config.patch(split_cat_fx_passes=True)(f)
f = torch._inductor.config.patch(
pre_grad_fusion_options={
"normalization_pass": {},
"remove_split_with_size_one_pass": {},
"merge_getitem_cat_pass": {},
"merge_stack_tahn_unbind_pass": {},
"merge_splits_pass": {},
"mutate_cat_pass": {},
"split_cat_pass": {},
"unbind_stack_pass": {},
},
post_grad_fusion_options={},
)(f)
return f


Expand Down Expand Up @@ -605,7 +617,10 @@ def multi_split_cat(x1, x2):
)
counters.clear()

@torch._inductor.config.patch(split_cat_fx_passes=False)
@torch._inductor.config.patch(
pre_grad_fusion_options={},
post_grad_fusion_options={},
)
def test_config_flag_is_respected(self):
def split_with_cat(x):
fs = torch.split(x, [4, 4, 24], dim=-1)
Expand Down
5 changes: 1 addition & 4 deletions torch/_inductor/fx_passes/decompose_mem_bound_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .. import config

from ..pattern_matcher import Arg, CallFunction, Match, register_graph_pattern
from .split_cat import construct_pattern_matcher_pass, get_config_flag
from .split_cat import construct_pattern_matcher_pass

aten = torch.ops.aten
log = logging.getLogger(__name__)
Expand Down Expand Up @@ -94,7 +94,6 @@ def print_decompose_pattern(match: Match, inputs: List[torch.fx.Node]):
@register_graph_pattern(
CallFunction(aten.bmm, Arg(), Arg()),
pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"),
extra_check=get_config_flag("decompose_mm_pass", "decompose_mem_bound_mm"),
)
def decompose_bmm(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node):
def repl(mat1, mat2):
Expand All @@ -111,7 +110,6 @@ def repl(mat1, mat2):
@register_graph_pattern(
CallFunction(aten.addmm, Arg(), Arg(), Arg()),
pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"),
extra_check=get_config_flag("decompose_mm_pass", "decompose_mem_bound_mm"),
)
def decompose_addmm(
match: Match,
Expand All @@ -133,7 +131,6 @@ def repl(mat1, mat2, mat3):
@register_graph_pattern(
CallFunction(aten.mm, Arg(), Arg()),
pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"),
extra_check=get_config_flag("decompose_mm_pass", "decompose_mem_bound_mm"),
)
def decompose_mm(
match: Match,
Expand Down
3 changes: 0 additions & 3 deletions torch/_inductor/fx_passes/group_batch_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,6 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):

if all(bias is None for bias in group_biases):
group_biases = None # type: ignore[assignment]
group_biases: Optional[List[Any]]

with graph.inserting_before(subset[0]):
fused_mm = graph.call_function(
Expand Down Expand Up @@ -649,10 +648,8 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]):

if all(bias is None for bias in group_biases):
group_biases = None # type: ignore[assignment]
group_biases: Optional[List[Any]]
if all(weight is None for weight in group_weights):
group_weights = None # type: ignore[assignment]
group_weights: Optional[List[Any]]
assert all(
eps == group_epss[0] for eps in group_epss
), "all epsilon values must be equal"
Expand Down
9 changes: 6 additions & 3 deletions torch/_inductor/fx_passes/post_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from ..utils import decode_device, is_pointwise_use
from ..virtualized import V
from .ddp_fusion import fuse_ddp_communication
from .group_batch_fusion import group_batch_fusion_passes
from .group_batch_fusion import group_batch_fusion_passes, POST_GRAD_FUSIONS
from .pre_grad import is_same_dict, save_inductor_dict
from .reinplace import reinplace_inplaceable_ops
from .split_cat import POST_GRAD_PATTERNS
Expand All @@ -54,7 +54,6 @@
aten = torch.ops.aten
prims = torch.ops.prims

pattern_matcher_passes = POST_GRAD_PATTERNS.values()
# First pass_patterns[0] are applied, then [1], then [2]
pass_patterns = [
PatternMatcherPass(),
Expand Down Expand Up @@ -89,7 +88,11 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
remove_noop_ops(gm.graph)
for patterns in pass_patterns:
patterns.apply(gm.graph) # type: ignore[arg-type]
for pattern_matcher_pass in pattern_matcher_passes:
for pass_name in config.post_grad_fusion_options:
# skip all patterns for group batch fusions
if pass_name in POST_GRAD_FUSIONS:
continue
pattern_matcher_pass = POST_GRAD_PATTERNS[pass_name]
inductor_before_change = save_inductor_dict(
[pattern_matcher_pass.pass_name]
)
Expand Down
20 changes: 10 additions & 10 deletions torch/_inductor/fx_passes/pre_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
stable_topological_sort,
)
from ..utils import is_cpu_device, pass_execution_and_save
from .group_batch_fusion import group_batch_fusion_passes
from .group_batch_fusion import group_batch_fusion_passes, PRE_GRAD_FUSIONS
from .misc_patterns import numpy_compat_normalization
from .split_cat import PRE_GRAD_PATTERNS

Expand Down Expand Up @@ -85,12 +85,6 @@ def remove_split_ops(graph, shape_prop):
return None


# split_cat related fusions
pattern_matcher_passes = list(PRE_GRAD_PATTERNS.values())
# non-split_cat related fusions
# TODO: move them to the fusions dict too.
pattern_matcher_passes.append(efficient_conv_bn_eval_pass)

pattern_matcher_passes_aten: List[PatternMatcherPass] = [
remove_split_with_size_one_pass_aten,
merge_getitem_cat_pass_aten,
Expand Down Expand Up @@ -134,6 +128,7 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs=None):
def shape_prop(mod) -> None:
ShapeProp(
gm=mod,
# pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode`
fake_mode=detect_fake_mode(example_inputs),
).propagate(*example_inputs)

Expand Down Expand Up @@ -202,10 +197,13 @@ def shape_prop(mod) -> None:
if example_inputs is not None:
gm = fuse_fx(gm, example_inputs)
numpy_compat_normalization(gm.graph)

optimus_scuba_log["before_recompile_pre_grad"] = upload_graph(gm.graph)
group_batch_fusion_passes(gm.graph, pre_grad=True)
for pattern_matcher_pass in pattern_matcher_passes:
for pass_name in config.pre_grad_fusion_options:
# skip all patterns for group batch fusions
if pass_name in PRE_GRAD_FUSIONS:
continue
pattern_matcher_pass = PRE_GRAD_PATTERNS[pass_name]
inductor_before_change = save_inductor_dict(
[pattern_matcher_pass.pass_name]
)
Expand All @@ -214,6 +212,8 @@ def shape_prop(mod) -> None:
optimus_scuba_log[
f"{pattern_matcher_pass.pass_name}_pre_grad"
] = upload_graph(gm.graph)
# TODO: move efficient_conv_bn_eval_pass to the fusions dict too.
efficient_conv_bn_eval_pass.apply(gm.graph) # type: ignore[arg-type]

if config.pre_grad_custom_pass is not None:
config.pre_grad_custom_pass(gm.graph)
Expand Down Expand Up @@ -249,7 +249,7 @@ def shape_prop(mod) -> None:

def fuse_fx(gm: torch.fx.GraphModule, example_inputs) -> torch.fx.GraphModule:
is_cpu = is_cpu_device(example_inputs)

# pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode`
fake_mode = detect_fake_mode(example_inputs)

gm = sink_cat_after_pointwise(gm)
Expand Down

0 comments on commit 874d7aa

Please sign in to comment.