Skip to content

Commit

Permalink
Document rebuilder methods
Browse files Browse the repository at this point in the history
  • Loading branch information
taranu committed Jun 10, 2024
1 parent 2db0df0 commit 8b2135c
Showing 1 changed file with 132 additions and 29 deletions.
161 changes: 132 additions & 29 deletions python/lsst/meas/extensions/multiprofit/rebuild_coadd_multiband.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,30 @@
__all__ = ["ModelRebuilder", "PatchModelMatches", "PatchCoaddRebuilder"]

from functools import cached_property
from typing import Iterable

import astropy.table
import astropy.units as u
import gauss2d.fit as g2f
import lsst.afw.table as afwTable
import lsst.daf.butler as dafButler
import numpy as np
import lsst.geom as geom
import numpy as np
import pydantic
from lsst.meas.extensions.scarlet.io import updateCatalogFootprints
from lsst.pipe.base import QuantumContext, QuantumGraph
from lsst.pipe.tasks.fit_coadd_multiband import (
CatalogExposureInputs, CoaddMultibandFitBaseTemplates, CoaddMultibandFitTask,
CoaddMultibandFitBaseTemplates,
CoaddMultibandFitInputConnections,
CoaddMultibandFitTask,
)
from lsst.skymap import BaseSkyMap, TractInfo
from typing import Iterable

from .fit_coadd_multiband import (
CatalogExposurePsfs, CatalogSourceFitterConfigData, MultiProFitSourceConfig, MultiProFitSourceTask,
CatalogExposurePsfs,
CatalogSourceFitterConfigData,
MultiProFitSourceConfig,
MultiProFitSourceTask,
)

astropy_to_geom_units = {
Expand All @@ -54,21 +58,81 @@


def astropy_unit_to_geom(unit: u.Unit, default=None) -> geom.AngleUnit:
"""Convert an astropy unit to an lsst.geom unit.
Parameters
----------
unit
The astropy unit to convert.
default
The default value to return if no known conversion is found.
Returns
-------
unit_geom
The equivalent unit, if found.
Raises
------
ValueError
Raised if no equivalent unit is found.
"""
unit_geom = astropy_to_geom_units.get(unit, default)
if unit_geom is None:
raise ValueError(f"{unit=} not found in {astropy_to_geom_units=}")
return unit_geom


def find_patches(tract_info: TractInfo, ra_array, dec_array, unit: geom.AngleUnit) -> list[int]:
"""Find the patches containing a list of ra/dec values within a tract.
Parameters
----------
tract_info
The TractInfo object for the tract.
ra_array
The array of right ascension values.
dec_array
The array of declination values (must be same length as ra_array).
unit
The unit of the RA/dec values.
Returns
-------
patches
A list of patches containing the specified RA/dec values.
"""
radec = [geom.SpherePoint(ra, dec, units=unit) for ra, dec in zip(ra_array, dec_array, strict=True)]
points = np.array([geom.Point2I(tract_info.wcs.skyToPixel(coords)) for coords in radec])
x_list, y_list = (points[:, idx]//tract_info.patch_inner_dimensions[idx] for idx in range(2))
x_list, y_list = (points[:, idx] // tract_info.patch_inner_dimensions[idx] for idx in range(2))
patches = [tract_info.getSequentialPatchIndexFromPair((x, y)) for x, y in zip(x_list, y_list)]
return patches


def get_radec_unit(table, coord_ra, coord_dec, default=None):
def get_radec_unit(table: astropy.table.Table, coord_ra: str, coord_dec: str, default=None):
"""Get the RA/dec units for columns in a table.
Parameters
----------
table
The table to determine units for.
coord_ra
The key of the right ascension column.
coord_dec
The key of the declination column.
default
The default value to return if no unit is found.
Returns
-------
unit
The unit of the RA/dec columns or None if none is found.
Raises
------
ValueError
Raised if the units are inconsistent.
"""
unit_ra, unit_dec = (
astropy_unit_to_geom(table[coord].unit, default=default) for coord in (coord_ra, coord_dec)
)
Expand Down Expand Up @@ -97,38 +161,56 @@ def channels(self) -> tuple[g2f.Channel]:

@classmethod
def from_butler(
cls,
butler: dafButler.Butler,
data_id: dict[str],
bands: Iterable[str],
name_coadd=None,
**kwargs
cls, butler: dafButler.Butler, data_id: dict[str], bands: Iterable[str], name_coadd=None, **kwargs
):
"""Construct a DataLoader from a Butler and dataId.
Parameters
----------
butler
The butler to load from.
data_id
Key-value pairs for the {name_coadd}Coadd_* dataId.
bands
The list of bands to load.
name_coadd
The prefix of the Coadd datasettype name.
**kwargs
Additional keyword arguments to pass to the init method for
`CoaddMultibandFitInputConnections`.
Returns
-------
data_loader
An initialized DataLoader.
"""
bands = tuple(bands)
if len(set(bands)) != len(bands):
raise ValueError(f"{bands=} is not a set")
if name_coadd is None:
name_coadd = CoaddMultibandFitBaseTemplates["name_coadd"]

catalog_multi = butler.get(
CoaddMultibandFitInputConnections.cat_ref.name.format(name_coadd=name_coadd),
**data_id, **kwargs
CoaddMultibandFitInputConnections.cat_ref.name.format(name_coadd=name_coadd), **data_id, **kwargs
)

catexps = {}
for band in bands:
data_id["band"] = band
catalog = butler.get(
CoaddMultibandFitInputConnections.cats_meas.name.format(name_coadd=name_coadd),
**data_id, **kwargs
**data_id,
**kwargs,
)
exposure = butler.get(
CoaddMultibandFitInputConnections.coadds.name.format(name_coadd=name_coadd),
**data_id, **kwargs
**data_id,
**kwargs,
)
models_scarlet = butler.get(
CoaddMultibandFitInputConnections.models_scarlet.name.format(name_coadd=name_coadd),
**data_id, **kwargs
**data_id,
**kwargs,
)
updateCatalogFootprints(
modelData=models_scarlet,
Expand All @@ -140,8 +222,11 @@ def from_butler(
)
# The config and table are harmless dummies
catexps[band] = CatalogExposurePsfs(
catalog=catalog, exposure=exposure, table_psf_fits=astropy.table.Table(),
dataId=data_id, id_tract_patch=data_id["patch"],
catalog=catalog,
exposure=exposure,
table_psf_fits=astropy.table.Table(),
dataId=data_id,
id_tract_patch=data_id["patch"],
channel=g2f.Channel.get(band),
config_fit=MultiProFitSourceConfig(),
)
Expand All @@ -168,15 +253,14 @@ def load_deblended_object(
"""
observations = []
for catexp in self.catexps:
observations.append(
catexp.get_source_observation(catexp.get_catalog()[idx_row])
)
observations.append(catexp.get_source_observation(catexp.get_catalog()[idx_row]))
return observations


class ModelRebuilder(DataLoader):
"""A rebuilder of MultiProFit models from their inputs and best-fit
parameter values."""
parameter values.
"""

fit_results: astropy.table.Table = pydantic.Field(doc="Multiprofit model fit results")
task_fit: MultiProFitSourceTask = pydantic.Field(doc="The task")
Expand Down Expand Up @@ -354,19 +438,32 @@ def from_butler(
load_multiprofit: bool = True,
dataset_type_ref: str = "truth_summary",
):
"""Init a PatchCoaddRebuilder from a single Butler collection.
"""Construct a PatchCoaddRebuilder from a single Butler collection.
Parameters
----------
butler
The butler to load from.
skymap
The skymap for the collection.
tract
The skymap tract id.
patch
The skymap patch id.
collection_merged
The name of the collection with the merged objectTable(s).
matches
A dictionary of model names with corresponding QuantumGraphs.
These may be None but must be provided for MultiProFit model
reconstruction to be possible.
bands
The list of bands to load data for.
name_model_ref
The name of the model to use as a reference. Must be a key in
`matches`.
format_collection
A format string for the output collection(s) defined in the
`matches` QuantumGraphs.
load_multiprofit
Whether to attempt to load an objectTable_tract_multiprofit.
dataset_type_ref
Expand Down Expand Up @@ -439,22 +536,28 @@ def from_butler(
# should probably be fixed in diff_matched
# but need to decide priority on matched - ref first? or target?
unit_coord_ref = get_radec_unit(
matched, "refcat_ra", "refcat_dec", default=geom.degrees,
matched,
"refcat_ra",
"refcat_dec",
default=geom.degrees,
)
unmatched = (
matched["patch"].mask if np.ma.is_masked(matched["patch"]) else ~(matched["patch"] >= 0)
) & np.isfinite(matched["refcat_ra"])
patches_unmatched = find_patches(
skymap_tract,
matched["refcat_ra"][unmatched], matched["refcat_dec"][unmatched],
unit=unit_coord_ref
matched["refcat_ra"][unmatched],
matched["refcat_dec"][unmatched],
unit=unit_coord_ref,
)
matched["patch"][np.where(unmatched)[0]] = patches_unmatched
matched = matched[matched["patch"] == patch]
rebuilder = (
ModelRebuilder.from_quantumGraph(butler, quantumgraph, dataId=dataId)
if is_mpf else
DataLoader.from_butler(butler, data_id=dataId, bands=bands, collections=[collection_merged])
if is_mpf
else DataLoader.from_butler(
butler, data_id=dataId, bands=bands, collections=[collection_merged]
)
)
matches_name[name] = PatchModelMatches(
matches=matched, quantumgraph=quantumgraph, rebuilder=rebuilder
Expand Down

0 comments on commit 8b2135c

Please sign in to comment.