Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Enable tests for embedded with cupy #1372

Merged
merged 21 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ def is_connectivity_field(


@functools.singledispatch
def field(
def _field(
definition: Any,
/,
*,
Expand All @@ -769,7 +769,7 @@ def field(


@functools.singledispatch
def connectivity(
def _connectivity(
definition: Any,
/,
codomain: Dimension,
Expand Down Expand Up @@ -898,7 +898,7 @@ def restrict(self, index: AnyIndexSpec) -> core_defs.IntegralScalar:
__getitem__ = restrict


connectivity.register(numbers.Integral, CartesianConnectivity.from_offset)
_connectivity.register(numbers.Integral, CartesianConnectivity.from_offset)


@enum.unique
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def empty(
buffer = next_allocators.allocate(
domain, dtype, aligned_index=aligned_index, allocator=allocator, device=device
)
res = common.field(buffer.ndarray, domain=domain)
res = common._field(buffer.ndarray, domain=domain)
assert common.is_mutable_field(res)
assert isinstance(res, nd_array_field.NdArrayField)
return res
Expand Down Expand Up @@ -357,7 +357,7 @@ def as_connectivity(
device = core_defs.Device(*data.__dlpack_device__())
buffer = next_allocators.allocate(actual_domain, dtype, allocator=allocator, device=device)
buffer.ndarray[...] = storage_utils.asarray(data) # type: ignore[index] # TODO(havogt): consider addin MutableNDArrayObject
connectivity_field = common.connectivity(
connectivity_field = common._connectivity(
buffer.ndarray, codomain=codomain, domain=actual_domain
)
assert isinstance(connectivity_field, nd_array_field.NdArrayConnectivityField)
Expand Down
12 changes: 6 additions & 6 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,15 +509,15 @@ class NumPyArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = np


common.field.register(np.ndarray, NumPyArrayField.from_array)
common._field.register(np.ndarray, NumPyArrayField.from_array)


@dataclasses.dataclass(frozen=True, eq=False)
class NumPyArrayConnectivityField(NdArrayConnectivityField):
array_ns: ClassVar[ModuleType] = np


common.connectivity.register(np.ndarray, NumPyArrayConnectivityField.from_array)
common._connectivity.register(np.ndarray, NumPyArrayConnectivityField.from_array)

# CuPy
if cp:
Expand All @@ -527,13 +527,13 @@ class NumPyArrayConnectivityField(NdArrayConnectivityField):
class CuPyArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = cp

common.field.register(cp.ndarray, CuPyArrayField.from_array)
common._field.register(cp.ndarray, CuPyArrayField.from_array)

@dataclasses.dataclass(frozen=True, eq=False)
class CuPyArrayConnectivityField(NdArrayConnectivityField):
array_ns: ClassVar[ModuleType] = cp

common.connectivity.register(cp.ndarray, CuPyArrayConnectivityField.from_array)
common._connectivity.register(cp.ndarray, CuPyArrayConnectivityField.from_array)

# JAX
if jnp:
Expand All @@ -551,7 +551,7 @@ def __setitem__(
# TODO(havogt): use something like `self.ndarray = self.ndarray.at(index).set(value)`
raise NotImplementedError("`__setitem__` for JaxArrayField not yet implemented.")

common.field.register(jnp.ndarray, JaxArrayField.from_array)
common._field.register(jnp.ndarray, JaxArrayField.from_array)


def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field:
Expand All @@ -566,7 +566,7 @@ def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]
named_ranges.append(
(dim, common.UnitRange(common.Infinity.negative(), common.Infinity.positive()))
)
return common.field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges))
return common._field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges))


def _builtins_broadcast(
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def broadcast(
assert core_defs.is_scalar_type(
field
) # default implementation for scalars, Fields are handled via dispatch
return common.field(
return common._field(
np.asarray(field)[
tuple([np.newaxis] * len(dims))
], # TODO(havogt) use FunctionField once available
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ def _maker(a) -> common.Field:
offset = origin.get(d, 0)
ranges.append(common.UnitRange(-offset, s - offset))

res = common.field(a, domain=common.Domain(dims=tuple(axes), ranges=tuple(ranges)))
res = common._field(a, domain=common.Domain(dims=tuple(axes), ranges=tuple(ranges)))
return res

return _maker
Expand Down
19 changes: 18 additions & 1 deletion tests/next_tests/exclusion_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import pytest

from gt4py.next import allocators as next_allocators


# Skip definitions
XFAIL = pytest.xfail
Expand All @@ -44,6 +46,11 @@ def short_id(self, num_components: int = 2) -> str:
return ".".join(self.value.split(".")[-num_components:])


class _PythonObjectIdMixinForAllocator(_PythonObjectIdMixin):
def short_id(self, num_components: int = 1) -> str:
return "None-" + super().short_id(num_components)


class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum):
GTFN_CPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn"
GTFN_CPU_IMPERATIVE = "gt4py.next.program_processors.runners.gtfn.run_gtfn_imperative"
Expand All @@ -55,6 +62,15 @@ class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum):
DOUBLE_ROUNDTRIP = "gt4py.next.program_processors.runners.double_roundtrip.backend"


cpu_allocator = next_allocators.StandardCPUFieldBufferAllocator()
gpu_allocator = next_allocators.StandardGPUFieldBufferAllocator()


class AllocatorId(_PythonObjectIdMixinForAllocator, str, enum.Enum):
CPU_ALLOCATOR = "next_tests.exclusion_matrices.cpu_allocator"
GPU_ALLOCATOR = "next_tests.exclusion_matrices.gpu_allocator"


class OptionalProgramBackendId(_PythonObjectIdMixin, str, enum.Enum):
DACE_CPU = "gt4py.next.program_processors.runners.dace_iterator.run_dace_cpu"
DACE_GPU = "gt4py.next.program_processors.runners.dace_iterator.run_dace_gpu"
Expand Down Expand Up @@ -146,7 +162,8 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
#: Skip matrix, contains for each backend processor a list of tuples with following fields:
#: (<test_marker>, <skip_definition, <skip_message>)
BACKEND_SKIP_TEST_MATRIX = {
None: EMBEDDED_SKIP_LIST,
AllocatorId.CPU_ALLOCATOR: EMBEDDED_SKIP_LIST,
AllocatorId.GPU_ALLOCATOR: EMBEDDED_SKIP_LIST,
OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST,
OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST
+ [
Expand Down
26 changes: 14 additions & 12 deletions tests/next_tests/integration_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from gt4py._core import definitions as core_defs
from gt4py.eve import extended_typing as xtyping
from gt4py.eve.extended_typing import Self
from gt4py.next import common, constructors, utils
from gt4py.next import allocators as next_allocators, common, constructors, utils
from gt4py.next.ffront import decorator
from gt4py.next.program_processors import processor_interface as ppi
from gt4py.next.type_system import type_specifications as ts, type_translation
Expand Down Expand Up @@ -103,7 +103,7 @@ def scalar(self, dtype: np.typing.DTypeLike) -> ScalarValue:

def field(
self,
backend: ppi.ProgramProcessor,
allocator: next_allocators.FieldBufferAllocatorProtocol,
sizes: dict[gtx.Dimension, int],
dtype: np.typing.DTypeLike,
) -> FieldValue:
Expand Down Expand Up @@ -137,15 +137,15 @@ def scalar_value(self) -> ScalarValue:

def field(
self,
backend: ppi.ProgramExecutor,
allocator: next_allocators.FieldBufferAllocatorProtocol,
sizes: dict[gtx.Dimension, int],
dtype: np.typing.DTypeLike,
) -> FieldValue:
return constructors.full(
domain=common.domain(sizes),
fill_value=self.value,
dtype=dtype,
allocator=backend,
allocator=allocator,
)


Expand All @@ -166,7 +166,7 @@ def scalar_value(self) -> ScalarValue:

def field(
self,
backend: ppi.ProgramExecutor,
allocator: next_allocators.FieldBufferAllocatorProtocol,
sizes: dict[gtx.Dimension, int],
dtype: np.typing.DTypeLike,
) -> FieldValue:
Expand All @@ -176,7 +176,7 @@ def field(
)
n_data = list(sizes.values())[0]
return constructors.as_field(
domain=common.domain(sizes), data=np.arange(0, n_data, dtype=dtype), allocator=backend
domain=common.domain(sizes), data=np.arange(0, n_data, dtype=dtype), allocator=allocator
)

def from_case(
Expand Down Expand Up @@ -207,7 +207,7 @@ def scalar_value(self) -> ScalarValue:

def field(
self,
backend: ppi.ProgramProcessor,
allocator: next_allocators.FieldBufferAllocatorProtocol,
sizes: dict[gtx.Dimension, int],
dtype: np.typing.DTypeLike,
) -> FieldValue:
Expand All @@ -218,7 +218,7 @@ def field(
return constructors.as_field(
common.domain(sizes),
np.arange(start, start + n_data, dtype=dtype).reshape(svals),
allocator=backend,
allocator=allocator,
)

def from_case(
Expand Down Expand Up @@ -482,10 +482,11 @@ def verify_with_default_data(
@pytest.fixture
def cartesian_case(fieldview_backend): # noqa: F811 # fixtures
yield Case(
fieldview_backend,
fieldview_backend if isinstance(fieldview_backend, ppi.ProgramExecutor) else None,
offset_provider={"Ioff": IDim, "Joff": JDim, "Koff": KDim},
default_sizes={IDim: 10, JDim: 10, KDim: 10},
grid_type=common.GridType.CARTESIAN,
allocator=fieldview_backend,
)


Expand Down Expand Up @@ -516,7 +517,7 @@ def _allocate_from_type(
match arg_type:
case ts.FieldType(dims=dims, dtype=arg_dtype):
return strategy.field(
backend=case.backend,
allocator=case.allocator,
sizes={dim: sizes[dim] for dim in dims},
dtype=dtype or arg_dtype.kind.name.lower(),
)
Expand Down Expand Up @@ -601,11 +602,12 @@ def get_default_data(
class Case:
"""Parametrizable components for single feature integration tests."""

backend: ppi.ProgramProcessor
backend: Optional[ppi.ProgramProcessor]
offset_provider: dict[str, common.Connectivity | gtx.Dimension]
default_sizes: dict[gtx.Dimension, int]
grid_type: common.GridType
allocator: next_allocators.FieldBufferAllocatorFactoryProtocol

@property
def as_field(self):
return constructors.as_field.partial(allocator=self.backend)
return constructors.as_field.partial(allocator=self.allocator)
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> Non
definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE,
definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES,
pytest.param(definitions.ProgramBackendId.GTFN_GPU, marks=pytest.mark.requires_gpu),
None,
# will use the default (embedded) execution, but input/output allocated with the provided allocator
definitions.AllocatorId.CPU_ALLOCATOR,
definitions.AllocatorId.GPU_ALLOCATOR,
]
+ OPTIONAL_PROCESSORS,
ids=lambda p: p.short_id() if p is not None else "None",
Expand All @@ -69,18 +71,18 @@ def fieldview_backend(request):
Notes:
Check ADR 15 for details on the test-exclusion matrices.
"""
backend_id = request.param
backend = None if backend_id is None else backend_id.load()
backend_or_allocator_id = request.param
backend_or_allocator = backend_or_allocator_id.load()

for marker, skip_mark, msg in next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX.get(
backend_id, []
backend_or_allocator_id, []
):
if request.node.get_closest_marker(marker):
skip_mark(msg.format(marker=marker, backend=backend_id))
skip_mark(msg.format(marker=marker, backend=backend_or_allocator_id))

backup_backend = decorator.DEFAULT_BACKEND
decorator.DEFAULT_BACKEND = no_backend
yield backend
yield backend_or_allocator
decorator.DEFAULT_BACKEND = backup_backend


Expand Down
Loading