diff --git a/py/desispec/io/spectra.py b/py/desispec/io/spectra.py index 809b5275a..3c45a2ab2 100644 --- a/py/desispec/io/spectra.py +++ b/py/desispec/io/spectra.py @@ -35,6 +35,7 @@ from ..spectra import Spectra, stack from .meta import specprod_root +from ..util import argmatch def write_spectra(outfile, spec, units=None): """ @@ -497,17 +498,17 @@ def read_frame_as_spectra(filename, night=None, expid=None, band=None, single=Fa return spec -def read_tile_spectra(tileid, night, specprod=None, reduxdir=None, coadd=False, +def read_tile_spectra(tileid, night=None, specprod=None, reduxdir=None, coadd=False, single=False, targets=None, fibers=None, redrock=True, - group=None): + group='cumulative'): """ Read and return combined spectra for a tile/night Args: tileid (int) : Tile ID - night (int or str) : YEARMMDD night or tile group, e.g. 'deep' or 'all' Options: + night (int or str) : YEARMMDD night specprod (str) : overrides $SPECPROD reduxdir (str) : overrides $DESI_SPECTRO_REDUX/$SPECPROD coadd (bool) : if True, read coadds instead of per-exp spectra @@ -532,12 +533,15 @@ def read_tile_spectra(tileid, night, specprod=None, reduxdir=None, coadd=False, #- will automatically use $SPECPROD if specprod=None reduxdir = specprod_root(specprod) - tiledir = os.path.join(reduxdir, 'tiles') + tiledir = os.path.join(reduxdir, 'tiles', group) + if night is None: + nightdirglob = os.path.join(tiledir, str(tileid), '*') + tilenightdirs = sorted(glob.glob(nightdirglob)) + night = os.path.basename(tilenightdirs[-1]) + nightstr = str(night) - if group is not None: - tiledir = os.path.join(tiledir, group) - if group == 'cumulative': - nightstr = 'thru'+nightstr + if group == 'cumulative': + nightstr = 'thru'+nightstr tiledir = os.path.join(tiledir, str(tileid), str(night)) @@ -548,10 +552,11 @@ def read_tile_spectra(tileid, night, specprod=None, reduxdir=None, coadd=False, log.debug(f'Reading spectra from {tiledir}') prefix = 'spectra' - specfiles = glob.glob(f'{tiledir}/{prefix}-?-{tileid}-{nightstr}.fits*') + specglob = f'{tiledir}/{prefix}-?-{tileid}-{nightstr}.fits*' + specfiles = glob.glob(specglob) if len(specfiles) == 0: - raise ValueError(f'No spectra found in {tiledir}') + raise ValueError(f'No spectra found in {specglob}') specfiles = sorted(specfiles) @@ -559,6 +564,18 @@ def read_tile_spectra(tileid, night, specprod=None, reduxdir=None, coadd=False, redshifts = list() for filename in specfiles: log.debug(f'reading {os.path.basename(filename)}') + + #- if filtering by fibers, check if we need to read this file + if fibers is not None: + # filenames are like prefix-PETAL-tileid-night.* + thispetal = int(os.path.basename(filename).split('-')[1]) + petals = np.asarray(fibers)//500 + + if not np.any(np.isin(thispetal, petals)): + log.debug('Skipping petal %d, not needed by fibers %s', + thispetal, fibers) + continue + sp = read_spectra(filename, single=single) if targets is not None: keep = np.in1d(sp.fibermap['TARGETID'], targets) @@ -572,29 +589,17 @@ def read_tile_spectra(tileid, night, specprod=None, reduxdir=None, coadd=False, if redrock: #- Read matching redrock file for this spectra/coadd file rrfile = os.path.basename(filename).replace(prefix, 'redrock', 1) - log.debug(f'Reading {rrfile}') - rrfile = os.path.join(tiledir, rrfile) + rrfile = checkgzip(os.path.join(tiledir, rrfile)) + log.debug(f'Reading {os.path.basename(rrfile)}') rr = Table.read(rrfile, 'REDSHIFTS') #- Trim rr to only have TARGETIDs in filtered spectra sp keep = np.in1d(rr['TARGETID'], sp.fibermap['TARGETID']) rr = rr[keep] - #- spectra files can have multiple entries per TARGETID, - #- while redrock files have only 1. Expand to match spectra. - #- Note: astropy.table.join changes the order - if len(sp.fibermap) > len(rr): - rrx = Table() - rrx['TARGETID'] = sp.fibermap['TARGETID'] - rrx = astropy.table.join(rrx, rr, keys='TARGETID') - else: - rrx = rr - - #- Sort the rrx Table to match the order of sp['TARGETID'] - ii = np.argsort(sp.fibermap['TARGETID']) - jj = np.argsort(rrx['TARGETID']) - kk = np.argsort(ii[jj]) - rrx = rrx[kk] + #- match the Redrock entries to the spectra fibermap entries + ii = argmatch(rr['TARGETID'], sp.fibermap['TARGETID']) + rrx = rr[ii] #- Confirm that we got all that expanding and sorting correct assert np.all(sp.fibermap['TARGETID'] == rrx['TARGETID']) diff --git a/py/desispec/test/test_spectra.py b/py/desispec/test/test_spectra.py index 3053586fa..4301817d5 100644 --- a/py/desispec/test/test_spectra.py +++ b/py/desispec/test/test_spectra.py @@ -8,6 +8,7 @@ import time import copy import warnings +import tempfile import numpy as np import numpy.testing as nt @@ -15,9 +16,11 @@ from astropy.table import Table, vstack from desiutil.io import encode_table -from desispec.io import empty_fibermap +from desispec.io import empty_fibermap, findfile +from desispec.io import read_tile_spectra from desispec.io.util import add_columns import desispec.coaddition +from desispec.test.util import get_blank_spectra # Import all functions from the module we are testing. from desispec.spectra import * @@ -25,6 +28,44 @@ class TestSpectra(unittest.TestCase): + @classmethod + def setUpClass(cls): + """Create specprod directory structure""" + cls.testDir = tempfile.mkdtemp() + cls.origEnv = { + "SPECPROD": None, + "DESI_SPECTRO_REDUX": None, + } + cls.testEnv = { + 'SPECPROD':'dailytest', + "DESI_SPECTRO_REDUX": os.path.join(cls.testDir, 'spectro', 'redux'), + } + + for e in cls.origEnv: + if e in os.environ: + cls.origEnv[e] = os.environ[e] + os.environ[e] = cls.testEnv[e] + + cls.reduxdir = os.path.join( + cls.testEnv['DESI_SPECTRO_REDUX'], + cls.testEnv['SPECPROD']) + + os.makedirs(cls.reduxdir, exist_ok=True) + + @classmethod + def tearDownClass(cls): + """Cleanup test files if they exist. + """ + for e in cls.origEnv: + if cls.origEnv[e] is None: + del os.environ[e] + else: + os.environ[e] = cls.origEnv[e] + + if os.path.exists(cls.testDir): + shutil.rmtree(cls.testDir) + + def setUp(self): #- catch specific warnings so that we can find and fix # warnings.filterwarnings("error", ".*did not parse as fits unit.*") @@ -466,3 +507,97 @@ def test_slice(self): sp2 = sp1[[True,False,True,False,True]] for band in self.bands: self.assertEqual(sp2.flux[band].shape[0], 3) + + def test_read_tile_spectra(self): + """test desispec.io.read_tile_spectra""" + + #----- + #- Setup + np.random.seed(0) + nspec = 5 + nspec2 = 2 + tileid = 100 + night = 20201010 + spectra = get_blank_spectra(nspec) + spectra.fibermap['TARGETID'] = 100000 + np.arange(nspec) + spectra.fibermap['FIBER'] = np.arange(nspec) + spectra.fibermap['TILEID'] = 1234 + + #- extend with extra exposures of the first two targets + spectra = stack([spectra, spectra[0:nspec2]]) + + #- coadd_spectra is in-place update, so generate another copy + coadd = spectra[:] + desispec.coaddition.coadd(coadd, onetile=True) + + #- bookkeeping checks + self.assertEqual(len(spectra.fibermap), nspec+nspec2) + self.assertEqual(len(coadd.fibermap), nspec) + self.assertEqual(len(np.unique(spectra.fibermap['TARGETID'])), + len(np.unique(coadd.fibermap['TARGETID']))) + + #- Fake Redrock catalog + zcat = Table() + zcat['TARGETID'] = coadd.fibermap['TARGETID'] + zcat['Z'] = np.ones(nspec) + zcat['ZERR'] = 1e-6 * np.ones(nspec) + zcat['ZWARN'] = np.zeros(nspec, dtype=np.int32) + zcat['SPECTYPE'] = 'QSO' + zcat['SUBTYPE'] = 'LOZ' + zcat.meta['EXTNAME'] = 'REDSHIFTS' + + #- Write files + npetal = 3 + for petal in range(npetal): + specfile = findfile('spectra', tile=tileid, night=night, spectrograph=petal) + coaddfile = findfile('coadd', tile=tileid, night=night, spectrograph=petal) + rrfile = findfile('redrock', tile=tileid, night=night, spectrograph=petal) + + os.makedirs(os.path.dirname(specfile), exist_ok=True) + + write_spectra(specfile, spectra) + write_spectra(coaddfile, coadd) + zcat.write(rrfile) + + #- increment FIBERs and TARGETIDs for next petal + spectra.fibermap['FIBER'] += 500 + coadd.fibermap['FIBER'] += 500 + coadd.exp_fibermap['FIBER'] += 500 + + spectra.fibermap['TARGETID'] += 10000 + coadd.fibermap['TARGETID'] += 10000 + coadd.exp_fibermap['TARGETID'] += 10000 + zcat['TARGETID'] += 10000 + + #----- + #- Try reading it + + #- spectra + spectra, redshifts = read_tile_spectra(tileid, night=night, coadd=False, redrock=True) + self.assertEqual(len(spectra.fibermap), npetal*(nspec+nspec2)) + self.assertEqual(len(spectra.fibermap), len(redshifts)) + self.assertTrue(np.all(spectra.fibermap['TARGETID'] == redshifts['TARGETID'])) + + #- coadd + spectra, redshifts = read_tile_spectra(tileid, night=night, coadd=True, redrock=True) + self.assertEqual(len(spectra.fibermap), npetal*nspec) + self.assertEqual(len(spectra.fibermap), len(redshifts)) + self.assertTrue(np.all(spectra.fibermap['TARGETID'] == redshifts['TARGETID'])) + + #- coadd without redrock + spectra = read_tile_spectra(tileid, night=night, coadd=True, redrock=False) + self.assertEqual(len(spectra.fibermap), npetal*nspec) + + #- subset of fibers + #- Note: test files only have 5 spectra, so test fiber%500 < 5 + fibers = [1,3,502] + spectra, redshifts = read_tile_spectra(tileid, night=night, coadd=True, fibers=fibers, redrock=True) + self.assertEqual(len(spectra.fibermap), 3) + self.assertEqual(list(spectra.fibermap['FIBER']), fibers) + self.assertEqual(list(spectra.fibermap['TARGETID']), list(redshifts['TARGETID'])) + + #- auto-derive night + sp1 = read_tile_spectra(tileid, night=night, redrock=False) + sp2 = read_tile_spectra(tileid, redrock=False) + self.assertTrue(np.all(sp1.fibermap == sp2.fibermap)) + diff --git a/py/desispec/test/test_util.py b/py/desispec/test/test_util.py index 34ae16ef9..90eed51ba 100644 --- a/py/desispec/test/test_util.py +++ b/py/desispec/test/test_util.py @@ -405,3 +405,80 @@ def test_parse_keyval(self): key, value = util.parse_keyval("biz=False ") self.assertEqual(type(value), bool) self.assertEqual(value, False) + + def test_argmatch(self): + #- basic argmatch + a = np.array([1,3,2,4]) + b = np.array([3,2,1,4]) + ii = util.argmatch(a, b) + self.assertTrue(np.all(a[ii] == b), f'{a=}, {ii=}, {a[ii]=} != {b=}') + + #- b with duplicates + a = np.array([1,3,2,4]) + b = np.array([3,2,1,4,2,3]) + ii = util.argmatch(a, b) + self.assertTrue(np.all(a[ii] == b), f'{a=}, {ii=}, {a[ii]=} != {b=}') + + #- special case already matching + a = np.array([1,3,2,4]) + b = a.copy() + ii = util.argmatch(a, b) + self.assertTrue(np.all(a[ii] == b), f'{a=}, {ii=}, {a[ii]=} != {b=}') + + #- special case already sorted + a = np.array([1,2,3,4]) + b = a.copy() + ii = util.argmatch(a, b) + self.assertTrue(np.all(a[ii] == b), f'{a=}, {ii=}, {a[ii]=} != {b=}') + + #- a with extras (before, in middle, and after range of b values) + a = np.array([1,3,2,4,0,5]) + b = np.array([3,1,4]) + ii = util.argmatch(a, b) + self.assertTrue(np.all(a[ii] == b), f'{a=}, {ii=}, {a[ii]=} != {b=}') + + #- a has duplicates + a = np.array([1,3,3,2,4]) + b = np.array([3,2,1,4]) + ii = util.argmatch(a, b) + self.assertTrue(np.all(a[ii] == b), f'{a=}, {ii=}, {a[ii]=} != {b=}') + + #- equal length arrays, not not equal values + a = np.array([1,3,2,4]) + b = np.array([3,1,1,2]) + ii = util.argmatch(a, b) + self.assertTrue(np.all(a[ii] == b), f'{a=}, {ii=}, {a[ii]=} != {b=}') + + a = np.array([1,3,2,4,2]) + b = np.array([3,1,1,2,4]) + ii = util.argmatch(a, b) + self.assertTrue(np.all(a[ii] == b), f'{a=}, {ii=}, {a[ii]=} != {b=}') + + #- a can have extras, but not b + a = np.array([1,3,2,4]) + b = np.array([3,2,5,4]) + with self.assertRaises(ValueError): + ii = util.argmatch(a, b) + + #- Brute force random testing with shuffles + a = np.arange(10) + b = a.copy() + for test in range(100): + np.random.shuffle(a) + np.random.shuffle(b) + ii = util.argmatch(a,b) + self.assertTrue(np.all(a[ii] == b), f'test number {test}\n{a=}\n{ii=}\n{a[ii]=} !=\n{b=}') + + #- Brute force random testing with repeats and extras + for test in range(100): + a = np.random.randint(0,20, size=50) + b = np.random.randint(5,15, size=51) + + #- all values in b must be in a, so remove extras in b + #- Note: extras in a is ok, just not in b + keep = np.isin(b, a) + b = b[keep] + + ii = util.argmatch(a,b) + self.assertTrue(np.all(a[ii] == b), f'test number {test}\n{a=}\n{ii=}\n{a[ii]=} !=\n{b=}') + diff --git a/py/desispec/test/util.py b/py/desispec/test/util.py index b1f951530..e82bb734c 100644 --- a/py/desispec/test/util.py +++ b/py/desispec/test/util.py @@ -5,6 +5,7 @@ from desispec.resolution import Resolution from desispec.frame import Frame +from desispec.spectra import Spectra from desispec.io import empty_fibermap from desitarget.targetmask import desi_mask @@ -87,6 +88,10 @@ def get_models(nspec=10, nwave=1000, wavemin=4000, wavemax=5000): def set_resolmatrix(nspec,nwave): + """arguably typo function name, retained for backwards compatibility""" + return get_resolmatrix(nspec, nwave) + +def get_resolmatrix(nspec,nwave): """ Generate a Resolution Matrix Args: nspec: int @@ -107,3 +112,57 @@ def set_resolmatrix(nspec,nwave): kernel /= sum(kernel) Rdata[i,:,j] = kernel return Rdata + +def get_resolmatrix_fixedsigma(nspec,nwave): + """ Generate a Resolution Matrix with fixed sigma + Args: + nspec: int + nwave: int + + Returns: + Rdata: np.array + + """ + sigma = 3.0 + ndiag = 21 + xx = np.linspace(-ndiag/2.0, +ndiag/2.0, ndiag) + kernel = np.exp(-xx**2/(2*sigma**2)) + kernel /= sum(kernel) + Rdata = np.zeros( (nspec, len(xx), nwave) ) + + for i in range(nspec): + for j in range(nwave): + Rdata[i,:,j] = kernel + + return Rdata + +def get_blank_spectra(nspec): + """Generate a blank spectrum object with realistic wavelength coverage""" + + wave = dict( + b=np.arange(3600, 5800.1, 0.8), + r=np.arange(5760, 7620.1, 0.8), + z=np.arange(7520, 9824.1, 0.8), + ) + bands = tuple(wave.keys()) + flux = dict() + ivar = dict() + mask = dict() + rdat = dict() + for band in bands: + nwave = len(wave[band]) + flux[band] = np.ones((nspec, nwave)) + ivar[band] = np.zeros((nspec, nwave)) + mask[band] = np.zeros((nspec, nwave), dtype=np.int32) + rdat[band] = get_resolmatrix_fixedsigma(nspec, nwave) + + fm = empty_fibermap(nspec) + fm['FIBER'] = np.arange(nspec, dtype=np.int32) + fm['TARGETID'] = np.arange(nspec, dtype=np.int64) + + sp = Spectra(bands=bands, wave=wave, flux=flux, ivar=ivar, mask=mask, + resolution_data=rdat, fibermap=fm) + + return sp + + diff --git a/py/desispec/util.py b/py/desispec/util.py index 47397910c..4c74e820b 100644 --- a/py/desispec/util.py +++ b/py/desispec/util.py @@ -684,3 +684,42 @@ def itemindices(a): return idmap +def argmatch(a, b): + """ + Returns indices ii such that a[ii] == b + + Args: + a: array-like + b: array-like + + Returns indices ii such that a[ii] == b + + Both `a` and `b` are allowed to have repeats, and `a` values can be a + superset of `b`, but `b` cannot contain values that are not in `a` + because then no indices `ii` could result in `a[ii] == b`. + + Related: desitarget.geomask.match_to which is similar, but doesn't allow + duplicates in `b`. + """ + a = np.asarray(a) + b = np.asarray(b) + ii = np.argsort(a) + jj = np.argsort(b) + kk = np.searchsorted(a[ii], b[jj]) + try: + match_indices = ii[kk[np.argsort(jj)]] + except IndexError: + #- if b has elements not in a, that can fail; + #- only do expensive check if needed + bad_b = np.isin(b, a, invert=True) + if np.any(bad_b): + raise ValueError(f'b contains values not in a; impossible to match {set(b[bad_b])} to {a=}') + else: + #- this should not occur + raise RuntimeError(f'argmatch failure for unknown reason {a=}, {b=}') + + if not np.all(a[match_indices] == b): + #- this should not occur + raise RuntimeError(f'argmatch failure for unknown reason {a=} {match_indices=} {a[match_indices]=} != {b}') + + return match_indices