diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index 0839e95b5b..f4b5285add 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -445,7 +445,7 @@ def domain( ) -def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> call: +def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> Callable: """ Create an `as_fieldop` call. @@ -454,7 +454,9 @@ def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> cal >>> str(as_fieldop(lambda_("it1", "it2")(plus(deref("it1"), deref("it2"))))("field1", "field2")) '(⇑(λ(it1, it2) → ·it1 + ·it2))(field1, field2)' """ - return call( + from gt4py.next.iterator.ir_utils import domain_utils + + result = call( call("as_fieldop")( *( ( @@ -467,6 +469,14 @@ def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> cal ) ) + def _populate_domain_annex_wrapper(*args, **kwargs): + node = result(*args, **kwargs) + if domain: + node.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) + return node + + return _populate_domain_annex_wrapper + def op_as_fieldop( op: str | itir.SymRef | Callable, domain: Optional[itir.FunCall] = None diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 0a0cf6d37e..a22ce1c7e1 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -144,7 +144,7 @@ def all(self) -> CollapseTuple.Flag: ignore_tuple_size: bool flags: Flag = Flag.all() # noqa: RUF009 [function-call-in-dataclass-default-argument] - PRESERVED_ANNEX_ATTRS = ("type",) + PRESERVED_ANNEX_ATTRS = ("type", "domain") @classmethod def apply( @@ -261,6 +261,7 @@ def transform_collapse_make_tuple_tuple_get( # tuple argument differs, just continue with the rest of the tree return None + itir_type_inference.reinfer(first_expr) # type is needed so reinfer on-demand assert self.ignore_tuple_size or isinstance( first_expr.type, (ts.TupleType, ts.DeferredType) ) @@ -281,7 +282,7 @@ def transform_collapse_tuple_get_make_tuple( and isinstance(node.args[0], ir.Literal) ): # `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i` - assert type_info.is_integer(node.args[0].type) + assert not node.args[0].type or type_info.is_integer(node.args[0].type) make_tuple_call = node.args[1] idx = int(node.args[0].value) assert idx < len( diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 2084ab2518..8802d0dd84 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -12,6 +12,11 @@ class ConstantFolding(PreserveLocationVisitor, NodeTranslator): + PRESERVED_ANNEX_ATTRS = ( + "type", + "domain", + ) + @classmethod def apply(cls, node: ir.Node) -> ir.Node: return cls().visit(node) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index cc42896f2b..0ca1c57642 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -13,11 +13,16 @@ from gt4py.eve import utils as eve_utils from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) from gt4py.next.iterator.transforms import ( inline_center_deref_lift_vars, inline_lambdas, inline_lifts, + merge_let, trace_shifts, ) from gt4py.next.iterator.type_system import inference as type_inference @@ -50,7 +55,6 @@ def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: if cpm.is_ref_to(stencil, "deref"): stencil = im.lambda_("arg")(im.deref("arg")) new_expr = im.as_fieldop(stencil, domain)(*expr.args) - type_inference.copy_type(from_=expr, to=new_expr, allow_untyped=True) return new_expr @@ -80,7 +84,12 @@ def _inline_as_fieldop_arg( for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True): if isinstance(inner_arg, itir.SymRef): - stencil_params.append(inner_param) + if inner_arg.id in extracted_args: + assert extracted_args[inner_arg.id] == inner_arg + alias = stencil_params[list(extracted_args.keys()).index(inner_arg.id)] + stencil_body = im.let(inner_param, im.ref(alias.id))(stencil_body) + else: + stencil_params.append(inner_param) extracted_args[inner_arg.id] = inner_arg elif isinstance(inner_arg, itir.Literal): # note: only literals, not all scalar expressions are required as it doesn't make sense @@ -100,12 +109,59 @@ def _inline_as_fieldop_arg( ), extracted_args +def _unwrap_scan(stencil: itir.Lambda | itir.FunCall): + """ + If given a scan, extract stencil part of its scan pass and a back-transformation into a scan. + + If a regular stencil is given the stencil is left as-is and the back-transformation is the + identity function. This function allows treating a scan stencil like a regular stencil during + a transformation avoiding the complexity introduced by the different IR format. + + >>> scan = im.call("scan")( + ... im.lambda_("state", "arg")(im.plus("state", im.deref("arg"))), True, 0.0 + ... ) + >>> stencil, back_trafo = _unwrap_scan(scan) + >>> str(stencil) + 'λ(arg) → state + ·arg' + >>> str(back_trafo(stencil)) + 'scan(λ(state, arg) → (λ(arg) → state + ·arg)(arg), True, 0.0)' + + In case a regular stencil is given it is returned as-is: + + >>> deref_stencil = im.lambda_("it")(im.deref("it")) + >>> stencil, back_trafo = _unwrap_scan(deref_stencil) + >>> assert stencil == deref_stencil + """ + if cpm.is_call_to(stencil, "scan"): + scan_pass, direction, init = stencil.args + assert isinstance(scan_pass, itir.Lambda) + # remove scan pass state to be used by caller + state_param = scan_pass.params[0] + stencil_like = im.lambda_(*scan_pass.params[1:])(scan_pass.expr) + + def restore_scan(transformed_stencil_like: itir.Lambda): + new_scan_pass = im.lambda_(state_param, *transformed_stencil_like.params)( + im.call(transformed_stencil_like)( + *(param.id for param in transformed_stencil_like.params) + ) + ) + return im.call("scan")(new_scan_pass, direction, init) + + return stencil_like, restore_scan + + assert isinstance(stencil, itir.Lambda) + return stencil, lambda s: s + + def fuse_as_fieldop( expr: itir.Expr, eligible_args: list[bool], *, uids: eve_utils.UIDGenerator ) -> itir.Expr: - assert cpm.is_applied_as_fieldop(expr) and isinstance(expr.fun.args[0], itir.Lambda) # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + assert cpm.is_applied_as_fieldop(expr) stencil: itir.Lambda = expr.fun.args[0] # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + assert isinstance(expr.fun.args[0], itir.Lambda) or cpm.is_call_to(stencil, "scan") # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + stencil, restore_scan = _unwrap_scan(stencil) + domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by is_applied_as_fieldop args: list[itir.Expr] = expr.args @@ -119,9 +175,7 @@ def fuse_as_fieldop( pass elif cpm.is_call_to(arg, "if_"): # TODO(tehrengruber): revisit if we want to inline if_ - type_ = arg.type arg = im.op_as_fieldop("if_")(*arg.args) - arg.type = type_ elif _is_tuple_expr_of_literals(arg): arg = im.op_as_fieldop(im.lambda_()(arg))() else: @@ -134,6 +188,7 @@ def fuse_as_fieldop( new_args = _merge_arguments(new_args, extracted_args) else: # just a safety check if typing information is available + type_inference.reinfer(arg) if arg.type and not isinstance(arg.type, ts.DeferredType): assert isinstance(arg.type, ts.TypeSpec) dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) @@ -148,24 +203,68 @@ def fuse_as_fieldop( new_param = stencil_param.id new_args = _merge_arguments(new_args, {new_param: arg}) - new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)( - *new_args.values() - ) + stencil = im.lambda_(*new_args.keys())(new_stencil_body) + stencil = restore_scan(stencil) # simplify stencil directly to keep the tree small - new_node = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply( - new_node + new_stencil = inline_lambdas.InlineLambdas.apply( + stencil, opcount_preserving=True, force_inline_lift_args=False + ) + new_stencil = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply( + new_stencil, is_stencil=True, uids=uids ) # to keep the tree small - new_node = inline_lambdas.InlineLambdas.apply( - new_node, opcount_preserving=True, force_inline_lift_args=True + new_stencil = merge_let.MergeLet().visit(new_stencil) + new_stencil = inline_lambdas.InlineLambdas.apply( + new_stencil, opcount_preserving=True, force_inline_lift_args=True ) - new_node = inline_lifts.InlineLifts().visit(new_node) + new_stencil = inline_lifts.InlineLifts().visit(new_stencil) - type_inference.copy_type(from_=expr, to=new_node, allow_untyped=True) + new_node = im.as_fieldop(new_stencil, domain)(*new_args.values()) return new_node +def _arg_inline_predicate(node: itir.Expr, shifts): + if _is_tuple_expr_of_literals(node): + return True + + if ( + is_applied_fieldop := cpm.is_applied_as_fieldop(node) + and not cpm.is_call_to(node.fun.args[0], "scan") # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop + ) or cpm.is_call_to(node, "if_"): + # always inline arg if it is an applied fieldop with only a single arg + if is_applied_fieldop and len(node.args) == 1: + return True + # argument is never used, will be removed when inlined + if len(shifts) == 0: + return True + # applied fieldop with list return type must always be inlined as no backend supports this + type_inference.reinfer(node) + assert isinstance(node.type, ts.TypeSpec) + dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, node.type) + if isinstance(dtype, ts.ListType): + return True + # only accessed at the center location + if shifts in [set(), {()}]: + return True + # TODO(tehrengruber): Disabled as the InlineCenterDerefLiftVars does not support this yet + # and it would increase the size of the tree otherwise. + # if len(shifts) == 1 and not any( + # trace_shifts.Sentinel.ALL_NEIGHBORS in access for access in shifts + # ): + # return True # noqa: ERA001 [commented-out-code] + + return False + + +def _make_tuple_element_inline_predicate(node: itir.Expr): + if cpm.is_applied_as_fieldop(node): # field, or tuple of fields + return True + if isinstance(node.type, ts.FieldType) and isinstance(node, itir.SymRef): + return True + return False + + @dataclasses.dataclass class FuseAsFieldOp(eve.NodeTranslator): """ @@ -194,6 +293,8 @@ class FuseAsFieldOp(eve.NodeTranslator): as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1[ ⟩)(inp1, inp2, inp3) """ # noqa: RUF002 # ignore ambiguous multiplication character + PRESERVED_ANNEX_ATTRS = ("domain",) + uids: eve_utils.UIDGenerator @classmethod @@ -204,6 +305,7 @@ def apply( offset_provider_type: common.OffsetProviderType, uids: Optional[eve_utils.UIDGenerator] = None, allow_undeclared_symbols=False, + within_set_at_expr: Optional[bool] = None, ): node = type_inference.infer( node, @@ -211,41 +313,133 @@ def apply( allow_undeclared_symbols=allow_undeclared_symbols, ) + if within_set_at_expr is None: + within_set_at_expr = not isinstance(node, itir.Program) + if not uids: uids = eve_utils.UIDGenerator() - return cls(uids=uids).visit(node) + return cls(uids=uids).visit(node, within_set_at_expr=within_set_at_expr) + + def visit(self, node, **kwargs): + new_node = super().visit(node, **kwargs) + if isinstance(node, itir.Expr) and hasattr(node.annex, "domain"): + new_node.annex.domain = node.annex.domain + return new_node + + def visit_SetAt(self, node: itir.SetAt, **kwargs): + return itir.SetAt( + expr=self.visit(node.expr, **kwargs | {"within_set_at_expr": True}), + domain=node.domain, + target=node.target, + ) + + def visit_FunCall(self, node: itir.FunCall, **kwargs): + if not kwargs.get("within_set_at_expr"): + return node + + # inline all fields with list dtype. This needs to happen before the children are visited + # such that the `as_fieldop` can be fused. + # TODO(tehrengruber): what should we do in case the field with list dtype is a let itself? + # This could duplicate other expressions which we did not intend to duplicate. + if cpm.is_let(node): + for arg in node.args: + type_inference.reinfer(arg) + eligible_els = [ + isinstance(arg.type, ts.FieldType) and isinstance(arg.type.dtype, ts.ListType) + for arg in node.args + ] + if any(eligible_els): + node = inline_lambdas.inline_lambda(node, eligible_params=eligible_els) + return self.visit(node, **kwargs) + + if cpm.is_applied_as_fieldop(node): # don't descend in stencil + node = im.as_fieldop(*node.fun.args)(*self.generic_visit(node.args, **kwargs)) # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop + elif kwargs.get("recurse", True): + node = self.generic_visit(node, **kwargs) + + if cpm.is_call_to(node, "make_tuple"): + for arg in node.args: + type_inference.reinfer(arg) + assert not isinstance(arg.type, ts.FieldType) or ( + hasattr(arg.annex, "domain") + and isinstance(arg.annex.domain, domain_utils.SymbolicDomain) + ) - def visit_FunCall(self, node: itir.FunCall): - node = self.generic_visit(node) + eligible_els = [_make_tuple_element_inline_predicate(arg) for arg in node.args] + field_args = [arg for i, arg in enumerate(node.args) if eligible_els[i]] + distinct_domains = set(arg.annex.domain.as_expr() for arg in field_args) + if len(distinct_domains) != len(field_args): + new_els: list[itir.Expr | None] = [None for _ in node.args] + field_args_by_domain: dict[itir.FunCall, list[tuple[int, itir.Expr]]] = {} + for i, arg in enumerate(node.args): + if eligible_els[i]: + assert isinstance(arg.annex.domain, domain_utils.SymbolicDomain) + domain = arg.annex.domain.as_expr() + field_args_by_domain.setdefault(domain, []) + field_args_by_domain[domain].append((i, arg)) + else: + new_els[i] = arg # keep as is + + if len(field_args_by_domain) == 1 and all(eligible_els): + # if we only have a single domain covering all args we don't need to create an + # unnecessary let + ((domain, inner_field_args),) = field_args_by_domain.items() + new_node = im.op_as_fieldop(lambda *args: im.make_tuple(*args), domain)( + *(arg for _, arg in inner_field_args) + ) + new_node = self.visit(new_node, **{**kwargs, "recurse": False}) + else: + let_vars = {} + for domain, inner_field_args in field_args_by_domain.items(): + if len(inner_field_args) > 1: + var = self.uids.sequential_id(prefix="__fasfop") + fused_args = im.op_as_fieldop( + lambda *args: im.make_tuple(*args), domain + )(*(arg for _, arg in inner_field_args)) + type_inference.reinfer(arg) + # don't recurse into nested args, but only consider newly created `as_fieldop` + # note: this will always inline (as we inline center accessed) + let_vars[var] = self.visit(fused_args, **{**kwargs, "recurse": False}) + for outer_tuple_idx, (inner_tuple_idx, _) in enumerate( + inner_field_args + ): + new_el = im.tuple_get(outer_tuple_idx, var) + new_el.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) + new_els[inner_tuple_idx] = new_el + else: + i, arg = inner_field_args[0] + new_els[i] = arg + assert not any(el is None for el in new_els) + assert let_vars + new_node = im.let(*let_vars.items())(im.make_tuple(*new_els)) + new_node = inline_lambdas.inline_lambda(new_node, opcount_preserving=True) + return new_node if cpm.is_call_to(node.fun, "as_fieldop"): node = _canonicalize_as_fieldop(node) - if cpm.is_call_to(node.fun, "as_fieldop") and isinstance(node.fun.args[0], itir.Lambda): - stencil: itir.Lambda = node.fun.args[0] - args: list[itir.Expr] = node.args - shifts = trace_shifts.trace_stencil(stencil) + # when multiple `as_fieldop` calls are fused that use the same argument, this argument + # might become referenced once only. In order to be able to continue fusing such arguments + # try inlining here. + if cpm.is_let(node): + new_node = inline_lambdas.inline_lambda(node, opcount_preserving=True) + if new_node is not node: # nothing has been inlined + return self.visit(new_node, **kwargs) - eligible_args = [] - for arg, arg_shifts in zip(args, shifts, strict=True): - assert isinstance(arg.type, ts.TypeSpec) - dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) - # TODO(tehrengruber): make this configurable - eligible_args.append( - _is_tuple_expr_of_literals(arg) - or ( - isinstance(arg, itir.FunCall) - and ( - ( - cpm.is_call_to(arg.fun, "as_fieldop") - and isinstance(arg.fun.args[0], itir.Lambda) - ) - or cpm.is_call_to(arg, "if_") - ) - and (isinstance(dtype, ts.ListType) or len(arg_shifts) <= 1) - ) + if cpm.is_call_to(node.fun, "as_fieldop"): + stencil = node.fun.args[0] + assert isinstance(stencil, itir.Lambda) or cpm.is_call_to(stencil, "scan") + args: list[itir.Expr] = node.args + shifts = trace_shifts.trace_stencil(stencil, num_args=len(args)) + + eligible_els = [ + _arg_inline_predicate(arg, arg_shifts) + for arg, arg_shifts in zip(args, shifts, strict=True) + ] + if any(eligible_els): + return self.visit( + fuse_as_fieldop(node, eligible_els, uids=self.uids), + **{**kwargs, "recurse": False}, ) - - return fuse_as_fieldop(node, eligible_args, uids=self.uids) return node diff --git a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py index 95c761d7ba..7bd26d0f19 100644 --- a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py +++ b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses -from typing import ClassVar, Optional +from typing import ClassVar, Optional, TypeVar import gt4py.next.iterator.ir_utils.common_pattern_matcher as cpm from gt4py import eve @@ -23,6 +23,9 @@ def is_center_derefed_only(node: itir.Node) -> bool: return hasattr(node.annex, "recorded_shifts") and node.annex.recorded_shifts in [set(), {()}] +T = TypeVar("T", bound=itir.Program | itir.Lambda) + + @dataclasses.dataclass class InlineCenterDerefLiftVars(eve.NodeTranslator): """ @@ -45,14 +48,19 @@ class InlineCenterDerefLiftVars(eve.NodeTranslator): Note: This pass uses and preserves the `recorded_shifts` annex. """ - PRESERVED_ANNEX_ATTRS: ClassVar[tuple[str, ...]] = ("recorded_shifts",) + PRESERVED_ANNEX_ATTRS: ClassVar[tuple[str, ...]] = ("domain", "recorded_shifts") uids: eve_utils.UIDGenerator @classmethod - def apply(cls, node: itir.Program, uids: Optional[eve_utils.UIDGenerator] = None): + def apply( + cls, node: T, *, is_stencil=False, uids: Optional[eve_utils.UIDGenerator] = None + ) -> T: if not uids: uids = eve_utils.UIDGenerator() + if is_stencil: + assert isinstance(node, itir.Expr) + trace_shifts.trace_stencil(node, save_to_annex=True) return cls(uids=uids).visit(node) def visit_FunCall(self, node: itir.FunCall, **kwargs): diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index f27dbbb74c..07d116555d 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -8,7 +8,7 @@ import dataclasses import enum -from typing import Callable, Optional +from typing import Callable, ClassVar, Optional import gt4py.eve as eve from gt4py.eve import NodeTranslator, traits @@ -112,6 +112,8 @@ class InlineLifts( function nodes. """ + PRESERVED_ANNEX_ATTRS: ClassVar[tuple[str, ...]] = ("domain",) + class Flag(enum.IntEnum): #: `shift(...)(lift(f)(args...))` -> `lift(f)(shift(...)(args)...)` PROPAGATE_SHIFT = 1 diff --git a/src/gt4py/next/iterator/transforms/inline_scalar.py b/src/gt4py/next/iterator/transforms/inline_scalar.py index 87b576d14d..dd6470630e 100644 --- a/src/gt4py/next/iterator/transforms/inline_scalar.py +++ b/src/gt4py/next/iterator/transforms/inline_scalar.py @@ -16,6 +16,8 @@ class InlineScalar(eve.NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("domain",) + @classmethod def apply(cls, program: itir.Program, offset_provider_type: common.OffsetProviderType): program = itir_inference.infer(program, offset_provider_type=offset_provider_type) diff --git a/src/gt4py/next/iterator/transforms/merge_let.py b/src/gt4py/next/iterator/transforms/merge_let.py index 0e7d74e594..9c0c25bd49 100644 --- a/src/gt4py/next/iterator/transforms/merge_let.py +++ b/src/gt4py/next/iterator/transforms/merge_let.py @@ -5,6 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from typing import ClassVar import gt4py.eve as eve from gt4py.next.iterator import ir as itir @@ -26,6 +27,8 @@ class MergeLet(eve.PreserveLocationVisitor, eve.NodeTranslator): This can significantly reduce the depth of the tree and its readability. """ + PRESERVED_ANNEX_ATTRS: ClassVar[tuple[str, ...]] = ("domain",) + def visit_FunCall(self, node: itir.FunCall): node = self.generic_visit(node) if ( diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 6906f81e3f..e27d38183d 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -70,6 +70,10 @@ def apply_common_transforms( ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = NormalizeShifts().visit(ir) + # TODO(tehrengruber): Many iterator test contain lifts that need to be inlined, e.g. + # test_can_deref. We didn't notice previously as FieldOpFusion did this implicitly everywhere. + ir = inline_lifts.InlineLifts().visit(ir) + # note: this increases the size of the tree # Inline. The domain inference can not handle "user" functions, e.g. `let f = λ(...) → ... in f(...)` ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 68346b6622..b2ab49deea 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -14,15 +14,14 @@ from gt4py import eve from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im class ValidateRecordedShiftsAnnex(eve.NodeVisitor): """Ensure every applied lift and its arguments have the `recorded_shifts` annex populated.""" def visit_FunCall(self, node: ir.FunCall): - if is_applied_lift(node): + if cpm.is_applied_lift(node): assert hasattr(node.annex, "recorded_shifts") if len(node.annex.recorded_shifts) == 0: @@ -334,8 +333,11 @@ def trace_stencil( if isinstance(stencil, ir.Lambda): assert num_args is None or num_args == len(stencil.params) num_args = len(stencil.params) + elif cpm.is_call_to(stencil, "scan"): + assert isinstance(stencil.args[0], ir.Lambda) + num_args = len(stencil.args[0].params) - 1 if not isinstance(num_args, int): - raise ValueError("Stencil must be an 'itir.Lambda' or `num_args` is given.") + raise ValueError("Stencil must be an 'itir.Lambda', scan, or `num_args` is given.") assert isinstance(num_args, int) args = [im.ref(f"__arg{i}") for i in range(num_args)] diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py index 168e9490e0..dd8b931960 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -5,19 +5,27 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import copy from typing import Callable, Optional from gt4py import next as gtx from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms import fuse_as_fieldop +from gt4py.next.iterator.ir_utils import ir_makers as im, domain_utils +from gt4py.next.iterator.transforms import fuse_as_fieldop, collapse_tuple from gt4py.next.type_system import type_specifications as ts IDim = gtx.Dimension("IDim") +JDim = gtx.Dimension("JDim") field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) +def _with_domain_annex(node: itir.Expr, domain: itir.Expr): + node = copy.deepcopy(node) + node.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) + return node + + def test_trivial(): d = im.domain("cartesian_domain", {IDim: (0, 1)}) testee = im.op_as_fieldop("plus", d)( @@ -46,6 +54,25 @@ def test_trivial_literal(): assert actual == expected +def test_trivial_same_arg_twice(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.op_as_fieldop("plus", d)( + # note: inp1 occurs twice here + im.op_as_fieldop("multiplies", d)(im.ref("inp1", field_type), im.ref("inp1", field_type)), + im.ref("inp2", field_type), + ) + expected = im.as_fieldop( + im.lambda_("inp1", "inp2")( + im.plus(im.multiplies_(im.deref("inp1"), im.deref("inp1")), im.deref("inp2")) + ), + d, + )(im.ref("inp1", field_type), im.ref("inp2", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + def test_tuple_arg(): d = im.domain("cartesian_domain", {}) testee = im.op_as_fieldop("plus", d)( @@ -99,19 +126,166 @@ def test_no_inline(): im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))) ), d1, - )(im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type))) + )(im.op_as_fieldop("plus", d2)(im.ref("inp1", field_type), im.ref("inp2", field_type))) actual = fuse_as_fieldop.FuseAsFieldOp.apply( testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True ) assert actual == testee +def test_staged_inlining(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.let( + "tmp", im.op_as_fieldop("plus", d)(im.ref("a", field_type), im.ref("b", field_type)) + )( + im.op_as_fieldop("plus", d)( + im.op_as_fieldop(im.lambda_("a")(im.plus("a", 1)), d)("tmp"), + im.op_as_fieldop(im.lambda_("a")(im.plus("a", 2)), d)("tmp"), + ) + ) + expected = im.as_fieldop( + im.lambda_("a", "b")( + im.let("_icdlv_1", im.plus(im.deref("a"), im.deref("b")))( + im.plus(im.plus("_icdlv_1", 1), im.plus("_icdlv_1", 2)) + ) + ), + d, + )(im.ref("a", field_type), im.ref("b", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_make_tuple_fusion_trivial(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.make_tuple( + im.as_fieldop("deref", d)(im.ref("a", field_type)), + im.as_fieldop("deref", d)(im.ref("a", field_type)), + ) + expected = im.as_fieldop( + im.lambda_("a")(im.make_tuple(im.deref("a"), im.deref("a"))), + d, + )(im.ref("a", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + # simplify to remove unnecessary make_tuple call `{v[0], v[1]}(actual)` + actual_simplified = collapse_tuple.CollapseTuple.apply( + actual, within_stencil=False, allow_undeclared_symbols=True + ) + assert actual_simplified == expected + + +def test_make_tuple_fusion_symref(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.make_tuple( + im.as_fieldop("deref", d)(im.ref("a", field_type)), + _with_domain_annex(im.ref("b", field_type), d), + ) + expected = im.as_fieldop( + im.lambda_("a", "b")(im.make_tuple(im.deref("a"), im.deref("b"))), + d, + )(im.ref("a", field_type), im.ref("b", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + # simplify to remove unnecessary make_tuple call + actual_simplified = collapse_tuple.CollapseTuple.apply( + actual, within_stencil=False, allow_undeclared_symbols=True + ) + assert actual_simplified == expected + + +def test_make_tuple_fusion_symref_same_ref(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.make_tuple( + im.as_fieldop("deref", d)(im.ref("a", field_type)), + _with_domain_annex(im.ref("a", field_type), d), + ) + expected = im.as_fieldop( + im.lambda_("a")(im.make_tuple(im.deref("a"), im.deref("a"))), + d, + )(im.ref("a", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + # simplify to remove unnecessary make_tuple call + actual_simplified = collapse_tuple.CollapseTuple.apply( + actual, within_stencil=False, allow_undeclared_symbols=True + ) + assert actual_simplified == expected + + +def test_make_tuple_nested(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.make_tuple( + _with_domain_annex(im.ref("a", field_type), d), + im.make_tuple( + _with_domain_annex(im.ref("b", field_type), d), + _with_domain_annex(im.ref("c", field_type), d), + ), + ) + expected = im.as_fieldop( + im.lambda_("a", "b", "c")( + im.make_tuple(im.deref("a"), im.make_tuple(im.deref("b"), im.deref("c"))) + ), + d, + )(im.ref("a", field_type), im.ref("b", field_type), im.ref("c", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + # simplify to remove unnecessary make_tuple call + actual_simplified = collapse_tuple.CollapseTuple.apply( + actual, within_stencil=False, allow_undeclared_symbols=True + ) + assert actual_simplified == expected + + +def test_make_tuple_fusion_different_domains(): + d1 = im.domain("cartesian_domain", {IDim: (0, 1)}) + d2 = im.domain("cartesian_domain", {JDim: (0, 1)}) + field_i_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + field_j_type = ts.FieldType(dims=[JDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + testee = im.make_tuple( + im.as_fieldop("deref", d1)(im.ref("a", field_i_type)), + im.as_fieldop("deref", d2)(im.ref("b", field_j_type)), + im.as_fieldop("deref", d1)(im.ref("c", field_i_type)), + im.as_fieldop("deref", d2)(im.ref("d", field_j_type)), + ) + expected = im.let( + ( + "__fasfop_1", + im.as_fieldop(im.lambda_("a", "c")(im.make_tuple(im.deref("a"), im.deref("c"))), d1)( + "a", "c" + ), + ), + ( + "__fasfop_2", + im.as_fieldop(im.lambda_("b", "d")(im.make_tuple(im.deref("b"), im.deref("d"))), d2)( + "b", "d" + ), + ), + )( + im.make_tuple( + im.tuple_get(0, "__fasfop_1"), + im.tuple_get(0, "__fasfop_2"), + im.tuple_get(1, "__fasfop_1"), + im.tuple_get(1, "__fasfop_2"), + ) + ) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + def test_partial_inline(): d1 = im.domain("cartesian_domain", {IDim: (1, 2)}) d2 = im.domain("cartesian_domain", {IDim: (0, 3)}) testee = im.as_fieldop( # first argument read at multiple locations -> not inlined - # second argument only reat at a single location -> inlined + # second argument only read at a single location -> inlined im.lambda_("a", "b")( im.plus( im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))), @@ -120,19 +294,88 @@ def test_partial_inline(): ), d1, )( - im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), - im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), + im.op_as_fieldop("plus", d2)(im.ref("inp1", field_type), im.ref("inp2", field_type)), + im.op_as_fieldop("plus", d2)(im.ref("inp1", field_type), im.ref("inp2", field_type)), ) expected = im.as_fieldop( - im.lambda_("a", "inp1")( + im.lambda_("a", "inp1", "inp2")( im.plus( im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))), - im.deref("inp1"), + im.plus(im.deref("inp1"), im.deref("inp2")), ) ), d1, - )(im.as_fieldop(im.lambda_("inp1")(im.deref("inp1")), d2)(im.ref("inp1", field_type)), "inp1") + )( + im.op_as_fieldop("plus", d2)(im.ref("inp1", field_type), im.ref("inp2", field_type)), + "inp1", + "inp2", + ) actual = fuse_as_fieldop.FuseAsFieldOp.apply( testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True ) assert actual == expected + + +def test_chained_fusion(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.let( + "a", im.op_as_fieldop("plus", d)(im.ref("inp1", field_type), im.ref("inp2", field_type)) + )( + im.op_as_fieldop("plus", d)( + im.as_fieldop("deref", d)(im.ref("a", field_type)), + im.as_fieldop("deref", d)(im.ref("a", field_type)), + ) + ) + expected = im.as_fieldop( + im.lambda_("inp1", "inp2")( + im.let("_icdlv_1", im.plus(im.deref("inp1"), im.deref("inp2")))( + im.plus("_icdlv_1", "_icdlv_1") + ) + ), + d, + )(im.ref("inp1", field_type), im.ref("inp2", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_inline_as_fieldop_with_list_dtype(): + list_field_type = ts.FieldType( + dims=[IDim], dtype=ts.ListType(element_type=ts.ScalarType(kind=ts.ScalarKind.INT32)) + ) + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.as_fieldop(im.lambda_("inp")(im.call("reduce")(im.deref("inp"), 0)), d)( + im.as_fieldop("deref")(im.ref("inp", list_field_type)) + ) + expected = im.as_fieldop(im.lambda_("inp")(im.call("reduce")(im.deref("inp"), 0)), d)( + im.ref("inp", list_field_type) + ) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_inline_into_scan(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + scan = im.call("scan")(im.lambda_("state", "a")(im.plus("state", im.deref("a"))), True, 0) + testee = im.as_fieldop(scan, d)(im.as_fieldop("deref")(im.ref("a", field_type))) + expected = im.as_fieldop(scan, d)(im.ref("a", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_no_inline_into_scan(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + scan_stencil = im.call("scan")( + im.lambda_("state", "a")(im.plus("state", im.deref("a"))), True, 0 + ) + scan = im.as_fieldop(scan_stencil, d)(im.ref("a", field_type)) + testee = im.as_fieldop(im.lambda_("arg")(im.deref("arg")), d)(scan) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == testee