Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-47980: Add an option to include dimension records into general query result #1135

Merged
merged 2 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -326,21 +326,48 @@

def __init__(self, spec: GeneralResultSpec, ctx: ResultPageConverterContext) -> None:
self.spec = spec

result_columns = spec.get_result_columns()
# In case `spec.include_dimension_records` is True then in addition to
# columns returned by the query we have to add columns from dimension
# records that are not returned by the query. These columns belong to
# either cached or skypix dimensions.
andy-slac marked this conversation as resolved.
Show resolved Hide resolved
columns = spec.get_result_columns()
universe = spec.dimensions.universe
self.converters: list[_GeneralColumnConverter] = []
for column in result_columns:
self.record_converters: dict[DimensionElement, _DimensionRecordRowConverter] = {}
for column in columns:
column_name = qt.ColumnSet.get_qualified_name(column.logical_table, column.field)
converter: _GeneralColumnConverter
if column.field == TimespanDatabaseRepresentation.NAME:
self.converters.append(_TimespanGeneralColumnConverter(column_name, ctx.db))
converter = _TimespanGeneralColumnConverter(column_name, ctx.db)
elif column.field == "ingest_date":
self.converters.append(_TimestampGeneralColumnConverter(column_name))
converter = _TimestampGeneralColumnConverter(column_name)
else:
self.converters.append(_DefaultGeneralColumnConverter(column_name))
converter = _DefaultGeneralColumnConverter(column_name)
self.converters.append(converter)

if spec.include_dimension_records:
universe = self.spec.dimensions.universe
for element_name in self.spec.dimensions.elements:
element = universe[element_name]
if isinstance(element, SkyPixDimension):
self.record_converters[element] = _SkypixDimensionRecordRowConverter(element)

Check warning on line 353 in python/lsst/daf/butler/direct_query_driver/_result_page_converter.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/direct_query_driver/_result_page_converter.py#L353

Added line #L353 was not covered by tests
elif element.is_cached:
self.record_converters[element] = _CachedDimensionRecordRowConverter(
element, ctx.dimension_record_cache
)

def convert(self, raw_rows: Iterable[sqlalchemy.Row]) -> GeneralResultPage:
rows = [tuple(cvt.convert(row) for cvt in self.converters) for row in raw_rows]
return GeneralResultPage(spec=self.spec, rows=rows)
rows = []
dimension_records = None
if self.spec.include_dimension_records:
dimension_records = {element: DimensionRecordSet(element) for element in self.record_converters}
for row in raw_rows:
rows.append(tuple(cvt.convert(row) for cvt in self.converters))
if dimension_records:
for element, converter in self.record_converters.items():
dimension_records[element].add(converter.convert(row))

return GeneralResultPage(spec=self.spec, rows=rows, dimension_records=dimension_records)


class _GeneralColumnConverter:
Expand Down
100 changes: 86 additions & 14 deletions python/lsst/daf/butler/queries/_general_query_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@

from .._dataset_ref import DatasetRef
from .._dataset_type import DatasetType
from ..dimensions import DataCoordinate, DimensionGroup
from ..dimensions import DataCoordinate, DimensionElement, DimensionGroup, DimensionRecord, DimensionRecordSet
from ._base import QueryResultsBase
from .driver import QueryDriver
from .result_specs import GeneralResultSpec
from .tree import QueryTree
from .tree import QueryTree, ResultColumn


class GeneralResultTuple(NamedTuple):
Expand Down Expand Up @@ -101,7 +101,11 @@
for page in self._driver.execute(self._spec, self._tree):
columns = tuple(str(column) for column in page.spec.get_result_columns())
for row in page.rows:
yield dict(zip(columns, row))
result = dict(zip(columns, row, strict=True))
if page.dimension_records:
records = self._get_cached_dimension_records(result, page.dimension_records)
self._add_dimension_records(result, records)
yield result

def iter_tuples(self, *dataset_types: DatasetType) -> Iterator[GeneralResultTuple]:
"""Iterate over result rows and return data coordinate, and dataset
Expand All @@ -124,23 +128,40 @@
id_key = f"{dataset_type.name}.dataset_id"
run_key = f"{dataset_type.name}.run"
dataset_keys.append((dataset_type, dimensions, id_key, run_key))
for row in self:
values = tuple(
row[key] for key in itertools.chain(all_dimensions.required, all_dimensions.implied)
)
data_coordinate = DataCoordinate.from_full_values(all_dimensions, values)
refs = []
for dataset_type, dimensions, id_key, run_key in dataset_keys:
values = tuple(row[key] for key in itertools.chain(dimensions.required, dimensions.implied))
data_id = DataCoordinate.from_full_values(dimensions, values)
refs.append(DatasetRef(dataset_type, data_id, row[run_key], id=row[id_key]))
yield GeneralResultTuple(data_id=data_coordinate, refs=refs, raw_row=row)
for page in self._driver.execute(self._spec, self._tree):
columns = tuple(str(column) for column in page.spec.get_result_columns())
for page_row in page.rows:
row = dict(zip(columns, page_row, strict=True))
if page.dimension_records:
cached_records = self._get_cached_dimension_records(row, page.dimension_records)
self._add_dimension_records(row, cached_records)
else:
cached_records = {}
data_coordinate = self._make_data_id(row, all_dimensions, cached_records)
refs = []
for dataset_type, dimensions, id_key, run_key in dataset_keys:
data_id = data_coordinate.subset(dimensions)
refs.append(DatasetRef(dataset_type, data_id, row[run_key], id=row[id_key]))
yield GeneralResultTuple(data_id=data_coordinate, refs=refs, raw_row=row)

@property
def dimensions(self) -> DimensionGroup:
# Docstring inherited
return self._spec.dimensions

@property
def has_dimension_records(self) -> bool:
"""Whether all data IDs in this iterable contain dimension records."""
return self._spec.include_dimension_records

def with_dimension_records(self) -> GeneralQueryResults:
"""Return a results object for which `has_dimension_records` is
`True`.
"""
if self.has_dimension_records:
return self

Check warning on line 162 in python/lsst/daf/butler/queries/_general_query_results.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/queries/_general_query_results.py#L162

Added line #L162 was not covered by tests
return self._copy(tree=self._tree, include_dimension_records=True)

def count(self, *, exact: bool = True, discard: bool = False) -> int:
# Docstring inherited.
return self._driver.count(self._tree, self._spec, exact=exact, discard=discard)
Expand All @@ -152,3 +173,54 @@
def _get_datasets(self) -> frozenset[str]:
# Docstring inherited.
return frozenset(self._spec.dataset_fields)

def _make_data_id(
self,
row: dict[str, Any],
dimensions: DimensionGroup,
cached_row_records: dict[DimensionElement, DimensionRecord],
) -> DataCoordinate:
values = tuple(row[key] for key in itertools.chain(dimensions.required, dimensions.implied))
data_coordinate = DataCoordinate.from_full_values(dimensions, values)
if self.has_dimension_records:
records = {}
for name in dimensions.elements:
element = dimensions.universe[name]
record = cached_row_records.get(element)
if record is None:
record = self._make_dimension_record(row, dimensions.universe[name])

Check warning on line 191 in python/lsst/daf/butler/queries/_general_query_results.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/queries/_general_query_results.py#L191

Added line #L191 was not covered by tests
records[name] = record
data_coordinate = data_coordinate.expanded(records)
return data_coordinate

def _make_dimension_record(self, row: dict[str, Any], element: DimensionElement) -> DimensionRecord:
column_map = list(

Check warning on line 197 in python/lsst/daf/butler/queries/_general_query_results.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/queries/_general_query_results.py#L197

Added line #L197 was not covered by tests
zip(
element.schema.dimensions.names,
element.dimensions.names,
)
)
for field in element.schema.remainder.names:
column_map.append((field, str(ResultColumn(element.name, field))))
d = {k: row[v] for k, v in column_map}
record_cls = element.RecordClass
return record_cls(**d)

Check warning on line 207 in python/lsst/daf/butler/queries/_general_query_results.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/queries/_general_query_results.py#L204-L207

Added lines #L204 - L207 were not covered by tests

def _get_cached_dimension_records(
self, row: dict[str, Any], dimension_records: dict[DimensionElement, DimensionRecordSet]
) -> dict[DimensionElement, DimensionRecord]:
"""Find cached dimension records matching this row."""
records = {}
for element, element_records in dimension_records.items():
required_values = tuple(row[key] for key in element.required.names)
records[element] = element_records.find_with_required_values(required_values)
return records

def _add_dimension_records(
self, row: dict[str, Any], records: dict[DimensionElement, DimensionRecord]
) -> None:
"""Extend row with the fields from cached dimension records."""
for element, record in records.items():
for name, value in record.toDict().items():
if name not in element.schema.required.names:
row[f"{element.name}.{name}"] = value
5 changes: 5 additions & 0 deletions python/lsst/daf/butler/queries/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from ..dimensions import (
DataCoordinate,
DataIdValue,
DimensionElement,
DimensionGroup,
DimensionRecord,
DimensionRecordSet,
Expand Down Expand Up @@ -120,6 +121,10 @@ class GeneralResultPage:
# spec.get_result_columns().
rows: list[tuple[Any, ...]]

# This map contains dimension records for cached and skypix elements,
# and only when spec.include_dimension_records is True.
dimension_records: dict[DimensionElement, DimensionRecordSet] | None


ResultPage: TypeAlias = Union[
DataCoordinateResultPage, DimensionRecordResultPage, DatasetRefResultPage, GeneralResultPage
Expand Down
11 changes: 11 additions & 0 deletions python/lsst/daf/butler/queries/result_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ class GeneralResultSpec(ResultSpecBase):
dataset_fields: Mapping[str, set[DatasetFieldName]]
"""Dataset fields included in this query."""

include_dimension_records: bool = False
"""Whether to include fields for all dimension records, in addition to
explicitly specified in `dimension_fields`.
"""

find_first: bool
"""Whether this query requires find-first resolution for a dataset.

Expand Down Expand Up @@ -241,6 +246,12 @@ def get_result_columns(self) -> ColumnSet:
result.dimension_fields[element_name].update(fields_for_element)
for dataset_type, fields_for_dataset in self.dataset_fields.items():
result.dataset_fields[dataset_type].update(fields_for_dataset)
if self.include_dimension_records:
# This only adds record fields for non-cached and non-skypix
# elements, this is what we want when generating query. When
# `include_dimension_records` is True, dimension records for cached
# and skypix elements are added to result pages by page converter.
_add_dimension_records_to_column_set(self.dimensions, result)
return result

@pydantic.model_validator(mode="after")
Expand Down
32 changes: 29 additions & 3 deletions python/lsst/daf/butler/remote_butler/_query_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,14 @@

from ...butler import Butler
from .._dataset_type import DatasetType
from ..dimensions import DataCoordinate, DataIdValue, DimensionGroup, DimensionRecord, DimensionUniverse
from ..dimensions import (
DataCoordinate,
DataIdValue,
DimensionGroup,
DimensionRecord,
DimensionRecordSet,
DimensionUniverse,
)
from ..queries.driver import (
DataCoordinateResultPage,
DatasetRefResultPage,
Expand Down Expand Up @@ -257,12 +264,31 @@

def _convert_general_result(spec: GeneralResultSpec, model: GeneralResultModel) -> GeneralResultPage:
"""Convert GeneralResultModel to a general result page."""
if spec.include_dimension_records:
# dimension_records must not be None when `include_dimension_records`
# is True, but it will be None if remote server was not upgraded.
if model.dimension_records is None:
raise ValueError(

Check warning on line 271 in python/lsst/daf/butler/remote_butler/_query_driver.py

View check run for this annotation

Codecov / codecov/patch

python/lsst/daf/butler/remote_butler/_query_driver.py#L271

Added line #L271 was not covered by tests
"Missing dimension records in general result -- " "it is likely that server needs an upgrade."
)

columns = spec.get_result_columns()
serializers = [
columns.get_column_spec(column.logical_table, column.field).serializer() for column in columns
]
rows = [
tuple(serializer.deserialize(value) for value, serializer in zip(row, serializers))
tuple(serializer.deserialize(value) for value, serializer in zip(row, serializers, strict=True))
for row in model.rows
]
return GeneralResultPage(spec=spec, rows=rows)

universe = spec.dimensions.universe
dimension_records = None
if model.dimension_records is not None:
dimension_records = {}
for name, records in model.dimension_records.items():
element = universe[name]
dimension_records[element] = DimensionRecordSet(
element, (DimensionRecord.from_simple(r, universe) for r in records)
)

return GeneralResultPage(spec=spec, rows=rows, dimension_records=dimension_records)
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ def _convert_general_result(page: GeneralResultPage) -> GeneralResultModel:
columns.get_column_spec(column.logical_table, column.field).serializer() for column in columns
]
rows = [
tuple(serializer.serialize(value) for value, serializer in zip(row, serializers)) for row in page.rows
tuple(serializer.serialize(value) for value, serializer in zip(row, serializers, strict=True))
for row in page.rows
]
return GeneralResultModel(rows=rows)
dimension_records = None
if page.dimension_records is not None:
dimension_records = {
element.name: [record.to_simple() for record in records]
for element, records in page.dimension_records.items()
}
return GeneralResultModel(rows=rows, dimension_records=dimension_records)
4 changes: 4 additions & 0 deletions python/lsst/daf/butler/remote_butler/server_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,10 @@ class GeneralResultModel(pydantic.BaseModel):

type: Literal["general"] = "general"
rows: list[tuple[Any, ...]]
# Dimension records indexed by element name, only cached and skypix
# elements are included. Default is used for compatibility with older
# servers that do not set this field.
dimension_records: dict[str, list[SerializedDimensionRecord]] | None = None


class QueryErrorResultModel(pydantic.BaseModel):
Expand Down
42 changes: 42 additions & 0 deletions python/lsst/daf/butler/tests/butler_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,9 @@ def test_general_query(self) -> None:
self.assertEqual(len(row_tuple.refs), 1)
self.assertEqual(row_tuple.refs[0].datasetType, flat)
self.assertTrue(row_tuple.refs[0].dataId.hasFull())
self.assertFalse(row_tuple.refs[0].dataId.hasRecords())
self.assertTrue(row_tuple.data_id.hasFull())
self.assertFalse(row_tuple.data_id.hasRecords())
self.assertEqual(row_tuple.data_id.dimensions, dimensions)
self.assertEqual(row_tuple.raw_row["flat.run"], "imported_g")

Expand Down Expand Up @@ -511,6 +513,46 @@ def test_general_query(self) -> None:
{Timespan(t1, t2), Timespan(t2, t3), Timespan(t3, None), Timespan.makeEmpty(), None},
)

dimensions = butler.dimensions["detector"].minimal_group

# Include dimension records into query.
with butler.query() as query:
query = query.join_dimensions(dimensions)
result = query.general(dimensions).order_by("detector")
rows = list(result.with_dimension_records())
self.assertEqual(
rows[0],
{
"instrument": "Cam1",
"detector": 1,
"instrument.visit_max": 1024,
"instrument.visit_system": 1,
"instrument.exposure_max": 512,
"instrument.detector_max": 4,
"instrument.class_name": "lsst.pipe.base.Instrument",
"detector.full_name": "Aa",
"detector.name_in_raft": "a",
"detector.raft": "A",
"detector.purpose": "SCIENCE",
},
)

dimensions = butler.dimensions.conform(["detector", "physical_filter"])

# DataIds should come with records.
with butler.query() as query:
query = query.join_dataset_search("flat", "imported_g")
result = query.general(dimensions, dataset_fields={"flat": ...}, find_first=True).order_by(
"detector"
)
result = result.with_dimension_records()
row_tuples = list(result.iter_tuples(flat))
self.assertEqual(len(row_tuples), 3)
for row_tuple in row_tuples:
self.assertTrue(row_tuple.data_id.hasRecords())
self.assertEqual(len(row_tuple.refs), 1)
self.assertTrue(row_tuple.refs[0].dataId.hasRecords())

def test_query_ingest_date(self) -> None:
"""Test general query returning ingest_date field."""
before_ingest = astropy.time.Time.now()
Expand Down
Loading