From a4db5e6daaef0c0db671a48b49cd187f6bce8330 Mon Sep 17 00:00:00 2001 From: Benjamin Alan Weaver Date: Wed, 15 Nov 2023 15:05:49 -0700 Subject: [PATCH] add tests for specutils round-trip --- .github/workflows/python-package.yml | 3 +- py/desispec/spectra.py | 125 +++++++++++++++++++++++++-- py/desispec/test/test_spectra.py | 74 +++++++++++----- 3 files changed, 174 insertions(+), 28 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index fc4b18980..89d413b62 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -18,7 +18,7 @@ jobs: matrix: os: [ubuntu-latest] python-version: ['3.9', '3.10'] # fuji+guadalupe, not ready for 3.10 yet - astropy-version: ['==5.0', '<5.1'] # fuji+guadalupe, latest + astropy-version: ['==5.0', '<5.1', '<6'] # fuji+guadalupe, latest fitsio-version: ['==1.1.6', '<2'] # fuji+guadalupe, latest numpy-version: ['<1.23'] # to keep asscalar, used by astropy env: @@ -78,6 +78,7 @@ jobs: python -m pip install pytest pytest-cov coveralls python -m pip install git+https://github.com/desihub/desiutil.git@${DESIUTIL_VERSION}#egg=desiutil python -m pip install -r requirements.txt + python -m pip install specutils python -m pip install -U 'numpy${{ matrix.numpy-version }}' python -m pip install -U 'astropy${{ matrix.astropy-version }}' python -m pip cache remove fitsio diff --git a/py/desispec/spectra.py b/py/desispec/spectra.py index d7e9709b4..9864eb620 100644 --- a/py/desispec/spectra.py +++ b/py/desispec/spectra.py @@ -23,6 +23,8 @@ _specutils_imported = True try: from specutils import SpectrumList, Spectrum1D + from astropy.units import Unit + from astropy.nddata import InverseVariance, StdDevUncertainty except ImportError: _specutils_imported = False @@ -674,12 +676,40 @@ def to_specutils(self): Raises ------ - ValueError + NameError If ``specutils`` is not available in the environment. """ if not _specutils_imported: - raise ValueError("specutils is not available in the environment.") + raise NameError("specutils is not available in the environment.") sl = SpectrumList() + AA = Unit('Angstrom') + specunit = Unit('10-17 erg cm-2 s-1 AA-1') + for i, band in enumerate(self.bands): + meta = {'band': band} + spectral_axis = self.wave[band] * AA + flux = self.flux[band] * specunit + uncertainty = InverseVariance(self.ivar[band] * (specunit**-2)) + mask = self.mask[band] != 0 + meta['int_mask'] = self.mask[band] + meta['resolution_data'] = self.resolution_data[band] + try: + meta['extra'] = self.extra[band] + except KeyError: + meta['extra'] = None + if i == 0: + # + # Only add these to the first item in the list. + # + meta['bands'] = self.bands + meta['fibermap'] = self.fibermap + meta['exp_fibermap'] = self.exp_fibermap + meta['desi_meta'] = self.meta + meta['single'] = self._single + meta['scores'] = self.scores + meta['scores_comments'] = self.scores_comments + meta['extra_catalog'] = self.extra_catalog + sl.append(Spectrum1D(flux=flux, spectral_axis=spectral_axis, + uncertainty=uncertainty, mask=mask, meta=meta)) return sl @classmethod @@ -698,21 +728,102 @@ def from_specutils(cls, spectra): Raises ------ - ValueError + NameError If ``specutils`` is not available in the environment. + ValueError + If an unknown type is found in `spectra`. """ if not _specutils_imported: - raise ValueError("specutils is not available in the environment.") + raise NameError("specutils is not available in the environment.") if isinstance(spectra, SpectrumList): - pass + try: + bands = spectra[0].meta['bands'] + except KeyError: + # + # This is a big assumption; it doesn't capture ['b', 'z'] or ['r', 'z']. + # + bands = ['b', 'r', 'z'][0:len(spectra)] elif isinstance(spectra, Spectrum1D): # # Assume this is a coadd across cameras. # - pass + try: + bands = spectra.meta['bands'] + except KeyError: + bands = ['brz'] else: raise ValueError("Unknown type input to from_specutils!") - return cls() + # + # Load objects that are independent of band from the first item. + # + try: + fibermap = spectra[0].meta['fibermap'] + except KeyError: + fibermap = None + try: + exp_fibermap = spectra[0].meta['exp_fibermap'] + except KeyError: + exp_fibermap = None + try: + meta = spectra[0].meta['desi_meta'] + except KeyError: + meta = None + try: + single = spectra[0].meta['single'] + except KeyError: + single = False + try: + scores = spectra[0].meta['scores'] + except KeyError: + scores = None + try: + scores_comments = spectra[0].meta['scores_comments'] + except KeyError: + scores_comments = None + try: + extra_catalog = spectra[0].meta['extra_catalog'] + except KeyError: + extra_catalog = None + # + # Load band-dependent quantities. + # + wave = dict() + flux = dict() + ivar = dict() + mask = dict() + resolution_data = None + extra = None + for i, band in enumerate(bands): + wave[band] = spectra[i].spectral_axis.value + flux[band] = spectra[i].flux.value + if isinstance(spectra[i].uncertainty, InverseVariance): + ivar[band] = spectra[i].uncertainty.array + elif isinstance(spectra[i].uncertainty, StdDevUncertainty): + # Future: may need np.isfinite() here? + ivar[band] = (spectra[i].uncertainty.array)**-2 + else: + raise ValueError("Unknown uncertainty type!") + try: + mask[band] = spectra[i].meta['int_mask'] + except KeyError: + try: + mask[band] = spectra.mask.astype(np.int32) + except AttributeError: + mask[band] = np.zeros(flux.shape, dtype=np.int32) + if 'resolution_data' in spectra[i].meta: + if resolution_data is None: + resolution_data = {band: spectra[i].meta['resolution_data']} + else: + resolution_data[band] = spectra[i].meta['resolution_data'] + if 'extra' in spectra[i].meta: + if extra is None: + extra = {band: spectra[i].meta['extra']} + else: + extra[band] = spectra[i].meta['extra'] + return cls(bands=bands, wave=wave, flux=flux, ivar=ivar, mask=mask, + resolution_data=resolution_data, fibermap=fibermap, exp_fibermap=exp_fibermap, + meta=meta, extra=extra, single=single, scores=scores, + scores_comments=scores_comments, extra_catalog=extra_catalog) def stack(speclist): diff --git a/py/desispec/test/test_spectra.py b/py/desispec/test/test_spectra.py index 42f1cee7e..3d6166367 100644 --- a/py/desispec/test/test_spectra.py +++ b/py/desispec/test/test_spectra.py @@ -14,6 +14,14 @@ from astropy.table import Table, vstack +_specutils_imported = True +try: + from specutils import SpectrumList, Spectrum1D + # from astropy.units import Unit + # from astropy.nddata import InverseVariance, StdDevUncertainty +except ImportError: + _specutils_imported = False + from desiutil.io import encode_table from desispec.io import empty_fibermap from desispec.io.util import add_columns @@ -46,7 +54,7 @@ def setUp(self): ['NIGHT', 'EXPID', 'TILEID'], [np.int32(0), np.int32(0), np.int32(0)], ) - + for s in range(self.nspec): fmap[s]["TARGETID"] = 456 + s fmap[s]["FIBER"] = 123 + s @@ -94,9 +102,9 @@ def setUp(self): self.flux[b] = np.repeat(np.arange(self.nspec, dtype=float), self.nwave).reshape( (self.nspec, self.nwave) ) + 3.0 self.ivar[b] = 1.0 / self.flux[b] - self.mask[b] = np.tile(np.arange(2, dtype=np.uint32), + self.mask[b] = np.tile(np.arange(2, dtype=np.uint32), (self.nwave * self.nspec) // 2).reshape( (self.nspec, self.nwave) ) - self.res[b] = np.zeros( (self.nspec, self.ndiag, self.nwave), + self.res[b] = np.zeros( (self.nspec, self.ndiag, self.nwave), dtype=np.float64) self.res[b][:,1,:] = 1.0 self.extra[b] = {} @@ -117,7 +125,6 @@ def tearDown(self): os.remove(self.filebuild) pass - def verify(self, spec, fmap): for key, val in self.meta.items(): assert(key in spec.meta) @@ -135,12 +142,11 @@ def verify(self, spec, fmap): if spec.extra_catalog is not None: assert(np.all(spec.extra_catalog == self.extra_catalog)) - def test_io(self): # manually create the spectra and write - spec = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, - ivar=self.ivar, mask=self.mask, resolution_data=self.res, + spec = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, + ivar=self.ivar, mask=self.mask, resolution_data=self.res, fibermap=self.fmap1, meta=self.meta, extra=self.extra) self.verify(spec, self.fmap1) @@ -165,8 +171,8 @@ def test_io(self): self.verify(comp, self.fmap1) # test I/O with the extra_catalog HDU enabled - spec = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, - ivar=self.ivar, mask=self.mask, resolution_data=self.res, + spec = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, + ivar=self.ivar, mask=self.mask, resolution_data=self.res, fibermap=self.fmap1, meta=self.meta, extra=self.extra, extra_catalog=self.extra_catalog) @@ -259,8 +265,6 @@ def test_read_rows(self): with self.assertRaises(ValueError): subset = read_spectra(self.fileio, rows=rows, targetids=[1,2]) - - def test_read_columns(self): """test reading while subselecting columns""" # manually create the spectra and write @@ -302,9 +306,9 @@ def test_empty(self): other = {} for b in self.bands: - other[b] = Spectra(bands=[b], wave={b : self.wave[b]}, - flux={b : self.flux[b]}, ivar={b : self.ivar[b]}, - mask={b : self.mask[b]}, resolution_data={b : self.res[b]}, + other[b] = Spectra(bands=[b], wave={b : self.wave[b]}, + flux={b : self.flux[b]}, ivar={b : self.ivar[b]}, + mask={b : self.mask[b]}, resolution_data={b : self.res[b]}, fibermap=self.fmap1, meta=self.meta, extra={b : self.extra[b]}) for b in self.bands: @@ -315,18 +319,18 @@ def test_empty(self): dummy = Spectra() spec.update(dummy) - self.verify(spec, self.fmap1) + self.verify(spec, self.fmap1) path = write_spectra(self.filebuild, spec) def test_updateselect(self): - spec = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar, - mask=self.mask, resolution_data=self.res, fibermap=self.fmap1, + spec = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar, + mask=self.mask, resolution_data=self.res, fibermap=self.fmap1, meta=self.meta, extra=self.extra) - other = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar, - mask=self.mask, resolution_data=self.res, fibermap=self.fmap2, + other = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar, + mask=self.mask, resolution_data=self.res, fibermap=self.fmap2, meta=self.meta, extra=self.extra) spec.update(other) @@ -379,7 +383,6 @@ def test_updateselect(self): nt.assert_array_equal(spec.fibermap['NIGHT'][self.nspec:], 0) nt.assert_array_equal(spec.fibermap['TARGETID'][0:self.nspec], spec.fibermap['TARGETID'][self.nspec:]) - def test_stack(self): """Test desispec.spectra.stack""" sp1 = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar, @@ -467,6 +470,37 @@ def test_slice(self): for band in self.bands: self.assertEqual(sp2.flux[band].shape[0], 3) + @unittest.skipUnless(_specutils_imported, "Unable to import specutils.") + def test_to_specutils(self): + """Test conversion to a specutils object. + """ + sp1 = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar, + mask=self.mask, resolution_data=self.res, + fibermap=self.fmap1, exp_fibermap=self.efmap1, + meta=self.meta, extra=self.extra, scores=self.scores, + extra_catalog=self.extra_catalog) + sl = sp1.to_specutils() + self.assertEqual(sl[0].meta['single'], sp1._single) + self.assertTrue((sl[0].mask == (sp1.mask[self.bands[0]] != 0)).all()) + self.assertTrue((sl[1].flux.value == sp1.flux[sp1.bands[1]]).all()) + + @unittest.skipUnless(_specutils_imported, "Unable to import specutils.") + def test_from_specutils(self): + """Test conversion from a specutils object. + """ + sp1 = Spectra(bands=self.bands, wave=self.wave, flux=self.flux, ivar=self.ivar, + mask=self.mask, resolution_data=self.res, + fibermap=self.fmap1, exp_fibermap=self.efmap1, + meta=self.meta, extra=self.extra, scores=self.scores, + extra_catalog=self.extra_catalog) + spectrum_list = sp1.to_specutils() + sp2 = Spectra.from_specutils(spectrum_list) + self.assertListEqual(sp1.bands, sp2.bands) + self.assertTrue((sp1.flux[self.bands[0]] == sp2.flux[self.bands[0]]).all()) + self.assertTrue((sp1.ivar[self.bands[1]] == sp2.ivar[self.bands[1]]).all()) + self.assertTrue((sp1.mask[self.bands[2]] == sp2.mask[self.bands[2]]).all()) + self.assertDictEqual(sp1.meta, sp2.meta) + def test_suite(): """Allows testing of only this module with the command::