From 394f9b8e847968ddb2ee16dd91f367caeef2ad51 Mon Sep 17 00:00:00 2001 From: Wilson Beebe Date: Thu, 23 Jan 2025 16:08:39 -0800 Subject: [PATCH] Fix trajectory_predict_skypos mutating input array --- src/kbmod/filters/known_object_filters.py | 4 ++-- src/kbmod/trajectory_utils.py | 4 +++- tests/test_known_object_filters.py | 8 +++++++- tests/test_trajectory_utils.py | 19 +++++++++++++------ 4 files changed, 25 insertions(+), 10 deletions(-) diff --git a/src/kbmod/filters/known_object_filters.py b/src/kbmod/filters/known_object_filters.py index a3fa20f8..cfbed3cb 100644 --- a/src/kbmod/filters/known_object_filters.py +++ b/src/kbmod/filters/known_object_filters.py @@ -44,8 +44,8 @@ def __init__( ---------- table : astropy.table.Table A table containing our catalog of observations of known objects. - obstimes : list(float) - The MJD times of each observation within KBMOD results we want to match to + obstimes : np.array(float) + A numpy array of MJD times of each observation within KBMOD results we want to match to the known objects. matcher_name : str The name of the filter to apply to the results. This both determines diff --git a/src/kbmod/trajectory_utils.py b/src/kbmod/trajectory_utils.py index 38a3f4c9..0fb33946 100644 --- a/src/kbmod/trajectory_utils.py +++ b/src/kbmod/trajectory_utils.py @@ -124,7 +124,9 @@ def trajectory_predict_skypos(trj, wcs, times): A SkyCoord with the transformed locations. """ dt = np.asarray(times) - dt -= dt[0] + # Note that we do a reassignment to avoid modifying the input, which may + # happen if `times` is already an array and `np.asarray` is a no-op. + dt = dt - dt[0] # Predict locations in pixel space. x_vals = trj.x + trj.vx * dt diff --git a/tests/test_known_object_filters.py b/tests/test_known_object_filters.py index c19ad802..986d741b 100644 --- a/tests/test_known_object_filters.py +++ b/tests/test_known_object_filters.py @@ -26,7 +26,10 @@ def setUp(self): # Create a fake dataset with 15 x 10 images and 25 obstimes. num_images = 25 - self.obstimes = np.array(create_fake_times(num_images)) + start_time = 58922.1 + self.obstimes = np.array(create_fake_times(num_images, t0=start_time)) + # Check that all obstimes are greater or equal to the start time + self.assertTrue(np.all(self.obstimes >= start_time)) ds = FakeDataSet(15, 10, self.obstimes, use_seed=True) self.wcs = make_fake_wcs(10.0, 15.0, 15, 10) ds.set_wcs(self.wcs) @@ -107,6 +110,9 @@ def setUp(self): time_offset=time_offset_mjd_close, ) + # check that the obstimes have been unmodified and are still not zero-offset + self.assertTrue(np.all(self.obstimes >= start_time)) + def test_known_objs_matcher_init( self, ): # Test that a table with no columns specified raises a ValueError diff --git a/tests/test_trajectory_utils.py b/tests/test_trajectory_utils.py index 9a1d6f91..0b3d2729 100644 --- a/tests/test_trajectory_utils.py +++ b/tests/test_trajectory_utils.py @@ -62,12 +62,19 @@ def test_predict_skypos(self): # Create a trajectory starting at the middle and traveling +2 pixels a day in x and -5 in y. trj = Trajectory(x=9, y=9, vx=2.0, vy=-5.0) - # Predict locations at times 0.0 and 1.0 - my_sky = trajectory_predict_skypos(trj, my_wcs, [0.0, 1.0]) - self.assertAlmostEqual(my_sky.ra[0].deg, 45.0) - self.assertAlmostEqual(my_sky.dec[0].deg, -15.0) - self.assertAlmostEqual(my_sky.ra[1].deg, 45.2, delta=0.01) - self.assertAlmostEqual(my_sky.dec[1].deg, -15.5, delta=0.01) + # Predict locations at times 57921.0 and 57922.0, both as a list and as a numpy array. + obstimes = [57921.0, 57922.0] + for curr_obstimes in [obstimes, np.array(obstimes)]: + my_sky = trajectory_predict_skypos(trj, my_wcs, curr_obstimes) + # Verify that the obstimes were not mutated + self.assertEqual(curr_obstimes[0], 57921.0) + self.assertEqual(curr_obstimes[1], 57922.0) + + # Verify that the predicted sky positions are correct. + self.assertAlmostEqual(my_sky.ra[0].deg, 45.0) + self.assertAlmostEqual(my_sky.dec[0].deg, -15.0) + self.assertAlmostEqual(my_sky.ra[1].deg, 45.2, delta=0.01) + self.assertAlmostEqual(my_sky.dec[1].deg, -15.5, delta=0.01) def test_trajectory_from_np_object(self): np_obj = np.array(