Skip to content

Commit

Permalink
Modify masking and NaN handling in HorneExtract (#163)
Browse files Browse the repository at this point in the history
* Update NaN handling and masking in HorneExtract
* Removed masking so NaNs propagate into 1D spectra
  • Loading branch information
ojustino authored Feb 21, 2023
1 parent 7043f78 commit 799ff49
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 32 deletions.
56 changes: 38 additions & 18 deletions specreduce/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,8 @@ def _parse_image(self, image,
elif mask is not None:
pass
else:
mask = ~np.isfinite(img)
# if user provides no mask at all, don't mask anywhere
mask = np.zeros_like(img)

if img.shape != mask.shape:
raise ValueError('image and mask shapes must match.')
Expand Down Expand Up @@ -484,13 +485,15 @@ def __call__(self, image=None, trace_object=None,
# parse image and replace optional arguments with updated values
self.image = self._parse_image(image, variance, mask, unit, disp_axis)
variance = self.image.uncertainty.array
mask = self.image.mask
unit = self.image.unit

# mask any previously uncaught invalid values
or_mask = np.logical_or(mask,
img = np.ma.masked_array(self.image.data, mask=mask)

# create separate mask including any previously uncaught non-finite
# values for purposes of calculating fit
or_mask = np.logical_or(img.mask,
~np.isfinite(self.image.data))
img = np.ma.masked_array(self.image.data, or_mask)
mask = img.mask

# If the trace is not flat, shift the rows in each column
# so the image is aligned along the trace:
Expand All @@ -510,26 +513,43 @@ def __call__(self, image=None, trace_object=None,
)

# co-add signal in each image column
ncols = img.shape[crossdisp_axis]
xd_pixels = np.arange(ncols) # y plot dir / x spec dir
coadd = img.sum(axis=disp_axis) / ncols
nrows = img.shape[crossdisp_axis]
xd_pixels = np.arange(nrows) # counted in y dir on plot (or x in spec)

row_mask = np.logical_or.reduce(or_mask, axis=disp_axis)
coadd = np.ma.masked_array(np.sum(img, axis=disp_axis) / nrows,
mask=row_mask)
# (mask rows with non-finite sums for fit to work later on)

# fit source profile, using Gaussian model as a template
# fit source profile to brightest row, using Gaussian model as template
# NOTE: could add argument for users to provide their own model
gauss_prof = models.Gaussian1D(amplitude=coadd.max(),
mean=coadd.argmax(), stddev=2)

# Fit extraction kernel to column with combined gaussian/bkgrd model
# Fit extraction kernel to column's finite values with combined model
# (must exclude masked indices manually; LevMarLSQFitter does not)
ext_prof = gauss_prof + bkgrd_prof
fitter = fitting.LevMarLSQFitter()
fit_ext_kernel = fitter(ext_prof, xd_pixels, coadd)
fit_ext_kernel = fitter(ext_prof,
xd_pixels[~row_mask], coadd[~row_mask])

# use compound model to fit a kernel to each image column
# use compound model to fit a kernel to each fully finite image column
# NOTE: infers Gaussian1D source profile; needs generalization for others
col_mask = np.logical_or.reduce(or_mask, axis=crossdisp_axis)
nonf_col = [np.nan] * img.shape[crossdisp_axis]

kernel_vals = []
norms = []
for col_pix in range(img.shape[disp_axis]):
# set gaussian model's mean as column's corresponding trace value
# for now, skip columns with any non-finite values
# NOTE: fit and other kernel operations should support masking again
# once a fix is in for renormalizing columns with non-finite values
if col_mask[col_pix]:
kernel_vals.append(nonf_col)
norms.append(np.nan)
continue

# else, set compound model's mean to column's matching trace value
fit_ext_kernel.mean_0 = mean_init_guess[col_pix]

# NOTE: support for variable FWHMs forthcoming and would be here
Expand All @@ -543,15 +563,15 @@ def __call__(self, image=None, trace_object=None,
* fit_ext_kernel.stddev_0 * np.sqrt(2*np.pi))

# transform fit-specific information
kernel_vals = np.array(kernel_vals).T
kernel_vals = np.vstack(kernel_vals).T
norms = np.array(norms)

# calculate kernel normalization, masking NaNs
g_x = np.ma.sum(kernel_vals**2 / variance, axis=crossdisp_axis)
# calculate kernel normalization
g_x = np.sum(kernel_vals**2 / variance, axis=crossdisp_axis)

# sum by column weights
weighted_img = np.ma.divide(img * kernel_vals, variance)
result = np.ma.sum(weighted_img, axis=crossdisp_axis) / g_x
weighted_img = np.divide(img * kernel_vals, variance)
result = np.sum(weighted_img, axis=crossdisp_axis) / g_x

# multiply kernel normalization into the extracted signal
extraction = result * norms
Expand Down
28 changes: 14 additions & 14 deletions specreduce/tests/test_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import astropy.units as u
from astropy.nddata import CCDData, VarianceUncertainty, UnknownUncertainty
from astropy.tests.helper import assert_quantity_allclose
from astropy.utils.exceptions import AstropyUserWarning

from specreduce.extract import (
BoxcarExtract, HorneExtract, OptimalExtract, _align_along_trace
Expand Down Expand Up @@ -128,6 +127,8 @@ def test_horne_image_validation():
== np.arange(image.shape[extract.disp_axis]) * u.pix)


# ignore Astropy warning for extractions that aren't best fit with a Gaussian:
@pytest.mark.filterwarnings("ignore:The fit may be unsuccessful")
def test_horne_variance_errors():
trace = FlatTrace(image, 3.0)

Expand Down Expand Up @@ -155,6 +156,7 @@ def test_horne_variance_errors():
mask=image.mask, unit=u.Jy)


@pytest.mark.filterwarnings("ignore:The fit may be unsuccessful")
def test_horne_non_flat_trace():
# create a synthetic "2D spectrum" and its non-flat trace
n_rows, n_cols = (10, 50)
Expand All @@ -181,19 +183,17 @@ def test_horne_non_flat_trace():
# ensure that mask is correctly unrolled back to its original alignment:
np.testing.assert_allclose(unrolled, original)

# These synthetic extractions don't fit well with a Gaussian, so will pass warning:
with pytest.warns(AstropyUserWarning, match="The fit may be unsuccessful"):
# Extract the spectrum from the non-flat image+trace
extract_non_flat = HorneExtract(
rolled, ArrayTrace(rolled, exact_trace),
variance=err, mask=mask, unit=u.Jy
)()

# Also extract the spectrum from the image after alignment with a flat trace
extract_flat = HorneExtract(
unrolled, FlatTrace(unrolled, n_rows // 2),
variance=err, mask=mask, unit=u.Jy
)()
# Extract the spectrum from the non-flat image+trace
extract_non_flat = HorneExtract(
rolled, ArrayTrace(rolled, exact_trace),
variance=err, mask=mask, unit=u.Jy
)()

# Also extract the spectrum from the image after alignment with a flat trace
extract_flat = HorneExtract(
unrolled, FlatTrace(unrolled, n_rows // 2),
variance=err, mask=mask, unit=u.Jy
)()

# ensure both extractions are equivalent:
assert_quantity_allclose(extract_non_flat.flux, extract_flat.flux)

0 comments on commit 799ff49

Please sign in to comment.