Skip to content

Commit

Permalink
bug[next]: allow fields of different sizes in tuple in itir embedded (#…
Browse files Browse the repository at this point in the history
…1442)

Undo an unintended change in #1202 to re-enable an icon4py pattern.

Longer term, probably, only transposable tuples of fields make sense, e.g. by intersecting.
  • Loading branch information
havogt authored Feb 2, 2024
1 parent d6dfd6f commit 58ec4dd
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 16 deletions.
22 changes: 8 additions & 14 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,15 @@ class LocatedField(Protocol):

@property
@abc.abstractmethod
def __gt_domain__(self) -> common.Domain: ...
def dims(self) -> tuple[common.Dimension, ...]: ...

# TODO(havogt): define generic Protocol to provide a concrete return type
@abc.abstractmethod
def field_getitem(self, indices: NamedFieldIndices) -> Any: ...

@property
def __gt_origin__(self) -> tuple[int, ...]:
return tuple([0] * len(self.__gt_domain__.dims))
return tuple([0] * len(self.dims))


@runtime_checkable
Expand Down Expand Up @@ -678,18 +678,12 @@ def _is_concrete_position(pos: Position) -> TypeGuard[ConcretePosition]:
def _get_axes(
field_or_tuple: LocatedField | tuple,
) -> Sequence[common.Dimension]: # arbitrary nesting of tuples of LocatedField
return _get_domain(field_or_tuple).dims


def _get_domain(
field_or_tuple: LocatedField | tuple,
) -> common.Domain: # arbitrary nesting of tuples of LocatedField
if isinstance(field_or_tuple, tuple):
first = _get_domain(field_or_tuple[0])
assert all(first == _get_domain(f) for f in field_or_tuple)
first = _get_axes(field_or_tuple[0])
assert all(first == _get_axes(f) for f in field_or_tuple)
return first
else:
return field_or_tuple.__gt_domain__
return field_or_tuple.dims


def _single_vertical_idx(
Expand Down Expand Up @@ -900,8 +894,8 @@ class NDArrayLocatedFieldWrapper(MutableLocatedField):
_ndarrayfield: common.Field

@property
def __gt_domain__(self) -> common.Domain:
return self._ndarrayfield.__gt_domain__
def dims(self) -> tuple[common.Dimension, ...]:
return self._ndarrayfield.__gt_domain__.dims

def _translate_named_indices(
self, _named_indices: NamedFieldIndices
Expand Down Expand Up @@ -1452,7 +1446,7 @@ def _tuple_assign(field: tuple | MutableLocatedField, value: Any, named_indices:
class TupleOfFields(TupleField):
def __init__(self, data):
self.data = data
self.__gt_domain__ = _get_domain(data)
self.dims = _get_axes(data)

def field_getitem(self, named_indices: NamedFieldIndices) -> Any:
return _build_tuple_result(self.data, named_indices)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,9 @@ def test_tuple_field_input(program_processor):
)
inp2 = gtx.as_field(
[IDim, JDim, KDim],
rng.normal(size=(shape[0], shape[1], shape[2])),
rng.normal(
size=(shape[0], shape[1], shape[2] + 1)
), # TODO(havogt) currently we allow different sizes, needed for icon4py compatibility
)

out = gtx.as_field([IDim, JDim, KDim], np.zeros(shape))
Expand All @@ -323,7 +325,7 @@ def test_tuple_field_input(program_processor):
}
run_processor(tuple_input[dom], program_processor, (inp1, inp2), out=out, offset_provider={})
if validate:
assert np.allclose(inp1.asnumpy() + inp2.asnumpy(), out.asnumpy())
assert np.allclose(inp1.asnumpy() + inp2.asnumpy()[:, :, :-1], out.asnumpy())


@pytest.mark.xfail(reason="Implement wrapper for extradim as tuple")
Expand Down

0 comments on commit 58ec4dd

Please sign in to comment.