diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..f4952bab --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,8 @@ +version: 2 +updates: +- package-ecosystem: pip + directory: "/" + schedule: + interval: daily + time: "10:00" + open-pull-requests-limit: 10 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..05211cb1 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,57 @@ +name: CI + +on: + push: + branches: + - master + tags: + - "*" + pull_request: + schedule: + # Run every Monday at 6am UTC + - cron: "0 6 * * 1" + +jobs: + test: + name: ${{ matrix.name }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + include: + - name: Unit tests Python 3.7 + os: macos-latest + python-version: 3.7 + extras: test + test-command: pytest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + with: + fetch-depth: 0 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Upgrade pip + run: | + pip install --upgrade pip + - name: Add conda to system path + run: | + # $CONDA is an environment variable pointing to the root of the miniconda directory + echo $CONDA/bin >> $GITHUB_PATH + - name: Install dependencies + run: | + conda env update --file env/environment-${PYTHONVERSION}.yml --name sedkit-${{ matrix.python-version }} + env: + PYTHONVERSION: ${{ matrix.python-version }} + - name: Install package + run: | + pip install -e .[${{ matrix.extras }}] + - name: Update conda + run: | + conda update --all + - name: Test with pytest + run: | + conda run -n sedkit-${{ matrix.python-version }} pytest diff --git a/.github/workflows/sedkit_workflow.yml b/.github/workflows/sedkit_workflow.yml deleted file mode 100644 index f05470c6..00000000 --- a/.github/workflows/sedkit_workflow.yml +++ /dev/null @@ -1,31 +0,0 @@ -name: sedkit Workflow - -on: [push] - -jobs: - build-linux: - name: Python - ${{ matrix.python-version }} - runs-on: ${{ matrix.os }} - strategy: - max-parallel: 5 - matrix: - os: [ubuntu-latest] - python-version: ['3.7', '3.8'] - steps: - - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Add conda to system path - run: | - # $CONDA is an environment variable pointing to the root of the miniconda directory - echo $CONDA/bin >> $GITHUB_PATH - - name: Install dependencies - run: | - conda env update --file env/environment-${{ matrix.python-version }}.yml --name sedkit-${{ matrix.python-version }} - env: - PYTHONVERSION: ${{ matrix.python-version }} - - name: Test with pytest - run: | - conda run -n sedkit-${{ matrix.python-version }} pytest diff --git a/ci/install_conda.sh b/ci/install_conda.sh deleted file mode 100755 index 53335a52..00000000 --- a/ci/install_conda.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/bash - -if [ -d "$HOME/miniconda3" ] && [ -e "$HOME/miniconda3/bin/conda" ]; then - echo "Miniconda install already present from cache: $HOME/miniconda3" - rm -rf $HOME/miniconda3/envs/hosts # Just in case... -else - echo "Installing Miniconda..." - rm -rf $HOME/miniconda3 # Just in case... - - if [ "${TRAVIS_OS_NAME}" == "osx" ]; then - wget http://repo.continuum.io/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -O miniconda.sh || exit 1 - else - wget http://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh || exit 1 - fi - - bash miniconda.sh -b -p "$HOME/miniconda3" || exit 1 -fi - -echo "Configuring Miniconda..." -conda config --set ssl_verify false || exit 1 -conda config --set always_yes true --set changeps1 false || exit 1 - -echo "Updating Miniconda" -conda update conda -conda update --all -conda info -a || exit 1 - -echo "Installing numpy" -conda install numpy \ No newline at end of file diff --git a/ci/setup_conda_env.sh b/ci/setup_conda_env.sh deleted file mode 100755 index 56844d52..00000000 --- a/ci/setup_conda_env.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -echo "Creating conda environment for Python $PYTHON_VERSION" -conda env create -f "env/environment-${PYTHON_VERSION}.yml" || exit 1 -export CONDA_ENV=sedkit-$PYTHON_VERSION -source activate $CONDA_ENV -pip install pytest pytest-cov coveralls diff --git a/env/environment-3.7.yml b/env/environment-3.7.yml index 2f3ae0b4..2441274e 100644 --- a/env/environment-3.7.yml +++ b/env/environment-3.7.yml @@ -1,27 +1,26 @@ name: sedkit-3.7 + channels: - - conda-forge - - http://ssb.stsci.edu/astroconda - - defaults +- http://ssb.stsci.edu/astroconda +- conda-forge +- defaults dependencies: - - numpy>=1.18.1 - - astropy>=4.3.1 - - astroquery>=0.4.2 - - python>=3.7 - - bokeh>=1.4.0 - - scipy>=1.4.1 - - pandas>=0.23.4 +- pip +- python==3.7.1 +- pip: + - astropy==4.3.1 + - astroquery==0.4.6 + - bokeh>=2.2.3 + - dill==0.3.4 + - dustmaps==1.0.9 + - emcee>=3.1.1 + - ipython==7.32.0 + - jupyter==1.0.0 + - numpy>=1.21.5 + - pandas>=1.3.5 - pickle5>=0.0.11 - - pip>=20.3.3 - - selenium>=2.49.2 - - sphinx>=3.4.3 - - sphinx-automodapi>=0.13 - - pytest>=6.2.1 - - pygments>=2.7.4 - - jupyter>=1.0.0 - - ipython>=7.12.0 - - pip: - - svo-filters==0.4.1 - - dustmaps==1.0.4 - - numpydoc==0.8.0 - - emcee>=3.0.2 + - PyQt5>=5.15.6 + - pytest==7.1.1 + - sphinx==4.5.0 + - scipy==1.7.3 + - svo-filters==0.4.1 diff --git a/env/environment-3.8.yml b/env/environment-3.8.yml index 6700352e..26dd390f 100644 --- a/env/environment-3.8.yml +++ b/env/environment-3.8.yml @@ -1,27 +1,26 @@ name: sedkit-3.8 + channels: - - conda-forge - - http://ssb.stsci.edu/astroconda - - defaults +- http://ssb.stsci.edu/astroconda +- conda-forge +- defaults dependencies: - - numpy>=1.18.1 - - astropy>=4.3.1 - - astroquery>=0.4.2 - - python>=3.8 - - bokeh>=1.4.0 - - scipy>=1.4.1 - - pandas>=0.23.4 - - pip>=20.3.3 - - selenium>=2.49.2 - - sphinx>=3.4.3 - - sphinx-automodapi>=0.13 - - pytest>=6.2.1 - - pygments>=2.7.4 - - tqdm>=4.62.3 - - jupyter>=1.0.0 - - ipython>=7.12.0 - - pip: - - svo-filters==0.4.1 - - dustmaps==1.0.4 - - numpydoc==0.8.0 - - emcee>=3.0.2 +- pip +- python==3.8.10 +- pip: + - astropy==5.0.4 + - astroquery==0.4.6 + - bokeh>=2.2.3 + - dill==0.3.4 + - dustmaps==1.0.9 + - emcee>=3.1.1 + - ipython==8.2.0 + - jupyter==1.0.0 + - numpy>=1.21.5 + - pandas>=1.3.5 + - pickle5>=0.0.11 + - PyQt5>=5.15.6 + - pytest==7.1.1 + - sphinx==4.5.0 + - scipy==1.8.0 + - svo-filters==0.4.1 diff --git a/sedkit/catalog.py b/sedkit/catalog.py index ee8e69e0..1368249a 100755 --- a/sedkit/catalog.py +++ b/sedkit/catalog.py @@ -7,6 +7,7 @@ """ import os +import dill import pickle from copy import copy from pkg_resources import resource_filename @@ -16,9 +17,10 @@ import astropy.table as at import astropy.units as q import numpy as np -from bokeh.models import HoverTool, ColumnDataSource, LabelSet +from bokeh.models import HoverTool, ColumnDataSource, LabelSet, TapTool, CustomJS from bokeh.plotting import figure, show from bokeh.models.glyphs import Patch +from bokeh.layouts import Row from .sed import SED from . import utilities as u @@ -33,20 +35,23 @@ def __init__(self, name='SED Catalog', marker='circle', color='blue', verbose=Tr self.name = name self.marker = marker self.color = color + self.palette = 'viridis' self.wave_units = q.um self.flux_units = q.erg/q.s/q.cm**2/q.AA + self.array_cols = ['sky_coords', 'SED', 'app_spec_SED', 'abs_spec_SED', 'app_phot_SED', 'abs_phot_SED', 'app_specphot_SED', 'abs_specphot_SED', 'app_SED', 'abs_SED', 'spectra'] + self.phot_cols = [] # List all the results columns - self.cols = ['name', 'ra', 'dec', 'age', 'age_unc', 'distance', 'distance_unc', + self.cols = ['name', 'age', 'age_unc', 'distance', 'distance_unc', 'parallax', 'parallax_unc', 'radius', 'radius_unc', 'spectral_type', 'spectral_type_unc', 'SpT', 'membership', 'reddening', 'fbol', 'fbol_unc', 'mbol', 'mbol_unc', 'Lbol', 'Lbol_unc', 'Lbol_sun', 'Lbol_sun_unc', 'Mbol', 'Mbol_unc', 'logg', 'logg_unc', - 'mass', 'mass_unc', 'Teff', 'Teff_unc', 'SED'] + 'mass', 'mass_unc', 'Teff', 'Teff_unc'] # A master table of all SED results - self.results = self.make_results_table(self) + self._results = self.make_results_table() # Try to set attributes from kwargs for k, v in kwargs.items(): @@ -72,8 +77,8 @@ def __add__(self, other, name=None): new_cat = Catalog(name=name or self.name) # Combine results - new_results = at.vstack([at.Table(self.results), at.Table(other.results)]) - new_cat.results = new_results + new_results = at.vstack([at.Table(self._results), at.Table(other._results)]) + new_cat._results = new_results return new_cat @@ -91,15 +96,15 @@ def add_column(self, name, data, unc=None): The uncertainty array """ # Make sure column doesn't exist - if name in self.results.colnames: + if name in self._results.colnames: raise ValueError("{}: Column already exists.".format(name)) # Make sure data is the right length - if len(data) != len(self.results): - raise ValueError("{} != {}: Data is not the right size for this catalog.".format(len(data), len(self.results))) + if len(data) != len(self._results): + raise ValueError("{} != {}: Data is not the right size for this catalog.".format(len(data), len(self._results))) # Add the column - self.results.add_column(data, name=name) + self._results.add_column(data, name=name) # Add uncertainty column if unc is not None: @@ -108,16 +113,16 @@ def add_column(self, name, data, unc=None): name = name + '_unc' # Make sure column doesn't exist - if name in self.results.colnames: + if name in self._results.colnames: raise ValueError("{}: Column already exists.".format(name)) # Make sure data is the right length - if len(unc) != len(self.results): + if len(unc) != len(self._results): raise ValueError( - "{} != {}: Data is not the right size for this catalog.".format(len(unc), len(self.results))) + "{} != {}: Data is not the right size for this catalog.".format(len(unc), len(self._results))) # Add the column - self.results.add_column(unc, name=name) + self._results.add_column(unc, name=name) def add_SED(self, sed): """Add an SED to the catalog @@ -130,7 +135,7 @@ def add_SED(self, sed): # Overwrite duplicate names idx = None if sed.name in self.results['name']: - self.message("{}: Target already in catalog. Overwriting with new SED...") + self.message("{}: Target already in catalog. Overwriting with new SED...".format(sed.name)) idx = np.where(self.results['name'] == sed.name)[0][0] # Turn off print statements @@ -145,7 +150,7 @@ def add_SED(self, sed): # Add the values and uncertainties if applicable new_row = {} - for col in self.cols[:-1]: + for col in self.cols: if col + '_unc' in self.cols: if isinstance(getattr(sed, col), tuple): @@ -164,6 +169,15 @@ def add_SED(self, sed): new_row[col] = val + # Store the spectra + new_row['spectra'] = [spec['spectrum'] for spec in sed.spectra] + + # Store the SED arrays + for pre in ['app', 'abs']: + for dat in ['phot_', 'spec_', 'specphot_', '']: + sed_name = '{}_{}SED'.format(pre, dat) + new_row[sed_name] = getattr(sed, sed_name).spectrum if getattr(sed, sed_name) is not None else None + # Add the SED new_row['SED'] = sed @@ -171,31 +185,42 @@ def add_SED(self, sed): for row in sed.photometry: # Add the column to the results table - if row['band'] not in self.results.colnames: - self.results.add_column(at.Column([np.nan] * len(self.results), dtype=np.float16, name=row['band'])) - self.results.add_column(at.Column([np.nan] * len(self.results), dtype=np.float16, name=row['band'] + '_unc')) - self.results.add_column(at.Column([np.nan] * len(self.results), dtype=np.float16, name='M_' + row['band'])) - self.results.add_column(at.Column([np.nan] * len(self.results), dtype=np.float16, name='M_' + row['band'] + '_unc')) + if row['band'] not in self._results.colnames: + self._results.add_column(at.Column([None] * len(self._results), dtype='O', name=row['band'])) + self._results.add_column(at.Column([None] * len(self._results), dtype='O', name=row['band'] + '_unc')) + self._results.add_column(at.Column([None] * len(self._results), dtype='O', name='M_' + row['band'])) + self._results.add_column(at.Column([None] * len(self._results), dtype='O', name='M_' + row['band'] + '_unc')) + self.phot_cols += [row['band']] # Add the apparent magnitude - new_row[row['band']] = row['app_magnitude'] + if u.isnumber(row['app_magnitude']): + new_row[row['band']] = row['app_magnitude'] + + # Add the apparent uncertainty + new_row['{}_unc'.format(row['band'])] = None if np.isnan(row['app_magnitude_unc']) else row['app_magnitude_unc'] - # Add the apparent uncertainty - new_row[row['band'] + '_unc'] = row['app_magnitude_unc'] + # Add the absolute magnitude + new_row['M_{}'.format(row['band'])] = None if np.isnan(row['abs_magnitude']) else row['abs_magnitude'] - # Add the absolute magnitude - new_row['M_' + row['band']] = row['abs_magnitude'] + # Add the absolute uncertainty + new_row['M_{}_unc'.format(row['band'])] = None if np.isnan(row['abs_magnitude_unc']) else row['abs_magnitude_unc'] - # Add the absolute uncertainty - new_row['M_' + row['band'] + '_unc'] = row['abs_magnitude_unc'] + # Ensure missing photometry columns are None + for band in self.phot_cols: + if band not in sed.photometry['band']: + new_row[band] = None + new_row['{}_unc'.format(band)] = None + new_row['M_{}'.format(band)] = None + new_row['M_{}_unc'.format(band)] = None - # Add the new row... + # Add the new row to the end of the list... if idx is None: - self.results.add_row(new_row) + self._results.add_row(new_row) - # ...or replace existing + # ...or replace the existing row else: - self.results[idx] = new_row + self._results.remove_row(idx) + self._results.insert_row(idx, new_row) self.message("Successfully added SED '{}'".format(sed.name)) @@ -248,7 +273,7 @@ def export(self, parentdir='.', dirname=None, format='ipac', sources=True, zippe os.system('mkdir {}'.format(sourcedir)) # Export all SEDs - for source in self.results['SED']: + for source in self._results['SED']: source.export(sourcedir) # zip if desired @@ -269,7 +294,7 @@ def filter(self, param, value): param: str The parameter to filter by, e.g. 'Teff' value: str, float, int, sequence - The criteria to filter by, + The criteria to filter by, which can be single valued like 1400 or a range with operators [<,<=,>,>=], e.g. (>1200,<1400), () @@ -281,7 +306,13 @@ def filter(self, param, value): """ # Make a new catalog cat = Catalog() - cat.results = u.filter_table(self.results, **{param: value}) + + # If it's a list, just get the rows in the list + if isinstance(value, (list, np.ndarray)): + cat._results = self._results[[idx for idx, val in enumerate(self._results[param]) if val in value]] + + else: + cat._results = u.filter_table(self._results, **{param: value}) return cat @@ -324,29 +355,34 @@ def from_file(self, filepath, run_methods=['find_2MASS'], delimiter=','): def get_data(self, *args): """Fetch the data for the given columns """ - results = [] + data = [] + + # Fill results table + results = self.results.filled(np.nan) for x in args: # Get the data if '-' in x: x1, x2 = x.split('-') - if self.results[x1].unit != self.results[x2].unit: + if results[x1].unit != results[x2].unit: raise TypeError('Columns must be the same units.') - xunit = self.results[x1].unit - xdata = self.results[x1] - self.results[x2] - xerror = np.sqrt(self.results['{}_unc'.format(x1)]**2 + self.results['{}_unc'.format(x2)]**2) + xunit = results[x1].unit + xdata = np.array(results[x1].tolist()) - np.array(results[x2].tolist()) + xerr1 = np.array(results['{}_unc'.format(x1)].tolist()) + xerr2 = np.array(results['{}_unc'.format(x2)].tolist()) + xerror = np.sqrt(xerr1**2 + xerr2**2) else: - xunit = self.results[x].unit - xdata = self.results[x] - xerror = self.results['{}_unc'.format(x)] + xunit = results[x].unit + xdata = np.array(results[x].value.tolist()) if hasattr(results[x], 'unit') else np.array(results[x].tolist()) + xerror = np.array(results['{}_unc'.format(x)].value.tolist()) if hasattr(results['{}_unc'.format(x)], 'unit') else np.array(results['{}_unc'.format(x)].tolist()) # Append to results - results.append([xdata, xerror, xunit]) + data.append([xdata, xerror, xunit]) - return results + return data def get_SED(self, name_or_idx): """Retrieve the SED for the given object @@ -357,38 +393,99 @@ def get_SED(self, name_or_idx): The name or index of the SED to get """ # Add the index - self.results.add_index('name') + self._results.add_index('name') # Get the rows - if isinstance(name_or_idx, str) and name_or_idx in self.results['name']: - return copy(self.results.loc[name_or_idx]['SED']) + if isinstance(name_or_idx, str) and name_or_idx in self._results['name']: + return copy(self._results[self._results['name'] == name_or_idx]['SED'][0]) - elif isinstance(name_or_idx, int) and name_or_idx <= len(self.results): - return copy(self.results[name_or_idx]['SED']) + elif isinstance(name_or_idx, int) and name_or_idx <= len(self._results): + return copy(self._results[name_or_idx]['SED']) else: self.message('Could not retrieve SED {}'.format(name_or_idx)) - return + return - def load(self, file): - """Load a saved Catalog""" - if os.path.isfile(file): + def generate_SEDs(self, table): + """ + Generate SEDs from a Catalog results table - f = open(file) - cat = pickle.load(f) - f.close() + Parameters + ---------- + table: astropy.table.QTable + The table of data to use + + Returns + ------- + sequence + The list of SEDs for each row in the input table + """ + sed_list = [] + t = self.make_results_table() + for row in table: + s = SED(row['name'], verbose=False) + + for att in ['age', 'parallax', 'radius', 'spectral_type']: + setattr(self, att, (row[att] * t[att].unit, row['{}_unc'.format(att)] * t[att].unit) if row[att] is not None else None) + + s.sky_coords = row['sky_coords'] + s.membership = row['membership'] + s.reddening = row['reddening'] + # Add spectra + for spec in row['spectra']: + s.add_spectrum(spec) + + # Add photometry + for col in row.colnames: + if '.' in col and not col.startswith('M_') and not col.endswith('_unc'): + if row[col] is not None and not np.isnan(row[col]): + s.add_photometry(col, float(row[col]), float(row['{}_unc'.format(col)])) + + # Make the SED + s.make_sed() + + # Add SED object to the list + sed_list.append(s) + del s + + return sed_list + + def load(self, file, make_seds=False): + """ + Load a saved Catalog + + Parameters + ---------- + file: str + The file to load + """ + if os.path.isfile(file): + + # Open the file f = open(file, 'rb') - cat = pickle.load(f) + results = pickle.load(f) f.close() - self.results = cat + # Make SEDs again + if make_seds: + seds = self.generate_SEDs(results) + results.add_column(seds, name='SED') + + # Set results attribute + self._results = results + + self.message("Catalog loaded from {}".format(file)) + + else: + + self.message("Could not load Catalog from {}".format(file)) - @staticmethod def make_results_table(self): """Generate blank results table""" - results = at.QTable(names=self.cols, dtype=['O'] * len(self.cols)) + all_cols = self.cols + self.array_cols + results = at.QTable(names=all_cols, masked=True, dtype=['O'] * len(all_cols)) results.add_index('name') # Set the units @@ -428,9 +525,8 @@ def message(self, msg, pre='[sedkit]'): else: print("{} {}".format(pre, msg)) - def plot(self, x, y, marker=None, color=None, scale=['linear','linear'], - xlabel=None, ylabel=None, fig=None, order=None, identify=[], - id_color='red', label_points=False, draw=True, **kwargs): + def iplot(self, x, y, marker=None, color=None, scale=['linear','linear'], + xlabel=None, ylabel=None, draw=True, order=None, **kwargs): """Plot parameter x versus parameter y Parameters @@ -448,7 +544,7 @@ def plot(self, x, y, marker=None, color=None, scale=['linear','linear'], xlabel: str The label for the x-axis ylable : str - The label for the y-axis + The label for the y-axis fig: bokeh.plotting.figure (optional) The figure to plot on order: int @@ -497,15 +593,172 @@ def plot(self, x, y, marker=None, color=None, scale=['linear','linear'], if y not in params: raise ValueError("'{}' is not a valid y parameter. Please choose from {}".format(y, params)) + # Tooltip names can't have '.' or '-' + xname = source.add(source.data[x], x.replace('.', '_').replace('-', '_')) + yname = source.add(source.data[y], y.replace('.', '_').replace('-', '_')) + + # Make photometry source + phot_source = ColumnDataSource(data={'phot_wave': [], 'phot': []}) + phot_data = [row['app_phot_SED'][:2] if row['app_phot_SED'] is not None else [[], []] for row in self._results] + phot_len = max([len(i[0]) for i in phot_data]) + for idx, row in enumerate(self._results): + w, f = phot_data[idx] + w = np.concatenate([w, np.zeros(phot_len - len(w)) * np.nan]) + f = np.concatenate([f, np.zeros(phot_len - len(f)) * np.nan]) + _ = phot_source.add(w, 'phot_wave{}'.format(idx)) + _ = phot_source.add(f, 'phot{}'.format(idx)) + + # Make spectra source + spec_source = ColumnDataSource(data={'spec_wave': [], 'spec': []}) + spec_data = [row['app_spec_SED'][:2] if row['app_spec_SED'] is not None else [[], []] for row in self._results] + spec_len = max([len(i[0]) for i in spec_data]) + for idx, row in enumerate(self._results): + w, f = spec_data[idx] + w = np.concatenate([w, np.zeros(spec_len - len(w)) * np.nan]) + f = np.concatenate([f, np.zeros(spec_len - len(f)) * np.nan]) + _ = spec_source.add(w, 'spec_wave{}'.format(idx)) + _ = spec_source.add(f, 'spec{}'.format(idx)) + + # Set up hover tool + tips = [('Name', '@name'), (x, '@{}'.format(xname)), (y, '@{}'.format(yname))] + hover = HoverTool(tooltips=tips, names=['points']) + + callback = CustomJS(args=dict(source=source, phot_source=phot_source, spec_source=spec_source), code=""" + var data = source.data; + var phot_data = phot_source.data; + var spec_data = spec_source.data; + var selected = source.selected.indices; + phot_source.data['phot_wave'] = phot_data['phot_wave' + selected[0]]; + phot_source.data['phot'] = phot_data['phot' + selected[0]]; + phot_source.change.emit(); + spec_source.data['spec_wave'] = spec_data['spec_wave' + selected[0]]; + spec_source.data['spec'] = spec_data['spec' + selected[0]]; + spec_source.change.emit(); + """) + tap = TapTool(callback=callback) + + # Make the plot + TOOLS = ['pan', 'reset', 'box_zoom', 'wheel_zoom', 'save', hover, tap] + title = '{} v {}'.format(x, y) + fig = figure(plot_width=500, plot_height=500, title=title, y_axis_type=scale[1], x_axis_type=scale[0], tools=TOOLS) + + # Get marker class + size = kwargs.get('size', 8) + kwargs['size'] = size + marker = getattr(fig, marker or self.marker) + color = color or self.color + + # Plot nominal values and errors + marker(x, y, source=source, color=color, fill_alpha=0.7, name='points', **kwargs) + fig = u.errorbars(fig, x, y, xerr='{}_unc'.format(x), yerr='{}_unc'.format(y), source=source, color=color) + + # Set axis labels + xunit = source.data[x].unit if hasattr(source.data[x], 'unit') else None + yunit = source.data[y].unit if hasattr(source.data[y], 'unit') else None + fig.xaxis.axis_label = xlabel or '{}{}'.format(x, ' [{}]'.format(xunit) if xunit else '') + fig.yaxis.axis_label = ylabel or '{}{}'.format(y, ' [{}]'.format(yunit) if yunit else '') + + # Formatting + fig.legend.location = "top_right" + + # Draw sub figure + sub = figure(plot_width=500, plot_height=500, title='Selected Source', + x_axis_label=str(self.wave_units), y_axis_label=str(self.flux_units), + x_axis_type='log', y_axis_type='log') + sub.line('phot_wave', 'phot', source=phot_source, color='black', alpha=0.2) + sub.circle('phot_wave', 'phot', source=phot_source, size=8, color='red', alpha=0.8) + sub.line('spec_wave', 'spec', source=spec_source, color='red', alpha=0.5) + + # Make row layout + layout = Row(children=[fig, sub]) + + if draw: + show(layout) + + return layout + + def plot(self, x, y, marker=None, color=None, scale=['linear','linear'], + xlabel=None, ylabel=None, fig=None, order=None, identify=[], + id_color='red', label_points=False, draw=True, **kwargs): + """Plot parameter x versus parameter y + + Parameters + ---------- + x: str + The name of the x axis parameter, e.g. 'SpT' + y: str + The name of the y axis parameter, e.g. 'Teff' + marker: str (optional) + The name of the method for the desired marker + color: str (optional) + The color to use for the points + scale: sequence + The (x,y) scale for the plot + xlabel: str + The label for the x-axis + ylable : str + The label for the y-axis + fig: bokeh.plotting.figure (optional) + The figure to plot on + order: int + The polynomial order to fit + identify: idx, str, sequence + Names of sources to highlight in the plot + id_color: str + The color of the identified points + label_points: bool + Print the name of the object next to the point + + Returns + ------- + bokeh.plotting.figure.Figure + The figure object + """ + # Grab the source and valid params + source = copy(self.source) + params = [k for k in source.column_names if not k.endswith('_unc')] + + # If no uncertainty column for parameter, add it + if '{}_unc'.format(x) not in source.column_names: + _ = source.add([None] * len(source.data['name']), '{}_unc'.format(x)) + if '{}_unc'.format(y) not in source.column_names: + _ = source.add([None] * len(source.data['name']), '{}_unc'.format(y)) + + # Check if the x parameter is a color + xname = x.replace('.', '_').replace('-', '_') + if '-' in x and all([i in params for i in x.split('-')]): + colordata = self.get_data(x)[0] + if len(colordata) == 3: + _ = source.add(at.Column(data=colordata[0], unit=colordata[2]), x) + _ = source.add(at.Column(data=colordata[1], unit=colordata[2]), '{}_unc'.format(x)) + params.append(x) + + # Check if the y parameter is a color + yname = y.replace('.', '_').replace('-', '_') + if '-' in y and all([i in params for i in y.split('-')]): + colordata = self.get_data(y)[0] + if len(colordata) == 3: + _ = source.add(at.Column(data=colordata[0], unit=colordata[2]), y) + _ = source.add(at.Column(data=colordata[1], unit=colordata[2]), '{}_unc'.format(y)) + params.append(y) + + # Check the params are in the table + if x not in params: + raise ValueError("'{}' is not a valid x parameter. Please choose from {}".format(x, params)) + if y not in params: + raise ValueError("'{}' is not a valid y parameter. Please choose from {}".format(y, params)) + # Make the figure if fig is None: # Tooltip names can't have '.' or '-' - xname = source.add(source.data[x], x.replace('.', '_').replace('-', '_')) - yname = source.add(source.data[y], y.replace('.', '_').replace('-', '_')) + _ = source.add(at.Column(data=source.data[x]), xname) + _ = source.add(at.Column(data=source.data[y]), yname) + _ = source.add(at.Column(data=source.data['{}_unc'.format(x)]), '{}_unc'.format(xname)) + _ = source.add(at.Column(data=source.data['{}_unc'.format(y)]), '{}_unc'.format(yname)) # Set up hover tool - tips = [('Name', '@name'), (x, '@{}'.format(xname)), (y, '@{}'.format(yname))] + tips = [('Name', '@name'), ('Idx', '@idx'), (x, '@{0} (@{0}_unc)'.format(xname)), (y, '@{0} (@{0}_unc)'.format(yname))] hover = HoverTool(tooltips=tips, names=['points']) # Make the plot @@ -522,11 +775,10 @@ def plot(self, x, y, marker=None, color=None, scale=['linear','linear'], # Prep data names = source.data['name'] xval, xerr = source.data[x], source.data['{}_unc'.format(x)] - xval[xval == None] = np.nan - xerr[xerr == None] = np.nan yval, yerr = source.data[y], source.data['{}_unc'.format(y)] - yval[yval == None] = np.nan - yerr[yerr == None] = np.nan + + # Make error bars + fig = u.errorbars(fig, xval, yval, xerr=xerr, yerr=yerr, color=color) # Plot nominal values marker(x, y, source=source, color=color, fill_alpha=0.7, name='points', **kwargs) @@ -535,16 +787,6 @@ def plot(self, x, y, marker=None, color=None, scale=['linear','linear'], idx = [ni for ni, name in enumerate(names) if name in identify] fig.circle(xval[idx], yval[idx], size=size + 5, color=id_color, fill_color=None, line_width=2) - # Plot y errorbars - y_err_x = [(i, i) for i in source.data[x]] - y_err_y = [(yval[n] if np.isnan(i - j) else i - j, yval[n] if np.isnan(i + j) else i + j) for n, (i, j) in enumerate(zip(yval, yerr))] - fig.multi_line(y_err_x, y_err_y, color=color) - - # Plot x errorbars - x_err_y = [(i, i) for i in source.data[y]] - x_err_x = [(xval[n] if np.isnan(i - j) else i - j, xval[n] if np.isnan(i + j) else i + j) for n, (i, j) in enumerate(zip(xval, xerr))] - fig.multi_line(x_err_x, x_err_y, color=color) - # Label points if label_points: labels = LabelSet(x=x, y=y, text='name', level='glyph', x_offset=5, y_offset=5, source=source, render_mode='canvas') @@ -620,8 +862,6 @@ def plot_SEDs(self, name_or_idx, scale=['log', 'log'], normalize=None, **kwargs) normalized: bool Normalize the SEDs to 1 """ - COLORS = u.color_gen('Category10') - # Plot all SEDS if name_or_idx in ['all', '*']: name_or_idx = list(range(len(self.results))) @@ -630,6 +870,8 @@ def plot_SEDs(self, name_or_idx, scale=['log', 'log'], normalize=None, **kwargs) if isinstance(name_or_idx, (str, int)): name_or_idx = [name_or_idx] + COLORS = u.color_gen(kwargs.get('palette', self.palette), n=len(name_or_idx)) + # Make the plot TOOLS = ['pan', 'reset', 'box_zoom', 'wheel_zoom', 'save'] title = self.name @@ -639,11 +881,14 @@ def plot_SEDs(self, name_or_idx, scale=['log', 'log'], normalize=None, **kwargs) y_axis_label='Flux Density [{}]'.format(str(self.flux_units)), tools=TOOLS) - # Plot each SED + # Plot each SED if it has been calculated for obj in name_or_idx: - c = next(COLORS) targ = self.get_SED(obj) - fig = targ.plot(fig=fig, color=c, output=True, normalize=normalize, legend=targ.name, **kwargs) + if targ.calculated: + c = next(COLORS) + fig = targ.plot(fig=fig, color=c, one_color=True, output=True, normalize=normalize, label=targ.name, **kwargs) + else: + print("No SED to plot for source {}".format(obj)) return fig @@ -656,22 +901,37 @@ def remove_SED(self, name_or_idx): The name or index of the SED to remove """ # Add the index - self.results.add_index('name') + self._results.add_index('name') # Get the rows - if isinstance(name_or_idx, str) and name_or_idx in self.results['name']: - self.results = self.results[self.results['name'] != name_or_idx] + if isinstance(name_or_idx, str) and name_or_idx in self._results['name']: + self._results = self._results[self._results['name'] != name_or_idx] - elif isinstance(name_or_idx, int) and name_or_idx <= len(self.results): - self.results.remove_row([name_or_idx]) + elif isinstance(name_or_idx, int) and name_or_idx <= len(self._results): + self._results.remove_row([name_or_idx]) else: self.message('Could not remove SED {}'.format(name_or_idx)) return + @property + def results(self): + """ + Return results table + """ + # Get results table + res_tab = self._results[[col for col in self._results.colnames if col not in self.array_cols]] + + # Mask empty elements + for col in res_tab.columns.values(): + col.mask = [not bool(val) for val in col] + + return res_tab + def save(self, file): - """Save the serialized data + """ + Save the serialized data Parameters ---------- @@ -686,18 +946,31 @@ def save(self, file): if not os.path.isfile(file): os.system('touch {}'.format(file)) + # Get the pickle-safe data + results = copy(self._results) + results = results[[k for k in results.colnames if k != 'SED']] + # Write the file f = open(file, 'wb') - pickle.dump(self.results, f, pickle.HIGHEST_PROTOCOL) + dill.dump(results, f) f.close() self.message('Catalog saved to {}'.format(file)) + else: + + self.message('{}: Path does not exist. Try again.'.format(path)) + @property def source(self): """Generates a ColumnDataSource from the results table""" - # Remove SED column - results_dict = {key: val for key, val in dict(self.results).items() if key != 'SED'} + results = copy(self.results) + + # Remove array columns + results_dict = {key: val for key, val in dict(results).items()} + + # Add the index as a column in the table for tooltips + results_dict['idx'] = np.arange(len(self.results)) return ColumnDataSource(data=results_dict) diff --git a/sedkit/query.py b/sedkit/query.py index 87209a69..88d1af7f 100755 --- a/sedkit/query.py +++ b/sedkit/query.py @@ -15,7 +15,12 @@ import astropy.units as q from astroquery.vizier import Vizier from astroquery.sdss import SDSS +from bokeh.plotting import figure, show +from bokeh.models import HoverTool, ColumnDataSource import numpy as np +from svo_filters import Filter + +from . import utilities as u # A list of photometry catalogs from Vizier @@ -263,7 +268,7 @@ def query_SDSS_apogee_spectra(target=None, sky_coords=None, verbose=True, **kwar # return [wav, flx, err], catalog, header -def query_vizier(catalog, target=None, sky_coords=None, col_names=None, wildcards=['e_*'], target_names=None, search_radius=20 * q.arcsec, idx=0, cat_name=None, verbose=True, **kwargs): +def query_vizier(catalog, target=None, sky_coords=None, col_names=None, wildcards=['e_*'], target_names=None, search_radius=20 * q.arcsec, idx=0, cat_name=None, verbose=True, preview=False, **kwargs): """ Search Vizier for photometry in the given catalog @@ -285,6 +290,8 @@ def query_vizier(catalog, target=None, sky_coords=None, col_names=None, wildcard The search radius for the Vizier query idx: int The index of the record to use if multiple Vizier results + preview: bool + Make a plot of all photometry results """ # Get the catalog if catalog in PHOT_CATALOGS: @@ -363,33 +370,64 @@ def query_vizier(catalog, target=None, sky_coords=None, col_names=None, wildcard results = [] if n_rec > 0: - # Grab the record - rec = dict(viz_cat[idx]) - ref = viz_cat.meta['name'] + if preview: + + # Make the figure + prev = figure(width=900, height=400, y_axis_type="log", x_axis_type="log") + filters = [Filter(name) for name in names] + colors = u.color_gen(kwargs.get('palette', 'viridis'), n=n_rec) + phot_tips = [('Band', '@desc'), ('Wave', '@wav'), ('Flux', '@flx'), ('Unc', '@unc'), ('Idx', '@idx')] + hover = HoverTool(names=['phot'], tooltips=phot_tips) + prev.add_tools(hover) + + def valid(flx, err): + return err > 0 and not np.isnan(flx) and not np.isnan(err) + + # Get all mags from the queried results table + for n, row in enumerate(viz_cat): + try: + color = next(colors) + mags = [row[col] for col in cols if valid(row[col], row['e_{}'.format(col)])] + uncs = [row['e_{}'.format(col)] for col in cols if valid(row[col], row['e_{}'.format(col)])] + flxs, uncs = np.array([u.mag2flux(filt, m, e) for filt, m, e in zip(filters, mags, uncs)]).T + source = ColumnDataSource(data=dict(wav=[filt.wave_eff.value for filt in filters], flx=flxs, unc=uncs, idx=[n] * len(flxs), desc=names)) + prev.line('wav', 'flx', source=source, color=color, name='phot', alpha=0.1, hover_color="firebrick", hover_alpha=1) + prev.circle('wav', 'flx', source=source, color=color, size=8, alpha=0.3, hover_color="firebrick", hover_alpha=1) + prev = u.errorbars(prev, 'wav', 'flx', yerr='unc', source=source, color=color, alpha=0.3, hover_color="firebrick", hover_alpha=1) + except ValueError: + pass + + return prev + + else: + + # Grab the record + rec = dict(viz_cat[idx]) + ref = viz_cat.meta['name'] - # Pull out the photometry - for name, col in zip(names, cols): + # Pull out the photometry + for name, col in zip(names, cols): - # Data for this column - data = [] + # Data for this column + data = [] - # Add name - data.append(name) + # Add name + data.append(name) - # Check for nominal value - nom = rec.get(col) - data.append(nom) - if nom is None: - print("[sedkit] Could not find '{}' column in '{}' catalog.".format(col, cat_name)) + # Check for nominal value + nom = rec.get(col) + data.append(nom) + if nom is None: + print("[sedkit] Could not find '{}' column in '{}' catalog.".format(col, cat_name)) - # Check for wildcards - for wc in wildcards: - wc_col = wc.replace('*', col) - val = rec.get(wc_col) - data.append(val) + # Check for wildcards + for wc in wildcards: + wc_col = wc.replace('*', col) + val = rec.get(wc_col) + data.append(val) - # Add reference - data.append(ref) - results.append(data) + # Add reference + data.append(ref) + results.append(data) return results diff --git a/sedkit/relations.py b/sedkit/relations.py index ff8b5753..4bea8ed0 100644 --- a/sedkit/relations.py +++ b/sedkit/relations.py @@ -6,6 +6,7 @@ This is the code used to generate the polynomial relations used in sedkit's calculations """ +import os from pkg_resources import resource_filename import astropy.io.ascii as ii @@ -13,6 +14,9 @@ import astropy.table as at from astroquery.vizier import Vizier from bokeh.plotting import figure, show +from scipy.optimize import least_squares +from bokeh.models.glyphs import Patch +from bokeh.models import ColumnDataSource import numpy as np from . import utilities as u @@ -24,16 +28,25 @@ class Relation: """A base class to store raw data, fit a polynomial, and evaluate quickly""" - def __init__(self, file, add_columns=None, ref=None, **kwargs): + def __init__(self, table, add_columns=None, ref=None, **kwargs): """Load the data Parameters ---------- - file: str - The file to load + table: str, astropy.table.Table + The file or table to load """ # Load the file into a table - self.data = ii.read(file, **kwargs) + if isinstance(table, str): + if os.path.exists(table): + table = ii.read(table, **kwargs) + + # Make sure it's a table + if not isinstance(table, at.Table): + raise TypeError("{} is not a valid table of data. Please provide a astropy.table.Table or path to an ascii file to ingest.".format(type(table))) + + # Store the data + self.data = table self.ref = ref # Fill in masked values @@ -69,7 +82,7 @@ def add_column(self, colname, values): # Add the column self.data[colname] = values - def add_relation(self, rel_name, order, xrange=None, xunit=None, yunit=None, plot=True): + def add_relation(self, rel_name, order, xrange=None, xunit=None, yunit=None, reject_outliers=False, plot=True): """ Create a polynomial of the given *order* for *yparam* as a function of *xparam* which can be evaluated at any x value @@ -86,6 +99,9 @@ def add_relation(self, rel_name, order, xrange=None, xunit=None, yunit=None, plo The units of the x parameter values yunit: astropy.units.quantity.Quantity The units of the y parameter values + reject_outliers: bool + Use outlier rejection in the fit if polynomial + is of order 3 or less """ # Get params xparam, yparam = self._parse_rel_name(rel_name) @@ -105,33 +121,74 @@ def add_relation(self, rel_name, order, xrange=None, xunit=None, yunit=None, plo rel['y'] = rel['y'][idx] # Remove masked and NaN values - rel['x'], rel['y'] = self.validate_data(rel['x'], rel['y']) - - # Determine monotonicity - rel['monotonic'] = u.monotonic(rel['x']) + rel['x'], rel['y'], rel['weight'] = self.validate_data(rel['x'], rel['y']) # Set weighting - rel['weight'] = np.ones_like(rel['x']) if '{}_unc'.format(yparam) in self.data.colnames: - rel['weight'] = 1. / self.data['{}_unc'].format(yparam) + y_unc = np.array(self.data['{}_unc'.format(yparam)]) + rel['x'], rel['y'], y_unc = self.validate_data(rel['x'], rel['y'], y_unc) + rel['weight'] = 1. / y_unc + + # Determine monotonicity + rel['monotonic'] = u.monotonic(rel['x']) # Try to fit a polynomial try: - # Fit polynomial - rel['coeffs'], rel['C_p'] = np.polyfit(rel['x'], rel['y'], rel['order'], w=rel['weight'], cov=True) + # X array + rel['x_fit'] = np.linspace(rel['x'].min(), rel['x'].max(), 1000) - # Matrix with rows 1, spt, spt**2, ... - rel['matrix'] = np.vstack([rel['x']**(order-i) for i in range(order + 1)]).T + if reject_outliers: - # Matrix multiplication calculates the polynomial values - rel['yi'] = np.dot(rel['matrix'], rel['coeffs']) + def f(x, *c): + """Generic polynomial function""" + result = 0 + for coeff in c: + result = x * result + coeff + return result - # C_y = TT*C_z*TT.T - rel['C_yi'] = np.dot(rel['matrix'], np.dot(rel['C_p'], rel['matrix'].T)) + def residual(p, x, y): + """Residual calulation""" + return y - f(x, *p) - # Standard deviations are sqrt of diagonal - rel['sig_yi'] = np.sqrt(np.diag(rel['C_yi'])) + def errFit(hess_inv, resVariance): + return np.sqrt(np.diag(hess_inv * resVariance)) + + # TODO: This fails for order 4 or more + # Fit polynomial to data + p0 = np.ones(rel['order'] + 1) + res_robust = least_squares(residual, p0, loss='soft_l1', f_scale=0.1, args=(rel['x'], rel['y'])) + rel['coeffs'] = res_robust.x + rel['jac'] = res_robust.jac + rel['y_fit'] = f(rel['x_fit'], *rel['coeffs']) + + # Calculate errors on coefficients + rel['sig_coeffs'] = errFit(np.linalg.inv(np.dot(rel['jac'].T, rel['jac'])), (residual(rel['coeffs'], rel['x'], rel['y']) ** 2).sum() / (len(rel['y']) - len(p0))) + rel['sig_coeffs2'] = errFit(np.linalg.inv(2 * np.dot(rel['jac'].T, rel['jac'])), (residual(rel['coeffs'], rel['x'], rel['y']) ** 2).sum() / (len(rel['y']) - len(p0))) + + # Calculate upper and lower bounds on the fit + coeff_err = rel['coeffs'] - rel['sig_coeffs'] + rel['y_fit_err'] = f(rel['x_fit'], *coeff_err) + + else: + + # Fit polynomial + rel['coeffs'], rel['C_p'] = np.polyfit(rel['x'], rel['y'], rel['order'], w=rel['weight'], cov=True) + + # Matrix with rows 1, spt, spt**2, ... + rel['matrix'] = np.vstack([rel['x'] ** (order - i) for i in range(order + 1)]).T + + # Matrix multiplication calculates the polynomial values + rel['yi'] = np.dot(rel['matrix'], rel['coeffs']) + + # C_y = TT*C_z*TT.T + rel['C_yi'] = np.dot(rel['matrix'], np.dot(rel['C_p'], rel['matrix'].T)) + + # Standard deviations are sqrt of diagonal + rel['sig_yi'] = np.sqrt(np.diag(rel['C_yi'])) + + # Plot polynomial values + rel['y_fit'] = np.polyval(rel['coeffs'], rel['x_fit']) except Exception as exc: print(exc) @@ -195,7 +252,7 @@ def evaluate(self, rel_name, x_val, plot=False): if plot: plt = self.plot(rel_name) - plt.circle([x_val], [y_val.value], color='red', size=10, legend='{}({})'.format(rel['yparam'], x_val)) + plt.circle([x_val], [y_val.value if hasattr(y_val, 'unit') else y_val], color='red', size=10, legend='{}({})'.format(rel['yparam'], x_val)) if y_upper: plt.line([x_val, x_val], [y_val - y_lower, y_val + y_upper], color='red') show(plt) @@ -247,7 +304,7 @@ def plot(self, rel_name, **kwargs): # Make the figure fig = figure(x_axis_label=xparam, y_axis_label=yparam) - x, y = self.validate_data(self.data[xparam], self.data[yparam]) + x, y, _ = self.validate_data(self.data[xparam], self.data[yparam]) fig.circle(x, y, legend='Data', **kwargs) if rel_name in self.relations: @@ -256,13 +313,18 @@ def plot(self, rel_name, **kwargs): rel = self.relations[rel_name] # Plot polynomial values - xaxis = np.linspace(rel['x'].min(), rel['x'].max(), 100) - evals = np.polyval(rel['coeffs'], xaxis) - fig.line(xaxis, evals, color='black', legend='Fit') + fig.line(rel['x_fit'], rel['y_fit'], color='black', legend='Fit') + + # # Plot relation error + # xpat = np.hstack((rel['x_fit'], rel['x_fit'][::-1])) + # ypat = np.hstack((rel['y_fit'] + rel['y_fit_err'], (rel['y_fit'] - rel['y_fit_err'])[::-1])) + # err_source = ColumnDataSource(dict(xaxis=xpat, yaxis=ypat)) + # glyph = Patch(x='xaxis', y='yaxis', fill_color='black', line_color=None, fill_alpha=0.1) + # fig.add_glyph(err_source, glyph) return fig - def validate_data(self, X, Y): + def validate_data(self, X, Y, Y_unc=None): """ Validate the data for onlu numbers @@ -272,13 +334,19 @@ def validate_data(self, X, Y): The x-array Y: sequence The y-array + Y_unc: sequence + The uncertainty of the y-array Returns ------- sequence The validated arrays """ - valid = np.asarray([(float(x), float(y)) for x, y in zip(X, Y) if u.isnumber(x) and u.isnumber(y)]).T + if Y_unc is None: + Y_unc = np.ones_like(Y) + + # Check for valid numbers to plot + valid = np.asarray([(float(x), float(y), float(y_unc)) for x, y, y_unc in zip(X, Y, Y_unc) if u.isnumber(x) and u.isnumber(y) and u.isnumber(y_unc)]).T if len(valid) == 0: raise ValueError("No valid data in the arrays") diff --git a/sedkit/sed.py b/sedkit/sed.py index 9f6a5616..aea7dcc2 100755 --- a/sedkit/sed.py +++ b/sedkit/sed.py @@ -129,8 +129,6 @@ def __init__(self, name='My Target', verbose=True, method_list=None, **kwargs): # Attributes with setters self._name = None - self._ra = None - self._dec = None self._age = None self._distance = None self._parallax = None @@ -150,6 +148,7 @@ def __init__(self, name='My Target', verbose=True, method_list=None, **kwargs): # Static attributes self.evo_model = 'DUSTY00' + self.frame = 'icrs' self.reddening = 0 self.SpT = None self.multiple = False @@ -177,6 +176,7 @@ def __init__(self, name='My Target', verbose=True, method_list=None, **kwargs): # Attributes of arbitrary length self.all_names = [] + self.warnings = [] self.stitched_spectra = [] self.app_spec_SED = None self.app_phot_SED = None @@ -267,67 +267,115 @@ def add_photometry(self, band, mag, mag_unc=None, system='Vega', ref=None, **kwa """ # Make sure the magnitudes are floats if (not isinstance(mag, (float, np.float32))) or np.isnan(mag): - raise TypeError("{}: Magnitude must be a float.".format(type(mag))) + self.message("{}: {} magnitude must be a float, not {}".format(mag, band, type(mag))) - # Check the uncertainty - if not isinstance(mag_unc, (float, np.float32, type(None), np.ma.core.MaskedConstant)): - raise TypeError("{}: Magnitude uncertainty must be a float, NaN, or None.".format(type(mag_unc))) + else: + # Check the uncertainty + if not isinstance(mag_unc, (float, np.float32, type(None), np.ma.core.MaskedConstant)): + self.message("{}: {} magnitude uncertainty must be a (float, NaN, None), not {}.".format(mag_unc, band, type(mag_unc))) + + # Make NaN if 0 or None + if (isinstance(mag_unc, (float, int)) and mag_unc == 0) or isinstance(mag_unc, (np.ma.core.MaskedConstant, type(None))): + mag_unc = np.nan + + # Get the bandpass + if isinstance(band, str): + bp = svo.Filter(band) + elif isinstance(band, svo.Filter): + bp, band = band, band.name + else: + self.message('Not a recognized bandpass: {}'.format(band)) - # Make NaN if 0 or None - if (isinstance(mag_unc, (float, int)) and mag_unc == 0) or isinstance(mag_unc, (np.ma.core.MaskedConstant, type(None))): - mag_unc = np.nan + # Convert to Vega + mag, mag_unc = u.convert_mag(band, mag, mag_unc, old=system, new=self.mag_system) - # Get the bandpass - if isinstance(band, str): - bp = svo.Filter(band) - elif isinstance(band, svo.Filter): - bp, band = band, band.name - else: - self.message('Not a recognized bandpass: {}'.format(band)) + # Convert bandpass to desired units + bp.wave_units = self.wave_units - # Convert to Vega - mag, mag_unc = u.convert_mag(band, mag, mag_unc, old=system, new=self.mag_system) + # Drop the current band if it exists + if band in self.photometry['band']: + self.drop_photometry(band) - # Convert bandpass to desired units - bp.wave_units = self.wave_units + # Apply the dereddening by subtracting the (bandpass extinction vector)*(source dust column density) + mag -= bp.ext_vector * self.reddening - # Drop the current band if it exists - if band in self.photometry['band']: - self.drop_photometry(band) + # Make a dict for the new point + mag = round(mag, 3) + mag_unc = mag_unc if np.isnan(mag_unc) else round(mag_unc, 3) + eff = bp.wave_eff.astype(np.float16) + new_photometry = {'band': band, 'eff': eff, 'app_magnitude': mag, 'app_magnitude_unc': mag_unc, 'bandpass': bp, 'ref': ref} - # Apply the dereddening by subtracting the (bandpass extinction vector)*(source dust column density) - mag -= bp.ext_vector * self.reddening + # Add the kwargs + new_photometry.update(kwargs) - # Make a dict for the new point - mag = round(mag, 3) - mag_unc = mag_unc if np.isnan(mag_unc) else round(mag_unc, 3) - eff = bp.wave_eff.astype(np.float16) - new_photometry = {'band': band, 'eff': eff, 'app_magnitude': mag, 'app_magnitude_unc': mag_unc, 'bandpass': bp, 'ref': ref} + # Add it to the table + self._photometry.add_row(new_photometry) + self.message("Setting {} photometry to {:.3f} ({:.3f}) with reference '{}'".format(band, mag, mag_unc, ref)) - # Add the kwargs - new_photometry.update(kwargs) + # Set SED as uncalculated + self.calculated = False - # Add it to the table - self._photometry.add_row(new_photometry) - self.message("Setting {} photometry to {:.3f} ({:.3f}) with reference '{}'".format(band, mag, mag_unc, ref)) + # Update photometry max and min wavelengths + self._calculate_phot_lims() - # Set SED as uncalculated - self.calculated = False + def add_photometry_row(self, row, bands=None, rename=None, unc_wildcard='e_*', ref=None): + """ + Parse a table row into a separate table of photometry. Used mostly to parse the + photometry from a single row of Vizier table output for ingestion into SED object. - # Update photometry max and min wavelengths - self._calculate_phot_lims() + Parameters + ---------- + row: astropy.table.Row + The table row of data to add + bands: sequence + The list of bands to preserve + rename: sequence + A list of new band names + unc_wildcard: str + The wildcard to look for when fetching uncertainties + ref: str, sequence + The reference for all photometry of a list of references for each band + + Returns + ------- + astropy.table.Table + The resulting table + """ + self.message("Reading photometry from Table row") + + # Assume SDSS, Gaia, 2MASS, WISE + if bands is None: + bands = ['FUVmag', 'NUVmag', 'umag', 'gmag', 'rmag', 'imag', 'zmag', 'Gmag', 'Jmag', 'Hmag', 'Kmag', 'W1mag', 'W2mag', + 'W3mag', 'W4mag'] + rename = ['GALEX.FUV', 'GALEX.NUV', 'SDSS.u', 'SDSS.g', 'SDSS.r', 'SDSS.i', 'SDSS.z', 'Gaia.G', '2MASS.J', '2MASS.H', '2MASS.Ks', + 'WISE.W1', 'WISE.W2', 'WISE.W3', 'WISE.W4'] + + # Add photometry to phot table + for idx, band in enumerate(bands): + if band in row.colnames: + goodmag = isinstance(row[band], (float, np.float16, np.float32)) + goodunc = isinstance(row[unc_wildcard.replace('*', band)], (float, np.float16, np.float32)) + if goodmag: + mag = row[band] + unc = row[unc_wildcard.replace('*', band)] if goodunc else np.nan + rf = row['ref'] if isinstance(ref, str) else ref[idx] if isinstance(ref, (list, tuple, np.ndarray)) else None + band = rename[idx] if rename is not None else band + + # Add the table to the object + self.add_photometry(**{'band': band, 'mag': mag, 'mag_unc': unc, 'ref': rf}) - def add_photometry_file(self, file): + def add_photometry_table(self, table, **kwargs): """ - Add a table of photometry from an ASCII file that contains the columns 'band', 'magnitude', and 'uncertainty' + Add photometry from a table or ASCII file that contains the columns 'band', 'magnitude', and 'uncertainty' Parameters ---------- - file: str - The path to the ascii file + table: str, astropy.table.Table, astropy.table.row.Row + The file path to an ASCII table, astropy Table, or astropy Row to convert to a table () """ - # Read the data - table = ii.read(file) + # Read the table data + if isinstance(table, str): + table = ii.read(table) # Test to see if columns are present cols = ['band', 'magnitude', 'uncertainty'] @@ -650,6 +698,13 @@ def calculate_Teff(self): def _calibrate_photometry(self, name='photometry'): """ Calculate the absolute magnitudes and flux values of all rows in the photometry table + + Parameters + ---------- + name: str + The name of the attribute to calibrate, ['photometry', 'synthetic_photometry'] + phot_flag: float + The survey-survey color to flag as suspicious """ # Reset photometric SED if name == 'photometry': @@ -708,6 +763,9 @@ def _calibrate_photometry(self, name='photometry'): # Set SED as uncalculated self.calculated = False + # Flag suspicious photometry + self._flag_photometry() + def _calibrate_spectra(self): """ Create composite spectra and flux calibrate @@ -799,32 +857,9 @@ def compare_model(self, modelgrid, fit_to='spec', rebin=True, **kwargs): @property def dec(self): """ - A property for declination - """ - return self._dec - - @dec.setter - def dec(self, dec, dec_unc=None, frame='icrs'): - """ - Set the declination of the source - - Padecmeters - ---------- - dec: astropy.units.quantity.Quantity - The declination - dec_unc: astropy.units.quantity.Quantity (optional) - The uncertainty - frame: str - The reference frame + A property for dec """ - if not isinstance(dec, (q.quantity.Quantity, str)): - raise TypeError("{}: Cannot interpret dec".format(dec)) - - # Make sure it's decimal degrees - self._dec = Angle(dec) - if self.ra is not None: - sky_coords = SkyCoord(ra=self.ra, dec=self.dec, unit=(q.degree, q.degree), frame='icrs') - self._set_sky_coords(sky_coords, simbad=False) + return self.sky_coords.dec if self.sky_coords is not None else None @property def distance(self): @@ -1152,7 +1187,7 @@ def find_PanSTARRS(self, **kwargs): """ self.find_photometry('PanSTARRS', **kwargs) - def find_photometry(self, catalog, col_names=None, target_names=None, search_radius=None, idx=0, **kwargs): + def find_photometry(self, catalog, col_names=None, target_names=None, search_radius=None, idx=0, preview=False, **kwargs): """ Search Vizier for photometry in the given catalog @@ -1168,23 +1203,33 @@ def find_photometry(self, catalog, col_names=None, target_names=None, search_rad The search radius for the Vizier query idx: int The index of the record to use if multiple Vizier results + preview: bool + Plot a preview of all queried photometry for visual inspection """ # Get the Vizier catalog - results = qu.query_vizier(catalog, col_names=col_names, target_names=target_names, target=self.name, sky_coords=self.sky_coords, search_radius=search_radius or self.search_radius, verbose=self.verbose, idx=idx, **kwargs) + results = qu.query_vizier(catalog, col_names=col_names, target_names=target_names, target=self.name, sky_coords=self.sky_coords, search_radius=search_radius or self.search_radius, verbose=self.verbose, idx=idx, preview=preview, **kwargs) - # Parse the record - for result in results: + if preview: + self.plot(fig=results) - # Get result - band, mag, unc, ref = result + else: - # Ensure Vegamag - system = 'AB' if 'SDSS' in band else 'Vega' + # Parse the record + for result in results: - self.add_photometry(band, mag, unc, ref=ref, system=system) + # Get result + band, mag, unc, ref = result - # Pause to prevent ConnectionError with astroquery - time.sleep(self.wait) + # Ensure Vegamag + # TODO: Vizier results in Vegamag already? + system = 'Vega' + # system = 'AB' if 'SDSS' in band else 'Vega' + + # Add the magnitude + self.add_photometry(band, mag, unc, ref=ref, system=system) + + # Pause to prevent ConnectionError with astroquery + time.sleep(self.wait) def find_SDSS(self, **kwargs): """ @@ -1278,7 +1323,7 @@ def find_Simbad(self, search_radius=None, include=['parallax', 'spectral_type', if self.sky_coords is None: sky_coords = tuple(viz_cat[0][['RA', 'DEC']]) sky_coords = SkyCoord(ra=sky_coords[0], dec=sky_coords[1], unit=(q.degree, q.degree), frame='icrs') - self._set_sky_coords(sky_coords, simbad=False) + self.sky_coords = sky_coords # Check for a parallax if 'parallax' in include and not hasattr(obj['PLX_VALUE'], 'mask'): @@ -1426,6 +1471,29 @@ def fit_spectral_type(self): # Run the fit self.fit_modelgrid(spl) + def _flag_photometry(self): + """ + Check that two adjascent photometric bands are reasonable + + Parameters + ---------- + band1: str + The first band name + band2: str + The second band name + phot_flag: float + The maximum allowed magnitude difference + """ + checks = [('SDSS.g', 'PS1.g', 0.3), ('SDSS.r', 'PS1.r', 0.3), ('SDSS.i', 'PS1.i', 0.3), ('SDSS.z', 'PS1.z', 0.3), + ('SDSS.z', '2MASS.J', 1.5), ('PS1.y', '2MASS.J', 1.5), ('2MASS.Ks', 'WISE.W1', 2)] + for band1, band2, flag in checks: + if self.get_mag(band1) is not None and self.get_mag(band2) is not None: + m1 = self.get_mag(band1) + m2 = self.get_mag(band2) + md = abs(m1[0] - m2[0]) + if md > flag and not np.isnan(m1[1]) and not np.isnan(m2[1]): + self.warning('{} and {} photometry are not smooth! Ratio = {}. Check your photometry!'.format(band1, band2, md)) + @property def flux_units(self): """ @@ -2008,6 +2076,8 @@ def make_sed(self): # Make sure there is data if len(self.spectra) == 0 and len(self.photometry) == 0: self.message('Cannot make the SED without spectra or photometry!') + self.calculated = False + return # Calculate flux and calibrate @@ -2324,7 +2394,7 @@ def photometry(self): def plot(self, app=True, photometry=True, spectra=True, integral=True, synthetic_photometry=False, best_fit=True, normalize=None, scale=['log', 'log'], output=False, fig=None, - color='#1f77b4', one_color=False, **kwargs): + color='#1f77b4', one_color=False, label=None, **kwargs): """ Plot the SED @@ -2354,6 +2424,10 @@ def plot(self, app=True, photometry=True, spectra=True, integral=True, synthetic The Boheh plot to add the SED to color: str The color for the plot points and lines + one_color: bool + Plots ass SED data using a single color + label: str + The legend label for the integral, defaults to Teff value Returns ------- @@ -2410,7 +2484,7 @@ def plot(self, app=True, photometry=True, spectra=True, integral=True, synthetic TOOLS = ['pan', 'reset', 'box_zoom', 'wheel_zoom', 'save'] xlab = 'Wavelength [{}]'.format(self.wave_units) ylab = 'Flux Density [{}]'.format(str(self.flux_units)) - self.fig = figure(plot_width=800, plot_height=500, title=self.name, + self.fig = figure(plot_width=900, plot_height=400, title=self.name, y_axis_type=scale[1], x_axis_type=scale[0], x_axis_label=xlab, y_axis_label=ylab, tools=TOOLS) @@ -2430,7 +2504,7 @@ def plot(self, app=True, photometry=True, spectra=True, integral=True, synthetic # Set up hover tool phot_tips = [('Band', '@desc'), ('Wave', '@x'), ('Flux', '@y'), ('Unc', '@z')] - hover = HoverTool(names=['photometry', 'nondetection'], tooltips=phot_tips, mode='vline') + hover = HoverTool(names=['photometry', 'nondetection'], tooltips=phot_tips) self.fig.add_tools(hover) # Plot points with errors @@ -2438,12 +2512,7 @@ def plot(self, app=True, photometry=True, spectra=True, integral=True, synthetic if len(pts) > 0: source = ColumnDataSource(data=dict(x=pts['x'], y=pts['y'], z=pts['z'], desc=[b.decode("utf-8") for b in pts['desc']])) self.fig.circle('x', 'y', source=source, legend_label='Photometry', name='photometry', color=color, fill_alpha=0.7, size=8) - y_err_x = [] - y_err_y = [] - for name, px, py, err in pts: - y_err_x.append((px, px)) - y_err_y.append((py - err, py + err)) - self.fig.multi_line(y_err_x, y_err_y, color=color) + self.fig = u.errorbars(self.fig, 'x', 'y', yerr='z', source=source, color=color) # Plot points without errors pts = np.array([(bnd, wav, flx * const, err * const) for bnd, wav, flx, err in np.array(self.photometry['band', 'eff', pre + 'flux', pre + 'flux_unc']) if (np.isnan(err) or err <= 0) and not np.isnan(flx)], dtype=[('desc', 'S20'), ('x', float), ('y', float), ('z', float)]) @@ -2464,16 +2533,11 @@ def plot(self, app=True, photometry=True, spectra=True, integral=True, synthetic if len(pts) > 0: source = ColumnDataSource(data=dict(x=pts['x'], y=pts['y'], z=pts['z'], desc=[b.decode("utf-8") for b in pts['desc']])) self.fig.square('x', 'y', source=source, legend_label='Synthetic Photometry', name='synthetic photometry', color=color, fill_alpha=0.7, size=8) - y_err_x = [] - y_err_y = [] - for name, px, py, err in pts: - y_err_x.append((px, px)) - y_err_y.append((py - err, py + err)) - self.fig.multi_line(y_err_x, y_err_y, color=color) + self.fig = u.errorbars(self.fig, 'x', 'y', yerr='z', source=source, color=color) # Plot the SED with linear interpolation completion if integral: - label = str(self.Teff[0]) if self.Teff is not None else 'Integral' + label = label or str(self.Teff[0]) if self.Teff is not None else 'Integral' self.fig.line(full_SED.wave, full_SED.flux * const, line_color=color if one_color else 'black', alpha=0.3, legend_label=label) if best_fit and len(self.best_fit) > 0: @@ -2503,32 +2567,9 @@ def plot(self, app=True, photometry=True, spectra=True, integral=True, synthetic @property def ra(self): """ - A property for right ascension - """ - return self._ra - - @ra.setter - def ra(self, ra, ra_unc=None, frame='icrs'): + A property for ra """ - Set the right ascension of the source - - Parameters - ---------- - ra: astropy.units.quantity.Quantity - The right ascension - ra_unc: astropy.units.quantity.Quantity (optional) - The uncertainty - frame: str - The reference frame - """ - if not isinstance(ra, (q.quantity.Quantity, str)): - raise TypeError("{}: Cannot interpret ra".format(ra)) - - # Make sure it's decimal degrees - self._ra = Angle(ra) - if self.dec is not None: - sky_coords = SkyCoord(ra=self.ra, dec=self.dec, unit=q.degree, frame='icrs') - self._set_sky_coords(sky_coords, simbad=False) + return self.sky_coords.ra if self.sky_coords is not None else None @property def radius(self): @@ -2609,6 +2650,9 @@ def results(self): elif isinstance(attr, (str, float, bytes, int)): rows.append([param, attr, '--', '--']) + elif hasattr(attr, 'deg'): + rows.append([param, attr.deg, '--', 'deg']) + else: pass @@ -2653,7 +2697,7 @@ def sky_coords(self): return self._sky_coords @sky_coords.setter - def sky_coords(self, sky_coords, frame='icrs'): + def sky_coords(self, sky_coords): """ A setter for sky coordinates @@ -2663,47 +2707,40 @@ def sky_coords(self, sky_coords, frame='icrs'): The sky coordinates frame: str The coordinate frame + simbad: bool + Search Simbad by the coordinates """ - # Make sure it's a sky coordinate - if not isinstance(sky_coords, (SkyCoord, tuple)): - raise TypeError('Sky coordinates must be astropy.coordinates.SkyCoord or (ra, dec) tuple.') - - if isinstance(sky_coords, tuple) and len(sky_coords) == 2: + # Allow None to be set + if sky_coords is None: + self._sky_coords = None - if isinstance(sky_coords[0], str): - sky_coords = SkyCoord(ra=sky_coords[0], dec=sky_coords[1], unit=(q.degree, q.degree), frame=frame) + else: - elif isinstance(sky_coords[0], (float, Angle, q.quantity.Quantity)): - sky_coords = SkyCoord(ra=sky_coords[0], dec=sky_coords[1], unit=q.degree, frame=frame) + # Make sure it's a sky coordinate + if not isinstance(sky_coords, (SkyCoord, tuple)): + raise TypeError('Sky coordinates must be astropy.coordinates.SkyCoord or (ra, dec) tuple.') - else: - raise TypeError("Cannot convert type {} to coordinates.".format(type(sky_coords[0]))) + # If it's already SkyCoord just set it + if isinstance(sky_coords, SkyCoord): + self._sky_coords = sky_coords + self.message("Setting sky_coords to {}".format(self.sky_coords)) + self.get_reddening() - self._set_sky_coords(sky_coords) + # If it's coordinates, make it into a SkyCoord + if isinstance(sky_coords, tuple) and len(sky_coords) == 2: - def _set_sky_coords(self, sky_coords, simbad=True): - """ - Calculate and set attributes from sky coords - - Parameters - ---------- - sky_coords: astropy.coordinates.SkyCoord - The sky coordinates - simbad: bool - Search Simbad by the coordinates - """ - # Set the sky coordinates - self._sky_coords = sky_coords - self._ra = sky_coords.ra.degree - self._dec = sky_coords.dec.degree - self.message("Setting sky_coords to {}".format(self.sky_coords)) + if isinstance(sky_coords[0], str): + self._sky_coords = SkyCoord(ra=sky_coords[0], dec=sky_coords[1], unit=(q.degree, q.degree), frame=self.frame) + self.message("Setting sky_coords to {}".format(self.sky_coords)) + self.get_reddening() - # Try to calculate reddening - self.get_reddening() + elif isinstance(sky_coords[0], (float, Angle, q.quantity.Quantity)): + self._sky_coords = SkyCoord(ra=sky_coords[0], dec=sky_coords[1], unit=q.degree, frame=self.frame) + self.message("Setting sky_coords to {}".format(self.sky_coords)) + self.get_reddening() - # Try to find the source in Simbad - if simbad: - self.find_Simbad() + else: + raise TypeError("Cannot convert type {} to coordinates.".format(type(sky_coords[0]))) def spectrum_from_modelgrid(self, model_grid, snr=10, **kwargs): """ @@ -2777,7 +2814,7 @@ def spectral_type(self, spectral_type): if '+' in spectral_type: self.multiple = True raw_SpT = spectral_type - self.message("{}: This source appears to be a multiple".format(spectral_type)) + self.warning("{}: This source appears to be a multiple".format(spectral_type)) spec_type = u.specType(spectral_type) spectral_type, spectral_type_unc, prefix, gravity, lum_class = spec_type @@ -2923,6 +2960,19 @@ def _validate_and_set_param(self, param, values, units, set_uncalculated=True, t return valid + def warning(self, msg): + """ + Display and/or save a warning + + Parameters + ---------- + msg: str + The message to print + """ + self.message(msg, pre='[WARNING]') + if msg not in self.warnings: + self.warnings.append(msg) + @property def wave_units(self): """ diff --git a/sedkit/tests/test_catalog.py b/sedkit/tests/test_catalog.py index 3fae87cf..6a0da07c 100644 --- a/sedkit/tests/test_catalog.py +++ b/sedkit/tests/test_catalog.py @@ -76,6 +76,7 @@ def test_filter(self): # Check there are two SEDs self.assertEqual(len(cat.results), 2) + print('SpT:', cat.get_data('spectral_type')) # Filter so there is only one result f_cat = cat.filter('spectral_type', '>30') @@ -101,13 +102,10 @@ def test_get_data(self): def test_get_SED(self): """Test get_SED method""" - # Make the catalog + # Get the SED using name cat = copy.copy(self.cat) - cat.add_SED(self.vega) - - # Get the SED + cat.add_SED(copy.copy(self.vega)) s = cat.get_SED('Vega') - self.assertEqual(type(s), type(self.vega)) def test_plot(self): @@ -122,7 +120,7 @@ def test_plot(self): self.assertEqual(str(type(plt)), "") # Color-color plot - plt = cat.plot('WISE.W1-WISE.W2', 'WISE.W1-WISE.W2', order=1) + plt = cat.plot('distance', 'parallax', order=1) self.assertEqual(str(type(plt)), "") # Bad columns @@ -130,10 +128,27 @@ def test_plot(self): self.assertRaises(ValueError, cat.plot, 'foo', 'parallax') # Fit polynomial - cat.plot('spectral_type', 'parallax', order=1) + cat.plot('distance', 'parallax', order=1) # Identify sources - cat.plot('spectral_type', 'parallax', identify=['Vega']) + cat.plot('distance', 'parallax', identify=['Vega']) + + def test_iplot(self): + """Test iplot method""" + # Make the catalog + cat = copy.copy(self.cat) + cat.add_SED(self.vega) + cat.add_SED(self.sirius) + + # Simple plot + plt = cat.iplot('distance', 'parallax') + + # Color-color plot + plt = cat.iplot('distance', 'parallax', order=1) + + # Bad columns + self.assertRaises(ValueError, cat.iplot, 'spectral_type', 'foo') + self.assertRaises(ValueError, cat.iplot, 'foo', 'parallax') def test_plot_SEDs(self): """Test plot_SEDs method""" @@ -149,6 +164,8 @@ def test_save_and_load(self): """Test save and load methods""" # Make the catalog cat = copy.copy(self.cat) + cat.add_SED(self.vega) + cat.add_SED(self.sirius) cat.save('test.p') # Try to load it diff --git a/sedkit/tests/test_sed.py b/sedkit/tests/test_sed.py index 502508db..0f47a4ee 100644 --- a/sedkit/tests/test_sed.py +++ b/sedkit/tests/test_sed.py @@ -43,13 +43,13 @@ def test_add_photometry(self): s.drop_photometry(0) self.assertEqual(len(s.photometry), 0) - def test_add_photometry_file(self): + def test_add_photometry_table(self): """Test that photometry is added properly from file""" s = copy.copy(self.sed) # Add the photometry f = resource_filename('sedkit', 'data/L3_photometry.txt') - s.add_photometry_file(f) + s.add_photometry_table(f) self.assertEqual(len(s.photometry), 8) def test_add_spectrum(self): @@ -93,19 +93,10 @@ def test_attributes(self): self.assertRaises(TypeError, setattr, s, 'age', (4, 0.1)) self.assertRaises(TypeError, setattr, s, 'age', (4*q.Jy, 0.1*q.Jy)) - # Dec - s.dec = 1.2345 * q.deg - self.assertRaises(TypeError, setattr, s, 'dec', 1.2345) - - # RA - s.ra = 1.2345 * q.deg - self.assertRaises(TypeError, setattr, s, 'ra', 1.2345) - # Sky coords s.sky_coords = 1.2345 * q.deg, 1.2345 * q.deg s.sky_coords = '1.2345', '1.2345' self.assertRaises(TypeError, setattr, s, 'sky_coords', 'foo') - self.assertRaises(TypeError, setattr, s, 'sky_coords', None) # Distance s.distance = None @@ -234,7 +225,6 @@ def test_find_SDSS_spectra(self): s = sed.SED() s.sky_coords = SkyCoord('0h8m05.63s +14d50m23.3s', frame='icrs') s.find_SDSS_spectra(search_radius=20 * q.arcsec) - assert len(s.spectra) > 0 def test_run_methods(self): """Test that the method_list argument works""" diff --git a/sedkit/utilities.py b/sedkit/utilities.py index a7eea203..755e272a 100755 --- a/sedkit/utilities.py +++ b/sedkit/utilities.py @@ -106,7 +106,8 @@ def convert_mag(band, mag, mag_unc, old='AB', new='Vega'): # TODO: Add other bandpasses AB_to_Vega = {'Johnson.U': 0.79, 'Johnson.B': -0.09, 'Johnson.V': 0.02, 'Cousins.R': 0.21, 'Cousins.I': 0.45, '2MASS.J': 0.91, '2MASS.H': 1.39, '2MASS.Ks': 1.85, - 'SDSS.u': 0.91, 'SDSS.g': -0.08, 'SDSS.r': 0.16, 'SDSS.i': 0.37, 'SDSS.z': 0.54} + 'SDSS.u': 0.91 - 0.04, 'SDSS.g': -0.08, 'SDSS.r': 0.16, 'SDSS.i': 0.37, 'SDSS.z': 0.54 + 0.02, + 'PS1.g': -0.08, 'PS1.r': 0.16, 'PS1.i': 0.37, 'PS1.z': 0.54, 'PS1.y': 0.63} if old == 'AB' and new == 'Vega': corr = AB_to_Vega.get(band, 0) @@ -188,7 +189,7 @@ def isnumber(s, nan=False): if isinstance(s, (str, bytes)): return s.replace('.', '').replace('-', '').replace(' + ', '').isnumeric() - elif isinstance(s, (int, float, np.int64, np.float32, np.float64)): + elif isinstance(s, (int, float, np.int64, np.float16, np.float32, np.float64)): if np.isnan(s) and not nan: return False else: @@ -284,7 +285,7 @@ def filter_table(table, **kwargs): if not value.endswith('*'): value = value + '$' - # Strip souble quotes + # Strip double quotes value = value.replace("'", '').replace('"', '').replace('*', '(.*)') # Regex @@ -306,7 +307,7 @@ def filter_table(table, **kwargs): if any([value.startswith(o) for o in ['<', '>', '=']]): value = [value] - # Assume eqality if no operator + # Assume equality if no operator else: value = ['== ' + value] @@ -319,28 +320,53 @@ def filter_table(table, **kwargs): # Equality if cond.startswith('='): - v = cond.replace('=', '') - table = table[table[param] == eval(v)] + val = cond.replace('=', '') + idx = [] + for i, v in enumerate(table[param].value): + if v is not None: + if v == float(val): + idx.append(i) + table = table[idx] # Less than or equal elif cond.startswith('<='): - v = cond.replace('<=', '') - table = table[table[param] <= eval(v)] + val = cond.replace('<=', '') + idx = [] + for i, v in enumerate(table[param].value): + if v is not None: + if v <= float(val): + idx.append(i) + table = table[idx] # Less than elif cond.startswith('<'): - v = cond.replace('<', '') - table = table[table[param] < eval(v)] + val = cond.replace('<', '') + idx = [] + for i, v in enumerate(table[param].value): + if v is not None: + if v < float(val): + idx.append(i) + table = table[idx] # Greater than or equal elif cond.startswith('>='): - v = cond.replace('>=', '') - table = table[table[param] >= eval(v)] + val = cond.replace('>=', '') + idx = [] + for i, v in enumerate(table[param].value): + if v is not None: + if v >= float(val): + idx.append(i) + table = table[idx] # Greater than elif cond.startswith('>'): - v = cond.replace('>', '') - table = table[table[param] > eval(v)] + val = cond.replace('>', '') + idx = [] + for i, v in enumerate(table[param].value): + if v is not None: + if v > float(val): + idx.append(i) + table = table[idx] else: raise ValueError("'{}' operator not understood.".format(cond)) @@ -500,29 +526,48 @@ def errfunc(p, a1, a2): return norm_factor -def errorbars(fig, x, y, xerr=None, xupper=None, xlower=None, yerr=None, yupper=None, ylower=None, color='red', point_kwargs={}, error_kwargs={}): +def errorbars(fig, x, y, xerr=None, xupper=None, xlower=None, yerr=None, yupper=None, ylower=None, source=None, color='red', name='errors', **kwargs): """ Hack to make errorbar plots in bokeh Parameters ---------- - x: sequence - The x axis data - y: sequence - The y axis data - xerr: sequence (optional) - The x axis errors - yerr: sequence (optional) - The y axis errors + x: sequence, str + The x axis data or ColumnDataSource key + y: sequence, str + The y axis data or ColumnDataSource key + xerr: sequence, str (optional) + The x axis symmetric errors or ColumnDataSource key + xlower: sequence, str (optional) + The x axis lower errors or ColumnDataSource key + xupper: sequence, str (optional) + The x axis upper errors or ColumnDataSource key + yerr: sequence, str (optional) + The y axis symmetric errors or ColumnDataSource key + ylower: sequence, str (optional) + The y axis lower errors or ColumnDataSource key + yupper: sequence, str (optional) + The y axis upper errors or ColumnDataSource key color: str The marker and error bar color - point_kwargs: dict - kwargs for the point styling - error_kwargs: dict - kwargs for the error bar styling + name: str + A name for the glyph legend: str The text for the legend """ + if source is not None: + + # Get data from ColumnDataSource + data = source.data + x = data[x] + xerr = data.get(xerr) + xlower = data.get(xlower) + xupper = data.get(xupper) + y = data[y] + yerr = data.get(yerr) + ylower = data.get(ylower) + yupper = data.get(yupper) + # Add x errorbars if possible if xerr is not None or (xupper is not None and xlower is not None): x_err_x = [] @@ -530,23 +575,25 @@ def errorbars(fig, x, y, xerr=None, xupper=None, xlower=None, yerr=None, yupper= # Symmetric uncertainties if xerr is not None: + if hasattr(xerr, 'unit'): + xerr = xerr.value for px, py, err in zip(x, y, xerr): - try: - x_err_x.append((px - err, px + err)) - x_err_y.append((py, py)) - except TypeError: - pass + if isnumber(px) and isnumber(py): + x_err_x.append([px - err, px + err]) + x_err_y.append([py, py]) # Asymmetric uncertainties elif xupper is not None and xlower is not None: + if hasattr(xlower, 'unit'): + xlower = xlower.value + if hasattr(xupper, 'unit'): + xupper = xupper.value for px, py, lower, upper in zip(x, y, xlower, xupper): - try: - x_err_x.append((px - lower, px + upper)) - x_err_y.append((py, py)) - except TypeError: - pass + if isnumber(px) and isnumber(py): + x_err_x.append([px - lower, px + upper]) + x_err_y.append([py, py]) - fig.multi_line(x_err_x, x_err_y, color=color, **error_kwargs) + fig.multi_line(x_err_x, x_err_y, color=color, name=name, **kwargs) # Add y errorbars if possible if yerr is not None or (yupper is not None and ylower is not None): @@ -555,23 +602,27 @@ def errorbars(fig, x, y, xerr=None, xupper=None, xlower=None, yerr=None, yupper= # Symmetric uncertainties if yerr is not None: + if hasattr(yerr, 'unit'): + yerr = yerr.value for px, py, err in zip(x, y, yerr): - try: - y_err_y.append((py - err, py + err)) - y_err_x.append((px, px)) - except TypeError: - pass + if isnumber(px) and isnumber(py): + y_err_y.append([py - err, py + err]) + y_err_x.append([px, px]) # Asymmetric uncertainties elif yupper is not None and ylower is not None: + if hasattr(ylower, 'unit'): + ylower = ylower.value + if hasattr(yupper, 'unit'): + yupper = yupper.value for px, py, lower, upper in zip(x, y, ylower, yupper): - try: - y_err_y.append((py - lower, py + upper)) - y_err_x.append((px, px)) - except TypeError: - pass + if isnumber(px) and isnumber(py): + y_err_y.append([py - lower, py + upper]) + y_err_x.append([px, px]) + + fig.multi_line(y_err_x, y_err_y, color=color, name=name, **kwargs) - fig.multi_line(y_err_x, y_err_y, color=color, **error_kwargs) + return fig def goodness(f1, f2, e1=None, e2=None, weights=None): @@ -990,7 +1041,7 @@ def specType(SpT, types=[i for i in 'OBAFGKMLTY'], verbose=False): if MK: # Get the stuff before and after the MK class - pre, suf = SpT.split(MK) + pre, suf = SpT.split(MK)[:2] # Get the numerical value val = float(re.findall(r'[0-9]\.?[0-9]?', suf)[0]) diff --git a/setup.py b/setup.py index 81357e91..6ca0533d 100755 --- a/setup.py +++ b/setup.py @@ -68,7 +68,7 @@ def run(self): 'Topic :: Software Development :: Libraries :: Python Modules', ], packages=find_packages(exclude=["examples"]), - version='1.2.0', + version='1.2.1', setup_requires=['setuptools_scm'], install_requires=['numpy', 'astropy', 'bokeh', 'emcee', 'pysynphot', 'scipy', 'astroquery', 'dustmaps', 'pandas','svo_filters', 'healpy'], include_package_data=True,