Skip to content

Commit

Permalink
Make tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
fred3m committed Jan 9, 2025
1 parent 727e023 commit 5165111
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 96 deletions.
2 changes: 1 addition & 1 deletion python/lsst/meas/extensions/scarlet/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def buildMonochromeObservation(
if bbox is None:
bbox = utils.bboxToScarletBox(imageForRedistribution.getBBox())
assert bbox is not None # needed for typing
parents = catalog[catalog["deblend_level"] == 0]
parents = catalog[catalog["parent"] == 0]
footprintImage = utils.footprintsToNumpy(parents, bbox.shape, bbox.origin[::-1])
# Extract the image array to re-distribute its flux
images = scl.Image(
Expand Down
67 changes: 37 additions & 30 deletions python/lsst/meas/extensions/scarlet/scarletDeblendTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,21 +542,24 @@ class ScarletDeblendConfig(pexConfig.Config):

# Size restrictions
maxNumberOfPeaks = pexConfig.Field[int](
default=600,
# default=600,
default=-1,
doc=(
"Only deblend the brightest maxNumberOfPeaks peaks in the parent"
" (<= 0: unlimited)"
),
)
maxFootprintArea = pexConfig.Field[int](
default=2_000_000,
# default=2_000_000,
default=-1,
doc=(
"Maximum area for footprints before they are ignored as large; "
"non-positive means no threshold applied"
),
)
maxAreaTimesPeaks = pexConfig.Field[int](
default=1_000_000_000,
# default=1_000_000_000,
default=-1,
doc=(
"Maximum rectangular footprint area * nPeaks in the footprint. "
"This was introduced in DM-33690 to prevent fields that are crowded or have a "
Expand Down Expand Up @@ -809,14 +812,10 @@ def _addSchemaKeys(self, schema: afwTable.Schema):
doc="The type of model used, for example "
"MultiExtendedSource, SingleExtendedSource, PointSource",
)
self.deblendLevelKey = schema.addField(
"deblend_level",
self.deblendDepth = schema.addField(
"deblend_depth",
type=np.int32,
doc="The level of deblending in the hierarchy. "
"Currently this has the values: \n"
" - 0: The top level parent source of the blend\n"
" - 1: A child of a parent that also has children\n"
" - 2: A child of a higher level parent that has no children."
doc="The depth of deblending in the hierarchy."
)
self.parentNPeaksKey = schema.addField(
"deblend_parentNPeaks",
Expand Down Expand Up @@ -947,12 +946,12 @@ def deblend(
"""Deblend a data cube of multiband images
Deblending iterates over sources from the input catalog,
which are blends of peaks with overlapping PSFs (level 0 parents).
which are blends of peaks with overlapping PSFs (depth 0 parents).
In many cases those footprints can be subdived into multiple
deconvovled footprints, which have an intermediate (level 1)
deconvovled footprints, which have an intermediate
parent record added to the catalog and are be deblended separately.
All deblended peaks have a source record (level 2)
added to the catalog.
All deblended peaks have a source record added to the catalog,
each of which has a depth one greater than the parent.
Parameters
----------
Expand Down Expand Up @@ -1031,8 +1030,9 @@ def deblend(
psfParents, parentHierarchy = self._buildParentHierarchy(catalog, context)

self.log.info(
f"Subdivided the level 0 parents to create {np.sum(catalog[self.deblendLevelKey] == 1)} "
"level 1 parents."
"Subdivided the top level parents to create "
f"{np.sum((catalog[self.deblendDepth] == 1) & (catalog[self.nPeaksKey] > 1))} "
"deconvolved parents."
)

# Attach full image objects to the task to simplify the API
Expand All @@ -1055,7 +1055,6 @@ def deblend(
)

psfParent = catalog.find(psfParentId)
psfParent.set(self.deblendLevelKey, 0)

# Since we use the first peak for the parent object, we should
# propagate its flags to the parent source.
Expand All @@ -1067,7 +1066,6 @@ def deblend(
# so there is no deconvolved parent record.
blendModel = self._deblendParent(
blendRecord=psfParent,
catalog=catalog,
children=children,
footprint=psfParents[psfParentId],
)
Expand Down Expand Up @@ -1098,7 +1096,6 @@ def deblend(
deconvolvedParent = catalog.find(deconvolvedParentId)
blendModel = self._deblendParent(
blendRecord=deconvolvedParent,
catalog=catalog,
children=children[deconvolvedParentId],
)
if blendModel is None:
Expand Down Expand Up @@ -1147,20 +1144,31 @@ def deblend(
mask, mask.getPlaneBitMask(self.config.notDeblendedMask)
)

nDeconvolvedParents = np.sum(catalog[self.deblendLevelKey] == 1)
nDeblendedSources = np.sum(catalog[self.deblendLevelKey] == 2)
nDeconvolvedParents = np.sum((catalog[self.deblendDepth] == 1) & (catalog[self.nPeaksKey] > 1))
nDeblendedSources = np.sum((catalog[self.deblendDepth] > 0) & (catalog[self.nPeaksKey] == 1))
self.log.info(
f"Deblender results: {nPsfBlendedParents} parent sources were "
f"split into {nDeconvolvedParents} deconvovled parents,"
f"resulting in {nDeblendedSources} deblended sources, "
f"for a total catalog size of {len(catalog)} sources",
)
return catalog, modelData

table = afwTable.SourceTable.make(self.schema)
sortedCatalog = afwTable.SourceCatalog(table)
parents = catalog[catalog["parent"] == 0]
sortedCatalog.extend(parents, deep=True)
parentIds = np.unique(catalog["parent"])
for parentId in parentIds[1:]:
parent = catalog.find(parentId)
if not parent.get(self.deblendSkippedKey):
children = catalog[catalog["parent"] == parentId]
sortedCatalog.extend(children, deep=True)

return sortedCatalog, modelData

def _deblendParent(
self,
blendRecord: afwTable.SourceRecord,
catalog: afwTable.SourceCatalog,
children: dict[int, afwTable.SourceRecord],
footprint: afwDet.Footprint | None = None,
) -> scl.Image | None:
Expand Down Expand Up @@ -1283,7 +1291,6 @@ def _deblendParent(
chi2=chi2,
)
scarletSource.record_id = sourceRecord.getId()
catalog.append(sourceRecord)

# Store the blend information so that it can be persisted
if self.config.version == "lite":
Expand Down Expand Up @@ -1554,7 +1561,7 @@ def _addChildren(parent: afwTable.SourceRecord):
children = {}
footprint = parent.getFootprint()
for peak in footprint.peaks:
child = self._createDeblendedSource(
child = self._addDeblendedSource(
parent=parent,
peak=peak,
catalog=catalog,
Expand Down Expand Up @@ -1673,10 +1680,10 @@ def _addDeconvolvedParents(
deconvolvedParents.append(deconvolvedParent)
deconvolvedParent.setParent(parent.getId())
deconvolvedParent.setFootprint(footprint)
deconvolvedParent.set(self.deblendLevelKey, 1)
deconvolvedParent.set(self.deblendDepth, parent[self.deblendDepth] + 1)
return deconvolvedParents

def _createDeblendedSource(
def _addDeblendedSource(
self,
parent: afwTable.SourceRecord,
peak: afwDet.PeakRecord,
Expand All @@ -1703,7 +1710,7 @@ def _createDeblendedSource(
src :
The new child source record.
"""
src = catalog.makeRecord()
src = catalog.addNew()
for key in self.toCopyFromParent:
src.set(key, parent.get(key))
# The peak catalog is the same for all bands,
Expand All @@ -1725,8 +1732,8 @@ def _createDeblendedSource(
src.set(self.peakCenter, geom.Point2I(peak["i_x"], peak["i_y"]))
src.set(self.peakIdKey, peak["id"])

# Set the deblend level
src.set(self.deblendLevelKey, 2)
# Set the deblend depth
src.set(self.deblendDepth, parent[self.deblendDepth] + 1)

return src

Expand Down
68 changes: 3 additions & 65 deletions tests/test_deblend.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,59 +78,6 @@ def setUp(self):
for b, coadd in enumerate(self.coadds):
coadd.setPsf(psfs[b])

def _insert_blank_source(self, modelData, catalog):
# Add parent
parent = catalog.addNew()
parent.setParent(0)
parent["deblend_nChild"] = 1
parent["deblend_nPeaks"] = 1
ss = SpanSet.fromShape(5, Stencil.CIRCLE, offset=(30, 70))
footprint = Footprint(ss)
peak = footprint.addPeak(30, 70, 0)
parent.setFootprint(footprint)

# Add the zero flux source
dtype = np.float32
center = (70, 30)
origin = (center[0] - 5, center[1] - 5)
psf = list(modelData.blends.values())[0].psf
src = catalog.addNew()
src.setParent(parent.getId())
src["deblend_peak_center_x"] = center[1]
src["deblend_peak_center_y"] = center[0]
src["deblend_nPeaks"] = 1

sources = {
src.getId(): {
"components": [],
"factorized": [
{
"origin": origin,
"peak": center,
"spectrum": np.zeros((len(self.bands),), dtype=dtype),
"morph": np.zeros((11, 11), dtype=dtype),
"shape": (11, 11),
}
],
"peak_id": peak.getId(),
}
}

blendData = scl.io.ScarletBlendData.from_dict(
{
"origin": origin,
"shape": (11, 11),
"psf_center": center,
"psf_shape": psf.shape,
"psf": psf.flatten(),
"sources": sources,
"bands": self.bands,
}
)
pid = parent.getId()
modelData.blends[pid] = blendData
return pid, src.getId()

def _deblend(self, version):
schema = SourceCatalog.Table.makeMinimalSchema()
# Adjust config options to test skipping parents
Expand Down Expand Up @@ -179,8 +126,6 @@ def _deblend(self, version):
def test_deblend_task(self):
catalog, modelData, config = self._deblend("lite")

bad_blend_id, bad_src_id = self._insert_blank_source(modelData, catalog)

# Attach the footprints in each band and compare to the full
# data model. This is done in each band, both with and without
# flux re-distribution to test all of the different possible
Expand Down Expand Up @@ -208,9 +153,10 @@ def test_deblend_task(self):
)

# Check that the number of deblended children is consistent
parents = catalog[catalog["deblend_level"] == 0]
parents = catalog[(catalog["deblend_depth"] == 0) & ~catalog["deblend_skipped"]]
self.assertEqual(
np.sum(parents["deblend_nChild"]), np.sum(catalog["deblend_level"] == 2)
np.sum(parents["deblend_nChild"]),
np.sum((catalog["deblend_depth"] > 0) & (catalog["deblend_nPeaks"] == 1))
)

# Check that the models have not been cleared
Expand All @@ -223,10 +169,6 @@ def test_deblend_task(self):
children = catalog[catalog["parent"] == parent.get("id")]
# Check that nChild is set correctly
self.assertEqual(len(children), parent.get("deblend_nChild"))
# Check that parent columns are propagated
# to their children
if parent.getId() == bad_blend_id:
continue
for parentCol, childCol in config.columnInheritance.items():
np.testing.assert_array_equal(
parent.get(parentCol), children[childCol]
Expand Down Expand Up @@ -344,10 +286,6 @@ def test_deblend_task(self):
skipped = largeFootprint | denseFootprint
np.testing.assert_array_equal(skipped, catalog["deblend_skipped"])

# Check that the zero flux source was flagged
for src in catalog:
np.testing.assert_equal(src["deblend_zeroFlux"], src.getId() == bad_src_id)


class MemoryTester(lsst.utils.tests.MemoryTestCase):
pass
Expand Down

0 comments on commit 5165111

Please sign in to comment.