Skip to content

Commit 6b7515c

Browse files
Fix typing errors using mypy 1.2 (#7752)
* test newest mypy * Update ci-additional.yaml * remove ignores * add typing * Use ClassVar * Generalize data_vars typing concat. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use a normal method to retrieve a type of Variable * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * ignore plotfunc error * force reinstall * remove outdated comments --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b889208 commit 6b7515c

File tree

8 files changed

+30
-28
lines changed

8 files changed

+30
-28
lines changed

.github/workflows/ci-additional.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ jobs:
119119
python xarray/util/print_versions.py
120120
- name: Install mypy
121121
run: |
122-
python -m pip install 'mypy<0.990'
122+
python -m pip install mypy --force-reinstall
123123
124124
- name: Run mypy
125125
run: |
@@ -173,7 +173,7 @@ jobs:
173173
python xarray/util/print_versions.py
174174
- name: Install mypy
175175
run: |
176-
python -m pip install 'mypy<0.990'
176+
python -m pip install mypy --force-reinstall
177177
178178
- name: Run mypy
179179
run: |

xarray/core/combine.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -369,9 +369,8 @@ def _nested_combine(
369369
return combined
370370

371371

372-
# Define type for arbitrarily-nested list of lists recursively
373-
# Currently mypy cannot handle this but other linters can (https://stackoverflow.com/a/53845083/3154101)
374-
DATASET_HYPERCUBE = Union[Dataset, Iterable["DATASET_HYPERCUBE"]] # type: ignore[misc]
372+
# Define type for arbitrarily-nested list of lists recursively:
373+
DATASET_HYPERCUBE = Union[Dataset, Iterable["DATASET_HYPERCUBE"]]
375374

376375

377376
def combine_nested(

xarray/core/concat.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Hashable, Iterable
4-
from typing import TYPE_CHECKING, Any, cast, overload
4+
from typing import TYPE_CHECKING, Any, Union, cast, overload
55

66
import pandas as pd
77

@@ -27,12 +27,14 @@
2727
JoinOptions,
2828
)
2929

30+
T_DataVars = Union[ConcatOptions, Iterable[Hashable]]
31+
3032

3133
@overload
3234
def concat(
3335
objs: Iterable[T_Dataset],
3436
dim: Hashable | T_DataArray | pd.Index,
35-
data_vars: ConcatOptions | list[Hashable] = "all",
37+
data_vars: T_DataVars = "all",
3638
coords: ConcatOptions | list[Hashable] = "different",
3739
compat: CompatOptions = "equals",
3840
positions: Iterable[Iterable[int]] | None = None,
@@ -47,7 +49,7 @@ def concat(
4749
def concat(
4850
objs: Iterable[T_DataArray],
4951
dim: Hashable | T_DataArray | pd.Index,
50-
data_vars: ConcatOptions | list[Hashable] = "all",
52+
data_vars: T_DataVars = "all",
5153
coords: ConcatOptions | list[Hashable] = "different",
5254
compat: CompatOptions = "equals",
5355
positions: Iterable[Iterable[int]] | None = None,
@@ -61,7 +63,7 @@ def concat(
6163
def concat(
6264
objs,
6365
dim,
64-
data_vars="all",
66+
data_vars: T_DataVars = "all",
6567
coords="different",
6668
compat: CompatOptions = "equals",
6769
positions=None,
@@ -291,7 +293,7 @@ def _calc_concat_dim_index(
291293
return dim, index
292294

293295

294-
def _calc_concat_over(datasets, dim, dim_names, data_vars, coords, compat):
296+
def _calc_concat_over(datasets, dim, dim_names, data_vars: T_DataVars, coords, compat):
295297
"""
296298
Determine which dataset variables need to be concatenated in the result,
297299
"""
@@ -445,7 +447,7 @@ def _parse_datasets(
445447
def _dataset_concat(
446448
datasets: list[T_Dataset],
447449
dim: str | T_DataArray | pd.Index,
448-
data_vars: str | list[str],
450+
data_vars: T_DataVars,
449451
coords: str | list[str],
450452
compat: CompatOptions,
451453
positions: Iterable[Iterable[int]] | None,
@@ -665,7 +667,7 @@ def get_indexes(name):
665667
def _dataarray_concat(
666668
arrays: Iterable[T_DataArray],
667669
dim: str | T_DataArray | pd.Index,
668-
data_vars: str | list[str],
670+
data_vars: T_DataVars,
669671
coords: str | list[str],
670672
compat: CompatOptions,
671673
positions: Iterable[Iterable[int]] | None,

xarray/core/rolling.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def _construct(
376376
window_dim = {d: window_dim_kwargs[str(d)] for d in self.dim}
377377

378378
window_dims = self._mapping_to_list(
379-
window_dim, allow_default=False, allow_allsame=False # type: ignore[arg-type] # https://github.com/python/mypy/issues/12506
379+
window_dim, allow_default=False, allow_allsame=False
380380
)
381381
strides = self._mapping_to_list(stride, default=1)
382382

@@ -753,7 +753,7 @@ def construct(
753753
window_dim = {d: window_dim_kwargs[str(d)] for d in self.dim}
754754

755755
window_dims = self._mapping_to_list(
756-
window_dim, allow_default=False, allow_allsame=False # type: ignore[arg-type] # https://github.com/python/mypy/issues/12506
756+
window_dim, allow_default=False, allow_allsame=False
757757
)
758758
strides = self._mapping_to_list(stride, default=1)
759759

xarray/core/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def get_valid_numpy_dtype(array: np.ndarray | pd.Index):
113113
dtype = np.dtype("O")
114114
elif hasattr(array, "categories"):
115115
# category isn't a real numpy dtype
116-
dtype = array.categories.dtype # type: ignore[union-attr]
116+
dtype = array.categories.dtype
117117
elif not is_valid_numpy_dtype(array.dtype):
118118
dtype = np.dtype("O")
119119
else:

xarray/tests/test_concat.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -538,8 +538,7 @@ def test_concat_data_vars_typing(self) -> None:
538538
actual = concat(objs, dim="x", data_vars="minimal")
539539
assert_identical(data, actual)
540540

541-
def test_concat_data_vars(self):
542-
# TODO: annotating this func fails
541+
def test_concat_data_vars(self) -> None:
543542
data = Dataset({"foo": ("x", np.random.randn(10))})
544543
objs: list[Dataset] = [data.isel(x=slice(5)), data.isel(x=slice(5, None))]
545544
for data_vars in ["minimal", "different", "all", [], ["foo"]]:

xarray/tests/test_plot.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2042,7 +2042,7 @@ def test_seaborn_palette_as_cmap(self) -> None:
20422042
def test_convenient_facetgrid(self) -> None:
20432043
a = easy_array((10, 15, 4))
20442044
d = DataArray(a, dims=["y", "x", "z"])
2045-
g = self.plotfunc(d, x="x", y="y", col="z", col_wrap=2)
2045+
g = self.plotfunc(d, x="x", y="y", col="z", col_wrap=2) # type: ignore[arg-type] # https://github.com/python/mypy/issues/15015
20462046

20472047
assert_array_equal(g.axs.shape, [2, 2])
20482048
for (y, x), ax in np.ndenumerate(g.axs):
@@ -2051,7 +2051,7 @@ def test_convenient_facetgrid(self) -> None:
20512051
assert "x" == ax.get_xlabel()
20522052

20532053
# Inferring labels
2054-
g = self.plotfunc(d, col="z", col_wrap=2)
2054+
g = self.plotfunc(d, col="z", col_wrap=2) # type: ignore[arg-type] # https://github.com/python/mypy/issues/15015
20552055
assert_array_equal(g.axs.shape, [2, 2])
20562056
for (y, x), ax in np.ndenumerate(g.axs):
20572057
assert ax.has_data()

xarray/tests/test_variable.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import warnings
4+
from abc import ABC, abstractmethod
45
from copy import copy, deepcopy
56
from datetime import datetime, timedelta
67
from textwrap import dedent
@@ -61,8 +62,10 @@ def var():
6162
return Variable(dims=list("xyz"), data=np.random.rand(3, 4, 5))
6263

6364

64-
class VariableSubclassobjects:
65-
cls: staticmethod[Variable]
65+
class VariableSubclassobjects(ABC):
66+
@abstractmethod
67+
def cls(self, *args, **kwargs) -> Variable:
68+
raise NotImplementedError
6669

6770
def test_properties(self):
6871
data = 0.5 * np.arange(10)
@@ -1056,7 +1059,8 @@ def test_rolling_window_errors(self, dim, window, window_dim, center):
10561059

10571060

10581061
class TestVariable(VariableSubclassobjects):
1059-
cls = staticmethod(Variable)
1062+
def cls(self, *args, **kwargs) -> Variable:
1063+
return Variable(*args, **kwargs)
10601064

10611065
@pytest.fixture(autouse=True)
10621066
def setup(self):
@@ -2228,13 +2232,10 @@ def test_coarsen_keep_attrs(self, operation="mean"):
22282232
assert new.attrs == _attrs
22292233

22302234

2231-
def _init_dask_variable(*args, **kwargs):
2232-
return Variable(*args, **kwargs).chunk()
2233-
2234-
22352235
@requires_dask
22362236
class TestVariableWithDask(VariableSubclassobjects):
2237-
cls = staticmethod(_init_dask_variable)
2237+
def cls(self, *args, **kwargs) -> Variable:
2238+
return Variable(*args, **kwargs).chunk()
22382239

22392240
def test_chunk(self):
22402241
unblocked = Variable(["dim_0", "dim_1"], np.ones((3, 4)))
@@ -2346,7 +2347,8 @@ def test_as_sparse(self):
23462347

23472348

23482349
class TestIndexVariable(VariableSubclassobjects):
2349-
cls = staticmethod(IndexVariable)
2350+
def cls(self, *args, **kwargs) -> IndexVariable:
2351+
return IndexVariable(*args, **kwargs)
23502352

23512353
def test_init(self):
23522354
with pytest.raises(ValueError, match=r"must be 1-dimensional"):

0 commit comments

Comments
 (0)