From 2db0df0f616567f36b8d6d7c33d9bee5bc2b2a0c Mon Sep 17 00:00:00 2001 From: Dan Taranu Date: Mon, 3 Jun 2024 13:30:55 -0700 Subject: [PATCH] Add DataLoader This loader needs no multiprofit data. --- .../lsst/meas/extensions/multiprofit/plots.py | 53 ++-- .../multiprofit/rebuild_coadd_multiband.py | 227 +++++++++++++++--- 2 files changed, 236 insertions(+), 44 deletions(-) diff --git a/python/lsst/meas/extensions/multiprofit/plots.py b/python/lsst/meas/extensions/multiprofit/plots.py index a88eea2..90d66f2 100644 --- a/python/lsst/meas/extensions/multiprofit/plots.py +++ b/python/lsst/meas/extensions/multiprofit/plots.py @@ -21,7 +21,7 @@ from abc import ABC, abstractmethod -from typing import Any, Iterable, Self +from typing import Any, Iterable, Self, Type import astropy.table import astropy.units as u @@ -32,7 +32,7 @@ import pydantic from lsst.multiprofit.plots import bands_weights_lsst, plot_model_rgb -from .rebuild_coadd_multiband import PatchCoaddRebuilder +from .rebuild_coadd_multiband import DataLoader, PatchCoaddRebuilder __all__ = [ "ObjectTableBase", @@ -304,6 +304,7 @@ def plot_blend( rebuilder: PatchCoaddRebuilder, idx_row_parent: int, weights: dict[str, float] = None, + table_ref_type: Type = TruthSummaryTable, kwargs_plot_parent: dict[str, Any] = None, kwargs_plot_children: dict[str, Any] = None, ) -> tuple[Figure, Axes, Figure, Axes]: @@ -317,6 +318,8 @@ def plot_blend( The row index of the parent object in the reference SourceCatalog. weights Multiplicative weights by band name for RGB plots. + table_ref_type + The type of reference table to construct when downselecting. kwargs_plot_parent Keyword arguments to pass to make RGB plots of the parent blend. kwargs_plot_children @@ -339,6 +342,8 @@ def plot_blend( kwargs_plot_children = {} if weights is None: weights = bands_weights_lsst + + plot_chi_hist = kwargs_plot_children.pop("plot_chi_hist", True) rebuilder_ref = rebuilder.matches[rebuilder.name_model_ref].rebuilder observations = { catexp.band: catexp.get_source_observation(catexp.get_catalog()[idx_row_parent], skip_flags=True) @@ -346,9 +351,10 @@ def plot_blend( } fig_rgb, ax_rgb, fig_gs, ax_gs, *_ = plot_model_rgb( - model=None, weights=weights, observations=observations, plot_singleband=False, **kwargs_plot_parent + model=None, weights=weights, observations=observations, plot_singleband=False, plot_chi_hist=False, + **kwargs_plot_parent ) - table_within_ref = downselect_table_axis(TruthSummaryTable(table=rebuilder.reference), ax_rgb) + table_within_ref = downselect_table_axis(table_ref_type(table=rebuilder.reference), ax_rgb) plot_objects(table_within_ref, ax_rgb, weights, table_downselected=True) objects_primary = rebuilder.objects[rebuilder.objects["detect_isPrimary"] == True] # noqa: E712 @@ -370,7 +376,7 @@ def plot_blend( objects_mpf = rebuilder.objects_multiprofit objects_mpf_within = {} for name, matched in rebuilder.matches.items(): - if matched.rebuilder: + if matched.rebuilder and objects_mpf: objects_mpf_within[name] = downselect_table_axis( ObjectTableMultiProFit(name_model=name, table=objects_mpf), ax_rgb, @@ -386,24 +392,37 @@ def plot_blend( for name, matched in rebuilder.matches.items(): print(f"Model: {name}") rebuilder_child = matched.rebuilder - if rebuilder_child: + is_dataloader = isinstance(rebuilder_child, DataLoader) + is_scarlet = is_dataloader and (name == "scarlet") + if is_scarlet or rebuilder_child: try: - model = rebuilder_child.make_model(idx_child) + if is_dataloader: + model = None + observations = rebuilder_child.load_deblended_object(idx_child) + else: + model = rebuilder_child.make_model(idx_child) + observations = None + _, ax_rgb_c, *_ = plot_model_rgb( - model=model, weights=weights, plot_singleband=False, **kwargs_plot_children + model=model, weights=weights, plot_singleband=False, + plot_chi_hist=(not is_dataloader) and plot_chi_hist, + observations=observations, + **kwargs_plot_children ) ax_rgb_c0 = ax_rgb_c[0][0] plot_objects(table_within_ref, ax_rgb_c0, weights) - plot_objects( - objects_mpf_within[name], - ax_rgb_c0, - weights, - kwargs_annotate=kwargs_annotate_obs, - kwargs_scatter=kwargs_scatter_obs, - labels_extended=labels_extended_model, - ) + tab_mpf = objects_mpf_within.get(name) + if tab_mpf: + plot_objects( + tab_mpf, + ax_rgb_c0, + weights, + kwargs_annotate=kwargs_annotate_obs, + kwargs_scatter=kwargs_scatter_obs, + labels_extended=labels_extended_model, + ) plt.show() except Exception as exc: - print(f"failed to rebuild due to {exc}") + print(f"{idx_child=} failed to rebuild due to {exc}") return fig_rgb, ax_rgb, fig_gs, ax_gs diff --git a/python/lsst/meas/extensions/multiprofit/rebuild_coadd_multiband.py b/python/lsst/meas/extensions/multiprofit/rebuild_coadd_multiband.py index 31ed996..99e734b 100644 --- a/python/lsst/meas/extensions/multiprofit/rebuild_coadd_multiband.py +++ b/python/lsst/meas/extensions/multiprofit/rebuild_coadd_multiband.py @@ -24,22 +24,62 @@ from functools import cached_property import astropy.table +import astropy.units as u import gauss2d.fit as g2f import lsst.afw.table as afwTable import lsst.daf.butler as dafButler +import numpy as np import lsst.geom as geom import pydantic +from lsst.meas.extensions.scarlet.io import updateCatalogFootprints from lsst.pipe.base import QuantumContext, QuantumGraph -from lsst.pipe.tasks.fit_coadd_multiband import CoaddMultibandFitTask -from lsst.skymap import BaseSkyMap +from lsst.pipe.tasks.fit_coadd_multiband import ( + CatalogExposureInputs, CoaddMultibandFitBaseTemplates, CoaddMultibandFitTask, + CoaddMultibandFitInputConnections, +) +from lsst.skymap import BaseSkyMap, TractInfo +from typing import Iterable -from .fit_coadd_multiband import CatalogExposurePsfs, CatalogSourceFitterConfigData, MultiProFitSourceTask +from .fit_coadd_multiband import ( + CatalogExposurePsfs, CatalogSourceFitterConfigData, MultiProFitSourceConfig, MultiProFitSourceTask, +) +astropy_to_geom_units = { + u.arcmin: geom.arcminutes, + u.arcsec: geom.arcseconds, + u.mas: geom.milliarcseconds, + u.deg: geom.degrees, + u.rad: geom.radians, +} + + +def astropy_unit_to_geom(unit: u.Unit, default=None) -> geom.AngleUnit: + unit_geom = astropy_to_geom_units.get(unit, default) + if unit_geom is None: + raise ValueError(f"{unit=} not found in {astropy_to_geom_units=}") + return unit_geom + + +def find_patches(tract_info: TractInfo, ra_array, dec_array, unit: geom.AngleUnit) -> list[int]: + radec = [geom.SpherePoint(ra, dec, units=unit) for ra, dec in zip(ra_array, dec_array, strict=True)] + points = np.array([geom.Point2I(tract_info.wcs.skyToPixel(coords)) for coords in radec]) + x_list, y_list = (points[:, idx]//tract_info.patch_inner_dimensions[idx] for idx in range(2)) + patches = [tract_info.getSequentialPatchIndexFromPair((x, y)) for x, y in zip(x_list, y_list)] + return patches + + +def get_radec_unit(table, coord_ra, coord_dec, default=None): + unit_ra, unit_dec = ( + astropy_unit_to_geom(table[coord].unit, default=default) for coord in (coord_ra, coord_dec) + ) + if unit_ra != unit_dec: + units = {coord: table[coord].unit for coord in (coord_ra, coord_dec)} + raise ValueError(f"Reference table has inconsistent {units=}") + return unit_ra -class ModelRebuilder(pydantic.BaseModel): - """A rebuilder of MultiProFit models from their inputs and best-fit - parameter values. - """ + +class DataLoader(pydantic.BaseModel): + """A collection of data that can be used to rebuild models.""" model_config = pydantic.ConfigDict(arbitrary_types_allowed=True, frozen=True) @@ -49,14 +89,98 @@ class ModelRebuilder(pydantic.BaseModel): catalog_multi: afwTable.SourceCatalog = pydantic.Field( doc="Patch-level multiband reference catalog (deepCoadd_ref)", ) - fit_results: astropy.table.Table = pydantic.Field(doc="Multiprofit model fit results") - task_fit: MultiProFitSourceTask = pydantic.Field(doc="The task") @cached_property def channels(self) -> tuple[g2f.Channel]: channels = tuple(g2f.Channel.get(catexp.band) for catexp in self.catexps) return channels + @classmethod + def from_butler( + cls, + butler: dafButler.Butler, + data_id: dict[str], + bands: Iterable[str], + name_coadd=None, + **kwargs + ): + bands = tuple(bands) + if len(set(bands)) != len(bands): + raise ValueError(f"{bands=} is not a set") + if name_coadd is None: + name_coadd = CoaddMultibandFitBaseTemplates["name_coadd"] + + catalog_multi = butler.get( + CoaddMultibandFitInputConnections.cat_ref.name.format(name_coadd=name_coadd), + **data_id, **kwargs + ) + + catexps = {} + for band in bands: + data_id["band"] = band + catalog = butler.get( + CoaddMultibandFitInputConnections.cats_meas.name.format(name_coadd=name_coadd), + **data_id, **kwargs + ) + exposure = butler.get( + CoaddMultibandFitInputConnections.coadds.name.format(name_coadd=name_coadd), + **data_id, **kwargs + ) + models_scarlet = butler.get( + CoaddMultibandFitInputConnections.models_scarlet.name.format(name_coadd=name_coadd), + **data_id, **kwargs + ) + updateCatalogFootprints( + modelData=models_scarlet, + catalog=catalog, + band=data_id["band"], + imageForRedistribution=exposure, + removeScarletData=True, + updateFluxColumns=False, + ) + # The config and table are harmless dummies + catexps[band] = CatalogExposurePsfs( + catalog=catalog, exposure=exposure, table_psf_fits=astropy.table.Table(), + dataId=data_id, id_tract_patch=data_id["patch"], + channel=g2f.Channel.get(band), + config_fit=MultiProFitSourceConfig(), + ) + return cls( + catalog_multi=catalog_multi, + catexps=list(catexps.values()), + ) + + def load_deblended_object( + self, + idx_row: int, + ) -> list[g2f.Observation]: + """Load a deblended object from catexps. + + Parameters + ---------- + idx_row + The index of the object to load. + + Returns + ------- + observations + The observations of the object (deblended if it is a child). + """ + observations = [] + for catexp in self.catexps: + observations.append( + catexp.get_source_observation(catexp.get_catalog()[idx_row]) + ) + return observations + + +class ModelRebuilder(DataLoader): + """A rebuilder of MultiProFit models from their inputs and best-fit + parameter values.""" + + fit_results: astropy.table.Table = pydantic.Field(doc="Multiprofit model fit results") + task_fit: MultiProFitSourceTask = pydantic.Field(doc="The task") + @cached_property def config_data(self) -> CatalogSourceFitterConfigData: config_data = self.make_config_data() @@ -197,7 +321,7 @@ class PatchModelMatches(pydantic.BaseModel): matches: astropy.table.Table | None = pydantic.Field(doc="Catalogs of matches") quantumgraph: QuantumGraph | None = pydantic.Field(doc="Quantum graph for fit task") - rebuilder: ModelRebuilder | None = pydantic.Field(doc="MultiProFit object model rebuilder") + rebuilder: DataLoader | ModelRebuilder | None = pydantic.Field(doc="MultiProFit object model rebuilder") class PatchCoaddRebuilder(pydantic.BaseModel): @@ -208,7 +332,7 @@ class PatchCoaddRebuilder(pydantic.BaseModel): matches: dict[str, PatchModelMatches] = pydantic.Field("Model matches by algorithm name") name_model_ref: str = pydantic.Field(doc="The name of the reference model in matches") objects: astropy.table.Table = pydantic.Field(doc="Object table") - objects_multiprofit: astropy.table.Table = pydantic.Field(doc="Object table for MultiProFit fits") + objects_multiprofit: astropy.table.Table | None = pydantic.Field(doc="Object table for MultiProFit fits") reference: astropy.table.Table = pydantic.Field(doc="Reference object table") skymap: str = pydantic.Field(doc="The skymap name") @@ -223,35 +347,68 @@ def from_butler( tract: int, patch: int, collection_merged: str, - matches: dict[str, QuantumGraph], + matches: dict[str, QuantumGraph | None], + bands: Iterable[str] = None, name_model_ref: str = None, format_collection: str = "{run}", + load_multiprofit: bool = True, + dataset_type_ref: str = "truth_summary", ): + """Init a PatchCoaddRebuilder from a single Butler collection. + + Parameters + ---------- + butler + skymap + tract + patch + collection_merged + matches + bands + name_model_ref + format_collection + load_multiprofit + Whether to attempt to load an objectTable_tract_multiprofit. + dataset_type_ref + The dataset type of the reference catalog. + + Returns + ------- + rebuilder + The fully-configured PatchCoaddRebuilder. + """ if name_model_ref is None: for name, quantumgraph in matches.items(): if quantumgraph is not None: name_model_ref = name break if name_model_ref is None: - raise ValueError("At least one matches with a quantumgraph must be supplied") + raise ValueError("Must supply name_model_ref or at least one matches with a quantumgraph") dataId = dict(skymap=skymap, tract=tract, patch=patch) objects = butler.get( "objectTable_tract", collections=[collection_merged], storageClass="ArrowAstropy", **dataId ) objects = objects[objects["patch"] == patch] - objects_multiprofit = butler.get( - "objectTable_tract_multiprofit", - collections=[collection_merged], - storageClass="ArrowAstropy", - **dataId, - ) - objects_multiprofit = objects_multiprofit[objects_multiprofit["patch"] == patch] + if load_multiprofit: + objects_multiprofit = butler.get( + "objectTable_tract_multiprofit", + collections=[collection_merged], + storageClass="ArrowAstropy", + **dataId, + ) + objects_multiprofit = objects_multiprofit[objects_multiprofit["patch"] == patch] + else: + objects_multiprofit = None reference = butler.get( - "truth_summary", collections=[collection_merged], storageClass="ArrowAstropy", **dataId + dataset_type_ref, collections=[collection_merged], storageClass="ArrowAstropy", **dataId ) skymap_tract = butler.get(BaseSkyMap.SKYMAP_DATASET_TYPE_NAME, skymap=skymap)[tract] - # the ci_imsim truth_summary still has string patches - if reference["patch"].dtype != int: + unit_coord_ref = get_radec_unit(reference, "ra", "dec", default=geom.degrees) + if "patch" not in reference.columns: + patches = find_patches(skymap_tract, reference["ra"], reference["dec"], unit=unit_coord_ref) + reference["patch"] = patches + elif reference["patch"].dtype != int: + # the ci_imsim truth_summary still has string patches index_patch = skymap_tract[patch].index str_patch = f"{index_patch.y},{index_patch.x}" reference = reference[ @@ -259,8 +416,7 @@ def from_butler( ] del reference["patch"] reference["patch"] = patch - else: - reference = reference[reference["patch"] == patch] + reference = reference[reference["patch"] == patch] points = skymap_tract.wcs.skyToPixel( [geom.SpherePoint(row["ra"], row["dec"], units=geom.degrees) for row in reference] ) @@ -270,7 +426,7 @@ def from_butler( for name, quantumgraph in matches.items(): is_mpf = quantumgraph is not None matched = butler.get( - f"matched_truth_summary_objectTable_tract{'_multiprofit' if is_mpf else ''}", + f"matched_{dataset_type_ref}_objectTable_tract{'_multiprofit' if is_mpf else ''}", collections=[ format_collection.format(run=quantumgraph.metadata["output"], name=name) if is_mpf @@ -279,9 +435,26 @@ def from_butler( storageClass="ArrowAstropy", **dataId, ) + # unmatched ref objects don't have a patch set + # should probably be fixed in diff_matched + # but need to decide priority on matched - ref first? or target? + unit_coord_ref = get_radec_unit( + matched, "refcat_ra", "refcat_dec", default=geom.degrees, + ) + unmatched = ( + matched["patch"].mask if np.ma.is_masked(matched["patch"]) else ~(matched["patch"] >= 0) + ) & np.isfinite(matched["refcat_ra"]) + patches_unmatched = find_patches( + skymap_tract, + matched["refcat_ra"][unmatched], matched["refcat_dec"][unmatched], + unit=unit_coord_ref + ) + matched["patch"][np.where(unmatched)[0]] = patches_unmatched matched = matched[matched["patch"] == patch] rebuilder = ( - ModelRebuilder.from_quantumGraph(butler, quantumgraph, dataId=dataId) if is_mpf else None + ModelRebuilder.from_quantumGraph(butler, quantumgraph, dataId=dataId) + if is_mpf else + DataLoader.from_butler(butler, data_id=dataId, bands=bands, collections=[collection_merged]) ) matches_name[name] = PatchModelMatches( matches=matched, quantumgraph=quantumgraph, rebuilder=rebuilder