Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor[cartesian]: gt4py/dace bridge cleanup #1895

Merged
merged 7 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/gt4py/cartesian/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ def generate_computation(self) -> Dict[str, Union[str, Dict]]:
Returns
-------
Dict[str, str | Dict] of source file names / directories -> contents:
If a key's value is a string it is interpreted as a file name and the value as the
source code of that file
If a key's value is a Dict, it is interpreted as a directory name and it's
If a key's value is a string, it is interpreted as a file name and its value as the
source code of that file.
If a key's value is a Dict, it is interpreted as a directory name and its
value as a nested file hierarchy to which the same rules are applied recursively.
The root path is relative to the build directory.

Expand Down Expand Up @@ -222,7 +222,7 @@ def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]:

Returns
-------
Analog to :py:meth:`generate_computation` but containing bindings source code, The
Analog to :py:meth:`generate_computation` but containing bindings source code. The
dictionary contains a tree of directories with leaves being a mapping from filename to
source code pairs, relative to the build directory.

Expand Down
3 changes: 1 addition & 2 deletions src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from gt4py.cartesian.backend.module_generator import make_args_data_from_gtir
from gt4py.cartesian.gtc import common, gtir
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.nodes import StencilComputation
from gt4py.cartesian.gtc.dace.oir_to_dace import OirSDFGBuilder
from gt4py.cartesian.gtc.dace.transformations import (
Expand Down Expand Up @@ -119,8 +120,6 @@ def _set_expansion_orders(sdfg: dace.SDFG):


def _set_tile_sizes(sdfg: dace.SDFG):
import gt4py.cartesian.gtc.dace.daceir as dcir # avoid circular import

for node, _ in filter(
lambda n: isinstance(n[0], StencilComputation), sdfg.all_nodes_recursive()
):
Expand Down
32 changes: 16 additions & 16 deletions src/gt4py/cartesian/gtc/dace/daceir.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def tile_symbol(self) -> eve.SymbolRef:
return eve.SymbolRef("__tile_" + self.lower())

@staticmethod
def dims_3d() -> Generator["Axis", None, None]:
def dims_3d() -> Generator[Axis, None, None]:
yield from [Axis.I, Axis.J, Axis.K]

@staticmethod
def dims_horizontal() -> Generator["Axis", None, None]:
def dims_horizontal() -> Generator[Axis, None, None]:
yield from [Axis.I, Axis.J]

def to_idx(self) -> int:
Expand Down Expand Up @@ -357,7 +357,7 @@ def free_symbols(self) -> Set[eve.SymbolRef]:


class GridSubset(eve.Node):
intervals: Dict[Axis, Union[DomainInterval, TileInterval, IndexWithExtent]]
intervals: Dict[Axis, Union[DomainInterval, IndexWithExtent, TileInterval]]

def __iter__(self):
for axis in Axis.dims_3d():
Expand Down Expand Up @@ -429,10 +429,10 @@ def from_gt4py_extent(cls, extent: gt4py.cartesian.gtc.definitions.Extent):
@classmethod
def from_interval(
cls,
interval: Union[oir.Interval, TileInterval, DomainInterval, IndexWithExtent],
interval: Union[DomainInterval, IndexWithExtent, oir.Interval, TileInterval],
axis: Axis,
):
res_interval: Union[IndexWithExtent, TileInterval, DomainInterval]
res_interval: Union[DomainInterval, IndexWithExtent, TileInterval]
if isinstance(interval, (DomainInterval, oir.Interval)):
res_interval = DomainInterval(
start=AxisBound(
Expand All @@ -441,7 +441,7 @@ def from_interval(
end=AxisBound(level=interval.end.level, offset=interval.end.offset, axis=Axis.K),
)
else:
assert isinstance(interval, (TileInterval, IndexWithExtent))
assert isinstance(interval, (IndexWithExtent, TileInterval))
res_interval = interval

return cls(intervals={axis: res_interval})
Expand All @@ -464,7 +464,7 @@ def full_domain(cls, axes=None):
return GridSubset(intervals=res_subsets)

def tile(self, tile_sizes: Dict[Axis, int]):
res_intervals: Dict[Axis, Union[DomainInterval, TileInterval, IndexWithExtent]] = {}
res_intervals: Dict[Axis, Union[DomainInterval, IndexWithExtent, TileInterval]] = {}
for axis, interval in self.intervals.items():
if isinstance(interval, DomainInterval) and axis in tile_sizes:
if axis == Axis.K:
Expand Down Expand Up @@ -505,15 +505,15 @@ def union(self, other):
intervals[axis] = interval1.union(interval2)
else:
assert (
isinstance(interval2, (TileInterval, DomainInterval))
and isinstance(interval1, (IndexWithExtent, DomainInterval))
isinstance(interval2, (DomainInterval, TileInterval))
and isinstance(interval1, (DomainInterval, IndexWithExtent))
) or (
isinstance(interval1, (TileInterval, DomainInterval))
isinstance(interval1, (DomainInterval, TileInterval))
and isinstance(interval2, IndexWithExtent)
)
intervals[axis] = (
interval1
if isinstance(interval1, (TileInterval, DomainInterval))
if isinstance(interval1, (DomainInterval, TileInterval))
else interval2
)
return GridSubset(intervals=intervals)
Expand Down Expand Up @@ -747,7 +747,7 @@ class IndexAccess(common.FieldAccess, Expr):
offset: Optional[Union[common.CartesianOffset, VariableKOffset]]


class AssignStmt(common.AssignStmt[Union[ScalarAccess, IndexAccess], Expr], Stmt):
class AssignStmt(common.AssignStmt[Union[IndexAccess, ScalarAccess], Expr], Stmt):
_dtype_validation = common.assign_stmt_dtype_validation(strict=True)


Expand Down Expand Up @@ -851,14 +851,14 @@ class Tasklet(ComputationNode, IterationNode, eve.SymbolTableTrait):
class DomainMap(ComputationNode, IterationNode):
index_ranges: List[Range]
schedule: MapSchedule
computations: List[Union[Tasklet, DomainMap, NestedSDFG]]
computations: List[Union[DomainMap, NestedSDFG, Tasklet]]


class ComputationState(IterationNode):
computations: List[Union[Tasklet, DomainMap]]
computations: List[Union[DomainMap, Tasklet]]


class DomainLoop(IterationNode, ComputationNode):
class DomainLoop(ComputationNode, IterationNode):
axis: Axis
index_range: Range
loop_states: List[Union[ComputationState, DomainLoop]]
Expand All @@ -868,7 +868,7 @@ class NestedSDFG(ComputationNode, eve.SymbolTableTrait):
label: eve.Coerced[eve.SymbolRef]
field_decls: List[FieldDecl]
symbol_decls: List[SymbolDecl]
states: List[Union[DomainLoop, ComputationState]]
states: List[Union[ComputationState, DomainLoop]]


# There are circular type references with string placeholders. These statements let datamodels resolve those.
Expand Down
12 changes: 10 additions & 2 deletions src/gt4py/cartesian/gtc/dace/oir_to_dace.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class SDFGContext:
decls: Dict[str, oir.Decl]
block_extents: Dict[int, Extent]
access_infos: Dict[str, dcir.FieldAccessInfo]
loop_counter: int = 0

def __init__(self, stencil: oir.Stencil):
self.sdfg = dace.SDFG(stencil.name)
Expand Down Expand Up @@ -98,14 +99,21 @@ def _make_dace_subset(self, local_access_info, field):
global_access_info, local_access_info, self.decls[field].data_dims
)

def _vloop_name(self, node: oir.VerticalLoop, ctx: OirSDFGBuilder.SDFGContext) -> str:
sdfg_name = ctx.sdfg.name
counter = ctx.loop_counter
ctx.loop_counter += 1

return f"{sdfg_name}_vloop_{counter}_{node.loop_order}_{id(node)}"

def visit_VerticalLoop(self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFGContext):
declarations = {
acc.name: ctx.decls[acc.name]
for acc in node.walk_values().if_isinstance(oir.FieldAccess, oir.ScalarAccess)
if acc.name in ctx.decls
}
library_node = StencilComputation(
name=f"{ctx.sdfg.name}_computation_{id(node)}",
name=self._vloop_name(node, ctx),
extents=ctx.block_extents,
declarations=declarations,
oir_node=node,
Expand Down Expand Up @@ -174,6 +182,6 @@ def visit_Stencil(self, node: oir.Stencil):
lifetime=dace.AllocationLifetime.Persistent,
debuginfo=get_dace_debuginfo(decl),
)
self.generic_visit(node, ctx=ctx)
self.visit(node.vertical_loops, ctx=ctx)
ctx.sdfg.validate()
return ctx.sdfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# SPDX-License-Identifier: BSD-3-Clause

import numpy as np
import pytest

from gt4py.cartesian import gtscript, testing as gt_testing
from gt4py.cartesian.gtscript import (
Expand All @@ -25,7 +24,6 @@
from .stencil_definitions import optional_field, two_optional_fields


# ---- Identity stencil ----
class TestIdentity(gt_testing.StencilTestSuite):
"""Identity stencil."""

Expand All @@ -43,7 +41,6 @@ def validation(field_a, domain=None, origin=None):
pass


# ---- Copy stencil ----
class TestCopy(gt_testing.StencilTestSuite):
"""Copy stencil."""

Expand Down Expand Up @@ -86,7 +83,6 @@ def validation(field_a, field_b, domain=None, origin=None):
field_b[...] = (field_b[...] - 1.0) / 2.0


# ---- Scale stencil ----
class TestGlobalScale(gt_testing.StencilTestSuite):
"""Scale stencil using a global global_name."""

Expand All @@ -108,7 +104,6 @@ def validation(field_a, domain, origin, **kwargs):
field_a[...] = SCALE_FACTOR * field_a # noqa: F821 [undefined-name]


# ---- Parametric scale stencil -----
class TestParametricScale(gt_testing.StencilTestSuite):
"""Scale stencil using a parameter."""

Expand All @@ -128,7 +123,6 @@ def validation(field_a, *, scale, domain, origin, **kwargs):
field_a[...] = scale * field_a


# --- Parametric-mix stencil ----
class TestParametricMix(gt_testing.StencilTestSuite):
"""Linear combination of input fields using several parameters."""

Expand Down