Skip to content

Commit bffb73b

Browse files
committed
Merge branch 'main' into c20_workflowefy
2 parents 2cff020 + d5b83a8 commit bffb73b

File tree

8 files changed

+106
-47
lines changed

8 files changed

+106
-47
lines changed

ci/base.Dockerfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,4 @@ RUN pyenv update && \
5454
ENV PATH="/root/.pyenv/shims:${PATH}"
5555

5656

57-
RUN pip install --upgrade pip setuptools wheel tox cupy-cuda11x
57+
RUN pip install --upgrade pip setuptools wheel tox cupy-cuda11x==12.3.0

docs/user/next/QuickstartGuide.md

+32-31
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ The following snippet imports the most commonly used features that are needed to
4646
import numpy as np
4747
4848
import gt4py.next as gtx
49-
from gt4py.next import float64, neighbor_sum, where
49+
from gt4py.next import float64, neighbor_sum, where, Dims
5050
```
5151

5252
#### Fields
@@ -91,11 +91,12 @@ Let's see an example for a field operator that adds two fields elementwise:
9191

9292
```{code-cell} ipython3
9393
@gtx.field_operator
94-
def add(a: gtx.Field[[CellDim, KDim], float64],
95-
b: gtx.Field[[CellDim, KDim], float64]) -> gtx.Field[[CellDim, KDim], float64]:
94+
def add(a: gtx.Field[gtx.Dims[CellDim, KDim], float64],
95+
b: gtx.Field[gtx.Dims[CellDim, KDim], float64]) -> gtx.Field[gtx.Dims[CellDim, KDim], float64]:
9696
return a + b
9797
```
9898

99+
\_Note: for now `Dims` is not mandatory, hence this type hint is also accepted: `gtx.Field[[CellDim, KDim], float64]`
99100
You can call field operators from [programs](#Programs), other field operators, or directly. The code snippet below shows a direct call, in which case you have to supply two additional arguments: `out`, which is a field to write the return value to, and `offset_provider`, which is left empty for now. The result of the field operator is a field with all entries equal to 5, but for brevity, only the average and the standard deviation of the entries are printed:
100101

101102
```{code-cell} ipython3
@@ -115,9 +116,9 @@ This example program below calls the above elementwise addition field operator t
115116

116117
```{code-cell} ipython3
117118
@gtx.program
118-
def run_add(a : gtx.Field[[CellDim, KDim], float64],
119-
b : gtx.Field[[CellDim, KDim], float64],
120-
result : gtx.Field[[CellDim, KDim], float64]):
119+
def run_add(a : gtx.Field[gtx.Dims[CellDim, KDim], float64],
120+
b : gtx.Field[gtx.Dims[CellDim, KDim], float64],
121+
result : gtx.Field[gtx.Dims[CellDim, KDim], float64]):
121122
add(a, b, out=result)
122123
add(b, result, out=result)
123124
```
@@ -247,11 +248,11 @@ Pay attention to the syntax where the field offset `E2C` can be freely accessed
247248

248249
```{code-cell} ipython3
249250
@gtx.field_operator
250-
def nearest_cell_to_edge(cell_values: gtx.Field[[CellDim], float64]) -> gtx.Field[[EdgeDim], float64]:
251+
def nearest_cell_to_edge(cell_values: gtx.Field[gtx.Dims[CellDim], float64]) -> gtx.Field[gtx.Dims[EdgeDim], float64]:
251252
return cell_values(E2C[0])
252253
253254
@gtx.program
254-
def run_nearest_cell_to_edge(cell_values: gtx.Field[[CellDim], float64], out : gtx.Field[[EdgeDim], float64]):
255+
def run_nearest_cell_to_edge(cell_values: gtx.Field[gtx.Dims[CellDim], float64], out : gtx.Field[gtx.Dims[EdgeDim], float64]):
255256
nearest_cell_to_edge(cell_values, out=out)
256257
257258
run_nearest_cell_to_edge(cell_values, edge_values, offset_provider={"E2C": E2C_offset_provider})
@@ -273,12 +274,12 @@ Similarly to the previous example, the output is once again a field on edges. Th
273274

274275
```{code-cell} ipython3
275276
@gtx.field_operator
276-
def sum_adjacent_cells(cells : gtx.Field[[CellDim], float64]) -> gtx.Field[[EdgeDim], float64]:
277-
# type of cells(E2C) is gtx.Field[[CellDim, E2CDim], float64]
277+
def sum_adjacent_cells(cells : gtx.Field[gtx.Dims[CellDim], float64]) -> gtx.Field[gtx.Dims[EdgeDim], float64]:
278+
# type of cells(E2C) is gtx.Field[gtx.Dims[CellDim, E2CDim], float64]
278279
return neighbor_sum(cells(E2C), axis=E2CDim)
279280
280281
@gtx.program
281-
def run_sum_adjacent_cells(cells : gtx.Field[[CellDim], float64], out : gtx.Field[[EdgeDim], float64]):
282+
def run_sum_adjacent_cells(cells : gtx.Field[gtx.Dims[CellDim], float64], out : gtx.Field[gtx.Dims[EdgeDim], float64]):
282283
sum_adjacent_cells(cells, out=out)
283284
284285
run_sum_adjacent_cells(cell_values, edge_values, offset_provider={"E2C": E2C_offset_provider})
@@ -302,7 +303,7 @@ This function takes 3 input arguments:
302303
- mask: a field with dtype boolean
303304
- true branch: a tuple, a field, or a scalar
304305
- false branch: a tuple, a field, of a scalar
305-
The mask can be directly a field of booleans (e.g. `gtx.Field[[CellDim], bool]`) or an expression evaluating to this type (e.g. `gtx.Field[[CellDim], float64] > 3`).
306+
The mask can be directly a field of booleans (e.g. `gtx.Field[gtx.Dims[CellDim], bool]`) or an expression evaluating to this type (e.g. `gtx.Field[[CellDim], float64] > 3`).
306307
The `where` builtin loops over each entry of the mask and returns values corresponding to the same indexes of either the true or the false branch.
307308
In the case where the true and false branches are either fields or scalars, the resulting output will be a field including all dimensions from all inputs. For example:
308309

@@ -312,8 +313,8 @@ result_where = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape))
312313
b = 6.0
313314
314315
@gtx.field_operator
315-
def conditional(mask: gtx.Field[[CellDim, KDim], bool], a: gtx.Field[[CellDim, KDim], float64], b: float
316-
) -> gtx.Field[[CellDim, KDim], float64]:
316+
def conditional(mask: gtx.Field[gtx.Dims[CellDim, KDim], bool], a: gtx.Field[gtx.Dims[CellDim, KDim], float64], b: float
317+
) -> gtx.Field[gtx.Dims[CellDim, KDim], float64]:
317318
return where(mask, a, b)
318319
319320
conditional(mask, a, b, out=result_where, offset_provider={})
@@ -329,13 +330,13 @@ result_1 = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape))
329330
result_2 = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape))
330331
331332
@gtx.field_operator
332-
def _conditional_tuple(mask: gtx.Field[[CellDim, KDim], bool], a: gtx.Field[[CellDim, KDim], float64], b: float
333-
) -> tuple[gtx.Field[[CellDim, KDim], float64], gtx.Field[[CellDim, KDim], float64]]:
333+
def _conditional_tuple(mask: gtx.Field[gtx.Dims[CellDim, KDim], bool], a: gtx.Field[gtx.Dims[CellDim, KDim], float64], b: float
334+
) -> tuple[gtx.Field[gtx.Dims[CellDim, KDim], float64], gtx.Field[gtx.Dims[CellDim, KDim], float64]]:
334335
return where(mask, (a, b), (b, a))
335336
336337
@gtx.program
337-
def conditional_tuple(mask: gtx.Field[[CellDim, KDim], bool], a: gtx.Field[[CellDim, KDim], float64], b: float,
338-
result_1: gtx.Field[[CellDim, KDim], float64], result_2: gtx.Field[[CellDim, KDim], float64]
338+
def conditional_tuple(mask: gtx.Field[gtx.Dims[CellDim, KDim], bool], a: gtx.Field[gtx.Dims[CellDim, KDim], float64], b: float,
339+
result_1: gtx.Field[gtx.Dims[CellDim, KDim], float64], result_2: gtx.Field[gtx.Dims[CellDim, KDim], float64]
339340
):
340341
_conditional_tuple(mask, a, b, out=(result_1, result_2))
341342
@@ -360,17 +361,17 @@ result_2 = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape))
360361
361362
@gtx.field_operator
362363
def _conditional_tuple_nested(
363-
mask: gtx.Field[[CellDim, KDim], bool], a: gtx.Field[[CellDim, KDim], float64], b: gtx.Field[[CellDim, KDim], float64], c: gtx.Field[[CellDim, KDim], float64], d: gtx.Field[[CellDim, KDim], float64]
364+
mask: gtx.Field[gtx.Dims[CellDim, KDim], bool], a: gtx.Field[gtx.Dims[CellDim, KDim], float64], b: gtx.Field[gtx.Dims[CellDim, KDim], float64], c: gtx.Field[gtx.Dims[CellDim, KDim], float64], d: gtx.Field[gtx.Dims[CellDim, KDim], float64]
364365
) -> tuple[
365-
tuple[gtx.Field[[CellDim, KDim], float64], gtx.Field[[CellDim, KDim], float64]],
366-
tuple[gtx.Field[[CellDim, KDim], float64], gtx.Field[[CellDim, KDim], float64]],
366+
tuple[gtx.Field[gtx.Dims[CellDim, KDim], float64], gtx.Field[gtx.Dims[CellDim, KDim], float64]],
367+
tuple[gtx.Field[gtx.Dims[CellDim, KDim], float64], gtx.Field[gtx.Dims[CellDim, KDim], float64]],
367368
]:
368369
return where(mask, ((a, b), (b, a)), ((c, d), (d, c)))
369370
370371
@gtx.program
371372
def conditional_tuple_nested(
372-
mask: gtx.Field[[CellDim, KDim], bool], a: gtx.Field[[CellDim, KDim], float64], b: gtx.Field[[CellDim, KDim], float64], c: gtx.Field[[CellDim, KDim], float64], d: gtx.Field[[CellDim, KDim], float64],
373-
result_1: gtx.Field[[CellDim, KDim], float64], result_2: gtx.Field[[CellDim, KDim], float64]
373+
mask: gtx.Field[gtx.Dims[CellDim, KDim], bool], a: gtx.Field[gtx.Dims[CellDim, KDim], float64], b: gtx.Field[gtx.Dims[CellDim, KDim], float64], c: gtx.Field[gtx.Dims[CellDim, KDim], float64], d: gtx.Field[gtx.Dims[CellDim, KDim], float64],
374+
result_1: gtx.Field[gtx.Dims[CellDim, KDim], float64], result_2: gtx.Field[gtx.Dims[CellDim, KDim], float64]
374375
):
375376
_conditional_tuple_nested(mask, a, b, c, d, out=((result_1, result_2), (result_2, result_1)))
376377
@@ -425,19 +426,19 @@ The second lines first creates a temporary field using `edge_differences(C2E)`,
425426

426427
```{code-cell} ipython3
427428
@gtx.field_operator
428-
def pseudo_lap(cells : gtx.Field[[CellDim], float64],
429-
edge_weights : gtx.Field[[CellDim, C2EDim], float64]) -> gtx.Field[[CellDim], float64]:
430-
edge_differences = cells(E2C[0]) - cells(E2C[1]) # type: gtx.Field[[EdgeDim], float64]
429+
def pseudo_lap(cells : gtx.Field[gtx.Dims[CellDim], float64],
430+
edge_weights : gtx.Field[gtx.Dims[CellDim, C2EDim], float64]) -> gtx.Field[gtx.Dims[CellDim], float64]:
431+
edge_differences = cells(E2C[0]) - cells(E2C[1]) # type: gtx.Field[gtx.Dims[EdgeDim], float64]
431432
return neighbor_sum(edge_differences(C2E) * edge_weights, axis=C2EDim)
432433
```
433434

434435
The program itself is just a shallow wrapper over the `pseudo_lap` field operator. The significant part is how offset providers for both the edge-to-cell and cell-to-edge connectivities are supplied when the program is called:
435436

436437
```{code-cell} ipython3
437438
@gtx.program
438-
def run_pseudo_laplacian(cells : gtx.Field[[CellDim], float64],
439-
edge_weights : gtx.Field[[CellDim, C2EDim], float64],
440-
out : gtx.Field[[CellDim], float64]):
439+
def run_pseudo_laplacian(cells : gtx.Field[gtx.Dims[CellDim], float64],
440+
edge_weights : gtx.Field[gtx.Dims[CellDim, C2EDim], float64],
441+
out : gtx.Field[gtx.Dims[CellDim], float64]):
441442
pseudo_lap(cells, edge_weights, out=out)
442443
443444
result_pseudo_lap = gtx.as_field([CellDim], np.zeros(shape=(6,)))
@@ -454,7 +455,7 @@ As a closure, here is an example of chaining field operators, which is very simp
454455

455456
```{code-cell} ipython3
456457
@gtx.field_operator
457-
def pseudo_laplap(cells : gtx.Field[[CellDim], float64],
458-
edge_weights : gtx.Field[[CellDim, C2EDim], float64]) -> gtx.Field[[CellDim], float64]:
458+
def pseudo_laplap(cells : gtx.Field[gtx.Dims[CellDim], float64],
459+
edge_weights : gtx.Field[gtx.Dims[CellDim, C2EDim], float64]) -> gtx.Field[gtx.Dims[CellDim], float64]:
459460
return pseudo_lap(pseudo_lap(cells, edge_weights), edge_weights)
460461
```

src/gt4py/next/__init__.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,17 @@
2424
"""
2525

2626
from . import common, ffront, iterator, program_processors, type_inference
27-
from .common import Dimension, DimensionKind, Domain, Field, GridType, UnitRange, domain, unit_range
27+
from .common import (
28+
Dimension,
29+
DimensionKind,
30+
Dims,
31+
Domain,
32+
Field,
33+
GridType,
34+
UnitRange,
35+
domain,
36+
unit_range,
37+
)
2838
from .constructors import as_connectivity, as_field, empty, full, ones, zeros
2939
from .embedded import ( # Just for registering field implementations
3040
nd_array_field as _nd_array_field,

src/gt4py/next/common.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
TypeAlias,
4343
TypeGuard,
4444
TypeVar,
45+
TypeVarTuple,
46+
Unpack,
4547
cast,
4648
extended_runtime_checkable,
4749
overload,
@@ -51,9 +53,15 @@
5153

5254

5355
DimT = TypeVar("DimT", bound="Dimension") # , covariant=True)
54-
DimsT = TypeVar("DimsT", bound=Sequence["Dimension"], covariant=True)
56+
ShapeT = TypeVarTuple("ShapeT")
5557

5658

59+
class Dims(Generic[Unpack[ShapeT]]):
60+
shape: tuple[Unpack[ShapeT]]
61+
62+
63+
DimsT = TypeVar("DimsT", bound=Dims, covariant=True)
64+
5765
Tag: TypeAlias = str
5866

5967

src/gt4py/next/iterator/transforms/trace_shifts.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# SPDX-License-Identifier: GPL-3.0-or-later
1414
import dataclasses
1515
import enum
16+
import sys
1617
from collections.abc import Callable
1718
from typing import Any, Final, Iterable, Literal
1819

@@ -263,6 +264,8 @@ def _tuple_get(index, tuple_val):
263264
}
264265

265266

267+
# TODO(tehrengruber): This pass is unnecessarily very inefficient and easily exceeds the default
268+
# recursion limit.
266269
@dataclasses.dataclass(frozen=True)
267270
class TraceShifts(PreserveLocationVisitor, NodeTranslator):
268271
shift_recorder: ShiftRecorder = dataclasses.field(default_factory=ShiftRecorder)
@@ -329,16 +332,22 @@ def visit_StencilClosure(self, node: ir.StencilClosure):
329332

330333
result = self.visit(node.stencil, ctx=_START_CTX)(*tracers)
331334
assert all(el is Sentinel.VALUE for el in _primitive_constituents(result))
335+
return node
332336

333337
@classmethod
334338
def apply(
335-
cls, node: ir.StencilClosure, *, inputs_only=True, save_to_annex=False
339+
cls, node: ir.StencilClosure | ir.FencilDefinition, *, inputs_only=True, save_to_annex=False
336340
) -> (
337341
dict[int, set[tuple[ir.OffsetLiteral, ...]]] | dict[str, set[tuple[ir.OffsetLiteral, ...]]]
338342
):
343+
old_recursionlimit = sys.getrecursionlimit()
344+
sys.setrecursionlimit(100000000)
345+
339346
instance = cls()
340347
instance.visit(node)
341348

349+
sys.setrecursionlimit(old_recursionlimit)
350+
342351
recorded_shifts = instance.shift_recorder.recorded_shifts
343352

344353
if save_to_annex:
@@ -348,6 +357,7 @@ def apply(
348357
ValidateRecordedShiftsAnnex().visit(node)
349358

350359
if inputs_only:
360+
assert isinstance(node, ir.StencilClosure)
351361
inputs_shifts = {}
352362
for inp in node.inputs:
353363
inputs_shifts[str(inp.id)] = recorded_shifts[id(inp)]

src/gt4py/next/type_system/type_translation.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import functools
1818
import types
1919
import typing
20-
from typing import Any, ForwardRef, Optional, Union
20+
from typing import Any, ForwardRef, Optional
2121

2222
import numpy as np
2323
import numpy.typing as npt
@@ -106,16 +106,18 @@ def from_type_hint(
106106
case common.Field:
107107
if (n_args := len(args)) != 2:
108108
raise ValueError(f"Field type requires two arguments, got {n_args}: '{type_hint}'.")
109-
110-
dims: Union[Ellipsis, list[common.Dimension]] = []
109+
dims: list[common.Dimension] = []
111110
dim_arg, dtype_arg = args
111+
dim_arg = (
112+
list(typing.get_args(dim_arg))
113+
if typing.get_origin(dim_arg) is common.Dims
114+
else dim_arg
115+
)
112116
if isinstance(dim_arg, list):
113117
for d in dim_arg:
114118
if not isinstance(d, common.Dimension):
115119
raise ValueError(f"Invalid field dimension definition '{d}'.")
116120
dims.append(d)
117-
elif dim_arg is Ellipsis:
118-
dims = dim_arg
119121
else:
120122
raise ValueError(f"Invalid field dimensions '{dim_arg}'.")
121123

tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py

+19-7
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,15 @@
3434
KDim = Dimension("KDim")
3535

3636

37-
@pytest.fixture(params=nd_array_field._nd_array_implementations)
37+
def nd_array_implementation_params():
38+
for xp in nd_array_field._nd_array_implementations:
39+
if hasattr(nd_array_field, "cp") and xp == nd_array_field.cp:
40+
yield pytest.param(xp, id=xp.__name__, marks=pytest.mark.requires_gpu)
41+
else:
42+
yield pytest.param(xp, id=xp.__name__)
43+
44+
45+
@pytest.fixture(params=nd_array_implementation_params())
3846
def nd_array_implementation(request):
3947
yield request.param
4048

@@ -272,12 +280,16 @@ def test_binary_operations_with_intersection(binary_arithmetic_op, dims, expecte
272280
assert np.allclose(op_result.ndarray, expected_result)
273281

274282

275-
@pytest.fixture(
276-
params=itertools.product(
277-
nd_array_field._nd_array_implementations, nd_array_field._nd_array_implementations
278-
),
279-
ids=lambda param: f"{param[0].__name__}-{param[1].__name__}",
280-
)
283+
def product_nd_array_implementation_params():
284+
for xp1 in nd_array_field._nd_array_implementations:
285+
for xp2 in nd_array_field._nd_array_implementations:
286+
marks = ()
287+
if any(hasattr(nd_array_field, "cp") and xp == nd_array_field.cp for xp in (xp1, xp2)):
288+
marks = pytest.mark.requires_gpu
289+
yield pytest.param((xp1, xp2), id=f"{xp1.__name__}-{xp2.__name__}", marks=marks)
290+
291+
292+
@pytest.fixture(params=product_nd_array_implementation_params())
281293
def product_nd_array_implementation(request):
282294
yield request.param
283295

tests/next_tests/unit_tests/type_system_tests/test_type_translation.py

+16
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,19 @@ def test_invalid_symbol_types():
166166
type_translation.from_type_hint(typing.Callable[[int], str])
167167
with pytest.raises(ValueError, match="Invalid callable annotations"):
168168
type_translation.from_type_hint(typing.Callable[[int], float])
169+
170+
171+
@pytest.mark.parametrize(
172+
"value, expected_dims",
173+
[
174+
(common.Dims[IDim, JDim], [IDim, JDim]),
175+
(common.Dims[IDim, np.float64], ValueError),
176+
(common.Dims["IDim"], ValueError),
177+
],
178+
)
179+
def test_generic_variadic_dims(value, expected_dims):
180+
if expected_dims == ValueError:
181+
with pytest.raises(ValueError, match="Invalid field dimension definition"):
182+
type_translation.from_type_hint(gtx.Field[value, np.int32])
183+
else:
184+
assert type_translation.from_type_hint(gtx.Field[value, np.int32]).dims == expected_dims

0 commit comments

Comments
 (0)