Skip to content

Commit 628a33b

Browse files
authored
feat[next][dace]: Add gt4py workflow for the DaCe backend (#1477)
Integrate the DaCe backend with the workflow API.
1 parent d7b6562 commit 628a33b

File tree

4 files changed

+276
-14
lines changed

4 files changed

+276
-14
lines changed

src/gt4py/next/otf/languages.py

+5
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ class Python(LanguageTag):
5757
...
5858

5959

60+
class SDFG(LanguageTag):
61+
settings_class = LanguageSettings
62+
...
63+
64+
6065
class NanobindSrcL(LanguageTag): ...
6166

6267

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# GT4Py - GridTools Framework
2+
#
3+
# Copyright (c) 2014-2023, ETH Zurich
4+
# All rights reserved.
5+
#
6+
# This file is part of the GT4Py project and the GridTools framework.
7+
# GT4Py is free software: you can redistribute it and/or modify it under
8+
# the terms of the GNU General Public License as published by the
9+
# Free Software Foundation, either version 3 of the License, or any later
10+
# version. See the LICENSE.txt file at the top-level directory of this
11+
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
12+
#
13+
# SPDX-License-Identifier: GPL-3.0-or-later
14+
15+
import functools
16+
17+
import factory
18+
19+
import gt4py._core.definitions as core_defs
20+
from gt4py.next import config
21+
from gt4py.next.otf import recipes, stages
22+
from gt4py.next.program_processors.runners.dace_iterator.workflow import (
23+
DaCeCompilationStepFactory,
24+
DaCeTranslationStepFactory,
25+
convert_args,
26+
)
27+
from gt4py.next.program_processors.runners.gtfn import GTFNBackendFactory
28+
29+
30+
def _no_bindings(inp: stages.ProgramSource) -> stages.CompilableSource:
31+
return stages.CompilableSource(program_source=inp, binding_source=None)
32+
33+
34+
class DaCeWorkflowFactory(factory.Factory):
35+
class Meta:
36+
model = recipes.OTFCompileWorkflow
37+
38+
class Params:
39+
device_type: core_defs.DeviceType = core_defs.DeviceType.CPU
40+
cmake_build_type: config.CMakeBuildType = factory.LazyFunction(
41+
lambda: config.CMAKE_BUILD_TYPE
42+
)
43+
use_field_canonical_representation: bool = False
44+
45+
translation = factory.SubFactory(
46+
DaCeTranslationStepFactory,
47+
device_type=factory.SelfAttribute("..device_type"),
48+
use_field_canonical_representation=factory.SelfAttribute(
49+
"..use_field_canonical_representation"
50+
),
51+
)
52+
bindings = _no_bindings
53+
compilation = factory.SubFactory(
54+
DaCeCompilationStepFactory,
55+
cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME),
56+
cmake_build_type=factory.SelfAttribute("..cmake_build_type"),
57+
)
58+
decoration = factory.LazyAttribute(
59+
lambda o: functools.partial(
60+
convert_args,
61+
device=o.device_type,
62+
use_field_canonical_representation=o.use_field_canonical_representation,
63+
)
64+
)
65+
66+
67+
class DaCeBackendFactory(GTFNBackendFactory):
68+
class Params:
69+
otf_workflow = factory.SubFactory(
70+
DaCeWorkflowFactory,
71+
device_type=factory.SelfAttribute("..device_type"),
72+
use_field_canonical_representation=factory.SelfAttribute(
73+
"..use_field_canonical_representation"
74+
),
75+
)
76+
name = factory.LazyAttribute(
77+
lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}"
78+
)
79+
auto_optimize = factory.Trait(
80+
otf_workflow__translation__auto_optimize=True,
81+
name_temps="_opt",
82+
)
83+
use_field_canonical_representation: bool = False
84+
85+
86+
run_dace_cpu = DaCeBackendFactory(cached=True, auto_optimize=True)
87+
88+
run_dace_gpu = DaCeBackendFactory(gpu=True, cached=True, auto_optimize=True)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# GT4Py - GridTools Framework
2+
#
3+
# Copyright (c) 2014-2023, ETH Zurich
4+
# All rights reserved.
5+
#
6+
# This file is part of the GT4Py project and the GridTools framework.
7+
# GT4Py is free software: you can redistribute it and/or modify it under
8+
# the terms of the GNU General Public License as published by the
9+
# Free Software Foundation, either version 3 of the License, or any later
10+
# version. See the LICENSE.txt file at the top-level directory of this
11+
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
12+
#
13+
# SPDX-License-Identifier: GPL-3.0-or-later
14+
15+
from __future__ import annotations
16+
17+
import dataclasses
18+
from typing import Callable, Optional, cast
19+
20+
import dace
21+
import factory
22+
from dace.codegen.compiled_sdfg import CompiledSDFG
23+
24+
from gt4py._core import definitions as core_defs
25+
from gt4py.next import common, config
26+
from gt4py.next.common import Dimension
27+
from gt4py.next.iterator import ir as itir
28+
from gt4py.next.iterator.transforms import LiftMode
29+
from gt4py.next.otf import languages, stages, step_types, workflow
30+
from gt4py.next.otf.binding import interface
31+
from gt4py.next.otf.compilation import cache
32+
from gt4py.next.otf.languages import LanguageSettings
33+
from gt4py.next.type_system import type_translation as tt
34+
35+
from . import build_sdfg_from_itir, get_sdfg_args
36+
37+
38+
@dataclasses.dataclass(frozen=True)
39+
class DaCeTranslator(
40+
workflow.ChainableWorkflowMixin[
41+
stages.ProgramCall,
42+
stages.ProgramSource[languages.SDFG, languages.LanguageSettings],
43+
],
44+
step_types.TranslationStep[languages.SDFG, languages.LanguageSettings],
45+
):
46+
auto_optimize: bool = False
47+
lift_mode: LiftMode = LiftMode.FORCE_INLINE
48+
device_type: core_defs.DeviceType = core_defs.DeviceType.CPU
49+
temporary_extraction_heuristics: Optional[
50+
Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]]
51+
] = None
52+
use_field_canonical_representation: bool = False
53+
54+
def _language_settings(self) -> languages.LanguageSettings:
55+
return languages.LanguageSettings(
56+
formatter_key="",
57+
formatter_style="",
58+
file_extension="sdfg",
59+
)
60+
61+
def __call__(
62+
self,
63+
inp: stages.ProgramCall,
64+
) -> stages.ProgramSource[languages.SDFG, LanguageSettings]:
65+
"""Generate DaCe SDFG file from the ITIR definition."""
66+
program: itir.FencilDefinition = inp.program
67+
on_gpu = True if self.device_type == core_defs.DeviceType.CUDA else False
68+
69+
# ITIR parameters
70+
column_axis: Optional[Dimension] = inp.kwargs.get("column_axis", None)
71+
offset_provider = inp.kwargs["offset_provider"]
72+
73+
sdfg = build_sdfg_from_itir(
74+
program,
75+
*inp.args,
76+
offset_provider=offset_provider,
77+
auto_optimize=self.auto_optimize,
78+
on_gpu=on_gpu,
79+
column_axis=column_axis,
80+
lift_mode=self.lift_mode,
81+
load_sdfg_from_file=False,
82+
save_sdfg=False,
83+
use_field_canonical_representation=self.use_field_canonical_representation,
84+
)
85+
86+
arg_types = tuple(
87+
interface.Parameter(param, tt.from_value(arg))
88+
for param, arg in zip(sdfg.arg_names, inp.args)
89+
)
90+
91+
module: stages.ProgramSource[languages.SDFG, languages.LanguageSettings] = (
92+
stages.ProgramSource(
93+
entry_point=interface.Function(program.id, arg_types),
94+
source_code=sdfg.to_json(),
95+
library_deps=tuple(),
96+
language=languages.SDFG,
97+
language_settings=self._language_settings(),
98+
)
99+
)
100+
return module
101+
102+
103+
class DaCeTranslationStepFactory(factory.Factory):
104+
class Meta:
105+
model = DaCeTranslator
106+
107+
108+
@dataclasses.dataclass(frozen=True)
109+
class DaCeCompiler(
110+
workflow.ChainableWorkflowMixin[
111+
stages.CompilableSource[languages.SDFG, languages.LanguageSettings, languages.Python],
112+
stages.CompiledProgram,
113+
],
114+
workflow.ReplaceEnabledWorkflowMixin[
115+
stages.CompilableSource[languages.SDFG, languages.LanguageSettings, languages.Python],
116+
stages.CompiledProgram,
117+
],
118+
step_types.CompilationStep[languages.SDFG, languages.LanguageSettings, languages.Python],
119+
):
120+
"""Use the dace build system to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``."""
121+
122+
cache_lifetime: config.BuildCacheLifetime
123+
device_type: core_defs.DeviceType = core_defs.DeviceType.CPU
124+
cmake_build_type: config.CMakeBuildType = config.CMakeBuildType.DEBUG
125+
126+
def __call__(
127+
self,
128+
inp: stages.CompilableSource[languages.SDFG, languages.LanguageSettings, languages.Python],
129+
) -> stages.CompiledProgram:
130+
sdfg = dace.SDFG.from_json(inp.program_source.source_code)
131+
132+
src_dir = cache.get_cache_folder(inp, self.cache_lifetime)
133+
sdfg.build_folder = src_dir / ".dacecache"
134+
135+
with dace.config.temporary_config():
136+
dace.config.Config.set("compiler", "build_type", value=self.cmake_build_type.value)
137+
if self.device_type == core_defs.DeviceType.CPU:
138+
compiler_args = dace.config.Config.get("compiler", "cpu", "args")
139+
# disable finite-math-only in order to support isfinite/isinf/isnan builtins
140+
if "-ffast-math" in compiler_args:
141+
compiler_args += " -fno-finite-math-only"
142+
if "-ffinite-math-only" in compiler_args:
143+
compiler_args.replace("-ffinite-math-only", "")
144+
145+
dace.config.Config.set("compiler", "cpu", "args", value=compiler_args)
146+
sdfg_program = sdfg.compile(validate=False)
147+
148+
return sdfg_program
149+
150+
151+
class DaCeCompilationStepFactory(factory.Factory):
152+
class Meta:
153+
model = DaCeCompiler
154+
155+
156+
def convert_args(
157+
inp: stages.CompiledProgram,
158+
device: core_defs.DeviceType = core_defs.DeviceType.CPU,
159+
use_field_canonical_representation: bool = False,
160+
) -> stages.CompiledProgram:
161+
sdfg_program = cast(CompiledSDFG, inp)
162+
on_gpu = True if device == core_defs.DeviceType.CUDA else False
163+
sdfg = sdfg_program.sdfg
164+
165+
def decorated_program(
166+
*args, offset_provider: dict[str, common.Connectivity | common.Dimension]
167+
):
168+
sdfg_args = get_sdfg_args(
169+
sdfg,
170+
*args,
171+
check_args=False,
172+
offset_provider=offset_provider,
173+
on_gpu=on_gpu,
174+
use_field_canonical_representation=use_field_canonical_representation,
175+
)
176+
177+
with dace.config.temporary_config():
178+
dace.config.Config.set("compiler", "allow_view_arguments", value=True)
179+
return inp(**sdfg_args)
180+
181+
return decorated_program

tests/next_tests/definitions.py

+2-14
Original file line numberDiff line numberDiff line change
@@ -83,20 +83,8 @@ class EmbeddedIds(_PythonObjectIdMixin, str, enum.Enum):
8383

8484

8585
class OptionalProgramBackendId(_PythonObjectIdMixin, str, enum.Enum):
86-
DACE_CPU = "gt4py.next.program_processors.runners.dace_iterator.run_dace_cpu"
87-
DACE_GPU = "gt4py.next.program_processors.runners.dace_iterator.run_dace_gpu"
88-
89-
90-
class ProgramExecutorId(_PythonObjectIdMixin, str, enum.Enum):
91-
GTFN_CPU_EXECUTOR = f"{ProgramBackendId.GTFN_CPU}.executor"
92-
GTFN_CPU_IMPERATIVE_EXECUTOR = f"{ProgramBackendId.GTFN_CPU_IMPERATIVE}.executor"
93-
GTFN_CPU_WITH_TEMPORARIES = f"{ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES}.executor"
94-
ROUNDTRIP = f"{ProgramBackendId.ROUNDTRIP}.executor"
95-
DOUBLE_ROUNDTRIP = f"{ProgramBackendId.DOUBLE_ROUNDTRIP}.executor"
96-
97-
98-
class OptionalProgramExecutorId(_PythonObjectIdMixin, str, enum.Enum):
99-
DACE_CPU_EXECUTOR = f"{OptionalProgramBackendId.DACE_CPU}.executor"
86+
DACE_CPU = "gt4py.next.program_processors.runners.dace.run_dace_cpu"
87+
DACE_GPU = "gt4py.next.program_processors.runners.dace.run_dace_gpu"
10088

10189

10290
class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):

0 commit comments

Comments
 (0)