From 335557baee5d3f15125e937928298a63b0808fe9 Mon Sep 17 00:00:00 2001 From: John Parejko Date: Wed, 23 Oct 2024 16:01:53 -0700 Subject: [PATCH] Refactor getTemplate run/runQuantum/getOverlappingExposures arguments Bring `runQauntum` more in line with how we want arguments to be handled (load the inputs and pass them as named arguments). Refactor `getOverlappingExposures` so that it can be used more easily as outside of `runQuantum`. --- python/lsst/ip/diffim/getTemplate.py | 57 ++++++++++++++++------------ tests/test_getTemplate.py | 31 ++++++++++++--- 2 files changed, 57 insertions(+), 31 deletions(-) diff --git a/python/lsst/ip/diffim/getTemplate.py b/python/lsst/ip/diffim/getTemplate.py index ff152798..538af454 100644 --- a/python/lsst/ip/diffim/getTemplate.py +++ b/python/lsst/ip/diffim/getTemplate.py @@ -115,15 +115,24 @@ def __init__(self, *args, **kwargs): def runQuantum(self, butlerQC, inputRefs, outputRefs): inputs = butlerQC.get(inputRefs) - results = self.getOverlappingExposures(inputs) - del inputs["skyMap"] # Only needed for the above. - inputs["coaddExposures"] = results.coaddExposures - inputs["dataIds"] = results.dataIds - inputs["physical_filter"] = butlerQC.quantum.dataId["physical_filter"] - outputs = self.run(**inputs) + bbox = inputs.pop("bbox") + wcs = inputs.pop("wcs") + coaddExposures = inputs.pop('coaddExposures') + skymap = inputs.pop("skyMap") + + # This should not happen with a properly configured execution context. + assert not inputs, "runQuantum got more inputs than expected" + + results = self.getOverlappingExposures(coaddExposures, bbox, skymap, wcs) + physical_filter = butlerQC.quantum.dataId["physical_filter"] + outputs = self.run(coaddExposures=results.coaddExposures, + bbox=bbox, + wcs=wcs, + dataIds=results.dataIds, + physical_filter=physical_filter) butlerQC.put(outputs, outputRefs) - def getOverlappingExposures(self, inputs): + def getOverlappingExposures(self, coaddExposures, bbox, skymap, wcs): """Return a data structure containing the coadds that overlap the specified bbox projected onto the sky, and a corresponding data structure of their dataIds. @@ -136,19 +145,18 @@ def getOverlappingExposures(self, inputs): Parameters ---------- - inputs : `dict` of task Inputs, containing: - - coaddExposures : `list` \ - [`lsst.daf.butler.DeferredDatasetHandle` of \ - `lsst.afw.image.Exposure`] - Data references to exposures that might overlap the desired - region. - - bbox : `lsst.geom.Box2I` - Template bounding box of the pixel geometry onto which the - coaddExposures will be resampled. - - skyMap : `lsst.skymap.SkyMap` - Geometry of the tracts and patches the coadds are defined on. - - wcs : `lsst.afw.geom.SkyWcs` - Template WCS onto which the coadds will be resampled. + coaddExposures : `list` \ + [`lsst.daf.butler.DeferredDatasetHandle` of \ + `lsst.afw.image.Exposure`] + Data references to exposures that might overlap the desired + region. + bbox : `lsst.geom.Box2I` + Template bounding box of the pixel geometry onto which the + coaddExposures will be resampled. + skyMap : `lsst.skymap.SkyMap` + Geometry of the tracts and patches the coadds are defined on. + wcs : `lsst.afw.geom.SkyWcs` + Template WCS onto which the coadds will be resampled. Returns ------- @@ -170,17 +178,16 @@ def getOverlappingExposures(self, inputs): Raised if no patches overlap the input detector bbox, or the input WCS is None. """ - if (wcs := inputs['wcs']) is None: + if wcs is None: raise pipeBase.NoWorkFound("Exposure has no WCS; cannot create a template.") # Exposure's validPolygon would be more accurate - detectorPolygon = geom.Box2D(inputs['bbox']) + detectorPolygon = geom.Box2D(bbox) overlappingArea = 0 coaddExposures = collections.defaultdict(list) dataIds = collections.defaultdict(list) - skymap = inputs['skyMap'] - for coaddRef in inputs['coaddExposures']: + for coaddRef in coaddExposures: dataId = coaddRef.dataId patchWcs = skymap[dataId['tract']].getWcs() patchBBox = skymap[dataId['tract']][dataId['patch']].getOuterBBox() @@ -199,7 +206,7 @@ def getOverlappingExposures(self, inputs): dataIds=dataIds) @timeMethod - def run(self, coaddExposures, bbox, wcs, dataIds, physical_filter): + def run(self, *, coaddExposures, bbox, wcs, dataIds, physical_filter): """Warp coadds from multiple tracts and patches to form a template to subtract from a science image. diff --git a/tests/test_getTemplate.py b/tests/test_getTemplate.py index 5359f428..ba89b117 100644 --- a/tests/test_getTemplate.py +++ b/tests/test_getTemplate.py @@ -209,8 +209,11 @@ def testRunOneTractInput(self): task = lsst.ip.diffim.GetTemplateTask() # Restrict to tract 0, since the box fits in just that tract. # Task modifies the input bbox, so pass a copy. - result = task.run({0: self.patches[0]}, lsst.geom.Box2I(box), - self.exposure.wcs, {0: self.dataIds[0]}, "a_test") + result = task.run(coaddExposures={0: self.patches[0]}, + bbox=lsst.geom.Box2I(box), + wcs=self.exposure.wcs, + dataIds={0: self.dataIds[0]}, + physical_filter="a_test") # All 4 patches from tract 0 are included in this template. self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 4) @@ -224,7 +227,11 @@ def testRunOneTractMultipleInputs(self): box = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Point2I(180, 180)) task = lsst.ip.diffim.GetTemplateTask() # Task modifies the input bbox, so pass a copy. - result = task.run(self.patches, lsst.geom.Box2I(box), self.exposure.wcs, self.dataIds, "a_test") + result = task.run(coaddExposures=self.patches, + bbox=lsst.geom.Box2I(box), + wcs=self.exposure.wcs, + dataIds=self.dataIds, + physical_filter="a_test") # All 4 patches from two tracts are included in this template. self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 8) @@ -236,7 +243,11 @@ def testRunTwoTracts(self): box = lsst.geom.Box2I(lsst.geom.Point2I(200, 200), lsst.geom.Point2I(600, 600)) task = lsst.ip.diffim.GetTemplateTask() # Task modifies the input bbox, so pass a copy. - result = task.run(self.patches, lsst.geom.Box2I(box), self.exposure.wcs, self.dataIds, "a_test") + result = task.run(coaddExposures=self.patches, + bbox=lsst.geom.Box2I(box), + wcs=self.exposure.wcs, + dataIds=self.dataIds, + physical_filter="a_test") # All 4 patches from all 4 tracts are included in this template self._checkMetadata(result.template, task.config, box, self.exposure.wcs, 16) @@ -248,7 +259,11 @@ def testRunNoTemplate(self): box = lsst.geom.Box2I(lsst.geom.Point2I(1200, 1200), lsst.geom.Point2I(1600, 1600)) task = lsst.ip.diffim.GetTemplateTask() with self.assertRaisesRegex(lsst.pipe.base.NoWorkFound, "No patches found"): - task.run(self.patches, lsst.geom.Box2I(box), self.exposure.wcs, self.dataIds, "a_test") + task.run(coaddExposures=self.patches, + bbox=lsst.geom.Box2I(box), + wcs=self.exposure.wcs, + dataIds=self.dataIds, + physical_filter="a_test") def testMissingPatches(self): """Test that a missing patch results in an appropriate mask. @@ -261,7 +276,11 @@ def testMissingPatches(self): box = lsst.geom.Box2I(lsst.geom.Point2I(0, 0), lsst.geom.Point2I(180, 180)) task = lsst.ip.diffim.GetTemplateTask() # Task modifies the input bbox, so pass a copy. - result = task.run(self.patches, lsst.geom.Box2I(box), self.exposure.wcs, self.dataIds, "a_test") + result = task.run(coaddExposures=self.patches, + bbox=lsst.geom.Box2I(box), + wcs=self.exposure.wcs, + dataIds=self.dataIds, + physical_filter="a_test") no_data = (result.template.mask.array & result.template.mask.getPlaneBitMask("NO_DATA")) != 0 self.assertTrue(all(np.isnan(result.template.image.array[no_data]))) self.assertTrue(all(np.isnan(result.template.variance.array[no_data])))