diff --git a/README.rst b/README.rst index 32b91da..841c137 100644 --- a/README.rst +++ b/README.rst @@ -5,4 +5,4 @@ meas_extensions_multiprofit ``meas_extensions_multiprofit`` is a package in the `LSST Science Pipelines `_. ``meas_extensions_multiprofit`` provides tasks and wrappers for running the -`MultiProFit `_ source modelling code on Science Pipelines data repositories. +`MultiProFit `_ source modelling code on Science Pipelines data repositories. diff --git a/python/lsst/meas/extensions/multiprofit/fit_coadd_multiband.py b/python/lsst/meas/extensions/multiprofit/fit_coadd_multiband.py index 052a2e5..7a865b5 100644 --- a/python/lsst/meas/extensions/multiprofit/fit_coadd_multiband.py +++ b/python/lsst/meas/extensions/multiprofit/fit_coadd_multiband.py @@ -22,7 +22,7 @@ import logging import math from functools import cached_property -from typing import Any, Iterable, Mapping, Sequence +from typing import Any, ClassVar, Iterable, Mapping, Sequence import lsst.gauss2d as g2 import lsst.gauss2d.fit as g2f @@ -34,18 +34,16 @@ import pydantic from astropy.table import Table from lsst.daf.butler.formatters.parquet import astropy_to_arrow -from lsst.multiprofit.config import set_config_from_dict from lsst.multiprofit.errors import NoDataError, PsfRebuildFitFlagError -from lsst.multiprofit.fit_psf import CatalogPsfFitterConfig, CatalogPsfFitterConfigData -from lsst.multiprofit.fit_source import ( +from lsst.multiprofit.utils import get_params_uniq, set_config_from_dict +from lsst.multiprofit.fitting.fit_psf import CatalogPsfFitterConfig, CatalogPsfFitterConfigData +from lsst.multiprofit.fitting.fit_source import ( CatalogExposureSourcesABC, CatalogSourceFitterABC, CatalogSourceFitterConfig, CatalogSourceFitterConfigData, ) -from lsst.multiprofit.utils import get_params_uniq from lsst.pex.config.configurableActions import ConfigurableAction, ConfigurableActionField -from pydantic.dataclasses import dataclass from .errors import IsParentError, NotPrimaryError from .utils import get_spanned_image @@ -227,7 +225,7 @@ def setDefaults(self): self.centroid_pixel_offset = -0.5 -@dataclass(frozen=True, kw_only=True, config=fitMB.CatalogExposureConfig) +@pydantic.dataclasses.dataclass(frozen=True, kw_only=True, config=fitMB.CatalogExposureConfig) class CatalogExposurePsfs(fitMB.CatalogExposureInputs, CatalogExposureSourcesABC): """Input data from lsst pipelines, parsed for MultiProFit.""" @@ -396,26 +394,8 @@ def __post_init__(self): object.__setattr__(self, "psf_model_data", config_data) -class MultiProFitSourceTask(CatalogSourceFitterABC, fitMB.CoaddMultibandFitSubTask): - """Run MultiProFit on Exposure/SourceCatalog pairs in multiple bands. - - This task uses MultiProFit to fit a single model to all sources in a coadd, - using a previously-fit PSF model for each exposure. The task may also use - prior measurements from single- or merged multiband catalogs for - initialization. - - Parameters - ---------- - **kwargs - Keyword arguments to pass to CoaddMultibandFitSubTask.__init__. - - Notes - ----- - See https://github.com/lsst-dm/multiprofit for more MultiProFit info. - """ - - ConfigClass = MultiProFitSourceConfig - _DefaultName = "multiProFitSource" +class MultiProFitSourceFitter(CatalogSourceFitterABC): + """A MultiProFit source fitter.""" def __init__(self, **kwargs: Any): errors_expected = {} if "errors_expected" not in kwargs else kwargs.pop("errors_expected") @@ -425,8 +405,7 @@ def __init__(self, **kwargs: Any): for error_catalog in (IsParentError, NoDataError, NotPrimaryError, PsfRebuildFitFlagError): if error_catalog not in errors_expected: errors_expected[error_catalog] = error_catalog.column_name() - CatalogSourceFitterABC.__init__(self, errors_expected=errors_expected) - fitMB.CoaddMultibandFitSubTask.__init__(self, **kwargs) + super().__init__(errors_expected=errors_expected) def copy_centroid_errors( self, @@ -603,11 +582,40 @@ def make_CatalogExposurePsfs(self, catexp: fitMB.CatalogExposureInputs) -> Catal ) return catexp_psf + def validate_fit_inputs( + self, + catalog_multi: Sequence, + catexps: list[CatalogExposurePsfs], + config_data: CatalogSourceFitterConfigData = None, + logger: logging.Logger = None, + **kwargs: Any, + ) -> None: + errors = [] + for idx, catexp in enumerate(catexps): + if not isinstance(catexp, CatalogExposurePsfs): + errors.append(f"catexps[{idx=} {type(catexp)=} !isinstance(CatalogExposurePsfs)") + if errors: + raise RuntimeError("\n".join(errors)) + + +class MultiProFitSourceTask(fitMB.CoaddMultibandFitSubTask): + """Run MultiProFit on Exposure/SourceCatalog pairs in multiple bands. + + This task uses MultiProFit to fit a single model to all sources in a coadd, + using a previously-fit PSF model for each exposure. The task may also use + prior measurements from single- or merged multiband catalogs for + initialization. + """ + + ConfigClass: ClassVar = MultiProFitSourceConfig + _DefaultName: ClassVar = "multiProFitSource" + @utilsTimer.timeMethod def run( self, catalog_multi: Sequence, catexps: list[fitMB.CatalogExposureInputs], + fitter: MultiProFitSourceFitter | None = None, **kwargs, ) -> pipeBase.Struct: """Run the MultiProFit source fit task on catalog-exposure pairs. @@ -627,32 +635,19 @@ def run( A table with fit parameters for the PSF model at the location of each source. """ + if fitter is None: + fitter = MultiProFitSourceFitter() n_catexps = len(catexps) catexps_conv: list[CatalogExposurePsfs] = [None] * n_catexps channels: list[g2f.Channel] = [None] * n_catexps for idx, catexp in enumerate(catexps): if not isinstance(catexp, CatalogExposurePsfs): - catexp = self.make_CatalogExposurePsfs(catexp) + catexp = fitter.make_CatalogExposurePsfs(catexp) catexps_conv[idx] = catexp channels[idx] = catexp.channel self.catexps = catexps config_data = CatalogSourceFitterConfigData(channels=channels, config=self.config) - catalog = self.fit( + catalog = fitter.fit( catalog_multi=catalog_multi, catexps=catexps_conv, config_data=config_data, **kwargs ) return pipeBase.Struct(output=astropy_to_arrow(catalog)) - - def validate_fit_inputs( - self, - catalog_multi: Sequence, - catexps: list[CatalogExposurePsfs], - config_data: CatalogSourceFitterConfigData = None, - logger: logging.Logger = None, - **kwargs: Any, - ) -> None: - errors = [] - for idx, catexp in enumerate(catexps): - if not isinstance(catexp, CatalogExposurePsfs): - errors.append(f"catexps[{idx=} {type(catexp)=} !isinstance(CatalogExposurePsfs)") - if errors: - raise RuntimeError("\n".join(errors)) diff --git a/python/lsst/meas/extensions/multiprofit/fit_coadd_psf.py b/python/lsst/meas/extensions/multiprofit/fit_coadd_psf.py index 27c83aa..ef477a2 100644 --- a/python/lsst/meas/extensions/multiprofit/fit_coadd_psf.py +++ b/python/lsst/meas/extensions/multiprofit/fit_coadd_psf.py @@ -19,13 +19,17 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from typing import ClassVar + import lsst.gauss2d.fit as g2f import lsst.pex.config as pexConfig import lsst.pipe.base as pipeBase import lsst.pipe.tasks.fit_coadd_psf as fitCP import lsst.utils.timer as utilsTimer from lsst.daf.butler.formatters.parquet import astropy_to_arrow -from lsst.multiprofit.fit_psf import CatalogPsfFitter, CatalogPsfFitterConfig, CatalogPsfFitterConfigData +from lsst.multiprofit.fitting.fit_psf import ( + CatalogPsfFitter, CatalogPsfFitterConfig, CatalogPsfFitterConfigData, +) from lsst.pex.exceptions import InvalidParameterError from .errors import IsParentError @@ -58,14 +62,10 @@ class MultiProFitPsfTask(CatalogPsfFitter, fitCP.CoaddPsfFitSubTask): ---------- **kwargs Keyword arguments to pass to CoaddPsfFitSubTask.__init__. - - Notes - ----- - See https://github.com/lsst-dm/multiprofit for more MultiProFit info. """ - ConfigClass = MultiProFitPsfConfig - _DefaultName = "multiProFitPsf" + ConfigClass: ClassVar = MultiProFitPsfConfig + _DefaultName: ClassVar = "multiProFitPsf" def __init__(self, **kwargs): errors_expected = {} if "errors_expected" not in kwargs else kwargs.pop("errors_expected") diff --git a/tests/test_fit_coadd.py b/tests/test_fit_coadd.py index 7eb3e87..8902334 100644 --- a/tests/test_fit_coadd.py +++ b/tests/test_fit_coadd.py @@ -37,9 +37,9 @@ SersicComponentConfig, SersicIndexParameterConfig, ) -from lsst.multiprofit.fit_psf import CatalogPsfFitterConfig from lsst.multiprofit.modelconfig import ModelConfig from lsst.multiprofit.sourceconfig import ComponentGroupConfig, SourceConfig +from lsst.multiprofit.fitting.fit_psf import CatalogPsfFitterConfig from lsst.pipe.tasks.fit_coadd_psf import CatalogExposurePsf ROOT = os.environ.get("TESTDATA_CI_IMSIM_MINI", None) @@ -260,7 +260,7 @@ def test_psf_fits(psf_fit_results): if psf_fit_results is not None: assert len(psf_fit_results) == n_test for column in psf_fit_results.columns: - assert np.all(np.isfinite(psf_fit_results[column])) + assert column and np.all(np.isfinite(psf_fit_results[column])) # TODO: Determine what checks can be done against previous values