Skip to content

Commit 357a444

Browse files
authored
Support non-str Hashables in DataArray (#8559)
* support hashable dims in DataArray * add whats-new * remove uneccessary except ImportErrors * improve some typing
1 parent 08c8f9a commit 357a444

File tree

6 files changed

+91
-49
lines changed

6 files changed

+91
-49
lines changed

doc/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ Deprecations
6363
Bug fixes
6464
~~~~~~~~~
6565

66+
- Support non-string hashable dimensions in :py:class:`xarray.DataArray` (:issue:`8546`, :pull:`8559`).
67+
By `Michael Niklas <https://github.com/headtr1ck>`_.
6668
- Reverse index output of bottleneck's rolling move_argmax/move_argmin functions (:issue:`8541`, :pull:`8552`).
6769
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
6870
- Vendor `SerializableLock` from dask and use as default lock for netcdf4 backends (:issue:`8442`, :pull:`8571`).

xarray/core/dataarray.py

+26-33
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
Generic,
1212
Literal,
1313
NoReturn,
14+
TypeVar,
15+
Union,
1416
overload,
1517
)
1618

@@ -61,6 +63,7 @@
6163
ReprObject,
6264
_default,
6365
either_dict_or_kwargs,
66+
hashable,
6467
)
6568
from xarray.core.variable import (
6669
IndexVariable,
@@ -73,23 +76,11 @@
7376
from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims
7477

7578
if TYPE_CHECKING:
76-
from typing import TypeVar, Union
77-
79+
from dask.dataframe import DataFrame as DaskDataFrame
80+
from dask.delayed import Delayed
81+
from iris.cube import Cube as iris_Cube
7882
from numpy.typing import ArrayLike
7983

80-
try:
81-
from dask.dataframe import DataFrame as DaskDataFrame
82-
except ImportError:
83-
DaskDataFrame = None
84-
try:
85-
from dask.delayed import Delayed
86-
except ImportError:
87-
Delayed = None # type: ignore[misc,assignment]
88-
try:
89-
from iris.cube import Cube as iris_Cube
90-
except ImportError:
91-
iris_Cube = None
92-
9384
from xarray.backends import ZarrStore
9485
from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes
9586
from xarray.core.groupby import DataArrayGroupBy
@@ -140,7 +131,9 @@ def _check_coords_dims(shape, coords, dim):
140131

141132

142133
def _infer_coords_and_dims(
143-
shape, coords, dims
134+
shape: tuple[int, ...],
135+
coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None,
136+
dims: str | Iterable[Hashable] | None,
144137
) -> tuple[Mapping[Hashable, Any], tuple[Hashable, ...]]:
145138
"""All the logic for creating a new DataArray"""
146139

@@ -157,8 +150,7 @@ def _infer_coords_and_dims(
157150

158151
if isinstance(dims, str):
159152
dims = (dims,)
160-
161-
if dims is None:
153+
elif dims is None:
162154
dims = [f"dim_{n}" for n in range(len(shape))]
163155
if coords is not None and len(coords) == len(shape):
164156
# try to infer dimensions from coords
@@ -168,16 +160,15 @@ def _infer_coords_and_dims(
168160
for n, (dim, coord) in enumerate(zip(dims, coords)):
169161
coord = as_variable(coord, name=dims[n]).to_index_variable()
170162
dims[n] = coord.name
171-
dims = tuple(dims)
172-
elif len(dims) != len(shape):
163+
dims_tuple = tuple(dims)
164+
if len(dims_tuple) != len(shape):
173165
raise ValueError(
174166
"different number of dimensions on data "
175-
f"and dims: {len(shape)} vs {len(dims)}"
167+
f"and dims: {len(shape)} vs {len(dims_tuple)}"
176168
)
177-
else:
178-
for d in dims:
179-
if not isinstance(d, str):
180-
raise TypeError(f"dimension {d} is not a string")
169+
for d in dims_tuple:
170+
if not hashable(d):
171+
raise TypeError(f"Dimension {d} is not hashable")
181172

182173
new_coords: Mapping[Hashable, Any]
183174

@@ -189,17 +180,21 @@ def _infer_coords_and_dims(
189180
for k, v in coords.items():
190181
new_coords[k] = as_variable(v, name=k)
191182
elif coords is not None:
192-
for dim, coord in zip(dims, coords):
183+
for dim, coord in zip(dims_tuple, coords):
193184
var = as_variable(coord, name=dim)
194185
var.dims = (dim,)
195186
new_coords[dim] = var.to_index_variable()
196187

197-
_check_coords_dims(shape, new_coords, dims)
188+
_check_coords_dims(shape, new_coords, dims_tuple)
198189

199-
return new_coords, dims
190+
return new_coords, dims_tuple
200191

201192

202-
def _check_data_shape(data, coords, dims):
193+
def _check_data_shape(
194+
data: Any,
195+
coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None,
196+
dims: str | Iterable[Hashable] | None,
197+
) -> Any:
203198
if data is dtypes.NA:
204199
data = np.nan
205200
if coords is not None and utils.is_scalar(data, include_0d=False):
@@ -405,10 +400,8 @@ class DataArray(
405400
def __init__(
406401
self,
407402
data: Any = dtypes.NA,
408-
coords: Sequence[Sequence[Any] | pd.Index | DataArray]
409-
| Mapping[Any, Any]
410-
| None = None,
411-
dims: Hashable | Sequence[Hashable] | None = None,
403+
coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None = None,
404+
dims: str | Iterable[Hashable] | None = None,
412405
name: Hashable | None = None,
413406
attrs: Mapping | None = None,
414407
# internal parameters

xarray/core/dataset.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@
130130
from xarray.util.deprecation_helpers import _deprecate_positional_args
131131

132132
if TYPE_CHECKING:
133+
from dask.dataframe import DataFrame as DaskDataFrame
134+
from dask.delayed import Delayed
133135
from numpy.typing import ArrayLike
134136

135137
from xarray.backends import AbstractDataStore, ZarrStore
@@ -164,15 +166,6 @@
164166
)
165167
from xarray.core.weighted import DatasetWeighted
166168

167-
try:
168-
from dask.delayed import Delayed
169-
except ImportError:
170-
Delayed = None # type: ignore[misc,assignment]
171-
try:
172-
from dask.dataframe import DataFrame as DaskDataFrame
173-
except ImportError:
174-
DaskDataFrame = None
175-
176169

177170
# list of attributes of pd.DatetimeIndex that are ndarrays of time info
178171
_DATETIMEINDEX_COMPONENTS = [

xarray/tests/__init__.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,11 @@ def _importorskip(
5959
raise ImportError("Minimum version not satisfied")
6060
except ImportError:
6161
has = False
62-
func = pytest.mark.skipif(not has, reason=f"requires {modname}")
62+
63+
reason = f"requires {modname}"
64+
if minversion is not None:
65+
reason += f">={minversion}"
66+
func = pytest.mark.skipif(not has, reason=reason)
6367
return has, func
6468

6569

@@ -122,10 +126,7 @@ def _importorskip(
122126
not has_pandas_version_two, reason="requires pandas 2.0.0"
123127
)
124128
has_numpy_array_api, requires_numpy_array_api = _importorskip("numpy", "1.26.0")
125-
has_h5netcdf_ros3 = _importorskip("h5netcdf", "1.3.0")
126-
requires_h5netcdf_ros3 = pytest.mark.skipif(
127-
not has_h5netcdf_ros3[0], reason="requires h5netcdf 1.3.0"
128-
)
129+
has_h5netcdf_ros3, requires_h5netcdf_ros3 = _importorskip("h5netcdf", "1.3.0")
129130

130131
has_netCDF4_1_6_2_or_above, requires_netCDF4_1_6_2_or_above = _importorskip(
131132
"netCDF4", "1.6.2"

xarray/tests/test_dataarray.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -401,8 +401,8 @@ def test_constructor_invalid(self) -> None:
401401
with pytest.raises(ValueError, match=r"not a subset of the .* dim"):
402402
DataArray(data, {"x": [0, 1, 2]})
403403

404-
with pytest.raises(TypeError, match=r"is not a string"):
405-
DataArray(data, dims=["x", None])
404+
with pytest.raises(TypeError, match=r"is not hashable"):
405+
DataArray(data, dims=["x", []]) # type: ignore[list-item]
406406

407407
with pytest.raises(ValueError, match=r"conflicting sizes for dim"):
408408
DataArray([1, 2, 3], coords=[("x", [0, 1])])

xarray/tests/test_hashable.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from __future__ import annotations
2+
3+
from enum import Enum
4+
from typing import TYPE_CHECKING, Union
5+
6+
import pytest
7+
8+
from xarray import DataArray, Dataset, Variable
9+
10+
if TYPE_CHECKING:
11+
from xarray.core.types import TypeAlias
12+
13+
DimT: TypeAlias = Union[int, tuple, "DEnum", "CustomHashable"]
14+
15+
16+
class DEnum(Enum):
17+
dim = "dim"
18+
19+
20+
class CustomHashable:
21+
def __init__(self, a: int) -> None:
22+
self.a = a
23+
24+
def __hash__(self) -> int:
25+
return self.a
26+
27+
28+
parametrize_dim = pytest.mark.parametrize(
29+
"dim",
30+
[
31+
pytest.param(5, id="int"),
32+
pytest.param(("a", "b"), id="tuple"),
33+
pytest.param(DEnum.dim, id="enum"),
34+
pytest.param(CustomHashable(3), id="HashableObject"),
35+
],
36+
)
37+
38+
39+
@parametrize_dim
40+
def test_hashable_dims(dim: DimT) -> None:
41+
v = Variable([dim], [1, 2, 3])
42+
da = DataArray([1, 2, 3], dims=[dim])
43+
Dataset({"a": ([dim], [1, 2, 3])})
44+
45+
# alternative constructors
46+
DataArray(v)
47+
Dataset({"a": v})
48+
Dataset({"a": da})
49+
50+
51+
@parametrize_dim
52+
def test_dataset_variable_hashable_names(dim: DimT) -> None:
53+
Dataset({dim: ("x", [1, 2, 3])})

0 commit comments

Comments
 (0)