Skip to content

Commit 087fe45

Browse files
Illviljanpre-commit-ci[bot]headtr1ck
authored
Use shape and dtype as typevars in NamedArray (#8294)
* Add from_array function * Update core.py * some fixes * Update test_namedarray.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes * fixes * fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * more * Update core.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_namedarray.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update core.py * fixes * fkxes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_namedarray.py * Update test_namedarray.py * Update test_namedarray.py * Update test_namedarray.py * Update test_namedarray.py * Update test_namedarray.py * Update test_namedarray.py * move to NDArray instead * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Rename and align more with numpy typing * Add duck type testing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * docstring * Update test_namedarray.py * fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update core.py * fixes * final * Follow numpy's example more with typing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * Update utils.py * Update utils.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Create _array_api.py * Create _typing.py * Update core.py * Update utils.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update _typing.py * Update core.py * Update utils.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Will this make pre-commit happy? * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update _array_api.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more * Update core.py * fixes * Update test_namedarray.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes * Use Self becuase Variable subclasses * fixes * Update test_namedarray.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update core.py * Update core.py * Update core.py * Update variable.py * Update variable.py * fix array api, add docstrings * Fix typing so that a different array gets correct typing * add _new with correct typing in variable * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update core.py * shape usually stays the same when copying * Update variable.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_namedarray.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_namedarray.py * same shape when astyping * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Delete test_namedarray_sketching.py * typos * remove any typing for now * fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update indexing.py * add namespace to some explicitindexing stuff * Update variable.py * Update duck_array_ops.py * Update duck_array_ops.py * fixes * Update variable.py * Fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_variable.py * Revert "Update test_variable.py" This reverts commit 6572abe. * Update _array_api.py * Update _array_api.py * Update _array_api.py * as_compatible_data lose the typing * Update indexing.py * Update core.py * Update core.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update variable.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update variable.py * Update variable.py * Update indexing.py * Update xarray/core/variable.py * cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update core.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update core.py * Update xarray/core/variable.py Co-authored-by: Michael Niklas <[email protected]> * Apply suggestions from code review Co-authored-by: Michael Niklas <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update core.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update core.py * Update core.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update core.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Michael Niklas <[email protected]>
1 parent c25c825 commit 087fe45

File tree

7 files changed

+1114
-337
lines changed

7 files changed

+1114
-337
lines changed

xarray/core/variable.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,25 @@ def __init__(
369369
if encoding is not None:
370370
self.encoding = encoding
371371

372+
def _new(
373+
self,
374+
dims=_default,
375+
data=_default,
376+
attrs=_default,
377+
):
378+
dims_ = copy.copy(self._dims) if dims is _default else dims
379+
380+
if attrs is _default:
381+
attrs_ = None if self._attrs is None else self._attrs.copy()
382+
else:
383+
attrs_ = attrs
384+
385+
if data is _default:
386+
return type(self)(dims_, copy.copy(self._data), attrs_)
387+
else:
388+
cls_ = type(self)
389+
return cls_(dims_, data, attrs_)
390+
372391
@property
373392
def _in_memory(self):
374393
return isinstance(
@@ -905,16 +924,17 @@ def _copy(
905924
ndata = data_old
906925
else:
907926
# don't share caching between copies
908-
ndata = indexing.MemoryCachedArray(data_old.array)
927+
# TODO: MemoryCachedArray doesn't match the array api:
928+
ndata = indexing.MemoryCachedArray(data_old.array) # type: ignore[assignment]
909929

910930
if deep:
911931
ndata = copy.deepcopy(ndata, memo)
912932

913933
else:
914934
ndata = as_compatible_data(data)
915-
if self.shape != ndata.shape:
935+
if self.shape != ndata.shape: # type: ignore[attr-defined]
916936
raise ValueError(
917-
f"Data shape {ndata.shape} must match shape of object {self.shape}"
937+
f"Data shape {ndata.shape} must match shape of object {self.shape}" # type: ignore[attr-defined]
918938
)
919939

920940
attrs = copy.deepcopy(self._attrs, memo) if deep else copy.copy(self._attrs)
@@ -1054,7 +1074,8 @@ def chunk(
10541074
# Using OuterIndexer is a pragmatic choice: dask does not yet handle
10551075
# different indexing types in an explicit way:
10561076
# https://github.com/dask/dask/issues/2883
1057-
ndata = indexing.ImplicitToExplicitIndexingAdapter(
1077+
# TODO: ImplicitToExplicitIndexingAdapter doesn't match the array api:
1078+
ndata = indexing.ImplicitToExplicitIndexingAdapter( # type: ignore[assignment]
10581079
data_old, indexing.OuterIndexer
10591080
)
10601081

@@ -2608,6 +2629,9 @@ class IndexVariable(Variable):
26082629

26092630
__slots__ = ()
26102631

2632+
# TODO: PandasIndexingAdapter doesn't match the array api:
2633+
_data: PandasIndexingAdapter # type: ignore[assignment]
2634+
26112635
def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
26122636
super().__init__(dims, data, attrs, encoding, fastpath)
26132637
if self.ndim != 1:
@@ -2756,9 +2780,9 @@ def copy(self, deep: bool = True, data: T_DuckArray | ArrayLike | None = None):
27562780

27572781
else:
27582782
ndata = as_compatible_data(data)
2759-
if self.shape != ndata.shape:
2783+
if self.shape != ndata.shape: # type: ignore[attr-defined]
27602784
raise ValueError(
2761-
f"Data shape {ndata.shape} must match shape of object {self.shape}"
2785+
f"Data shape {ndata.shape} must match shape of object {self.shape}" # type: ignore[attr-defined]
27622786
)
27632787

27642788
attrs = copy.deepcopy(self._attrs) if deep else copy.copy(self._attrs)

xarray/namedarray/_array_api.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from types import ModuleType
2+
from typing import Any
3+
4+
import numpy as np
5+
6+
from xarray.namedarray._typing import (
7+
_arrayapi,
8+
_DType,
9+
_ScalarType,
10+
_ShapeType,
11+
_SupportsImag,
12+
_SupportsReal,
13+
)
14+
from xarray.namedarray.core import NamedArray
15+
16+
17+
def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType:
18+
if isinstance(x._data, _arrayapi):
19+
return x._data.__array_namespace__()
20+
else:
21+
return np
22+
23+
24+
def astype(
25+
x: NamedArray[_ShapeType, Any], dtype: _DType, /, *, copy: bool = True
26+
) -> NamedArray[_ShapeType, _DType]:
27+
"""
28+
Copies an array to a specified data type irrespective of Type Promotion Rules rules.
29+
30+
Parameters
31+
----------
32+
x : NamedArray
33+
Array to cast.
34+
dtype : _DType
35+
Desired data type.
36+
copy : bool, optional
37+
Specifies whether to copy an array when the specified dtype matches the data
38+
type of the input array x.
39+
If True, a newly allocated array must always be returned.
40+
If False and the specified dtype matches the data type of the input array,
41+
the input array must be returned; otherwise, a newly allocated array must be
42+
returned. Default: True.
43+
44+
Returns
45+
-------
46+
out : NamedArray
47+
An array having the specified data type. The returned array must have the
48+
same shape as x.
49+
50+
Examples
51+
--------
52+
>>> narr = NamedArray(("x",), np.array([1.5, 2.5]))
53+
>>> astype(narr, np.dtype(int)).data
54+
array([1, 2])
55+
"""
56+
if isinstance(x._data, _arrayapi):
57+
xp = x._data.__array_namespace__()
58+
return x._new(data=xp.astype(x, dtype, copy=copy))
59+
60+
# np.astype doesn't exist yet:
61+
return x._new(data=x._data.astype(dtype, copy=copy)) # type: ignore[attr-defined]
62+
63+
64+
def imag(
65+
x: NamedArray[_ShapeType, np.dtype[_SupportsImag[_ScalarType]]], / # type: ignore[type-var]
66+
) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]:
67+
"""
68+
Returns the imaginary component of a complex number for each element x_i of the
69+
input array x.
70+
71+
Parameters
72+
----------
73+
x : NamedArray
74+
Input array. Should have a complex floating-point data type.
75+
76+
Returns
77+
-------
78+
out : NamedArray
79+
An array containing the element-wise results. The returned array must have a
80+
floating-point data type with the same floating-point precision as x
81+
(e.g., if x is complex64, the returned array must have the floating-point
82+
data type float32).
83+
84+
Examples
85+
--------
86+
>>> narr = NamedArray(("x",), np.array([1 + 2j, 2 + 4j]))
87+
>>> imag(narr).data
88+
array([2., 4.])
89+
"""
90+
xp = _get_data_namespace(x)
91+
out = x._new(data=xp.imag(x._data))
92+
return out
93+
94+
95+
def real(
96+
x: NamedArray[_ShapeType, np.dtype[_SupportsReal[_ScalarType]]], / # type: ignore[type-var]
97+
) -> NamedArray[_ShapeType, np.dtype[_ScalarType]]:
98+
"""
99+
Returns the real component of a complex number for each element x_i of the
100+
input array x.
101+
102+
Parameters
103+
----------
104+
x : NamedArray
105+
Input array. Should have a complex floating-point data type.
106+
107+
Returns
108+
-------
109+
out : NamedArray
110+
An array containing the element-wise results. The returned array must have a
111+
floating-point data type with the same floating-point precision as x
112+
(e.g., if x is complex64, the returned array must have the floating-point
113+
data type float32).
114+
115+
Examples
116+
--------
117+
>>> narr = NamedArray(("x",), np.array([1 + 2j, 2 + 4j]))
118+
>>> real(narr).data
119+
array([1., 2.])
120+
"""
121+
xp = _get_data_namespace(x)
122+
return x._new(data=xp.real(x._data))

0 commit comments

Comments
 (0)