Skip to content

Commit

Permalink
add config helper for cubic data like cubeviz and rampviz
Browse files Browse the repository at this point in the history
  • Loading branch information
bmorris3 committed Jul 30, 2024
1 parent 682e302 commit 407915a
Show file tree
Hide file tree
Showing 8 changed files with 414 additions and 81 deletions.
2 changes: 1 addition & 1 deletion jdaviz/configs/cubeviz/cubeviz.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ tray:
- g-data-quality
- g-subset-plugin
- g-markers
- cubeviz-slice
- cube-slice
- g-unit-conversion
- cubeviz-spectral-extraction
- g-gaussian-smooth
Expand Down
77 changes: 4 additions & 73 deletions jdaviz/configs/cubeviz/helper.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
import numpy as np
from astropy.io import fits
from astropy.io import registry as io_registry
from specutils import Spectrum1D
from specutils.io.registers import _astropy_has_priorities

from jdaviz.core.events import SnackbarMessage
from jdaviz.core.helpers import ImageConfigHelper
from jdaviz.configs.default.plugins.line_lists.line_list_mixin import LineListMixin
from jdaviz.configs.specviz import Specviz
from jdaviz.core.events import (AddDataMessage,
SliceSelectSliceMessage)
from jdaviz.core.events import AddDataMessage, SnackbarMessage
from jdaviz.core.helpers import CubeConfigHelper

__all__ = ['Cubeviz']

Expand All @@ -18,7 +10,7 @@
"Wavenumber", "Velocity", "Energy"]


class Cubeviz(ImageConfigHelper, LineListMixin):
class Cubeviz(CubeConfigHelper, LineListMixin):
"""Cubeviz Helper class"""
_default_configuration = 'cubeviz'
_default_spectrum_viewer_reference_name = "spectrum-viewer"
Expand Down Expand Up @@ -112,8 +104,7 @@ def select_wavelength(self, wavelength):
"""
if not isinstance(wavelength, (int, float)):
raise TypeError("wavelength must be a float or int")
msg = SliceSelectSliceMessage(value=wavelength, sender=self)
self.app.hub.broadcast(msg)
self.select_slice(wavelength)

@property
def specviz(self):
Expand Down Expand Up @@ -166,63 +157,3 @@ def get_aperture_photometry_results(self):
"""
return self.plugins['Aperture Photometry']._obj.export_table()


# TODO: We can remove this when specutils supports it, i.e.,
# https://github.com/astropy/specutils/issues/592 and
# https://github.com/astropy/specutils/pull/1009
# NOTE: Cannot use custom_write decorator from specutils because
# that involves asking user to manually put something in
# their ~/.specutils directory.

def jdaviz_cube_fitswriter(spectrum, file_name, **kwargs):
"""This is a custom writer for Spectrum1D data cube.
This writer is specifically targetting data cube
generated from Cubeviz plugins (e.g., cube fitting)
with FITS WCS. It writes out data in the following format
(with MASK only exist when applicable)::
No. Name Ver Type
0 PRIMARY 1 PrimaryHDU
1 SCI 1 ImageHDU (float32)
2 MASK 1 ImageHDU (uint16)
The FITS file generated by this writer does not need a
custom reader to be read back into Spectrum1D.
Examples
--------
To write out a Spectrum1D cube using this writer:
>>> spec.write("my_output.fits", format="jdaviz-cube", overwrite=True) # doctest: +SKIP
"""
pri_hdu = fits.PrimaryHDU()

flux = spectrum.flux
sci_hdu = fits.ImageHDU(flux.value.astype(np.float32))
sci_hdu.name = "SCI"
sci_hdu.header.update(spectrum.meta)
sci_hdu.header.update(spectrum.wcs.to_header())
sci_hdu.header['BUNIT'] = flux.unit.to_string(format='fits')

hlist = [pri_hdu, sci_hdu]

# https://specutils.readthedocs.io/en/latest/spectrum1d.html#including-masks
# Good: False or 0
# Bad: True or non-zero
if spectrum.mask is not None:
mask_hdu = fits.ImageHDU(spectrum.mask.astype(np.uint16))
mask_hdu.name = "MASK"
hlist.append(mask_hdu)

hdulist = fits.HDUList(hlist)
hdulist.writeto(file_name, **kwargs)


if _astropy_has_priorities():
kwargs = {"priority": 0}
else: # pragma: no cover
kwargs = {}
io_registry.register_writer(
"jdaviz-cube", Spectrum1D, jdaviz_cube_fitswriter, force=False, **kwargs)
2 changes: 1 addition & 1 deletion jdaviz/configs/cubeviz/plugins/slice/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
__all__ = ['Slice']


@tray_registry('cubeviz-slice', label="Slice", viewer_requirements='spectrum')
@tray_registry('cube-slice', label="Slice", viewer_requirements='spectrum')
class Slice(PluginTemplateMixin):
"""
See the :ref:`Slice Plugin Documentation <slice>` for more details.
Expand Down
93 changes: 93 additions & 0 deletions jdaviz/configs/rampviz/helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from jdaviz.core.helpers import CubeConfigHelper
from jdaviz.core.events import SliceSelectSliceMessage

__all__ = ['Rampviz']


class Rampviz(CubeConfigHelper):
"""Rampviz Helper class"""
_default_configuration = 'rampviz'
_default_profile_viewer_reference_name = "integration-viewer"
_default_diff_viewer_reference_name = "diff-viewer"
_default_group_viewer_reference_name = "group-viewer"
_default_image_viewer_reference_name = "image-viewer"

_loaded_flux_cube = None

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def load_data(self, data, data_label=None, override_cube_limit=False, **kwargs):
"""
Load and parse a data cube with Cubeviz.
(Note that only one cube may be loaded per Cubeviz instance.)
Parameters
----------
data : str, `~astropy.io.fits.HDUList`, or ndarray
A string file path, astropy FITS object pointing to the
data cube, a spectrum object, or a Numpy array cube.
If plain array is given, axes order must be ``(x, y, z)``.
data_label : str or `None`
Data label to go with the given data. If not given,
one will be automatically generated.
override_cube_limit : bool
Override internal cube count limitation and load the data anyway.
Setting this to `True` is not recommended unless you know what
you are doing.
**kwargs : dict
Extra keywords accepted by Jdaviz application-level parser.
"""
if not override_cube_limit and len(self.app.state.data_items) != 0:
raise RuntimeError('Only one cube may be loaded per Cubeviz instance')
if data_label:
kwargs['data_label'] = data_label

super().load_data(data, parser_reference="cubeviz-data-parser", **kwargs)

def select_group(self, group_index):
"""
Select the slice closest to the provided wavelength.
Parameters
----------
group_index : float
Group index to select in units of the x-axis of the integration.
The nearest group will be selected if "snap to slice" is enabled
in the slice plugin.
"""
if not isinstance(group_index, int):
raise TypeError("group_index must be an integer")
if slice < 0:
raise ValueError("group_index must be positive")

msg = SliceSelectSliceMessage(value=group_index, sender=self)
self.app.hub.broadcast(msg)

def get_data(self, data_label=None, spatial_subset=None,
temporal_subset=None, cls=None, use_display_units=False):
"""
Returns data with name equal to ``data_label`` of type ``cls`` with subsets applied from
``spectral_subset``, if applicable.
Parameters
----------
data_label : str, optional
Provide a label to retrieve a specific data set from data_collection.
spatial_subset : str, optional
Spatial subset applied to data. Only applicable if ``data_label`` points to a cube or
image. To extract a spectrum from a cube, use the spectral extraction plugin instead.
temporal_subset : str, optional
cls : `~specutils.Spectrum1D`, `~astropy.nddata.CCDData`, optional
The type that data will be returned as.
Returns
-------
data : cls
Data is returned as type cls with subsets applied.
"""
return self._get_data(data_label=data_label, spatial_subset=spatial_subset,
temporal_subset=temporal_subset,
cls=cls, use_display_units=use_display_units)
177 changes: 177 additions & 0 deletions jdaviz/configs/rampviz/plugins/parsers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import os
import numpy as np
from astropy.io import fits

from jdaviz.utils import download_uri_to_path
from jdaviz.configs.cubeviz.plugins.parsers import (
_parse_ndarray, _parse_hdulist, _parse_gif
)
from jdaviz.core.registries import data_parser_registry

try:
from roman_datamodels import datamodels as rdd
except ImportError:
HAS_ROMAN_DATAMODELS = False
else:
HAS_ROMAN_DATAMODELS = True


@data_parser_registry("rampviz-data-parser")
def parse_data(app, file_obj, data_type=None, data_label=None,
parent=None, cache=None, local_path=None, timeout=None):
"""
Attempts to parse a data file and auto-populate available viewers in
rampviz.
Parameters
----------
app : `~jdaviz.app.Application`
The application-level object used to reference the viewers.
file_obj : str
The path to a cube-like data file.
data_type : str, {'flux', 'mask', 'uncert'}
The data type used to explicitly differentiate parsed data.
data_label : str, optional
The label to be applied to the Glue data component.
parent : str, optional
Data label for "parent" data to associate with the loaded data as "child".
cache : None, bool, or str
Cache the downloaded file if the data are retrieved by a query
to a URL or URI.
local_path : str, optional
Cache remote files to this path. This is only used if data is
requested from `astroquery.mast`.
timeout : float, optional
If downloading from a remote URI, set the timeout limit for
remote requests in seconds (passed to
`~astropy.utils.data.download_file` or
`~astroquery.mast.Conf.timeout`).
"""

flux_viewer_reference_name = app._jdaviz_helper._default_group_viewer_reference_name
uncert_viewer_reference_name = app._jdaviz_helper._default_diff_viewer_reference_name
spectrum_viewer_reference_name = app._jdaviz_helper._default_profile_viewer_reference_name

if data_type is not None and data_type.lower() not in ('flux', 'mask', 'uncert'):
raise TypeError("Data type must be one of 'flux', 'mask', or 'uncert' "
f"but got '{data_type}'")

# If the file object is an hdulist or a string, use the generic parser for
# fits files.
# TODO: this currently only supports fits files. We will want to make this
# generic enough to work with other file types (e.g. ASDF). For now, this
# supports MaNGA and JWST data.
if isinstance(file_obj, fits.hdu.hdulist.HDUList):
_parse_hdulist(
app, file_obj, file_name=data_label,
flux_viewer_reference_name=flux_viewer_reference_name,
uncert_viewer_reference_name=uncert_viewer_reference_name
)
elif isinstance(file_obj, str):
if file_obj.lower().endswith('.gif'): # pragma: no cover
_parse_gif(app, file_obj, data_label,
flux_viewer_reference_name=flux_viewer_reference_name)
return
elif file_obj.lower().endswith('.asdf'):
if not HAS_ROMAN_DATAMODELS:
raise ImportError(
"ASDF detected but roman-datamodels is not installed."
)
with rdd.open(file_obj) as pf:
_roman_3d_to_glue_data(
app, pf, data_label,
flux_viewer_reference_name=flux_viewer_reference_name,
spectrum_viewer_reference_name=spectrum_viewer_reference_name,
uncert_viewer_reference_name=uncert_viewer_reference_name
)
return

# try parsing file_obj as a URI/URL:
file_obj = download_uri_to_path(
file_obj, cache=cache, local_path=local_path, timeout=timeout
)

file_name = os.path.basename(file_obj)

with fits.open(file_obj) as hdulist:
_parse_hdulist(
app, hdulist, file_name=data_label or file_name,
flux_viewer_reference_name=flux_viewer_reference_name,
uncert_viewer_reference_name=uncert_viewer_reference_name
)

elif isinstance(file_obj, np.ndarray) and file_obj.ndim == 3:
_parse_ndarray(app, file_obj, data_label=data_label, data_type=data_type,
flux_viewer_reference_name=flux_viewer_reference_name,
uncert_viewer_reference_name=uncert_viewer_reference_name)

app.get_tray_item_from_name("Spectral Extraction").disabled_msg = ""

elif HAS_ROMAN_DATAMODELS and isinstance(file_obj, rdd.DataModel):
with rdd.open(file_obj) as pf:
_roman_3d_to_glue_data(
app, pf, data_label,
flux_viewer_reference_name=flux_viewer_reference_name,
spectrum_viewer_reference_name=spectrum_viewer_reference_name,
uncert_viewer_reference_name=uncert_viewer_reference_name
)

else:
raise NotImplementedError(f'Unsupported data format: {file_obj}')


def _roman_3d_to_glue_data(
app, file_obj, data_label,
flux_viewer_reference_name=None,
diff_viewer_reference_name=None,
integration_viewer_reference_name=None,
):
"""
Parse a Roman 3D ramp cube file (Level 1),
usually with suffix '_uncal.asdf'.
"""
def _swap_axes(x):
# swap axes per the conventions of Roman cubes
# (group axis comes first) and the default in
# Cubeviz (wavelength axis expected last)
return np.swapaxes(x, 0, -1)

# update viewer reference names for Roman ramp cubes:
# app._update_viewer_reference_name()

data = file_obj.data

if data_label is None:
data_label = app.return_data_label(file_obj)

# last axis is the group axis, first two are spatial axes:
diff_data = np.vstack([
# begin with a group of zeros, so
# that `diff_data.ndim == data.ndim`
np.zeros((1, *data[0].shape)),
np.diff(data, axis=0)
])

# load the `data` cube into what's usually the "flux-viewer"
_parse_ndarray(
app,
file_obj=_swap_axes(data),
data_label=f"{data_label}[DATA]",
data_type="flux",
flux_viewer_reference_name=flux_viewer_reference_name,
)

# load the diff of the data cube
# into what's usually the "uncert-viewer"
_parse_ndarray(
app,
file_obj=_swap_axes(diff_data),
data_type="uncert",
data_label=f"{data_label}[DIFF]",
uncert_viewer_reference_name=diff_viewer_reference_name,
)

# the default collapse function in the profile viewer is "sum",
# but for ramp files, "median" is more useful:
viewer = app.get_viewer('integration-viewer')
viewer.state.function = 'median'
Loading

0 comments on commit 407915a

Please sign in to comment.