Skip to content

Commit

Permalink
Refactor getTemplate run/runQuantum/getOverlappingExposures arguments
Browse files Browse the repository at this point in the history
Bring `runQauntum` more in line with how we want arguments to be
handled (load the inputs and pass them as named arguments).
Refactor `getOverlappingExposures` so that it can be used more easily
as outside of `runQuantum`.
  • Loading branch information
parejkoj committed Nov 13, 2024
1 parent bd46683 commit 335557b
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 31 deletions.
57 changes: 32 additions & 25 deletions python/lsst/ip/diffim/getTemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,24 @@ def __init__(self, *args, **kwargs):

def runQuantum(self, butlerQC, inputRefs, outputRefs):
inputs = butlerQC.get(inputRefs)
results = self.getOverlappingExposures(inputs)
del inputs["skyMap"] # Only needed for the above.
inputs["coaddExposures"] = results.coaddExposures
inputs["dataIds"] = results.dataIds
inputs["physical_filter"] = butlerQC.quantum.dataId["physical_filter"]
outputs = self.run(**inputs)
bbox = inputs.pop("bbox")
wcs = inputs.pop("wcs")
coaddExposures = inputs.pop('coaddExposures')
skymap = inputs.pop("skyMap")

# This should not happen with a properly configured execution context.
assert not inputs, "runQuantum got more inputs than expected"

results = self.getOverlappingExposures(coaddExposures, bbox, skymap, wcs)
physical_filter = butlerQC.quantum.dataId["physical_filter"]
outputs = self.run(coaddExposures=results.coaddExposures,
bbox=bbox,
wcs=wcs,
dataIds=results.dataIds,
physical_filter=physical_filter)
butlerQC.put(outputs, outputRefs)

def getOverlappingExposures(self, inputs):
def getOverlappingExposures(self, coaddExposures, bbox, skymap, wcs):
"""Return a data structure containing the coadds that overlap the
specified bbox projected onto the sky, and a corresponding data
structure of their dataIds.
Expand All @@ -136,19 +145,18 @@ def getOverlappingExposures(self, inputs):
Parameters
----------
inputs : `dict` of task Inputs, containing:
- coaddExposures : `list` \
[`lsst.daf.butler.DeferredDatasetHandle` of \
`lsst.afw.image.Exposure`]
Data references to exposures that might overlap the desired
region.
- bbox : `lsst.geom.Box2I`
Template bounding box of the pixel geometry onto which the
coaddExposures will be resampled.
- skyMap : `lsst.skymap.SkyMap`
Geometry of the tracts and patches the coadds are defined on.
- wcs : `lsst.afw.geom.SkyWcs`
Template WCS onto which the coadds will be resampled.
coaddExposures : `list` \
[`lsst.daf.butler.DeferredDatasetHandle` of \
`lsst.afw.image.Exposure`]
Data references to exposures that might overlap the desired
region.
bbox : `lsst.geom.Box2I`
Template bounding box of the pixel geometry onto which the
coaddExposures will be resampled.
skyMap : `lsst.skymap.SkyMap`
Geometry of the tracts and patches the coadds are defined on.
wcs : `lsst.afw.geom.SkyWcs`
Template WCS onto which the coadds will be resampled.
Returns
-------
Expand All @@ -170,17 +178,16 @@ def getOverlappingExposures(self, inputs):
Raised if no patches overlap the input detector bbox, or the input
WCS is None.
"""
if (wcs := inputs['wcs']) is None:
if wcs is None:
raise pipeBase.NoWorkFound("Exposure has no WCS; cannot create a template.")

# Exposure's validPolygon would be more accurate
detectorPolygon = geom.Box2D(inputs['bbox'])
detectorPolygon = geom.Box2D(bbox)
overlappingArea = 0
coaddExposures = collections.defaultdict(list)
dataIds = collections.defaultdict(list)

skymap = inputs['skyMap']
for coaddRef in inputs['coaddExposures']:
for coaddRef in coaddExposures:
dataId = coaddRef.dataId
patchWcs = skymap[dataId['tract']].getWcs()
patchBBox = skymap[dataId['tract']][dataId['patch']].getOuterBBox()
Expand All @@ -199,7 +206,7 @@ def getOverlappingExposures(self, inputs):
dataIds=dataIds)

@timeMethod
def run(self, coaddExposures, bbox, wcs, dataIds, physical_filter):
def run(self, *, coaddExposures, bbox, wcs, dataIds, physical_filter):
"""Warp coadds from multiple tracts and patches to form a template to
subtract from a science image.
Expand Down
31 changes: 25 additions & 6 deletions tests/test_getTemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,11 @@ def testRunOneTractInput(self):
task = lsst.ip.diffim.GetTemplateTask()
# Restrict to tract 0, since the box fits in just that tract.
# Task modifies the input bbox, so pass a copy.
result = task.run({0: self.patches[0]}, lsst.geom.Box2I(box),
self.exposure.wcs, {0: self.dataIds[0]}, "a_test")
result = task.run(coaddExposures={0: self.patches[0]},
bbox=lsst.geom.Box2I(box),
wcs=self.exposure.wcs,
dataIds={0: self.dataIds[0]},
physical_filter="a_test")

# All 4 patches from tract 0 are included in this template.
self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 4)
Expand All @@ -224,7 +227,11 @@ def testRunOneTractMultipleInputs(self):
box = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Point2I(180, 180))
task = lsst.ip.diffim.GetTemplateTask()
# Task modifies the input bbox, so pass a copy.
result = task.run(self.patches, lsst.geom.Box2I(box), self.exposure.wcs, self.dataIds, "a_test")
result = task.run(coaddExposures=self.patches,
bbox=lsst.geom.Box2I(box),
wcs=self.exposure.wcs,
dataIds=self.dataIds,
physical_filter="a_test")

# All 4 patches from two tracts are included in this template.
self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 8)
Expand All @@ -236,7 +243,11 @@ def testRunTwoTracts(self):
box = lsst.geom.Box2I(lsst.geom.Point2I(200, 200), lsst.geom.Point2I(600, 600))
task = lsst.ip.diffim.GetTemplateTask()
# Task modifies the input bbox, so pass a copy.
result = task.run(self.patches, lsst.geom.Box2I(box), self.exposure.wcs, self.dataIds, "a_test")
result = task.run(coaddExposures=self.patches,
bbox=lsst.geom.Box2I(box),
wcs=self.exposure.wcs,
dataIds=self.dataIds,
physical_filter="a_test")

# All 4 patches from all 4 tracts are included in this template
self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 16)
Expand All @@ -248,7 +259,11 @@ def testRunNoTemplate(self):
box = lsst.geom.Box2I(lsst.geom.Point2I(1200, 1200), lsst.geom.Point2I(1600, 1600))
task = lsst.ip.diffim.GetTemplateTask()
with self.assertRaisesRegex(lsst.pipe.base.NoWorkFound, "No patches found"):
task.run(self.patches, lsst.geom.Box2I(box), self.exposure.wcs, self.dataIds, "a_test")
task.run(coaddExposures=self.patches,
bbox=lsst.geom.Box2I(box),
wcs=self.exposure.wcs,
dataIds=self.dataIds,
physical_filter="a_test")

def testMissingPatches(self):
"""Test that a missing patch results in an appropriate mask.
Expand All @@ -261,7 +276,11 @@ def testMissingPatches(self):
box = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Point2I(180, 180))
task = lsst.ip.diffim.GetTemplateTask()
# Task modifies the input bbox, so pass a copy.
result = task.run(self.patches, lsst.geom.Box2I(box), self.exposure.wcs, self.dataIds, "a_test")
result = task.run(coaddExposures=self.patches,
bbox=lsst.geom.Box2I(box),
wcs=self.exposure.wcs,
dataIds=self.dataIds,
physical_filter="a_test")
no_data = (result.template.mask.array & result.template.mask.getPlaneBitMask("NO_DATA")) != 0
self.assertTrue(all(np.isnan(result.template.image.array[no_data])))
self.assertTrue(all(np.isnan(result.template.variance.array[no_data])))
Expand Down

0 comments on commit 335557b

Please sign in to comment.