Skip to content

Commit

Permalink
Plot StreamObjects with line segments (plus an API change for the plo…
Browse files Browse the repository at this point in the history
…tting) (#130)

Resolves #111

This uses StreamObject.xy to construct stream segments and then plots them with a LineCollection from matplotlib.

This API follows the guidelines in

https://matplotlib.org/stable/users/explain/figure/api_interfaces.html#third-party-library-data-object-interfaces
  • Loading branch information
wkearn authored Jan 24, 2025
1 parent 1daba86 commit 3689447
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 33 deletions.
7 changes: 5 additions & 2 deletions docs/tutorial/stream.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"outputs": [],
"source": [
"import topotoolbox as tt3\n",
"import matplotlib.pyplot as plt\n",
"\n",
"dem = tt3.load_dem('tibet')\n",
"fd = tt3.FlowObject(dem);\n",
Expand All @@ -33,7 +34,7 @@
"source": [
"## Plot the stream network\n",
"\n",
"The stream network can be plotted using the `show` method on `StreamObject`. By passing the `dem` to the `overlay` parameter, the stream network is displayed on top of the DEM."
"The stream network can be plotted using the `plot` method on `StreamObject`."
]
},
{
Expand All @@ -43,7 +44,9 @@
"metadata": {},
"outputs": [],
"source": [
"s.show(overlay=dem)"
"fig, ax = plt.subplots()\n",
"ax.imshow(dem,cmap=\"terrain\")\n",
"s.plot(ax=ax,color='k');"
]
}
],
Expand Down
56 changes: 25 additions & 31 deletions src/topotoolbox/stream_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection

from .flow_object import FlowObject

Expand Down Expand Up @@ -288,41 +289,34 @@ def xy(self):

return segments

def show(self, cmap='hot', overlay: GridObject | None = None,
overlay_cmap: str = 'binary', alpha: float = 0.8) -> None:
"""
Display the StreamObject instance as an image using Matplotlib.
def plot(self, ax=None, **kwargs):
"""Plot the StreamObject
Stream segments as computed by StreamObject.xy are plotted
using a LineCollection. Note that collections are not used in
autoscaling the provided axis. If the axis limits are not
already set, by another underlying plot, for example, call
ax.autoscale_view() on the returned axes to show the plot.
Parameters
----------
cmap : str, optional
Matplotlib colormap that will be used for the stream.
overlay_cmap : str, optional
Matplotlib colormap that will be used in the background plot.
overlay : GridObject | None, optional
To overlay the stream over a dem to better visualize the stream.
alpha : float, optional
When using an dem to overlay, this controls the opacity of the dem.
ax: matplotlib.axes.Axes, optional
The axes in which to plot the StreamObject. If no axes are
given, the current axes are used.
**kwargs
Additional keyword arguments are forwarded to LineCollection
Returns
-------
matplotlib.axes.Axes
The axes into which the StreamObject has been plotted.
"""
stream = np.zeros(shape=self.shape, dtype=np.int64, order='F')
stream[np.unravel_index(self.stream,self.shape,order='F')] = 1

if overlay is not None:
if self.shape == overlay.shape:
plt.imshow(overlay, cmap=overlay_cmap, alpha=alpha)
plt.imshow(stream, cmap=cmap,
alpha=stream.astype(np.float32))
plt.show()
else:
err = (f"Shape mismatch: Stream shape {self.shape} does not "
f"match overlay shape {overlay.shape}.")
raise ValueError(err) from None
else:
plt.imshow(stream, cmap=cmap)
plt.title(self.name)
plt.colorbar()
plt.tight_layout()
plt.show()
if ax is None:
ax = plt.gca()
collection = LineCollection(self.xy(), **kwargs)
ax.add_collection(collection)
return ax

def chitransform(self,
upstream_area : GridObject | np.ndarray,
Expand Down

0 comments on commit 3689447

Please sign in to comment.