Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Enable tests for embedded with cupy #1372

Merged
merged 21 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/development/ADRs/0015-Test_Exclusion_Matrices.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
---
tags: []
tags: [testing]
---

# Test-Exclusion Matrices

- **Status**: valid
- **Authors**: Edoardo Paone (@edopao), Enrique G. Paredes (@egparedes)
- **Created**: 2023-09-21
- **Updated**: 2023-09-21
- **Updated**: 2024-01-25

In the context of Field View testing, lacking support for specific ITIR features while a certain backend
is being developed, we decided to use `pytest` fixtures to exclude unsupported tests.
Expand All @@ -22,7 +22,7 @@ the supported backends, while keeping the test code clean.
## Decision

It was decided to apply fixtures and markers from `pytest` module. The fixture is the same used to execute the test
on different backends (`fieldview_backend` and `program_processor`), but it is extended with a check on the available feature markers.
on different backends (`exec_alloc_descriptor` and `program_processor`), but it is extended with a check on the available feature markers.
If a test is annotated with a feature marker, the fixture will check if this feature is supported on the selected backend.
If no marker is specified, the test is supposed to run on all backends.

Expand All @@ -33,7 +33,7 @@ In the example below, `test_offset_field` requires the backend to support dynami
def test_offset_field(cartesian_case):
```

In order to selectively enable the backends, the dictionary `next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX`
In order to selectively enable the backends, the dictionary `next_tests.definitions.BACKEND_SKIP_TEST_MATRIX`
lists for each backend the features that are not supported.
The fixture will check if the annotated feature is present in the exclusion-matrix for the selected backend.
If so, the exclusion matrix will also specify the action `pytest` should take (e.g. `SKIP` or `XFAIL`).
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,11 @@ markers = [
'uses_origin: tests that require backend support for domain origin',
'uses_reduction_over_lift_expressions: tests that require backend support for reduction over lift expressions',
'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields',
'uses_scan: tests that uses scan',
'uses_scan_in_field_operator: tests that require backend support for scan in field operator',
'uses_scan_without_field_args: tests that require calls to scan that do not have any fields as arguments',
'uses_scan_nested: tests that use nested scans',
'uses_scan_requiring_projector: tests need a projector implementation in gtfn',
'uses_sparse_fields: tests that require backend support for sparse fields',
'uses_sparse_fields_as_output: tests that require backend support for writing sparse fields',
'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset',
Expand All @@ -349,7 +353,7 @@ markers = [
'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields',
'uses_cartesian_shift: tests that use a Cartesian connectivity',
'uses_unstructured_shift: tests that use a unstructured connectivity',
'uses_scan: tests that uses scan',
'uses_max_over: tests that use the max_over builtin',
'checks_specific_error: tests that rely on the backend to produce a specific error message'
]
norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*']
Expand Down
1 change: 1 addition & 0 deletions src/gt4py/next/allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def __init__(self) -> None:

device_allocators[core_defs.DeviceType.CPU] = StandardCPUFieldBufferAllocator()


assert is_field_allocator(device_allocators[core_defs.DeviceType.CPU])


Expand Down
9 changes: 6 additions & 3 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,8 +843,10 @@ def is_connectivity_field(
return isinstance(v, ConnectivityField) # type: ignore[misc] # we use extended_runtime_checkable


# Utility function to construct a `Field` from different buffer representations.
# Consider removing this function and using `Field` constructor directly. See also `_connectivity`.
@functools.singledispatch
def field(
def _field(
definition: Any,
/,
*,
Expand All @@ -854,8 +856,9 @@ def field(
raise NotImplementedError


# See comment for `_field`.
@functools.singledispatch
def connectivity(
def _connectivity(
definition: Any,
/,
codomain: Dimension,
Expand Down Expand Up @@ -980,7 +983,7 @@ def restrict(self, index: AnyIndexSpec) -> core_defs.IntegralScalar:
__getitem__ = restrict


connectivity.register(numbers.Integral, CartesianConnectivity.from_offset)
_connectivity.register(numbers.Integral, CartesianConnectivity.from_offset)


@enum.unique
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def empty(
buffer = next_allocators.allocate(
domain, dtype, aligned_index=aligned_index, allocator=allocator, device=device
)
res = common.field(buffer.ndarray, domain=domain)
res = common._field(buffer.ndarray, domain=domain)
assert common.is_mutable_field(res)
assert isinstance(res, nd_array_field.NdArrayField)
return res
Expand Down Expand Up @@ -356,9 +356,9 @@ def as_connectivity(
if (allocator is None) and (device is None) and xtyping.supports_dlpack(data):
device = core_defs.Device(*data.__dlpack_device__())
buffer = next_allocators.allocate(actual_domain, dtype, allocator=allocator, device=device)
# TODO(havogt): consider addin MutableNDArrayObject
# TODO(havogt): consider adding MutableNDArrayObject
buffer.ndarray[...] = storage_utils.asarray(data) # type: ignore[index]
connectivity_field = common.connectivity(
connectivity_field = common._connectivity(
buffer.ndarray, codomain=codomain, domain=actual_domain
)
assert isinstance(connectivity_field, nd_array_field.NdArrayConnectivityField)
Expand Down
22 changes: 12 additions & 10 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,7 @@ class NdArrayField(
_domain: common.Domain
_ndarray: core_defs.NDArrayObject

array_ns: ClassVar[
ModuleType
] # TODO(havogt) after storage PR is merged, update to the NDArrayNamespace protocol
array_ns: ClassVar[ModuleType] # TODO(havogt) introduce a NDArrayNamespace protocol

@property
def domain(self) -> common.Domain:
Expand Down Expand Up @@ -197,7 +195,11 @@ def remap(
# finally, take the new array
new_buffer = xp.take(self._ndarray, new_idx_array, axis=dim_idx)

return self.__class__.from_array(new_buffer, domain=new_domain, dtype=self.dtype)
return self.__class__.from_array(
new_buffer,
domain=new_domain,
dtype=self.dtype,
)

__call__ = remap # type: ignore[assignment]

Expand Down Expand Up @@ -510,15 +512,15 @@ class NumPyArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = np


common.field.register(np.ndarray, NumPyArrayField.from_array)
common._field.register(np.ndarray, NumPyArrayField.from_array)


@dataclasses.dataclass(frozen=True, eq=False)
class NumPyArrayConnectivityField(NdArrayConnectivityField):
array_ns: ClassVar[ModuleType] = np


common.connectivity.register(np.ndarray, NumPyArrayConnectivityField.from_array)
common._connectivity.register(np.ndarray, NumPyArrayConnectivityField.from_array)

# CuPy
if cp:
Expand All @@ -528,13 +530,13 @@ class NumPyArrayConnectivityField(NdArrayConnectivityField):
class CuPyArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = cp

common.field.register(cp.ndarray, CuPyArrayField.from_array)
common._field.register(cp.ndarray, CuPyArrayField.from_array)

@dataclasses.dataclass(frozen=True, eq=False)
class CuPyArrayConnectivityField(NdArrayConnectivityField):
array_ns: ClassVar[ModuleType] = cp

common.connectivity.register(cp.ndarray, CuPyArrayConnectivityField.from_array)
common._connectivity.register(cp.ndarray, CuPyArrayConnectivityField.from_array)

# JAX
if jnp:
Expand All @@ -552,7 +554,7 @@ def __setitem__(
# TODO(havogt): use something like `self.ndarray = self.ndarray.at(index).set(value)`
raise NotImplementedError("'__setitem__' for JaxArrayField not yet implemented.")

common.field.register(jnp.ndarray, JaxArrayField.from_array)
common._field.register(jnp.ndarray, JaxArrayField.from_array)


def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field:
Expand All @@ -565,7 +567,7 @@ def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]
else:
domain_slice.append(np.newaxis)
named_ranges.append((dim, common.UnitRange.infinite()))
return common.field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges))
return common._field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges))


def _builtins_broadcast(
Expand Down
33 changes: 27 additions & 6 deletions src/gt4py/next/embedded/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import dataclasses
from types import ModuleType
from typing import Any, Callable, Generic, ParamSpec, Sequence, TypeVar

import numpy as np

from gt4py import eve
from gt4py._core import definitions as core_defs
from gt4py.next import common, constructors, errors, utils
from gt4py.next import common, errors, utils
from gt4py.next.embedded import common as embedded_common, context as embedded_context


Expand All @@ -43,7 +46,8 @@ def __call__(self, *args: common.Field | core_defs.Scalar, **kwargs: common.Fiel
scan_range = embedded_context.closure_column_range.get()
assert self.axis == scan_range[0]
scan_axis = scan_range[0]
domain_intersection = _intersect_scan_args(*args, *kwargs.values())
all_args = [*args, *kwargs.values()]
domain_intersection = _intersect_scan_args(*all_args)
non_scan_domain = common.Domain(*[nr for nr in domain_intersection if nr[0] != scan_axis])

out_domain = common.Domain(
Expand All @@ -53,7 +57,8 @@ def __call__(self, *args: common.Field | core_defs.Scalar, **kwargs: common.Fiel
# even if the scan dimension is not in the input, we can scan over it
out_domain = common.Domain(*out_domain, (scan_range))

res = _construct_scan_array(out_domain)(self.init)
xp = _get_array_ns(*all_args)
res = _construct_scan_array(out_domain, xp)(self.init)

def scan_loop(hpos):
acc = self.init
Expand Down Expand Up @@ -128,7 +133,11 @@ def _tuple_assign_field(
):
@utils.tree_map
def impl(target: common.MutableField, source: common.Field):
target[domain] = source[domain]
if common.is_field(source):
target[domain] = source[domain]
else:
assert core_defs.is_scalar_type(source)
target[domain] = source

impl(target, source)

Expand All @@ -141,10 +150,21 @@ def _intersect_scan_args(
)


def _construct_scan_array(domain: common.Domain):
def _get_array_ns(
*args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...]
) -> ModuleType:
for arg in utils.flatten_nested_tuple(args):
if hasattr(arg, "array_ns"):
return arg.array_ns
return np
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of returning np, maybe the save thing would be to error for now.



def _construct_scan_array(
domain: common.Domain, xp: ModuleType
): # TODO(havogt) introduce a NDArrayNamespace protocol
@utils.tree_map
def impl(init: core_defs.Scalar) -> common.Field:
return constructors.empty(domain, dtype=type(init))
return common._field(xp.empty(domain.shape, dtype=type(init)), domain=domain)

return impl

Expand All @@ -168,6 +188,7 @@ def _tuple_at(
@utils.tree_map
def impl(field: common.Field | core_defs.Scalar) -> core_defs.Scalar:
res = field[pos] if common.is_field(field) else field
res = res.item() if hasattr(res, "item") else res # extract scalar value from array
assert core_defs.is_scalar_type(res)
return res

Expand Down
8 changes: 2 additions & 6 deletions src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,8 @@ def broadcast(
assert core_defs.is_scalar_type(
field
) # default implementation for scalars, Fields are handled via dispatch
return common.field(
np.asarray(field)[
tuple([np.newaxis] * len(dims))
], # TODO(havogt) use FunctionField once available
domain=common.Domain(dims=dims, ranges=tuple([common.UnitRange.infinite()] * len(dims))),
)
# TODO(havogt) implement with FunctionField, the workaround is to ignore broadcasting on scalars as they broadcast automatically, but we lose the check for compatible dimensions
return field # type: ignore[return-value] # see comment above


@WhereBuiltinFunction
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,7 @@ def _maker(a) -> common.Field:
offset = origin.get(d, 0)
ranges.append(common.UnitRange(-offset, s - offset))

res = common.field(a, domain=common.Domain(dims=tuple(axes), ranges=tuple(ranges)))
res = common._field(a, domain=common.Domain(dims=tuple(axes), ranges=tuple(ranges)))
return res

return _maker
Expand Down
4 changes: 2 additions & 2 deletions tests/next_tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from . import exclusion_matrices
from . import definitions


__all__ = ["exclusion_matrices", "get_processor_id"]
__all__ = ["definitions", "get_processor_id"]


def get_processor_id(processor):
Expand Down
Loading
Loading