Skip to content

Commit 3024655

Browse files
authored
Only use necessary dims when creating temporary dataarray (#9206)
* Only use necessary dims when creating temporary dataarray * Update dataset_plot.py * Can't check only data_vars all corrds are no longer added by default * Update dataset_plot.py * Add tests * Update whats-new.rst * Update dataset_plot.py
1 parent 179c670 commit 3024655

File tree

3 files changed

+52
-5
lines changed

3 files changed

+52
-5
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ Deprecations
3737

3838
Bug fixes
3939
~~~~~~~~~
40+
- Fix scatter plot broadcasting unneccesarily. (:issue:`9129`, :pull:`9206`)
41+
By `Jimmy Westling <https://github.com/illviljan>`_.
4042
- Don't convert custom indexes to ``pandas`` indexes when computing a diff (:pull:`9157`)
4143
By `Justus Magin <https://github.com/keewis>`_.
4244
- Make :py:func:`testing.assert_allclose` work with numpy 2.0 (:issue:`9165`, :pull:`9166`).

xarray/plot/dataset_plot.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -721,8 +721,8 @@ def _temp_dataarray(ds: Dataset, y: Hashable, locals_: dict[str, Any]) -> DataAr
721721
"""Create a temporary datarray with extra coords."""
722722
from xarray.core.dataarray import DataArray
723723

724-
# Base coords:
725-
coords = dict(ds.coords)
724+
coords = dict(ds[y].coords)
725+
dims = set(ds[y].dims)
726726

727727
# Add extra coords to the DataArray from valid kwargs, if using all
728728
# kwargs there is a risk that we add unnecessary dataarrays as
@@ -732,12 +732,17 @@ def _temp_dataarray(ds: Dataset, y: Hashable, locals_: dict[str, Any]) -> DataAr
732732
coord_kwargs = locals_.keys() & valid_coord_kwargs
733733
for k in coord_kwargs:
734734
key = locals_[k]
735-
if ds.data_vars.get(key) is not None:
736-
coords[key] = ds[key]
735+
darray = ds.get(key)
736+
if darray is not None:
737+
coords[key] = darray
738+
dims.update(darray.dims)
739+
740+
# Trim dataset from unneccessary dims:
741+
ds_trimmed = ds.drop_dims(ds.sizes.keys() - dims) # TODO: Use ds.dims in the future
737742

738743
# The dataarray has to include all the dims. Broadcast to that shape
739744
# and add the additional coords:
740-
_y = ds[y].broadcast_like(ds)
745+
_y = ds[y].broadcast_like(ds_trimmed)
741746

742747
return DataArray(_y, coords=coords)
743748

xarray/tests/test_plot.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3416,3 +3416,43 @@ def test_9155() -> None:
34163416
data = xr.DataArray([1, 2, 3], dims=["x"])
34173417
fig, ax = plt.subplots(ncols=1, nrows=1)
34183418
data.plot(ax=ax)
3419+
3420+
3421+
@requires_matplotlib
3422+
def test_temp_dataarray() -> None:
3423+
from xarray.plot.dataset_plot import _temp_dataarray
3424+
3425+
x = np.arange(1, 4)
3426+
y = np.arange(4, 6)
3427+
var1 = np.arange(x.size * y.size).reshape((x.size, y.size))
3428+
var2 = np.arange(x.size * y.size).reshape((x.size, y.size))
3429+
ds = xr.Dataset(
3430+
{
3431+
"var1": (["x", "y"], var1),
3432+
"var2": (["x", "y"], 2 * var2),
3433+
"var3": (["x"], 3 * x),
3434+
},
3435+
coords={
3436+
"x": x,
3437+
"y": y,
3438+
"model": np.arange(7),
3439+
},
3440+
)
3441+
3442+
# No broadcasting:
3443+
y_ = "var1"
3444+
locals_ = {"x": "var2"}
3445+
da = _temp_dataarray(ds, y_, locals_)
3446+
assert da.shape == (3, 2)
3447+
3448+
# Broadcast from 1 to 2dim:
3449+
y_ = "var3"
3450+
locals_ = {"x": "var1"}
3451+
da = _temp_dataarray(ds, y_, locals_)
3452+
assert da.shape == (3, 2)
3453+
3454+
# Ignore non-valid coord kwargs:
3455+
y_ = "var3"
3456+
locals_ = dict(x="x", extend="var2")
3457+
da = _temp_dataarray(ds, y_, locals_)
3458+
assert da.shape == (3,)

0 commit comments

Comments
 (0)