Skip to content

Commit

Permalink
Generalize reproject_and_coadd to N-dimensions and fix test failures
Browse files Browse the repository at this point in the history
  • Loading branch information
astrofrog committed Jun 5, 2024
1 parent bc503a4 commit f5e3016
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 220 deletions.
21 changes: 20 additions & 1 deletion reproject/array_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

__all__ = ["map_coordinates"]
__all__ = ["map_coordinates", "sample_array_edges"]


def map_coordinates(image, coords, **kwargs):
Expand Down Expand Up @@ -35,3 +35,22 @@ def map_coordinates(image, coords, **kwargs):
values[reset] = kwargs.get("cval", 0.0)

return values


def sample_array_edges(shape, *, n_samples):
# Given an N-dimensional array shape, sample each edge of the array using
# the requested number of samples (which will include vertices). To do this
# we iterate through the dimensions and for each one we sample the points
# in that dimension and iterate over the combination of other vertices.
# Returns an array with dimensions (N, n_samples)
all_positions = []
ndim = len(shape)
shape = np.array(shape)
for idim in range(ndim):
for vertex in range(2**ndim):
positions = -0.5 + shape * ((vertex & (2 ** np.arange(ndim))) > 0).astype(int)
positions = np.broadcast_to(positions, (n_samples, ndim)).copy()
positions[:, idim] = np.linspace(-0.5, shape[idim] - 0.5, n_samples)
all_positions.append(positions)
positions = np.unique(np.vstack(all_positions), axis=0).T
return positions
180 changes: 77 additions & 103 deletions reproject/mosaicking/coadd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from astropy.wcs import WCS
from astropy.wcs.wcsapi import SlicedLowLevelWCS

from ..array_utils import sample_array_edges
from ..utils import parse_input_data, parse_input_weights, parse_output_projection
from .background import determine_offset_matrix, solve_corrections_sgd
from .subset_array import ReprojectedArraySubset
Expand All @@ -30,15 +31,13 @@ def reproject_and_coadd(
output_footprint=None,
block_sizes=None,
progress_bar=None,
blank_pixel_value=np.nan,
blank_pixel_value=0,
**kwargs,
):
"""
Given a set of input images, reproject and co-add these to a single
Given a set of input data, reproject and co-add these to a single
final image.
This currently only works with 2-d images with celestial WCS.
Parameters
----------
input_data : iterable
Expand Down Expand Up @@ -149,24 +148,31 @@ def reproject_and_coadd(

wcs_out, shape_out = parse_output_projection(output_projection, shape_out=shape_out)

if output_array is not None and output_array.shape != shape_out:
if output_array is None:
output_array = np.zeros(shape_out)
elif output_array.shape != shape_out:
raise ValueError(
"If you specify an output array, it must have a shape matching "
f"the output shape {shape_out}"
)
if output_footprint is not None and output_footprint.shape != shape_out:

if output_footprint is None:
output_footprint = np.zeros(shape_out)
elif output_footprint.shape != shape_out:
raise ValueError(
"If you specify an output footprint array, it must have a shape matching "
f"the output shape {shape_out}"
)

if output_array is None:
output_array = np.zeros(shape_out)
if output_footprint is None:
output_footprint = np.zeros(shape_out)
# Define 'on-the-fly' mode: in the case where we don't need to match
# the backgrounds and we are combining with 'mean' or 'sum', we don't
# have to keep track of the intermediate arrays and can just modify
# the output array on-the-fly
on_the_fly = not match_background and combine_function in ("mean", "sum")

# Start off by reprojecting individual images to the final projection
if match_background:

if not on_the_fly:
arrays = []

for idata in progress_bar(range(len(input_data))):
Expand All @@ -192,71 +198,42 @@ def reproject_and_coadd(
# significant distortion (when the edges of the input image become
# convex in the output projection), and transforming every edge pixel,
# which provides a lot of redundant information.
if array_in.ndim == 2:
ny, nx = array_in.shape
n_per_edge = 11
xs = np.linspace(-0.5, nx - 0.5, n_per_edge)
ys = np.linspace(-0.5, ny - 0.5, n_per_edge)
xs = np.concatenate((xs, np.full(n_per_edge, xs[-1]), xs, np.full(n_per_edge, xs[0])))
ys = np.concatenate((np.full(n_per_edge, ys[0]), ys, np.full(n_per_edge, ys[-1]), ys))
xc_out, yc_out = wcs_out.world_to_pixel(wcs_in.pixel_to_world(xs, ys))
shape_out_cel = shape_out
elif array_in.ndim == 3:
# for cubes, we only handle single corners now
nz, ny, nx = array_in.shape
xc = np.array([-0.5, nx - 0.5, nx - 0.5, -0.5])
yc = np.array([-0.5, -0.5, ny - 0.5, ny - 0.5])
zc = np.array([-0.5, nz - 0.5])
# TODO: figure out what to do here if the low_level_wcs doesn't support subsetting
xc_out, yc_out = wcs_out.low_level_wcs.celestial.world_to_pixel(
wcs_in.celestial.pixel_to_world(xc, yc)
)
zc_out = wcs_out.low_level_wcs.spectral.world_to_pixel(
wcs_in.spectral.pixel_to_world(zc)
)
shape_out_cel = shape_out[1:]
else:
raise ValueError(f"Wrong number of dimensions: {array_in.ndim}")

edges = sample_array_edges(array_in.shape, n_samples=11)[::-1]
edges_out = wcs_out.world_to_pixel(wcs_in.pixel_to_world(*edges))[::-1]

# Determine the cutout parameters

# In some cases, images might not have valid coordinates in the corners,
# such as all-sky images or full solar disk views. In this case we skip
# this step and just use the full output WCS for reprojection.

if np.any(np.isnan(xc_out)) or np.any(np.isnan(yc_out)):
imin = 0
imax = shape_out_cel[1]
jmin = 0
jmax = shape_out_cel[0]
else:
imin = max(0, int(np.floor(xc_out.min() + 0.5)))
imax = min(shape_out_cel[1], int(np.ceil(xc_out.max() + 0.5)))
jmin = max(0, int(np.floor(yc_out.min() + 0.5)))
jmax = min(shape_out_cel[0], int(np.ceil(yc_out.max() + 0.5)))
ndim_out = len(shape_out)

if imax < imin or jmax < jmin:
skip_data = False
if np.any(np.isnan(edges_out)):
bounds = list(zip([0] * ndim_out, shape_out))
else:
bounds = []
for idim in range(ndim_out):
imin = max(0, int(np.floor(edges_out[idim].min() + 0.5)))
imax = min(shape_out[idim], int(np.ceil(edges_out[idim].max() + 0.5)))
bounds.append((imin, imax))
if imax < imin:
skip_data = True
break

if skip_data:
continue

if array_in.ndim == 2:
if isinstance(wcs_out, WCS):
wcs_out_indiv = wcs_out[jmin:jmax, imin:imax]
else:
wcs_out_indiv = SlicedLowLevelWCS(
wcs_out.low_level_wcs, (slice(jmin, jmax), slice(imin, imax))
)
shape_out_indiv = (jmax - jmin, imax - imin)
kmin, kmax = None, None # for reprojectedarraysubset below
elif array_in.ndim == 3:
kmin = max(0, int(np.floor(zc_out.min() + 0.5)))
kmax = min(shape_out[0], int(np.ceil(zc_out.max() + 0.5)))
if isinstance(wcs_out, WCS):
wcs_out_indiv = wcs_out[kmin:kmax, jmin:jmax, imin:imax]
else:
wcs_out_indiv = SlicedLowLevelWCS(
wcs_out.low_level_wcs, (slice(kmin, kmax), slice(jmin, jmax), slice(imin, imax))
)
shape_out_indiv = (kmax - kmin, jmax - jmin, imax - imin)
slice_out = tuple([slice(imin, imax) for (imin, imax) in bounds])

if isinstance(wcs_out, WCS):
wcs_out_indiv = wcs_out[slice_out]
else:
wcs_out_indiv = SlicedLowLevelWCS(wcs_out.low_level_wcs, slice_out)

shape_out_indiv = [imax - imin for (imin, imax) in bounds]

if block_sizes is not None:
if len(block_sizes) == len(input_data) and len(block_sizes[idata]) == len(shape_out):
Expand Down Expand Up @@ -296,22 +273,20 @@ def reproject_and_coadd(
weights[reset] = 0.0
footprint *= weights

array = ReprojectedArraySubset(array, footprint, imin, imax, jmin, jmax, kmin, kmax)
array = ReprojectedArraySubset(array, footprint, bounds)

# TODO: make sure we gracefully handle the case where the
# output image is empty (due e.g. to no overlap).

if match_background:
arrays.append(array)
if on_the_fly:
# By default, values outside of the footprint are set to NaN
# but we set these to 0 here to avoid getting NaNs in the
# means/sums.
array.array[array.footprint == 0] = 0
output_array[array.view_in_original_array] += array.array * array.footprint
output_footprint[array.view_in_original_array] += array.footprint
else:
if combine_function in ("mean", "sum"):
# By default, values outside of the footprint are set to NaN
# but we set these to 0 here to avoid getting NaNs in the
# means/sums.
array.array[array.footprint == 0] = 0

output_array[array.view_in_original_array] += array.array * array.footprint
output_footprint[array.view_in_original_array] += array.footprint
arrays.append(array)

# If requested, try and match the backgrounds.
if match_background and len(arrays) > 1:
Expand All @@ -322,11 +297,6 @@ def reproject_and_coadd(
for array, correction in zip(arrays, corrections, strict=True):
array.array -= correction

if combine_function == "min":
output_array[...] = np.inf
elif combine_function == "max":
output_array[...] = -np.inf

if combine_function in ("mean", "sum"):
if match_background:
# if we're not matching the background, this part has already been done
Expand All @@ -336,37 +306,41 @@ def reproject_and_coadd(
# means/sums.
array.array[array.footprint == 0] = 0

output_array[array.view_in_original_array] += array.array * array.footprint
output_footprint[array.view_in_original_array] += array.footprint
output_array[array.view_in_original_array] += array.array * array.footprint
output_footprint[array.view_in_original_array] += array.footprint

if combine_function == "mean":
with np.errstate(invalid="ignore"):
output_array /= output_footprint
output_array[output_footprint == 0] = blank_pixel_value

elif combine_function in ("first", "last", "min", "max"):
if match_background:
for array in arrays:
if combine_function == "first":
mask = output_footprint[array.view_in_original_array] == 0
elif combine_function == "last":
mask = array.footprint > 0
elif combine_function == "min":
mask = (array.footprint > 0) & (
array.array < output_array[array.view_in_original_array]
)
elif combine_function == "max":
mask = (array.footprint > 0) & (
array.array > output_array[array.view_in_original_array]
)

output_footprint[array.view_in_original_array] = np.where(
mask, array.footprint, output_footprint[array.view_in_original_array]
if combine_function == "min":
output_array[...] = np.inf
elif combine_function == "max":
output_array[...] = -np.inf

for array in arrays:
if combine_function == "first":
mask = output_footprint[array.view_in_original_array] == 0
elif combine_function == "last":
mask = array.footprint > 0
elif combine_function == "min":
mask = (array.footprint > 0) & (
array.array < output_array[array.view_in_original_array]
)
output_array[array.view_in_original_array] = np.where(
mask, array.array, output_array[array.view_in_original_array]
elif combine_function == "max":
mask = (array.footprint > 0) & (
array.array > output_array[array.view_in_original_array]
)

output_footprint[array.view_in_original_array] = np.where(
mask, array.footprint, output_footprint[array.view_in_original_array]
)
output_array[array.view_in_original_array] = np.where(
mask, array.array, output_array[array.view_in_original_array]
)

output_array[output_footprint == 0] = blank_pixel_value

return output_array, output_footprint
Loading

0 comments on commit f5e3016

Please sign in to comment.