Skip to content

Commit

Permalink
Add SlicedLowLevelWCS support to reproject and fix a bug (#8172)
Browse files Browse the repository at this point in the history
Co-authored-by: James Davies <[email protected]>
Co-authored-by: Howard Bushouse <[email protected]>
  • Loading branch information
3 people committed Feb 20, 2024
1 parent 4c73900 commit 4cc0ac1
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 27 deletions.
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ resample

- Use the same ``iscale`` value for resampling science data and variance arrays. [#8159]

- Changed to use the high-level APE 14 API (``pixel_to_world_values`` and
``world_to_pixel_values``) for reproject, which also fixed a bug, and
removed support for astropy model [#8172]

residual_fringe
---------------

Expand Down
31 changes: 8 additions & 23 deletions jwst/resample/resample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import warnings

import numpy as np
from astropy import wcs as fitswcs
from astropy.modeling import Model
from astropy import units as u
import gwcs

Expand Down Expand Up @@ -134,8 +132,9 @@ def reproject(wcs1, wcs2):
Parameters
----------
wcs1, wcs2 : `~astropy.wcs.WCS` or `~gwcs.wcs.WCS` or `~astropy.modeling.Model`
WCS objects.
wcs1, wcs2 : `~astropy.wcs.WCS` or `~gwcs.wcs.WCS`
WCS objects that have `pixel_to_world_values` and `world_to_pixel_values`
methods.
Returns
-------
Expand All @@ -144,25 +143,11 @@ def reproject(wcs1, wcs2):
positions in ``wcs1`` and returns x, y positions in ``wcs2``.
"""

if isinstance(wcs1, fitswcs.WCS):
forward_transform = wcs1.all_pix2world
elif isinstance(wcs1, gwcs.WCS):
forward_transform = wcs1.forward_transform
elif issubclass(wcs1, Model):
forward_transform = wcs1
else:
raise TypeError("Expected input to be astropy.wcs.WCS or gwcs.WCS "
"object or astropy.modeling.Model subclass")

if isinstance(wcs2, fitswcs.WCS):
backward_transform = wcs2.all_world2pix
elif isinstance(wcs2, gwcs.WCS):
backward_transform = wcs2.backward_transform
elif issubclass(wcs2, Model):
backward_transform = wcs2.inverse
else:
raise TypeError("Expected input to be astropy.wcs.WCS or gwcs.WCS "
"object or astropy.modeling.Model subclass")
try:
forward_transform = wcs1.pixel_to_world_values
backward_transform = wcs2.world_to_pixel_values
except AttributeError as err:
raise TypeError("Input should be a WCS") from err

def _reproject(x, y):
sky = forward_transform(x, y)
Expand Down
5 changes: 3 additions & 2 deletions jwst/resample/tests/test_resample_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,9 +476,10 @@ def test_resample_undefined_variance(nircam_rate, shape):
im.var_poisson = np.ones(shape, dtype=im.var_poisson.dtype.type)
im.var_flat = np.ones(shape, dtype=im.var_flat.dtype.type)
im.meta.filename = "foo.fits"

c = ModelContainer([im])
ResampleStep.call(c, blendheaders=False)

with pytest.warns(RuntimeWarning, match="var_rnoise array not available"):
ResampleStep.call(c, blendheaders=False)


@pytest.mark.parametrize('ratio', [0.7, 1.2])
Expand Down
90 changes: 88 additions & 2 deletions jwst/resample/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
"""Test various utility functions"""
from numpy.testing import assert_array_equal
from astropy import coordinates as coord
from astropy import wcs as fitswcs
from astropy.modeling import models as astmodels
from gwcs import coordinate_frames as cf
from gwcs.wcstools import wcs_from_fiducial
from numpy.testing import assert_allclose, assert_array_equal
import numpy as np
import pytest

Expand All @@ -9,7 +14,8 @@
from jwst.resample.resample_utils import (
build_mask,
build_driz_weight,
decode_context
decode_context,
reproject
)


Expand All @@ -25,6 +31,59 @@
JWST_NAMES_INV = '~' + JWST_NAMES


@pytest.fixture(scope='module')
def wcs_gwcs():
crval = (150.0, 2.0)
crpix = (500.0, 500.0)
shape = (1000, 1000)
pscale = 0.06 / 3600

prj = astmodels.Pix2Sky_TAN()
fiducial = np.array(crval)

pc = np.array([[-1., 0.], [0., 1.]])
pc_matrix = astmodels.AffineTransformation2D(pc, name='pc_rotation_matrix')
scale = astmodels.Scale(pscale, name='cdelt1') & astmodels.Scale(pscale, name='cdelt2')
transform = pc_matrix | scale

out_frame = cf.CelestialFrame(name='world', axes_names=('lon', 'lat'), reference_frame=coord.ICRS())
input_frame = cf.Frame2D(name="detector")
wnew = wcs_from_fiducial(fiducial, coordinate_frame=out_frame, projection=prj,
transform=transform, input_frame=input_frame)

output_bounding_box = ((0.0, float(shape[1])), (0.0, float(shape[0])))
offset1, offset2 = crpix
offsets = astmodels.Shift(-offset1, name='crpix1') & astmodels.Shift(-offset2, name='crpix2')

wnew.insert_transform('detector', offsets, after=True)
wnew.bounding_box = output_bounding_box

tr = wnew.pipeline[0].transform
pix_area = (
np.deg2rad(tr['cdelt1'].factor.value) *
np.deg2rad(tr['cdelt2'].factor.value)
)

wnew.pixel_area = pix_area
wnew.pixel_shape = shape[::-1]
wnew.array_shape = shape
return wnew


@pytest.fixture(scope='module')
def wcs_fitswcs(wcs_gwcs):
fits_wcs = fitswcs.WCS(wcs_gwcs.to_fits_sip())
return fits_wcs


@pytest.fixture(scope='module')
def wcs_slicedwcs(wcs_gwcs):
xmin, xmax = 100, 500
slices = (slice(xmin, xmax), slice(xmin, xmax))
sliced_wcs = fitswcs.wcsapi.SlicedLowLevelWCS(wcs_gwcs, slices)
return sliced_wcs


@pytest.mark.parametrize(
'dq, bitvalues, expected', [
(DQ, 0, np.array([1, 0, 0, 0, 0, 0, 0, 0, 0])),
Expand Down Expand Up @@ -116,3 +175,30 @@ def test_decode_context():

assert sorted(idx1) == [9, 12, 14, 19, 21, 25, 37, 40, 46, 58, 64, 65, 67, 77]
assert sorted(idx2) == [9, 20, 29, 36, 47, 49, 64, 69, 70, 79]


@pytest.mark.parametrize(
"wcs1, wcs2, offset",
[
("wcs_gwcs", "wcs_fitswcs", 0),
("wcs_fitswcs", "wcs_gwcs", 0),
("wcs_gwcs", "wcs_slicedwcs", 100),
("wcs_slicedwcs", "wcs_gwcs", -100),
("wcs_fitswcs", "wcs_slicedwcs", 100),
("wcs_slicedwcs", "wcs_fitswcs", -100),
]
)
def test_reproject(wcs1, wcs2, offset, request):
wcs1 = request.getfixturevalue(wcs1)
wcs2 = request.getfixturevalue(wcs2)
x = np.arange(150, 200)

f = reproject(wcs1, wcs2)
res = f(x, x)
assert_allclose(x, res[0] + offset)
assert_allclose(x, res[1] + offset)


def test_reproject_with_garbage_input():
with pytest.raises(TypeError):
reproject("foo", "bar")

0 comments on commit 4cc0ac1

Please sign in to comment.