Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Units Implimentation #23

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 201 additions & 14 deletions saba/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import (absolute_import, unicode_literals, division,
print_function)
import numpy as np
from collections import OrderedDict
import warnings

from collections import OrderedDict, defaultdict
from sherpa.fit import Fit
from sherpa.data import Data1D, Data1DInt, Data2D, Data2DInt, DataSimulFit
from sherpa.data import BaseData
Expand All @@ -12,13 +14,13 @@
from sherpa.optmethods import GridSearch, LevMar, MonCar, NelderMead
from sherpa.estmethods import Confidence, Covariance, Projection
from sherpa.sim import MCMC
import warnings

from astropy.extern.six.moves import range
from astropy.utils import format_doc
from astropy.utils.exceptions import AstropyUserWarning
from astropy.tests.helper import catch_warnings

from astropy.modeling.utils import _combine_equivalency_dict
from astropy.units import Quantity

with catch_warnings(AstropyUserWarning) as warns:
"""this is to stop the import warning
Expand All @@ -29,7 +31,6 @@
if "SherpaFitter" not in w.message.message:
warnings.warn(w)

# from astropy.modeling

__all__ = ('SherpaFitter', 'SherpaMCMC')

Expand Down Expand Up @@ -96,8 +97,6 @@ class EstMethod(SherpaWrapper):
'projection': Projection}




class SherpaMCMC(object):
"""
An interface which makes use of sherpa's MCMC(pyBLoCXS) functionality.
Expand Down Expand Up @@ -325,8 +324,7 @@ def __init__(self, optimizer="levmar", statistic="leastsq", estmethod="covarianc
# sherpa doesn't currently have a docstring for est_method but maybe the future
setattr(self.__class__, 'est_config', property(lambda s: s._est_config, doc=self._est_method.__doc__))


def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None, bkg=None, bkg_scale=1, **kwargs):
def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None, bkg=None, bkg_scale=1, equivalencies=None, **kwargs):
"""
Fit the astropy model with a the sherpa fit routines.

Expand Down Expand Up @@ -362,11 +360,14 @@ def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None,
"""

tie_list = []
models, x, y, z, xbinsize, ybinsize, err, bkg = self.remove_units(models, x, y, z, xbinsize, ybinsize, err, bkg, equivalencies)

try:
n_inputs = models[0].n_inputs
except TypeError:
n_inputs = models.n_inputs

#print("x={x}\ny={y}\nz={z}\nerr={err}\nxbinsize={xbinsize}\nybinsize={ybinsize}\nbkg={bkg}".format(x=x, y=y, z=z, err=err, xbinsize=xbinsize, ybinsize=ybinsize, bkg=bkg))
self._data = Dataset(n_inputs, x, y, z, xbinsize, ybinsize, err, bkg, bkg_scale)

if self._data.ndata > 1:
Expand All @@ -389,7 +390,7 @@ def __call__(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None,
self._fitter = Fit(self._data.data, self._fitmodel.sherpa_model, self._stat_method, self._opt_method, self._est_method, **kwargs)
self.fit_info = self._fitter.fit()

return self._fitmodel.get_astropy_model()
return self.restore_units(self._fitmodel.get_astropy_model())

def est_errors(self, sigma=None, maxiters=None, numcores=1, methoddict=None, parlist=None):
"""
Expand Down Expand Up @@ -439,6 +440,190 @@ def get_sampler(self, *args, **kwargs):
"""
return SherpaMCMC(self, *args, **kwargs)

def remove_units(self, models, x, y, z=None, xbinsize=None, ybinsize=None, err=None, bkg=None, equivalencies=None):
"""
This stripts data and models of their units so they can be fit.
"""
try:
models._supports_unit_fitting
n_models = 1
except AttributeError:
n_models = len(models)

_x = np.array(x)
_y = np.array(y)
n_data = 1
if _x.ndim == 2 or (_x.dtype == np.object or _y.dtype == np.object):
n_data = _x.shape[0]

if n_data > 1 and n_models > 1 and not n_data == n_models: #check if we can handle n_data and n_models
raise Exception("Don't know how to handle multiple models "
"unless there is one foreach dataset")
else:
n_dim = max(n_data,n_models)
if n_data == 1:
# make them match n_dim if n_dim is 1 it wont matter anyway although perhaps its inefficeint?
# do range so that each data and model are completely independent dont want pointer like problems
x = [x.copy() for _ in range(n_dim)]
y = [y.copy() for _ in range(n_dim)]

if z is None:
z = n_dim * [z]
else:
z = [z.copy() for _ in range(n_dim)]
if err is None:
err = n_dim * [err]
else:
err = [err.copy() for _ in range(n_dim)]
if xbinsize is None:
xbinsize = n_dim * [xbinsize]
else:
xbinsize = [xbinsize.copy() for _ in range(n_dim)]
if ybinsize is None:
ybinsize = n_dim * [ybinsize]
else:
ybinsize = [ybinsize.copy() for _ in range(n_dim)]
if bkg is None:
bkg = n_dim * [bkg]
else:
bkg = [bkg.copy() for _ in range(n_dim)]
else:
if z is None:
z = n_dim * [z]
if err is None:
err = n_dim * [err]
if xbinsize is None:
xbinsize = n_dim * [xbinsize]
if ybinsize is None:
ybinsize = n_dim * [ybinsize]
if bkg is None:
bkg = n_dim * [bkg]

if n_models == 1:
models = [models.copy() for _ in range(n_dim)]

assert len(x) == len(y) == len(z) == len(err) == len(xbinsize) == len(ybinsize) == len(bkg), ValueError("lenghts of one of your parameters don't match models{lm}, x({lx}), y({ly}), "
"z{lz}, err({lerr}), xbinsize({lxbin}), ybinsize({lybin}), bkg({lbkg})".format(lm=len(models), lx=len(x),
ly=len(y), lz=len(z), lerr=len(err), lxbin=len(xbinsize), lybin=len(ybinsize),
lbkg=len(bkg)))
#iterate over all the things
self._units_sets = defaultdict(list)
_x = []
_y = []
_z = []
_xbinsize = []
_ybinsize = []
_err = []
_bkg = []
_models = []
#print("x={x}\ny={y}\nz={z}\nerr={err}\nxbinsize={xbinsize}\nybinsize={ybinsize}\nbkg={bkg}".format(x=repr(x), y=y, z=repr(z), err=repr(err), xbinsize=repr(xbinsize), ybinsize=repr(ybinsize), bkg=repr(bkg)))
for model, xx, yy, zz, xxbin, yybin, eerr, bbkg in zip(models, x, y, z, xbinsize, ybinsize, err, bkg):

if model._supports_unit_fitting:
input_units_equivalencies = _combine_equivalency_dict(model.inputs,
equivalencies,
model.input_units_equivalencies)
if model.input_units is not None:
if isinstance(xx, Quantity):
self._units_sets['x'].append(1 * xx.unit)
xx = xx.to(model.input_units['x'], equivalencies=input_units_equivalencies['x'])
else:
self._units_sets['x'].append(None)



if isinstance(yy, Quantity) and zz is not None:
self._units_sets['y'].append(1 * yy.unit)
yy = yy.to(model.input_units['y'], equivalencies=input_units_equivalencies['y'])
elif isinstance(yy, Quantity):
self._units_sets['y'].append(1 * yy.unit)
else:
self._units_sets['y'].append(None)

if xxbin is not None and isinstance(xxbin, Quantity):
xxbin = xxbin.to(model.input_units['x'], equivalencies=input_units_equivalencies['x'])

if yybin is not None and isinstance(yybin, Quantity) and z is not None:
yybin = yybin.to(model.input_units['y'], equivalencies=input_units_equivalencies['y'])

if isinstance(zz, Quantity):
self._units_sets['z'].append(1 * zz.unit)
else:
self._units_sets['z'].append(None)

if eerr is not None and isinstance(eerr, Quantity):
if z is not None:
eerr = eerr.to(self._units_sets['z'])
else:
eerr = eerr.to(self._units_sets['y'])

if bbkg is not None and isinstance(bbkg, Quantity):
if z is not None:
bbkg = bbkg.to(self._units_sets['z'])
else:
bbkg = bbkg.to(self._units_sets['y'])


_x.append(xx)
_y.append(yy)
_z.append(zz)
_xbinsize.append(xxbin)
_ybinsize.append(yybin)
_err.append(eerr)
_bkg.append(bbkg)
_models.append(model.without_units_for_data(x=xx, y=yy, z=zz))
else:
_x.append(xx)
_y.append(yy)
_z.append(zz)
_xbinsize.append(xxbin)
_ybinsize.append(yybin)
_err.append(eerr)
_bkg.append(bbkg)
_models.append(model)
self._units_sets['x'].append(None)
self._units_sets['y'].append(None)
self._units_sets['z'].append(None)
else:
if isinstance(xx, Quantity) or isinstance(yy, Quantity) or isinstance(zz, Quantity):
warnings.warn(AstropyUserWarning("This model{0} does not support being fit to data "
"with units the units will be ignored this may "
"produce erroneous results".format(len(self._units_sets['x']))))
_x.append(xx)
_y.append(yy)
_z.append(zz)
_xbinsize.append(xxbin)
_ybinsize.append(yybin)
_err.append(eerr)
_bkg.append(bbkg)
_models.append(model)
self._units_sets['x'].append(None)
self._units_sets['y'].append(None)
self._units_sets['z'].append(None)

#print("x={x}\ny={y}\nz={z}\nerr={err}\nxbinsize={xbinsize}\nybinsize={ybinsize}\nbkg={bkg}".format(x=_x, y=_y, z=_z, err=err, xbinsize=_xbinsize, ybinsize=_ybinsize, bkg=_bkg))
print(_models,_x)
if n_dim == 1:
return _models[0], _x[0], _y[0], _z[0], _xbinsize[0], _ybinsize[0], _err[0], _bkg[0]
return _models, _x, _y, _z, _xbinsize, _ybinsize, _err, _bkg

def restore_units(self,models):
"""
This retores the units to data .
"""
_models = []
try:
if models._supports_unit_fitting and self._units_sets['x'][0] is not None and self._units_sets['y'][0] is not None:
return models.with_units_from_data(x=self._units_sets['x'][0], y=self._units_sets['y'][0], z=self._units_sets['z'][0])
except AttributeError:
for n, model in enumerate(models):
if self._units_sets['x'][n] is not None and self._units_sets['y'][n] is not None:
_models.append(model.with_units_from_data(x=1 * self._units_sets['x'][n], y=1 * self._units_sets['y'][n], z=1 * self._units_sets['z'][n]))
else:
_models.append(model)
return _models
return models


class Dataset(SherpaWrapper):

Expand Down Expand Up @@ -536,7 +721,7 @@ def _make_dataset(n_dim, x, y, z=None, xbinsize=None, ybinsize=None, err=None, b
if z is None:
assert x.shape == y.shape, "shape of x and y don't match in dataset %i" % n
else:
z = np.asarray(z)
z = np.array(z)
assert x.shape == y.shape == z.shape, "shapes x,y and z don't match in dataset %i" % n

if xbinsize is not None:
Expand Down Expand Up @@ -712,6 +897,8 @@ def get_astropy_model(self):
else:
return return_models[0]

def are_model_units_sane(models):
pass # TODO

class Data1DIntBkg(Data1DInt):
"""
Expand Down Expand Up @@ -756,7 +943,7 @@ def get_background(self, index):
return self._backgrounds[index]

def __init__(self, name, xlo, xhi, y, bkg, staterror=None, bkg_scale=1, src_scale=1):
self._bkg = np.asanyarray(bkg)
self._bkg = np.array(bkg)
self._bkg_scale = src_scale
self.exposure = 1

Expand Down Expand Up @@ -812,7 +999,7 @@ def get_background(self, index):
return self._backgrounds[index]

def __init__(self, name, x, y, bkg, staterror=None, bkg_scale=1, src_scale=1):
self._bkg = np.asanyarray(bkg)
self._bkg = np.array(bkg)
self._bkg_scale = src_scale
self.exposure = 1
self.subtracted = False
Expand Down Expand Up @@ -872,7 +1059,7 @@ def get_background(self, index):
return self._backgrounds[index]

def __init__(self, name, xlo, xhi, ylo, yhi, z, bkg, staterror=None, bkg_scale=1, src_scale=1):
self._bkg = np.asanyarray(bkg)
self._bkg = np.array(bkg)
self._bkg_scale = src_scale
self.exposure = 1

Expand Down Expand Up @@ -933,7 +1120,7 @@ def get_background(self, index):
return self._backgrounds[index]

def __init__(self, name, x, y, z, bkg, staterror=None, bkg_scale=1, src_scale=1):
self._bkg = np.asanyarray(bkg)
self._bkg = np.array(bkg)
self._bkg_scale = src_scale
self.exposure = 1
self.subtracted = False
Expand Down
23 changes: 23 additions & 0 deletions saba/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from saba import SherpaFitter, Dataset, ConvertedModel
from astropy.modeling.models import Gaussian1D, Gaussian2D, Polynomial1D
from astropy import units as u

_RANDOM_SEED = 0x1337
np.random.seed(_RANDOM_SEED)
Expand Down Expand Up @@ -343,3 +344,25 @@ def test_check_fit(self, nbins=100):
x_med = (xx[med_ind] + xx[med_ind + 1]) / 2.0

assert_allclose(self.params[nn], x_med, atol=0.1)



class TestSherpaQuantiesFitter(object):

def setup_class(self):
self.x = np.linspace(1, 5, 30) * u.micron
self.y = np.exp(-0.5 * (self.x - 2.5 * u.micron)**2 / (200 * u.nm)**2) * u.mJy
self.model_micron = Gaussian1D(mean=3 * u.micron, stddev=1 * u.micron, amplitude=1 * u.Jy)
self.model_thz = Gaussian1D(mean=110 * u.THz, stddev=10 * u.THz, amplitude=1 * u.Jy)
self.fitter = SherpaFitter()

def test_basic(self):
self.fitter(self.model_micron, self.x, self.y)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it makes sense to actually run a fitter here, so you can look at the return value?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just trying to check that no exceptions are thrown at the moment.
I was planning extending the tests over this weekend to test the values make sense.


def test_user_equivalencies(self):
self.fitter(self.model_thz, self.x, self.y, equivalencies={'x': u.spectral()})

def test_model_equivalencies(self):
self.model_thz.input_units_equivalencies = {'x': u.spectral()}
self.fitter(self.model_thz, self.x, self.y)