Skip to content

Commit

Permalink
Fix trajectory_predict_skypos mutating input array
Browse files Browse the repository at this point in the history
  • Loading branch information
wilsonbb committed Jan 24, 2025
1 parent 7436615 commit 394f9b8
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/kbmod/filters/known_object_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def __init__(
----------
table : astropy.table.Table
A table containing our catalog of observations of known objects.
obstimes : list(float)
The MJD times of each observation within KBMOD results we want to match to
obstimes : np.array(float)
A numpy array of MJD times of each observation within KBMOD results we want to match to
the known objects.
matcher_name : str
The name of the filter to apply to the results. This both determines
Expand Down
4 changes: 3 additions & 1 deletion src/kbmod/trajectory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ def trajectory_predict_skypos(trj, wcs, times):
A SkyCoord with the transformed locations.
"""
dt = np.asarray(times)
dt -= dt[0]
# Note that we do a reassignment to avoid modifying the input, which may
# happen if `times` is already an array and `np.asarray` is a no-op.
dt = dt - dt[0]

# Predict locations in pixel space.
x_vals = trj.x + trj.vx * dt
Expand Down
8 changes: 7 additions & 1 deletion tests/test_known_object_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def setUp(self):

# Create a fake dataset with 15 x 10 images and 25 obstimes.
num_images = 25
self.obstimes = np.array(create_fake_times(num_images))
start_time = 58922.1
self.obstimes = np.array(create_fake_times(num_images, t0=start_time))
# Check that all obstimes are greater or equal to the start time
self.assertTrue(np.all(self.obstimes >= start_time))
ds = FakeDataSet(15, 10, self.obstimes, use_seed=True)
self.wcs = make_fake_wcs(10.0, 15.0, 15, 10)
ds.set_wcs(self.wcs)
Expand Down Expand Up @@ -107,6 +110,9 @@ def setUp(self):
time_offset=time_offset_mjd_close,
)

# check that the obstimes have been unmodified and are still not zero-offset
self.assertTrue(np.all(self.obstimes >= start_time))

def test_known_objs_matcher_init(
self,
): # Test that a table with no columns specified raises a ValueError
Expand Down
19 changes: 13 additions & 6 deletions tests/test_trajectory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,19 @@ def test_predict_skypos(self):
# Create a trajectory starting at the middle and traveling +2 pixels a day in x and -5 in y.
trj = Trajectory(x=9, y=9, vx=2.0, vy=-5.0)

# Predict locations at times 0.0 and 1.0
my_sky = trajectory_predict_skypos(trj, my_wcs, [0.0, 1.0])
self.assertAlmostEqual(my_sky.ra[0].deg, 45.0)
self.assertAlmostEqual(my_sky.dec[0].deg, -15.0)
self.assertAlmostEqual(my_sky.ra[1].deg, 45.2, delta=0.01)
self.assertAlmostEqual(my_sky.dec[1].deg, -15.5, delta=0.01)
# Predict locations at times 57921.0 and 57922.0, both as a list and as a numpy array.
obstimes = [57921.0, 57922.0]
for curr_obstimes in [obstimes, np.array(obstimes)]:
my_sky = trajectory_predict_skypos(trj, my_wcs, curr_obstimes)
# Verify that the obstimes were not mutated
self.assertEqual(curr_obstimes[0], 57921.0)
self.assertEqual(curr_obstimes[1], 57922.0)

# Verify that the predicted sky positions are correct.
self.assertAlmostEqual(my_sky.ra[0].deg, 45.0)
self.assertAlmostEqual(my_sky.dec[0].deg, -15.0)
self.assertAlmostEqual(my_sky.ra[1].deg, 45.2, delta=0.01)
self.assertAlmostEqual(my_sky.dec[1].deg, -15.5, delta=0.01)

def test_trajectory_from_np_object(self):
np_obj = np.array(
Expand Down

0 comments on commit 394f9b8

Please sign in to comment.