diff --git a/python/lsst/pipe/tasks/postprocess.py b/python/lsst/pipe/tasks/postprocess.py index 684a1e745..d83743aab 100644 --- a/python/lsst/pipe/tasks/postprocess.py +++ b/python/lsst/pipe/tasks/postprocess.py @@ -44,6 +44,7 @@ import numpy as np import pandas as pd import astropy.table +import astropy.utils.metadata import lsst.geom import lsst.pex.config as pexConfig @@ -82,6 +83,95 @@ def flattenFilters(df, noDupCols=["coord_ra", "coord_dec"], camelCase=False, inp return newDf +class TableVStack: + """A helper class for stacking astropy tables without having them all in + memory at once. + + Parameters + ---------- + capacity : `int` + Full size of the final table. + + Notes + ----- + Unlike `astropy.table.vstack`, this class requires all tables to have the + exact same columns (it's slightly more strict than even the + ``join_type="exact"`` argument to `astropy.table.vstack`). + """ + + def __init__(self, capacity): + self.index = 0 + self.capacity = capacity + self.result = None + + @classmethod + def from_handles(cls, handles): + """Construct from an iterable of + `lsst.daf.butler.DeferredDatasetHandle`. + + Parameters + ---------- + handles : `~collections.abc.Iterable` [ \ + `lsst.daf.butler.DeferredDatasetHandle` ] + Iterable of handles. Must have a storage class that supports the + "rowcount" component, which is all that will be fetched. + + Returns + ------- + vstack : `TableVStack` + An instance of this class, initialized with capacity equal to the + sum of the rowcounts of all the given table handles. + """ + capacity = sum(handle.get("rowcount") for handle in handles) + return cls(capacity=capacity) + + def extend(self, table): + """Add a single table to the stack. + + Parameters + ---------- + table : `astropy.table.Table` + An astropy table instance. + """ + if self.result is None: + self.result = astropy.table.Table() + for name in table.colnames: + column = table[name] + column_cls = type(column) + self.result[name] = column_cls.info.new_like([column], self.capacity, name=name) + self.index = len(table) + self.result.meta = table.meta.copy() + else: + next_index = self.index + len(table) + for name in table.colnames: + self.result[name][self.index:next_index] = table[name] + self.index = next_index + self.result.meta = astropy.utils.metadata.merge(self.result.meta, table.meta) + + @classmethod + def vstack_handles(cls, handles): + """Vertically stack tables represented by deferred dataset handles. + + Parameters + ---------- + handles : `~collections.abc.Iterable` [ \ + `lsst.daf.butler.DeferredDatasetHandle` ] + Iterable of handles. Must have the "ArrowAstropy" storage class + and identical columns. + + Returns + ------- + table : `astropy.table.Table` + Concatenated table with the same columns as each input table and + the rows of all of them. + """ + handles = tuple(handles) # guard against single-pass iterators + vstack = cls.from_handles(handles) + for handle in handles: + vstack.extend(handle.get()) + return vstack.result + + class WriteObjectTableConnections(pipeBase.PipelineTaskConnections, defaultTemplates={"coaddName": "deep"}, dimensions=("tract", "patch", "skymap")): @@ -932,6 +1022,7 @@ class ConsolidateObjectTableConnections(pipeBase.PipelineTaskConnections, storageClass="ArrowAstropy", dimensions=("tract", "patch", "skymap"), multiple=True, + deferLoad=True, ) outputCatalog = connectionTypes.Output( doc="Pre-tract horizontal concatenation of the input objectTables", @@ -965,7 +1056,7 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): inputs = butlerQC.get(inputRefs) self.log.info("Concatenating %s per-patch Object Tables", len(inputs["inputCatalogs"])) - table = astropy.table.vstack(inputs["inputCatalogs"], join_type="exact") + table = TableVStack.vstack_handles(inputs["inputCatalogs"]) butlerQC.put(pipeBase.Struct(outputCatalog=table), outputRefs) @@ -1142,7 +1233,8 @@ class ConsolidateSourceTableConnections(pipeBase.PipelineTaskConnections, name="{catalogType}sourceTable", storageClass="ArrowAstropy", dimensions=("instrument", "visit", "detector"), - multiple=True + multiple=True, + deferLoad=True, ) outputCatalog = connectionTypes.Output( doc="Per-visit concatenation of Source Table", @@ -1175,7 +1267,7 @@ def runQuantum(self, butlerQC, inputRefs, outputRefs): inputs = butlerQC.get(inputRefs) self.log.info("Concatenating %s per-detector Source Tables", len(inputs["inputCatalogs"])) - table = astropy.table.vstack(inputs["inputCatalogs"], join_type="exact") + table = TableVStack.vstack_handles(inputs["inputCatalogs"]) butlerQC.put(pipeBase.Struct(outputCatalog=table), outputRefs)