Skip to content

Commit

Permalink
feat[next]: Separates ITIR -> SDFG translation from running (#1379)
Browse files Browse the repository at this point in the history
Before it was only possible to translate ITIR to SDFG and execute it and it was not possible to extract the SDFG.
This commits splits this task into two parts and thus allows to perform the ITIR to SDFG translation without executing it.
  • Loading branch information
philip-paul-mueller authored Dec 5, 2023
1 parent d7cf10f commit 9f2ed1e
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 48 deletions.
117 changes: 69 additions & 48 deletions src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,58 +185,85 @@ def get_cache_id(
return m.hexdigest()


def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> Optional[dace.SDFG]:
def build_sdfg_from_itir(
program: itir.FencilDefinition,
*args,
offset_provider: dict[str, Any],
auto_optimize: bool = False,
on_gpu: bool = False,
column_axis: Optional[Dimension] = None,
lift_mode: LiftMode = LiftMode.FORCE_INLINE,
) -> dace.SDFG:
"""Translate a Fencil into an SDFG.
Args:
program: The Fencil that should be translated.
*args: Arguments for which the fencil should be called.
offset_provider: The set of offset providers that should be used.
auto_optimize: Apply DaCe's `auto_optimize` heuristic.
on_gpu: Performs the translation for GPU, defaults to `False`.
column_axis: The column axis to be used, defaults to `None`.
lift_mode: Which lift mode should be used, defaults `FORCE_INLINE`.
Notes:
Currently only the `FORCE_INLINE` liftmode is supported and the value of `lift_mode` is ignored.
"""
# TODO(edopao): As temporary fix until temporaries are supported in the DaCe Backend force
# `lift_more` to `FORCE_INLINE` mode.
lift_mode = LiftMode.FORCE_INLINE

arg_types = [type_translation.from_value(arg) for arg in args]
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU

# visit ITIR and generate SDFG
program = preprocess_program(program, offset_provider, lift_mode)
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, on_gpu)
sdfg = sdfg_genenerator.visit(program)
sdfg.simplify()

# run DaCe auto-optimization heuristics
if auto_optimize:
# TODO Investigate how symbol definitions improve autoopt transformations,
# in which case the cache table should take the symbols map into account.
symbols: dict[str, int] = {}
sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu)

return sdfg


def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs):
# build parameters
auto_optimize = kwargs.get("auto_optimize", False)
build_cache = kwargs.get("build_cache", None)
build_type = kwargs.get("build_type", "RelWithDebInfo")
run_on_gpu = kwargs.get("run_on_gpu", False)
# Return parameter
return_sdfg = kwargs.get("return_sdfg", False)
run_sdfg = kwargs.get("run_sdfg", True)
on_gpu = kwargs.get("on_gpu", False)
auto_optimize = kwargs.get("auto_optimize", False)
lift_mode = kwargs.get("lift_mode", LiftMode.FORCE_INLINE)
# ITIR parameters
column_axis = kwargs.get("column_axis", None)
lift_mode = (
LiftMode.FORCE_INLINE
) # TODO(edopao): make it configurable once temporaries are supported in DaCe backend
offset_provider = kwargs["offset_provider"]

arg_types = [type_translation.from_value(arg) for arg in args]
device = dace.DeviceType.GPU if run_on_gpu else dace.DeviceType.CPU
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU
neighbor_tables = filter_neighbor_tables(offset_provider)

cache_id = get_cache_id(program, arg_types, column_axis, offset_provider)
sdfg: Optional[dace.SDFG] = None
if build_cache is not None and cache_id in build_cache:
# retrieve SDFG program from build cache
sdfg_program = build_cache[cache_id]
sdfg = sdfg_program.sdfg

else:
# visit ITIR and generate SDFG
program = preprocess_program(program, offset_provider, lift_mode)
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, run_on_gpu)
sdfg = sdfg_genenerator.visit(program)

# All arguments required by the SDFG, regardless if explicit and implicit, are added
# as positional arguments. In the front are all arguments to the Fencil, in that
# order, they are followed by the arguments created by the translation process,
# their order is determined by DaCe and unspecific.
assert len(sdfg.arg_names) == 0
arg_list = [str(a) for a in program.params]
sig_list = sdfg.signature_arglist(with_types=False)
implicit_args = set(sig_list) - set(arg_list)
call_params = arg_list + [ia for ia in sig_list if ia in implicit_args]
sdfg.arg_names = call_params

sdfg.simplify()

# run DaCe auto-optimization heuristics
if auto_optimize:
# TODO Investigate how symbol definitions improve autoopt transformations,
# in which case the cache table should take the symbols map into account.
symbols: dict[str, int] = {}
sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=run_on_gpu)

# compile SDFG and retrieve SDFG program
sdfg = build_sdfg_from_itir(
program,
*args,
offset_provider=offset_provider,
auto_optimize=auto_optimize,
on_gpu=on_gpu,
column_axis=column_axis,
lift_mode=lift_mode,
)

sdfg.build_folder = cache._session_cache_dir_path / ".dacecache"
with dace.config.temporary_config():
dace.config.Config.set("compiler", "build_type", value=build_type)
Expand Down Expand Up @@ -271,16 +298,10 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> Option
if key in sdfg.signature_arglist(with_types=False)
}

if run_sdfg:
with dace.config.temporary_config():
dace.config.Config.set("compiler", "allow_view_arguments", value=True)
dace.config.Config.set("frontend", "check_args", value=True)
sdfg_program(**expected_args)
#

if return_sdfg:
return sdfg
return None
with dace.config.temporary_config():
dace.config.Config.set("compiler", "allow_view_arguments", value=True)
dace.config.Config.set("frontend", "check_args", value=True)
sdfg_program(**expected_args)


def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
Expand All @@ -290,7 +311,7 @@ def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
**kwargs,
build_cache=_build_cache_cpu,
build_type=_build_type,
run_on_gpu=False,
on_gpu=False,
)


Expand All @@ -308,7 +329,7 @@ def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
**kwargs,
build_cache=_build_cache_gpu,
build_type=_build_type,
run_on_gpu=True,
on_gpu=True,
)

else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,16 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
access_node = last_state.add_access(inner_name)
last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet)

# Create the call signature for the SDFG.
# All arguments required by the SDFG, regardless if explicit and implicit, are added
# as positional arguments. In the front are all arguments to the Fencil, in that
# order, they are followed by the arguments created by the translation process,
arg_list = [str(a) for a in node.params]
sig_list = program_sdfg.signature_arglist(with_types=False)
implicit_args = set(sig_list) - set(arg_list)
call_params = arg_list + [ia for ia in sig_list if ia in implicit_args]
program_sdfg.arg_names = call_params

program_sdfg.validate()
return program_sdfg

Expand Down

0 comments on commit 9f2ed1e

Please sign in to comment.