diff --git a/specreduce/background.py b/specreduce/background.py index 88482695..9daa54d1 100644 --- a/specreduce/background.py +++ b/specreduce/background.py @@ -8,7 +8,7 @@ from astropy import units as u from specreduce.extract import _ap_weight_image, _to_spectrum1d_pixels -from specreduce.tracing import Trace, FlatTrace +from specreduce.tracing import FlatTrace, Trace __all__ = ['Background'] diff --git a/specreduce/extract.py b/specreduce/extract.py index 9965794f..6d2b75ce 100644 --- a/specreduce/extract.py +++ b/specreduce/extract.py @@ -10,7 +10,7 @@ from astropy.nddata import NDData from specreduce.core import SpecreduceOperation -from specreduce.tracing import Trace, FlatTrace +from specreduce.tracing import FlatTrace, Trace from specutils import Spectrum1D __all__ = ['BoxcarExtract', 'HorneExtract', 'OptimalExtract'] diff --git a/specreduce/tests/test_tracing.py b/specreduce/tests/test_tracing.py index 1d0fc983..ae18154a 100644 --- a/specreduce/tests/test_tracing.py +++ b/specreduce/tests/test_tracing.py @@ -141,3 +141,17 @@ def test_kosmos_trace(): raise RuntimeError('Trace was erroneously calculated on all-NaN image') # could try to catch warning thrown for all-nan bins + + +def test_mutability(): + trace = FlatTrace(IM, 10) + assert trace.trace_pos == 10 + assert trace.trace_pos == trace.trace[0] + + trace_shifted = trace + 10 + assert trace_shifted.trace_pos == 20 + assert trace_shifted.trace[0] == 20 + + with pytest.raises(AttributeError): + # this attribute shouldn't be writable: + trace_shifted.trace_pos = 15 diff --git a/specreduce/tracing.py b/specreduce/tracing.py index 6b2557ca..b7a4cdd0 100644 --- a/specreduce/tracing.py +++ b/specreduce/tracing.py @@ -1,7 +1,7 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst from copy import deepcopy -from dataclasses import dataclass +from dataclasses import dataclass, field import warnings from astropy.modeling import fitting, models @@ -13,7 +13,7 @@ __all__ = ['Trace', 'FlatTrace', 'ArrayTrace', 'KosmosTrace'] -@dataclass +@dataclass(init=False, frozen=True) class Trace: """ Basic tracing class that by default traces the middle of the image. @@ -22,44 +22,28 @@ class Trace: ---------- image : `~astropy.nddata.CCDData` Image to be traced - - Properties - ---------- - shape : tuple - Shape of the array describing the trace """ image: CCDData + _trace: (np.ndarray, None) = field(repr=False) + + def __init__(self, image, _trace=None): + object.__setattr__(self, '_trace', self._default_trace(image)) + object.__setattr__(self, 'image', image) def __post_init__(self): - self.trace_pos = self.image.shape[0] / 2 - self.trace = np.ones_like(self.image[0]) * self.trace_pos + # this class only exists to catch __post_init__ calls in its + # subclasses, so that super().__post_init__ calls work correctly. + pass def __getitem__(self, i): return self.trace[i] - @property - def shape(self): - return self.trace.shape - - def shift(self, delta): - """ - Shift the trace by delta pixels perpendicular to the axis being traced - - Parameters - ---------- - delta : float - Shift to be applied to the trace - """ - # act on self.trace.data to ignore the mask and then re-mask when calling _bound_trace - self.trace = np.asarray(self.trace.data) + delta - self._bound_trace() - def _bound_trace(self): """ Mask trace positions that are outside the upper/lower bounds of the image. """ ny = self.image.shape[0] - self.trace = np.ma.masked_outside(self.trace, 0, ny-1) + object.__setattr__(self, '_trace', np.ma.masked_outside(self._trace, 0, ny - 1)) def __add__(self, delta): """ @@ -77,8 +61,41 @@ def __sub__(self, delta): """ return self.__add__(-delta) + def shift(self, delta): + """ + Shift the trace by delta pixels perpendicular to the axis being traced + + Parameters + ---------- + delta : float + Shift to be applied to the trace + """ + # act on self._trace.data to ignore the mask and then re-mask when calling _bound_trace + object.__setattr__(self, '_trace', np.asarray(self._trace.data) + delta) + self._bound_trace() -@dataclass + @property + def shape(self): + return self._trace.shape + + @property + def trace(self): + return self._trace + + @staticmethod + def _default_trace(image, trace_pos=None): + """ + Compute a default trace position and trace array using only + the image dimensions. + """ + if trace_pos is None: + trace_pos = image.shape[0] / 2 + + trace = np.ones_like(image[0]) * trace_pos + return trace + + +@dataclass(init=False, frozen=True) class FlatTrace(Trace): """ Trace that is constant along the axis being traced @@ -92,10 +109,12 @@ class FlatTrace(Trace): trace_pos : float Position of the trace """ - trace_pos: float + _trace_pos: (float, np.ndarray) = field(repr=False) - def __post_init__(self): - self.set_position(self.trace_pos) + def __init__(self, image, trace_pos): + object.__setattr__(self, '_trace_pos', trace_pos) + super().__init__(image, _trace=self._default_trace(image, trace_pos)) + self.set_position(trace_pos) def set_position(self, trace_pos): """ @@ -106,12 +125,29 @@ def set_position(self, trace_pos): trace_pos : float Position of the trace """ - self.trace_pos = trace_pos - self.trace = np.ones_like(self.image[0]) * self.trace_pos + object.__setattr__(self, '_trace_pos', trace_pos) + object.__setattr__(self, '_trace', np.ones_like(self.image[0]) * trace_pos) self._bound_trace() + @property + def trace_pos(self): + return self._trace_pos -@dataclass + def shift(self, delta): + """ + Shift the trace by delta pixels perpendicular to the axis being traced + + Parameters + ---------- + delta : float + Shift to be applied to the trace + """ + # act on self._trace.data to ignore the mask and then re-mask when calling _bound_trace + object.__setattr__(self, '_trace_pos', self._trace_pos + delta) + super().shift(delta) + + +@dataclass(init=False, frozen=True) class ArrayTrace(Trace): """ Define a trace given an array of trace positions @@ -121,24 +157,26 @@ class ArrayTrace(Trace): trace : `numpy.ndarray` Array containing trace positions """ - trace: np.ndarray - def __post_init__(self): + def __init__(self, image, trace): + super().__init__(image, _trace=trace) nx = self.image.shape[1] - nt = len(self.trace) + nt = len(trace) if nt != nx: if nt > nx: # truncate trace to fit image - self.trace = self.trace[0:nx] + trace = trace[0:nx] else: # assume trace starts at beginning of image and pad out trace to fit. # padding will be the last value of the trace, but will be masked out. - padding = np.ma.MaskedArray(np.ones(nx - nt) * self.trace[-1], mask=True) - self.trace = np.ma.hstack([self.trace, padding]) + padding = np.ma.MaskedArray(np.ones(nx - nt) * trace[-1], mask=True) + trace = np.ma.hstack([trace, padding]) + object.__setattr__(self, '_trace', trace) + self._bound_trace() -@dataclass +@dataclass(init=False, frozen=True) class KosmosTrace(Trace): """ Trace the spectrum aperture in an image. @@ -192,14 +230,24 @@ class KosmosTrace(Trace): 4) add other interpolation modes besides spline, maybe via specutils.manipulation methods? """ - bins: int = 20 - guess: float = None - window: int = None - peak_method: str = 'gaussian' + bins: int + guess: float + window: int + peak_method: str _crossdisp_axis = 0 _disp_axis = 1 - def __post_init__(self): + def _process_init_kwargs(self, **kwargs): + for attr, value in kwargs.items(): + object.__setattr__(self, attr, value) + + def __init__(self, image, bins=20, guess=None, window=None, peak_method='gaussian'): + # This method will assign the user supplied value (or default) to the attrs: + self._process_init_kwargs( + bins=bins, guess=guess, window=window, peak_method=peak_method + ) + super().__init__(image, _trace=self._default_trace(image)) + # handle multiple image types and mask uncaught invalid values if isinstance(self.image, NDData): img = np.ma.masked_invalid(np.ma.masked_array(self.image.data, @@ -223,7 +271,7 @@ def __post_init__(self): if not isinstance(self.bins, int): warnings.warn('TRACE: Converting bins to int') - self.bins = int(self.bins) + object.__setattr__(self, 'bins', int(self.bins)) if self.bins < 4: raise ValueError('bins must be >= 4') @@ -240,7 +288,7 @@ def __post_init__(self): "length of the image's spatial direction") elif self.window is not None and not isinstance(self.window, int): warnings.warn('TRACE: Converting window to int') - self.window = int(self.window) + object.__setattr__(self, 'window', int(self.window)) # set max peak location by user choice or wavelength with max avg flux ztot = img.sum(axis=self._disp_axis) / img.shape[self._disp_axis] @@ -343,4 +391,4 @@ def __post_init__(self): warnings.warn("TRACE ERROR: No valid points found in trace") trace_y = np.tile(np.nan, len(x_bins)) - self.trace = np.ma.masked_invalid(trace_y) + object.__setattr__(self, '_trace', np.ma.masked_invalid(trace_y))