Skip to content

Commit

Permalink
feat[next][dace]: Support for sparse fields and reductions over lift …
Browse files Browse the repository at this point in the history
…expressions (#1377)

This PR adds support to DaCe backend for sparse fields and reductions over lift expressions.
  • Loading branch information
edopao authored Dec 4, 2023
1 parent 6e13354 commit b1f9c9a
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def get_output_nodes(

def visit_FencilDefinition(self, node: itir.FencilDefinition):
program_sdfg = dace.SDFG(name=node.id)
last_state = program_sdfg.add_state("program_entry")
last_state = program_sdfg.add_state("program_entry", True)
self.node_types = itir_typing.infer_all(node)

# Filter neighbor tables from offset providers.
Expand Down Expand Up @@ -216,7 +216,7 @@ def visit_StencilClosure(
# Create the closure's nested SDFG and single state.
closure_sdfg = dace.SDFG(name="closure")
closure_state = closure_sdfg.add_state("closure_entry")
closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init")
closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init", True)

input_names = [str(inp.id) for inp in node.inputs]
neighbor_tables = filter_neighbor_tables(self.offset_provider)
Expand Down Expand Up @@ -423,7 +423,7 @@ def _visit_scan_stencil_closure(
scan_sdfg = dace.SDFG(name="scan")

# create a state machine for lambda call over the scan dimension
start_state = scan_sdfg.add_state("start")
start_state = scan_sdfg.add_state("start", True)
lambda_state = scan_sdfg.add_state("lambda_compute")
end_state = scan_sdfg.add_state("end")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import dace
import numpy as np
from dace import subsets
from dace.transformation.dataflow import MapFusion
from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols

Expand All @@ -39,6 +40,7 @@
filter_neighbor_tables,
flatten_list,
map_nested_sdfg_symbols,
new_array_symbols,
unique_name,
unique_var_name,
)
Expand Down Expand Up @@ -131,9 +133,13 @@ def get_reduce_identity_value(op_name_: str, type_: Any):
}


# Define type of variables used for field indexing
_INDEX_DTYPE = _TYPE_MAPPING["int64"]


@dataclasses.dataclass
class SymbolExpr:
value: str | dace.symbolic.sympy.Basic
value: dace.symbolic.SymbolicType
dtype: dace.typeclass


Expand Down Expand Up @@ -226,7 +232,7 @@ def builtin_neighbors(
outputs={"__result"},
)
idx_name = unique_var_name()
sdfg.add_scalar(idx_name, dace.int64, transient=True)
sdfg.add_scalar(idx_name, _INDEX_DTYPE, transient=True)
state.add_memlet_path(
state.add_access(table_name),
me,
Expand Down Expand Up @@ -283,10 +289,12 @@ def builtin_can_deref(
assert shift_callable.fun.id == "shift"
iterator = transformer._visit_shift(can_deref_callable)

# this iterator is accessing a neighbor table, so it should return an index
assert iterator.dtype in dace.dtypes.INTEGER_TYPES
# create tasklet to check that field indices are non-negative (-1 is invalid)
args = [ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.dimensions]
args = [ValueExpr(access_node, iterator.dtype) for access_node in iterator.indices.values()]
internals = [f"{arg.value.data}_v" for arg in args]
expr_code = " && ".join([f"{v} >= 0" for v in internals])
expr_code = " and ".join([f"{v} >= 0" for v in internals])

# TODO(edopao): select-memlet could maybe allow to efficiently translate can_deref to predicative execution
return transformer.add_expr_tasklet(
Expand All @@ -309,6 +317,26 @@ def builtin_if(
return transformer.add_expr_tasklet(expr_args, expr, type_, "if")


def builtin_list_get(
transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr]
) -> list[ValueExpr]:
args = list(itertools.chain(*transformer.visit(node_args)))
assert len(args) == 2
# index node
assert isinstance(args[0], (SymbolExpr, ValueExpr))
# 1D-array node
assert isinstance(args[1], ValueExpr)
# source node should be a 1D array
assert len(transformer.context.body.arrays[args[1].value.data].shape) == 1

expr_args = [(arg, f"{arg.value.data}_v") for arg in args if not isinstance(arg, SymbolExpr)]
internals = [
arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" for arg in args
]
expr = f"{internals[1]}[{internals[0]}]"
return transformer.add_expr_tasklet(expr_args, expr, args[1].dtype, "list_get")


def builtin_cast(
transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr]
) -> list[ValueExpr]:
Expand Down Expand Up @@ -340,16 +368,13 @@ def builtin_tuple_get(
raise ValueError("Tuple can only be subscripted with compile-time constants")


def builtin_undefined(*args: Any) -> Any:
raise NotImplementedError()


_GENERAL_BUILTIN_MAPPING: dict[
str, Callable[["PythonTaskletCodegen", itir.Expr, list[itir.Expr]], list[ValueExpr]]
] = {
"can_deref": builtin_can_deref,
"cast_": builtin_cast,
"if_": builtin_if,
"list_get": builtin_list_get,
"make_tuple": builtin_make_tuple,
"neighbors": builtin_neighbors,
"tuple_get": builtin_tuple_get,
Expand Down Expand Up @@ -387,16 +412,11 @@ def _add_symbol(self, param, arg):
elif isinstance(arg, IteratorExpr):
# create storage in lambda sdfg
ndims = len(arg.dimensions)
shape = tuple(
dace.symbol(unique_var_name() + "__shp", dace.int64) for _ in range(ndims)
)
strides = tuple(
dace.symbol(unique_var_name() + "__strd", dace.int64) for _ in range(ndims)
)
shape, strides = new_array_symbols(param, ndims)
self._sdfg.add_array(param, shape=shape, strides=strides, dtype=arg.dtype)
index_names = {dim: f"__{param}_i_{dim}" for dim in arg.indices.keys()}
for _, index_name in index_names.items():
self._sdfg.add_scalar(index_name, dtype=dace.int64)
self._sdfg.add_scalar(index_name, dtype=_INDEX_DTYPE)
# update table of lambda symbol
field = self._state.add_access(param)
indices = {
Expand Down Expand Up @@ -513,14 +533,7 @@ def visit_Lambda(

# Add connectivities as arrays
for name in connectivity_names:
shape = (
dace.symbol(unique_var_name() + "__shp", dace.int64),
dace.symbol(unique_var_name() + "__shp", dace.int64),
)
strides = (
dace.symbol(unique_var_name() + "__strd", dace.int64),
dace.symbol(unique_var_name() + "__strd", dace.int64),
)
shape, strides = new_array_symbols(name, ndim=2)
dtype = self.context.body.arrays[name].dtype
lambda_sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype)

Expand All @@ -542,11 +555,9 @@ def visit_Lambda(
result_name = unique_var_name()
lambda_sdfg.add_scalar(result_name, expr.dtype, transient=True)
result_access = lambda_state.add_access(result_name)
lambda_state.add_edge(
lambda_state.add_nedge(
expr.value,
None,
result_access,
None,
# in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution
dace.Memlet.simple(result_access.data, "0", wcr_str=self.context.reduce_wcr),
)
Expand Down Expand Up @@ -587,12 +598,13 @@ def visit_FunCall(self, node: itir.FunCall) -> list[ValueExpr] | IteratorExpr:
return self._visit_reduce(node)

if isinstance(node.fun, itir.SymRef):
if str(node.fun.id) in _MATH_BUILTINS_MAPPING:
builtin_name = str(node.fun.id)
if builtin_name in _MATH_BUILTINS_MAPPING:
return self._visit_numeric_builtin(node)
elif str(node.fun.id) in _GENERAL_BUILTIN_MAPPING:
elif builtin_name in _GENERAL_BUILTIN_MAPPING:
return self._visit_general_builtin(node)
else:
raise NotImplementedError()
raise NotImplementedError(f"{builtin_name} not implemented")
return self._visit_call(node)

def _visit_call(self, node: itir.FunCall):
Expand Down Expand Up @@ -697,7 +709,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]:
for dim in sorted_dims
]
args = [ValueExpr(iterator.field, iterator.dtype)] + [
ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.indices
ValueExpr(iterator.indices[dim], _INDEX_DTYPE) for dim in iterator.indices
]
internals = [f"{arg.value.data}_v" for arg in args]

Expand Down Expand Up @@ -726,14 +738,88 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]:

return [ValueExpr(value=result_access, dtype=iterator.dtype)]

else:
elif all([dim in iterator.indices for dim in iterator.dimensions]):
# The deref iterator has index values on all dimensions: the result will be a scalar
args = [ValueExpr(iterator.field, iterator.dtype)] + [
ValueExpr(iterator.indices[dim], iterator.dtype) for dim in sorted_dims
ValueExpr(iterator.indices[dim], _INDEX_DTYPE) for dim in sorted_dims
]
internals = [f"{arg.value.data}_v" for arg in args]
expr = f"{internals[0]}[{', '.join(internals[1:])}]"
return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref")

else:
# Not all dimensions are included in the deref index list:
# this means the ND-field will be sliced along one or more dimensions and the result will be an array
field_array = self.context.body.arrays[iterator.field.data]
result_shape = tuple(
dim_size
for dim, dim_size in zip(sorted_dims, field_array.shape)
if dim not in iterator.indices
)
result_name = unique_var_name()
self.context.body.add_array(result_name, result_shape, iterator.dtype, transient=True)
result_array = self.context.body.arrays[result_name]
result_node = self.context.state.add_access(result_name)

deref_connectors = ["_inp"] + [
f"_i_{dim}" for dim in sorted_dims if dim in iterator.indices
]
deref_nodes = [iterator.field] + [
iterator.indices[dim] for dim in sorted_dims if dim in iterator.indices
]
deref_memlets = [dace.Memlet.from_array(iterator.field.data, field_array)] + [
dace.Memlet.simple(node.data, "0") for node in deref_nodes[1:]
]

# we create a nested sdfg in order to access the index scalar values as symbols in a memlet subset
deref_sdfg = dace.SDFG("deref")
deref_sdfg.add_array(
"_inp", field_array.shape, iterator.dtype, strides=field_array.strides
)
for connector in deref_connectors[1:]:
deref_sdfg.add_scalar(connector, _INDEX_DTYPE)
deref_sdfg.add_array("_out", result_shape, iterator.dtype)
deref_init_state = deref_sdfg.add_state("init", True)
deref_access_state = deref_sdfg.add_state("access")
deref_sdfg.add_edge(
deref_init_state,
deref_access_state,
dace.InterstateEdge(
assignments={f"_sym{inp}": inp for inp in deref_connectors[1:]}
),
)
# we access the size in source field shape as symbols set on the nested sdfg
source_subset = tuple(
f"_sym_i_{dim}" if dim in iterator.indices else f"0:{size}"
for dim, size in zip(sorted_dims, field_array.shape)
)
deref_access_state.add_nedge(
deref_access_state.add_access("_inp"),
deref_access_state.add_access("_out"),
dace.Memlet(
data="_out",
subset=subsets.Range.from_array(result_array),
other_subset=",".join(source_subset),
),
)

deref_node = self.context.state.add_nested_sdfg(
deref_sdfg,
self.context.body,
inputs=set(deref_connectors),
outputs={"_out"},
)
for connector, node, memlet in zip(deref_connectors, deref_nodes, deref_memlets):
self.context.state.add_edge(node, None, deref_node, connector, memlet)
self.context.state.add_edge(
deref_node,
"_out",
result_node,
None,
dace.Memlet.from_array(result_name, result_array),
)
return [ValueExpr(result_node, iterator.dtype)]

def _split_shift_args(
self, args: list[itir.Expr]
) -> tuple[list[itir.Expr], Optional[list[itir.Expr]]]:
Expand All @@ -760,6 +846,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr:
offset_dim = tail[0].value
assert isinstance(offset_dim, str)
offset_node = self.visit(tail[1])[0]
assert offset_node.dtype in dace.dtypes.INTEGER_TYPES

if isinstance(self.offset_provider[offset_dim], NeighborTableOffsetProvider):
offset_provider = self.offset_provider[offset_dim]
Expand All @@ -769,7 +856,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr:
target_dim = offset_provider.neighbor_axis.value
args = [
ValueExpr(connectivity, offset_provider.table.dtype),
ValueExpr(iterator.indices[shifted_dim], dace.int64),
ValueExpr(iterator.indices[shifted_dim], offset_node.dtype),
offset_node,
]
internals = [f"{arg.value.data}_v" for arg in args]
Expand All @@ -780,7 +867,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr:
shifted_dim = offset_provider.origin_axis.value
target_dim = offset_provider.neighbor_axis.value
args = [
ValueExpr(iterator.indices[shifted_dim], dace.int64),
ValueExpr(iterator.indices[shifted_dim], offset_node.dtype),
offset_node,
]
internals = [f"{arg.value.data}_v" for arg in args]
Expand All @@ -791,14 +878,14 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr:
shifted_dim = self.offset_provider[offset_dim].value
target_dim = shifted_dim
args = [
ValueExpr(iterator.indices[shifted_dim], dace.int64),
ValueExpr(iterator.indices[shifted_dim], offset_node.dtype),
offset_node,
]
internals = [f"{arg.value.data}_v" for arg in args]
expr = f"{internals[0]} + {internals[1]}"

shifted_value = self.add_expr_tasklet(
list(zip(args, internals)), expr, dace.dtypes.int64, "shift"
list(zip(args, internals)), expr, offset_node.dtype, "shift"
)[0].value

shifted_index = {dim: value for dim, value in iterator.indices.items()}
Expand All @@ -811,7 +898,7 @@ def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]:
offset = node.value
assert isinstance(offset, int)
offset_var = unique_var_name()
self.context.body.add_scalar(offset_var, dace.dtypes.int64, transient=True)
self.context.body.add_scalar(offset_var, _INDEX_DTYPE, transient=True)
offset_node = self.context.state.add_access(offset_var)
tasklet_node = self.context.state.add_tasklet(
"get_offset", {}, {"__out"}, f"__out = {offset}"
Expand Down Expand Up @@ -906,7 +993,7 @@ def _visit_reduce(self, node: itir.FunCall):

# initialize the reduction result based on type of operation
init_value = get_reduce_identity_value(op_name.id, result_dtype)
init_state = self.context.body.add_state_before(self.context.state, "init")
init_state = self.context.body.add_state_before(self.context.state, "init", True)
init_tasklet = init_state.add_tasklet(
"init_reduce", {}, {"__out"}, f"__out = {init_value}"
)
Expand Down Expand Up @@ -1044,29 +1131,24 @@ def closure_to_tasklet_sdfg(
node_types: dict[int, next_typing.Type],
) -> tuple[Context, Sequence[ValueExpr]]:
body = dace.SDFG("tasklet_toplevel")
state = body.add_state("tasklet_toplevel_entry")
state = body.add_state("tasklet_toplevel_entry", True)
symbol_map: dict[str, TaskletExpr] = {}

idx_accesses = {}
for dim, idx in domain.items():
name = f"{idx}_value"
body.add_scalar(name, dtype=dace.int64, transient=True)
body.add_scalar(name, dtype=_INDEX_DTYPE, transient=True)
tasklet = state.add_tasklet(f"get_{dim}", set(), {"value"}, f"value = {idx}")
access = state.add_access(name)
idx_accesses[dim] = access
state.add_edge(tasklet, "value", access, None, dace.Memlet.simple(name, "0"))
for name, ty in inputs:
if isinstance(ty, ts.FieldType):
ndim = len(ty.dims)
shape = [
dace.symbol(f"{unique_var_name()}_shp{i}", dtype=dace.int64) for i in range(ndim)
]
stride = [
dace.symbol(f"{unique_var_name()}_strd{i}", dtype=dace.int64) for i in range(ndim)
]
shape, strides = new_array_symbols(name, ndim)
dims = [dim.value for dim in ty.dims]
dtype = as_dace_type(ty.dtype)
body.add_array(name, shape=shape, strides=stride, dtype=dtype)
body.add_array(name, shape=shape, strides=strides, dtype=dtype)
field = state.add_access(name)
indices = {dim: idx_accesses[dim] for dim in domain.keys()}
symbol_map[name] = IteratorExpr(field, indices, dtype, dims)
Expand All @@ -1076,9 +1158,8 @@ def closure_to_tasklet_sdfg(
body.add_scalar(name, dtype=dtype)
symbol_map[name] = ValueExpr(state.add_access(name), dtype)
for arr, name in connectivities:
shape = [dace.symbol(f"{unique_var_name()}_shp{i}", dtype=dace.int64) for i in range(2)]
stride = [dace.symbol(f"{unique_var_name()}_strd{i}", dtype=dace.int64) for i in range(2)]
body.add_array(name, shape=shape, strides=stride, dtype=arr.dtype)
shape, strides = new_array_symbols(name, ndim=2)
body.add_array(name, shape=shape, strides=strides, dtype=arr.dtype)

context = Context(body, state, symbol_map)
translator = PythonTaskletCodegen(offset_provider, context, node_types)
Expand Down
Loading

0 comments on commit b1f9c9a

Please sign in to comment.