Skip to content

Commit

Permalink
Update unit tests for numpy 2 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
melanieclarke committed Sep 20, 2024
1 parent d2e7e2b commit b07e8fc
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions jwst/coron/tests/test_coron.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
def test_fourier_imshift():
""" Test of fourier_imshift() in imageregistration.py """

image = np.zeros((5, 5), dtype=np.float32)
image = np.zeros((5, 5), dtype=np.float64)
image[1:4, 1:4] += 1.0
image[2, 2] += 2.0
shift = [1.2, 0.6]
Expand All @@ -31,7 +31,7 @@ def test_fourier_imshift():
def test_shift_subtract():
""" Test of shift_subtract() in imageregistration.py """

target = np.arange((15), dtype=np.float32).reshape((3, 5))
target = np.arange((15), dtype=np.float64).reshape((3, 5))
reference = target + 0.1
reference[1, 0] -= 0.2
reference[2, 0] += 2.3
Expand Down Expand Up @@ -68,7 +68,7 @@ def test_shift_subtract():
def test_align_fourierLSQ():
""" Test of align_fourierLSQ() in imageregistration.py """

target = np.arange((15), dtype=np.float32).reshape((3, 5))
target = np.arange((15), dtype=np.float64).reshape((3, 5))
reference = target + 0.1
reference[1, 0] -= 0.2
reference[2, 0] += 2.3
Expand All @@ -85,8 +85,8 @@ def test_align_fourierLSQ():
def test_align_array():
""" Test of align_array() in imageregistration.py """

temp = np.arange((15), dtype=np.float32).reshape((3, 5))
targ = np.zeros((3, 3, 5))
temp = np.arange((15), dtype=np.float64).reshape((3, 5))
targ = np.zeros((3, 3, 5), dtype=np.float64)
targ[:] = temp
targ[0, 1, 1] += 0.3
targ[0, 2, 1] += 0.7
Expand All @@ -108,7 +108,7 @@ def test_align_array():
ref[0, 2] += 1.3
ref[1, 4] -= 0.6

mask = temp * 0 + 1
mask = np.full(temp.shape, 1)
mask[1, 1] = 0
mask[1, 2] = 0
aligned, shifts = imageregistration.align_array(ref, targ, mask)
Expand Down Expand Up @@ -149,7 +149,7 @@ def test_align_models():
""" Test of align_models() in imageregistration.py """

temp = np.arange((15), dtype=np.float32).reshape((3, 5))
targ = np.zeros((3, 3, 5))
targ = np.zeros((3, 3, 5), dtype=np.float32)
targ[:] = temp
targ[0, 1, 1] += 0.3
targ[0, 2, 1] += 0.7
Expand All @@ -168,8 +168,7 @@ def test_align_models():
ref[1, 2, 3] -= 5.0
ref[2, 0, 3] -= 1.6

mask = ref[0, :, :]
mask = ref[0, :, :] * 0 + 1
mask = np.full(ref[0].shape, 1)
mask[1, 1] = 0
mask[1, 2] = 0

Expand All @@ -181,10 +180,10 @@ def test_align_models():
results_sub = am_results.data[:2, 2, :3]

truth_results_sub = np.array(
[[10.0, 11.7, 12.0], [10.036278, 11.138131, 10.180669]],
[[10.0, 11.7, 12.0], [10.0, 11.2, 10.2]],
)

npt.assert_allclose(results_sub, truth_results_sub, atol=1e-6)
npt.assert_allclose(results_sub, truth_results_sub, rtol=1e-2)


def test_KLT():
Expand Down

0 comments on commit b07e8fc

Please sign in to comment.