Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Change LDOs to use masked products by default #661

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 52 additions & 27 deletions spectral_cube/lower_dimensional_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from . import spectral_axis
from .io.core import LowerDimensionalObjectWrite
from .utils import SliceWarning, BeamWarning, SmoothingWarning, FITSWarning
from .cube_utils import convert_bunit
from . import cube_utils
from . import wcs_utils
from .masks import BooleanArrayMask, MaskBase

Expand All @@ -26,7 +26,6 @@
MultiBeamMixinClass, BeamMixinClass,
HeaderMixinClass
)
from . import cube_utils

__all__ = ['LowerDimensionalObject', 'Projection', 'Slice', 'OneDSpectrum']
class LowerDimensionalObject(u.Quantity, BaseNDClass, HeaderMixinClass):
Expand Down Expand Up @@ -89,6 +88,10 @@ def __getitem__(self, key, **kwargs):
else:
newwcs = None

print(new_qty)
print(new_qty._data)
print(new_qty.value)

new = self.__class__(value=new_qty.value,
unit=new_qty.unit,
copy=False,
Expand Down Expand Up @@ -132,20 +135,28 @@ def array(self):
Get a pure array representation of the LDO. Useful when multiplying
and using numpy indexing tricks.
"""
return np.asarray(self)
return self.filled_data[:].value

@property
def _data(self):
# the _data property is required by several other mixins
# (which probably means defining it here is a bad design)
return self.array
# @property
# def _data(self):
# # the _data property is required by several other mixins
# # (which probably means defining it here is a bad design)
# return self.__data

@property
def quantity(self):
"""
Get a pure `~astropy.units.Quantity` representation of the LDO.
"""
return u.Quantity(self)
return u.Quantity(self.filled_data[:])

@property
def value(self):
"""
Get a unitless numpy array with the mask applied.
"""
return np.asarray(self.filled_data[:])
# return np.asarray(self)

def to(self, unit, equivalencies=[], freq=None):
"""
Expand Down Expand Up @@ -248,26 +259,33 @@ def _initial_set_mask(self, mask):
matters: ``self`` must have ``_wcs``, for example.
"""
if mask is None:
mask = BooleanArrayMask(np.ones_like(self.value, dtype=bool),
self._wcs, shape=self.value.shape)
mask = BooleanArrayMask(np.isfinite(self._data),
self._wcs, shape=self._data.shape)

elif isinstance(mask, np.ndarray):
if mask.shape != self.value.shape:
if mask.shape != self._data.shape:
raise ValueError("Mask shape must match the {0} shape."
.format(self.__class__.__name__)
)
mask = BooleanArrayMask(mask, self._wcs, shape=self.value.shape)
mask = BooleanArrayMask(mask, self._wcs, shape=self._data.shape)
elif isinstance(mask, MaskBase):
pass
else:
raise TypeError("mask of type {} is not a supported mask "
"type.".format(type(mask)))

# Validate the mask before setting
mask._validate_wcs(new_data=self.value, new_wcs=self._wcs,
mask._validate_wcs(new_data=self._data, new_wcs=self._wcs,
wcs_tolerance=self._wcs_tolerance)

self._mask = mask

def __repr__(self):
prefixstr = '<' + self.__class__.__name__ + ' '
arrstr = np.array2string(self.filled_data[:].value, separator=',',
prefix=prefixstr)
return '{0}{1}{2:s}>'.format(prefixstr, arrstr, self._unitstr)


class Projection(LowerDimensionalObject, SpatialCoordMixinClass,
MaskableArrayMixinClass, BeamMixinClass):
Expand All @@ -282,8 +300,16 @@ def __new__(cls, value, unit=None, dtype=None, copy=True, wcs=None,
if wcs is not None and wcs.wcs.naxis != 2:
raise ValueError("wcs should have two dimension")

self = u.Quantity.__new__(cls, value, unit=unit, dtype=dtype,
copy=copy).view(cls)
# self = u.Quantity.__new__(cls, value, unit=unit, dtype=dtype,
# copy=copy).view(cls)

self = super().__new__(cls, value, unit=unit, dtype=dtype,
copy=copy).view(cls)

# self = cls.__new__(cls, value, unit=unit, dtype=dtype,
# copy=copy).view(cls)

self._data = np.asarray(value)
self._wcs = wcs
self._meta = {} if meta is None else meta
self._wcs_tolerance = wcs_tolerance
Expand Down Expand Up @@ -383,7 +409,7 @@ def _new_projection_with(self, data=None, wcs=None, mask=None, meta=None,
fill_value = self._fill_value if fill_value is None else fill_value

if beam is None:
if hasattr(self, 'beam'):
if self._beam is not None:
beam = self.beam

newproj = self.__class__(value=data, wcs=wcs, mask=mask, meta=meta,
Expand Down Expand Up @@ -412,7 +438,7 @@ def from_hdu(hdu):
mywcs = wcs.WCS(hdu.header)

if "BUNIT" in hdu.header:
unit = convert_bunit(hdu.header["BUNIT"])
unit = cube_utils.convert_bunit(hdu.header["BUNIT"])
meta["BUNIT"] = hdu.header["BUNIT"]
else:
unit = None
Expand Down Expand Up @@ -646,8 +672,13 @@ def __new__(cls, value, unit=None, dtype=None, copy=True, wcs=None,
if wcs is not None and wcs.wcs.naxis != 1:
raise ValueError("wcs should have two dimension")

self = u.Quantity.__new__(cls, value, unit=unit, dtype=dtype,
copy=copy).view(cls)
self = super().__new__(cls, value, unit=unit, dtype=dtype,
copy=copy).view(cls)

# self = u.Quantity.__new__(cls, value, unit=unit, dtype=dtype,
# copy=copy).view(cls)

self._data = np.asarray(value)
self._wcs = wcs
self._meta = {} if meta is None else meta
self._wcs_tolerance = wcs_tolerance
Expand All @@ -670,12 +701,6 @@ def __new__(cls, value, unit=None, dtype=None, copy=True, wcs=None,

return self

def __repr__(self):
prefixstr = '<' + self.__class__.__name__ + ' '
arrstr = np.array2string(self.filled_data[:].value, separator=',',
prefix=prefixstr)
return '{0}{1}{2:s}>'.format(prefixstr, arrstr, self._unitstr)

@staticmethod
def from_hdu(hdu):
'''
Expand All @@ -696,7 +721,7 @@ def from_hdu(hdu):
mywcs = wcs.WCS(hdu.header)

if "BUNIT" in hdu.header:
unit = convert_bunit(hdu.header["BUNIT"])
unit = cube_utils.convert_bunit(hdu.header["BUNIT"])
meta["BUNIT"] = hdu.header["BUNIT"]
else:
unit = None
Expand Down
23 changes: 23 additions & 0 deletions spectral_cube/tests/test_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,3 +790,26 @@ def test_spatial_world(view, data_adv, use_dask):
for result, expected in zip(w2_flat, world):
print(result.shape, expected.flatten().shape)
assert_allclose(result, expected.flatten())


def test_LDO_fill_value():
new_quant = twelve_qty_2d.copy()

new_quant[0, :] = np.NaN

proj = Projection(new_quant)

proj_fill0 = proj.with_fill_value(0.)

assert (np.isfinite(new_quant) == np.isfinite(proj)).all()

# Try using the filled data.
assert proj_fill0.fill_value == 0.0

assert np.isfinite(proj._get_filled_data(fill=0.)).all()

assert np.isfinite(proj_fill0.filled_data[:]).all()
assert np.isfinite(proj_fill0.unitless_filled_data[:]).all()

assert np.isfinite(proj_fill0.quantity).all()
assert np.isfinite(proj_fill0.value).all()