Skip to content

Commit bf0c531

Browse files
Add deprecation flag on base backend.
Update unit tests Lint
1 parent b7c43b5 commit bf0c531

File tree

5 files changed

+33
-5
lines changed

5 files changed

+33
-5
lines changed

src/gt4py/cartesian/backend/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]:
228228

229229
class BaseBackend(Backend):
230230
MODULE_GENERATOR_CLASS: ClassVar[Type["BaseModuleGenerator"]]
231+
deprecated: bool = False
231232

232233
def load(self) -> Optional[Type["StencilObject"]]:
233234
build_info = self.builder.options.build_info

src/gt4py/cartesian/backend/cuda_backend.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -147,12 +147,14 @@ class CudaBackend(BaseGTBackend, CLIBackendMixin):
147147
PYEXT_GENERATOR_CLASS = CudaExtGenerator # type: ignore
148148
MODULE_GENERATOR_CLASS = CUDAPyExtModuleGenerator
149149
GT_BACKEND_T = "gpu"
150+
deprecated = not GT4PY_GTC_CUDA_USE
150151

151152
def generate_extension(self, **kwargs: Any) -> Tuple[str, str]:
152153
return self.make_extension(stencil_ir=self.builder.gtir, uses_cuda=True)
153154

154155
def generate(self) -> Type["StencilObject"]:
155-
if GT4PY_GTC_CUDA_USE:
156+
# We push for hard deprecation here by raising by default and warning if use has been forced.
157+
if not self.deprecated:
156158
warnings.warn(
157159
"cuda backend is deprecated, feature developed after February 2024 will not be available",
158160
DeprecationWarning,
@@ -161,7 +163,7 @@ def generate(self) -> Type["StencilObject"]:
161163
else:
162164
raise NotImplementedError(
163165
"cuda backend is no longer maintained (February 2024)."
164-
"You can still force the use of the backend by defining GT4PY_GTC_CUDA_USE=1"
166+
"You can still force the use of the backend by defining GT4PY_GTC_CUDA_USE=1."
165167
)
166168

167169
self.check_options(self.builder.options)

src/gt4py/cartesian/backend/gtc_common.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,10 @@ def make_extension(
286286
gt_pyext_sources: Dict[str, Any]
287287
if not self.builder.options._impl_opts.get("disable-code-generation", False):
288288
gt_pyext_files = self.make_extension_sources(stencil_ir=stencil_ir)
289-
gt_pyext_sources = {**gt_pyext_files["computation"], **gt_pyext_files["bindings"]}
289+
gt_pyext_sources = {
290+
**gt_pyext_files["computation"],
291+
**gt_pyext_files["bindings"],
292+
}
290293
else:
291294
# Pass NOTHING to the self.builder means try to reuse the source code files
292295
gt_pyext_files = {}

tests/cartesian_tests/definitions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _get_backends_with_storage_info(storage_info_kind: str):
4444
for name in _ALL_BACKEND_NAMES:
4545
backend = gt4pyc.backend.from_name(name)
4646
if backend is not None:
47-
if backend.storage_info["device"] == storage_info_kind:
47+
if backend.storage_info["device"] == storage_info_kind and not backend.deprecated:
4848
res.append(_backend_name_as_param(name))
4949
return res
5050

tests/cartesian_tests/unit_tests/backend_tests/test_backend.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@ def stencil_def(
4848
out = pa * fa + pb * fb - pc * fc # type: ignore # noqa
4949

5050

51-
field_info_val = {0: ("out", "fa"), 1: ("out", "fa", "fb"), 2: ("out", "fa", "fb", "fc")}
51+
field_info_val = {
52+
0: ("out", "fa"),
53+
1: ("out", "fa", "fb"),
54+
2: ("out", "fa", "fb", "fc"),
55+
}
5256
parameter_info_val = {0: ("pa",), 1: ("pa", "pb"), 2: ("pa", "pb", "pc")}
5357
unreferenced_val = {0: ("pb", "fb", "pc", "fc"), 1: ("pc", "fc"), 2: ()}
5458

@@ -168,5 +172,23 @@ def test_toolchain_profiling(backend_name: str, mode: int, rebuild: bool):
168172
assert build_info["load_time"] > 0.0
169173

170174

175+
@pytest.mark.parametrize("backend_name", ["cuda"])
176+
def test_deprecation_gtc_cuda(backend_name: str):
177+
# Default deprecation, raise an error
178+
build_info: Dict[str, Any] = {}
179+
builder = (
180+
StencilBuilder(cast(StencilFunc, stencil_def))
181+
.with_backend(backend_name)
182+
.with_externals({"MODE": 2})
183+
.with_options(
184+
name=stencil_def.__name__,
185+
module=stencil_def.__module__,
186+
build_info=build_info,
187+
)
188+
)
189+
with pytest.raises(NotImplementedError):
190+
builder.build()
191+
192+
171193
if __name__ == "__main__":
172194
pytest.main([__file__])

0 commit comments

Comments
 (0)