Skip to content

Commit 6fa162e

Browse files
frank-weipytorchmergebot
authored andcommitted
Reland: [aotinductor] Replicate split_cat from torch IR to predispatch IR" (pytorch#118590)
Summary: This is part the pass migration efforts. The final target is removing the acc tracer in AOTI. In this diff, I did a few things: 1. copy and modify the `fx_passes/split_cat.py` passes based on predispatch IR. 2. verify the correctness by copying the `test_split_cat_fx_passes.py` and create a new file `test_split_cat_fx_passes_aten_fb.py` which is executed in AOTI and checked the counters 3. create a util function to execute the pass and compare the before/after graph to give user more information like pass effect and time spent. It will create logs like ``` [2024-01-25 20:26:48,997] torch._inductor.utils: [INFO] [Pre grad(predispatch IR)]Apply split_cat, index: 0, save before/after graph to /tmp/tmpvlpwrklp, graph before/after are the same = False, time elapsed = 0:00:00.001585 [2024-01-25 20:26:49,000] torch._inductor.utils: [INFO] [Pre grad(predispatch IR)]Apply split_cat, index: 1, save before/after graph to /tmp/tmpz_onjfeu, graph before/after are the same = False, time elapsed = 0:00:00.001873 [2024-01-25 20:26:49,002] torch._inductor.utils: [INFO] [Pre grad(predispatch IR)]Apply split_cat, index: 2, save before/after graph to /tmp/tmpgkck8yko, graph before/after are the same = True, time elapsed = 0:00:00.000269 [2024-01-25 20:26:49,007] torch._inductor.utils: [INFO] [Pre grad(predispatch IR)]Apply split_cat, index: 3, save before/after graph to /tmp/tmpquenq06y, graph before/after are the same = False, time elapsed = 0:00:00.003621 [2024-01-25 20:26:49,009] torch._inductor.utils: [INFO] [Pre grad(predispatch IR)]Apply split_cat, index: 4, save before/after graph to /tmp/tmpi8fia0dv, graph before/after are the same = True, time elapsed = 0:00:00.000190 ``` Differential Revision: D53171027 Pull Request resolved: pytorch#118590 Approved by: https://github.com/kflu, https://github.com/khabinov, https://github.com/chenyang78
1 parent 7761ceb commit 6fa162e

File tree

4 files changed

+75
-11
lines changed

4 files changed

+75
-11
lines changed

test/inductor/test_aot_inductor.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import sys
55
import tempfile
6+
import types
67
import unittest
78
from typing import Dict, Tuple
89

@@ -78,7 +79,8 @@ def check_model(
7879
}
7980
):
8081
torch.manual_seed(0)
81-
model = model.to(self.device)
82+
if not isinstance(model, types.FunctionType):
83+
model = model.to(self.device)
8284
ref_model = copy.deepcopy(model)
8385
ref_inputs = copy.deepcopy(example_inputs)
8486
expected = ref_model(*ref_inputs)

torch/_inductor/fx_passes/pre_grad.py

+36-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from torch.fx.passes.shape_prop import ShapeProp
1414
from torch.nn import functional as F
1515
from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights
16-
1716
from .. import config
1817

1918
from ..fx_utils import matches_module_function_pattern
@@ -22,7 +21,7 @@
2221
PatternMatcherPass,
2322
stable_topological_sort,
2423
)
25-
from ..utils import is_cpu_device
24+
from ..utils import is_cpu_device, pass_execution_and_save
2625
from .group_batch_fusion import group_batch_fusion_passes
2726
from .misc_patterns import numpy_compat_normalization
2827

@@ -35,6 +34,12 @@
3534
efficient_conv_bn_eval_pass = PatternMatcherPass(prevent_match_across_mutations=True)
3635
merge_getitem_cat_pass = PatternMatcherPass(prevent_match_across_mutations=True)
3736
predispatch_pass = PatternMatcherPass(prevent_match_across_mutations=True)
37+
# based on predispatch aten IR
38+
normalization_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
39+
merge_splits_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
40+
split_cat_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
41+
unbind_stack_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
42+
merge_getitem_cat_pass_aten = PatternMatcherPass(prevent_match_across_mutations=True)
3843

3944
pattern_matcher_passes: List[PatternMatcherPass] = [
4045
normalization_pass,
@@ -44,6 +49,13 @@
4449
unbind_stack_pass,
4550
efficient_conv_bn_eval_pass,
4651
]
52+
pattern_matcher_passes_aten: List[PatternMatcherPass] = [
53+
normalization_pass_aten,
54+
merge_getitem_cat_pass_aten,
55+
merge_splits_pass_aten,
56+
split_cat_pass_aten,
57+
unbind_stack_pass_aten,
58+
]
4759

4860

4961
@init_once_fakemode
@@ -66,7 +78,6 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs):
6678
Consider adding a new pass to post_grad.py or joint_graph.py which
6779
are after functionalization and normalization.
6880
"""
69-
7081
if config.pattern_matcher:
7182
lazy_init()
7283
if hasattr(
@@ -75,8 +86,28 @@ def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs):
7586
gm_before_fx_passes = gm.__copy__()
7687
# explicitly run with predispatch atenIR based passes
7788
if config.is_predispatch:
78-
group_batch_fusion_passes(gm.graph, pre_grad=True)
79-
predispatch_pass.apply(gm.graph) # type: ignore[arg-type]
89+
pass_execution_and_save(
90+
group_batch_fusion_passes,
91+
gm,
92+
"[Pre grad(predispatch IR)] Apply group_batch_fusion",
93+
)
94+
pass_execution_and_save(
95+
predispatch_pass.apply,
96+
gm,
97+
"[Pre grad(predispatch IR)] Apply predispatch_pass",
98+
)
99+
log.debug(
100+
"[Pre grad(predispatch IR)]Before split cat in pre grad pass. graph: %s",
101+
gm.graph,
102+
)
103+
for ind, pattern_matcher_pass_aten in enumerate(
104+
pattern_matcher_passes_aten
105+
):
106+
pass_execution_and_save(
107+
pattern_matcher_pass_aten.apply,
108+
gm,
109+
f"[Pre grad(predispatch IR)]Apply split_cat, index: {ind}",
110+
)
80111
else:
81112
gm = fuse_fx(gm, example_inputs)
82113
numpy_compat_normalization(gm.graph)

torch/_inductor/fx_passes/split_cat.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -299,11 +299,9 @@ class TorchSplit(CallFunction):
299299
splits are unique getitems.
300300
"""
301301

302-
def __init__(self, arg, sizes):
302+
def __init__(self, arg, sizes, func=torch.split):
303303
# using KeywordArg("dim") for `dim` checks they all match
304-
super().__init__(
305-
torch.split, arg, sizes, _users=MULTIPLE, dim=KeywordArg("dim")
306-
)
304+
super().__init__(func, arg, sizes, _users=MULTIPLE, dim=KeywordArg("dim"))
307305

308306
def _match(self, node: torch.fx.Node, ctx: MatchContext):
309307
m = super()._match(node, ctx)

torch/_inductor/utils.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import functools
77
import getpass
88
import inspect
9+
import io
910
import itertools
1011
import logging
1112
import math
@@ -19,6 +20,7 @@
1920
import textwrap
2021
import time
2122
import unittest
23+
from datetime import datetime
2224
from io import StringIO
2325
from typing import (
2426
Any,
@@ -45,7 +47,6 @@
4547
from torch.autograd import DeviceType
4648
from torch.autograd.profiler_util import EventList
4749
from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing
48-
4950
from . import config
5051

5152
log = logging.getLogger(__name__)
@@ -1225,3 +1226,35 @@ class Placeholder(enum.Enum):
12251226
# The descriptive name of the triton kernel; when unique_kernel_names = False, this
12261227
# placeholder will be replaced with a string with more information.
12271228
DESCRIPTIVE_NAME = "DESCRIPTIVE_NAME"
1229+
1230+
1231+
def pass_execution_and_save(func, gm, msg):
1232+
from .pattern_matcher import stable_topological_sort
1233+
1234+
with tempfile.NamedTemporaryFile(
1235+
mode="w",
1236+
encoding="utf-8",
1237+
delete=False,
1238+
) as f:
1239+
before_io = io.StringIO()
1240+
after_io = io.StringIO()
1241+
print(f"Before:\n{gm.graph}", file=f)
1242+
print(gm.graph, file=before_io)
1243+
start_time = datetime.now()
1244+
func(gm.graph)
1245+
time_elapsed = datetime.now() - start_time
1246+
# recompile graph
1247+
stable_topological_sort(gm.graph)
1248+
gm.graph.lint()
1249+
gm.recompile()
1250+
1251+
print(f"After:\n{gm.graph}", file=f)
1252+
print(gm.graph, file=after_io)
1253+
t = before_io.getvalue() == after_io.getvalue()
1254+
log.info(
1255+
"%s, save before/after graph to %s, graph before/after are the same = %s, time elapsed = %s",
1256+
msg,
1257+
f.name,
1258+
t,
1259+
time_elapsed,
1260+
)

0 commit comments

Comments
 (0)