Skip to content

Commit d7b6562

Browse files
authored
refactor[next]: Refactor workflow based executors and move the backend class. (#1470)
## Changed - `OTFCompileExecutor` and `CachedOTFCompileExecutor` merged into `ModularExecutor` ## Moved - `OTFBackend` moved and renamed to `next.backend.Backend` ## Reasoning The two `*CompileExecutor` classes were identical, except for type hints. Since they are now almost exclusively used inside `OTFBackend`, which does not retain typing information about which of them it contains, this distinction is no longer helpful for static type checking. `OTFBackend` has always been more general than it's naming and location suggested. It is the de-facto definition of a backend within `gt4py.next`: a wrapper around an executor and an allocator. This change makes the status quo visible.
1 parent ae9c203 commit d7b6562

File tree

9 files changed

+110
-105
lines changed

9 files changed

+110
-105
lines changed

docs/development/ADRs/0015-Test_Exclusion_Matrices.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ The following backend processors are defined:
4949
```python
5050
DACE_CPU = "dace_iterator.run_dace_cpu"
5151
DACE_GPU = "dace_iterator.run_dace_gpu"
52-
GTFN_CPU = "otf_compile_executor.run_gtfn"
53-
GTFN_CPU_IMPERATIVE = "otf_compile_executor.run_gtfn_imperative"
54-
GTFN_CPU_WITH_TEMPORARIES = "otf_compile_executor.run_gtfn_with_temporaries"
52+
GTFN_CPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn"
53+
GTFN_CPU_IMPERATIVE = "gt4py.next.program_processors.runners.gtfn.run_gtfn_imperative"
54+
GTFN_CPU_WITH_TEMPORARIES = "gt4py.next.program_processors.runners.gtfn.run_gtfn_with_temporaries"
5555
GTFN_GPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn_gpu"
5656
```
5757

src/gt4py/next/backend.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# GT4Py - GridTools Framework
2+
#
3+
# Copyright (c) 2014-2023, ETH Zurich
4+
# All rights reserved.
5+
#
6+
# This file is part of the GT4Py project and the GridTools framework.
7+
# GT4Py is free software: you can redistribute it and/or modify it under
8+
# the terms of the GNU General Public License as published by the
9+
# Free Software Foundation, either version 3 of the License, or any later
10+
# version. See the LICENSE.txt file at the top-level directory of this
11+
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
12+
#
13+
# SPDX-License-Identifier: GPL-3.0-or-later
14+
15+
from __future__ import annotations
16+
17+
import dataclasses
18+
from typing import Any, Generic
19+
20+
from gt4py._core import definitions as core_defs
21+
from gt4py.next import allocators as next_allocators
22+
from gt4py.next.iterator import ir as itir
23+
from gt4py.next.program_processors import processor_interface as ppi
24+
25+
26+
@dataclasses.dataclass(frozen=True)
27+
class Backend(Generic[core_defs.DeviceTypeT]):
28+
executor: ppi.ProgramExecutor
29+
allocator: next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]
30+
31+
def __call__(self, program: itir.FencilDefinition, *args, **kwargs: Any) -> None:
32+
self.executor.__call__(program, *args, **kwargs)
33+
34+
@property
35+
def __name__(self) -> str:
36+
return getattr(self.executor, "__name__", None) or repr(self)
37+
38+
@property
39+
def kind(self) -> type[ppi.ProgramExecutor]:
40+
return self.executor.kind
41+
42+
@property
43+
def __gt_allocator__(
44+
self,
45+
) -> next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]:
46+
return self.allocator
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# GT4Py - GridTools Framework
2+
#
3+
# Copyright (c) 2014-2023, ETH Zurich
4+
# All rights reserved.
5+
#
6+
# This file is part of the GT4Py project and the GridTools framework.
7+
# GT4Py is free software: you can redistribute it and/or modify it under
8+
# the terms of the GNU General Public License as published by the
9+
# Free Software Foundation, either version 3 of the License, or any later
10+
# version. See the LICENSE.txt file at the top-level directory of this
11+
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
12+
#
13+
# SPDX-License-Identifier: GPL-3.0-or-later
14+
15+
from __future__ import annotations
16+
17+
import dataclasses
18+
from typing import Any, Optional, TypeVar
19+
20+
import gt4py.next.iterator.ir as itir
21+
import gt4py.next.program_processors.processor_interface as ppi
22+
from gt4py.next.otf import languages, stages, workflow
23+
24+
25+
SrcL = TypeVar("SrcL", bound=languages.NanobindSrcL)
26+
TgtL = TypeVar("TgtL", bound=languages.LanguageTag)
27+
LS = TypeVar("LS", bound=languages.LanguageSettings)
28+
HashT = TypeVar("HashT")
29+
30+
31+
@dataclasses.dataclass(frozen=True)
32+
class ModularExecutor(ppi.ProgramExecutor):
33+
otf_workflow: workflow.Workflow[stages.ProgramCall, stages.CompiledProgram]
34+
name: Optional[str] = None
35+
36+
def __call__(self, program: itir.FencilDefinition, *args, **kwargs: Any) -> None:
37+
self.otf_workflow(stages.ProgramCall(program, args, kwargs))(
38+
*args, offset_provider=kwargs["offset_provider"]
39+
)
40+
41+
@property
42+
def __name__(self) -> str:
43+
return self.name or repr(self)

src/gt4py/next/program_processors/otf_compile_executor.py

-83
This file was deleted.

src/gt4py/next/program_processors/runners/dace_iterator/__init__.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@
2525

2626
import gt4py.next.allocators as next_allocators
2727
import gt4py.next.iterator.ir as itir
28-
import gt4py.next.program_processors.otf_compile_executor as otf_exec
2928
import gt4py.next.program_processors.processor_interface as ppi
30-
from gt4py.next import common
29+
from gt4py.next import backend, common
3130
from gt4py.next.iterator import transforms as itir_transforms
3231
from gt4py.next.otf.compilation import cache as compilation_cache
3332
from gt4py.next.type_system import type_specifications as ts, type_translation
@@ -437,7 +436,7 @@ def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
437436
)
438437

439438

440-
run_dace_cpu = otf_exec.OTFBackend(
439+
run_dace_cpu = backend.Backend(
441440
executor=ppi.program_executor(_run_dace_cpu, name="run_dace_cpu"),
442441
allocator=next_allocators.StandardCPUFieldBufferAllocator(),
443442
)
@@ -460,7 +459,7 @@ def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
460459
raise RuntimeError("Missing 'cupy' dependency for GPU execution.")
461460

462461

463-
run_dace_gpu = otf_exec.OTFBackend(
462+
run_dace_gpu = backend.Backend(
464463
executor=ppi.program_executor(_run_dace_gpu, name="run_dace_gpu"),
465464
allocator=next_allocators.StandardGPUFieldBufferAllocator(),
466465
)

src/gt4py/next/program_processors/runners/double_roundtrip.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
from typing import TYPE_CHECKING, Any
1818

19-
import gt4py.next.program_processors.otf_compile_executor as otf_compile_executor
2019
import gt4py.next.program_processors.processor_interface as ppi
21-
import gt4py.next.program_processors.runners.roundtrip as roundtrip
20+
from gt4py.next import backend as next_backend
21+
from gt4py.next.program_processors.runners import roundtrip
2222

2323

2424
if TYPE_CHECKING:
@@ -30,7 +30,7 @@ def executor(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None:
3030
roundtrip.execute_roundtrip(program, *args, dispatch_backend=roundtrip.executor, **kwargs)
3131

3232

33-
backend = otf_compile_executor.OTFBackend(
33+
backend = next_backend.Backend(
3434
executor=executor,
3535
allocator=roundtrip.backend.allocator,
3636
)

src/gt4py/next/program_processors/runners/gtfn.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@
2222
import gt4py._core.definitions as core_defs
2323
import gt4py.next.allocators as next_allocators
2424
from gt4py.eve.utils import content_hash
25-
from gt4py.next import common, config
25+
from gt4py.next import backend, common, config
2626
from gt4py.next.iterator import transforms
2727
from gt4py.next.iterator.transforms import global_tmps
2828
from gt4py.next.otf import recipes, stages, workflow
2929
from gt4py.next.otf.binding import nanobind
3030
from gt4py.next.otf.compilation import compiler
3131
from gt4py.next.otf.compilation.build_systems import compiledb
32-
from gt4py.next.program_processors import otf_compile_executor
32+
from gt4py.next.program_processors import modular_executor
3333
from gt4py.next.program_processors.codegens.gtfn import gtfn_module
3434
from gt4py.next.type_system.type_translation import from_value
3535

@@ -145,7 +145,7 @@ class Params:
145145

146146
class GTFNBackendFactory(factory.Factory):
147147
class Meta:
148-
model = otf_compile_executor.OTFBackend
148+
model = backend.Backend
149149

150150
class Params:
151151
name_device = "cpu"
@@ -159,7 +159,7 @@ class Params:
159159
)
160160
cached = factory.Trait(
161161
executor=factory.LazyAttribute(
162-
lambda o: otf_compile_executor.CachedOTFCompileExecutor(
162+
lambda o: modular_executor.ModularExecutor(
163163
otf_workflow=workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function),
164164
name=o.name,
165165
)
@@ -181,7 +181,7 @@ class Params:
181181
)
182182

183183
executor = factory.LazyAttribute(
184-
lambda o: otf_compile_executor.OTFCompileExecutor(otf_workflow=o.otf_workflow, name=o.name)
184+
lambda o: modular_executor.ModularExecutor(otf_workflow=o.otf_workflow, name=o.name)
185185
)
186186
allocator = next_allocators.StandardCPUFieldBufferAllocator()
187187

src/gt4py/next/program_processors/runners/roundtrip.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@
2121
from collections.abc import Callable, Iterable
2222
from typing import Any, Optional
2323

24-
import gt4py.eve.codegen as codegen
2524
import gt4py.next.allocators as next_allocators
2625
import gt4py.next.common as common
2726
import gt4py.next.iterator.embedded as embedded
2827
import gt4py.next.iterator.ir as itir
2928
import gt4py.next.iterator.transforms as itir_transforms
3029
import gt4py.next.iterator.transforms.global_tmps as gtmps_transform
31-
import gt4py.next.program_processors.otf_compile_executor as otf_compile_executor
3230
import gt4py.next.program_processors.processor_interface as ppi
31+
from gt4py.eve import codegen
3332
from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako
33+
from gt4py.next import backend as next_backend
3434

3535

3636
def _create_tmp(axes, origin, shape, dtype):
@@ -229,6 +229,6 @@ def execute_roundtrip(
229229

230230
executor = ppi.program_executor(execute_roundtrip) # type: ignore[arg-type]
231231

232-
backend = otf_compile_executor.OTFBackend(
232+
backend = next_backend.Backend(
233233
executor=executor, allocator=next_allocators.StandardCPUFieldBufferAllocator()
234234
)

tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from numpy import int32, int64
1717

1818
from gt4py import next as gtx
19-
from gt4py.next import common
19+
from gt4py.next import backend, common
2020
from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms
21-
from gt4py.next.program_processors import otf_compile_executor
21+
from gt4py.next.program_processors import modular_executor
2222
from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries
2323

2424
from next_tests.integration_tests import cases
@@ -38,8 +38,8 @@
3838

3939
@pytest.fixture
4040
def run_gtfn_with_temporaries_and_symbolic_sizes():
41-
return otf_compile_executor.OTFBackend(
42-
executor=otf_compile_executor.OTFCompileExecutor(
41+
return backend.Backend(
42+
executor=modular_executor.ModularExecutor(
4343
name="run_gtfn_with_temporaries_and_sizes",
4444
otf_workflow=run_gtfn_with_temporaries.executor.otf_workflow.replace(
4545
translation=run_gtfn_with_temporaries.executor.otf_workflow.translation.replace(

0 commit comments

Comments
 (0)