Skip to content

Commit 5a9ff0b

Browse files
dcherianpre-commit-ci[bot]max-sixty
authored
Optimize polyfit (#9766)
* Optimize polyfit Closes #5629 1. Use Variable instead of DataArray 2. Use `reshape_blockwise` when possible following #5629 (comment) * clean up little more * more clean up * Add one comment * Update doc/whats-new.rst * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix whats-new * Update doc/whats-new.rst Co-authored-by: Maximilian Roos <[email protected]> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Maximilian Roos <[email protected]>
1 parent b16a104 commit 5a9ff0b

File tree

7 files changed

+122
-47
lines changed

7 files changed

+122
-47
lines changed

doc/whats-new.rst

+3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ New Features
2929
- Support lazy grouping by dask arrays, and allow specifying ordered groups with ``UniqueGrouper(labels=["a", "b", "c"])``
3030
(:issue:`2852`, :issue:`757`).
3131
By `Deepak Cherian <https://github.com/dcherian>`_.
32+
- Optimize :py:meth:`DataArray.polyfit` and :py:meth:`Dataset.polyfit` with dask, when used with
33+
arrays with more than two dimensions.
34+
(:issue:`5629`). By `Deepak Cherian <https://github.com/dcherian>`_.
3235

3336
Breaking changes
3437
~~~~~~~~~~~~~~~~

xarray/core/dask_array_compat.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from typing import Any
2+
3+
from xarray.namedarray.utils import module_available
4+
5+
6+
def reshape_blockwise(
7+
x: Any,
8+
shape: int | tuple[int, ...],
9+
chunks: tuple[tuple[int, ...], ...] | None = None,
10+
):
11+
if module_available("dask", "2024.08.2"):
12+
from dask.array import reshape_blockwise
13+
14+
return reshape_blockwise(x, shape=shape, chunks=chunks)
15+
else:
16+
return x.reshape(shape)

xarray/core/dask_array_ops.py

+30
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import math
4+
35
from xarray.core import dtypes, nputils
46

57

@@ -19,6 +21,23 @@ def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1):
1921
def least_squares(lhs, rhs, rcond=None, skipna=False):
2022
import dask.array as da
2123

24+
from xarray.core.dask_array_compat import reshape_blockwise
25+
26+
# The trick here is that the core dimension is axis 0.
27+
# All other dimensions need to be reshaped down to one axis for `lstsq`
28+
# (which only accepts 2D input)
29+
# and this needs to be undone after running `lstsq`
30+
# The order of values in the reshaped axes is irrelevant.
31+
# There are big gains to be had by simply reshaping the blocks on a blockwise
32+
# basis, and then undoing that transform.
33+
# We use a specific `reshape_blockwise` method in dask for this optimization
34+
if rhs.ndim > 2:
35+
out_shape = rhs.shape
36+
reshape_chunks = rhs.chunks
37+
rhs = reshape_blockwise(rhs, (rhs.shape[0], math.prod(rhs.shape[1:])))
38+
else:
39+
out_shape = None
40+
2241
lhs_da = da.from_array(lhs, chunks=(rhs.chunks[0], lhs.shape[1]))
2342
if skipna:
2443
added_dim = rhs.ndim == 1
@@ -42,6 +61,17 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
4261
# Residuals here are (1, 1) but should be (K,) as rhs is (N, K)
4362
# See issue dask/dask#6516
4463
coeffs, residuals, _, _ = da.linalg.lstsq(lhs_da, rhs)
64+
65+
if out_shape is not None:
66+
coeffs = reshape_blockwise(
67+
coeffs,
68+
shape=(coeffs.shape[0], *out_shape[1:]),
69+
chunks=((coeffs.shape[0],), *reshape_chunks[1:]),
70+
)
71+
residuals = reshape_blockwise(
72+
residuals, shape=out_shape[1:], chunks=reshape_chunks[1:]
73+
)
74+
4575
return coeffs, residuals
4676

4777

xarray/core/dataset.py

+50-46
Original file line numberDiff line numberDiff line change
@@ -9086,15 +9086,14 @@ def polyfit(
90869086
numpy.polyval
90879087
xarray.polyval
90889088
"""
9089-
from xarray.core.dataarray import DataArray
9090-
9091-
variables = {}
9089+
variables: dict[Hashable, Variable] = {}
90929090
skipna_da = skipna
90939091

90949092
x = np.asarray(_ensure_numeric(self.coords[dim]).astype(np.float64))
90959093

90969094
xname = f"{self[dim].name}_"
90979095
order = int(deg) + 1
9096+
degree_coord_values = np.arange(order)[::-1]
90989097
lhs = np.vander(x, order)
90999098

91009099
if rcond is None:
@@ -9120,46 +9119,48 @@ def polyfit(
91209119
rank = np.linalg.matrix_rank(lhs)
91219120

91229121
if full:
9123-
rank = DataArray(rank, name=xname + "matrix_rank")
9124-
variables[rank.name] = rank
9122+
rank = Variable(dims=(), data=rank)
9123+
variables[xname + "matrix_rank"] = rank
91259124
_sing = np.linalg.svd(lhs, compute_uv=False)
9126-
sing = DataArray(
9127-
_sing,
9125+
variables[xname + "singular_values"] = Variable(
91289126
dims=(degree_dim,),
9129-
coords={degree_dim: np.arange(rank - 1, -1, -1)},
9130-
name=xname + "singular_values",
9127+
data=np.concatenate([np.full((order - rank.data,), np.nan), _sing]),
91319128
)
9132-
variables[sing.name] = sing
91339129

91349130
# If we have a coordinate get its underlying dimension.
9135-
true_dim = self.coords[dim].dims[0]
9131+
(true_dim,) = self.coords[dim].dims
91369132

9137-
for name, da in self.data_vars.items():
9138-
if true_dim not in da.dims:
9133+
other_coords = {
9134+
dim: self._variables[dim]
9135+
for dim in set(self.dims) - {true_dim}
9136+
if dim in self._variables
9137+
}
9138+
present_dims: set[Hashable] = set()
9139+
for name, var in self._variables.items():
9140+
if name in self._coord_names or name in self.dims:
9141+
continue
9142+
if true_dim not in var.dims:
91399143
continue
91409144

9141-
if is_duck_dask_array(da.data) and (
9145+
if is_duck_dask_array(var._data) and (
91429146
rank != order or full or skipna is None
91439147
):
91449148
# Current algorithm with dask and skipna=False neither supports
91459149
# deficient ranks nor does it output the "full" info (issue dask/dask#6516)
91469150
skipna_da = True
91479151
elif skipna is None:
9148-
skipna_da = bool(np.any(da.isnull()))
9149-
9150-
dims_to_stack = [dimname for dimname in da.dims if dimname != true_dim]
9151-
stacked_coords: dict[Hashable, DataArray] = {}
9152-
if dims_to_stack:
9153-
stacked_dim = utils.get_temp_dimname(dims_to_stack, "stacked")
9154-
rhs = da.transpose(true_dim, *dims_to_stack).stack(
9155-
{stacked_dim: dims_to_stack}
9156-
)
9157-
stacked_coords = {stacked_dim: rhs[stacked_dim]}
9158-
scale_da = scale[:, np.newaxis]
9152+
skipna_da = bool(np.any(var.isnull()))
9153+
9154+
if var.ndim > 1:
9155+
rhs = var.transpose(true_dim, ...)
9156+
other_dims = rhs.dims[1:]
9157+
scale_da = scale.reshape(-1, *((1,) * len(other_dims)))
91599158
else:
9160-
rhs = da
9159+
rhs = var
91619160
scale_da = scale
9161+
other_dims = ()
91629162

9163+
present_dims.update(other_dims)
91639164
if w is not None:
91649165
rhs = rhs * w[:, np.newaxis]
91659166

@@ -9179,42 +9180,45 @@ def polyfit(
91799180
# Thus a ReprObject => polyfit was called on a DataArray
91809181
name = ""
91819182

9182-
coeffs = DataArray(
9183-
coeffs / scale_da,
9184-
dims=[degree_dim] + list(stacked_coords.keys()),
9185-
coords={degree_dim: np.arange(order)[::-1], **stacked_coords},
9186-
name=name + "polyfit_coefficients",
9183+
variables[name + "polyfit_coefficients"] = Variable(
9184+
data=coeffs / scale_da, dims=(degree_dim,) + other_dims
91879185
)
9188-
if dims_to_stack:
9189-
coeffs = coeffs.unstack(stacked_dim)
9190-
variables[coeffs.name] = coeffs
91919186

91929187
if full or (cov is True):
9193-
residuals = DataArray(
9194-
residuals if dims_to_stack else residuals.squeeze(),
9195-
dims=list(stacked_coords.keys()),
9196-
coords=stacked_coords,
9197-
name=name + "polyfit_residuals",
9188+
variables[name + "polyfit_residuals"] = Variable(
9189+
data=residuals if var.ndim > 1 else residuals.squeeze(),
9190+
dims=other_dims,
91989191
)
9199-
if dims_to_stack:
9200-
residuals = residuals.unstack(stacked_dim)
9201-
variables[residuals.name] = residuals
92029192

92039193
if cov:
92049194
Vbase = np.linalg.inv(np.dot(lhs.T, lhs))
92059195
Vbase /= np.outer(scale, scale)
9196+
if TYPE_CHECKING:
9197+
fac: int | Variable
92069198
if cov == "unscaled":
92079199
fac = 1
92089200
else:
92099201
if x.shape[0] <= order:
92109202
raise ValueError(
92119203
"The number of data points must exceed order to scale the covariance matrix."
92129204
)
9213-
fac = residuals / (x.shape[0] - order)
9214-
covariance = DataArray(Vbase, dims=("cov_i", "cov_j")) * fac
9215-
variables[name + "polyfit_covariance"] = covariance
9205+
fac = variables[name + "polyfit_residuals"] / (x.shape[0] - order)
9206+
variables[name + "polyfit_covariance"] = (
9207+
Variable(data=Vbase, dims=("cov_i", "cov_j")) * fac
9208+
)
92169209

9217-
return type(self)(data_vars=variables, attrs=self.attrs.copy())
9210+
return type(self)(
9211+
data_vars=variables,
9212+
coords={
9213+
degree_dim: degree_coord_values,
9214+
**{
9215+
name: coord
9216+
for name, coord in other_coords.items()
9217+
if name in present_dims
9218+
},
9219+
},
9220+
attrs=self.attrs.copy(),
9221+
)
92189222

92199223
def pad(
92209224
self,

xarray/core/nputils.py

+10
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,12 @@ def warn_on_deficient_rank(rank, order):
255255

256256

257257
def least_squares(lhs, rhs, rcond=None, skipna=False):
258+
if rhs.ndim > 2:
259+
out_shape = rhs.shape
260+
rhs = rhs.reshape(rhs.shape[0], -1)
261+
else:
262+
out_shape = None
263+
258264
if skipna:
259265
added_dim = rhs.ndim == 1
260266
if added_dim:
@@ -281,6 +287,10 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
281287
if residuals.size == 0:
282288
residuals = coeffs[0] * np.nan
283289
warn_on_deficient_rank(rank, lhs.shape[1])
290+
291+
if out_shape is not None:
292+
coeffs = coeffs.reshape(-1, *out_shape[1:])
293+
residuals = residuals.reshape(*out_shape[1:])
284294
return coeffs, residuals
285295

286296

xarray/tests/test_dataarray.py

+12
Original file line numberDiff line numberDiff line change
@@ -4308,6 +4308,18 @@ def test_polyfit(self, use_dask, use_datetime) -> None:
43084308
out = da.polyfit("x", 8, full=True)
43094309
np.testing.assert_array_equal(out.polyfit_residuals.isnull(), [True, False])
43104310

4311+
@requires_dask
4312+
def test_polyfit_nd_dask(self) -> None:
4313+
da = (
4314+
DataArray(np.arange(120), dims="time", coords={"time": np.arange(120)})
4315+
.chunk({"time": 20})
4316+
.expand_dims(lat=5, lon=5)
4317+
.chunk({"lat": 2, "lon": 2})
4318+
)
4319+
actual = da.polyfit("time", 1, skipna=False)
4320+
expected = da.compute().polyfit("time", 1, skipna=False)
4321+
assert_allclose(actual, expected)
4322+
43114323
def test_pad_constant(self) -> None:
43124324
ar = DataArray(np.arange(3 * 4 * 5).reshape(3, 4, 5))
43134325
actual = ar.pad(dim_0=(1, 3))

xarray/tests/test_dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6698,7 +6698,7 @@ def test_polyfit_coord(self) -> None:
66986698

66996699
out = ds.polyfit("numbers", 2, full=False)
67006700
assert "var3_polyfit_coefficients" in out
6701-
assert "dim1" in out
6701+
assert "dim1" in out.dims
67026702
assert "dim2" not in out
67036703
assert "dim3" not in out
67046704

0 commit comments

Comments
 (0)