Skip to content

Commit

Permalink
fix[next][dace]: Fix memlet for array slicing (#1399)
Browse files Browse the repository at this point in the history
Implementation of array slicing in DaCe backend changed to a mapped tasklet. Tested on GPU. CUDA code generation did not support the previous implementation, based on memlet in nested-SDFG.
  • Loading branch information
edopao authored Dec 19, 2023
1 parent 315d920 commit 15a7bd6
Showing 1 changed file with 21 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import dace
import numpy as np
from dace import subsets
from dace.transformation.dataflow import MapFusion

import gt4py.eve.codegen
Expand Down Expand Up @@ -754,52 +753,29 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]:
dace.Memlet.simple(node.data, "0") for node in deref_nodes[1:]
]

# we create a nested sdfg in order to access the index scalar values as symbols in a memlet subset
deref_sdfg = dace.SDFG("deref")
deref_sdfg.add_array(
"_inp", field_array.shape, iterator.dtype, strides=field_array.strides
)
for connector in deref_connectors[1:]:
deref_sdfg.add_scalar(connector, _INDEX_DTYPE)
deref_sdfg.add_array("_out", result_shape, iterator.dtype)
deref_init_state = deref_sdfg.add_state("init", True)
deref_access_state = deref_sdfg.add_state("access")
deref_sdfg.add_edge(
deref_init_state,
deref_access_state,
dace.InterstateEdge(
assignments={f"_sym{inp}": inp for inp in deref_connectors[1:]}
),
)
# we access the size in source field shape as symbols set on the nested sdfg
source_subset = tuple(
f"_sym_i_{dim}" if dim in iterator.indices else f"0:{size}"
# we create a mapped tasklet for array slicing
map_ranges = {
f"_i_{dim}": f"0:{size}"
for dim, size in zip(sorted_dims, field_array.shape)
if dim not in iterator.indices
}
src_subset = ",".join([f"_i_{dim}" for dim in sorted_dims])
dst_subset = ",".join(
[f"_i_{dim}" for dim in sorted_dims if dim not in iterator.indices]
)
deref_access_state.add_nedge(
deref_access_state.add_access("_inp"),
deref_access_state.add_access("_out"),
dace.Memlet(
data="_out",
subset=subsets.Range.from_array(result_array),
other_subset=",".join(source_subset),
),
)

deref_node = self.context.state.add_nested_sdfg(
deref_sdfg,
self.context.body,
inputs=set(deref_connectors),
outputs={"_out"},
)
for connector, node, memlet in zip(deref_connectors, deref_nodes, deref_memlets):
self.context.state.add_edge(node, None, deref_node, connector, memlet)
self.context.state.add_edge(
deref_node,
"_out",
result_node,
None,
dace.Memlet.from_array(result_name, result_array),
self.context.state.add_mapped_tasklet(
"deref",
map_ranges,
inputs={k: v for k, v in zip(deref_connectors, deref_memlets)},
outputs={
"_out": dace.Memlet.from_array(result_name, result_array),
},
code=f"_out[{dst_subset}] = _inp[{src_subset}]",
external_edges=True,
input_nodes={node.data: node for node in deref_nodes},
output_nodes={
result_name: result_node,
},
)
return [ValueExpr(result_node, iterator.dtype)]

Expand Down

0 comments on commit 15a7bd6

Please sign in to comment.