Skip to content

Commit f281756

Browse files
committed
Merge origin/main
2 parents 5bf44f3 + 2970575 commit f281756

File tree

14 files changed

+87
-41
lines changed

14 files changed

+87
-41
lines changed

ci/cscs-ci.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ stages:
1919
- test
2020

2121
build py38 baseimage:
22-
extends: .container-builder
22+
extends: .container-builder-cscs-zen2
2323
stage: baseimage
2424
# we create a tag that depends on the SHA value of ci/base.Dockerfile, this way
2525
# a new base image is only built when the SHA of this file changes
@@ -52,7 +52,7 @@ build py310 baseimage:
5252
<<: *py310
5353

5454
build py38 image:
55-
extends: .container-builder
55+
extends: .container-builder-cscs-zen2
5656
needs: ["build py38 baseimage"]
5757
stage: image
5858
variables:

src/gt4py/_core/definitions.py

+2
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,8 @@ def shape(self) -> tuple[int, ...]: ...
444444
@property
445445
def dtype(self) -> Any: ...
446446

447+
def item(self) -> Any: ...
448+
447449
def astype(self, dtype: npt.DTypeLike) -> NDArrayObject: ...
448450

449451
def __getitem__(self, item: Any) -> NDArrayObject: ...

src/gt4py/next/allocators.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
CUPY_DEVICE: Final[Literal[None, core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM]] = (
4545
None
4646
if not cp
47-
else (core_defs.DeviceType.ROCM if cp.cuda.get_hipcc_path() else core_defs.DeviceType.CUDA)
47+
else (core_defs.DeviceType.ROCM if cp.cuda.runtime.is_hip else core_defs.DeviceType.CUDA)
4848
)
4949

5050

src/gt4py/next/common.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -623,14 +623,17 @@ def asnumpy(self) -> np.ndarray: ...
623623
def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ...
624624

625625
@abc.abstractmethod
626-
def restrict(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: ...
626+
def restrict(self, item: AnyIndexSpec) -> Field: ...
627+
628+
@abc.abstractmethod
629+
def as_scalar(self) -> core_defs.ScalarT: ...
627630

628631
# Operators
629632
@abc.abstractmethod
630633
def __call__(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ...
631634

632635
@abc.abstractmethod
633-
def __getitem__(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: ...
636+
def __getitem__(self, item: AnyIndexSpec) -> Field: ...
634637

635638
@abc.abstractmethod
636639
def __abs__(self) -> Field: ...
@@ -896,6 +899,9 @@ def ndarray(self) -> Never:
896899
def asnumpy(self) -> Never:
897900
raise NotImplementedError()
898901

902+
def as_scalar(self) -> Never:
903+
raise NotImplementedError()
904+
899905
@functools.cached_property
900906
def domain(self) -> Domain:
901907
return Domain(dims=(self.dimension,), ranges=(UnitRange.infinite(),))
@@ -947,9 +953,7 @@ def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Conne
947953

948954
__call__ = remap
949955

950-
def restrict(self, index: AnyIndexSpec) -> core_defs.IntegralScalar:
951-
if is_int_index(index):
952-
return index + self.offset
956+
def restrict(self, index: AnyIndexSpec) -> Never:
953957
raise NotImplementedError() # we could possibly implement with a FunctionField, but we don't have a use-case
954958

955959
__getitem__ = restrict

src/gt4py/next/embedded/nd_array_field.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,13 @@ def asnumpy(self) -> np.ndarray:
120120
else:
121121
return np.asarray(self._ndarray)
122122

123+
def as_scalar(self) -> core_defs.ScalarT:
124+
if self.domain.ndim != 0:
125+
raise ValueError(
126+
"'as_scalar' is only valid on 0-dimensional 'Field's, got a {self.domain.ndim}-dimensional 'Field'."
127+
)
128+
return self.ndarray.item()
129+
123130
@property
124131
def codomain(self) -> type[core_defs.ScalarT]:
125132
return self.dtype.scalar_type
@@ -204,15 +211,11 @@ def remap(
204211

205212
__call__ = remap # type: ignore[assignment]
206213

207-
def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT:
214+
def restrict(self, index: common.AnyIndexSpec) -> common.Field:
208215
new_domain, buffer_slice = self._slice(index)
209-
210216
new_buffer = self.ndarray[buffer_slice]
211-
if len(new_domain) == 0:
212-
# TODO: assert core_defs.is_scalar_type(new_buffer), new_buffer
213-
return new_buffer # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here
214-
else:
215-
return self.__class__.from_array(new_buffer, domain=new_domain)
217+
new_buffer = self.__class__.array_ns.asarray(new_buffer)
218+
return self.__class__.from_array(new_buffer, domain=new_domain)
216219

217220
__getitem__ = restrict
218221

@@ -433,7 +436,7 @@ def inverse_image(
433436

434437
return new_dims
435438

436-
def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.IntegralScalar:
439+
def restrict(self, index: common.AnyIndexSpec) -> common.Field:
437440
cache_key = (id(self.ndarray), self.domain, index)
438441

439442
if (restricted_connectivity := self._cache.get(cache_key, None)) is None:

src/gt4py/next/embedded/operators.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,7 @@ def _tuple_at(
187187
) -> core_defs.Scalar | tuple[core_defs.ScalarT | tuple, ...]:
188188
@utils.tree_map
189189
def impl(field: common.Field | core_defs.Scalar) -> core_defs.Scalar:
190-
res = field[pos] if common.is_field(field) else field
191-
res = res.item() if hasattr(res, "item") else res # extract scalar value from array
190+
res = field[pos].as_scalar() if common.is_field(field) else field
192191
assert core_defs.is_scalar_type(res)
193192
return res
194193

src/gt4py/next/ffront/func_to_foast.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def visit_Subscript(self, node: ast.Subscript, **kwargs) -> foast.Subscript:
289289
index = self._match_index(node.slice)
290290
except ValueError:
291291
raise errors.DSLError(
292-
self.get_location(node.slice), "eXpected an integral index."
292+
self.get_location(node.slice), "Expected an integral index."
293293
) from None
294294

295295
return foast.Subscript(

src/gt4py/next/ffront/past_passes/type_deduction.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -217,12 +217,12 @@ def visit_Call(self, node: past.Call, **kwargs):
217217
f"'{new_kwargs['out'].type}'."
218218
)
219219
elif new_func.id in ["minimum", "maximum"]:
220-
if new_args[0].type != new_args[1].type:
220+
if arg_types[0] != arg_types[1]:
221221
raise ValueError(
222222
f"First and second argument in '{new_func.id}' must be of the same type."
223-
f"Got '{new_args[0].type}' and '{new_args[1].type}'."
223+
f"Got '{arg_types[0]}' and '{arg_types[1]}'."
224224
)
225-
return_type = new_args[0].type
225+
return_type = arg_types[0]
226226
else:
227227
raise AssertionError(
228228
"Only calls to 'FieldOperator', 'ScanOperator' or 'minimum' and 'maximum' builtins allowed."

src/gt4py/next/ffront/past_to_itir.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def _visit_stencil_call_out_arg(
332332
) -> tuple[itir.Expr, itir.FunCall]:
333333
if isinstance(out_arg, past.Subscript):
334334
# as the ITIR does not support slicing a field we have to do a deeper
335-
# inspection of the PAST to emulate the behaviour
335+
# inspection of the PAST to emulate the behaviour
336336
out_field_name: past.Name = out_arg.value
337337
return (
338338
self._construct_itir_out_arg(out_field_name),
@@ -409,12 +409,11 @@ def visit_BinOp(self, node: past.BinOp, **kwargs) -> itir.FunCall:
409409
)
410410

411411
def visit_Call(self, node: past.Call, **kwargs) -> itir.FunCall:
412-
if node.func.id in ["maximum", "minimum"] and len(node.args) == 2:
412+
if node.func.id in ["maximum", "minimum"]:
413+
assert len(node.args) == 2
413414
return itir.FunCall(
414415
fun=itir.SymRef(id=node.func.id),
415416
args=[self.visit(node.args[0]), self.visit(node.args[1])],
416417
)
417418
else:
418-
raise AssertionError(
419-
"Only 'minimum' and 'maximum' builtins supported supported currently."
420-
)
419+
raise NotImplementedError("Only 'minimum', and 'maximum' builtins supported currently.")

src/gt4py/next/iterator/embedded.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -919,7 +919,7 @@ def _translate_named_indices(
919919
return tuple(domain_slice)
920920

921921
def field_getitem(self, named_indices: NamedFieldIndices) -> Any:
922-
return self._ndarrayfield[self._translate_named_indices(named_indices)]
922+
return self._ndarrayfield[self._translate_named_indices(named_indices)].as_scalar()
923923

924924
def field_setitem(self, named_indices: NamedFieldIndices, value: Any):
925925
if common.is_mutable_field(self._ndarrayfield):
@@ -1040,6 +1040,7 @@ class IndexField(common.Field):
10401040
"""
10411041

10421042
_dimension: common.Dimension
1043+
_cur_index: Optional[core_defs.IntegralScalar] = None
10431044

10441045
@property
10451046
def __gt_domain__(self) -> common.Domain:
@@ -1055,7 +1056,10 @@ def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override
10551056

10561057
@property
10571058
def domain(self) -> common.Domain:
1058-
return common.Domain((self._dimension, common.UnitRange.infinite()))
1059+
if self._cur_index is None:
1060+
return common.Domain((self._dimension, common.UnitRange.infinite()))
1061+
else:
1062+
return common.Domain()
10591063

10601064
@property
10611065
def codomain(self) -> type[core_defs.int32]:
@@ -1072,16 +1076,24 @@ def ndarray(self) -> core_defs.NDArrayObject:
10721076
def asnumpy(self) -> np.ndarray:
10731077
raise NotImplementedError()
10741078

1079+
def as_scalar(self) -> core_defs.IntegralScalar:
1080+
if self.domain.ndim != 0:
1081+
raise ValueError(
1082+
"'as_scalar' is only valid on 0-dimensional 'Field's, got a {self.domain.ndim}-dimensional 'Field'."
1083+
)
1084+
assert self._cur_index is not None
1085+
return self._cur_index
1086+
10751087
def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -> common.Field:
10761088
# TODO can be implemented by constructing and ndarray (but do we know of which kind?)
10771089
raise NotImplementedError()
10781090

1079-
def restrict(self, item: common.AnyIndexSpec) -> common.Field | core_defs.int32:
1091+
def restrict(self, item: common.AnyIndexSpec) -> common.Field:
10801092
if common.is_absolute_index_sequence(item) and all(common.is_named_index(e) for e in item): # type: ignore[arg-type] # we don't want to pollute the typing of `is_absolute_index_sequence` for this temporary code # fmt: off
10811093
d, r = item[0]
10821094
assert d == self._dimension
1083-
assert isinstance(r, int)
1084-
return self.dtype.scalar_type(r)
1095+
assert isinstance(r, core_defs.INTEGRAL_TYPES)
1096+
return self.__class__(self._dimension, r) # type: ignore[arg-type] # not sure why the assert above does not work
10851097
# TODO set a domain...
10861098
raise NotImplementedError()
10871099

@@ -1195,8 +1207,12 @@ def remap(self, index_field: common.ConnectivityField | fbuiltins.FieldOffset) -
11951207
# TODO can be implemented by constructing and ndarray (but do we know of which kind?)
11961208
raise NotImplementedError()
11971209

1198-
def restrict(self, item: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT:
1210+
def restrict(self, item: common.AnyIndexSpec) -> common.Field:
11991211
# TODO set a domain...
1212+
return self
1213+
1214+
def as_scalar(self) -> core_defs.ScalarT:
1215+
assert self.domain.ndim == 0
12001216
return self._value
12011217

12021218
__call__ = remap

tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py

+15
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,21 @@ def testee(inp: gtx.Field[[KDim], float]) -> gtx.Field[[KDim], float]:
321321
cases.verify(cartesian_case, testee, inp, out=out, ref=expected)
322322

323323

324+
def test_single_value_field(cartesian_case):
325+
@gtx.field_operator
326+
def testee_fo(a: cases.IKField) -> cases.IKField:
327+
return a
328+
329+
@gtx.program
330+
def testee_prog(a: cases.IKField):
331+
testee_fo(a, out=a[1:2, 3:4])
332+
333+
a = cases.allocate(cartesian_case, testee_prog, "a")()
334+
ref = a[1, 3]
335+
336+
cases.verify(cartesian_case, testee_prog, a, inout=a[1, 3], ref=ref)
337+
338+
324339
def test_astype_int(cartesian_case): # noqa: F811 # fixtures
325340
@gtx.field_operator
326341
def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]:

tests/next_tests/integration_tests/feature_tests/iterator_tests/test_horizontal_indirection.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_simple_indirection(program_processor):
7070

7171
ref = np.zeros(shape, dtype=inp.dtype)
7272
for i in range(shape[0]):
73-
ref[i] = inp.ndarray[i + 1 - 1] if cond[i] < 0.0 else inp.ndarray[i + 1 + 1]
73+
ref[i] = inp.asnumpy()[i + 1 - 1] if cond.asnumpy()[i] < 0.0 else inp.asnumpy()[i + 1 + 1]
7474

7575
run_processor(
7676
conditional_indirection[cartesian_domain(named_range(IDim, 0, shape[0]))],
@@ -101,7 +101,7 @@ def test_direct_offset_for_indirection(program_processor):
101101

102102
ref = np.zeros(shape)
103103
for i in range(shape[0]):
104-
ref[i] = inp[i + cond[i]]
104+
ref[i] = inp.asnumpy()[i + cond.asnumpy()[i]]
105105

106106
run_processor(
107107
direct_indirection[cartesian_domain(named_range(IDim, 0, shape[0]))],

tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,15 @@ def fencil(x, y, z, out, inp):
6565
def naive_lap(inp):
6666
shape = [inp.shape[0] - 2, inp.shape[1] - 2, inp.shape[2]]
6767
out = np.zeros(shape)
68+
inp_data = inp.asnumpy()
6869
for i in range(1, shape[0] + 1):
6970
for j in range(1, shape[1] + 1):
7071
for k in range(0, shape[2]):
71-
out[i - 1, j - 1, k] = -4 * inp[i, j, k] + (
72-
inp[i + 1, j, k] + inp[i - 1, j, k] + inp[i, j + 1, k] + inp[i, j - 1, k]
72+
out[i - 1, j - 1, k] = -4 * inp_data[i, j, k] + (
73+
inp_data[i + 1, j, k]
74+
+ inp_data[i - 1, j, k]
75+
+ inp_data[i, j + 1, k]
76+
+ inp_data[i, j - 1, k]
7377
)
7478
return out
7579

tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -468,10 +468,11 @@ def test_absolute_indexing_value_return():
468468
field = common._field(np.reshape(np.arange(100, dtype=np.int32), (10, 10)), domain=domain)
469469

470470
named_index = ((IDim, 12), (JDim, 6))
471+
assert common.is_field(field)
471472
value = field[named_index]
472473

473-
assert isinstance(value, np.int32)
474-
assert value == 21
474+
assert common.is_field(value)
475+
assert value.as_scalar() == 21
475476

476477

477478
@pytest.mark.parametrize(
@@ -568,14 +569,17 @@ def test_relative_indexing_slice_3D(index, expected_shape, expected_domain):
568569

569570
@pytest.mark.parametrize(
570571
"index, expected_value",
571-
[((1, 0), 10), ((0, 1), 1)],
572+
[
573+
((1, 0), 10),
574+
((0, 1), 1),
575+
],
572576
)
573577
def test_relative_indexing_value_return(index, expected_value):
574578
domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 15), UnitRange(2, 12)))
575579
field = common._field(np.reshape(np.arange(100, dtype=int), (10, 10)), domain=domain)
576580
indexed_field = field[index]
577581

578-
assert indexed_field == expected_value
582+
assert indexed_field.as_scalar() == expected_value
579583

580584

581585
@pytest.mark.parametrize("lazy_slice", [lambda f: f[13], lambda f: f[:5, :3, :2]])

0 commit comments

Comments
 (0)