diff --git a/python/lsst/daf/butler/direct_query_driver/_result_page_converter.py b/python/lsst/daf/butler/direct_query_driver/_result_page_converter.py index 33c2988155..1d97d28df2 100644 --- a/python/lsst/daf/butler/direct_query_driver/_result_page_converter.py +++ b/python/lsst/daf/butler/direct_query_driver/_result_page_converter.py @@ -326,17 +326,35 @@ class GeneralResultPageConverter(ResultPageConverter): # numpydoc ignore=PR01 def __init__(self, spec: GeneralResultSpec, ctx: ResultPageConverterContext) -> None: self.spec = spec - - result_columns = spec.get_result_columns() + # In case `spec.include_dimension_records` is True then in addition to + # columns returned by the query we have to add columns from dimension + # records that are not returned by the query. These columns belong to + # either cached or skypix dimensions. + query_result_columns = set(spec.get_result_columns()) + output_columns = spec.get_all_result_columns() + universe = spec.dimensions.universe self.converters: list[_GeneralColumnConverter] = [] - for column in result_columns: + for column in output_columns: column_name = qt.ColumnSet.get_qualified_name(column.logical_table, column.field) - if column.field == TimespanDatabaseRepresentation.NAME: - self.converters.append(_TimespanGeneralColumnConverter(column_name, ctx.db)) + converter: _GeneralColumnConverter + if column not in query_result_columns and column.field is not None: + # This must be a field from a cached dimension record or + # skypix record. + assert isinstance(column.logical_table, str), "Do not expect AnyDatasetType here" + element = universe[column.logical_table] + if isinstance(element, SkyPixDimension): + converter = _SkypixRecordGeneralColumnConverter(element, column.field) + else: + converter = _CachedRecordGeneralColumnConverter( + element, column.field, ctx.dimension_record_cache + ) + elif column.field == TimespanDatabaseRepresentation.NAME: + converter = _TimespanGeneralColumnConverter(column_name, ctx.db) elif column.field == "ingest_date": - self.converters.append(_TimestampGeneralColumnConverter(column_name)) + converter = _TimestampGeneralColumnConverter(column_name) else: - self.converters.append(_DefaultGeneralColumnConverter(column_name)) + converter = _DefaultGeneralColumnConverter(column_name) + self.converters.append(converter) def convert(self, raw_rows: Iterable[sqlalchemy.Row]) -> GeneralResultPage: rows = [tuple(cvt.convert(row) for cvt in self.converters) for row in raw_rows] @@ -422,3 +440,47 @@ def __init__(self, name: str, db: Database): def convert(self, row: sqlalchemy.Row) -> Any: timespan = self.timespan_class.extract(row._mapping, self.name) return timespan + + +class _CachedRecordGeneralColumnConverter(_GeneralColumnConverter): + """Helper for converting result row into a field value for cached + dimension records. + + Parameters + ---------- + element : `DimensionElement` + Dimension element, must be of cached type. + field : `str` + Name of the field to extract from the dimension record. + cache : `DimensionRecordCache` + Cache for dimension records. + """ + + def __init__(self, element: DimensionElement, field: str, cache: DimensionRecordCache) -> None: + self._record_converter = _CachedDimensionRecordRowConverter(element, cache) + self._field = field + + def convert(self, row: sqlalchemy.Row) -> Any: + record = self._record_converter.convert(row) + return getattr(record, self._field) + + +class _SkypixRecordGeneralColumnConverter(_GeneralColumnConverter): + """Helper for converting result row into a field value for skypix + dimension records. + + Parameters + ---------- + element : `SkyPixDimension` + Dimension element. + field : `str` + Name of the field to extract from the dimension record. + """ + + def __init__(self, element: SkyPixDimension, field: str) -> None: + self._record_converter = _SkypixDimensionRecordRowConverter(element) + self._field = field + + def convert(self, row: sqlalchemy.Row) -> Any: + record = self._record_converter.convert(row) + return getattr(record, self._field) diff --git a/python/lsst/daf/butler/queries/_general_query_results.py b/python/lsst/daf/butler/queries/_general_query_results.py index 3a98264bff..f2215b7bdb 100644 --- a/python/lsst/daf/butler/queries/_general_query_results.py +++ b/python/lsst/daf/butler/queries/_general_query_results.py @@ -35,11 +35,11 @@ from .._dataset_ref import DatasetRef from .._dataset_type import DatasetType -from ..dimensions import DataCoordinate, DimensionGroup +from ..dimensions import DataCoordinate, DimensionElement, DimensionGroup, DimensionRecord from ._base import QueryResultsBase from .driver import QueryDriver from .result_specs import GeneralResultSpec -from .tree import QueryTree +from .tree import QueryTree, ResultColumn class GeneralResultTuple(NamedTuple): @@ -99,9 +99,20 @@ def __iter__(self) -> Iterator[dict[str, Any]]: fields (separated from dataset type name by dot). """ for page in self._driver.execute(self._spec, self._tree): - columns = tuple(str(column) for column in page.spec.get_result_columns()) + columns = tuple(str(column) for column in page.spec.get_all_result_columns()) for row in page.rows: - yield dict(zip(columns, row)) + try: + yield dict(zip(columns, row, strict=True)) + except ValueError: + message = ( + "Inconsistent size of columns and data rows in general result -- " + f"columns: {columns}, row: {row}" + ) + if len(columns) > len(row): + # Probably means that server does not know about + # `include_dimension_records` in the spec. + message += " (may mean that Butler server needs an upgrade)" + raise ValueError(message) from None def iter_tuples(self, *dataset_types: DatasetType) -> Iterator[GeneralResultTuple]: """Iterate over result rows and return data coordinate, and dataset @@ -125,14 +136,10 @@ def iter_tuples(self, *dataset_types: DatasetType) -> Iterator[GeneralResultTupl run_key = f"{dataset_type.name}.run" dataset_keys.append((dataset_type, dimensions, id_key, run_key)) for row in self: - values = tuple( - row[key] for key in itertools.chain(all_dimensions.required, all_dimensions.implied) - ) - data_coordinate = DataCoordinate.from_full_values(all_dimensions, values) + data_coordinate = self._make_data_id(row, all_dimensions) refs = [] for dataset_type, dimensions, id_key, run_key in dataset_keys: - values = tuple(row[key] for key in itertools.chain(dimensions.required, dimensions.implied)) - data_id = DataCoordinate.from_full_values(dimensions, values) + data_id = self._make_data_id(row, dimensions) refs.append(DatasetRef(dataset_type, data_id, row[run_key], id=row[id_key])) yield GeneralResultTuple(data_id=data_coordinate, refs=refs, raw_row=row) @@ -141,6 +148,19 @@ def dimensions(self) -> DimensionGroup: # Docstring inherited return self._spec.dimensions + @property + def has_dimension_records(self) -> bool: + """Whether all data IDs in this iterable contain dimension records.""" + return self._spec.include_dimension_records + + def with_dimension_records(self) -> GeneralQueryResults: + """Return a results object for which `has_dimension_records` is + `True`. + """ + if self.has_dimension_records: + return self + return self._copy(tree=self._tree, include_dimension_records=True) + def count(self, *, exact: bool = True, discard: bool = False) -> int: # Docstring inherited. return self._driver.count(self._tree, self._spec, exact=exact, discard=discard) @@ -152,3 +172,27 @@ def _copy(self, tree: QueryTree, **kwargs: Any) -> GeneralQueryResults: def _get_datasets(self) -> frozenset[str]: # Docstring inherited. return frozenset(self._spec.dataset_fields) + + def _make_data_id(self, row: dict[str, Any], dimensions: DimensionGroup) -> DataCoordinate: + values = tuple(row[key] for key in itertools.chain(dimensions.required, dimensions.implied)) + data_coordinate = DataCoordinate.from_full_values(dimensions, values) + if self.has_dimension_records: + records = { + name: self._make_dimension_record(row, dimensions.universe[name]) + for name in dimensions.elements + } + data_coordinate = data_coordinate.expanded(records) + return data_coordinate + + def _make_dimension_record(self, row: dict[str, Any], element: DimensionElement) -> DimensionRecord: + column_map = list( + zip( + element.schema.dimensions.names, + element.dimensions.names, + ) + ) + for field in element.schema.remainder.names: + column_map.append((field, str(ResultColumn(element.name, field)))) + d = {k: row[v] for k, v in column_map} + record_cls = element.RecordClass + return record_cls(**d) diff --git a/python/lsst/daf/butler/queries/result_specs.py b/python/lsst/daf/butler/queries/result_specs.py index 5bdc459f7c..6e3b1360e2 100644 --- a/python/lsst/daf/butler/queries/result_specs.py +++ b/python/lsst/daf/butler/queries/result_specs.py @@ -213,6 +213,11 @@ class GeneralResultSpec(ResultSpecBase): dataset_fields: Mapping[str, set[DatasetFieldName]] """Dataset fields included in this query.""" + include_dimension_records: bool = False + """Whether to include fields for all dimension records, in addition to + explicitly specified in `dimension_fields`. + """ + find_first: bool """Whether this query requires find-first resolution for a dataset. @@ -241,6 +246,33 @@ def get_result_columns(self) -> ColumnSet: result.dimension_fields[element_name].update(fields_for_element) for dataset_type, fields_for_dataset in self.dataset_fields.items(): result.dataset_fields[dataset_type].update(fields_for_dataset) + if self.include_dimension_records: + # This only adds record fields for non-cached and non-skypix + # elements, this is what we want when generating query. We could + # potentially add those too but it may make queries slower, so + # instead we query cached dimension records separately and add them + # to the result page in the page converter. + _add_dimension_records_to_column_set(self.dimensions, result) + return result + + def get_all_result_columns(self) -> ColumnSet: + """Return all columns that have to appear in the result. This includes + columns for all dimension records for all dimensions if + ``include_dimension_records`` is `True`. + + Returns + ------- + columns : `ColumnSet` + Full column set. + """ + dimensions = self.dimensions + result = self.get_result_columns() + if self.include_dimension_records: + for element_name in dimensions.elements: + element = dimensions.universe[element_name] + # Non-cached dimensions are already there, but it does not harm + # to add them again. + result.dimension_fields[element_name].update(element.schema.remainder.names) return result @pydantic.model_validator(mode="after") diff --git a/python/lsst/daf/butler/remote_butler/_query_driver.py b/python/lsst/daf/butler/remote_butler/_query_driver.py index 6f3035bc9e..b1ede0adf2 100644 --- a/python/lsst/daf/butler/remote_butler/_query_driver.py +++ b/python/lsst/daf/butler/remote_butler/_query_driver.py @@ -257,7 +257,7 @@ def _convert_query_result_page( def _convert_general_result(spec: GeneralResultSpec, model: GeneralResultModel) -> GeneralResultPage: """Convert GeneralResultModel to a general result page.""" - columns = spec.get_result_columns() + columns = spec.get_all_result_columns() serializers = [ columns.get_column_spec(column.logical_table, column.field).serializer() for column in columns ] diff --git a/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py b/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py index e934948374..be487af947 100644 --- a/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py +++ b/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py @@ -79,7 +79,7 @@ def convert_query_page(spec: ResultSpec, page: ResultPage) -> QueryExecuteResult def _convert_general_result(page: GeneralResultPage) -> GeneralResultModel: """Convert GeneralResultPage to a serializable model.""" - columns = page.spec.get_result_columns() + columns = page.spec.get_all_result_columns() serializers = [ columns.get_column_spec(column.logical_table, column.field).serializer() for column in columns ] diff --git a/python/lsst/daf/butler/tests/butler_queries.py b/python/lsst/daf/butler/tests/butler_queries.py index 139ce47d4b..4c1367fcf7 100644 --- a/python/lsst/daf/butler/tests/butler_queries.py +++ b/python/lsst/daf/butler/tests/butler_queries.py @@ -436,7 +436,9 @@ def test_general_query(self) -> None: self.assertEqual(len(row_tuple.refs), 1) self.assertEqual(row_tuple.refs[0].datasetType, flat) self.assertTrue(row_tuple.refs[0].dataId.hasFull()) + self.assertFalse(row_tuple.refs[0].dataId.hasRecords()) self.assertTrue(row_tuple.data_id.hasFull()) + self.assertFalse(row_tuple.data_id.hasRecords()) self.assertEqual(row_tuple.data_id.dimensions, dimensions) self.assertEqual(row_tuple.raw_row["flat.run"], "imported_g") @@ -511,6 +513,46 @@ def test_general_query(self) -> None: {Timespan(t1, t2), Timespan(t2, t3), Timespan(t3, None), Timespan.makeEmpty(), None}, ) + dimensions = butler.dimensions["detector"].minimal_group + + # Include dimension records into query. + with butler.query() as query: + query = query.join_dimensions(dimensions) + result = query.general(dimensions).order_by("detector") + rows = list(result.with_dimension_records()) + self.assertEqual( + rows[0], + { + "instrument": "Cam1", + "detector": 1, + "instrument.visit_max": 1024, + "instrument.visit_system": 1, + "instrument.exposure_max": 512, + "instrument.detector_max": 4, + "instrument.class_name": "lsst.pipe.base.Instrument", + "detector.full_name": "Aa", + "detector.name_in_raft": "a", + "detector.raft": "A", + "detector.purpose": "SCIENCE", + }, + ) + + dimensions = butler.dimensions.conform(["detector", "physical_filter"]) + + # DataIds should come with records. + with butler.query() as query: + query = query.join_dataset_search("flat", "imported_g") + result = query.general(dimensions, dataset_fields={"flat": ...}, find_first=True).order_by( + "detector" + ) + result = result.with_dimension_records() + row_tuples = list(result.iter_tuples(flat)) + self.assertEqual(len(row_tuples), 3) + for row_tuple in row_tuples: + self.assertTrue(row_tuple.data_id.hasRecords()) + self.assertEqual(len(row_tuple.refs), 1) + self.assertTrue(row_tuple.refs[0].dataId.hasRecords()) + def test_query_ingest_date(self) -> None: """Test general query returning ingest_date field.""" before_ingest = astropy.time.Time.now()