Skip to content

Commit

Permalink
Update TiledDataset.plot() (#441)
Browse files Browse the repository at this point in the history
Co-authored-by: Stuart Mumford <[email protected]>
  • Loading branch information
SolarDrew and Cadair authored Oct 29, 2024
1 parent c696a85 commit feb67cc
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 11 deletions.
1 change: 1 addition & 0 deletions changelog/441.trivial.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Minor updates to `TiledDataset.plot()` for working with more complex arrangements of tiles.
80 changes: 70 additions & 10 deletions dkist/dataset/tiled_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,31 @@ def tiles_shape(self):
"""
return [[tile.data.shape for tile in row] for row in self]

def plot(self, slice_index: int, share_zscale=False, **kwargs):
def plot(self, slice_index, share_zscale=False, **kwargs):
"""
Plot a slice of each tile in the TiledDataset
Parameters
----------
slice_index : `int`, sequence of `int`s or `numpy.s_`
Object representing a slice which will reduce each component dataset
of the TiledDataset to a 2D image. This is passed to
``TiledDataset.slice_tiles``
share_zscale : `bool`
Determines whether the color scale of the plots should be calculated
independently (``False``) or shared across all plots (``True``).
Defaults to False
"""
if isinstance(slice_index, int):
slice_index = (slice_index,)
vmin, vmax = np.inf, 0
fig = plt.figure()
for i, tile in enumerate(self.flat):
ax = fig.add_subplot(self.shape[0], self.shape[1], i+1, projection=tile[0].wcs)
tile[slice_index].plot(axes=ax, **kwargs)
tiles = self.slice_tiles[slice_index].flat
for i, tile in enumerate(tiles):
ax = fig.add_subplot(self.shape[0], self.shape[1], i+1, projection=tile.wcs)
tile.plot(axes=ax, **kwargs)
if i == 0:
# TODO: When we can depend on astropy >=7.0 we can remove these or statements
xlabel = ax.coords[0].get_axislabel() or ax.coords[0]._get_default_axislabel()
ylabel = ax.coords[1].get_axislabel() or ax.coords[1]._get_default_axislabel()
for coord in ax.coords:
Expand All @@ -174,12 +192,58 @@ def plot(self, slice_index: int, share_zscale=False, **kwargs):
if share_zscale:
for ax in fig.get_axes():
ax.get_images()[0].set_clim(vmin, vmax)
timestamp = self[0, 0].axis_world_coords("time")[-1].iso[slice_index]
fig.suptitle(f"{self.inventory['instrumentName']} Dataset ({self.inventory['datasetId']}) at time {timestamp} (slice={slice_index})", y=0.95)
title = f"{self.inventory['instrumentName']} Dataset ({self.inventory['datasetId']}) at "
for i, (coord, val) in enumerate(list(tiles[0].global_coords.items())[::-1]):
if coord == "time":
val = val.iso
if coord == "stokes":
val = val.symbol
title += f"{coord} {val}" + (", " if i != len(slice_index)-1 else " ")
title += f"(slice={(slice_index if len(slice_index) > 1 else slice_index[0])})".replace("slice(None, None, None)", ":")
fig.suptitle(title, y=0.95)
return fig

@property
def slice_tiles(self):
"""
Returns a new TiledDataset with the given slice applied to each of the tiles.
Examples
--------
.. code-block:: python
>>> from dkist import load_dataset
>>> from dkist.data.sample import VBI_AJQWW # doctest: +REMOTE_DATA
>>> ds = load_dataset(VBI_AJQWW) # doctest: +REMOTE_DATA
>>> ds.slice_tiles[0, 10:-10] # doctest: +REMOTE_DATA
<dkist.dataset.tiled_dataset.TiledDataset object at ...>
This VBI Dataset AJQWW is an array of (3, 3) Dataset objects and
consists of 9 frames.
Files are stored in ...
Each Dataset has 2 pixel and 2 world dimensions.
The data are represented by a <class 'dask.array.core.Array'> object:
dask.array<getitem, shape=(4076, 4096), dtype=float32, chunksize=(4076, 4096), chunktype=numpy.ndarray>
Array Dim Axis Name Data size Bounds
0 helioprojective latitude 4076 None
1 helioprojective longitude 4096 None
World Dim Axis Name Physical Type Units
1 helioprojective latitude custom:pos.helioprojective.lat arcsec
0 helioprojective longitude custom:pos.helioprojective.lon arcsec
Correlation between pixel and world axes:
| PIXEL DIMENSIONS
| helioprojective | helioprojective
WORLD DIMENSIONS | longitude | latitude
------------------------- | --------------- | ---------------
helioprojective longitude | x | x
helioprojective latitude | x | x
"""

return TiledDatasetSlicer(self._data, self.inventory)

# TODO: def regrid()
Expand All @@ -199,10 +263,6 @@ def files(self):
"""
A `~.FileManager` helper for interacting with the files backing the data in this ``Dataset``.
"""
return self._file_manager

@property
def _file_manager(self):
fileuris = [[tile.files.filenames for tile in row] for row in self]
dtype = self[0, 0].files.fileuri_array.dtype
shape = self[0, 0].files.shape
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ tests = [
"pytest-mpl",
"pytest-httpserver",
"pytest-filter-subpackage",
"pytest-benchmark",
"pytest-benchmark<5",
"pytest-xdist",
"hypothesis",
"tox",
Expand Down

0 comments on commit feb67cc

Please sign in to comment.