Skip to content

Commit

Permalink
Merge branch 'tickets/DM-43907'
Browse files Browse the repository at this point in the history
  • Loading branch information
taranu committed Aug 14, 2024
2 parents 07d6d96 + 0c7e20b commit cace424
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 56 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ meas_extensions_multiprofit
``meas_extensions_multiprofit`` is a package in the `LSST Science Pipelines <https://pipelines.lsst.io>`_.

``meas_extensions_multiprofit`` provides tasks and wrappers for running the
`MultiProFit <https://github.com/lsst-dm/multiprofit>`_ source modelling code on Science Pipelines data repositories.
`MultiProFit <https://github.com/lsst/multiprofit>`_ source modelling code on Science Pipelines data repositories.
87 changes: 41 additions & 46 deletions python/lsst/meas/extensions/multiprofit/fit_coadd_multiband.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import logging
import math
from functools import cached_property
from typing import Any, Iterable, Mapping, Sequence
from typing import Any, ClassVar, Iterable, Mapping, Sequence

import lsst.gauss2d as g2
import lsst.gauss2d.fit as g2f
Expand All @@ -34,18 +34,16 @@
import pydantic
from astropy.table import Table
from lsst.daf.butler.formatters.parquet import astropy_to_arrow
from lsst.multiprofit.config import set_config_from_dict
from lsst.multiprofit.errors import NoDataError, PsfRebuildFitFlagError
from lsst.multiprofit.fit_psf import CatalogPsfFitterConfig, CatalogPsfFitterConfigData
from lsst.multiprofit.fit_source import (
from lsst.multiprofit.utils import get_params_uniq, set_config_from_dict
from lsst.multiprofit.fitting.fit_psf import CatalogPsfFitterConfig, CatalogPsfFitterConfigData
from lsst.multiprofit.fitting.fit_source import (
CatalogExposureSourcesABC,
CatalogSourceFitterABC,
CatalogSourceFitterConfig,
CatalogSourceFitterConfigData,
)
from lsst.multiprofit.utils import get_params_uniq
from lsst.pex.config.configurableActions import ConfigurableAction, ConfigurableActionField
from pydantic.dataclasses import dataclass

from .errors import IsParentError, NotPrimaryError
from .utils import get_spanned_image
Expand Down Expand Up @@ -227,7 +225,7 @@ def setDefaults(self):
self.centroid_pixel_offset = -0.5


@dataclass(frozen=True, kw_only=True, config=fitMB.CatalogExposureConfig)
@pydantic.dataclasses.dataclass(frozen=True, kw_only=True, config=fitMB.CatalogExposureConfig)
class CatalogExposurePsfs(fitMB.CatalogExposureInputs, CatalogExposureSourcesABC):
"""Input data from lsst pipelines, parsed for MultiProFit."""

Expand Down Expand Up @@ -396,26 +394,8 @@ def __post_init__(self):
object.__setattr__(self, "psf_model_data", config_data)


class MultiProFitSourceTask(CatalogSourceFitterABC, fitMB.CoaddMultibandFitSubTask):
"""Run MultiProFit on Exposure/SourceCatalog pairs in multiple bands.
This task uses MultiProFit to fit a single model to all sources in a coadd,
using a previously-fit PSF model for each exposure. The task may also use
prior measurements from single- or merged multiband catalogs for
initialization.
Parameters
----------
**kwargs
Keyword arguments to pass to CoaddMultibandFitSubTask.__init__.
Notes
-----
See https://github.com/lsst-dm/multiprofit for more MultiProFit info.
"""

ConfigClass = MultiProFitSourceConfig
_DefaultName = "multiProFitSource"
class MultiProFitSourceFitter(CatalogSourceFitterABC):
"""A MultiProFit source fitter."""

def __init__(self, **kwargs: Any):
errors_expected = {} if "errors_expected" not in kwargs else kwargs.pop("errors_expected")
Expand All @@ -425,8 +405,7 @@ def __init__(self, **kwargs: Any):
for error_catalog in (IsParentError, NoDataError, NotPrimaryError, PsfRebuildFitFlagError):
if error_catalog not in errors_expected:
errors_expected[error_catalog] = error_catalog.column_name()
CatalogSourceFitterABC.__init__(self, errors_expected=errors_expected)
fitMB.CoaddMultibandFitSubTask.__init__(self, **kwargs)
super().__init__(errors_expected=errors_expected)

def copy_centroid_errors(
self,
Expand Down Expand Up @@ -603,11 +582,40 @@ def make_CatalogExposurePsfs(self, catexp: fitMB.CatalogExposureInputs) -> Catal
)
return catexp_psf

def validate_fit_inputs(
self,
catalog_multi: Sequence,
catexps: list[CatalogExposurePsfs],
config_data: CatalogSourceFitterConfigData = None,
logger: logging.Logger = None,
**kwargs: Any,
) -> None:
errors = []
for idx, catexp in enumerate(catexps):
if not isinstance(catexp, CatalogExposurePsfs):
errors.append(f"catexps[{idx=} {type(catexp)=} !isinstance(CatalogExposurePsfs)")
if errors:
raise RuntimeError("\n".join(errors))


class MultiProFitSourceTask(fitMB.CoaddMultibandFitSubTask):
"""Run MultiProFit on Exposure/SourceCatalog pairs in multiple bands.
This task uses MultiProFit to fit a single model to all sources in a coadd,
using a previously-fit PSF model for each exposure. The task may also use
prior measurements from single- or merged multiband catalogs for
initialization.
"""

ConfigClass: ClassVar = MultiProFitSourceConfig
_DefaultName: ClassVar = "multiProFitSource"

@utilsTimer.timeMethod
def run(
self,
catalog_multi: Sequence,
catexps: list[fitMB.CatalogExposureInputs],
fitter: MultiProFitSourceFitter | None = None,
**kwargs,
) -> pipeBase.Struct:
"""Run the MultiProFit source fit task on catalog-exposure pairs.
Expand All @@ -627,32 +635,19 @@ def run(
A table with fit parameters for the PSF model at the location
of each source.
"""
if fitter is None:
fitter = MultiProFitSourceFitter()
n_catexps = len(catexps)
catexps_conv: list[CatalogExposurePsfs] = [None] * n_catexps
channels: list[g2f.Channel] = [None] * n_catexps
for idx, catexp in enumerate(catexps):
if not isinstance(catexp, CatalogExposurePsfs):
catexp = self.make_CatalogExposurePsfs(catexp)
catexp = fitter.make_CatalogExposurePsfs(catexp)
catexps_conv[idx] = catexp
channels[idx] = catexp.channel
self.catexps = catexps
config_data = CatalogSourceFitterConfigData(channels=channels, config=self.config)
catalog = self.fit(
catalog = fitter.fit(
catalog_multi=catalog_multi, catexps=catexps_conv, config_data=config_data, **kwargs
)
return pipeBase.Struct(output=astropy_to_arrow(catalog))

def validate_fit_inputs(
self,
catalog_multi: Sequence,
catexps: list[CatalogExposurePsfs],
config_data: CatalogSourceFitterConfigData = None,
logger: logging.Logger = None,
**kwargs: Any,
) -> None:
errors = []
for idx, catexp in enumerate(catexps):
if not isinstance(catexp, CatalogExposurePsfs):
errors.append(f"catexps[{idx=} {type(catexp)=} !isinstance(CatalogExposurePsfs)")
if errors:
raise RuntimeError("\n".join(errors))
14 changes: 7 additions & 7 deletions python/lsst/meas/extensions/multiprofit/fit_coadd_psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

from typing import ClassVar

import lsst.gauss2d.fit as g2f
import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
import lsst.pipe.tasks.fit_coadd_psf as fitCP
import lsst.utils.timer as utilsTimer
from lsst.daf.butler.formatters.parquet import astropy_to_arrow
from lsst.multiprofit.fit_psf import CatalogPsfFitter, CatalogPsfFitterConfig, CatalogPsfFitterConfigData
from lsst.multiprofit.fitting.fit_psf import (
CatalogPsfFitter, CatalogPsfFitterConfig, CatalogPsfFitterConfigData,
)
from lsst.pex.exceptions import InvalidParameterError

from .errors import IsParentError
Expand Down Expand Up @@ -58,14 +62,10 @@ class MultiProFitPsfTask(CatalogPsfFitter, fitCP.CoaddPsfFitSubTask):
----------
**kwargs
Keyword arguments to pass to CoaddPsfFitSubTask.__init__.
Notes
-----
See https://github.com/lsst-dm/multiprofit for more MultiProFit info.
"""

ConfigClass = MultiProFitPsfConfig
_DefaultName = "multiProFitPsf"
ConfigClass: ClassVar = MultiProFitPsfConfig
_DefaultName: ClassVar = "multiProFitPsf"

def __init__(self, **kwargs):
errors_expected = {} if "errors_expected" not in kwargs else kwargs.pop("errors_expected")
Expand Down
4 changes: 2 additions & 2 deletions tests/test_fit_coadd.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
SersicComponentConfig,
SersicIndexParameterConfig,
)
from lsst.multiprofit.fit_psf import CatalogPsfFitterConfig
from lsst.multiprofit.modelconfig import ModelConfig
from lsst.multiprofit.sourceconfig import ComponentGroupConfig, SourceConfig
from lsst.multiprofit.fitting.fit_psf import CatalogPsfFitterConfig
from lsst.pipe.tasks.fit_coadd_psf import CatalogExposurePsf

ROOT = os.environ.get("TESTDATA_CI_IMSIM_MINI", None)
Expand Down Expand Up @@ -260,7 +260,7 @@ def test_psf_fits(psf_fit_results):
if psf_fit_results is not None:
assert len(psf_fit_results) == n_test
for column in psf_fit_results.columns:
assert np.all(np.isfinite(psf_fit_results[column]))
assert column and np.all(np.isfinite(psf_fit_results[column]))
# TODO: Determine what checks can be done against previous values


Expand Down

0 comments on commit cace424

Please sign in to comment.