diff --git a/src/hats/catalog/healpix_dataset/healpix_dataset.py b/src/hats/catalog/healpix_dataset/healpix_dataset.py index 3fa32e58..4a9f63ab 100644 --- a/src/hats/catalog/healpix_dataset/healpix_dataset.py +++ b/src/hats/catalog/healpix_dataset/healpix_dataset.py @@ -14,6 +14,7 @@ from hats.catalog.dataset.table_properties import TableProperties from hats.catalog.partition_info import PartitionInfo from hats.inspection import plot_pixels +from hats.inspection.visualize_catalog import plot_moc from hats.pixel_math import HealpixPixel from hats.pixel_tree import PixelAlignment, PixelAlignmentType from hats.pixel_tree.moc_filter import filter_by_moc @@ -170,6 +171,17 @@ def plot_pixels(self, **kwargs): """Create a visual map of the pixel density of the catalog. Args: - kwargs: Additional args to pass to `hipscat.inspection.visualize_catalog.plot_healpix_map` + kwargs: Additional args to pass to `hats.inspection.visualize_catalog.plot_healpix_map` """ return plot_pixels(self, **kwargs) + + def plot_moc(self, **kwargs): + """Create a visual map of the coverage of the catalog. + + Args: + kwargs: Additional args to pass to `hats.inspection.visualize_catalog.plot_moc` + """ + default_title = f"Coverage MOC of {self.catalog_name}" + plot_args = {"title": default_title} + plot_args.update(kwargs) + return plot_moc(self.moc, **plot_args) diff --git a/src/hats/inspection/visualize_catalog.py b/src/hats/inspection/visualize_catalog.py index 622143d8..227a9425 100644 --- a/src/hats/inspection/visualize_catalog.py +++ b/src/hats/inspection/visualize_catalog.py @@ -14,10 +14,10 @@ import numpy as np from astropy.coordinates import ICRS, Angle, SkyCoord from astropy.units import Quantity +from astropy.visualization.wcsaxes import WCSAxes from astropy.visualization.wcsaxes.frame import BaseFrame, EllipticalFrame from astropy.wcs.utils import pixel_to_skycoord, skycoord_to_pixel from matplotlib import pyplot as plt -from matplotlib.axes import Axes from matplotlib.collections import PathCollection from matplotlib.colors import Colormap, Normalize from matplotlib.figure import Figure @@ -109,6 +109,73 @@ def plot_pixel_list(pixels: List[HealpixPixel], plot_title: str = "", projection return fig, ax +def plot_moc( + moc: MOC, + *, + projection: str = "MOL", + title: str = "", + fov: Quantity | Tuple[Quantity, Quantity] = None, + center: SkyCoord | None = None, + wcs: astropy.wcs.WCS = None, + frame_class: Type[BaseFrame] | None = None, + ax: WCSAxes | None = None, + fig: Figure | None = None, + **kwargs, +) -> Tuple[Figure, WCSAxes]: + """Plots a moc + + By default, a new matplotlib figure and axis will be created, and the projection will be a Molleweide + projection across the whole sky. + + Args: + moc (mocpy.MOC): MOC to plot + projection (str): The projection to use in the WCS. Available projections listed at + https://docs.astropy.org/en/stable/wcs/supported_projections.html + title (str): The title of the plot + fov (Quantity or Sequence[Quantity, Quantity] | None): The Field of View of the WCS. Must be an + astropy Quantity with an angular unit, or a tuple of quantities for different longitude and \ + latitude FOVs (Default covers the full sky) + center (SkyCoord | None): The center of the projection in the WCS (Default: SkyCoord(0, 0)) + wcs (WCS | None): The WCS to specify the projection of the plot. If used, all other WCS parameters + are ignored and the parameters from the WCS object is used. + frame_class (Type[BaseFrame] | None): The class of the frame for the WCSAxes to be initialized with. + if the `ax` kwarg is used, this value is ignored (By Default uses EllipticalFrame for full + sky projection. If FOV is set, RectangularFrame is used) + ax (WCSAxes | None): The matplotlib axes to plot onto. If None, an axes will be created to be used. If + specified, the axes must be an astropy WCSAxes, and the `wcs` parameter must be set with the WCS + object used in the axes. (Default: None) + fig (Figure | None): The matplotlib figure to add the axes to. If None, one will be created, unless + ax is specified (Default: None) + **kwargs: Additional kwargs to pass to `mocpy.MOC.fill` + + Returns: + Tuple[Figure, WCSAxes] - The figure and axes used to plot the healpix map + """ + fig, ax, wcs = initialize_wcs_axes( + projection=projection, + fov=fov, + center=center, + wcs=wcs, + frame_class=frame_class, + ax=ax, + fig=fig, + figsize=(9, 5), + ) + + mocpy_args = {"alpha": 0.5, "fill": True, "color": "teal"} + mocpy_args.update(**kwargs) + + moc.fill(ax, wcs, **mocpy_args) + + ax.coords[0].set_format_unit("deg") + + plt.grid() + plt.ylabel("Dec") + plt.xlabel("RA") + plt.title(title) + return fig, ax + + def cull_to_fov(depth_ipix_d: Dict[int, Tuple[np.ndarray, np.ndarray]], wcs): """Culls a mapping of ipix to values to pixels that are inside the plot window defined by a WCS @@ -305,7 +372,7 @@ def plot_healpix_map( center: SkyCoord | None = None, wcs: astropy.wcs.WCS = None, frame_class: Type[BaseFrame] | None = None, - ax: Axes | None = None, + ax: WCSAxes | None = None, fig: Figure | None = None, **kwargs, ): @@ -348,25 +415,76 @@ def plot_healpix_map( frame_class (Type[BaseFrame] | None): The class of the frame for the WCSAxes to be initialized with. if the `ax` kwarg is used, this value is ignored (By Default uses EllipticalFrame for full sky projection. If FOV is set, RectangularFrame is used) - ax (Axes | None): The matplotlib axes to plot onto. If None, an axes will be created to be used. If - specified, the axes must be initialized with a WCS for the projection, and passed to the method - with the WCS parameter. (Default: None) + ax (WCSAxes | None): The matplotlib axes to plot onto. If None, an axes will be created to be used. If + specified, the axes must be an astropy WCSAxes, and the `wcs` parameter must be set with the WCS + object used in the axes. (Default: None) fig (Figure | None): The matplotlib figure to add the axes to. If None, one will be created, unless ax is specified (Default: None) **kwargs: Additional kwargs to pass to creating the matplotlib `PathCollection` artist Returns: - Tuple[Figure, Axes] - The figure and axes used to plot the healpix map + Tuple[Figure, WCSAxes] - The figure and axes used to plot the healpix map """ if ipix is None or depth is None: order = int(np.ceil(np.log2(len(healpix_map) / 12) / 2)) ipix = np.arange(len(healpix_map)) depth = np.full(len(healpix_map), fill_value=order) + + fig, ax, wcs = initialize_wcs_axes( + projection=projection, + fov=fov, + center=center, + wcs=wcs, + frame_class=frame_class, + ax=ax, + fig=fig, + figsize=(10, 5), + ) + + _plot_healpix_value_map(ipix, depth, healpix_map, ax, wcs, cmap=cmap, norm=norm, cbar=cbar, **kwargs) + plt.grid() + plt.ylabel("Dec") + plt.xlabel("RA") + plt.title(title) + return fig, ax + + +def initialize_wcs_axes( + projection: str = "MOL", + fov: Quantity | Tuple[Quantity, Quantity] = None, + center: SkyCoord | None = None, + wcs: astropy.wcs.WCS = None, + frame_class: Type[BaseFrame] | None = None, + ax: WCSAxes | None = None, + fig: Figure | None = None, + **kwargs, +): + """Initializes matplotlib Figure and WCSAxes if they do not exist + + Args: + projection (str): The projection to use in the WCS. Available projections listed at + https://docs.astropy.org/en/stable/wcs/supported_projections.html + fov (Quantity or Sequence[Quantity, Quantity] | None): The Field of View of the WCS. Must be an + astropy Quantity with an angular unit, or a tuple of quantities for different longitude and \ + latitude FOVs (Default covers the full sky) + center (SkyCoord | None): The center of the projection in the WCS (Default: SkyCoord(0, 0)) + wcs (WCS | None): The WCS to specify the projection of the plot. If used, all other WCS parameters + are ignored and the parameters from the WCS object is used. + frame_class (Type[BaseFrame] | None): The class of the frame for the WCSAxes to be initialized with. + if the `ax` kwarg is used, this value is ignored (By Default uses EllipticalFrame for full + sky projection. If FOV is set, RectangularFrame is used) + ax (WCSAxes | None): The matplotlib axes to plot onto. If None, an axes will be created to be used. If + specified, the axes must be an astropy WCSAxes, and the `wcs` parameter must be set with the WCS + object used in the axes. (Default: None) + fig (Figure | None): The matplotlib figure to add the axes to. If None, one will be created, unless + ax is specified (Default: None) + kwargs: additional kwargs to pass to figure initialization + """ if fig is None: if ax is not None: fig = ax.get_figure() else: - fig = plt.figure(figsize=(10, 5)) + fig = plt.figure(**kwargs) if frame_class is None and fov is None and wcs is None: frame_class = EllipticalFrame if fov is None: @@ -388,12 +506,7 @@ def plot_healpix_map( raise ValueError( "if ax is provided, wcs must also be provided with the projection used in initializing ax" ) - _plot_healpix_value_map(ipix, depth, healpix_map, ax, wcs, cmap=cmap, norm=norm, cbar=cbar, **kwargs) - plt.grid() - plt.ylabel("Dec") - plt.xlabel("RA") - plt.title(title) - return fig, ax + return fig, ax, wcs def _plot_healpix_value_map(ipix, depth, values, ax, wcs, cmap="viridis", norm=None, cbar=True, **kwargs): diff --git a/tests/hats/inspection/test_visualize_catalog.py b/tests/hats/inspection/test_visualize_catalog.py index a57f90f5..762dd4f8 100644 --- a/tests/hats/inspection/test_visualize_catalog.py +++ b/tests/hats/inspection/test_visualize_catalog.py @@ -1,3 +1,5 @@ +from unittest.mock import MagicMock + import astropy.units as u import matplotlib.pyplot as plt import numpy as np @@ -12,7 +14,7 @@ from mocpy.moc.plot.utils import build_plotting_moc from hats.inspection import plot_pixels -from hats.inspection.visualize_catalog import cull_from_pixel_map, cull_to_fov, plot_healpix_map +from hats.inspection.visualize_catalog import cull_from_pixel_map, cull_to_fov, plot_healpix_map, plot_moc # pylint: disable=no-member @@ -669,3 +671,21 @@ def test_catalog_plot(small_sky_order1_catalog): np.testing.assert_array_equal(path.codes, codes) np.testing.assert_array_equal(col.get_array(), np.array(order_3_orders)) assert ax.get_title() == f"Catalog pixel density map - {small_sky_order1_catalog.catalog_name}" + + +def test_plot_moc(small_sky_order1_catalog): + small_sky_order1_catalog.moc.fill = MagicMock() + _, ax = plot_moc(small_sky_order1_catalog.moc) + small_sky_order1_catalog.moc.fill.assert_called_once() + assert small_sky_order1_catalog.moc.fill.call_args[0][0] is ax + wcs = ax.wcs + assert small_sky_order1_catalog.moc.fill.call_args[0][1] is wcs + + +def test_plot_moc_catalog(small_sky_order1_catalog): + small_sky_order1_catalog.moc.fill = MagicMock() + _, ax = small_sky_order1_catalog.plot_moc() + small_sky_order1_catalog.moc.fill.assert_called_once() + assert small_sky_order1_catalog.moc.fill.call_args[0][0] is ax + wcs = ax.wcs + assert small_sky_order1_catalog.moc.fill.call_args[0][1] is wcs