From 15a7bd627d9fc818befd5f6ff6e795868563ff37 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 19 Dec 2023 08:43:41 +0100 Subject: [PATCH] fix[next][dace]: Fix memlet for array slicing (#1399) Implementation of array slicing in DaCe backend changed to a mapped tasklet. Tested on GPU. CUDA code generation did not support the previous implementation, based on memlet in nested-SDFG. --- .../runners/dace_iterator/itir_to_tasklet.py | 66 ++++++------------- 1 file changed, 21 insertions(+), 45 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index d08476847f..4c202b1fe8 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -18,7 +18,6 @@ import dace import numpy as np -from dace import subsets from dace.transformation.dataflow import MapFusion import gt4py.eve.codegen @@ -754,52 +753,29 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: dace.Memlet.simple(node.data, "0") for node in deref_nodes[1:] ] - # we create a nested sdfg in order to access the index scalar values as symbols in a memlet subset - deref_sdfg = dace.SDFG("deref") - deref_sdfg.add_array( - "_inp", field_array.shape, iterator.dtype, strides=field_array.strides - ) - for connector in deref_connectors[1:]: - deref_sdfg.add_scalar(connector, _INDEX_DTYPE) - deref_sdfg.add_array("_out", result_shape, iterator.dtype) - deref_init_state = deref_sdfg.add_state("init", True) - deref_access_state = deref_sdfg.add_state("access") - deref_sdfg.add_edge( - deref_init_state, - deref_access_state, - dace.InterstateEdge( - assignments={f"_sym{inp}": inp for inp in deref_connectors[1:]} - ), - ) - # we access the size in source field shape as symbols set on the nested sdfg - source_subset = tuple( - f"_sym_i_{dim}" if dim in iterator.indices else f"0:{size}" + # we create a mapped tasklet for array slicing + map_ranges = { + f"_i_{dim}": f"0:{size}" for dim, size in zip(sorted_dims, field_array.shape) + if dim not in iterator.indices + } + src_subset = ",".join([f"_i_{dim}" for dim in sorted_dims]) + dst_subset = ",".join( + [f"_i_{dim}" for dim in sorted_dims if dim not in iterator.indices] ) - deref_access_state.add_nedge( - deref_access_state.add_access("_inp"), - deref_access_state.add_access("_out"), - dace.Memlet( - data="_out", - subset=subsets.Range.from_array(result_array), - other_subset=",".join(source_subset), - ), - ) - - deref_node = self.context.state.add_nested_sdfg( - deref_sdfg, - self.context.body, - inputs=set(deref_connectors), - outputs={"_out"}, - ) - for connector, node, memlet in zip(deref_connectors, deref_nodes, deref_memlets): - self.context.state.add_edge(node, None, deref_node, connector, memlet) - self.context.state.add_edge( - deref_node, - "_out", - result_node, - None, - dace.Memlet.from_array(result_name, result_array), + self.context.state.add_mapped_tasklet( + "deref", + map_ranges, + inputs={k: v for k, v in zip(deref_connectors, deref_memlets)}, + outputs={ + "_out": dace.Memlet.from_array(result_name, result_array), + }, + code=f"_out[{dst_subset}] = _inp[{src_subset}]", + external_edges=True, + input_nodes={node.data: node for node in deref_nodes}, + output_nodes={ + result_name: result_node, + }, ) return [ValueExpr(result_node, iterator.dtype)]