Skip to content

Commit f0986bb

Browse files
havogtegparedes
andauthored
feat[next]: Enable tests for embedded with cupy (#1372)
Introduces mechanism in tests for having different allocators for the same (`None`) backend. Fixes: - The resulting buffer for scan is deduced from the buffer type of the arguments, if there are no arguments we fallback to numpy (maybe break). We need to find a mechanism for this corner case. Currently these tests are excluded with `pytest.mark.uses_scan_without_field_args` for cupy embedded execution. Refactoring: - make common.field and common.connectivity private - rename next_tests.exclusion_matrices to definitions TODOs for later: - `broadcast` of scalar ignores the broadcast --------- Co-authored-by: Enrique González Paredes <[email protected]>
1 parent 8c3b3d7 commit f0986bb

29 files changed

+274
-194
lines changed

docs/development/ADRs/0015-Test_Exclusion_Matrices.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
---
2-
tags: []
2+
tags: [testing]
33
---
44

55
# Test-Exclusion Matrices
66

77
- **Status**: valid
88
- **Authors**: Edoardo Paone (@edopao), Enrique G. Paredes (@egparedes)
99
- **Created**: 2023-09-21
10-
- **Updated**: 2023-09-21
10+
- **Updated**: 2024-01-25
1111

1212
In the context of Field View testing, lacking support for specific ITIR features while a certain backend
1313
is being developed, we decided to use `pytest` fixtures to exclude unsupported tests.
@@ -22,7 +22,7 @@ the supported backends, while keeping the test code clean.
2222
## Decision
2323

2424
It was decided to apply fixtures and markers from `pytest` module. The fixture is the same used to execute the test
25-
on different backends (`fieldview_backend` and `program_processor`), but it is extended with a check on the available feature markers.
25+
on different backends (`exec_alloc_descriptor` and `program_processor`), but it is extended with a check on the available feature markers.
2626
If a test is annotated with a feature marker, the fixture will check if this feature is supported on the selected backend.
2727
If no marker is specified, the test is supposed to run on all backends.
2828

@@ -33,7 +33,7 @@ In the example below, `test_offset_field` requires the backend to support dynami
3333
def test_offset_field(cartesian_case):
3434
```
3535

36-
In order to selectively enable the backends, the dictionary `next_tests.exclusion_matrices.BACKEND_SKIP_TEST_MATRIX`
36+
In order to selectively enable the backends, the dictionary `next_tests.definitions.BACKEND_SKIP_TEST_MATRIX`
3737
lists for each backend the features that are not supported.
3838
The fixture will check if the annotated feature is present in the exclusion-matrix for the selected backend.
3939
If so, the exclusion matrix will also specify the action `pytest` should take (e.g. `SKIP` or `XFAIL`).

pyproject.toml

+5-1
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,11 @@ markers = [
340340
'uses_origin: tests that require backend support for domain origin',
341341
'uses_reduction_over_lift_expressions: tests that require backend support for reduction over lift expressions',
342342
'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields',
343+
'uses_scan: tests that uses scan',
343344
'uses_scan_in_field_operator: tests that require backend support for scan in field operator',
345+
'uses_scan_without_field_args: tests that require calls to scan that do not have any fields as arguments',
346+
'uses_scan_nested: tests that use nested scans',
347+
'uses_scan_requiring_projector: tests need a projector implementation in gtfn',
344348
'uses_sparse_fields: tests that require backend support for sparse fields',
345349
'uses_sparse_fields_as_output: tests that require backend support for writing sparse fields',
346350
'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset',
@@ -349,7 +353,7 @@ markers = [
349353
'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields',
350354
'uses_cartesian_shift: tests that use a Cartesian connectivity',
351355
'uses_unstructured_shift: tests that use a unstructured connectivity',
352-
'uses_scan: tests that uses scan',
356+
'uses_max_over: tests that use the max_over builtin',
353357
'checks_specific_error: tests that rely on the backend to produce a specific error message'
354358
]
355359
norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*']

src/gt4py/next/allocators.py

+1
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def __init__(self) -> None:
231231

232232
device_allocators[core_defs.DeviceType.CPU] = StandardCPUFieldBufferAllocator()
233233

234+
234235
assert is_field_allocator(device_allocators[core_defs.DeviceType.CPU])
235236

236237

src/gt4py/next/common.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -843,8 +843,10 @@ def is_connectivity_field(
843843
return isinstance(v, ConnectivityField) # type: ignore[misc] # we use extended_runtime_checkable
844844

845845

846+
# Utility function to construct a `Field` from different buffer representations.
847+
# Consider removing this function and using `Field` constructor directly. See also `_connectivity`.
846848
@functools.singledispatch
847-
def field(
849+
def _field(
848850
definition: Any,
849851
/,
850852
*,
@@ -854,8 +856,9 @@ def field(
854856
raise NotImplementedError
855857

856858

859+
# See comment for `_field`.
857860
@functools.singledispatch
858-
def connectivity(
861+
def _connectivity(
859862
definition: Any,
860863
/,
861864
codomain: Dimension,
@@ -980,7 +983,7 @@ def restrict(self, index: AnyIndexSpec) -> core_defs.IntegralScalar:
980983
__getitem__ = restrict
981984

982985

983-
connectivity.register(numbers.Integral, CartesianConnectivity.from_offset)
986+
_connectivity.register(numbers.Integral, CartesianConnectivity.from_offset)
984987

985988

986989
@enum.unique

src/gt4py/next/constructors.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def empty(
8787
buffer = next_allocators.allocate(
8888
domain, dtype, aligned_index=aligned_index, allocator=allocator, device=device
8989
)
90-
res = common.field(buffer.ndarray, domain=domain)
90+
res = common._field(buffer.ndarray, domain=domain)
9191
assert common.is_mutable_field(res)
9292
assert isinstance(res, nd_array_field.NdArrayField)
9393
return res
@@ -356,9 +356,9 @@ def as_connectivity(
356356
if (allocator is None) and (device is None) and xtyping.supports_dlpack(data):
357357
device = core_defs.Device(*data.__dlpack_device__())
358358
buffer = next_allocators.allocate(actual_domain, dtype, allocator=allocator, device=device)
359-
# TODO(havogt): consider addin MutableNDArrayObject
359+
# TODO(havogt): consider adding MutableNDArrayObject
360360
buffer.ndarray[...] = storage_utils.asarray(data) # type: ignore[index]
361-
connectivity_field = common.connectivity(
361+
connectivity_field = common._connectivity(
362362
buffer.ndarray, codomain=codomain, domain=actual_domain
363363
)
364364
assert isinstance(connectivity_field, nd_array_field.NdArrayConnectivityField)

src/gt4py/next/embedded/nd_array_field.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,7 @@ class NdArrayField(
9595
_domain: common.Domain
9696
_ndarray: core_defs.NDArrayObject
9797

98-
array_ns: ClassVar[
99-
ModuleType
100-
] # TODO(havogt) after storage PR is merged, update to the NDArrayNamespace protocol
98+
array_ns: ClassVar[ModuleType] # TODO(havogt) introduce a NDArrayNamespace protocol
10199

102100
@property
103101
def domain(self) -> common.Domain:
@@ -197,7 +195,11 @@ def remap(
197195
# finally, take the new array
198196
new_buffer = xp.take(self._ndarray, new_idx_array, axis=dim_idx)
199197

200-
return self.__class__.from_array(new_buffer, domain=new_domain, dtype=self.dtype)
198+
return self.__class__.from_array(
199+
new_buffer,
200+
domain=new_domain,
201+
dtype=self.dtype,
202+
)
201203

202204
__call__ = remap # type: ignore[assignment]
203205

@@ -510,15 +512,15 @@ class NumPyArrayField(NdArrayField):
510512
array_ns: ClassVar[ModuleType] = np
511513

512514

513-
common.field.register(np.ndarray, NumPyArrayField.from_array)
515+
common._field.register(np.ndarray, NumPyArrayField.from_array)
514516

515517

516518
@dataclasses.dataclass(frozen=True, eq=False)
517519
class NumPyArrayConnectivityField(NdArrayConnectivityField):
518520
array_ns: ClassVar[ModuleType] = np
519521

520522

521-
common.connectivity.register(np.ndarray, NumPyArrayConnectivityField.from_array)
523+
common._connectivity.register(np.ndarray, NumPyArrayConnectivityField.from_array)
522524

523525
# CuPy
524526
if cp:
@@ -528,13 +530,13 @@ class NumPyArrayConnectivityField(NdArrayConnectivityField):
528530
class CuPyArrayField(NdArrayField):
529531
array_ns: ClassVar[ModuleType] = cp
530532

531-
common.field.register(cp.ndarray, CuPyArrayField.from_array)
533+
common._field.register(cp.ndarray, CuPyArrayField.from_array)
532534

533535
@dataclasses.dataclass(frozen=True, eq=False)
534536
class CuPyArrayConnectivityField(NdArrayConnectivityField):
535537
array_ns: ClassVar[ModuleType] = cp
536538

537-
common.connectivity.register(cp.ndarray, CuPyArrayConnectivityField.from_array)
539+
common._connectivity.register(cp.ndarray, CuPyArrayConnectivityField.from_array)
538540

539541
# JAX
540542
if jnp:
@@ -552,7 +554,7 @@ def __setitem__(
552554
# TODO(havogt): use something like `self.ndarray = self.ndarray.at(index).set(value)`
553555
raise NotImplementedError("'__setitem__' for JaxArrayField not yet implemented.")
554556

555-
common.field.register(jnp.ndarray, JaxArrayField.from_array)
557+
common._field.register(jnp.ndarray, JaxArrayField.from_array)
556558

557559

558560
def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field:
@@ -565,7 +567,7 @@ def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]
565567
else:
566568
domain_slice.append(np.newaxis)
567569
named_ranges.append((dim, common.UnitRange.infinite()))
568-
return common.field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges))
570+
return common._field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges))
569571

570572

571573
def _builtins_broadcast(

src/gt4py/next/embedded/operators.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
# SPDX-License-Identifier: GPL-3.0-or-later
1414

1515
import dataclasses
16+
from types import ModuleType
1617
from typing import Any, Callable, Generic, ParamSpec, Sequence, TypeVar
1718

19+
import numpy as np
20+
1821
from gt4py import eve
1922
from gt4py._core import definitions as core_defs
20-
from gt4py.next import common, constructors, errors, utils
23+
from gt4py.next import common, errors, utils
2124
from gt4py.next.embedded import common as embedded_common, context as embedded_context
2225

2326

@@ -43,7 +46,8 @@ def __call__(self, *args: common.Field | core_defs.Scalar, **kwargs: common.Fiel
4346
scan_range = embedded_context.closure_column_range.get()
4447
assert self.axis == scan_range[0]
4548
scan_axis = scan_range[0]
46-
domain_intersection = _intersect_scan_args(*args, *kwargs.values())
49+
all_args = [*args, *kwargs.values()]
50+
domain_intersection = _intersect_scan_args(*all_args)
4751
non_scan_domain = common.Domain(*[nr for nr in domain_intersection if nr[0] != scan_axis])
4852

4953
out_domain = common.Domain(
@@ -53,7 +57,8 @@ def __call__(self, *args: common.Field | core_defs.Scalar, **kwargs: common.Fiel
5357
# even if the scan dimension is not in the input, we can scan over it
5458
out_domain = common.Domain(*out_domain, (scan_range))
5559

56-
res = _construct_scan_array(out_domain)(self.init)
60+
xp = _get_array_ns(*all_args)
61+
res = _construct_scan_array(out_domain, xp)(self.init)
5762

5863
def scan_loop(hpos):
5964
acc = self.init
@@ -128,7 +133,11 @@ def _tuple_assign_field(
128133
):
129134
@utils.tree_map
130135
def impl(target: common.MutableField, source: common.Field):
131-
target[domain] = source[domain]
136+
if common.is_field(source):
137+
target[domain] = source[domain]
138+
else:
139+
assert core_defs.is_scalar_type(source)
140+
target[domain] = source
132141

133142
impl(target, source)
134143

@@ -141,10 +150,21 @@ def _intersect_scan_args(
141150
)
142151

143152

144-
def _construct_scan_array(domain: common.Domain):
153+
def _get_array_ns(
154+
*args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...]
155+
) -> ModuleType:
156+
for arg in utils.flatten_nested_tuple(args):
157+
if hasattr(arg, "array_ns"):
158+
return arg.array_ns
159+
return np
160+
161+
162+
def _construct_scan_array(
163+
domain: common.Domain, xp: ModuleType
164+
): # TODO(havogt) introduce a NDArrayNamespace protocol
145165
@utils.tree_map
146166
def impl(init: core_defs.Scalar) -> common.Field:
147-
return constructors.empty(domain, dtype=type(init))
167+
return common._field(xp.empty(domain.shape, dtype=type(init)), domain=domain)
148168

149169
return impl
150170

@@ -168,6 +188,7 @@ def _tuple_at(
168188
@utils.tree_map
169189
def impl(field: common.Field | core_defs.Scalar) -> core_defs.Scalar:
170190
res = field[pos] if common.is_field(field) else field
191+
res = res.item() if hasattr(res, "item") else res # extract scalar value from array
171192
assert core_defs.is_scalar_type(res)
172193
return res
173194

src/gt4py/next/ffront/fbuiltins.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -188,12 +188,8 @@ def broadcast(
188188
assert core_defs.is_scalar_type(
189189
field
190190
) # default implementation for scalars, Fields are handled via dispatch
191-
return common.field(
192-
np.asarray(field)[
193-
tuple([np.newaxis] * len(dims))
194-
], # TODO(havogt) use FunctionField once available
195-
domain=common.Domain(dims=dims, ranges=tuple([common.UnitRange.infinite()] * len(dims))),
196-
)
191+
# TODO(havogt) implement with FunctionField, the workaround is to ignore broadcasting on scalars as they broadcast automatically, but we lose the check for compatible dimensions
192+
return field # type: ignore[return-value] # see comment above
197193

198194

199195
@WhereBuiltinFunction

src/gt4py/next/iterator/embedded.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1035,7 +1035,7 @@ def _maker(a) -> common.Field:
10351035
offset = origin.get(d, 0)
10361036
ranges.append(common.UnitRange(-offset, s - offset))
10371037

1038-
res = common.field(a, domain=common.Domain(dims=tuple(axes), ranges=tuple(ranges)))
1038+
res = common._field(a, domain=common.Domain(dims=tuple(axes), ranges=tuple(ranges)))
10391039
return res
10401040

10411041
return _maker

tests/next_tests/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
#
1313
# SPDX-License-Identifier: GPL-3.0-or-later
1414

15-
from . import exclusion_matrices
15+
from . import definitions
1616

1717

18-
__all__ = ["exclusion_matrices", "get_processor_id"]
18+
__all__ = ["definitions", "get_processor_id"]
1919

2020

2121
def get_processor_id(processor):

0 commit comments

Comments
 (0)