From feb67cc0b6ffd3769bce1ef84a3d38c30f09dcd7 Mon Sep 17 00:00:00 2001 From: Drew Leonard Date: Tue, 29 Oct 2024 19:13:29 +0000 Subject: [PATCH] Update `TiledDataset.plot()` (#441) Co-authored-by: Stuart Mumford --- changelog/441.trivial.rst | 1 + dkist/dataset/tiled_dataset.py | 80 +++++++++++++++++++++++++++++----- pyproject.toml | 2 +- 3 files changed, 72 insertions(+), 11 deletions(-) create mode 100644 changelog/441.trivial.rst diff --git a/changelog/441.trivial.rst b/changelog/441.trivial.rst new file mode 100644 index 000000000..ae78a7104 --- /dev/null +++ b/changelog/441.trivial.rst @@ -0,0 +1 @@ +Minor updates to `TiledDataset.plot()` for working with more complex arrangements of tiles. diff --git a/dkist/dataset/tiled_dataset.py b/dkist/dataset/tiled_dataset.py index cd9474e1b..3c315a3b1 100644 --- a/dkist/dataset/tiled_dataset.py +++ b/dkist/dataset/tiled_dataset.py @@ -152,13 +152,31 @@ def tiles_shape(self): """ return [[tile.data.shape for tile in row] for row in self] - def plot(self, slice_index: int, share_zscale=False, **kwargs): + def plot(self, slice_index, share_zscale=False, **kwargs): + """ + Plot a slice of each tile in the TiledDataset + + Parameters + ---------- + slice_index : `int`, sequence of `int`s or `numpy.s_` + Object representing a slice which will reduce each component dataset + of the TiledDataset to a 2D image. This is passed to + ``TiledDataset.slice_tiles`` + share_zscale : `bool` + Determines whether the color scale of the plots should be calculated + independently (``False``) or shared across all plots (``True``). + Defaults to False + """ + if isinstance(slice_index, int): + slice_index = (slice_index,) vmin, vmax = np.inf, 0 fig = plt.figure() - for i, tile in enumerate(self.flat): - ax = fig.add_subplot(self.shape[0], self.shape[1], i+1, projection=tile[0].wcs) - tile[slice_index].plot(axes=ax, **kwargs) + tiles = self.slice_tiles[slice_index].flat + for i, tile in enumerate(tiles): + ax = fig.add_subplot(self.shape[0], self.shape[1], i+1, projection=tile.wcs) + tile.plot(axes=ax, **kwargs) if i == 0: + # TODO: When we can depend on astropy >=7.0 we can remove these or statements xlabel = ax.coords[0].get_axislabel() or ax.coords[0]._get_default_axislabel() ylabel = ax.coords[1].get_axislabel() or ax.coords[1]._get_default_axislabel() for coord in ax.coords: @@ -174,12 +192,58 @@ def plot(self, slice_index: int, share_zscale=False, **kwargs): if share_zscale: for ax in fig.get_axes(): ax.get_images()[0].set_clim(vmin, vmax) - timestamp = self[0, 0].axis_world_coords("time")[-1].iso[slice_index] - fig.suptitle(f"{self.inventory['instrumentName']} Dataset ({self.inventory['datasetId']}) at time {timestamp} (slice={slice_index})", y=0.95) + title = f"{self.inventory['instrumentName']} Dataset ({self.inventory['datasetId']}) at " + for i, (coord, val) in enumerate(list(tiles[0].global_coords.items())[::-1]): + if coord == "time": + val = val.iso + if coord == "stokes": + val = val.symbol + title += f"{coord} {val}" + (", " if i != len(slice_index)-1 else " ") + title += f"(slice={(slice_index if len(slice_index) > 1 else slice_index[0])})".replace("slice(None, None, None)", ":") + fig.suptitle(title, y=0.95) return fig @property def slice_tiles(self): + """ + Returns a new TiledDataset with the given slice applied to each of the tiles. + + Examples + -------- + .. code-block:: python + + >>> from dkist import load_dataset + >>> from dkist.data.sample import VBI_AJQWW # doctest: +REMOTE_DATA + >>> ds = load_dataset(VBI_AJQWW) # doctest: +REMOTE_DATA + >>> ds.slice_tiles[0, 10:-10] # doctest: +REMOTE_DATA + + This VBI Dataset AJQWW is an array of (3, 3) Dataset objects and + consists of 9 frames. + Files are stored in ... + + Each Dataset has 2 pixel and 2 world dimensions. + + The data are represented by a object: + dask.array + + Array Dim Axis Name Data size Bounds + 0 helioprojective latitude 4076 None + 1 helioprojective longitude 4096 None + + World Dim Axis Name Physical Type Units + 1 helioprojective latitude custom:pos.helioprojective.lat arcsec + 0 helioprojective longitude custom:pos.helioprojective.lon arcsec + + Correlation between pixel and world axes: + + | PIXEL DIMENSIONS + | helioprojective | helioprojective + WORLD DIMENSIONS | longitude | latitude + ------------------------- | --------------- | --------------- + helioprojective longitude | x | x + helioprojective latitude | x | x + """ + return TiledDatasetSlicer(self._data, self.inventory) # TODO: def regrid() @@ -199,10 +263,6 @@ def files(self): """ A `~.FileManager` helper for interacting with the files backing the data in this ``Dataset``. """ - return self._file_manager - - @property - def _file_manager(self): fileuris = [[tile.files.filenames for tile in row] for row in self] dtype = self[0, 0].files.fileuri_array.dtype shape = self[0, 0].files.shape diff --git a/pyproject.toml b/pyproject.toml index f337271d7..d29b878d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ tests = [ "pytest-mpl", "pytest-httpserver", "pytest-filter-subpackage", - "pytest-benchmark", + "pytest-benchmark<5", "pytest-xdist", "hypothesis", "tox",