diff --git a/saba/main.py b/saba/main.py index 1320744..d2fd187 100644 --- a/saba/main.py +++ b/saba/main.py @@ -1,7 +1,9 @@ from __future__ import (absolute_import, unicode_literals, division, print_function) import numpy as np -from collections import OrderedDict +import warnings + +from collections import OrderedDict, defaultdict from sherpa.fit import Fit from sherpa.data import Data1D, Data1DInt, Data2D, Data2DInt, DataSimulFit from sherpa.data import BaseData @@ -12,13 +14,13 @@ from sherpa.optmethods import GridSearch, LevMar, MonCar, NelderMead from sherpa.estmethods import Confidence, Covariance, Projection from sherpa.sim import MCMC -import warnings - from astropy.extern.six.moves import range from astropy.utils import format_doc from astropy.utils.exceptions import AstropyUserWarning from astropy.tests.helper import catch_warnings +from astropy.modeling.utils import _combine_equivalency_dict +from astropy.units import Quantity with catch_warnings(AstropyUserWarning) as warns: """this is to stop the import warning @@ -29,7 +31,6 @@ if "SherpaFitter" not in w.message.message: warnings.warn(w) -# from astropy.modeling __all__ = ('SherpaFitter', 'SherpaMCMC') @@ -96,8 +97,6 @@ class EstMethod(SherpaWrapper): 'projection': Projection} - - class SherpaMCMC(object): """ An interface which makes use of sherpa's MCMC(pyBLoCXS) functionality. @@ -325,8 +324,7 @@ def __init__(self, optimizer="levmar", statistic="leastsq", estmethod="covarianc # sherpa doesn't currently have a docstring for est_method but maybe the future setattr(self.__class__, 'est_config', property(lambda s: s._est_config, doc=self._est_method.__doc__)) - - def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None, bkg=None, bkg_scale=1, **kwargs): + def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None, bkg=None, bkg_scale=1, equivalencies=None, **kwargs): """ Fit the astropy model with a the sherpa fit routines. @@ -362,11 +360,14 @@ def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None, """ tie_list = [] + models, x, y, z, xbinsize, ybinsize, err, bkg = self.remove_units(models, x, y, z, xbinsize, ybinsize, err, bkg, equivalencies) + try: n_inputs = models[0].n_inputs except TypeError: n_inputs = models.n_inputs + #print("x={x}\ny={y}\nz={z}\nerr={err}\nxbinsize={xbinsize}\nybinsize={ybinsize}\nbkg={bkg}".format(x=x, y=y, z=z, err=err, xbinsize=xbinsize, ybinsize=ybinsize, bkg=bkg)) self._data = Dataset(n_inputs, x, y, z, xbinsize, ybinsize, err, bkg, bkg_scale) if self._data.ndata > 1: @@ -389,7 +390,7 @@ def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None, self._fitter = Fit(self._data.data, self._fitmodel.sherpa_model, self._stat_method, self._opt_method, self._est_method, **kwargs) self.fit_info = self._fitter.fit() - return self._fitmodel.get_astropy_model() + return self.restore_units(self._fitmodel.get_astropy_model()) def est_errors(self, sigma=None, maxiters=None, numcores=1, methoddict=None, parlist=None): """ @@ -439,6 +440,190 @@ def get_sampler(self, *args, **kwargs): """ return SherpaMCMC(self, *args, **kwargs) + def remove_units(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None, bkg=None, equivalencies=None): + """ + This stripts data and models of their units so they can be fit. + """ + try: + models._supports_unit_fitting + n_models = 1 + except AttributeError: + n_models = len(models) + + _x = np.array(x) + _y = np.array(y) + n_data = 1 + if _x.ndim == 2 or (_x.dtype == np.object or _y.dtype == np.object): + n_data = _x.shape[0] + + if n_data > 1 and n_models > 1 and not n_data == n_models: #check if we can handle n_data and n_models + raise Exception("Don't know how to handle multiple models " + "unless there is one foreach dataset") + else: + n_dim = max(n_data,n_models) + if n_data == 1: + # make them match n_dim if n_dim is 1 it wont matter anyway although perhaps its inefficeint? + # do range so that each data and model are completely independent dont want pointer like problems + x = [x.copy() for _ in range(n_dim)] + y = [y.copy() for _ in range(n_dim)] + + if z is None: + z = n_dim * [z] + else: + z = [z.copy() for _ in range(n_dim)] + if err is None: + err = n_dim * [err] + else: + err = [err.copy() for _ in range(n_dim)] + if xbinsize is None: + xbinsize = n_dim * [xbinsize] + else: + xbinsize = [xbinsize.copy() for _ in range(n_dim)] + if ybinsize is None: + ybinsize = n_dim * [ybinsize] + else: + ybinsize = [ybinsize.copy() for _ in range(n_dim)] + if bkg is None: + bkg = n_dim * [bkg] + else: + bkg = [bkg.copy() for _ in range(n_dim)] + else: + if z is None: + z = n_dim * [z] + if err is None: + err = n_dim * [err] + if xbinsize is None: + xbinsize = n_dim * [xbinsize] + if ybinsize is None: + ybinsize = n_dim * [ybinsize] + if bkg is None: + bkg = n_dim * [bkg] + + if n_models == 1: + models = [models.copy() for _ in range(n_dim)] + + assert len(x) == len(y) == len(z) == len(err) == len(xbinsize) == len(ybinsize) == len(bkg), ValueError("lenghts of one of your parameters don't match models{lm}, x({lx}), y({ly}), " + "z{lz}, err({lerr}), xbinsize({lxbin}), ybinsize({lybin}), bkg({lbkg})".format(lm=len(models), lx=len(x), + ly=len(y), lz=len(z), lerr=len(err), lxbin=len(xbinsize), lybin=len(ybinsize), + lbkg=len(bkg))) + #iterate over all the things + self._units_sets = defaultdict(list) + _x = [] + _y = [] + _z = [] + _xbinsize = [] + _ybinsize = [] + _err = [] + _bkg = [] + _models = [] + #print("x={x}\ny={y}\nz={z}\nerr={err}\nxbinsize={xbinsize}\nybinsize={ybinsize}\nbkg={bkg}".format(x=repr(x), y=y, z=repr(z), err=repr(err), xbinsize=repr(xbinsize), ybinsize=repr(ybinsize), bkg=repr(bkg))) + for model, xx, yy, zz, xxbin, yybin, eerr, bbkg in zip(models, x, y, z, xbinsize, ybinsize, err, bkg): + + if model._supports_unit_fitting: + input_units_equivalencies = _combine_equivalency_dict(model.inputs, + equivalencies, + model.input_units_equivalencies) + if model.input_units is not None: + if isinstance(xx, Quantity): + self._units_sets['x'].append(1 * xx.unit) + xx = xx.to(model.input_units['x'], equivalencies=input_units_equivalencies['x']) + else: + self._units_sets['x'].append(None) + + + + if isinstance(yy, Quantity) and zz is not None: + self._units_sets['y'].append(1 * yy.unit) + yy = yy.to(model.input_units['y'], equivalencies=input_units_equivalencies['y']) + elif isinstance(yy, Quantity): + self._units_sets['y'].append(1 * yy.unit) + else: + self._units_sets['y'].append(None) + + if xxbin is not None and isinstance(xxbin, Quantity): + xxbin = xxbin.to(model.input_units['x'], equivalencies=input_units_equivalencies['x']) + + if yybin is not None and isinstance(yybin, Quantity) and z is not None: + yybin = yybin.to(model.input_units['y'], equivalencies=input_units_equivalencies['y']) + + if isinstance(zz, Quantity): + self._units_sets['z'].append(1 * zz.unit) + else: + self._units_sets['z'].append(None) + + if eerr is not None and isinstance(eerr, Quantity): + if z is not None: + eerr = eerr.to(self._units_sets['z']) + else: + eerr = eerr.to(self._units_sets['y']) + + if bbkg is not None and isinstance(bbkg, Quantity): + if z is not None: + bbkg = bbkg.to(self._units_sets['z']) + else: + bbkg = bbkg.to(self._units_sets['y']) + + + _x.append(xx) + _y.append(yy) + _z.append(zz) + _xbinsize.append(xxbin) + _ybinsize.append(yybin) + _err.append(eerr) + _bkg.append(bbkg) + _models.append(model.without_units_for_data(x=xx, y=yy, z=zz)) + else: + _x.append(xx) + _y.append(yy) + _z.append(zz) + _xbinsize.append(xxbin) + _ybinsize.append(yybin) + _err.append(eerr) + _bkg.append(bbkg) + _models.append(model) + self._units_sets['x'].append(None) + self._units_sets['y'].append(None) + self._units_sets['z'].append(None) + else: + if isinstance(xx, Quantity) or isinstance(yy, Quantity) or isinstance(zz, Quantity): + warnings.warn(AstropyUserWarning("This model{0} does not support being fit to data " + "with units the units will be ignored this may " + "produce erroneous results".format(len(self._units_sets['x'])))) + _x.append(xx) + _y.append(yy) + _z.append(zz) + _xbinsize.append(xxbin) + _ybinsize.append(yybin) + _err.append(eerr) + _bkg.append(bbkg) + _models.append(model) + self._units_sets['x'].append(None) + self._units_sets['y'].append(None) + self._units_sets['z'].append(None) + + #print("x={x}\ny={y}\nz={z}\nerr={err}\nxbinsize={xbinsize}\nybinsize={ybinsize}\nbkg={bkg}".format(x=_x, y=_y, z=_z, err=err, xbinsize=_xbinsize, ybinsize=_ybinsize, bkg=_bkg)) + print(_models,_x) + if n_dim == 1: + return _models[0], _x[0], _y[0], _z[0], _xbinsize[0], _ybinsize[0], _err[0], _bkg[0] + return _models, _x, _y, _z, _xbinsize, _ybinsize, _err, _bkg + + def restore_units(self,models): + """ + This retores the units to data . + """ + _models = [] + try: + if models._supports_unit_fitting and self._units_sets['x'][0] is not None and self._units_sets['y'][0] is not None: + return models.with_units_from_data(x=self._units_sets['x'][0], y=self._units_sets['y'][0], z=self._units_sets['z'][0]) + except AttributeError: + for n, model in enumerate(models): + if self._units_sets['x'][n] is not None and self._units_sets['y'][n] is not None: + _models.append(model.with_units_from_data(x=1 * self._units_sets['x'][n], y=1 * self._units_sets['y'][n], z=1 * self._units_sets['z'][n])) + else: + _models.append(model) + return _models + return models + class Dataset(SherpaWrapper): @@ -536,7 +721,7 @@ def _make_dataset(n_dim, x, y, z=None, xbinsize=None, ybinsize=None, err=None, b if z is None: assert x.shape == y.shape, "shape of x and y don't match in dataset %i" % n else: - z = np.asarray(z) + z = np.array(z) assert x.shape == y.shape == z.shape, "shapes x,y and z don't match in dataset %i" % n if xbinsize is not None: @@ -712,6 +897,8 @@ def get_astropy_model(self): else: return return_models[0] + def are_model_units_sane(models): + pass # TODO class Data1DIntBkg(Data1DInt): """ @@ -756,7 +943,7 @@ def get_background(self, index): return self._backgrounds[index] def __init__(self, name, xlo, xhi, y, bkg, staterror=None, bkg_scale=1, src_scale=1): - self._bkg = np.asanyarray(bkg) + self._bkg = np.array(bkg) self._bkg_scale = src_scale self.exposure = 1 @@ -812,7 +999,7 @@ def get_background(self, index): return self._backgrounds[index] def __init__(self, name, x, y, bkg, staterror=None, bkg_scale=1, src_scale=1): - self._bkg = np.asanyarray(bkg) + self._bkg = np.array(bkg) self._bkg_scale = src_scale self.exposure = 1 self.subtracted = False @@ -872,7 +1059,7 @@ def get_background(self, index): return self._backgrounds[index] def __init__(self, name, xlo, xhi, ylo, yhi, z, bkg, staterror=None, bkg_scale=1, src_scale=1): - self._bkg = np.asanyarray(bkg) + self._bkg = np.array(bkg) self._bkg_scale = src_scale self.exposure = 1 @@ -933,7 +1120,7 @@ def get_background(self, index): return self._backgrounds[index] def __init__(self, name, x, y, z, bkg, staterror=None, bkg_scale=1, src_scale=1): - self._bkg = np.asanyarray(bkg) + self._bkg = np.array(bkg) self._bkg_scale = src_scale self.exposure = 1 self.subtracted = False diff --git a/saba/tests/test_main.py b/saba/tests/test_main.py index 5ee449e..7e92a42 100644 --- a/saba/tests/test_main.py +++ b/saba/tests/test_main.py @@ -21,6 +21,7 @@ from saba import SherpaFitter, Dataset, ConvertedModel from astropy.modeling.models import Gaussian1D, Gaussian2D, Polynomial1D +from astropy import units as u _RANDOM_SEED = 0x1337 np.random.seed(_RANDOM_SEED) @@ -343,3 +344,25 @@ def test_check_fit(self, nbins=100): x_med = (xx[med_ind] + xx[med_ind + 1]) / 2.0 assert_allclose(self.params[nn], x_med, atol=0.1) + + + +class TestSherpaQuantiesFitter(object): + + def setup_class(self): + self.x = np.linspace(1, 5, 30) * u.micron + self.y = np.exp(-0.5 * (self.x - 2.5 * u.micron)**2 / (200 * u.nm)**2) * u.mJy + self.model_micron = Gaussian1D(mean=3 * u.micron, stddev=1 * u.micron, amplitude=1 * u.Jy) + self.model_thz = Gaussian1D(mean=110 * u.THz, stddev=10 * u.THz, amplitude=1 * u.Jy) + self.fitter = SherpaFitter() + + def test_basic(self): + self.fitter(self.model_micron, self.x, self.y) + + + def test_user_equivalencies(self): + self.fitter(self.model_thz, self.x, self.y, equivalencies={'x': u.spectral()}) + + def test_model_equivalencies(self): + self.model_thz.input_units_equivalencies = {'x': u.spectral()} + self.fitter(self.model_thz, self.x, self.y)