diff --git a/docs/conf.py b/docs/conf.py index 7a668af0..ea4f7f0b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -174,7 +174,8 @@ { 'astropy': ('https://docs.astropy.org/en/stable/', None), 'ccdproc': ('https://ccdproc.readthedocs.io/en/stable/', None), - 'specutils': ('https://specutils.readthedocs.io/en/stable/', None) + 'specutils': ('https://specutils.readthedocs.io/en/stable/', None), + 'gwcs': ('https://gwcs.readthedocs.io/en/stable/', None) } ) # diff --git a/docs/index.rst b/docs/index.rst index e05ac72b..785b4397 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -50,6 +50,7 @@ Calibration .. toctree:: :maxdepth: 1 + wavelength_calibration.rst extinction.rst specphot_standards.rst diff --git a/docs/wavelength_calibration.rst b/docs/wavelength_calibration.rst new file mode 100644 index 00000000..f34d7dec --- /dev/null +++ b/docs/wavelength_calibration.rst @@ -0,0 +1,53 @@ +.. _wavelength_calibration: + +Wavelength Calibration +====================== + +Wavelength calibration is currently supported for 1D spectra. Given a list of spectral +lines with known wavelengths and estimated pixel positions on an input calibration +spectrum, you can currently use ``specreduce`` to: + +#. Fit an ``astropy`` model to the wavelength/pixel pairs to generate a spectral WCS + solution for the dispersion. +#. Apply the generated spectral WCS to other `~specutils.Spectrum1D` objects. + +1D Wavelength Calibration +------------------------- + +The `~specreduce.wavelength_calibration.WavelengthCalibration1D` class can be used +to fit a dispersion model to a list of line positions and wavelengths. Future development +will implement catalogs of known lamp spectra for use in matching observed lines. In the +example below, the line positions (``pixel_centers``) have already been extracted from +``lamp_spectrum``:: + + import astropy.units as u + from specreduce import WavelengthCalibration1D + pixel_centers = [10, 22, 31, 43] + wavelengths = [5340, 5410, 5476, 5543]*u.AA + test_cal = WavelengthCalibration1D(lamp_spectrum, line_pixels=pixel_centers, + line_wavelengths=wavelengths) + calibrated_spectrum = test_cal.apply_to_spectrum(science_spectrum) + +The example above uses the default model (`~astropy.modeling.functional_models.Linear1D`) +to fit the input spectral lines, and then applies the calculated WCS solution to a second +spectrum (``science_spectrum``). Any other 1D ``astropy`` model can be provided as the +input ``model`` parameter to the `~specreduce.wavelength_calibration.WavelengthCalibration1D`. +In the above example, the model fit and WCS construction is all done as part of the +``apply_to_spectrum()`` call, but you could also access the `~gwcs.wcs.WCS` object itself +by calling:: + + test_cal.wcs + +The calculated WCS is a cached property that will be cleared if the ``line_list``, ``model``, +or ``input_spectrum`` properties are updated, since these will alter the calculated dispersion +fit. + +You can also provide the input pixel locations and wavelengths of the lines as an +`~astropy.table.QTable` with (at minimum) columns ``pixel_center`` and ``wavelength``, +using the ``matched_line_list`` input argument:: + + from astropy.table import QTable + pixels = [10, 20, 30, 40]*u.pix + wavelength = [5340, 5410, 5476, 5543]*u.AA + line_list = QTable([pixels, wavelength], names=["pixel_center", "wavelength"]) + test_cal = WavelengthCalibration1D(lamp_spectrum, matched_line_list=line_list) \ No newline at end of file diff --git a/specreduce/__init__.py b/specreduce/__init__.py index 6bff38bb..5bf2362e 100644 --- a/specreduce/__init__.py +++ b/specreduce/__init__.py @@ -7,3 +7,4 @@ # ---------------------------------------------------------------------------- from specreduce.core import * # noqa +from specreduce.wavelength_calibration import * # noqa diff --git a/specreduce/conftest.py b/specreduce/conftest.py index 672b2733..87ee89f1 100644 --- a/specreduce/conftest.py +++ b/specreduce/conftest.py @@ -6,6 +6,10 @@ import os from astropy.version import version as astropy_version +import astropy.units as u +import numpy as np +import pytest +from specutils import Spectrum1D # For Astropy 3.0 and later, we can use the standalone pytest plugin if astropy_version < '3.0': @@ -20,6 +24,37 @@ ASTROPY_HEADER = False +@pytest.fixture +def spec1d(): + np.random.seed(7) + flux = np.random.random(50)*u.Jy + sa = np.arange(0, 50)*u.pix + spec = Spectrum1D(flux, spectral_axis=sa) + return spec + + +@pytest.fixture +def spec1d_with_emission_line(): + np.random.seed(7) + sa = np.arange(0, 200)*u.pix + flux = (np.random.randn(200) + + 10*np.exp(-0.01*((sa.value-130)**2)) + + sa.value/100) * u.Jy + spec = Spectrum1D(flux, spectral_axis=sa) + return spec + + +@pytest.fixture +def spec1d_with_absorption_line(): + np.random.seed(7) + sa = np.arange(0, 200)*u.pix + flux = (np.random.randn(200) - + 10*np.exp(-0.01*((sa.value-130)**2)) + + sa.value/100) * u.Jy + spec = Spectrum1D(flux, spectral_axis=sa) + return spec + + def pytest_configure(config): if ASTROPY_HEADER: diff --git a/specreduce/tests/test_wavelength_calibration.py b/specreduce/tests/test_wavelength_calibration.py new file mode 100644 index 00000000..93048f67 --- /dev/null +++ b/specreduce/tests/test_wavelength_calibration.py @@ -0,0 +1,85 @@ +from numpy.testing import assert_allclose +import pytest + +from astropy.table import QTable +import astropy.units as u +from astropy.modeling.models import Polynomial1D +from astropy.modeling.fitting import LinearLSQFitter +from astropy.tests.helper import assert_quantity_allclose + +from specreduce import WavelengthCalibration1D + + +def test_linear_from_list(spec1d): + centers = [0, 10, 20, 30] + w = [5000, 5100, 5198, 5305]*u.AA + test = WavelengthCalibration1D(spec1d, line_pixels=centers, line_wavelengths=w) + spec2 = test.apply_to_spectrum(spec1d) + + assert_quantity_allclose(spec2.spectral_axis[0], 4998.8*u.AA) + assert_quantity_allclose(spec2.spectral_axis[-1], 5495.169999*u.AA) + + +def test_wavelength_from_table(spec1d): + centers = [0, 10, 20, 30] + w = [5000, 5100, 5198, 5305]*u.AA + table = QTable([w], names=["wavelength"]) + WavelengthCalibration1D(spec1d, line_pixels=centers, line_wavelengths=table) + + +def test_linear_from_table(spec1d): + centers = [0, 10, 20, 30] + w = [5000, 5100, 5198, 5305]*u.AA + table = QTable([centers, w], names=["pixel_center", "wavelength"]) + test = WavelengthCalibration1D(spec1d, matched_line_list=table) + spec2 = test.apply_to_spectrum(spec1d) + + assert_quantity_allclose(spec2.spectral_axis[0], 4998.8*u.AA) + assert_quantity_allclose(spec2.spectral_axis[-1], 5495.169999*u.AA) + + +def test_poly_from_table(spec1d): + # This test is mostly to prove that you can use other models + centers = [0, 10, 20, 30, 40] + w = [5005, 5110, 5214, 5330, 5438]*u.AA + table = QTable([centers, w], names=["pixel_center", "wavelength"]) + + test = WavelengthCalibration1D(spec1d, matched_line_list=table, + model=Polynomial1D(2), fitter=LinearLSQFitter()) + test.apply_to_spectrum(spec1d) + + assert_allclose(test.model.parameters, [5.00477143e+03, 1.03457143e+01, 1.28571429e-02]) + + +def test_replace_spectrum(spec1d, spec1d_with_emission_line): + centers = [0, 10, 20, 30]*u.pix + w = [5000, 5100, 5198, 5305]*u.AA + test = WavelengthCalibration1D(spec1d, line_pixels=centers, line_wavelengths=w) + # Accessing this property causes fits the model and caches the resulting WCS + test.wcs + assert "wcs" in test.__dict__ + + # Replace the input spectrum, which should clear the cached properties + test.input_spectrum = spec1d_with_emission_line + assert "wcs" not in test.__dict__ + + +def test_expected_errors(spec1d): + centers = [0, 10, 20, 30, 40] + w = [5005, 5110, 5214, 5330, 5438]*u.AA + table = QTable([centers, w], names=["pixel_center", "wavelength"]) + + with pytest.raises(ValueError, match="Cannot specify line_wavelengths separately"): + WavelengthCalibration1D(spec1d, matched_line_list=table, line_wavelengths=w) + + with pytest.raises(ValueError, match="must have the same length"): + w2 = [5005, 5110, 5214, 5330, 5438, 5500]*u.AA + WavelengthCalibration1D(spec1d, line_pixels=centers, line_wavelengths=w2) + + with pytest.raises(ValueError, match="astropy.units.Quantity array or" + " as an astropy.table.QTable"): + w2 = [5005, 5110, 5214, 5330, 5438] + WavelengthCalibration1D(spec1d, line_pixels=centers, line_wavelengths=w2) + + with pytest.raises(ValueError, match="specify at least one"): + WavelengthCalibration1D(spec1d, line_pixels=centers) diff --git a/specreduce/wavelength_calibration.py b/specreduce/wavelength_calibration.py new file mode 100644 index 00000000..d870f06e --- /dev/null +++ b/specreduce/wavelength_calibration.py @@ -0,0 +1,230 @@ +from astropy.modeling.models import Linear1D +from astropy.modeling.fitting import LMLSQFitter, LinearLSQFitter +from astropy.table import QTable, hstack +import astropy.units as u +from functools import cached_property +from gwcs import wcs +from gwcs import coordinate_frames as cf +import numpy as np +from specutils import Spectrum1D + + +__all__ = ['WavelengthCalibration1D'] + + +def get_available_catalogs(): + """ + ToDo: Decide in what format to store calibration line catalogs (e.g., for lamps) + and write this function to determine the list of available catalog names. + """ + return [] + + +def concatenate_catalogs(): + """ + ToDo: Code logic to combine the lines from multiple catalogs if needed + """ + pass + + +class WavelengthCalibration1D(): + + def __init__(self, input_spectrum, matched_line_list=None, line_pixels=None, + line_wavelengths=None, catalog=None, model=Linear1D(), fitter=None): + """ + input_spectrum: `~specutils.Spectrum1D` + A one-dimensional Spectrum1D calibration spectrum from an arc lamp or similar. + matched_line_list: `~astropy.table.QTable`, optional + An `~astropy.table.QTable` table with (minimally) columns named + "pixel_center" and "wavelength" with known corresponding line pixel centers + and wavelengths populated. + line_pixels: list, array, `~astropy.table.QTable`, optional + List or array of line pixel locations to anchor the wavelength solution fit. + Will be converted to an astropy table internally if a list or array was input. + Can also be input as an `~astropy.table.QTable` table with (minimally) a column + named "pixel_center". + line_wavelengths: `~astropy.units.Quantity`, `~astropy.table.QTable`, optional + `astropy.units.Quantity` array of line wavelength values corresponding to the + line pixels defined in ``line_list``. Does not have to be in the same order] + (the lists will be sorted) but does currently need to be the same length as + line_list. Can also be input as an `~astropy.table.QTable` with (minimally) + a "wavelength" column. + catalog: list, str, `~astropy.table.QTable`, optional + The name of a catalog of line wavelengths to load and use in automated and + template-matching line matching. + model: `~astropy.modeling.Model` + The model to fit for the wavelength solution. Defaults to a linear model. + fitter: `~astropy.modeling.fitting.Fitter`, optional + The fitter to use in optimizing the model fit. Defaults to + `~astropy.modeling.fitting.LinearLSQFitter` if the model to fit is linear + or `~astropy.modeling.fitting.LMLSQFitter` if the model to fit is non-linear. + + Note that either ``matched_line_list`` or ``line_pixels`` must be specified, + and if ``matched_line_list`` is not input, at least one of ``line_wavelengths`` + or ``catalog`` must be specified. + """ + self._input_spectrum = input_spectrum + self._model = model + self._cached_properties = ['wcs',] + self.fitter = fitter + self._potential_wavelengths = None + self._catalog = catalog + + # ToDo: Implement having line catalogs + self._available_catalogs = get_available_catalogs() + + # We use either line_pixels or matched_line_list to create self._matched_line_list, + # and check that various requirements are fulfilled by the input args. + if matched_line_list is not None: + pixel_arg = "matched_line_list" + if not isinstance(matched_line_list, QTable): + raise ValueError("matched_line_list must be an astropy.table.QTable.") + self._matched_line_list = matched_line_list + elif line_pixels is not None: + pixel_arg = "line_pixels" + if isinstance(line_pixels, (list, np.ndarray)): + self._matched_line_list = QTable([line_pixels], names=["pixel_center"]) + elif isinstance(line_pixels, QTable): + self._matched_line_list = line_pixels + else: + raise ValueError("Either matched_line_list or line_pixels must be specified.") + + if "pixel_center" not in self._matched_line_list.columns: + raise ValueError(f"{pixel_arg} must have a 'pixel_center' column.") + + if self._matched_line_list["pixel_center"].unit is None: + self._matched_line_list["pixel_center"].unit = u.pix + + # Make sure our pixel locations are sorted + self._matched_line_list.sort("pixel_center") + + if (line_wavelengths is None and catalog is None + and "wavelength" not in self._matched_line_list.columns): + raise ValueError("You must specify at least one of line_wavelengths, " + "catalog, or 'wavelength' column in matched_line_list.") + + # Sanity checks on line_wavelengths value + if line_wavelengths is not None: + if (isinstance(self._matched_line_list, QTable) and + "wavelength" in self._matched_line_list.columns): + raise ValueError("Cannot specify line_wavelengths separately if there is" + " a 'wavelength' column in matched_line_list.") + if len(line_wavelengths) != len(self._matched_line_list): + raise ValueError("If line_wavelengths is specified, it must have the same " + f"length as {pixel_arg}") + if not isinstance(line_wavelengths, (u.Quantity, QTable)): + raise ValueError("line_wavelengths must be specified as an astropy.units.Quantity" + " array or as an astropy.table.QTable") + if isinstance(line_wavelengths, u.Quantity): + # Ensure frequency is descending or wavelength is ascending + if str(line_wavelengths.unit.physical_type) == "frequency": + line_wavelengths[::-1].sort() + else: + line_wavelengths.sort() + self._matched_line_list["wavelength"] = line_wavelengths + elif isinstance(line_wavelengths, QTable): + line_wavelengths.sort("wavelength") + self._matched_line_list = hstack([self._matched_line_list, line_wavelengths]) + + # Parse desired catalogs of lines for matching. + if catalog is not None: + # For now we avoid going into the later logic and just throw an error + raise NotImplementedError("No catalogs are available yet, please input " + "wavelengths with line_wavelengths or as a " + f"column in {pixel_arg}") + + if isinstance(catalog, QTable): + if "wavelength" not in catalog.columns: + raise ValueError("Catalog table must have a 'wavelength' column.") + self._catalog = catalog + else: + # This will need to be updated to match up with Tim's catalog code + if isinstance(catalog, list): + self._catalog = catalog + else: + self._catalog = [catalog] + for cat in self._catalog: + if isinstance(cat, str): + if cat not in self._available_catalogs: + raise ValueError(f"Line list '{cat}' is not an available catalog.") + + # Get the potential lines from any specified catalogs to use in matching + self._potential_wavelengths = concatenate_catalogs(self._catalog) + + def identify_lines(self): + """ + ToDo: Code matching algorithm between line pixel locations and potential line + wavelengths from catalogs. + """ + pass + + def _clear_cache(self, *attrs): + """ + provide convenience function to clearing the cache for cached_properties + """ + if not len(attrs): + attrs = self._cached_properties + for attr in attrs: + if attr in self.__dict__: + del self.__dict__[attr] + + @property + def available_catalogs(self): + return self._available_catalogs + + @property + def input_spectrum(self): + return self._input_spectrum + + @input_spectrum.setter + def input_spectrum(self, new_spectrum): + # We want to clear the refined locations if a new calibration spectrum is provided + self._clear_cache() + self._input_spectrum = new_spectrum + + @property + def model(self): + return self._model + + @model.setter + def model(self, new_model): + self._clear_cache() + self._model = new_model + + @cached_property + def wcs(self): + # computes and returns WCS after fitting self.model to self.refined_pixels + x = self._matched_line_list["pixel_center"] + y = self._matched_line_list["wavelength"] + + if self.fitter is None: + # Flexible defaulting if self.fitter is None + if self.model.linear: + fitter = LinearLSQFitter(calc_uncertainties=True) + else: + fitter = LMLSQFitter(calc_uncertainties=True) + else: + fitter = self.fitter + + # Fit the model + self._model = fitter(self._model, x, y) + + # Build a GWCS pipeline from the fitted model + pixel_frame = cf.CoordinateFrame(1, "SPECTRAL", [0,], axes_names=["x",], unit=[u.pix,]) + spectral_frame = cf.SpectralFrame(axes_names=["wavelength",], + unit=[self._matched_line_list["wavelength"].unit,]) + + pipeline = [(pixel_frame, self.model), (spectral_frame, None)] + + wcsobj = wcs.WCS(pipeline) + + return wcsobj + + def apply_to_spectrum(self, spectrum=None): + # returns spectrum1d with wavelength calibration applied + # actual line refinement and WCS solution should already be done so that this can + # be called on multiple science sources + spectrum = self.input_spectrum if spectrum is None else spectrum + updated_spectrum = Spectrum1D(spectrum.flux, wcs=self.wcs, mask=spectrum.mask, + uncertainty=spectrum.uncertainty) + return updated_spectrum