Skip to content

Commit

Permalink
Merge pull request #770 from dirac-institute/output_ra_dec
Browse files Browse the repository at this point in the history
Have KBMOD append RA, dec columns if possible
  • Loading branch information
jeremykubica authored Jan 16, 2025
2 parents 7436615 + c599f41 commit 4bd4386
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 6 deletions.
91 changes: 85 additions & 6 deletions src/kbmod/run_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .results import Results
from .trajectory_generator import create_trajectory_generator
from .trajectory_utils import predict_pixel_locations
from .wcs_utils import wcs_to_dict
from .work_unit import WorkUnit

Expand Down Expand Up @@ -187,7 +188,14 @@ def do_gpu_search(self, config, stack, trj_generator):
keep = self.load_and_filter_results(search, config)
return keep

def run_search(self, config, stack, trj_generator=None, wcs=None, extra_meta=None):
def run_search(
self,
config,
stack,
trj_generator=None,
workunit=None,
extra_meta=None,
):
"""This function serves as the highest-level python interface for starting
a KBMOD search given an ImageStack and SearchConfiguration.
Expand All @@ -200,8 +208,8 @@ def run_search(self, config, stack, trj_generator=None, wcs=None, extra_meta=Non
trj_generator : `TrajectoryGenerator`, optional
The object to generate the candidate trajectories for each pixel.
If None uses the default EclipticCenteredSearch
wcs : `astropy.wcs.WCS`, optional
A global WCS for all images in the search.
workunit : `WorkUnit`, optional
An optional WorkUnit with additional meta-data, including the per-image WCS.
extra_meta : `dict`, optional
Any additional metadata to save as part of the results file.
Expand Down Expand Up @@ -253,8 +261,11 @@ def run_search(self, config, stack, trj_generator=None, wcs=None, extra_meta=Non
if config["save_all_stamps"]:
append_all_stamps(keep, stack, config["stamp_radius"])

# Append the WCS information if it is provided. This will be saved with the results.
keep.table.wcs = wcs
# Append additional information derived from the WorkUnit if one is provided,
# including a global WCS and per-time (RA, dec) predictions for each image.
if workunit is not None:
keep.table.wcs = workunit.wcs
append_ra_dec_to_results(workunit, keep)

# Create and save any additional meta data that should be saved with the results.
num_img = stack.img_count()
Expand Down Expand Up @@ -306,6 +317,74 @@ def run_search_from_work_unit(self, work):
work.config,
work.im_stack,
trj_generator=trj_generator,
wcs=work.wcs,
workunit=work,
extra_meta=extra_meta,
)


def append_ra_dec_to_results(workunit, results):
"""Append predicted (RA, dec) positions to the results.
Parameters
----------
workunit : `WorkUnit`
The WorkUnit with all the WCS information.
results : `Results`
The current table of results including the per-pixel trajectories.
This is modified in-place.
"""
num_results = len(results)
if num_results == 0:
return # Nothing to do

num_times = workunit.im_stack.img_count()
times = workunit.im_stack.build_zeroed_times()

# Predict where each candidate trajectory will be at each time step.
xp = predict_pixel_locations(times, results["x"], results["vx"], as_int=False)
yp = predict_pixel_locations(times, results["y"], results["vy"], as_int=False)

# Compute the predicted (RA, dec) positions for each trajectory in global space.
if workunit.wcs is not None:
logger.info("Found common WCS. Adding global_ra and global_dec columns.")

skypos = workunit.wcs.pixel_to_world(xp, yp)
results.table["global_ra"] = skypos.ra.degree
results.table["global_dec"] = skypos.dec.degree

# Loop over the trajectories to build the original positions.
all_ra = []
all_dec = []
for idx in range(num_results):
pos_tuples = [(xp[idx, j], yp[idx, j]) for j in range(num_times)]
skypos = workunit.image_positions_to_original_icrs(
image_indices=np.arange(num_times), # Compute for all times.
positions=pos_tuples,
input_format="xy",
output_format="radec",
filter_in_frame=False,
)

# We get back a list of SkyCoord, because we gave a list.
# So we flatten it and extract the coordinate values.
all_ra.append([skypos[j].ra.degree for j in range(num_times)])
all_dec.append([skypos[j].dec.degree for j in range(num_times)])

results.table["img_ra"] = all_ra
results.table["img_dec"] = all_dec
else:
logger.info("No common WCS found. Skipping global_ra and global_dec columns.")

# If there are no global WCS, we just predict per image.
all_ra = np.zeros((len(results), num_times))
all_dec = np.zeros((len(results), num_times))

for time_idx in range(num_times):
wcs = workunit.get_wcs(time_idx)
if wcs is not None:
skypos = wcs.pixel_to_world(xp[:, time_idx], yp[:, time_idx])
all_ra[:, time_idx] = skypos.ra.degree
all_dec[:, time_idx] = skypos.dec.degree

results.table["img_ra"] = all_ra
results.table["img_dec"] = all_dec
144 changes: 144 additions & 0 deletions tests/test_run_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""Test some of the functions needed for running the search."""

import unittest

import numpy as np

from kbmod.configuration import SearchConfiguration
from kbmod.fake_data.fake_data_creator import create_fake_times, FakeDataSet
from kbmod.results import Results
from kbmod.run_search import append_ra_dec_to_results
from kbmod.search import *
from kbmod.wcs_utils import make_fake_wcs
from kbmod.work_unit import WorkUnit


class test_run_search(unittest.TestCase):
def test_append_ra_dec_global(self):
# Create a fake WorkUnit with 20 times, a completely random ImageStack,
# and no trajectories.
num_times = 20
fake_times = create_fake_times(num_times, t0=60676.0)
fake_ds = FakeDataSet(800, 600, fake_times)

# Append a global fake WCS and one for each time.
global_wcs = make_fake_wcs(20.0, 0.0, 800, 600, deg_per_pixel=0.5 / 3600.0)
all_wcs = []
for idx in range(num_times):
curr = make_fake_wcs(
20.01 + idx / 100.0, 0.01 + idx / 100.0, 800, 600, deg_per_pixel=0.5 / 3600.0
)
all_wcs.append(curr)

fake_wu = WorkUnit(
im_stack=fake_ds.stack,
config=SearchConfiguration(),
wcs=global_wcs,
per_image_wcs=all_wcs,
reprojected=True,
per_image_indices=[i for i in range(num_times)],
heliocentric_distance=np.full(num_times, 100.0),
obstimes=fake_times,
)

# Create three fake trajectories in the bounds of the images. We don't
# bother actually inserting them into the pixels.
trjs = [
Trajectory(x=5, y=10, vx=1, vy=1, flux=1000.0, lh=1000.0, obs_count=num_times),
Trajectory(x=400, y=300, vx=-5, vy=-2, flux=1000.0, lh=1000.0, obs_count=num_times),
Trajectory(x=100, y=500, vx=10, vy=-10, flux=1000.0, lh=1000.0, obs_count=num_times),
]
results = Results.from_trajectories(trjs)
self.assertEqual(len(results), 3)

append_ra_dec_to_results(fake_wu, results)

# The global RA should exist and be close to 20.0 for all observations.
self.assertEqual(len(results["global_ra"]), 3)
for i in range(3):
self.assertEqual(len(results["global_ra"][i]), num_times)
self.assertTrue(np.all(results["global_ra"][i] > 19.0))
self.assertTrue(np.all(results["global_ra"][i] < 21.0))

# The global Dec should exist and be close to 0.0 for all observations.
self.assertEqual(len(results["global_dec"]), 3)
for i in range(3):
self.assertEqual(len(results["global_dec"][i]), num_times)
self.assertTrue(np.all(results["global_dec"][i] > -1.0))
self.assertTrue(np.all(results["global_dec"][i] < 1.0))

# The per-image RA should exist, be close to 20.0 for all observations,
# and be different from the global RA
self.assertEqual(len(results["img_ra"]), 3)
for i in range(3):
self.assertEqual(len(results["img_ra"][i]), num_times)
self.assertTrue(np.all(results["img_ra"][i] > 19.0))
self.assertTrue(np.all(results["img_ra"][i] < 21.0))
self.assertFalse(np.any(results["img_ra"][i] == results["global_ra"][i]))

# The global Dec should exist and be close to 0.0 for all observations.
self.assertEqual(len(results["img_dec"]), 3)
for i in range(3):
self.assertEqual(len(results["img_dec"][i]), num_times)
self.assertTrue(np.all(results["img_dec"][i] > -1.0))
self.assertTrue(np.all(results["img_dec"][i] < 1.0))
self.assertFalse(np.any(results["img_dec"][i] == results["global_dec"][i]))

def test_append_ra_dec_no_global(self):
# Create a fake WorkUnit with 20 times, a completely random ImageStack,
# and no trajectories.
num_times = 20
fake_times = create_fake_times(num_times, t0=60676.0)
fake_ds = FakeDataSet(800, 600, fake_times)

# Append a global fake WCS and one for each time.
all_wcs = []
for idx in range(num_times):
curr = make_fake_wcs(
20.01 + idx / 100.0, 0.01 + idx / 100.0, 800, 600, deg_per_pixel=0.5 / 3600.0
)
all_wcs.append(curr)

fake_wu = WorkUnit(
im_stack=fake_ds.stack,
config=SearchConfiguration(),
wcs=None,
per_image_wcs=all_wcs,
reprojected=False,
per_image_indices=[i for i in range(num_times)],
obstimes=fake_times,
)

# Create three fake trajectories in the bounds of the images. We don't
# bother actually inserting them into the pixels.
trjs = [
Trajectory(x=5, y=10, vx=1, vy=1, flux=1000.0, lh=1000.0, obs_count=num_times),
Trajectory(x=400, y=300, vx=-5, vy=-2, flux=1000.0, lh=1000.0, obs_count=num_times),
Trajectory(x=100, y=500, vx=10, vy=-10, flux=1000.0, lh=1000.0, obs_count=num_times),
]
results = Results.from_trajectories(trjs)
self.assertEqual(len(results), 3)

append_ra_dec_to_results(fake_wu, results)

# The global RA and global dec should not exist.
self.assertFalse("global_ra" in results.colnames)
self.assertFalse("global_dec" in results.colnames)

# The per-image RA should exist, be close to 20.0 for all observations.
self.assertEqual(len(results["img_ra"]), 3)
for i in range(3):
self.assertEqual(len(results["img_ra"][i]), num_times)
self.assertTrue(np.all(results["img_ra"][i] > 19.0))
self.assertTrue(np.all(results["img_ra"][i] < 21.0))

# The global Dec should exist and be close to 0.0 for all observations.
self.assertEqual(len(results["img_dec"]), 3)
for i in range(3):
self.assertEqual(len(results["img_dec"][i]), num_times)
self.assertTrue(np.all(results["img_dec"][i] > -1.0))
self.assertTrue(np.all(results["img_dec"][i] < 1.0))


if __name__ == "__main__":
unittest.main()

0 comments on commit 4bd4386

Please sign in to comment.