@@ -49,7 +49,7 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname):
49
49
add_colorbar = False
50
50
add_legend = False
51
51
else :
52
- if add_guide is True and funcname != "quiver" :
52
+ if add_guide is True and funcname not in ( "quiver" , "streamplot" ) :
53
53
raise ValueError ("Cannot set add_guide when hue is None." )
54
54
add_legend = False
55
55
add_colorbar = False
@@ -62,11 +62,23 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname):
62
62
hue_style = "continuous"
63
63
elif hue_style != "continuous" :
64
64
raise ValueError (
65
- "hue_style must be 'continuous' or None for .plot.quiver"
65
+ "hue_style must be 'continuous' or None for .plot.quiver or "
66
+ ".plot.streamplot"
66
67
)
67
68
else :
68
69
add_quiverkey = False
69
70
71
+ if (add_guide or add_guide is None ) and funcname == "streamplot" :
72
+ if hue :
73
+ add_colorbar = True
74
+ if not hue_style :
75
+ hue_style = "continuous"
76
+ elif hue_style != "continuous" :
77
+ raise ValueError (
78
+ "hue_style must be 'continuous' or None for .plot.quiver or "
79
+ ".plot.streamplot"
80
+ )
81
+
70
82
if hue_style is not None and hue_style not in ["discrete" , "continuous" ]:
71
83
raise ValueError ("hue_style must be either None, 'discrete' or 'continuous'." )
72
84
@@ -186,7 +198,7 @@ def _dsplot(plotfunc):
186
198
x, y : str
187
199
Variable names for x, y axis.
188
200
u, v : str, optional
189
- Variable names for quiver plots
201
+ Variable names for quiver or streamplot plots
190
202
hue: str, optional
191
203
Variable by which to color scattered points
192
204
hue_style: str, optional
@@ -338,8 +350,11 @@ def newplotfunc(
338
350
else :
339
351
cmap_params_subset = {}
340
352
341
- if (u is not None or v is not None ) and plotfunc .__name__ != "quiver" :
342
- raise ValueError ("u, v are only allowed for quiver plots." )
353
+ if (u is not None or v is not None ) and plotfunc .__name__ not in (
354
+ "quiver" ,
355
+ "streamplot" ,
356
+ ):
357
+ raise ValueError ("u, v are only allowed for quiver or streamplot plots." )
343
358
344
359
primitive = plotfunc (
345
360
ds = ds ,
@@ -383,7 +398,7 @@ def newplotfunc(
383
398
coordinates = "figure" ,
384
399
)
385
400
386
- if plotfunc .__name__ == "quiver" :
401
+ if plotfunc .__name__ in ( "quiver" , "streamplot" ) :
387
402
title = ds [u ]._title_for_slice ()
388
403
else :
389
404
title = ds [x ]._title_for_slice ()
@@ -526,3 +541,54 @@ def quiver(ds, x, y, ax, u, v, **kwargs):
526
541
kwargs .setdefault ("pivot" , "middle" )
527
542
hdl = ax .quiver (* args , ** kwargs , ** cmap_params )
528
543
return hdl
544
+
545
+
546
+ @_dsplot
547
+ def streamplot (ds , x , y , ax , u , v , ** kwargs ):
548
+ """ Quiver plot with Dataset variables."""
549
+ import matplotlib as mpl
550
+
551
+ if x is None or y is None or u is None or v is None :
552
+ raise ValueError ("Must specify x, y, u, v for streamplot plots." )
553
+
554
+ # Matplotlib's streamplot has strong restrictions on what x and y can be, so need to
555
+ # get arrays transposed the 'right' way around. 'x' cannot vary within 'rows', so
556
+ # the dimension of x must be the second dimension. 'y' cannot vary with 'columns' so
557
+ # the dimension of y must be the first dimension. If x and y are both 2d, assume the
558
+ # user has got them right already.
559
+ if len (ds [x ].dims ) == 1 :
560
+ xdim = ds [x ].dims [0 ]
561
+ if len (ds [y ].dims ) == 1 :
562
+ ydim = ds [y ].dims [0 ]
563
+ if xdim is not None and ydim is None :
564
+ ydim = set (ds [y ].dims ) - set ([xdim ])
565
+ if ydim is not None and xdim is None :
566
+ xdim = set (ds [x ].dims ) - set ([ydim ])
567
+
568
+ x , y , u , v = broadcast (ds [x ], ds [y ], ds [u ], ds [v ])
569
+
570
+ if xdim is not None and ydim is not None :
571
+ # Need to ensure the arrays are transposed correctly
572
+ x = x .transpose (ydim , xdim )
573
+ y = y .transpose (ydim , xdim )
574
+ u = u .transpose (ydim , xdim )
575
+ v = v .transpose (ydim , xdim )
576
+
577
+ args = [x .values , y .values , u .values , v .values ]
578
+ hue = kwargs .pop ("hue" )
579
+ cmap_params = kwargs .pop ("cmap_params" )
580
+
581
+ if hue :
582
+ kwargs ["color" ] = ds [hue ].values
583
+
584
+ # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params
585
+ if not cmap_params ["norm" ]:
586
+ cmap_params ["norm" ] = mpl .colors .Normalize (
587
+ cmap_params .pop ("vmin" ), cmap_params .pop ("vmax" )
588
+ )
589
+
590
+ kwargs .pop ("hue_style" )
591
+ hdl = ax .streamplot (* args , ** kwargs , ** cmap_params )
592
+
593
+ # Return .lines so colorbar creation works properly
594
+ return hdl .lines
0 commit comments