Skip to content

Commit

Permalink
feat[next][dace]: Modified the file caching. (#1434)
Browse files Browse the repository at this point in the history
In PR #1422 @edopao introduced a mechanism to skip the SDFG translation.
This PR moves this cache from the `run_dace_iterator()` function into the `build_sdfg_from_itir()` function.
  • Loading branch information
philip-paul-mueller authored Feb 1, 2024
1 parent 75d23d0 commit d6dfd6f
Showing 1 changed file with 41 additions and 23 deletions.
64 changes: 41 additions & 23 deletions src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,21 +254,36 @@ def build_sdfg_from_itir(
on_gpu: bool = False,
column_axis: Optional[common.Dimension] = None,
lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE,
load_sdfg_from_file: bool = False,
cache_id: Optional[str] = None,
save_sdfg: bool = True,
) -> dace.SDFG:
"""Translate a Fencil into an SDFG.
Args:
program: The Fencil that should be translated.
*args: Arguments for which the fencil should be called.
offset_provider: The set of offset providers that should be used.
auto_optimize: Apply DaCe's `auto_optimize` heuristic.
on_gpu: Performs the translation for GPU, defaults to `False`.
column_axis: The column axis to be used, defaults to `None`.
lift_mode: Which lift mode should be used, defaults `FORCE_INLINE`.
program: The Fencil that should be translated.
*args: Arguments for which the fencil should be called.
offset_provider: The set of offset providers that should be used.
auto_optimize: Apply DaCe's `auto_optimize` heuristic.
on_gpu: Performs the translation for GPU, defaults to `False`.
column_axis: The column axis to be used, defaults to `None`.
lift_mode: Which lift mode should be used, defaults `FORCE_INLINE`.
load_sdfg_from_file: Allows to read the SDFG from file, instead of generating it, for debug only.
cache_id: The id of the cache entry, used to disambiguate stored sdfgs.
save_sdfg: If `True`, the default the SDFG is stored as a file and can be loaded, this allows to skip the lowering step, requires `load_sdfg_from_file` set to `True`.
Notes:
Currently only the `FORCE_INLINE` liftmode is supported and the value of `lift_mode` is ignored.
"""
# Test if we can go through the cache?
sdfg_filename = (
f"_dacegraphs/gt4py/{cache_id if cache_id is not None else '.'}/{program.id}.sdfg"
)
if load_sdfg_from_file and Path(sdfg_filename).exists():
sdfg: dace.SDFG = dace.SDFG.from_file(sdfg_filename)
sdfg.validate()
return sdfg

# TODO(edopao): As temporary fix until temporaries are supported in the DaCe Backend force
# `lift_more` to `FORCE_INLINE` mode.
lift_mode = itir_transforms.LiftMode.FORCE_INLINE
Expand All @@ -277,7 +292,7 @@ def build_sdfg_from_itir(
# visit ITIR and generate SDFG
program = preprocess_program(program, offset_provider, lift_mode)
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis)
sdfg: dace.SDFG = sdfg_genenerator.visit(program)
sdfg = sdfg_genenerator.visit(program)
if sdfg is None:
raise RuntimeError(f"Visit failed for program {program.id}.")

Expand Down Expand Up @@ -311,6 +326,10 @@ def build_sdfg_from_itir(
if on_gpu:
sdfg.apply_gpu_transformations()

# Store the sdfg such that we can later reuse it.
if save_sdfg:
sdfg.save(sdfg_filename)

return sdfg


Expand All @@ -326,7 +345,8 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs):
column_axis = kwargs.get("column_axis", None)
offset_provider = kwargs["offset_provider"]
# debug option to store SDFGs on filesystem and skip lowering ITIR to SDFG at each run
skip_itir_lowering_to_sdfg = kwargs.get("skip_itir_lowering_to_sdfg", False)
load_sdfg_from_file = kwargs.get("load_sdfg_from_file", False)
save_sdfg = kwargs.get("save_sdfg", True)

arg_types = [type_translation.from_value(arg) for arg in args]

Expand All @@ -336,20 +356,18 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs):
sdfg_program = build_cache[cache_id]
sdfg = sdfg_program.sdfg
else:
sdfg_filename = f"_dacegraphs/gt4py/{cache_id}/{program.id}.sdfg"
if not (skip_itir_lowering_to_sdfg and Path(sdfg_filename).exists()):
sdfg = build_sdfg_from_itir(
program,
*args,
offset_provider=offset_provider,
auto_optimize=auto_optimize,
on_gpu=on_gpu,
column_axis=column_axis,
lift_mode=lift_mode,
)
sdfg.save(sdfg_filename)
else:
sdfg = dace.SDFG.from_file(sdfg_filename)
sdfg = build_sdfg_from_itir(
program,
*args,
offset_provider=offset_provider,
auto_optimize=auto_optimize,
on_gpu=on_gpu,
column_axis=column_axis,
lift_mode=lift_mode,
load_sdfg_from_file=load_sdfg_from_file,
cache_id=cache_id,
save_sdfg=save_sdfg,
)

sdfg.build_folder = compilation_cache._session_cache_dir_path / ".dacecache"
with dace.config.temporary_config():
Expand Down

0 comments on commit d6dfd6f

Please sign in to comment.