Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor[cartesian]: unexpanded sdfg cleanups #1843

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 3 additions & 16 deletions src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,6 @@ def _set_tile_sizes(sdfg: dace.SDFG):
node.tile_sizes_interpretation = "strides"


def _to_device(sdfg: dace.SDFG, device: str) -> None:
"""Update sdfg in place."""
if device == "gpu":
for array in sdfg.arrays.values():
array.storage = dace.StorageType.GPU_Global
for node, _ in sdfg.all_nodes_recursive():
if isinstance(node, StencilComputation):
node.device = dace.DeviceType.GPU


def _pre_expand_transformations(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, layout_map):
args_data = make_args_data_from_gtir(gtir_pipeline)

Expand All @@ -151,10 +141,6 @@ def _pre_expand_transformations(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, la
sdfg.add_state(gtir_pipeline.gtir.name)
return sdfg

for array in sdfg.arrays.values():
if array.transient:
array.lifetime = dace.AllocationLifetime.Persistent

sdfg.simplify(validate=False)

_set_expansion_orders(sdfg)
Expand Down Expand Up @@ -351,9 +337,10 @@ def _unexpanded_sdfg(self):
"oir_pipeline", DefaultPipeline()
)
oir_node = oir_pipeline.run(base_oir)
sdfg = OirSDFGBuilder().visit(oir_node)
sdfg = OirSDFGBuilder().visit(
oir_node, device=self.builder.backend.storage_info["device"]
)

_to_device(sdfg, self.builder.backend.storage_info["device"])
_pre_expand_transformations(
self.builder.gtir_pipeline,
sdfg,
Expand Down
11 changes: 11 additions & 0 deletions src/gt4py/cartesian/gtc/dace/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

# StencilComputation in/out connector prefixes
CONNECTOR_PREFIX_IN: str = "__in_"
CONNECTOR_PREFIX_OUT: str = "__out_"
17 changes: 11 additions & 6 deletions src/gt4py/cartesian/gtc/dace/expansion/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import sympy

from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.constants import CONNECTOR_PREFIX_IN, CONNECTOR_PREFIX_OUT
from gt4py.cartesian.gtc.dace.expansion.daceir_builder import DaCeIRBuilder
from gt4py.cartesian.gtc.dace.expansion.sdfg_builder import StencilComputationSDFGBuilder

Expand Down Expand Up @@ -77,11 +78,11 @@ def _fix_context(
"""
# change connector names
for in_edge in parent_state.in_edges(node):
assert in_edge.dst_conn.startswith("__in_")
in_edge.dst_conn = in_edge.dst_conn[len("__in_") :]
assert in_edge.dst_conn.startswith(CONNECTOR_PREFIX_IN)
in_edge.dst_conn = in_edge.dst_conn.removeprefix(CONNECTOR_PREFIX_IN)
for out_edge in parent_state.out_edges(node):
assert out_edge.src_conn.startswith("__out_")
out_edge.src_conn = out_edge.src_conn[len("__out_") :]
assert out_edge.src_conn.startswith(CONNECTOR_PREFIX_OUT)
out_edge.src_conn = out_edge.src_conn.removeprefix(CONNECTOR_PREFIX_OUT)

# union input and output subsets
subsets = {}
Expand Down Expand Up @@ -125,9 +126,13 @@ def _get_parent_arrays(
) -> Dict[str, dace.data.Data]:
parent_arrays: Dict[str, dace.data.Data] = {}
for edge in (e for e in parent_state.in_edges(node) if e.dst_conn is not None):
parent_arrays[edge.dst_conn[len("__in_") :]] = parent_sdfg.arrays[edge.data.data]
parent_arrays[edge.dst_conn.removeprefix(CONNECTOR_PREFIX_IN)] = parent_sdfg.arrays[
edge.data.data
]
for edge in (e for e in parent_state.out_edges(node) if e.src_conn is not None):
parent_arrays[edge.src_conn[len("__out_") :]] = parent_sdfg.arrays[edge.data.data]
parent_arrays[edge.src_conn.removeprefix(CONNECTOR_PREFIX_OUT)] = parent_sdfg.arrays[
edge.data.data
]
return parent_arrays

@staticmethod
Expand Down
10 changes: 4 additions & 6 deletions src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
from gt4py import eve
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.expansion.tasklet_codegen import TaskletCodegen
from gt4py.cartesian.gtc.dace.expansion.utils import get_dace_debuginfo
from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass
from gt4py.cartesian.gtc.dace.utils import make_dace_subset
from gt4py.cartesian.gtc.dace.utils import get_dace_debuginfo, make_dace_subset


class StencilComputationSDFGBuilder(eve.VisitorWithSymbolTableTrait):
Expand Down Expand Up @@ -268,13 +267,13 @@ def visit_ComputationState(
for memlet in computation.read_memlets:
if memlet.field not in read_acc_and_conn:
read_acc_and_conn[memlet.field] = (
sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)),
sdfg_ctx.state.add_access(memlet.field),
None,
)
for memlet in computation.write_memlets:
if memlet.field not in write_acc_and_conn:
write_acc_and_conn[memlet.field] = (
sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)),
sdfg_ctx.state.add_access(memlet.field),
None,
)
node_ctx = StencilComputationSDFGBuilder.NodeContext(
Expand All @@ -298,7 +297,7 @@ def visit_FieldDecl(
dtype=data_type_to_dace_typeclass(node.dtype),
storage=node.storage.to_dace_storage(),
transient=node.name not in non_transients,
debuginfo=dace.DebugInfo(0),
debuginfo=get_dace_debuginfo(node),
)

def visit_SymbolDecl(
Expand Down Expand Up @@ -343,7 +342,6 @@ def visit_NestedSDFG(
inputs=node.input_connectors,
outputs=node.output_connectors,
symbol_mapping=symbol_mapping,
debuginfo=dace.DebugInfo(0),
)
self.visit(
node.read_memlets,
Expand Down
14 changes: 0 additions & 14 deletions src/gt4py/cartesian/gtc/dace/expansion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@

from typing import TYPE_CHECKING, List

import dace
import dace.data
import dace.library
import dace.subsets

from gt4py import eve
from gt4py.cartesian.gtc import common, oir
from gt4py.cartesian.gtc.dace import daceir as dcir
Expand All @@ -25,15 +20,6 @@
from gt4py.cartesian.gtc.dace.nodes import StencilComputation


def get_dace_debuginfo(node: common.LocNode):
if node.loc is not None:
return dace.dtypes.DebugInfo(
node.loc.line, node.loc.column, node.loc.line, node.loc.column, node.loc.filename
)
else:
return dace.dtypes.DebugInfo(0)


class HorizontalIntervalRemover(eve.NodeTranslator):
def visit_HorizontalMask(self, node: common.HorizontalMask, *, axis: dcir.Axis):
mask_attrs = dict(i=node.i, j=node.j)
Expand Down
8 changes: 5 additions & 3 deletions src/gt4py/cartesian/gtc/dace/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from gt4py.cartesian.gtc import common, oir
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.expansion.expansion import StencilComputationExpansion
from gt4py.cartesian.gtc.dace.expansion.utils import HorizontalExecutionSplitter
from gt4py.cartesian.gtc.dace.expansion_specification import ExpansionItem, make_expansion_order
from gt4py.cartesian.gtc.dace.utils import get_dace_debuginfo
from gt4py.cartesian.gtc.definitions import Extent
from gt4py.cartesian.gtc.oir import Decl, FieldDecl, VerticalLoop, VerticalLoopSection

from .expansion.utils import HorizontalExecutionSplitter, get_dace_debuginfo
from .expansion_specification import ExpansionItem, make_expansion_order


def _set_expansion_order(
node: StencilComputation, expansion_order: Union[List[ExpansionItem], List[str]]
Expand Down Expand Up @@ -119,6 +119,7 @@ def __init__(
extents: Optional[Dict[int, Extent]] = None,
declarations: Optional[Dict[str, Decl]] = None,
expansion_order=None,
device: Optional[dace.DeviceType] = None,
*args,
**kwargs,
):
Expand All @@ -137,6 +138,7 @@ def __init__(
self.oir_node = typing.cast(PickledDataclassProperty, oir_node)
self.extents = extents_dict # type: ignore
self.declarations = declarations # type: ignore
self.device = device
self.symbol_mapping = {
decl.name: dace.symbol(
decl.name,
Expand Down
56 changes: 41 additions & 15 deletions src/gt4py/cartesian/gtc/dace/oir_to_dace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict
from typing import Dict, Literal

import dace
import dace.properties
Expand All @@ -18,16 +18,32 @@
import gt4py.cartesian.gtc.oir as oir
from gt4py import eve
from gt4py.cartesian.gtc.dace import daceir as dcir
from gt4py.cartesian.gtc.dace.constants import CONNECTOR_PREFIX_IN, CONNECTOR_PREFIX_OUT
from gt4py.cartesian.gtc.dace.nodes import StencilComputation
from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass
from gt4py.cartesian.gtc.dace.utils import compute_dcir_access_infos, make_dace_subset
from gt4py.cartesian.gtc.dace.utils import (
compute_dcir_access_infos,
get_dace_debuginfo,
make_dace_subset,
)
from gt4py.cartesian.gtc.definitions import Extent
from gt4py.cartesian.gtc.passes.oir_optimizations.utils import (
AccessCollector,
compute_horizontal_block_extents,
)


transient_storage_per_device: Dict[Literal["cpu", "gpu"], dace.StorageType] = {
"cpu": dace.StorageType.Default,
"gpu": dace.StorageType.GPU_Global,
}

device_type_per_device: Dict[Literal["cpu", "gpu"], dace.DeviceType] = {
"cpu": dace.DeviceType.CPU,
"gpu": dace.DeviceType.GPU,
}
Comment on lines +36 to +44
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a named type for Literal["cpu", "gpu"]? I found it typed like in the LayoutInfo

class LayoutInfo(TypedDict):
alignment: int # measured in bytes
device: Literal["cpu", "gpu"]
layout_map: Callable[[Tuple[str, ...]], Tuple[Optional[int], ...]]
is_optimal_layout: Callable[[Any, Tuple[str, ...]], bool]



class OirSDFGBuilder(eve.NodeVisitor):
@dataclass
class SDFGContext:
Expand Down Expand Up @@ -94,8 +110,12 @@ def _make_dace_subset(self, local_access_info, field):
)

def visit_VerticalLoop(
self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFGContext, **kwargs
):
self,
node: oir.VerticalLoop,
*,
ctx: OirSDFGBuilder.SDFGContext,
device: Literal["cpu", "gpu"],
) -> None:
declarations = {
acc.name: ctx.decls[acc.name]
for acc in node.walk_values().if_isinstance(oir.FieldAccess, oir.ScalarAccess)
Expand All @@ -106,6 +126,7 @@ def visit_VerticalLoop(
extents=ctx.block_extents,
declarations=declarations,
oir_node=node,
device=device_type_per_device[device],
)

state = ctx.sdfg.add_state()
Expand All @@ -117,23 +138,25 @@ def visit_VerticalLoop(
access_collection = AccessCollector.apply(node)

for field in access_collection.read_fields():
access_node = state.add_access(field, debuginfo=dace.DebugInfo(0))
library_node.add_in_connector("__in_" + field)
access_node = state.add_access(field, debuginfo=get_dace_debuginfo(declarations[field]))
connector_name = CONNECTOR_PREFIX_IN + field
library_node.add_in_connector(connector_name)
subset = ctx.make_input_dace_subset(node, field)
state.add_edge(
access_node, None, library_node, "__in_" + field, dace.Memlet(field, subset=subset)
access_node, None, library_node, connector_name, dace.Memlet(field, subset=subset)
)

for field in access_collection.write_fields():
access_node = state.add_access(field, debuginfo=dace.DebugInfo(0))
library_node.add_out_connector("__out_" + field)
access_node = state.add_access(field, debuginfo=get_dace_debuginfo(declarations[field]))
connector_name = CONNECTOR_PREFIX_OUT + field
library_node.add_out_connector(connector_name)
subset = ctx.make_output_dace_subset(node, field)
state.add_edge(
library_node, "__out_" + field, access_node, None, dace.Memlet(field, subset=subset)
library_node, connector_name, access_node, None, dace.Memlet(field, subset=subset)
)

def visit_Stencil(self, node: oir.Stencil, **kwargs):
ctx = OirSDFGBuilder.SDFGContext(stencil=node)
def visit_Stencil(self, node: oir.Stencil, *, device: Literal["cpu", "gpu"]) -> dace.SDFG:
ctx = OirSDFGBuilder.SDFGContext(node)
for param in node.params:
if isinstance(param, oir.FieldDecl):
dim_strs = [d for i, d in enumerate("IJK") if param.dimensions[i]] + [
Expand All @@ -148,7 +171,8 @@ def visit_Stencil(self, node: oir.Stencil, **kwargs):
],
dtype=data_type_to_dace_typeclass(param.dtype),
transient=False,
debuginfo=dace.DebugInfo(0),
storage=transient_storage_per_device[device],
debuginfo=get_dace_debuginfo(param),
)
else:
ctx.sdfg.add_symbol(param.name, stype=data_type_to_dace_typeclass(param.dtype))
Expand All @@ -166,8 +190,10 @@ def visit_Stencil(self, node: oir.Stencil, **kwargs):
],
dtype=data_type_to_dace_typeclass(decl.dtype),
transient=True,
debuginfo=dace.DebugInfo(0),
lifetime=dace.AllocationLifetime.Persistent,
storage=transient_storage_per_device[device],
debuginfo=get_dace_debuginfo(decl),
)
self.generic_visit(node, ctx=ctx)
self.generic_visit(node, ctx=ctx, device=device)
ctx.sdfg.validate()
return ctx.sdfg
9 changes: 9 additions & 0 deletions src/gt4py/cartesian/gtc/dace/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@
from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_horizontal_block_extents


def get_dace_debuginfo(node: common.LocNode) -> dace.dtypes.DebugInfo:
if node.loc is None:
return dace.dtypes.DebugInfo(0)

return dace.dtypes.DebugInfo(
node.loc.line, node.loc.column, node.loc.line, node.loc.column, node.loc.filename
)


def array_dimensions(array: dace.data.Array):
dims = [
any(
Expand Down