diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 97dd90eb54..7fd4794e57 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -207,6 +207,7 @@ def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]: neighbor_tables = filter_neighbor_tables(offset_provider) device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU + sdfg_sig = sdfg.signature_arglist(with_types=False) dace_args = get_args(sdfg, args) dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} dace_conn_args = get_connectivity_args(neighbor_tables, device) @@ -224,11 +225,8 @@ def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]: **dace_conn_strides, **dace_offsets, } - expected_args = { - key: value - for key, value in all_args.items() - if key in sdfg.signature_arglist(with_types=False) - } + expected_args = {key: all_args[key] for key in sdfg_sig} + return expected_args @@ -258,21 +256,22 @@ def build_sdfg_from_itir( # TODO(edopao): As temporary fix until temporaries are supported in the DaCe Backend force # `lift_more` to `FORCE_INLINE` mode. lift_mode = itir_transforms.LiftMode.FORCE_INLINE - arg_types = [type_translation.from_value(arg) for arg in args] - device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU # visit ITIR and generate SDFG program = preprocess_program(program, offset_provider, lift_mode) + # TODO: According to Lex one should build the SDFG first in a general mannor. + # Generalisation to a particular device should happen only at the end. sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, on_gpu) sdfg = sdfg_genenerator.visit(program) sdfg.simplify() # run DaCe auto-optimization heuristics if auto_optimize: - # TODO Investigate how symbol definitions improve autoopt transformations, - # in which case the cache table should take the symbols map into account. + # TODO: Investigate how symbol definitions improve autoopt transformations, + # in which case the cache table should take the symbols map into account. symbols: dict[str, int] = {} + device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu) return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index b3e6662623..e3b5ddf2ac 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -209,14 +209,9 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) # Create the call signature for the SDFG. - # All arguments required by the SDFG, regardless if explicit and implicit, are added - # as positional arguments. In the front are all arguments to the Fencil, in that - # order, they are followed by the arguments created by the translation process, - arg_list = [str(a) for a in node.params] - sig_list = program_sdfg.signature_arglist(with_types=False) - implicit_args = set(sig_list) - set(arg_list) - call_params = arg_list + [ia for ia in sig_list if ia in implicit_args] - program_sdfg.arg_names = call_params + # Only the arguments requiered by the Fencil, i.e. `node.params` are added as poitional arguments. + # The implicit arguments, such as the offset providers or the arguments created by the translation process, must be passed as keywords only arguments. + program_sdfg.arg_names = [str(a) for a in node.params] program_sdfg.validate() return program_sdfg