Skip to content

Commit

Permalink
Merge branch 'tickets/DM-43906'
Browse files Browse the repository at this point in the history
  • Loading branch information
taranu committed Jul 16, 2024
2 parents b3c3077 + b034810 commit 07d6d96
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 33 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ keywords = [
dependencies = [
"astropy",
"galsim",
"gauss2d",
"gauss2dfit",
"lsst-gauss2d",
"lsst-gauss2dfit",
"lsst-multiprofit",
"matplotlib",
"numpy",
Expand Down
10 changes: 5 additions & 5 deletions python/lsst/meas/extensions/multiprofit/fit_coadd_multiband.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from functools import cached_property
from typing import Any, Iterable, Mapping, Sequence

import gauss2d as g2
import gauss2d.fit as g2f
import lsst.gauss2d as g2
import lsst.gauss2d.fit as g2f
import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
import lsst.pipe.tasks.fit_coadd_multiband as fitMB
Expand Down Expand Up @@ -319,7 +319,7 @@ def get_psf_model(self, source):
param.value = math.sqrt(param.value**2 - sigma_subtract_sq)
return psf_model

def get_source_observation(self, source, **kwargs) -> g2f.Observation:
def get_source_observation(self, source, **kwargs) -> g2f.ObservationD:
if not kwargs.get("skip_flags"):
if (not source["detect_isPrimary"]) or source["merge_peak_sky"]:
raise NotPrimaryError(f"source {source[self.config_fit.column_id]} has invalid flags for fit")
Expand Down Expand Up @@ -374,7 +374,7 @@ def get_source_observation(self, source, **kwargs) -> g2f.Observation:

coordsys = g2.CoordinateSystem(1.0, 1.0, x_min_bbox, y_min_bbox)

obs = g2f.Observation(
obs = g2f.ObservationD(
image=g2.ImageD(img, coordsys),
sigma_inv=g2.ImageD(sigma_inv, coordsys),
mask_inv=g2.ImageB(mask, coordsys),
Expand Down Expand Up @@ -450,7 +450,7 @@ def get_model_radec(self, source: Mapping[str, Any], cen_x: float, cen_y: float)

def initialize_model(
self,
model: g2f.Model,
model: g2f.ModelD,
source: Mapping[str, Any],
catexps: list[CatalogExposureSourcesABC],
values_init: Mapping[g2f.ParameterD, float] | None = None,
Expand Down
6 changes: 3 additions & 3 deletions python/lsst/meas/extensions/multiprofit/fit_coadd_psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

import gauss2d.fit as g2f
import lsst.gauss2d.fit as g2f
import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
import lsst.pipe.tasks.fit_coadd_psf as fitCP
Expand Down Expand Up @@ -89,12 +89,12 @@ def check_source(self, source, config):

def initialize_model(
self,
model: g2f.Model,
model: g2f.ModelD,
config_data: CatalogPsfFitterConfigData,
limits_x: g2f.LimitsD = None,
limits_y: g2f.LimitsD = None,
) -> None:
"""Initialize a Model for a single source row.
"""Initialize a ModelD for a single source row.
Parameters
----------
Expand Down
14 changes: 7 additions & 7 deletions python/lsst/meas/extensions/multiprofit/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"ObjectTableBase",
"TruthSummaryTable",
"ObjectTable",
"ObjectTableCModel",
"ObjectTableCModelD",
"ObjectTableMultiProFit",
"ObjectTablePsf",
"downselect_table",
Expand All @@ -52,7 +52,7 @@
FigureAxes = tuple[Figure, Axes]


class ObjectTableBase(ABC, pydantic.BaseModel):
class ObjectTableBase(ABC, pydantic.BaseModelD):
"""Base class for retrieving columns from tract-based object tables."""

model_config = pydantic.ConfigDict(arbitrary_types_allowed=True, frozen=True)
Expand Down Expand Up @@ -155,11 +155,11 @@ def get_y(self):
return self.table["y"]


class ObjectTableCModel(ObjectTable):
"""Class for retrieving CModel fluxes from objectTable_tract."""
class ObjectTableCModelD(ObjectTable):
"""Class for retrieving CModelD fluxes from objectTable_tract."""

def get_flux(self, band: str) -> np.ndarray:
return self.table[f"{band}_cModelFlux"]
return self.table[f"{band}_cModelDFlux"]


class ObjectTableMultiProFit(ObjectTableBase):
Expand Down Expand Up @@ -360,7 +360,7 @@ def plot_blend(
objects_primary = rebuilder.objects[rebuilder.objects["detect_isPrimary"] == True] # noqa: E712
kwargs_annotate_obs = dict(color="white", fontsize=14, ha="right", va="top")
kwargs_scatter_obs = dict(c="white", marker="x", s=70)
table_within_cmodel = downselect_table_axis(ObjectTableCModel(table=objects_primary), ax_rgb)
table_within_cmodel = downselect_table_axis(ObjectTableCModelD(table=objects_primary), ax_rgb)
labels_extended_model = ("C", "E")
plot_objects(
table_within_cmodel,
Expand Down Expand Up @@ -390,7 +390,7 @@ def plot_blend(

for idx_child in idx_children:
for name, matched in rebuilder.matches.items():
print(f"Model: {name}")
print(f"ModelD: {name}")
rebuilder_child = matched.rebuilder
is_dataloader = isinstance(rebuilder_child, DataLoader)
is_scarlet = is_dataloader and (name == "scarlet")
Expand Down
28 changes: 14 additions & 14 deletions python/lsst/meas/extensions/multiprofit/rebuild_coadd_multiband.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

__all__ = ["ModelRebuilder", "PatchModelMatches", "PatchCoaddRebuilder"]
__all__ = ["ModelDRebuilder", "PatchModelDMatches", "PatchCoaddRebuilder"]

from functools import cached_property
from typing import Iterable

import astropy.table
import astropy.units as u
import gauss2d.fit as g2f
import lsst.gauss2d.fit as g2f
import lsst.afw.table as afwTable
import lsst.daf.butler as dafButler
import lsst.geom as geom
Expand Down Expand Up @@ -142,7 +142,7 @@ def get_radec_unit(table: astropy.table.Table, coord_ra: str, coord_dec: str, de
return unit_ra


class DataLoader(pydantic.BaseModel):
class DataLoader(pydantic.BaseModelD):
"""A collection of data that can be used to rebuild models."""

model_config = pydantic.ConfigDict(arbitrary_types_allowed=True, frozen=True)
Expand Down Expand Up @@ -238,7 +238,7 @@ def from_butler(
def load_deblended_object(
self,
idx_row: int,
) -> list[g2f.Observation]:
) -> list[g2f.ObservationD]:
"""Load a deblended object from catexps.
Parameters
Expand All @@ -257,7 +257,7 @@ def load_deblended_object(
return observations


class ModelRebuilder(DataLoader):
class ModelDRebuilder(DataLoader):
"""A rebuilder of MultiProFit models from their inputs and best-fit
parameter values.
"""
Expand Down Expand Up @@ -292,7 +292,7 @@ def from_quantumGraph(
Returns
-------
rebuilder
A ModelRebuilder instance initialized with the necessary kwargs.
A ModelDRebuilder instance initialized with the necessary kwargs.
"""
if dataId is None:
quantum = next(iter(quantumgraph.outputQuanta)).quantum
Expand Down Expand Up @@ -343,8 +343,8 @@ def make_model(
idx_row: int,
config_data: CatalogSourceFitterConfigData = None,
init: bool = True,
) -> g2f.Model:
"""Make a Model for a single row from the originally fitted catalog.
) -> g2f.ModelD:
"""Make a ModelD for a single row from the originally fitted catalog.
Parameters
----------
Expand Down Expand Up @@ -398,22 +398,22 @@ def set_model(self, idx_row: int, config_data: CatalogSourceFitterConfigData = N
param.value = row[f"{prefix}{key}"] + offsets.get(type(param), 0.0)


class PatchModelMatches(pydantic.BaseModel):
class PatchModelDMatches(pydantic.BaseModelD):
"""Storage for MultiProFit tables matched to a reference catalog."""

model_config = pydantic.ConfigDict(arbitrary_types_allowed=True, frozen=True)

matches: astropy.table.Table | None = pydantic.Field(doc="Catalogs of matches")
quantumgraph: QuantumGraph | None = pydantic.Field(doc="Quantum graph for fit task")
rebuilder: DataLoader | ModelRebuilder | None = pydantic.Field(doc="MultiProFit object model rebuilder")
rebuilder: DataLoader | ModelDRebuilder | None = pydantic.Field(doc="MultiProFit object model rebuilder")


class PatchCoaddRebuilder(pydantic.BaseModel):
class PatchCoaddRebuilder(pydantic.BaseModelD):
"""A rebuilder for patch-level coadd catalog/exposure fits."""

model_config = pydantic.ConfigDict(arbitrary_types_allowed=True, frozen=True)

matches: dict[str, PatchModelMatches] = pydantic.Field("Model matches by algorithm name")
matches: dict[str, PatchModelDMatches] = pydantic.Field("ModelD matches by algorithm name")
name_model_ref: str = pydantic.Field(doc="The name of the reference model in matches")
objects: astropy.table.Table = pydantic.Field(doc="Object table")
objects_multiprofit: astropy.table.Table | None = pydantic.Field(doc="Object table for MultiProFit fits")
Expand Down Expand Up @@ -553,13 +553,13 @@ def from_butler(
matched["patch"][np.where(unmatched)[0]] = patches_unmatched
matched = matched[matched["patch"] == patch]
rebuilder = (
ModelRebuilder.from_quantumGraph(butler, quantumgraph, dataId=dataId)
ModelDRebuilder.from_quantumGraph(butler, quantumgraph, dataId=dataId)
if is_mpf
else DataLoader.from_butler(
butler, data_id=dataId, bands=bands, collections=[collection_merged]
)
)
matches_name[name] = PatchModelMatches(
matches_name[name] = PatchModelDMatches(
matches=matched, quantumgraph=quantumgraph, rebuilder=rebuilder
)
return cls(
Expand Down
3 changes: 1 addition & 2 deletions tests/test_fit_coadd.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import os

import gauss2d.fit as g2f
import lsst.gauss2d.fit as g2f
import lsst.meas.extensions.multiprofit.fit_coadd_multiband as fitCMB
import lsst.meas.extensions.multiprofit.fit_coadd_psf as fitCP
import numpy as np
Expand Down Expand Up @@ -240,7 +240,6 @@ def source_fit_ser_shapelet_psf_results(
source_fit_ser_config.action_psf = fitCMB.SourceTablePsfComponentsAction()
task = fitCMB.MultiProFitSourceTask(config=source_fit_ser_config)
results = task.run(catalog_multi=catalog, catexps=[catexp])
source_fit_ser_config.action_psf = fitCMB.PsfComponentsAction
return results.output.to_pandas()


Expand Down

0 comments on commit 07d6d96

Please sign in to comment.