From 5bda60274a0486665ce1b9d37cabb905a4ed30be Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 17 Feb 2025 12:51:03 +0100 Subject: [PATCH 1/7] Just visit vertical loops In the oir -> dace lowering (the one with the unexpanded library nodes, don't use the generic visit function. Instead, visit the vertical loops directly. --- src/gt4py/cartesian/gtc/dace/oir_to_dace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py index 9dd66bac82..db1c369d4b 100644 --- a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py +++ b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py @@ -174,6 +174,6 @@ def visit_Stencil(self, node: oir.Stencil): lifetime=dace.AllocationLifetime.Persistent, debuginfo=get_dace_debuginfo(decl), ) - self.generic_visit(node, ctx=ctx) + self.visit(node.vertical_loops, ctx=ctx) ctx.sdfg.validate() return ctx.sdfg From 62b94231e5f9edb8bc91de36db8d6a77dd0701b6 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 17 Feb 2025 13:28:45 +0100 Subject: [PATCH 2/7] Rename library nodes Add a bit more information to the library node name. This facilitates debugging in that it is easier to associate the orgininal vertical loops with the nodes. --- src/gt4py/cartesian/gtc/dace/oir_to_dace.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py index db1c369d4b..bd06da7d8f 100644 --- a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py +++ b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py @@ -41,6 +41,7 @@ class SDFGContext: decls: Dict[str, oir.Decl] block_extents: Dict[int, Extent] access_infos: Dict[str, dcir.FieldAccessInfo] + loop_counter: int = 0 def __init__(self, stencil: oir.Stencil): self.sdfg = dace.SDFG(stencil.name) @@ -98,6 +99,13 @@ def _make_dace_subset(self, local_access_info, field): global_access_info, local_access_info, self.decls[field].data_dims ) + def _vloop_name(self, node: oir.VerticalLoop, ctx: OirSDFGBuilder.SDFGContext) -> str: + sdfg_name = ctx.sdfg.name + counter = ctx.loop_counter + ctx.loop_counter += 1 + + return f"{sdfg_name}_vloop_{counter}_{node.loop_order}_{id(node)}" + def visit_VerticalLoop(self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFGContext): declarations = { acc.name: ctx.decls[acc.name] @@ -105,7 +113,7 @@ def visit_VerticalLoop(self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFG if acc.name in ctx.decls } library_node = StencilComputation( - name=f"{ctx.sdfg.name}_computation_{id(node)}", + name=self._vloop_name(node, ctx), extents=ctx.block_extents, declarations=declarations, oir_node=node, From 79ed1092e91c449bdef0f6afcda02ad0c076c956 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 3 Mar 2025 10:59:11 +0100 Subject: [PATCH 3/7] Fix typos in comment --- src/gt4py/cartesian/backend/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gt4py/cartesian/backend/base.py b/src/gt4py/cartesian/backend/base.py index 5bab0453a9..571f86b527 100644 --- a/src/gt4py/cartesian/backend/base.py +++ b/src/gt4py/cartesian/backend/base.py @@ -172,9 +172,9 @@ def generate_computation(self) -> Dict[str, Union[str, Dict]]: Returns ------- Dict[str, str | Dict] of source file names / directories -> contents: - If a key's value is a string it is interpreted as a file name and the value as the - source code of that file - If a key's value is a Dict, it is interpreted as a directory name and it's + If a key's value is a string, it is interpreted as a file name and its value as the + source code of that file. + If a key's value is a Dict, it is interpreted as a directory name and its value as a nested file hierarchy to which the same rules are applied recursively. The root path is relative to the build directory. @@ -222,7 +222,7 @@ def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]: Returns ------- - Analog to :py:meth:`generate_computation` but containing bindings source code, The + Analog to :py:meth:`generate_computation` but containing bindings source code. The dictionary contains a tree of directories with leaves being a mapping from filename to source code pairs, relative to the build directory. From 630ce9622e216716cd8f4334b547fc266a172a96 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 21 Feb 2025 16:26:45 +0100 Subject: [PATCH 4/7] Cleanups in testsuites - Import is unused - Comments are redundant with doc strings --- .../integration_tests/multi_feature_tests/test_suites.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py index 10d8999565..032dc3bb5e 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py @@ -7,7 +7,6 @@ # SPDX-License-Identifier: BSD-3-Clause import numpy as np -import pytest from gt4py.cartesian import gtscript, testing as gt_testing from gt4py.cartesian.gtscript import ( @@ -25,7 +24,6 @@ from .stencil_definitions import optional_field, two_optional_fields -# ---- Identity stencil ---- class TestIdentity(gt_testing.StencilTestSuite): """Identity stencil.""" @@ -43,7 +41,6 @@ def validation(field_a, domain=None, origin=None): pass -# ---- Copy stencil ---- class TestCopy(gt_testing.StencilTestSuite): """Copy stencil.""" @@ -86,7 +83,6 @@ def validation(field_a, field_b, domain=None, origin=None): field_b[...] = (field_b[...] - 1.0) / 2.0 -# ---- Scale stencil ---- class TestGlobalScale(gt_testing.StencilTestSuite): """Scale stencil using a global global_name.""" @@ -108,7 +104,6 @@ def validation(field_a, domain, origin, **kwargs): field_a[...] = SCALE_FACTOR * field_a # noqa: F821 [undefined-name] -# ---- Parametric scale stencil ----- class TestParametricScale(gt_testing.StencilTestSuite): """Scale stencil using a parameter.""" @@ -128,7 +123,6 @@ def validation(field_a, *, scale, domain, origin, **kwargs): field_a[...] = scale * field_a -# --- Parametric-mix stencil ---- class TestParametricMix(gt_testing.StencilTestSuite): """Linear combination of input fields using several parameters.""" From 3f4ad9883fc0c6cfa9372f97754c6080a2cd375b Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 24 Jan 2025 11:49:47 +0100 Subject: [PATCH 5/7] Move import since there's no circular import here --- src/gt4py/cartesian/backend/dace_backend.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 8ca18705c9..a36a9824bd 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -31,6 +31,7 @@ ) from gt4py.cartesian.backend.module_generator import make_args_data_from_gtir from gt4py.cartesian.gtc import common, gtir +from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.dace.nodes import StencilComputation from gt4py.cartesian.gtc.dace.oir_to_dace import OirSDFGBuilder from gt4py.cartesian.gtc.dace.transformations import ( @@ -119,8 +120,6 @@ def _set_expansion_orders(sdfg: dace.SDFG): def _set_tile_sizes(sdfg: dace.SDFG): - import gt4py.cartesian.gtc.dace.daceir as dcir # avoid circular import - for node, _ in filter( lambda n: isinstance(n[0], StencilComputation), sdfg.all_nodes_recursive() ): From 19874db0d056a2ed9ab324cabb09552bae1379f2 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 3 Mar 2025 13:58:26 +0100 Subject: [PATCH 6/7] Modern annotations don't need to be quoted --- src/gt4py/cartesian/gtc/dace/daceir.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/cartesian/gtc/dace/daceir.py b/src/gt4py/cartesian/gtc/dace/daceir.py index 43a33fdd6d..23f7c36200 100644 --- a/src/gt4py/cartesian/gtc/dace/daceir.py +++ b/src/gt4py/cartesian/gtc/dace/daceir.py @@ -51,11 +51,11 @@ def tile_symbol(self) -> eve.SymbolRef: return eve.SymbolRef("__tile_" + self.lower()) @staticmethod - def dims_3d() -> Generator["Axis", None, None]: + def dims_3d() -> Generator[Axis, None, None]: yield from [Axis.I, Axis.J, Axis.K] @staticmethod - def dims_horizontal() -> Generator["Axis", None, None]: + def dims_horizontal() -> Generator[Axis, None, None]: yield from [Axis.I, Axis.J] def to_idx(self) -> int: From eabe32558578352ab26fe7db84fddccf1533b29c Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 3 Mar 2025 14:08:18 +0100 Subject: [PATCH 7/7] Alphabetical order in union types --- src/gt4py/cartesian/gtc/dace/daceir.py | 28 +++++++++++++------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/gt4py/cartesian/gtc/dace/daceir.py b/src/gt4py/cartesian/gtc/dace/daceir.py index 23f7c36200..90c0649940 100644 --- a/src/gt4py/cartesian/gtc/dace/daceir.py +++ b/src/gt4py/cartesian/gtc/dace/daceir.py @@ -357,7 +357,7 @@ def free_symbols(self) -> Set[eve.SymbolRef]: class GridSubset(eve.Node): - intervals: Dict[Axis, Union[DomainInterval, TileInterval, IndexWithExtent]] + intervals: Dict[Axis, Union[DomainInterval, IndexWithExtent, TileInterval]] def __iter__(self): for axis in Axis.dims_3d(): @@ -429,10 +429,10 @@ def from_gt4py_extent(cls, extent: gt4py.cartesian.gtc.definitions.Extent): @classmethod def from_interval( cls, - interval: Union[oir.Interval, TileInterval, DomainInterval, IndexWithExtent], + interval: Union[DomainInterval, IndexWithExtent, oir.Interval, TileInterval], axis: Axis, ): - res_interval: Union[IndexWithExtent, TileInterval, DomainInterval] + res_interval: Union[DomainInterval, IndexWithExtent, TileInterval] if isinstance(interval, (DomainInterval, oir.Interval)): res_interval = DomainInterval( start=AxisBound( @@ -441,7 +441,7 @@ def from_interval( end=AxisBound(level=interval.end.level, offset=interval.end.offset, axis=Axis.K), ) else: - assert isinstance(interval, (TileInterval, IndexWithExtent)) + assert isinstance(interval, (IndexWithExtent, TileInterval)) res_interval = interval return cls(intervals={axis: res_interval}) @@ -464,7 +464,7 @@ def full_domain(cls, axes=None): return GridSubset(intervals=res_subsets) def tile(self, tile_sizes: Dict[Axis, int]): - res_intervals: Dict[Axis, Union[DomainInterval, TileInterval, IndexWithExtent]] = {} + res_intervals: Dict[Axis, Union[DomainInterval, IndexWithExtent, TileInterval]] = {} for axis, interval in self.intervals.items(): if isinstance(interval, DomainInterval) and axis in tile_sizes: if axis == Axis.K: @@ -505,15 +505,15 @@ def union(self, other): intervals[axis] = interval1.union(interval2) else: assert ( - isinstance(interval2, (TileInterval, DomainInterval)) - and isinstance(interval1, (IndexWithExtent, DomainInterval)) + isinstance(interval2, (DomainInterval, TileInterval)) + and isinstance(interval1, (DomainInterval, IndexWithExtent)) ) or ( - isinstance(interval1, (TileInterval, DomainInterval)) + isinstance(interval1, (DomainInterval, TileInterval)) and isinstance(interval2, IndexWithExtent) ) intervals[axis] = ( interval1 - if isinstance(interval1, (TileInterval, DomainInterval)) + if isinstance(interval1, (DomainInterval, TileInterval)) else interval2 ) return GridSubset(intervals=intervals) @@ -747,7 +747,7 @@ class IndexAccess(common.FieldAccess, Expr): offset: Optional[Union[common.CartesianOffset, VariableKOffset]] -class AssignStmt(common.AssignStmt[Union[ScalarAccess, IndexAccess], Expr], Stmt): +class AssignStmt(common.AssignStmt[Union[IndexAccess, ScalarAccess], Expr], Stmt): _dtype_validation = common.assign_stmt_dtype_validation(strict=True) @@ -851,14 +851,14 @@ class Tasklet(ComputationNode, IterationNode, eve.SymbolTableTrait): class DomainMap(ComputationNode, IterationNode): index_ranges: List[Range] schedule: MapSchedule - computations: List[Union[Tasklet, DomainMap, NestedSDFG]] + computations: List[Union[DomainMap, NestedSDFG, Tasklet]] class ComputationState(IterationNode): - computations: List[Union[Tasklet, DomainMap]] + computations: List[Union[DomainMap, Tasklet]] -class DomainLoop(IterationNode, ComputationNode): +class DomainLoop(ComputationNode, IterationNode): axis: Axis index_range: Range loop_states: List[Union[ComputationState, DomainLoop]] @@ -868,7 +868,7 @@ class NestedSDFG(ComputationNode, eve.SymbolTableTrait): label: eve.Coerced[eve.SymbolRef] field_decls: List[FieldDecl] symbol_decls: List[SymbolDecl] - states: List[Union[DomainLoop, ComputationState]] + states: List[Union[ComputationState, DomainLoop]] # There are circular type references with string placeholders. These statements let datamodels resolve those.