Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Enable tests for embedded with cupy #1372

Merged
merged 21 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,9 @@ markers = [
'uses_origin: tests that require backend support for domain origin',
'uses_reduction_over_lift_expressions: tests that require backend support for reduction over lift expressions',
'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields',
'uses_scan: tests that uses scan',
'uses_scan_in_field_operator: tests that require backend support for scan in field operator',
'uses_scan_without_field_args: tests that require calls to scan that do not have any fields as arguments',
'uses_sparse_fields: tests that require backend support for sparse fields',
'uses_sparse_fields_as_output: tests that require backend support for writing sparse fields',
'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset',
Expand All @@ -349,7 +351,6 @@ markers = [
'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields',
'uses_cartesian_shift: tests that use a Cartesian connectivity',
'uses_unstructured_shift: tests that use a unstructured connectivity',
'uses_scan: tests that uses scan',
'checks_specific_error: tests that rely on the backend to produce a specific error message'
]
norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*']
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@ def is_connectivity_field(


@functools.singledispatch
def field(
def _field(
definition: Any,
/,
*,
Expand All @@ -855,7 +855,7 @@ def field(


@functools.singledispatch
def connectivity(
def _connectivity(
definition: Any,
/,
codomain: Dimension,
Expand Down Expand Up @@ -980,7 +980,7 @@ def restrict(self, index: AnyIndexSpec) -> core_defs.IntegralScalar:
__getitem__ = restrict


connectivity.register(numbers.Integral, CartesianConnectivity.from_offset)
_connectivity.register(numbers.Integral, CartesianConnectivity.from_offset)


@enum.unique
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def empty(
buffer = next_allocators.allocate(
domain, dtype, aligned_index=aligned_index, allocator=allocator, device=device
)
res = common.field(buffer.ndarray, domain=domain)
res = common._field(buffer.ndarray, domain=domain)
assert common.is_mutable_field(res)
assert isinstance(res, nd_array_field.NdArrayField)
return res
Expand Down Expand Up @@ -356,9 +356,9 @@ def as_connectivity(
if (allocator is None) and (device is None) and xtyping.supports_dlpack(data):
device = core_defs.Device(*data.__dlpack_device__())
buffer = next_allocators.allocate(actual_domain, dtype, allocator=allocator, device=device)
# TODO(havogt): consider addin MutableNDArrayObject
# TODO(havogt): consider adding MutableNDArrayObject
buffer.ndarray[...] = storage_utils.asarray(data) # type: ignore[index]
connectivity_field = common.connectivity(
connectivity_field = common._connectivity(
buffer.ndarray, codomain=codomain, domain=actual_domain
)
assert isinstance(connectivity_field, nd_array_field.NdArrayConnectivityField)
Expand Down
22 changes: 12 additions & 10 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,7 @@ class NdArrayField(
_domain: common.Domain
_ndarray: core_defs.NDArrayObject

array_ns: ClassVar[
ModuleType
] # TODO(havogt) after storage PR is merged, update to the NDArrayNamespace protocol
array_ns: ClassVar[ModuleType] # TODO(havogt) introduce a NDArrayNamespace protocol

@property
def domain(self) -> common.Domain:
Expand Down Expand Up @@ -197,7 +195,11 @@ def remap(
# finally, take the new array
new_buffer = xp.take(self._ndarray, new_idx_array, axis=dim_idx)

return self.__class__.from_array(new_buffer, domain=new_domain, dtype=self.dtype)
return self.__class__.from_array(
new_buffer,
domain=new_domain,
dtype=self.dtype,
)

__call__ = remap # type: ignore[assignment]

Expand Down Expand Up @@ -510,15 +512,15 @@ class NumPyArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = np


common.field.register(np.ndarray, NumPyArrayField.from_array)
common._field.register(np.ndarray, NumPyArrayField.from_array)


@dataclasses.dataclass(frozen=True, eq=False)
class NumPyArrayConnectivityField(NdArrayConnectivityField):
array_ns: ClassVar[ModuleType] = np


common.connectivity.register(np.ndarray, NumPyArrayConnectivityField.from_array)
common._connectivity.register(np.ndarray, NumPyArrayConnectivityField.from_array)

# CuPy
if cp:
Expand All @@ -528,13 +530,13 @@ class NumPyArrayConnectivityField(NdArrayConnectivityField):
class CuPyArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = cp

common.field.register(cp.ndarray, CuPyArrayField.from_array)
common._field.register(cp.ndarray, CuPyArrayField.from_array)

@dataclasses.dataclass(frozen=True, eq=False)
class CuPyArrayConnectivityField(NdArrayConnectivityField):
array_ns: ClassVar[ModuleType] = cp

common.connectivity.register(cp.ndarray, CuPyArrayConnectivityField.from_array)
common._connectivity.register(cp.ndarray, CuPyArrayConnectivityField.from_array)

# JAX
if jnp:
Expand All @@ -552,7 +554,7 @@ def __setitem__(
# TODO(havogt): use something like `self.ndarray = self.ndarray.at(index).set(value)`
raise NotImplementedError("'__setitem__' for JaxArrayField not yet implemented.")

common.field.register(jnp.ndarray, JaxArrayField.from_array)
common._field.register(jnp.ndarray, JaxArrayField.from_array)


def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field:
Expand All @@ -565,7 +567,7 @@ def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]
else:
domain_slice.append(np.newaxis)
named_ranges.append((dim, common.UnitRange.infinite()))
return common.field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges))
return common._field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges))


def _builtins_broadcast(
Expand Down
33 changes: 27 additions & 6 deletions src/gt4py/next/embedded/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# SPDX-License-Identifier: GPL-3.0-or-later

import dataclasses
from types import ModuleType
from typing import Any, Callable, Generic, ParamSpec, Sequence, TypeVar

import numpy as np

from gt4py import eve
from gt4py._core import definitions as core_defs
from gt4py.next import common, constructors, errors, utils
from gt4py.next import common, errors, utils
from gt4py.next.embedded import common as embedded_common, context as embedded_context


Expand All @@ -43,7 +46,8 @@ def __call__(self, *args: common.Field | core_defs.Scalar, **kwargs: common.Fiel
scan_range = embedded_context.closure_column_range.get()
assert self.axis == scan_range[0]
scan_axis = scan_range[0]
domain_intersection = _intersect_scan_args(*args, *kwargs.values())
all_args = [*args, *kwargs.values()]
domain_intersection = _intersect_scan_args(*all_args)
non_scan_domain = common.Domain(*[nr for nr in domain_intersection if nr[0] != scan_axis])

out_domain = common.Domain(
Expand All @@ -53,7 +57,8 @@ def __call__(self, *args: common.Field | core_defs.Scalar, **kwargs: common.Fiel
# even if the scan dimension is not in the input, we can scan over it
out_domain = common.Domain(*out_domain, (scan_range))

res = _construct_scan_array(out_domain)(self.init)
xp = _get_array_ns(*all_args)
res = _construct_scan_array(out_domain, xp)(self.init)

def scan_loop(hpos):
acc = self.init
Expand Down Expand Up @@ -128,7 +133,11 @@ def _tuple_assign_field(
):
@utils.tree_map
def impl(target: common.MutableField, source: common.Field):
target[domain] = source[domain]
if common.is_field(source):
target[domain] = source[domain]
else:
assert core_defs.is_scalar_type(source)
target[domain] = source

impl(target, source)

Expand All @@ -141,10 +150,21 @@ def _intersect_scan_args(
)


def _construct_scan_array(domain: common.Domain):
def _get_array_ns(
*args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...]
) -> ModuleType:
for arg in utils.flatten_nested_tuple(args):
if hasattr(arg, "array_ns"):
return arg.array_ns
return np
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of returning np, maybe the save thing would be to error for now.



def _construct_scan_array(
domain: common.Domain, xp: ModuleType
): # TODO(havogt) introduce a NDArrayNamespace protocol
@utils.tree_map
def impl(init: core_defs.Scalar) -> common.Field:
return constructors.empty(domain, dtype=type(init))
return common._field(xp.empty(domain.shape, dtype=type(init)), domain=domain)

return impl

Expand All @@ -168,6 +188,7 @@ def _tuple_at(
@utils.tree_map
def impl(field: common.Field | core_defs.Scalar) -> core_defs.Scalar:
res = field[pos] if common.is_field(field) else field
res = res.item() if hasattr(res, "item") else res # extract scalar value from array
assert core_defs.is_scalar_type(res)
return res

Expand Down
8 changes: 2 additions & 6 deletions src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,8 @@ def broadcast(
assert core_defs.is_scalar_type(
field
) # default implementation for scalars, Fields are handled via dispatch
return common.field(
np.asarray(field)[
tuple([np.newaxis] * len(dims))
], # TODO(havogt) use FunctionField once available
domain=common.Domain(dims=dims, ranges=tuple([common.UnitRange.infinite()] * len(dims))),
)
# TODO(havogt) implement with FunctionField, the workaround is to ignore broadcasting on scalars as they broadcast automatically, but we lose the check for compatible dimensions
return field # type: ignore[return-value] # see comment above


@WhereBuiltinFunction
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,7 @@ def _maker(a) -> common.Field:
offset = origin.get(d, 0)
ranges.append(common.UnitRange(-offset, s - offset))

res = common.field(a, domain=common.Domain(dims=tuple(axes), ranges=tuple(ranges)))
res = common._field(a, domain=common.Domain(dims=tuple(axes), ranges=tuple(ranges)))
return res

return _maker
Expand Down
4 changes: 2 additions & 2 deletions tests/next_tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

from . import exclusion_matrices
from . import definitions


__all__ = ["exclusion_matrices", "get_processor_id"]
__all__ = ["definitions", "get_processor_id"]


def get_processor_id(processor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import pytest

from gt4py.next import allocators as next_allocators


# Skip definitions
XFAIL = pytest.xfail
Expand All @@ -44,6 +46,11 @@ def short_id(self, num_components: int = 2) -> str:
return ".".join(self.value.split(".")[-num_components:])


class _PythonObjectIdMixinForAllocator(_PythonObjectIdMixin):
def short_id(self, num_components: int = 1) -> str:
return "None-" + super().short_id(num_components)


class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum):
GTFN_CPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn"
GTFN_CPU_IMPERATIVE = "gt4py.next.program_processors.runners.gtfn.run_gtfn_imperative"
Expand All @@ -55,6 +62,15 @@ class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum):
DOUBLE_ROUNDTRIP = "gt4py.next.program_processors.runners.double_roundtrip.backend"


cpu_allocator = next_allocators.StandardCPUFieldBufferAllocator()
gpu_allocator = next_allocators.StandardGPUFieldBufferAllocator()


class AllocatorId(_PythonObjectIdMixinForAllocator, str, enum.Enum):
CPU_ALLOCATOR = "next_tests.definitions.cpu_allocator"
GPU_ALLOCATOR = "next_tests.definitions.gpu_allocator"


class OptionalProgramBackendId(_PythonObjectIdMixin, str, enum.Enum):
DACE_CPU = "gt4py.next.program_processors.runners.dace_iterator.run_dace_cpu"
DACE_GPU = "gt4py.next.program_processors.runners.dace_iterator.run_dace_gpu"
Expand Down Expand Up @@ -93,7 +109,9 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
USES_NEGATIVE_MODULO = "uses_negative_modulo"
USES_ORIGIN = "uses_origin"
USES_REDUCTION_OVER_LIFT_EXPRESSIONS = "uses_reduction_over_lift_expressions"
USES_SCAN = "uses_scan"
USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator"
USES_SCAN_WITHOUT_FIELD_ARGS = "uses_scan_without_field_args"
USES_SPARSE_FIELDS = "uses_sparse_fields"
USES_SPARSE_FIELDS_AS_OUTPUT = "uses_sparse_fields_as_output"
USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS = "uses_reduction_with_only_sparse_fields"
Expand All @@ -103,7 +121,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
USES_ZERO_DIMENSIONAL_FIELDS = "uses_zero_dimensional_fields"
USES_CARTESIAN_SHIFT = "uses_cartesian_shift"
USES_UNSTRUCTURED_SHIFT = "uses_unstructured_shift"
USES_SCAN = "uses_scan"
CHECKS_SPECIFIC_ERROR = "checks_specific_error"

# Skip messages (available format keys: 'marker', 'backend')
Expand Down Expand Up @@ -144,7 +161,12 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
#: Skip matrix, contains for each backend processor a list of tuples with following fields:
#: (<test_marker>, <skip_definition, <skip_message>)
BACKEND_SKIP_TEST_MATRIX = {
None: EMBEDDED_SKIP_LIST,
AllocatorId.CPU_ALLOCATOR: EMBEDDED_SKIP_LIST,
AllocatorId.GPU_ALLOCATOR: EMBEDDED_SKIP_LIST
+ [
# we can't extract the type of the output field
(USES_SCAN_WITHOUT_FIELD_ARGS, XFAIL, UNSUPPORTED_MESSAGE)
],
OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST,
OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST
+ [
Expand Down
Loading
Loading