diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 7e0e060834..011ca4d92b 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -168,7 +168,7 @@ class LocatedField(Protocol): @property @abc.abstractmethod - def __gt_domain__(self) -> common.Domain: ... + def dims(self) -> tuple[common.Dimension, ...]: ... # TODO(havogt): define generic Protocol to provide a concrete return type @abc.abstractmethod @@ -176,7 +176,7 @@ def field_getitem(self, indices: NamedFieldIndices) -> Any: ... @property def __gt_origin__(self) -> tuple[int, ...]: - return tuple([0] * len(self.__gt_domain__.dims)) + return tuple([0] * len(self.dims)) @runtime_checkable @@ -678,18 +678,12 @@ def _is_concrete_position(pos: Position) -> TypeGuard[ConcretePosition]: def _get_axes( field_or_tuple: LocatedField | tuple, ) -> Sequence[common.Dimension]: # arbitrary nesting of tuples of LocatedField - return _get_domain(field_or_tuple).dims - - -def _get_domain( - field_or_tuple: LocatedField | tuple, -) -> common.Domain: # arbitrary nesting of tuples of LocatedField if isinstance(field_or_tuple, tuple): - first = _get_domain(field_or_tuple[0]) - assert all(first == _get_domain(f) for f in field_or_tuple) + first = _get_axes(field_or_tuple[0]) + assert all(first == _get_axes(f) for f in field_or_tuple) return first else: - return field_or_tuple.__gt_domain__ + return field_or_tuple.dims def _single_vertical_idx( @@ -900,8 +894,8 @@ class NDArrayLocatedFieldWrapper(MutableLocatedField): _ndarrayfield: common.Field @property - def __gt_domain__(self) -> common.Domain: - return self._ndarrayfield.__gt_domain__ + def dims(self) -> tuple[common.Dimension, ...]: + return self._ndarrayfield.__gt_domain__.dims def _translate_named_indices( self, _named_indices: NamedFieldIndices @@ -1452,7 +1446,7 @@ def _tuple_assign(field: tuple | MutableLocatedField, value: Any, named_indices: class TupleOfFields(TupleField): def __init__(self, data): self.data = data - self.__gt_domain__ = _get_domain(data) + self.dims = _get_axes(data) def field_getitem(self, named_indices: NamedFieldIndices) -> Any: return _build_tuple_result(self.data, named_indices) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index add772e7ef..925ad33e86 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -311,7 +311,9 @@ def test_tuple_field_input(program_processor): ) inp2 = gtx.as_field( [IDim, JDim, KDim], - rng.normal(size=(shape[0], shape[1], shape[2])), + rng.normal( + size=(shape[0], shape[1], shape[2] + 1) + ), # TODO(havogt) currently we allow different sizes, needed for icon4py compatibility ) out = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) @@ -323,7 +325,7 @@ def test_tuple_field_input(program_processor): } run_processor(tuple_input[dom], program_processor, (inp1, inp2), out=out, offset_provider={}) if validate: - assert np.allclose(inp1.asnumpy() + inp2.asnumpy(), out.asnumpy()) + assert np.allclose(inp1.asnumpy() + inp2.asnumpy()[:, :, :-1], out.asnumpy()) @pytest.mark.xfail(reason="Implement wrapper for extradim as tuple")