|
24 | 24 | import typing
|
25 | 25 | import warnings
|
26 | 26 | from collections.abc import Callable
|
27 |
| -from typing import Generator, Generic, TypeVar |
| 27 | +from typing import Generic, TypeVar |
28 | 28 |
|
29 | 29 | from gt4py import eve
|
30 | 30 | from gt4py._core import definitions as core_defs
|
|
72 | 72 | DEFAULT_BACKEND: Callable = None
|
73 | 73 |
|
74 | 74 |
|
75 |
| -def _field_constituents_shape_and_dims( |
76 |
| - arg, arg_type: ts.FieldType | ts.ScalarType | ts.TupleType |
77 |
| -) -> Generator[tuple[tuple[int, ...], list[Dimension]]]: |
78 |
| - if isinstance(arg_type, ts.TupleType): |
79 |
| - for el, el_type in zip(arg, arg_type.types): |
80 |
| - yield from _field_constituents_shape_and_dims(el, el_type) |
81 |
| - elif isinstance(arg_type, ts.FieldType): |
82 |
| - dims = type_info.extract_dims(arg_type) |
83 |
| - if hasattr(arg, "shape"): |
84 |
| - assert len(arg.shape) == len(dims) |
85 |
| - yield (arg.shape, dims) |
86 |
| - else: |
87 |
| - yield (None, dims) |
88 |
| - elif isinstance(arg_type, ts.ScalarType): |
89 |
| - yield (None, []) |
90 |
| - else: |
91 |
| - raise ValueError("Expected 'FieldType' or 'TupleType' thereof.") |
92 |
| - |
93 |
| - |
94 | 75 | # TODO(tehrengruber): Decide if and how programs can call other programs. As a
|
95 | 76 | # result Program could become a GTCallable.
|
96 | 77 | @dataclasses.dataclass(frozen=True)
|
@@ -179,7 +160,7 @@ def with_backend(self, backend: ppi.ProgramExecutor) -> Program:
|
179 | 160 | def with_grid_type(self, grid_type: GridType) -> Program:
|
180 | 161 | return dataclasses.replace(self, grid_type=grid_type)
|
181 | 162 |
|
182 |
| - def with_bound_args(self, **kwargs) -> ProgramWithBoundArgs: |
| 163 | + def with_bound_args(self, **kwargs: Any) -> ProgramWithBoundArgs: |
183 | 164 | """
|
184 | 165 | Bind scalar, i.e. non field, program arguments.
|
185 | 166 |
|
@@ -229,7 +210,7 @@ def itir(self) -> itir.FencilDefinition:
|
229 | 210 | )
|
230 | 211 | ).program
|
231 | 212 |
|
232 |
| - def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> None: |
| 213 | + def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs: Any) -> None: |
233 | 214 | if self.backend is None:
|
234 | 215 | warnings.warn(
|
235 | 216 | UserWarning(
|
@@ -376,7 +357,9 @@ def program(
|
376 | 357 |
|
377 | 358 | def program_inner(definition: types.FunctionType) -> Program:
|
378 | 359 | return Program.from_function(
|
379 |
| - definition, DEFAULT_BACKEND if backend is eve.NOTHING else backend, grid_type |
| 360 | + definition, |
| 361 | + DEFAULT_BACKEND if backend is eve.NOTHING else backend, |
| 362 | + grid_type, |
380 | 363 | )
|
381 | 364 |
|
382 | 365 | return program_inner if definition is None else program_inner(definition)
|
@@ -634,9 +617,13 @@ def field_operator(definition=None, *, backend=eve.NOTHING, grid_type=None):
|
634 | 617 | ... ...
|
635 | 618 | """
|
636 | 619 |
|
637 |
| - def field_operator_inner(definition: types.FunctionType) -> FieldOperator[foast.FieldOperator]: |
| 620 | + def field_operator_inner( |
| 621 | + definition: types.FunctionType, |
| 622 | + ) -> FieldOperator[foast.FieldOperator]: |
638 | 623 | return FieldOperator.from_function(
|
639 |
| - definition, DEFAULT_BACKEND if backend is eve.NOTHING else backend, grid_type |
| 624 | + definition, |
| 625 | + DEFAULT_BACKEND if backend is eve.NOTHING else backend, |
| 626 | + grid_type, |
640 | 627 | )
|
641 | 628 |
|
642 | 629 | return field_operator_inner if definition is None else field_operator_inner(definition)
|
|
0 commit comments