Skip to content

Commit

Permalink
Allow for different uncertainty types in template_comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
rosteen committed Nov 18, 2021
1 parent dc890d2 commit 31c083a
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 8 deletions.
45 changes: 38 additions & 7 deletions specutils/analysis/template_comparison.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from astropy.nddata import StdDevUncertainty, VarianceUncertainty, InverseVariance

from ..manipulation import (FluxConservingResampler,
LinearInterpolatedResampler,
Expand All @@ -8,7 +9,33 @@
__all__ = ['template_match', 'template_redshift']


def _normalize_for_template_matching(observed_spectrum, template_spectrum):
def _uncertainty_to_standard_deviation(uncertainty):
"""
Convenience function to convert other uncertainty types to standard deviation,
for consistency in calculations elsewhere.
Parameters
----------
uncertainty : :class:`~astropy.nddata.NDUncertainty`
The input uncertainty
Returns
-------
:class:`~numpy.ndarray`
The array of standard deviation values.
"""
if uncertainty is not None:
if isinstance(uncertainty, StdDevUncertainty):
stddev = uncertainty.array
elif isinstance(uncertainty, VarianceUncertainty):
stddev = np.sqrt(uncertainty.array)
elif isinstance(uncertainty, InverseVariance):
stddev = 1 / np.sqrt(uncertainty.array)

return stddev

def _normalize_for_template_matching(observed_spectrum, template_spectrum, stddev=None):
"""
Calculate a scale factor to be applied to the template spectrum so the
total flux in both spectra will be the same.
Expand All @@ -27,10 +54,10 @@ def _normalize_for_template_matching(observed_spectrum, template_spectrum):
A float which will normalize the template spectrum's flux so that it
can be compared to the observed spectrum.
"""
num = np.sum((observed_spectrum.flux*template_spectrum.flux) /
(observed_spectrum.uncertainty.array**2))
denom = np.sum((template_spectrum.flux /
observed_spectrum.uncertainty.array)**2)
if stddev is None:
stddev = _uncertainty_to_standard_deviation(observed_spectrum.uncertainty)
num = np.sum((observed_spectrum.flux*template_spectrum.flux) / (stddev**2))
denom = np.sum((template_spectrum.flux / stddev)**2)

return num/denom

Expand Down Expand Up @@ -89,16 +116,20 @@ def _chi_square_for_templates(observed_spectrum, template_spectrum, resample_met
template_obswavelength = fluxc_resample(template_spectrum,
observed_spectrum.spectral_axis)

# Convert the uncertainty to standard deviation if needed
stddev = _uncertainty_to_standard_deviation(observed_spectrum.uncertainty)

# Normalize spectra
normalization = _normalize_for_template_matching(observed_spectrum,
template_obswavelength)
template_obswavelength,
stddev)

# Numerator
num_right = normalization * template_obswavelength.flux
num = observed_spectrum.flux - num_right

# Denominator
denom = observed_spectrum.uncertainty.array * observed_spectrum.flux.unit
denom = stddev * observed_spectrum.flux.unit

# Get chi square
result = (num/denom)**2
Expand Down
58 changes: 57 additions & 1 deletion specutils/tests/test_template_comparison.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import astropy.units as u
import numpy as np
from astropy.nddata import StdDevUncertainty
from astropy.nddata import StdDevUncertainty, VarianceUncertainty, InverseVariance

from ..spectra.spectrum1d import Spectrum1D
from ..spectra.spectrum_collection import SpectrumCollection
Expand Down Expand Up @@ -367,3 +367,59 @@ def test_template_known_redshift():
assert len(tm_result) == 5
np.testing.assert_almost_equal(tm_result[1], redshift)
np.testing.assert_almost_equal(tm_result[3], 1.9062409482056814e-31)


def test_template_match_variance():
"""
Test template_match when both observed and template spectra have the same wavelength axis.
"""
# Seed np.random so that results are consistent
np.random.seed(42)

# Create test spectra
spec_axis = np.linspace(0, 50, 50) * u.AA
spec = Spectrum1D(spectral_axis=spec_axis,
flux=np.random.randn(50) * u.Jy,
uncertainty=VarianceUncertainty(np.random.sample(50)**2, unit='Jy2'))

spec1 = Spectrum1D(spectral_axis=spec_axis,
flux=np.random.randn(50) * u.Jy,
uncertainty=VarianceUncertainty(np.random.sample(50)**2, unit='Jy2'))

# Get result from template_match
tm_result = template_comparison.template_match(spec, spec1)

# Create new spectrum for comparison
spec_result = Spectrum1D(spectral_axis=spec_axis,
flux=spec1.flux * template_comparison._normalize_for_template_matching(spec, spec1))

assert quantity_allclose(tm_result[0].flux, spec_result.flux, atol=0.01*u.Jy)
np.testing.assert_almost_equal(tm_result[3], 40093.28353756253)


def test_template_match_inverse_variance():
"""
Test template_match when both observed and template spectra have the same wavelength axis.
"""
# Seed np.random so that results are consistent
np.random.seed(42)

# Create test spectra
spec_axis = np.linspace(0, 50, 50) * u.AA
spec = Spectrum1D(spectral_axis=spec_axis,
flux=np.random.randn(50) * u.Jy,
uncertainty=InverseVariance(1/np.random.sample(50)**2, unit='1 / Jy2'))

spec1 = Spectrum1D(spectral_axis=spec_axis,
flux=np.random.randn(50) * u.Jy,
uncertainty=InverseVariance(1/np.random.sample(50)**2, unit='1 / Jy2'))

# Get result from template_match
tm_result = template_comparison.template_match(spec, spec1)

# Create new spectrum for comparison
spec_result = Spectrum1D(spectral_axis=spec_axis,
flux=spec1.flux * template_comparison._normalize_for_template_matching(spec, spec1))

assert quantity_allclose(tm_result[0].flux, spec_result.flux, atol=0.01*u.Jy)
np.testing.assert_almost_equal(tm_result[3], 40093.28353756253)

0 comments on commit 31c083a

Please sign in to comment.