diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 5b822a1ab5..c84e1061c4 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -132,16 +132,6 @@ def _set_tile_sizes(sdfg: dace.SDFG): node.tile_sizes_interpretation = "strides" -def _to_device(sdfg: dace.SDFG, device: str) -> None: - """Update sdfg in place.""" - if device == "gpu": - for array in sdfg.arrays.values(): - array.storage = dace.StorageType.GPU_Global - for node, _ in sdfg.all_nodes_recursive(): - if isinstance(node, StencilComputation): - node.device = dace.DeviceType.GPU - - def _pre_expand_transformations(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, layout_map): args_data = make_args_data_from_gtir(gtir_pipeline) @@ -151,10 +141,6 @@ def _pre_expand_transformations(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, la sdfg.add_state(gtir_pipeline.gtir.name) return sdfg - for array in sdfg.arrays.values(): - if array.transient: - array.lifetime = dace.AllocationLifetime.Persistent - sdfg.simplify(validate=False) _set_expansion_orders(sdfg) @@ -351,9 +337,10 @@ def _unexpanded_sdfg(self): "oir_pipeline", DefaultPipeline() ) oir_node = oir_pipeline.run(base_oir) - sdfg = OirSDFGBuilder().visit(oir_node) + sdfg = OirSDFGBuilder().visit( + oir_node, device=self.builder.backend.storage_info["device"] + ) - _to_device(sdfg, self.builder.backend.storage_info["device"]) _pre_expand_transformations( self.builder.gtir_pipeline, sdfg, diff --git a/src/gt4py/cartesian/gtc/dace/constants.py b/src/gt4py/cartesian/gtc/dace/constants.py new file mode 100644 index 0000000000..5ebf86b769 --- /dev/null +++ b/src/gt4py/cartesian/gtc/dace/constants.py @@ -0,0 +1,11 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +# StencilComputation in/out connector prefixes +CONNECTOR_PREFIX_IN: str = "__in_" +CONNECTOR_PREFIX_OUT: str = "__out_" diff --git a/src/gt4py/cartesian/gtc/dace/expansion/expansion.py b/src/gt4py/cartesian/gtc/dace/expansion/expansion.py index 27f55d451d..20d7743661 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/expansion.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/expansion.py @@ -18,6 +18,7 @@ import sympy from gt4py.cartesian.gtc.dace import daceir as dcir +from gt4py.cartesian.gtc.dace.constants import CONNECTOR_PREFIX_IN, CONNECTOR_PREFIX_OUT from gt4py.cartesian.gtc.dace.expansion.daceir_builder import DaCeIRBuilder from gt4py.cartesian.gtc.dace.expansion.sdfg_builder import StencilComputationSDFGBuilder @@ -77,11 +78,11 @@ def _fix_context( """ # change connector names for in_edge in parent_state.in_edges(node): - assert in_edge.dst_conn.startswith("__in_") - in_edge.dst_conn = in_edge.dst_conn[len("__in_") :] + assert in_edge.dst_conn.startswith(CONNECTOR_PREFIX_IN) + in_edge.dst_conn = in_edge.dst_conn.removeprefix(CONNECTOR_PREFIX_IN) for out_edge in parent_state.out_edges(node): - assert out_edge.src_conn.startswith("__out_") - out_edge.src_conn = out_edge.src_conn[len("__out_") :] + assert out_edge.src_conn.startswith(CONNECTOR_PREFIX_OUT) + out_edge.src_conn = out_edge.src_conn.removeprefix(CONNECTOR_PREFIX_OUT) # union input and output subsets subsets = {} @@ -125,9 +126,13 @@ def _get_parent_arrays( ) -> Dict[str, dace.data.Data]: parent_arrays: Dict[str, dace.data.Data] = {} for edge in (e for e in parent_state.in_edges(node) if e.dst_conn is not None): - parent_arrays[edge.dst_conn[len("__in_") :]] = parent_sdfg.arrays[edge.data.data] + parent_arrays[edge.dst_conn.removeprefix(CONNECTOR_PREFIX_IN)] = parent_sdfg.arrays[ + edge.data.data + ] for edge in (e for e in parent_state.out_edges(node) if e.src_conn is not None): - parent_arrays[edge.src_conn[len("__out_") :]] = parent_sdfg.arrays[edge.data.data] + parent_arrays[edge.src_conn.removeprefix(CONNECTOR_PREFIX_OUT)] = parent_sdfg.arrays[ + edge.data.data + ] return parent_arrays @staticmethod diff --git a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py index 6728ccaa7d..3aeda7a484 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py @@ -20,9 +20,8 @@ from gt4py import eve from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.dace.expansion.tasklet_codegen import TaskletCodegen -from gt4py.cartesian.gtc.dace.expansion.utils import get_dace_debuginfo from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass -from gt4py.cartesian.gtc.dace.utils import make_dace_subset +from gt4py.cartesian.gtc.dace.utils import get_dace_debuginfo, make_dace_subset class StencilComputationSDFGBuilder(eve.VisitorWithSymbolTableTrait): @@ -268,13 +267,13 @@ def visit_ComputationState( for memlet in computation.read_memlets: if memlet.field not in read_acc_and_conn: read_acc_and_conn[memlet.field] = ( - sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)), + sdfg_ctx.state.add_access(memlet.field), None, ) for memlet in computation.write_memlets: if memlet.field not in write_acc_and_conn: write_acc_and_conn[memlet.field] = ( - sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)), + sdfg_ctx.state.add_access(memlet.field), None, ) node_ctx = StencilComputationSDFGBuilder.NodeContext( @@ -298,7 +297,7 @@ def visit_FieldDecl( dtype=data_type_to_dace_typeclass(node.dtype), storage=node.storage.to_dace_storage(), transient=node.name not in non_transients, - debuginfo=dace.DebugInfo(0), + debuginfo=get_dace_debuginfo(node), ) def visit_SymbolDecl( @@ -343,7 +342,6 @@ def visit_NestedSDFG( inputs=node.input_connectors, outputs=node.output_connectors, symbol_mapping=symbol_mapping, - debuginfo=dace.DebugInfo(0), ) self.visit( node.read_memlets, diff --git a/src/gt4py/cartesian/gtc/dace/expansion/utils.py b/src/gt4py/cartesian/gtc/dace/expansion/utils.py index 919ec02996..7a29ec99a6 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/utils.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/utils.py @@ -10,11 +10,6 @@ from typing import TYPE_CHECKING, List -import dace -import dace.data -import dace.library -import dace.subsets - from gt4py import eve from gt4py.cartesian.gtc import common, oir from gt4py.cartesian.gtc.dace import daceir as dcir @@ -25,15 +20,6 @@ from gt4py.cartesian.gtc.dace.nodes import StencilComputation -def get_dace_debuginfo(node: common.LocNode): - if node.loc is not None: - return dace.dtypes.DebugInfo( - node.loc.line, node.loc.column, node.loc.line, node.loc.column, node.loc.filename - ) - else: - return dace.dtypes.DebugInfo(0) - - class HorizontalIntervalRemover(eve.NodeTranslator): def visit_HorizontalMask(self, node: common.HorizontalMask, *, axis: dcir.Axis): mask_attrs = dict(i=node.i, j=node.j) diff --git a/src/gt4py/cartesian/gtc/dace/nodes.py b/src/gt4py/cartesian/gtc/dace/nodes.py index 34401e18b9..5f47881db1 100644 --- a/src/gt4py/cartesian/gtc/dace/nodes.py +++ b/src/gt4py/cartesian/gtc/dace/nodes.py @@ -23,12 +23,12 @@ from gt4py.cartesian.gtc import common, oir from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.dace.expansion.expansion import StencilComputationExpansion +from gt4py.cartesian.gtc.dace.expansion.utils import HorizontalExecutionSplitter +from gt4py.cartesian.gtc.dace.expansion_specification import ExpansionItem, make_expansion_order +from gt4py.cartesian.gtc.dace.utils import get_dace_debuginfo from gt4py.cartesian.gtc.definitions import Extent from gt4py.cartesian.gtc.oir import Decl, FieldDecl, VerticalLoop, VerticalLoopSection -from .expansion.utils import HorizontalExecutionSplitter, get_dace_debuginfo -from .expansion_specification import ExpansionItem, make_expansion_order - def _set_expansion_order( node: StencilComputation, expansion_order: Union[List[ExpansionItem], List[str]] @@ -119,6 +119,7 @@ def __init__( extents: Optional[Dict[int, Extent]] = None, declarations: Optional[Dict[str, Decl]] = None, expansion_order=None, + device: Optional[dace.DeviceType] = None, *args, **kwargs, ): @@ -137,6 +138,7 @@ def __init__( self.oir_node = typing.cast(PickledDataclassProperty, oir_node) self.extents = extents_dict # type: ignore self.declarations = declarations # type: ignore + self.device = device self.symbol_mapping = { decl.name: dace.symbol( decl.name, diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py index 14448bb08e..e5c81f199b 100644 --- a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py +++ b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py @@ -9,7 +9,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Dict +from typing import Dict, Literal import dace import dace.properties @@ -18,9 +18,14 @@ import gt4py.cartesian.gtc.oir as oir from gt4py import eve from gt4py.cartesian.gtc.dace import daceir as dcir +from gt4py.cartesian.gtc.dace.constants import CONNECTOR_PREFIX_IN, CONNECTOR_PREFIX_OUT from gt4py.cartesian.gtc.dace.nodes import StencilComputation from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass -from gt4py.cartesian.gtc.dace.utils import compute_dcir_access_infos, make_dace_subset +from gt4py.cartesian.gtc.dace.utils import ( + compute_dcir_access_infos, + get_dace_debuginfo, + make_dace_subset, +) from gt4py.cartesian.gtc.definitions import Extent from gt4py.cartesian.gtc.passes.oir_optimizations.utils import ( AccessCollector, @@ -28,6 +33,17 @@ ) +transient_storage_per_device: Dict[Literal["cpu", "gpu"], dace.StorageType] = { + "cpu": dace.StorageType.Default, + "gpu": dace.StorageType.GPU_Global, +} + +device_type_per_device: Dict[Literal["cpu", "gpu"], dace.DeviceType] = { + "cpu": dace.DeviceType.CPU, + "gpu": dace.DeviceType.GPU, +} + + class OirSDFGBuilder(eve.NodeVisitor): @dataclass class SDFGContext: @@ -94,8 +110,12 @@ def _make_dace_subset(self, local_access_info, field): ) def visit_VerticalLoop( - self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFGContext, **kwargs - ): + self, + node: oir.VerticalLoop, + *, + ctx: OirSDFGBuilder.SDFGContext, + device: Literal["cpu", "gpu"], + ) -> None: declarations = { acc.name: ctx.decls[acc.name] for acc in node.walk_values().if_isinstance(oir.FieldAccess, oir.ScalarAccess) @@ -106,6 +126,7 @@ def visit_VerticalLoop( extents=ctx.block_extents, declarations=declarations, oir_node=node, + device=device_type_per_device[device], ) state = ctx.sdfg.add_state() @@ -117,23 +138,25 @@ def visit_VerticalLoop( access_collection = AccessCollector.apply(node) for field in access_collection.read_fields(): - access_node = state.add_access(field, debuginfo=dace.DebugInfo(0)) - library_node.add_in_connector("__in_" + field) + access_node = state.add_access(field, debuginfo=get_dace_debuginfo(declarations[field])) + connector_name = CONNECTOR_PREFIX_IN + field + library_node.add_in_connector(connector_name) subset = ctx.make_input_dace_subset(node, field) state.add_edge( - access_node, None, library_node, "__in_" + field, dace.Memlet(field, subset=subset) + access_node, None, library_node, connector_name, dace.Memlet(field, subset=subset) ) for field in access_collection.write_fields(): - access_node = state.add_access(field, debuginfo=dace.DebugInfo(0)) - library_node.add_out_connector("__out_" + field) + access_node = state.add_access(field, debuginfo=get_dace_debuginfo(declarations[field])) + connector_name = CONNECTOR_PREFIX_OUT + field + library_node.add_out_connector(connector_name) subset = ctx.make_output_dace_subset(node, field) state.add_edge( - library_node, "__out_" + field, access_node, None, dace.Memlet(field, subset=subset) + library_node, connector_name, access_node, None, dace.Memlet(field, subset=subset) ) - def visit_Stencil(self, node: oir.Stencil, **kwargs): - ctx = OirSDFGBuilder.SDFGContext(stencil=node) + def visit_Stencil(self, node: oir.Stencil, *, device: Literal["cpu", "gpu"]) -> dace.SDFG: + ctx = OirSDFGBuilder.SDFGContext(node) for param in node.params: if isinstance(param, oir.FieldDecl): dim_strs = [d for i, d in enumerate("IJK") if param.dimensions[i]] + [ @@ -148,7 +171,8 @@ def visit_Stencil(self, node: oir.Stencil, **kwargs): ], dtype=data_type_to_dace_typeclass(param.dtype), transient=False, - debuginfo=dace.DebugInfo(0), + storage=transient_storage_per_device[device], + debuginfo=get_dace_debuginfo(param), ) else: ctx.sdfg.add_symbol(param.name, stype=data_type_to_dace_typeclass(param.dtype)) @@ -166,8 +190,10 @@ def visit_Stencil(self, node: oir.Stencil, **kwargs): ], dtype=data_type_to_dace_typeclass(decl.dtype), transient=True, - debuginfo=dace.DebugInfo(0), + lifetime=dace.AllocationLifetime.Persistent, + storage=transient_storage_per_device[device], + debuginfo=get_dace_debuginfo(decl), ) - self.generic_visit(node, ctx=ctx) + self.generic_visit(node, ctx=ctx, device=device) ctx.sdfg.validate() return ctx.sdfg diff --git a/src/gt4py/cartesian/gtc/dace/utils.py b/src/gt4py/cartesian/gtc/dace/utils.py index bd65861a49..4e8a0f0c7b 100644 --- a/src/gt4py/cartesian/gtc/dace/utils.py +++ b/src/gt4py/cartesian/gtc/dace/utils.py @@ -23,6 +23,15 @@ from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_horizontal_block_extents +def get_dace_debuginfo(node: common.LocNode) -> dace.dtypes.DebugInfo: + if node.loc is None: + return dace.dtypes.DebugInfo(0) + + return dace.dtypes.DebugInfo( + node.loc.line, node.loc.column, node.loc.line, node.loc.column, node.loc.filename + ) + + def array_dimensions(array: dace.data.Array): dims = [ any(