Skip to content

Commit ba47216

Browse files
johnomotanikeewis
andauthored
Add Dataset.plot.streamplot() method (#5003)
Co-authored-by: keewis <[email protected]>
1 parent 821479d commit ba47216

File tree

5 files changed

+151
-6
lines changed

5 files changed

+151
-6
lines changed

doc/api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ Plotting
243243

244244
Dataset.plot.scatter
245245
Dataset.plot.quiver
246+
Dataset.plot.streamplot
246247

247248
DataArray
248249
=========

doc/user-guide/plotting.rst

+20
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,26 @@ where ``u`` and ``v`` denote the x and y direction components of the arrow vecto
787787
788788
``scale`` is required for faceted quiver plots. The scale determines the number of data units per arrow length unit, i.e. a smaller scale parameter makes the arrow longer.
789789

790+
Streamplot
791+
~~~~~~~~~~
792+
793+
Visualizing vector fields is also supported with streamline plots:
794+
795+
.. ipython:: python
796+
:okwarning:
797+
798+
@savefig ds_simple_streamplot.png
799+
ds.isel(w=1, z=1).plot.streamplot(x="x", y="y", u="A", v="B")
800+
801+
802+
where ``u`` and ``v`` denote the x and y direction components of the vectors tangent to the streamlines. Again, faceting is also possible:
803+
804+
.. ipython:: python
805+
:okwarning:
806+
807+
@savefig ds_facet_streamplot.png
808+
ds.plot.streamplot(x="x", y="y", u="A", v="B", col="w", row="z")
809+
790810
.. _plot-maps:
791811

792812
Maps

doc/whats-new.rst

+3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ New Features
3131
- Support for `dask.graph_manipulation
3232
<https://docs.dask.org/en/latest/graph_manipulation.html>`_ (requires dask >=2021.3)
3333
By `Guido Imperiale <https://github.com/crusaderky>`_
34+
- Add :py:meth:`Dataset.plot.streamplot` for streamplot plots with :py:class:`Dataset`
35+
variables (:pull:`5003`).
36+
By `John Omotani <https://github.com/johnomotani>`_.
3437
- Many of the arguments for the :py:attr:`DataArray.str` methods now support
3538
providing an array-like input. In this case, the array provided to the
3639
arguments is broadcast against the original array and applied elementwise.

xarray/plot/dataset_plot.py

+72-6
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname):
4949
add_colorbar = False
5050
add_legend = False
5151
else:
52-
if add_guide is True and funcname != "quiver":
52+
if add_guide is True and funcname not in ("quiver", "streamplot"):
5353
raise ValueError("Cannot set add_guide when hue is None.")
5454
add_legend = False
5555
add_colorbar = False
@@ -62,11 +62,23 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname):
6262
hue_style = "continuous"
6363
elif hue_style != "continuous":
6464
raise ValueError(
65-
"hue_style must be 'continuous' or None for .plot.quiver"
65+
"hue_style must be 'continuous' or None for .plot.quiver or "
66+
".plot.streamplot"
6667
)
6768
else:
6869
add_quiverkey = False
6970

71+
if (add_guide or add_guide is None) and funcname == "streamplot":
72+
if hue:
73+
add_colorbar = True
74+
if not hue_style:
75+
hue_style = "continuous"
76+
elif hue_style != "continuous":
77+
raise ValueError(
78+
"hue_style must be 'continuous' or None for .plot.quiver or "
79+
".plot.streamplot"
80+
)
81+
7082
if hue_style is not None and hue_style not in ["discrete", "continuous"]:
7183
raise ValueError("hue_style must be either None, 'discrete' or 'continuous'.")
7284

@@ -186,7 +198,7 @@ def _dsplot(plotfunc):
186198
x, y : str
187199
Variable names for x, y axis.
188200
u, v : str, optional
189-
Variable names for quiver plots
201+
Variable names for quiver or streamplot plots
190202
hue: str, optional
191203
Variable by which to color scattered points
192204
hue_style: str, optional
@@ -338,8 +350,11 @@ def newplotfunc(
338350
else:
339351
cmap_params_subset = {}
340352

341-
if (u is not None or v is not None) and plotfunc.__name__ != "quiver":
342-
raise ValueError("u, v are only allowed for quiver plots.")
353+
if (u is not None or v is not None) and plotfunc.__name__ not in (
354+
"quiver",
355+
"streamplot",
356+
):
357+
raise ValueError("u, v are only allowed for quiver or streamplot plots.")
343358

344359
primitive = plotfunc(
345360
ds=ds,
@@ -383,7 +398,7 @@ def newplotfunc(
383398
coordinates="figure",
384399
)
385400

386-
if plotfunc.__name__ == "quiver":
401+
if plotfunc.__name__ in ("quiver", "streamplot"):
387402
title = ds[u]._title_for_slice()
388403
else:
389404
title = ds[x]._title_for_slice()
@@ -526,3 +541,54 @@ def quiver(ds, x, y, ax, u, v, **kwargs):
526541
kwargs.setdefault("pivot", "middle")
527542
hdl = ax.quiver(*args, **kwargs, **cmap_params)
528543
return hdl
544+
545+
546+
@_dsplot
547+
def streamplot(ds, x, y, ax, u, v, **kwargs):
548+
""" Quiver plot with Dataset variables."""
549+
import matplotlib as mpl
550+
551+
if x is None or y is None or u is None or v is None:
552+
raise ValueError("Must specify x, y, u, v for streamplot plots.")
553+
554+
# Matplotlib's streamplot has strong restrictions on what x and y can be, so need to
555+
# get arrays transposed the 'right' way around. 'x' cannot vary within 'rows', so
556+
# the dimension of x must be the second dimension. 'y' cannot vary with 'columns' so
557+
# the dimension of y must be the first dimension. If x and y are both 2d, assume the
558+
# user has got them right already.
559+
if len(ds[x].dims) == 1:
560+
xdim = ds[x].dims[0]
561+
if len(ds[y].dims) == 1:
562+
ydim = ds[y].dims[0]
563+
if xdim is not None and ydim is None:
564+
ydim = set(ds[y].dims) - set([xdim])
565+
if ydim is not None and xdim is None:
566+
xdim = set(ds[x].dims) - set([ydim])
567+
568+
x, y, u, v = broadcast(ds[x], ds[y], ds[u], ds[v])
569+
570+
if xdim is not None and ydim is not None:
571+
# Need to ensure the arrays are transposed correctly
572+
x = x.transpose(ydim, xdim)
573+
y = y.transpose(ydim, xdim)
574+
u = u.transpose(ydim, xdim)
575+
v = v.transpose(ydim, xdim)
576+
577+
args = [x.values, y.values, u.values, v.values]
578+
hue = kwargs.pop("hue")
579+
cmap_params = kwargs.pop("cmap_params")
580+
581+
if hue:
582+
kwargs["color"] = ds[hue].values
583+
584+
# TODO: Fix this by always returning a norm with vmin, vmax in cmap_params
585+
if not cmap_params["norm"]:
586+
cmap_params["norm"] = mpl.colors.Normalize(
587+
cmap_params.pop("vmin"), cmap_params.pop("vmax")
588+
)
589+
590+
kwargs.pop("hue_style")
591+
hdl = ax.streamplot(*args, **kwargs, **cmap_params)
592+
593+
# Return .lines so colorbar creation works properly
594+
return hdl.lines

xarray/tests/test_plot.py

+55
Original file line numberDiff line numberDiff line change
@@ -2221,6 +2221,61 @@ def test_facetgrid(self):
22212221
self.ds.plot.quiver(x="x", y="y", u="u", v="v", row="row", col="col")
22222222

22232223

2224+
@requires_matplotlib
2225+
class TestDatasetStreamplotPlots(PlotTestCase):
2226+
@pytest.fixture(autouse=True)
2227+
def setUp(self):
2228+
das = [
2229+
DataArray(
2230+
np.random.randn(3, 3, 2, 2),
2231+
dims=["x", "y", "row", "col"],
2232+
coords=[range(k) for k in [3, 3, 2, 2]],
2233+
)
2234+
for _ in [1, 2]
2235+
]
2236+
ds = Dataset({"u": das[0], "v": das[1]})
2237+
ds.x.attrs["units"] = "xunits"
2238+
ds.y.attrs["units"] = "yunits"
2239+
ds.col.attrs["units"] = "colunits"
2240+
ds.row.attrs["units"] = "rowunits"
2241+
ds.u.attrs["units"] = "uunits"
2242+
ds.v.attrs["units"] = "vunits"
2243+
ds["mag"] = np.hypot(ds.u, ds.v)
2244+
self.ds = ds
2245+
2246+
def test_streamline(self):
2247+
with figure_context():
2248+
hdl = self.ds.isel(row=0, col=0).plot.streamplot(x="x", y="y", u="u", v="v")
2249+
assert isinstance(hdl, mpl.collections.LineCollection)
2250+
with raises_regex(ValueError, "specify x, y, u, v"):
2251+
self.ds.isel(row=0, col=0).plot.streamplot(x="x", y="y", u="u")
2252+
2253+
with raises_regex(ValueError, "hue_style"):
2254+
self.ds.isel(row=0, col=0).plot.streamplot(
2255+
x="x", y="y", u="u", v="v", hue="mag", hue_style="discrete"
2256+
)
2257+
2258+
def test_facetgrid(self):
2259+
with figure_context():
2260+
fg = self.ds.plot.streamplot(
2261+
x="x", y="y", u="u", v="v", row="row", col="col", hue="mag"
2262+
)
2263+
for handle in fg._mappables:
2264+
assert isinstance(handle, mpl.collections.LineCollection)
2265+
2266+
with figure_context():
2267+
fg = self.ds.plot.streamplot(
2268+
x="x",
2269+
y="y",
2270+
u="u",
2271+
v="v",
2272+
row="row",
2273+
col="col",
2274+
hue="mag",
2275+
add_guide=False,
2276+
)
2277+
2278+
22242279
@requires_matplotlib
22252280
class TestDatasetScatterPlots(PlotTestCase):
22262281
@pytest.fixture(autouse=True)

0 commit comments

Comments
 (0)