Skip to content

Commit 338d3f4

Browse files
committed
Merge improve_collapse_tuple
2 parents f5d9146 + 00bbf2b commit 338d3f4

File tree

9 files changed

+213
-150
lines changed

9 files changed

+213
-150
lines changed

src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]:
2727
)
2828

2929

30-
def is_let(node: itir.Node) -> bool:
30+
def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]:
3131
"""Match expression of the form `(λ(...) → ...)(...)`."""
3232
return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda)
3333

3434

35-
def is_if_call(node: itir.Expr):
35+
def is_if_call(node: itir.Expr) -> TypeGuard[itir.FunCall]:
3636
"""Match expression of the form `if_(cond, true_branch, false_branch)`."""
3737
return isinstance(node, itir.FunCall) and node.fun == im.ref("if_")

src/gt4py/next/iterator/ir_utils/ir_makers.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -250,12 +250,10 @@ class let:
250250
"""
251251

252252
@typing.overload
253-
def __init__(self, var: str | itir.Sym, init_form: itir.Expr):
254-
...
253+
def __init__(self, var: str | itir.Sym, init_form: itir.Expr): ...
255254

256255
@typing.overload
257-
def __init__(self, *args: Iterable[tuple[str | itir.Sym, itir.Expr]]):
258-
...
256+
def __init__(self, *args: Iterable[tuple[str | itir.Sym, itir.Expr]]): ...
259257

260258
def __init__(self, *args):
261259
if all(isinstance(arg, tuple) and len(arg) == 2 for arg in args):
@@ -268,7 +266,7 @@ def __init__(self, *args):
268266
self.init_forms = [args[1]]
269267
else:
270268
raise TypeError(
271-
"Invalid arguments. Expected a variable name and an init form or a list thereof."
269+
"Invalid arguments: expected a variable name and an init form or a list thereof."
272270
)
273271

274272
def __call__(self, form):

src/gt4py/next/iterator/ir_utils/misc.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,16 @@ def visit_SymRef(self, node: itir.SymRef, *, sym_map: dict[str, str]):
5757
return im.ref(sym_map[node.id]) if node.id in sym_map else node
5858

5959

60-
def is_provable_equal(a: itir.Expr, b: itir.Expr):
60+
def is_equal(a: itir.Expr, b: itir.Expr):
6161
"""
62-
Return true if, bot not only if, two expression (with equal scope) have the same value.
62+
Return true if, but not only if, two expression (with equal scope) have the same value.
63+
64+
Be aware that this function might return false even though the two expression have the same
65+
value.
6366
6467
>>> testee1 = im.lambda_("a")(im.plus("a", "b"))
6568
>>> testee2 = im.lambda_("c")(im.plus("c", "b"))
66-
>>> assert is_provable_equal(testee1, testee2)
69+
>>> assert is_equal(testee1, testee2)
6770
"""
6871
return a == b or (
6972
CannonicalizeBoundSymbolNames.apply(a) == CannonicalizeBoundSymbolNames.apply(b)

src/gt4py/next/iterator/transforms/collapse_tuple.py

+124-106
Large diffs are not rendered by default.

src/gt4py/next/iterator/transforms/inline_lambdas.py

-2
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,6 @@ class InlineLambdas(PreserveLocationVisitor, NodeTranslator):
137137

138138
force_inline_trivial_lift_args: bool
139139

140-
force_inline_lambda_args: bool
141-
142140
@classmethod
143141
def apply(
144142
cls,

src/gt4py/next/iterator/transforms/pass_manager.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -165,15 +165,17 @@ def main_transforms(
165165
inlined = ConstantFolding.apply(inlined)
166166
# This pass is required to be in the loop such that when an `if_` call with tuple arguments
167167
# is constant-folded the surrounding tuple_get calls can be removed.
168-
if stage == 1:
169-
inlined = CollapseTuple.apply(
170-
inlined,
171-
# to limit number of times global type inference is executed, only in the last iterations.
172-
# use_global_type_inference=inlined == ir,
173-
ignore_tuple_size=True, # possibly dangerous
174-
use_global_type_inference=False,
175-
)
176-
inlined = PropagateDeref.apply(inlined) # todo: document
168+
inlined = CollapseTuple.apply(
169+
inlined,
170+
# to limit number of times global type inference is executed, only in the last iterations.
171+
# use_global_type_inference=inlined == ir,
172+
ignore_tuple_size=True, # possibly dangerous
173+
use_global_type_inference=False,
174+
)
175+
# This pass is required such that a deref outside of a
176+
# `tuple_get(make_tuple(let(...), ...))` call is propagated into the let after the
177+
# `tuple_get` is removed by the `CollapseTuple` pass.
178+
inlined = PropagateDeref.apply(inlined)
177179

178180
if inlined == ir:
179181
stage += 1

src/gt4py/next/program_processors/runners/gtfn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from gt4py.next.otf import languages, recipes, stages, step_types, workflow
2727
from gt4py.next.otf.binding import nanobind
2828
from gt4py.next.otf.compilation import cache, compiler
29-
from gt4py.next.otf.compilation.build_systems import cmake, compiledb
29+
from gt4py.next.otf.compilation.build_systems import compiledb
3030
from gt4py.next.program_processors import otf_compile_executor
3131
from gt4py.next.program_processors.codegens.gtfn import gtfn_module
3232
from gt4py.next.type_system.type_translation import from_value
@@ -130,7 +130,7 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int:
130130
)
131131

132132
GTFN_DEFAULT_COMPILE_STEP: step_types.CompilationStep = compiler.Compiler(
133-
cache_strategy=cache.Strategy.PERSISTENT, builder_factory=cmake.CMakeFactory()
133+
cache_strategy=cache.Strategy.SESSION, builder_factory=compiledb.CompiledbFactory()
134134
)
135135

136136

tests/next_tests/definitions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
148148
COMMON_SKIP_TEST_LIST = [
149149
(REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE),
150150
(USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE),
151-
# (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE),
151+
(USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE),
152152
(USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE),
153153
(USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE),
154154
(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE),

tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py

+64-20
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ def test_simple_make_tuple_tuple_get():
2020
tuple_of_size_2 = im.make_tuple("first", "second")
2121
testee = im.make_tuple(im.tuple_get(0, tuple_of_size_2), im.tuple_get(1, tuple_of_size_2))
2222

23-
actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET)
23+
actual = CollapseTuple.apply(
24+
testee,
25+
remove_letified_make_tuple_elements=False,
26+
flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET,
27+
)
2428

2529
expected = tuple_of_size_2
2630
assert actual == expected
@@ -32,7 +36,11 @@ def test_nested_make_tuple_tuple_get():
3236
im.tuple_get(0, tup_of_size2_from_lambda), im.tuple_get(1, tup_of_size2_from_lambda)
3337
)
3438

35-
actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET)
39+
actual = CollapseTuple.apply(
40+
testee,
41+
remove_letified_make_tuple_elements=False,
42+
flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET,
43+
)
3644

3745
assert actual == tup_of_size2_from_lambda
3846

@@ -42,21 +50,33 @@ def test_different_tuples_make_tuple_tuple_get():
4250
t1 = im.make_tuple("foo1", "bar1")
4351
testee = im.make_tuple(im.tuple_get(0, t0), im.tuple_get(1, t1))
4452

45-
actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET)
53+
actual = CollapseTuple.apply(
54+
testee,
55+
remove_letified_make_tuple_elements=False,
56+
flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET,
57+
)
4658

4759
assert actual == testee # did nothing
4860

4961

5062
def test_incompatible_order_make_tuple_tuple_get():
5163
tuple_of_size_2 = im.make_tuple("first", "second")
5264
testee = im.make_tuple(im.tuple_get(1, tuple_of_size_2), im.tuple_get(0, tuple_of_size_2))
53-
actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET)
65+
actual = CollapseTuple.apply(
66+
testee,
67+
remove_letified_make_tuple_elements=False,
68+
flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET,
69+
)
5470
assert actual == testee # did nothing
5571

5672

5773
def test_incompatible_size_make_tuple_tuple_get():
5874
testee = im.make_tuple(im.tuple_get(0, im.make_tuple("first", "second")))
59-
actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET)
75+
actual = CollapseTuple.apply(
76+
testee,
77+
remove_letified_make_tuple_elements=False,
78+
flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET,
79+
)
6080
assert actual == testee # did nothing
6181

6282

@@ -71,14 +91,22 @@ def test_merged_with_smaller_outer_size_make_tuple_tuple_get():
7191
def test_simple_tuple_get_make_tuple():
7292
expected = im.ref("bar")
7393
testee = im.tuple_get(1, im.make_tuple("foo", expected))
74-
actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE)
94+
actual = CollapseTuple.apply(
95+
testee,
96+
remove_letified_make_tuple_elements=False,
97+
flags=CollapseTuple.Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE,
98+
)
7599
assert expected == actual
76100

77101

78102
def test_propagate_tuple_get():
79103
expected = im.let(("el1", 1), ("el2", 2))(im.tuple_get(0, im.make_tuple("el1", "el2")))
80104
testee = im.tuple_get(0, im.let(("el1", 1), ("el2", 2))(im.make_tuple("el1", "el2")))
81-
actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.PROPAGATE_TUPLE_GET)
105+
actual = CollapseTuple.apply(
106+
testee,
107+
remove_letified_make_tuple_elements=False,
108+
flags=CollapseTuple.Flag.PROPAGATE_TUPLE_GET,
109+
)
82110
assert expected == actual
83111

84112

@@ -89,23 +117,35 @@ def test_letify_make_tuple_elements():
89117
im.make_tuple("_tuple_el_1", "_tuple_el_2")
90118
)
91119

92-
actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS)
120+
actual = CollapseTuple.apply(
121+
testee,
122+
remove_letified_make_tuple_elements=False,
123+
flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS,
124+
)
93125
assert actual == expected
94126

95127

96128
def test_letify_make_tuple_with_trivial_elements():
97129
testee = im.let(("a", 1), ("b", 2))(im.make_tuple("a", "b"))
98130
expected = testee # did nothing
99131

100-
actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS)
132+
actual = CollapseTuple.apply(
133+
testee,
134+
remove_letified_make_tuple_elements=False,
135+
flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS,
136+
)
101137
assert actual == expected
102138

103139

104140
def test_inline_trivial_make_tuple():
105141
testee = im.let("tup", im.make_tuple("a", "b"))("tup")
106142
expected = im.make_tuple("a", "b")
107143

108-
actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.INLINE_TRIVIAL_MAKE_TUPLE)
144+
actual = CollapseTuple.apply(
145+
testee,
146+
remove_letified_make_tuple_elements=False,
147+
flags=CollapseTuple.Flag.INLINE_TRIVIAL_MAKE_TUPLE,
148+
)
109149
assert actual == expected
110150

111151

@@ -114,7 +154,11 @@ def test_propagate_to_if_on_tuples():
114154
expected = im.if_(
115155
"cond", im.tuple_get(0, im.make_tuple(1, 2)), im.tuple_get(0, im.make_tuple(3, 4))
116156
)
117-
actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES)
157+
actual = CollapseTuple.apply(
158+
testee,
159+
remove_letified_make_tuple_elements=False,
160+
flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES,
161+
)
118162
assert actual == expected
119163

120164

@@ -127,28 +171,28 @@ def test_propagate_to_if_on_tuples_with_let():
127171
)
128172
actual = CollapseTuple.apply(
129173
testee,
174+
remove_letified_make_tuple_elements=True,
130175
flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES
131-
| CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS
132-
| CollapseTuple.Flag.REMOVE_LETIFIED_MAKE_TUPLE_ELEMENTS,
176+
| CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS,
133177
)
134178
assert actual == expected
135179

136180

137181
def test_propagate_nested_lift():
138182
testee = im.let("a", im.let("b", 1)("a_val"))("a")
139183
expected = im.let("b", 1)(im.let("a", "a_val")("a"))
140-
actual = CollapseTuple.apply(testee, flags=CollapseTuple.Flag.PROPAGATE_NESTED_LET)
184+
actual = CollapseTuple.apply(
185+
testee,
186+
remove_letified_make_tuple_elements=False,
187+
flags=CollapseTuple.Flag.PROPAGATE_NESTED_LET,
188+
)
141189
assert actual == expected
142190

143191

144-
def test_collapse_complicated_():
145-
# TODO: fuse with test_propagate_to_if_on_tuples_with_let
192+
def test_if_on_tuples_with_let():
146193
testee = im.let("val", im.if_("cond", im.make_tuple(1, 2), im.make_tuple(3, 4)))(
147194
im.tuple_get(0, "val")
148195
)
149196
expected = im.if_("cond", 1, 3)
150-
actual = CollapseTuple.apply(
151-
testee,
152-
# flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES
153-
)
197+
actual = CollapseTuple.apply(testee, remove_letified_make_tuple_elements=False)
154198
assert actual == expected

0 commit comments

Comments
 (0)