Skip to content

Commit

Permalink
Add DataLoader
Browse files Browse the repository at this point in the history
This loader needs no multiprofit data.
  • Loading branch information
taranu committed Jun 3, 2024
1 parent 7b06fb2 commit 2db0df0
Show file tree
Hide file tree
Showing 2 changed files with 236 additions and 44 deletions.
53 changes: 36 additions & 17 deletions python/lsst/meas/extensions/multiprofit/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


from abc import ABC, abstractmethod
from typing import Any, Iterable, Self
from typing import Any, Iterable, Self, Type

import astropy.table
import astropy.units as u
Expand All @@ -32,7 +32,7 @@
import pydantic
from lsst.multiprofit.plots import bands_weights_lsst, plot_model_rgb

from .rebuild_coadd_multiband import PatchCoaddRebuilder
from .rebuild_coadd_multiband import DataLoader, PatchCoaddRebuilder

__all__ = [
"ObjectTableBase",
Expand Down Expand Up @@ -304,6 +304,7 @@ def plot_blend(
rebuilder: PatchCoaddRebuilder,
idx_row_parent: int,
weights: dict[str, float] = None,
table_ref_type: Type = TruthSummaryTable,
kwargs_plot_parent: dict[str, Any] = None,
kwargs_plot_children: dict[str, Any] = None,
) -> tuple[Figure, Axes, Figure, Axes]:
Expand All @@ -317,6 +318,8 @@ def plot_blend(
The row index of the parent object in the reference SourceCatalog.
weights
Multiplicative weights by band name for RGB plots.
table_ref_type
The type of reference table to construct when downselecting.
kwargs_plot_parent
Keyword arguments to pass to make RGB plots of the parent blend.
kwargs_plot_children
Expand All @@ -339,16 +342,19 @@ def plot_blend(
kwargs_plot_children = {}
if weights is None:
weights = bands_weights_lsst

plot_chi_hist = kwargs_plot_children.pop("plot_chi_hist", True)
rebuilder_ref = rebuilder.matches[rebuilder.name_model_ref].rebuilder
observations = {
catexp.band: catexp.get_source_observation(catexp.get_catalog()[idx_row_parent], skip_flags=True)
for catexp in rebuilder_ref.catexps
}

fig_rgb, ax_rgb, fig_gs, ax_gs, *_ = plot_model_rgb(
model=None, weights=weights, observations=observations, plot_singleband=False, **kwargs_plot_parent
model=None, weights=weights, observations=observations, plot_singleband=False, plot_chi_hist=False,
**kwargs_plot_parent
)
table_within_ref = downselect_table_axis(TruthSummaryTable(table=rebuilder.reference), ax_rgb)
table_within_ref = downselect_table_axis(table_ref_type(table=rebuilder.reference), ax_rgb)
plot_objects(table_within_ref, ax_rgb, weights, table_downselected=True)

objects_primary = rebuilder.objects[rebuilder.objects["detect_isPrimary"] == True] # noqa: E712
Expand All @@ -370,7 +376,7 @@ def plot_blend(
objects_mpf = rebuilder.objects_multiprofit
objects_mpf_within = {}
for name, matched in rebuilder.matches.items():
if matched.rebuilder:
if matched.rebuilder and objects_mpf:
objects_mpf_within[name] = downselect_table_axis(
ObjectTableMultiProFit(name_model=name, table=objects_mpf),
ax_rgb,
Expand All @@ -386,24 +392,37 @@ def plot_blend(
for name, matched in rebuilder.matches.items():
print(f"Model: {name}")
rebuilder_child = matched.rebuilder
if rebuilder_child:
is_dataloader = isinstance(rebuilder_child, DataLoader)
is_scarlet = is_dataloader and (name == "scarlet")
if is_scarlet or rebuilder_child:
try:
model = rebuilder_child.make_model(idx_child)
if is_dataloader:
model = None
observations = rebuilder_child.load_deblended_object(idx_child)
else:
model = rebuilder_child.make_model(idx_child)
observations = None

_, ax_rgb_c, *_ = plot_model_rgb(
model=model, weights=weights, plot_singleband=False, **kwargs_plot_children
model=model, weights=weights, plot_singleband=False,
plot_chi_hist=(not is_dataloader) and plot_chi_hist,
observations=observations,
**kwargs_plot_children
)
ax_rgb_c0 = ax_rgb_c[0][0]
plot_objects(table_within_ref, ax_rgb_c0, weights)
plot_objects(
objects_mpf_within[name],
ax_rgb_c0,
weights,
kwargs_annotate=kwargs_annotate_obs,
kwargs_scatter=kwargs_scatter_obs,
labels_extended=labels_extended_model,
)
tab_mpf = objects_mpf_within.get(name)
if tab_mpf:
plot_objects(
tab_mpf,
ax_rgb_c0,
weights,
kwargs_annotate=kwargs_annotate_obs,
kwargs_scatter=kwargs_scatter_obs,
labels_extended=labels_extended_model,
)
plt.show()
except Exception as exc:
print(f"failed to rebuild due to {exc}")
print(f"{idx_child=} failed to rebuild due to {exc}")

return fig_rgb, ax_rgb, fig_gs, ax_gs
Loading

0 comments on commit 2db0df0

Please sign in to comment.