Skip to content

Commit

Permalink
add tests for specutils round-trip
Browse files Browse the repository at this point in the history
  • Loading branch information
weaverba137 committed Nov 15, 2023
1 parent e2c1581 commit a4db5e6
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 28 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
matrix:
os: [ubuntu-latest]
python-version: ['3.9', '3.10'] # fuji+guadalupe, not ready for 3.10 yet
astropy-version: ['==5.0', '<5.1'] # fuji+guadalupe, latest
astropy-version: ['==5.0', '<5.1', '<6'] # fuji+guadalupe, latest
fitsio-version: ['==1.1.6', '<2'] # fuji+guadalupe, latest
numpy-version: ['<1.23'] # to keep asscalar, used by astropy
env:
Expand Down Expand Up @@ -78,6 +78,7 @@ jobs:
python -m pip install pytest pytest-cov coveralls
python -m pip install git+https://github.com/desihub/desiutil.git@${DESIUTIL_VERSION}#egg=desiutil
python -m pip install -r requirements.txt
python -m pip install specutils
python -m pip install -U 'numpy${{ matrix.numpy-version }}'
python -m pip install -U 'astropy${{ matrix.astropy-version }}'
python -m pip cache remove fitsio
Expand Down
125 changes: 118 additions & 7 deletions py/desispec/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
_specutils_imported = True
try:
from specutils import SpectrumList, Spectrum1D
from astropy.units import Unit
from astropy.nddata import InverseVariance, StdDevUncertainty
except ImportError:
_specutils_imported = False

Expand Down Expand Up @@ -674,12 +676,40 @@ def to_specutils(self):
Raises
------
ValueError
NameError
If ``specutils`` is not available in the environment.
"""
if not _specutils_imported:
raise ValueError("specutils is not available in the environment.")
raise NameError("specutils is not available in the environment.")
sl = SpectrumList()
AA = Unit('Angstrom')
specunit = Unit('10-17 erg cm-2 s-1 AA-1')
for i, band in enumerate(self.bands):
meta = {'band': band}
spectral_axis = self.wave[band] * AA
flux = self.flux[band] * specunit
uncertainty = InverseVariance(self.ivar[band] * (specunit**-2))
mask = self.mask[band] != 0
meta['int_mask'] = self.mask[band]
meta['resolution_data'] = self.resolution_data[band]
try:
meta['extra'] = self.extra[band]
except KeyError:
meta['extra'] = None
if i == 0:
#
# Only add these to the first item in the list.
#
meta['bands'] = self.bands
meta['fibermap'] = self.fibermap
meta['exp_fibermap'] = self.exp_fibermap
meta['desi_meta'] = self.meta
meta['single'] = self._single
meta['scores'] = self.scores
meta['scores_comments'] = self.scores_comments
meta['extra_catalog'] = self.extra_catalog
sl.append(Spectrum1D(flux=flux, spectral_axis=spectral_axis,
uncertainty=uncertainty, mask=mask, meta=meta))
return sl

@classmethod
Expand All @@ -698,21 +728,102 @@ def from_specutils(cls, spectra):
Raises
------
ValueError
NameError
If ``specutils`` is not available in the environment.
ValueError
If an unknown type is found in `spectra`.
"""
if not _specutils_imported:
raise ValueError("specutils is not available in the environment.")
raise NameError("specutils is not available in the environment.")
if isinstance(spectra, SpectrumList):
pass
try:
bands = spectra[0].meta['bands']
except KeyError:
#
# This is a big assumption; it doesn't capture ['b', 'z'] or ['r', 'z'].
#
bands = ['b', 'r', 'z'][0:len(spectra)]
elif isinstance(spectra, Spectrum1D):
#
# Assume this is a coadd across cameras.
#
pass
try:
bands = spectra.meta['bands']
except KeyError:
bands = ['brz']
else:
raise ValueError("Unknown type input to from_specutils!")
return cls()
#
# Load objects that are independent of band from the first item.
#
try:
fibermap = spectra[0].meta['fibermap']
except KeyError:
fibermap = None
try:
exp_fibermap = spectra[0].meta['exp_fibermap']
except KeyError:
exp_fibermap = None
try:
meta = spectra[0].meta['desi_meta']
except KeyError:
meta = None
try:
single = spectra[0].meta['single']
except KeyError:
single = False
try:
scores = spectra[0].meta['scores']
except KeyError:
scores = None
try:
scores_comments = spectra[0].meta['scores_comments']
except KeyError:
scores_comments = None
try:
extra_catalog = spectra[0].meta['extra_catalog']
except KeyError:
extra_catalog = None
#
# Load band-dependent quantities.
#
wave = dict()
flux = dict()
ivar = dict()
mask = dict()
resolution_data = None
extra = None
for i, band in enumerate(bands):
wave[band] = spectra[i].spectral_axis.value
flux[band] = spectra[i].flux.value
if isinstance(spectra[i].uncertainty, InverseVariance):
ivar[band] = spectra[i].uncertainty.array
elif isinstance(spectra[i].uncertainty, StdDevUncertainty):
# Future: may need np.isfinite() here?
ivar[band] = (spectra[i].uncertainty.array)**-2
else:
raise ValueError("Unknown uncertainty type!")
try:
mask[band] = spectra[i].meta['int_mask']
except KeyError:
try:
mask[band] = spectra.mask.astype(np.int32)
except AttributeError:
mask[band] = np.zeros(flux.shape, dtype=np.int32)
if 'resolution_data' in spectra[i].meta:
if resolution_data is None:
resolution_data = {band: spectra[i].meta['resolution_data']}
else:
resolution_data[band] = spectra[i].meta['resolution_data']
if 'extra' in spectra[i].meta:
if extra is None:
extra = {band: spectra[i].meta['extra']}
else:
extra[band] = spectra[i].meta['extra']
return cls(bands=bands, wave=wave, flux=flux, ivar=ivar, mask=mask,
resolution_data=resolution_data, fibermap=fibermap, exp_fibermap=exp_fibermap,
meta=meta, extra=extra, single=single, scores=scores,
scores_comments=scores_comments, extra_catalog=extra_catalog)


def stack(speclist):
Expand Down
74 changes: 54 additions & 20 deletions py/desispec/test/test_spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@

from astropy.table import Table, vstack

_specutils_imported = True
try:
from specutils import SpectrumList, Spectrum1D
# from astropy.units import Unit
# from astropy.nddata import InverseVariance, StdDevUncertainty
except ImportError:
_specutils_imported = False

from desiutil.io import encode_table
from desispec.io import empty_fibermap
from desispec.io.util import add_columns
Expand Down Expand Up @@ -46,7 +54,7 @@ def setUp(self):
['NIGHT', 'EXPID', 'TILEID'],
[np.int32(0), np.int32(0), np.int32(0)],
)

for s in range(self.nspec):
fmap[s]["TARGETID"] = 456 + s
fmap[s]["FIBER"] = 123 + s
Expand Down Expand Up @@ -94,9 +102,9 @@ def setUp(self):
self.flux[b] = np.repeat(np.arange(self.nspec, dtype=float),
self.nwave).reshape( (self.nspec, self.nwave) ) + 3.0
self.ivar[b] = 1.0 / self.flux[b]
self.mask[b] = np.tile(np.arange(2, dtype=np.uint32),
self.mask[b] = np.tile(np.arange(2, dtype=np.uint32),
(self.nwave * self.nspec) // 2).reshape( (self.nspec, self.nwave) )
self.res[b] = np.zeros( (self.nspec, self.ndiag, self.nwave),
self.res[b] = np.zeros( (self.nspec, self.ndiag, self.nwave),
dtype=np.float64)
self.res[b][:,1,:] = 1.0
self.extra[b] = {}
Expand All @@ -117,7 +125,6 @@ def tearDown(self):
os.remove(self.filebuild)
pass


def verify(self, spec, fmap):
for key, val in self.meta.items():
assert(key in spec.meta)
Expand All @@ -135,12 +142,11 @@ def verify(self, spec, fmap):
if spec.extra_catalog is not None:
assert(np.all(spec.extra_catalog == self.extra_catalog))


def test_io(self):

# manually create the spectra and write
spec = Spectra(bands=self.bands, wave=self.wave, flux=self.flux,
ivar=self.ivar, mask=self.mask, resolution_data=self.res,
spec = Spectra(bands=self.bands, wave=self.wave, flux=self.flux,
ivar=self.ivar, mask=self.mask, resolution_data=self.res,
fibermap=self.fmap1, meta=self.meta, extra=self.extra)

self.verify(spec, self.fmap1)
Expand All @@ -165,8 +171,8 @@ def test_io(self):
self.verify(comp, self.fmap1)

# test I/O with the extra_catalog HDU enabled
spec = Spectra(bands=self.bands, wave=self.wave, flux=self.flux,
ivar=self.ivar, mask=self.mask, resolution_data=self.res,
spec = Spectra(bands=self.bands, wave=self.wave, flux=self.flux,
ivar=self.ivar, mask=self.mask, resolution_data=self.res,
fibermap=self.fmap1, meta=self.meta, extra=self.extra,
extra_catalog=self.extra_catalog)

Expand Down Expand Up @@ -259,8 +265,6 @@ def test_read_rows(self):
with self.assertRaises(ValueError):
subset = read_spectra(self.fileio, rows=rows, targetids=[1,2])



def test_read_columns(self):
"""test reading while subselecting columns"""
# manually create the spectra and write
Expand Down Expand Up @@ -302,9 +306,9 @@ def test_empty(self):

other = {}
for b in self.bands:
other[b] = Spectra(bands=[b], wave={b : self.wave[b]},
flux={b : self.flux[b]}, ivar={b : self.ivar[b]},
mask={b : self.mask[b]}, resolution_data={b : self.res[b]},
other[b] = Spectra(bands=[b], wave={b : self.wave[b]},
flux={b : self.flux[b]}, ivar={b : self.ivar[b]},
mask={b : self.mask[b]}, resolution_data={b : self.res[b]},
fibermap=self.fmap1, meta=self.meta, extra={b : self.extra[b]})

for b in self.bands:
Expand All @@ -315,18 +319,18 @@ def test_empty(self):
dummy = Spectra()
spec.update(dummy)

self.verify(spec, self.fmap1)
self.verify(spec, self.fmap1)

path = write_spectra(self.filebuild, spec)


def test_updateselect(self):
spec = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar,
mask=self.mask, resolution_data=self.res, fibermap=self.fmap1,
spec = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar,
mask=self.mask, resolution_data=self.res, fibermap=self.fmap1,
meta=self.meta, extra=self.extra)

other = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar,
mask=self.mask, resolution_data=self.res, fibermap=self.fmap2,
other = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar,
mask=self.mask, resolution_data=self.res, fibermap=self.fmap2,
meta=self.meta, extra=self.extra)

spec.update(other)
Expand Down Expand Up @@ -379,7 +383,6 @@ def test_updateselect(self):
nt.assert_array_equal(spec.fibermap['NIGHT'][self.nspec:], 0)
nt.assert_array_equal(spec.fibermap['TARGETID'][0:self.nspec], spec.fibermap['TARGETID'][self.nspec:])


def test_stack(self):
"""Test desispec.spectra.stack"""
sp1 = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar,
Expand Down Expand Up @@ -467,6 +470,37 @@ def test_slice(self):
for band in self.bands:
self.assertEqual(sp2.flux[band].shape[0], 3)

@unittest.skipUnless(_specutils_imported, "Unable to import specutils.")
def test_to_specutils(self):
"""Test conversion to a specutils object.
"""
sp1 = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar,
mask=self.mask, resolution_data=self.res,
fibermap=self.fmap1, exp_fibermap=self.efmap1,
meta=self.meta, extra=self.extra, scores=self.scores,
extra_catalog=self.extra_catalog)
sl = sp1.to_specutils()
self.assertEqual(sl[0].meta['single'], sp1._single)
self.assertTrue((sl[0].mask == (sp1.mask[self.bands[0]] != 0)).all())
self.assertTrue((sl[1].flux.value == sp1.flux[sp1.bands[1]]).all())

@unittest.skipUnless(_specutils_imported, "Unable to import specutils.")
def test_from_specutils(self):
"""Test conversion from a specutils object.
"""
sp1 = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar,
mask=self.mask, resolution_data=self.res,
fibermap=self.fmap1, exp_fibermap=self.efmap1,
meta=self.meta, extra=self.extra, scores=self.scores,
extra_catalog=self.extra_catalog)
spectrum_list = sp1.to_specutils()
sp2 = Spectra.from_specutils(spectrum_list)
self.assertListEqual(sp1.bands, sp2.bands)
self.assertTrue((sp1.flux[self.bands[0]] == sp2.flux[self.bands[0]]).all())
self.assertTrue((sp1.ivar[self.bands[1]] == sp2.ivar[self.bands[1]]).all())
self.assertTrue((sp1.mask[self.bands[2]] == sp2.mask[self.bands[2]]).all())
self.assertDictEqual(sp1.meta, sp2.meta)


def test_suite():
"""Allows testing of only this module with the command::
Expand Down

0 comments on commit a4db5e6

Please sign in to comment.