diff --git a/python/lsst/daf/butler/_butler.py b/python/lsst/daf/butler/_butler.py index f83a4ad347..937ff8e6c9 100644 --- a/python/lsst/daf/butler/_butler.py +++ b/python/lsst/daf/butler/_butler.py @@ -482,6 +482,11 @@ def get_known_repos(cls) -> set[str]: """ return ButlerRepoIndex.get_known_repos() + @abstractmethod + def _caching_context(self) -> AbstractContextManager[None]: + """Context manager that enables caching.""" + raise NotImplementedError() + @abstractmethod def transaction(self) -> AbstractContextManager[None]: """Context manager supporting `Butler` transactions. diff --git a/python/lsst/daf/butler/_named.py b/python/lsst/daf/butler/_named.py index 55850645bd..e3ff851b72 100644 --- a/python/lsst/daf/butler/_named.py +++ b/python/lsst/daf/butler/_named.py @@ -266,7 +266,7 @@ def freeze(self) -> NamedKeyMapping[K, V]: to a new variable (and considering any previous references invalidated) should allow for more accurate static type checking. """ - if not isinstance(self._dict, MappingProxyType): + if not isinstance(self._dict, MappingProxyType): # type: ignore[unreachable] self._dict = MappingProxyType(self._dict) # type: ignore return self @@ -578,7 +578,7 @@ def freeze(self) -> NamedValueAbstractSet[K]: to a new variable (and considering any previous references invalidated) should allow for more accurate static type checking. """ - if not isinstance(self._mapping, MappingProxyType): + if not isinstance(self._mapping, MappingProxyType): # type: ignore[unreachable] self._mapping = MappingProxyType(self._mapping) # type: ignore return self diff --git a/python/lsst/daf/butler/_registry_shim.py b/python/lsst/daf/butler/_registry_shim.py index 67f50a16e1..4d2653abe0 100644 --- a/python/lsst/daf/butler/_registry_shim.py +++ b/python/lsst/daf/butler/_registry_shim.py @@ -102,6 +102,10 @@ def refresh(self) -> None: # Docstring inherited from a base class. self._registry.refresh() + def caching_context(self) -> contextlib.AbstractContextManager[None]: + # Docstring inherited from a base class. + return self._butler._caching_context() + @contextlib.contextmanager def transaction(self, *, savepoint: bool = False) -> Iterator[None]: # Docstring inherited from a base class. diff --git a/python/lsst/daf/butler/direct_butler.py b/python/lsst/daf/butler/direct_butler.py index 6b70ecb1e1..80bcf5d1ae 100644 --- a/python/lsst/daf/butler/direct_butler.py +++ b/python/lsst/daf/butler/direct_butler.py @@ -299,6 +299,10 @@ def isWriteable(self) -> bool: # Docstring inherited. return self._registry.isWriteable() + def _caching_context(self) -> contextlib.AbstractContextManager[None]: + """Context manager that enables caching.""" + return self._registry.caching_context() + @contextlib.contextmanager def transaction(self) -> Iterator[None]: """Context manager supporting `Butler` transactions. diff --git a/python/lsst/daf/butler/persistence_context.py b/python/lsst/daf/butler/persistence_context.py index b366564d45..8830fd9550 100644 --- a/python/lsst/daf/butler/persistence_context.py +++ b/python/lsst/daf/butler/persistence_context.py @@ -33,7 +33,7 @@ import uuid from collections.abc import Callable, Hashable from contextvars import Context, ContextVar, Token, copy_context -from typing import TYPE_CHECKING, ParamSpec, TypeVar, cast +from typing import TYPE_CHECKING, ParamSpec, TypeVar if TYPE_CHECKING: from ._dataset_ref import DatasetRef @@ -198,4 +198,4 @@ def run(self, function: Callable[_Q, _T], *args: _Q.args, **kwargs: _Q.kwargs) - # cast the result as we know this is exactly what the return type will # be. result = self._ctx.run(self._functionRunner, function, *args, **kwargs) # type: ignore - return cast(_T, result) + return result diff --git a/python/lsst/daf/butler/registry/_caching_context.py b/python/lsst/daf/butler/registry/_caching_context.py new file mode 100644 index 0000000000..2674a54f69 --- /dev/null +++ b/python/lsst/daf/butler/registry/_caching_context.py @@ -0,0 +1,79 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ["CachingContext"] + +from typing import TYPE_CHECKING + +from ._collection_record_cache import CollectionRecordCache +from ._collection_summary_cache import CollectionSummaryCache +from ._dataset_type_cache import DatasetTypeCache + +if TYPE_CHECKING: + from .interfaces import DatasetRecordStorage + + +class CachingContext: + """Collection of caches for various types of records retrieved from + database. + + Notes + ----- + Caching is usually disabled for most of the record types, but it can be + explicitly and temporarily enabled in some context (e.g. quantum graph + building) using Registry method. This class is a collection of cache + instances which will be `None` when caching is disabled. Instance of this + class is passed to the relevant managers that can use it to query or + populate caches when caching is enabled. + + Dataset type cache is always enabled for now, this avoids the need for + explicitly enabling caching in pipetask executors. + """ + + collection_records: CollectionRecordCache | None = None + """Cache for collection records (`CollectionRecordCache`).""" + + collection_summaries: CollectionSummaryCache | None = None + """Cache for collection summary records (`CollectionSummaryCache`).""" + + dataset_types: DatasetTypeCache[DatasetRecordStorage] + """Cache for dataset types, never disabled (`DatasetTypeCache`).""" + + def __init__(self) -> None: + self.dataset_types = DatasetTypeCache() + + def enable(self) -> None: + """Enable caches, initializes all caches.""" + self.collection_records = CollectionRecordCache() + self.collection_summaries = CollectionSummaryCache() + + def disable(self) -> None: + """Disable caches, sets all caches to `None`.""" + self.collection_records = None + self.collection_summaries = None diff --git a/python/lsst/daf/butler/registry/_collection_record_cache.py b/python/lsst/daf/butler/registry/_collection_record_cache.py new file mode 100644 index 0000000000..da00bb6a35 --- /dev/null +++ b/python/lsst/daf/butler/registry/_collection_record_cache.py @@ -0,0 +1,165 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("CollectionRecordCache",) + +from collections.abc import Iterable, Iterator +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .interfaces import CollectionRecord + + +class CollectionRecordCache: + """Cache for collection records. + + Notes + ----- + This class stores collection records and can retrieve them using either + collection name or collection key. One complication is that key type can be + either collection name or a distinct integer value. To optimize storage + when the key is the same as collection name, this class only stores key to + record mapping when key is of a non-string type. + + In come contexts (e.g. ``resolve_wildcard``) a full list of collections is + needed. To signify that cache content can be used in such contexts, cache + defines special ``full`` flag that needs to be set by client. + """ + + def __init__(self) -> None: + self._by_name: dict[str, CollectionRecord] = {} + # This dict is only used for records whose key type is not str. + self._by_key: dict[Any, CollectionRecord] = {} + self._full = False + + @property + def full(self) -> bool: + """`True` if cache holds all known collection records (`bool`).""" + return self._full + + def add(self, record: CollectionRecord) -> None: + """Add one record to the cache. + + Parameters + ---------- + record : `CollectionRecord` + Collection record, replaces any existing record with the same name + or key. + """ + # In case we replace same record name with different key, find the + # existing record and drop its key. + if (old_record := self._by_name.get(record.name)) is not None: + self._by_key.pop(old_record.key) + if (old_record := self._by_key.get(record.key)) is not None: + self._by_name.pop(old_record.name) + self._by_name[record.name] = record + if not isinstance(record.key, str): + self._by_key[record.key] = record + + def set(self, records: Iterable[CollectionRecord], *, full: bool = False) -> None: + """Replace cache contents with the new set of records. + + Parameters + ---------- + records : `~collections.abc.Iterable` [`CollectionRecord`] + Collection records. + full : `bool` + If `True` then ``records`` contain all known collection records. + """ + self.clear() + for record in records: + self._by_name[record.name] = record + if not isinstance(record.key, str): + self._by_key[record.key] = record + self._full = full + + def clear(self) -> None: + """Remove all records from the cache.""" + self._by_name = {} + self._by_key = {} + self._full = False + + def discard(self, record: CollectionRecord) -> None: + """Remove single record from the cache. + + Parameters + ---------- + record : `CollectionRecord` + Collection record to remove. + """ + self._by_name.pop(record.name, None) + if not isinstance(record.key, str): + self._by_key.pop(record.key, None) + + def get_by_name(self, name: str) -> CollectionRecord | None: + """Return collection record given its name. + + Parameters + ---------- + name : `str` + Collection name. + + Returns + ------- + record : `CollectionRecord` or `None` + Collection record, `None` is returned if the name is not in the + cache. + """ + return self._by_name.get(name) + + def get_by_key(self, key: Any) -> CollectionRecord | None: + """Return collection record given its key. + + Parameters + ---------- + key : `Any` + Collection key. + + Returns + ------- + record : `CollectionRecord` or `None` + Collection record, `None` is returned if the key is not in the + cache. + """ + if isinstance(key, str): + return self._by_name.get(key) + return self._by_key.get(key) + + def records(self) -> Iterator[CollectionRecord]: + """Return iterator for the set of records in the cache, can only be + used if `full` is true. + + Raises + ------ + RuntimeError + Raised if ``self.full`` is `False`. + """ + if not self._full: + raise RuntimeError("cannot call records() if cache is not full") + return iter(self._by_name.values()) diff --git a/python/lsst/daf/butler/registry/_collection_summary_cache.py b/python/lsst/daf/butler/registry/_collection_summary_cache.py new file mode 100644 index 0000000000..ed5b2f2fa2 --- /dev/null +++ b/python/lsst/daf/butler/registry/_collection_summary_cache.py @@ -0,0 +1,86 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("CollectionSummaryCache",) + +from collections.abc import Iterable, Mapping +from typing import Any + +from ._collection_summary import CollectionSummary + + +class CollectionSummaryCache: + """Cache for collection summaries. + + Notes + ----- + This class stores `CollectionSummary` records indexed by collection keys. + For cache to be usable the records that are given to `update` method have + to include all dataset types, i.e. the query that produces records should + not be constrained by dataset type. + """ + + def __init__(self) -> None: + self._cache: dict[Any, CollectionSummary] = {} + + def update(self, summaries: Mapping[Any, CollectionSummary]) -> None: + """Add records to the cache. + + Parameters + ---------- + summaries : `~collections.abc.Mapping` [`Any`, `CollectionSummary`] + Summary records indexed by collection key, records must include all + dataset types. + """ + self._cache.update(summaries) + + def find_summaries(self, keys: Iterable[Any]) -> tuple[dict[Any, CollectionSummary], set[Any]]: + """Return summary records given a set of keys. + + Parameters + ---------- + keys : `~collections.abc.Iterable` [`Any`] + Sequence of collection keys. + + Returns + ------- + summaries : `dict` [`Any`, `CollectionSummary`] + Dictionary of summaries indexed by collection keys, includes + records found in the cache. + missing_keys : `set` [`Any`] + Collection keys that are not present in the cache. + """ + found = {} + not_found = set() + for key in keys: + if (summary := self._cache.get(key)) is not None: + found[key] = summary + else: + not_found.add(key) + return found, not_found diff --git a/python/lsst/daf/butler/registry/_dataset_type_cache.py b/python/lsst/daf/butler/registry/_dataset_type_cache.py new file mode 100644 index 0000000000..3f1665dfa3 --- /dev/null +++ b/python/lsst/daf/butler/registry/_dataset_type_cache.py @@ -0,0 +1,162 @@ +# This file is part of daf_butler. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ("DatasetTypeCache",) + +from collections.abc import Iterable, Iterator +from typing import Generic, TypeVar + +from .._dataset_type import DatasetType + +_T = TypeVar("_T") + + +class DatasetTypeCache(Generic[_T]): + """Cache for dataset types. + + Notes + ----- + This class caches mapping of dataset type name to a corresponding + `DatasetType` instance. Registry manager also needs to cache corresponding + "storage" instance, so this class allows storing additional opaque object + along with the dataset type. + + In come contexts (e.g. ``resolve_wildcard``) a full list of dataset types + is needed. To signify that cache content can be used in such contexts, + cache defines special ``full`` flag that needs to be set by client. + """ + + def __init__(self) -> None: + self._cache: dict[str, tuple[DatasetType, _T | None]] = {} + self._full = False + + @property + def full(self) -> bool: + """`True` if cache holds all known dataset types (`bool`).""" + return self._full + + def add(self, dataset_type: DatasetType, extra: _T | None = None) -> None: + """Add one record to the cache. + + Parameters + ---------- + dataset_type : `DatasetType` + Dataset type, replaces any existing dataset type with the same + name. + extra : `Any`, optional + Additional opaque object stored with this dataset type. + """ + self._cache[dataset_type.name] = (dataset_type, extra) + + def set(self, data: Iterable[DatasetType | tuple[DatasetType, _T | None]], *, full: bool = False) -> None: + """Replace cache contents with the new set of dataset types. + + Parameters + ---------- + data : `~collections.abc.Iterable` + Sequence of `DatasetType` instances or tuples of `DatasetType` and + an extra opaque object. + full : `bool` + If `True` then ``data`` contains all known dataset types. + """ + self.clear() + for item in data: + if isinstance(item, DatasetType): + item = (item, None) + self._cache[item[0].name] = item + self._full = full + + def clear(self) -> None: + """Remove everything from the cache.""" + self._cache = {} + self._full = False + + def discard(self, name: str) -> None: + """Remove named dataset type from the cache. + + Parameters + ---------- + name : `str` + Name of the dataset type to remove. + """ + self._cache.pop(name, None) + + def get(self, name: str) -> tuple[DatasetType | None, _T | None]: + """Return cached info given dataset type name. + + Parameters + ---------- + name : `str` + Dataset type name. + + Returns + ------- + dataset_type : `DatasetType` or `None` + Cached dataset type, `None` is returned if the name is not in the + cache. + extra : `Any` or `None` + Cached opaque data, `None` is returned if the name is not in the + cache or no extra info was stored for this dataset type. + """ + item = self._cache.get(name) + if item is None: + return (None, None) + return item + + def get_dataset_type(self, name: str) -> DatasetType | None: + """Return dataset type given its name. + + Parameters + ---------- + name : `str` + Dataset type name. + + Returns + ------- + dataset_type : `DatasetType` or `None` + Cached dataset type, `None` is returned if the name is not in the + cache. + """ + item = self._cache.get(name) + if item is None: + return None + return item[0] + + def items(self) -> Iterator[tuple[DatasetType, _T | None]]: + """Return iterator for the set of items in the cache, can only be + used if `full` is true. + + Raises + ------ + RuntimeError + Raised if ``self.full`` is `False`. + """ + if not self._full: + raise RuntimeError("cannot call items() if cache is not full") + return iter(self._cache.values()) diff --git a/python/lsst/daf/butler/registry/_registry.py b/python/lsst/daf/butler/registry/_registry.py index 2f0cb3231d..21e651314c 100644 --- a/python/lsst/daf/butler/registry/_registry.py +++ b/python/lsst/daf/butler/registry/_registry.py @@ -118,6 +118,11 @@ def refresh(self) -> None: """ raise NotImplementedError() + @abstractmethod + def caching_context(self) -> contextlib.AbstractContextManager[None]: + """Context manager that enables caching.""" + raise NotImplementedError() + @contextlib.contextmanager @abstractmethod def transaction(self, *, savepoint: bool = False) -> Iterator[None]: diff --git a/python/lsst/daf/butler/registry/collections/_base.py b/python/lsst/daf/butler/registry/collections/_base.py index f611cd630c..f39ee607e9 100644 --- a/python/lsst/daf/butler/registry/collections/_base.py +++ b/python/lsst/daf/butler/registry/collections/_base.py @@ -34,18 +34,18 @@ from abc import abstractmethod from collections import namedtuple from collections.abc import Iterable, Iterator, Set -from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast +from typing import TYPE_CHECKING, Any, TypeVar, cast import sqlalchemy -from ..._timespan import Timespan, TimespanDatabaseRepresentation -from ...dimensions import DimensionUniverse +from ..._timespan import TimespanDatabaseRepresentation from .._collection_type import CollectionType from .._exceptions import MissingCollectionError from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple from ..wildcards import CollectionWildcard if TYPE_CHECKING: + from .._caching_context import CachingContext from ..interfaces import Database, DimensionRecordStorageManager @@ -157,150 +157,10 @@ def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type) ) -class DefaultRunRecord(RunRecord): - """Default `RunRecord` implementation. - - This method assumes the same run table definition as produced by - `makeRunTableSpec` method. The only non-fixed name in the schema - is the PK column name, this needs to be passed in a constructor. - - Parameters - ---------- - db : `Database` - Registry database. - key - Unique collection ID, can be the same as ``name`` if ``name`` is used - for identification. Usually this is an integer or string, but can be - other database-specific type. - name : `str` - Run collection name. - table : `sqlalchemy.schema.Table` - Table for run records. - idColumnName : `str` - Name of the identifying column in run table. - host : `str`, optional - Name of the host where run was produced. - timespan : `Timespan`, optional - Timespan for this run. - """ - - def __init__( - self, - db: Database, - key: Any, - name: str, - *, - table: sqlalchemy.schema.Table, - idColumnName: str, - host: str | None = None, - timespan: Timespan | None = None, - ): - super().__init__(key=key, name=name, type=CollectionType.RUN) - self._db = db - self._table = table - self._host = host - if timespan is None: - timespan = Timespan(begin=None, end=None) - self._timespan = timespan - self._idName = idColumnName - - def update(self, host: str | None = None, timespan: Timespan | None = None) -> None: - # Docstring inherited from RunRecord. - if timespan is None: - timespan = Timespan(begin=None, end=None) - row = { - self._idName: self.key, - "host": host, - } - self._db.getTimespanRepresentation().update(timespan, result=row) - count = self._db.update(self._table, {self._idName: self.key}, row) - if count != 1: - raise RuntimeError(f"Run update affected {count} records; expected exactly one.") - self._host = host - self._timespan = timespan - - @property - def host(self) -> str | None: - # Docstring inherited from RunRecord. - return self._host - - @property - def timespan(self) -> Timespan: - # Docstring inherited from RunRecord. - return self._timespan - - -class DefaultChainedCollectionRecord(ChainedCollectionRecord): - """Default `ChainedCollectionRecord` implementation. - - This method assumes the same chain table definition as produced by - `makeCollectionChainTableSpec` method. All column names in the table are - fixed and hard-coded in the methods. - - Parameters - ---------- - db : `Database` - Registry database. - key - Unique collection ID, can be the same as ``name`` if ``name`` is used - for identification. Usually this is an integer or string, but can be - other database-specific type. - name : `str` - Collection name. - table : `sqlalchemy.schema.Table` - Table for chain relationship records. - universe : `DimensionUniverse` - Object managing all known dimensions. - """ - - def __init__( - self, - db: Database, - key: Any, - name: str, - *, - table: sqlalchemy.schema.Table, - universe: DimensionUniverse, - ): - super().__init__(key=key, name=name, universe=universe) - self._db = db - self._table = table - self._universe = universe - - def _update(self, manager: CollectionManager, children: tuple[str, ...]) -> None: - # Docstring inherited from ChainedCollectionRecord. - rows = [] - position = itertools.count() - for child in manager.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False): - rows.append( - { - "parent": self.key, - "child": child.key, - "position": next(position), - } - ) - with self._db.transaction(): - self._db.delete(self._table, ["parent"], {"parent": self.key}) - self._db.insert(self._table, *rows) - - def _load(self, manager: CollectionManager) -> tuple[str, ...]: - # Docstring inherited from ChainedCollectionRecord. - sql = ( - sqlalchemy.sql.select( - self._table.columns.child, - ) - .select_from(self._table) - .where(self._table.columns.parent == self.key) - .order_by(self._table.columns.position) - ) - with self._db.query(sql) as sql_result: - return tuple(manager[row[self._table.columns.child]].name for row in sql_result.mappings()) - - K = TypeVar("K") -class DefaultCollectionManager(Generic[K], CollectionManager): +class DefaultCollectionManager(CollectionManager[K]): """Default `CollectionManager` implementation. This implementation uses record classes defined in this module and is @@ -331,72 +191,34 @@ def __init__( collectionIdName: str, *, dimensions: DimensionRecordStorageManager, + caching_context: CachingContext, registry_schema_version: VersionTuple | None = None, ): super().__init__(registry_schema_version=registry_schema_version) self._db = db self._tables = tables self._collectionIdName = collectionIdName - self._records: dict[K, CollectionRecord] = {} # indexed by record ID self._dimensions = dimensions + self._caching_context = caching_context def refresh(self) -> None: # Docstring inherited from CollectionManager. - sql = sqlalchemy.sql.select( - *(list(self._tables.collection.columns) + list(self._tables.run.columns)) - ).select_from(self._tables.collection.join(self._tables.run, isouter=True)) - # Put found records into a temporary instead of updating self._records - # in place, for exception safety. - records = [] - chains = [] - TimespanReprClass = self._db.getTimespanRepresentation() - with self._db.query(sql) as sql_result: - sql_rows = sql_result.mappings().fetchall() - for row in sql_rows: - collection_id = row[self._tables.collection.columns[self._collectionIdName]] - name = row[self._tables.collection.columns.name] - type = CollectionType(row["type"]) - record: CollectionRecord - if type is CollectionType.RUN: - record = DefaultRunRecord( - key=collection_id, - name=name, - db=self._db, - table=self._tables.run, - idColumnName=self._collectionIdName, - host=row[self._tables.run.columns.host], - timespan=TimespanReprClass.extract(row), - ) - elif type is CollectionType.CHAINED: - record = DefaultChainedCollectionRecord( - db=self._db, - key=collection_id, - table=self._tables.collection_chain, - name=name, - universe=self._dimensions.universe, - ) - chains.append(record) - else: - record = CollectionRecord(key=collection_id, name=name, type=type) - records.append(record) - self._setRecordCache(records) - for chain in chains: - try: - chain.refresh(self) - except MissingCollectionError: - # This indicates a race condition in which some other client - # created a new collection and added it as a child of this - # (pre-existing) chain between the time we fetched all - # collections and the time we queried for parent-child - # relationships. - # Because that's some other unrelated client, we shouldn't care - # about that parent collection anyway, so we just drop it on - # the floor (a manual refresh can be used to get it back). - self._removeCachedRecord(chain) + if self._caching_context.collection_records is not None: + self._caching_context.collection_records.clear() + + def _fetch_all(self) -> list[CollectionRecord[K]]: + """Retrieve all records into cache if not done so yet.""" + if self._caching_context.collection_records is not None: + if self._caching_context.collection_records.full: + return list(self._caching_context.collection_records.records()) + records = self._fetch_by_key(None) + if self._caching_context.collection_records is not None: + self._caching_context.collection_records.set(records, full=True) + return records def register( self, name: str, type: CollectionType, doc: str | None = None - ) -> tuple[CollectionRecord, bool]: + ) -> tuple[CollectionRecord[K], bool]: # Docstring inherited from CollectionManager. registered = False record = self._getByName(name) @@ -411,7 +233,7 @@ def register( assert isinstance(inserted_or_updated, bool) registered = inserted_or_updated assert row is not None - collection_id = row[self._collectionIdName] + collection_id = cast(K, row[self._collectionIdName]) if type is CollectionType.RUN: TimespanReprClass = self._db.getTimespanRepresentation() row, _ = self._db.sync( @@ -420,25 +242,20 @@ def register( returning=("host",) + TimespanReprClass.getFieldNames(), ) assert row is not None - record = DefaultRunRecord( - db=self._db, + record = RunRecord[K]( key=collection_id, name=name, - table=self._tables.run, - idColumnName=self._collectionIdName, host=row["host"], timespan=TimespanReprClass.extract(row), ) elif type is CollectionType.CHAINED: - record = DefaultChainedCollectionRecord( - db=self._db, + record = ChainedCollectionRecord[K]( key=collection_id, name=name, - table=self._tables.collection_chain, - universe=self._dimensions.universe, + children=[], ) else: - record = CollectionRecord(key=collection_id, name=name, type=type) + record = CollectionRecord[K](key=collection_id, name=name, type=type) self._addCachedRecord(record) return record, registered @@ -453,19 +270,48 @@ def remove(self, name: str) -> None: ) self._removeCachedRecord(record) - def find(self, name: str) -> CollectionRecord: + def find(self, name: str) -> CollectionRecord[K]: # Docstring inherited from CollectionManager. result = self._getByName(name) if result is None: raise MissingCollectionError(f"No collection with name '{name}' found.") return result - def __getitem__(self, key: Any) -> CollectionRecord: + def _find_many(self, names: Iterable[str]) -> list[CollectionRecord[K]]: + """Return multiple records given their names.""" + names = list(names) + # To protect against potential races in cache updates. + records: dict[str, CollectionRecord | None] = {} + if self._caching_context.collection_records is not None: + for name in names: + records[name] = self._caching_context.collection_records.get_by_name(name) + fetch_names = [name for name, record in records.items() if record is None] + else: + fetch_names = list(names) + records = {name: None for name in fetch_names} + if fetch_names: + for record in self._fetch_by_name(fetch_names): + records[record.name] = record + self._addCachedRecord(record) + missing_names = [name for name, record in records.items() if record is None] + if len(missing_names) == 1: + raise MissingCollectionError(f"No collection with name '{missing_names[0]}' found.") + elif len(missing_names) > 1: + raise MissingCollectionError(f"No collections with names '{' '.join(missing_names)}' found.") + return [cast(CollectionRecord[K], records[name]) for name in names] + + def __getitem__(self, key: Any) -> CollectionRecord[K]: # Docstring inherited from CollectionManager. - try: - return self._records[key] - except KeyError as err: - raise MissingCollectionError(f"Collection with key '{key}' not found.") from err + if self._caching_context.collection_records is not None: + if (record := self._caching_context.collection_records.get_by_key(key)) is not None: + return record + if records := self._fetch_by_key([key]): + record = records[0] + if self._caching_context.collection_records is not None: + self._caching_context.collection_records.add(record) + return record + else: + raise MissingCollectionError(f"Collection with key '{key}' not found.") def resolve_wildcard( self, @@ -475,13 +321,13 @@ def resolve_wildcard( done: set[str] | None = None, flatten_chains: bool = True, include_chains: bool | None = None, - ) -> list[CollectionRecord]: + ) -> list[CollectionRecord[K]]: # Docstring inherited if done is None: done = set() include_chains = include_chains if include_chains is not None else not flatten_chains - def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[CollectionRecord]: + def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[CollectionRecord[K]]: if record.name in done: return if record.type in collection_types: @@ -490,28 +336,29 @@ def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[Collect yield record if flatten_chains and record.type is CollectionType.CHAINED: done.add(record.name) - for name in cast(ChainedCollectionRecord, record).children: + for child in self._find_many(cast(ChainedCollectionRecord[K], record).children): # flake8 can't tell that we only delete this closure when # we're totally done with it. - yield from resolve_nested(self.find(name), done) # noqa: F821 + yield from resolve_nested(child, done) # noqa: F821 - result: list[CollectionRecord] = [] + result: list[CollectionRecord[K]] = [] if wildcard.patterns is ...: - for record in self._records.values(): + for record in self._fetch_all(): result.extend(resolve_nested(record, done)) del resolve_nested return result - for name in wildcard.strings: - result.extend(resolve_nested(self.find(name), done)) + if wildcard.strings: + for record in self._find_many(wildcard.strings): + result.extend(resolve_nested(record, done)) if wildcard.patterns: - for record in self._records.values(): + for record in self._fetch_all(): if any(p.fullmatch(record.name) for p in wildcard.patterns): result.extend(resolve_nested(record, done)) del resolve_nested return result - def getDocumentation(self, key: Any) -> str | None: + def getDocumentation(self, key: K) -> str | None: # Docstring inherited from CollectionManager. sql = ( sqlalchemy.sql.select(self._tables.collection.columns.doc) @@ -521,27 +368,76 @@ def getDocumentation(self, key: Any) -> str | None: with self._db.query(sql) as sql_result: return sql_result.scalar() - def setDocumentation(self, key: Any, doc: str | None) -> None: + def setDocumentation(self, key: K, doc: str | None) -> None: # Docstring inherited from CollectionManager. self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc}) - def _setRecordCache(self, records: Iterable[CollectionRecord]) -> None: - """Set internal record cache to contain given records, - old cached records will be removed. - """ - self._records = {} - for record in records: - self._records[record.key] = record - - def _addCachedRecord(self, record: CollectionRecord) -> None: + def _addCachedRecord(self, record: CollectionRecord[K]) -> None: """Add single record to cache.""" - self._records[record.key] = record + if self._caching_context.collection_records is not None: + self._caching_context.collection_records.add(record) - def _removeCachedRecord(self, record: CollectionRecord) -> None: + def _removeCachedRecord(self, record: CollectionRecord[K]) -> None: """Remove single record from cache.""" - del self._records[record.key] + if self._caching_context.collection_records is not None: + self._caching_context.collection_records.discard(record) - @abstractmethod - def _getByName(self, name: str) -> CollectionRecord | None: + def _getByName(self, name: str) -> CollectionRecord[K] | None: """Find collection record given collection name.""" + if self._caching_context.collection_records is not None: + if (record := self._caching_context.collection_records.get_by_name(name)) is not None: + return record + records = self._fetch_by_name([name]) + for record in records: + self._addCachedRecord(record) + return records[0] if records else None + + @abstractmethod + def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[K]]: + """Fetch collection record from database given its name.""" + raise NotImplementedError() + + @abstractmethod + def _fetch_by_key(self, collection_ids: Iterable[K] | None) -> list[CollectionRecord[K]]: + """Fetch collection record from database given its key, or fetch all + collctions if argument is None. + """ raise NotImplementedError() + + def update_chain( + self, chain: ChainedCollectionRecord[K], children: Iterable[str], flatten: bool = False + ) -> ChainedCollectionRecord[K]: + # Docstring inherited from CollectionManager. + children_as_wildcard = CollectionWildcard.from_names(children) + for record in self.resolve_wildcard( + children_as_wildcard, + flatten_chains=True, + include_chains=True, + collection_types={CollectionType.CHAINED}, + ): + if record == chain: + raise ValueError(f"Cycle in collection chaining when defining '{chain.name}'.") + if flatten: + children = tuple( + record.name for record in self.resolve_wildcard(children_as_wildcard, flatten_chains=True) + ) + + rows = [] + position = itertools.count() + names = [] + for child in self.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False): + rows.append( + { + "parent": chain.key, + "child": child.key, + "position": next(position), + } + ) + names.append(child.name) + with self._db.transaction(): + self._db.delete(self._tables.collection_chain, ["parent"], {"parent": chain.key}) + self._db.insert(self._tables.collection_chain, *rows) + + record = ChainedCollectionRecord[K](chain.key, chain.name, children=tuple(names)) + self._addCachedRecord(record) + return record diff --git a/python/lsst/daf/butler/registry/collections/nameKey.py b/python/lsst/daf/butler/registry/collections/nameKey.py index e5e635e61c..9fc2e32271 100644 --- a/python/lsst/daf/butler/registry/collections/nameKey.py +++ b/python/lsst/daf/butler/registry/collections/nameKey.py @@ -26,16 +26,17 @@ # along with this program. If not, see . from __future__ import annotations -from ... import ddl - __all__ = ["NameKeyCollectionManager"] +from collections.abc import Iterable, Mapping from typing import TYPE_CHECKING, Any import sqlalchemy +from ... import ddl from ..._timespan import TimespanDatabaseRepresentation -from ..interfaces import VersionTuple +from .._collection_type import CollectionType +from ..interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord, VersionTuple from ._base import ( CollectionTablesTuple, DefaultCollectionManager, @@ -44,7 +45,8 @@ ) if TYPE_CHECKING: - from ..interfaces import CollectionRecord, Database, DimensionRecordStorageManager, StaticTablesContext + from .._caching_context import CachingContext + from ..interfaces import Database, DimensionRecordStorageManager, StaticTablesContext _KEY_FIELD_SPEC = ddl.FieldSpec("name", dtype=sqlalchemy.String, length=64, primaryKey=True) @@ -68,7 +70,7 @@ def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) -> ) -class NameKeyCollectionManager(DefaultCollectionManager): +class NameKeyCollectionManager(DefaultCollectionManager[str]): """A `CollectionManager` implementation that uses collection names for primary/foreign keys and aggressively loads all collection/run records in the database into memory. @@ -85,6 +87,7 @@ def initialize( context: StaticTablesContext, *, dimensions: DimensionRecordStorageManager, + caching_context: CachingContext, registry_schema_version: VersionTuple | None = None, ) -> NameKeyCollectionManager: # Docstring inherited from CollectionManager. @@ -93,6 +96,7 @@ def initialize( tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore collectionIdName="name", dimensions=dimensions, + caching_context=caching_context, registry_schema_version=registry_schema_version, ) @@ -152,9 +156,117 @@ def addRunForeignKey( ) return copy - def _getByName(self, name: str) -> CollectionRecord | None: - # Docstring inherited from DefaultCollectionManager. - return self._records.get(name) + def getParentChains(self, key: str) -> set[str]: + # Docstring inherited from CollectionManager. + table = self._tables.collection_chain + sql = ( + sqlalchemy.sql.select(table.columns["parent"]) + .select_from(table) + .where(table.columns["child"] == key) + ) + with self._db.query(sql) as sql_result: + parent_names = set(sql_result.scalars().all()) + return parent_names + + def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[str]]: + # Docstring inherited from base class. + return self._fetch_by_key(names) + + def _fetch_by_key(self, collection_ids: Iterable[str] | None) -> list[CollectionRecord[str]]: + # Docstring inherited from base class. + sql = sqlalchemy.sql.select(*self._tables.collection.columns, *self._tables.run.columns).select_from( + self._tables.collection.join(self._tables.run, isouter=True) + ) + + chain_sql = sqlalchemy.sql.select( + self._tables.collection_chain.columns["parent"], + self._tables.collection_chain.columns["position"], + self._tables.collection_chain.columns["child"], + ) + + records: list[CollectionRecord[str]] = [] + # We want to keep transactions as short as possible. When we fetch + # everything we want to quickly fetch things into memory and finish + # transaction. When we fetch just few records we need to process result + # of the first query before we can run the second one. + if collection_ids is not None: + sql = sql.where(self._tables.collection.columns[self._collectionIdName].in_(collection_ids)) + with self._db.transaction(): + with self._db.query(sql) as sql_result: + sql_rows = sql_result.mappings().fetchall() + + records, chained_ids = self._rows_to_records(sql_rows) + + if chained_ids: + # Retrieve chained collection compositions + chain_sql = chain_sql.where( + self._tables.collection_chain.columns["parent"].in_(chained_ids) + ) + with self._db.query(chain_sql) as sql_result: + chain_rows = sql_result.mappings().fetchall() + + records += self._rows_to_chains(chain_rows, chained_ids) + + else: + with self._db.transaction(): + with self._db.query(sql) as sql_result: + sql_rows = sql_result.mappings().fetchall() + with self._db.query(chain_sql) as sql_result: + chain_rows = sql_result.mappings().fetchall() + + records, chained_ids = self._rows_to_records(sql_rows) + records += self._rows_to_chains(chain_rows, chained_ids) + + return records + + def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[str]], list[str]]: + """Convert rows returned from collection query to a list of records + and a list chained collection names. + """ + records: list[CollectionRecord[str]] = [] + TimespanReprClass = self._db.getTimespanRepresentation() + chained_ids: list[str] = [] + for row in rows: + name = row[self._tables.collection.columns.name] + type = CollectionType(row["type"]) + record: CollectionRecord[str] + if type is CollectionType.RUN: + record = RunRecord[str]( + key=name, + name=name, + host=row[self._tables.run.columns.host], + timespan=TimespanReprClass.extract(row), + ) + records.append(record) + elif type is CollectionType.CHAINED: + # Need to delay chained collection construction until to + # fetch their children names. + chained_ids.append(name) + else: + record = CollectionRecord[str](key=name, name=name, type=type) + records.append(record) + + return records, chained_ids + + def _rows_to_chains(self, rows: Iterable[Mapping], chained_ids: list[str]) -> list[CollectionRecord[str]]: + """Convert rows returned from collection chain query to a list of + records. + """ + chains_defs: dict[str, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids} + for row in rows: + chains_defs[row["parent"]].append((row["position"], row["child"])) + + records: list[CollectionRecord[str]] = [] + for name, children in chains_defs.items(): + children_names = [child for _, child in sorted(children)] + record = ChainedCollectionRecord[str]( + key=name, + name=name, + children=children_names, + ) + records.append(record) + + return records @classmethod def currentVersions(cls) -> list[VersionTuple]: diff --git a/python/lsst/daf/butler/registry/collections/synthIntKey.py b/python/lsst/daf/butler/registry/collections/synthIntKey.py index 8e49140c8d..3b7dbb1de4 100644 --- a/python/lsst/daf/butler/registry/collections/synthIntKey.py +++ b/python/lsst/daf/butler/registry/collections/synthIntKey.py @@ -30,13 +30,14 @@ __all__ = ["SynthIntKeyCollectionManager"] -from collections.abc import Iterable +from collections.abc import Iterable, Mapping from typing import TYPE_CHECKING, Any import sqlalchemy from ..._timespan import TimespanDatabaseRepresentation -from ..interfaces import CollectionRecord, VersionTuple +from .._collection_type import CollectionType +from ..interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord, VersionTuple from ._base import ( CollectionTablesTuple, DefaultCollectionManager, @@ -45,6 +46,7 @@ ) if TYPE_CHECKING: + from .._caching_context import CachingContext from ..interfaces import Database, DimensionRecordStorageManager, StaticTablesContext @@ -73,43 +75,11 @@ def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) -> ) -class SynthIntKeyCollectionManager(DefaultCollectionManager): +class SynthIntKeyCollectionManager(DefaultCollectionManager[int]): """A `CollectionManager` implementation that uses synthetic primary key (auto-incremented integer) for collections table. - - Most of the logic, including caching policy, is implemented in the base - class, this class only adds customizations specific to this particular - table schema. - - Parameters - ---------- - db : `Database` - Interface to the underlying database engine and namespace. - tables : `NamedTuple` - Named tuple of SQLAlchemy table objects. - collectionIdName : `str` - Name of the column in collections table that identifies it (PK). - dimensions : `DimensionRecordStorageManager` - Manager object for the dimensions in this `Registry`. """ - def __init__( - self, - db: Database, - tables: CollectionTablesTuple, - collectionIdName: str, - dimensions: DimensionRecordStorageManager, - registry_schema_version: VersionTuple | None = None, - ): - super().__init__( - db=db, - tables=tables, - collectionIdName=collectionIdName, - dimensions=dimensions, - registry_schema_version=registry_schema_version, - ) - self._nameCache: dict[str, CollectionRecord] = {} # indexed by collection name - @classmethod def initialize( cls, @@ -117,6 +87,7 @@ def initialize( context: StaticTablesContext, *, dimensions: DimensionRecordStorageManager, + caching_context: CachingContext, registry_schema_version: VersionTuple | None = None, ) -> SynthIntKeyCollectionManager: # Docstring inherited from CollectionManager. @@ -125,6 +96,7 @@ def initialize( tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore collectionIdName="collection_id", dimensions=dimensions, + caching_context=caching_context, registry_schema_version=registry_schema_version, ) @@ -184,29 +156,134 @@ def addRunForeignKey( ) return copy - def _setRecordCache(self, records: Iterable[CollectionRecord]) -> None: - """Set internal record cache to contain given records, - old cached records will be removed. + def getParentChains(self, key: int) -> set[str]: + # Docstring inherited from CollectionManager. + chain = self._tables.collection_chain + collection = self._tables.collection + sql = ( + sqlalchemy.sql.select(collection.columns["name"]) + .select_from(collection) + .join(chain, onclause=collection.columns[self._collectionIdName] == chain.columns["parent"]) + .where(chain.columns["child"] == key) + ) + with self._db.query(sql) as sql_result: + parent_names = set(sql_result.scalars().all()) + return parent_names + + def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[int]]: + # Docstring inherited from base class. + return self._fetch("name", names) + + def _fetch_by_key(self, collection_ids: Iterable[int] | None) -> list[CollectionRecord[int]]: + # Docstring inherited from base class. + return self._fetch(self._collectionIdName, collection_ids) + + def _fetch( + self, column_name: str, collections: Iterable[int | str] | None + ) -> list[CollectionRecord[int]]: + collection_chain = self._tables.collection_chain + collection = self._tables.collection + sql = sqlalchemy.sql.select(*collection.columns, *self._tables.run.columns).select_from( + collection.join(self._tables.run, isouter=True) + ) + + chain_sql = ( + sqlalchemy.sql.select( + collection_chain.columns["parent"], + collection_chain.columns["position"], + collection.columns["name"].label("child_name"), + ) + .select_from(collection_chain) + .join( + collection, + onclause=collection_chain.columns["child"] == collection.columns[self._collectionIdName], + ) + ) + + records: list[CollectionRecord[int]] = [] + # We want to keep transactions as short as possible. When we fetch + # everything we want to quickly fetch things into memory and finish + # transaction. When we fetch just few records we need to process first + # query before wi can run second one, + if collections is not None: + sql = sql.where(collection.columns[column_name].in_(collections)) + with self._db.transaction(): + with self._db.query(sql) as sql_result: + sql_rows = sql_result.mappings().fetchall() + + records, chained_ids = self._rows_to_records(sql_rows) + + if chained_ids: + chain_sql = chain_sql.where(collection_chain.columns["parent"].in_(list(chained_ids))) + + with self._db.query(chain_sql) as sql_result: + chain_rows = sql_result.mappings().fetchall() + + records += self._rows_to_chains(chain_rows, chained_ids) + + else: + with self._db.transaction(): + with self._db.query(sql) as sql_result: + sql_rows = sql_result.mappings().fetchall() + with self._db.query(chain_sql) as sql_result: + chain_rows = sql_result.mappings().fetchall() + + records, chained_ids = self._rows_to_records(sql_rows) + records += self._rows_to_chains(chain_rows, chained_ids) + + return records + + def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[int]], dict[int, str]]: + """Convert rows returned from collection query to a list of records + and a dict chained collection names. + """ + records: list[CollectionRecord[int]] = [] + chained_ids: dict[int, str] = {} + TimespanReprClass = self._db.getTimespanRepresentation() + for row in rows: + key: int = row[self._collectionIdName] + name: str = row[self._tables.collection.columns.name] + type = CollectionType(row["type"]) + record: CollectionRecord[int] + if type is CollectionType.RUN: + record = RunRecord[int]( + key=key, + name=name, + host=row[self._tables.run.columns.host], + timespan=TimespanReprClass.extract(row), + ) + records.append(record) + elif type is CollectionType.CHAINED: + # Need to delay chained collection construction until to + # fetch their children names. + chained_ids[key] = name + else: + record = CollectionRecord[int](key=key, name=name, type=type) + records.append(record) + return records, chained_ids + + def _rows_to_chains( + self, rows: Iterable[Mapping], chained_ids: dict[int, str] + ) -> list[CollectionRecord[int]]: + """Convert rows returned from collection chain query to a list of + records. """ - self._records = {} - self._nameCache = {} - for record in records: - self._records[record.key] = record - self._nameCache[record.name] = record - - def _addCachedRecord(self, record: CollectionRecord) -> None: - """Add single record to cache.""" - self._records[record.key] = record - self._nameCache[record.name] = record - - def _removeCachedRecord(self, record: CollectionRecord) -> None: - """Remove single record from cache.""" - del self._records[record.key] - del self._nameCache[record.name] - - def _getByName(self, name: str) -> CollectionRecord | None: - # Docstring inherited from DefaultCollectionManager. - return self._nameCache.get(name) + chains_defs: dict[int, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids} + for row in rows: + chains_defs[row["parent"]].append((row["position"], row["child_name"])) + + records: list[CollectionRecord[int]] = [] + for key, children in chains_defs.items(): + name = chained_ids[key] + children_names = [child for _, child in sorted(children)] + record = ChainedCollectionRecord[int]( + key=key, + name=name, + children=children_names, + ) + records.append(record) + + return records @classmethod def currentVersions(cls) -> list[VersionTuple]: diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py b/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py index d52d337a2b..692d8585a3 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py @@ -4,9 +4,11 @@ __all__ = ("ByDimensionsDatasetRecordStorageManagerUUID",) +import dataclasses import logging import warnings from collections import defaultdict +from collections.abc import Iterable, Mapping from typing import TYPE_CHECKING, Any import sqlalchemy @@ -30,6 +32,7 @@ ) if TYPE_CHECKING: + from ..._caching_context import CachingContext from ...interfaces import ( CollectionManager, CollectionRecord, @@ -54,6 +57,34 @@ class MissingDatabaseTableError(RuntimeError): """Exception raised when a table is not found in a database.""" +@dataclasses.dataclass +class _DatasetTypeRecord: + """Contents of a single dataset type record.""" + + dataset_type: DatasetType + dataset_type_id: int + tag_table_name: str + calib_table_name: str | None + + +class _SpecTableFactory: + """Factory for `sqlalchemy.schema.Table` instances that builds table + instances using provided `ddl.TableSpec` definition and verifies that + table exists in the database. + """ + + def __init__(self, db: Database, name: str, spec: ddl.TableSpec): + self._db = db + self._name = name + self._spec = spec + + def __call__(self) -> sqlalchemy.schema.Table: + table = self._db.getExistingTable(self._name, self._spec) + if table is None: + raise MissingDatabaseTableError(f"Table {self._name} is missing from database schema.") + return table + + class ByDimensionsDatasetRecordStorageManagerBase(DatasetRecordStorageManager): """A manager class for datasets that uses one dataset-collection table for each group of dataset types that share the same dimensions. @@ -90,6 +121,8 @@ class ByDimensionsDatasetRecordStorageManagerBase(DatasetRecordStorageManager): tables used by this class. summaries : `CollectionSummaryManager` Structure containing tables that summarize the contents of collections. + caching_context : `CachingContext` + Object controlling caching of information returned by managers. """ def __init__( @@ -100,6 +133,7 @@ def __init__( dimensions: DimensionRecordStorageManager, static: StaticDatasetTablesTuple, summaries: CollectionSummaryManager, + caching_context: CachingContext, registry_schema_version: VersionTuple | None = None, ): super().__init__(registry_schema_version=registry_schema_version) @@ -108,8 +142,7 @@ def __init__( self._dimensions = dimensions self._static = static self._summaries = summaries - self._byName: dict[str, ByDimensionsDatasetRecordStorage] = {} - self._byId: dict[int, ByDimensionsDatasetRecordStorage] = {} + self._caching_context = caching_context @classmethod def initialize( @@ -119,6 +152,7 @@ def initialize( *, collections: CollectionManager, dimensions: DimensionRecordStorageManager, + caching_context: CachingContext, registry_schema_version: VersionTuple | None = None, ) -> DatasetRecordStorageManager: # Docstring inherited from DatasetRecordStorageManager. @@ -131,6 +165,8 @@ def initialize( context, collections=collections, dimensions=dimensions, + dataset_type_table=static.dataset_type, + caching_context=caching_context, ) return cls( db=db, @@ -138,6 +174,7 @@ def initialize( dimensions=dimensions, static=static, summaries=summaries, + caching_context=caching_context, registry_schema_version=registry_schema_version, ) @@ -205,60 +242,34 @@ def addDatasetForeignKey( def refresh(self) -> None: # Docstring inherited from DatasetRecordStorageManager. - byName: dict[str, ByDimensionsDatasetRecordStorage] = {} - byId: dict[int, ByDimensionsDatasetRecordStorage] = {} - dataset_types: dict[int, DatasetType] = {} - c = self._static.dataset_type.columns - with self._db.query(self._static.dataset_type.select()) as sql_result: - sql_rows = sql_result.mappings().fetchall() - for row in sql_rows: - name = row[c.name] - dimensions = self._dimensions.loadDimensionGraph(row[c.dimensions_key]) - calibTableName = row[c.calibration_association_table] - datasetType = DatasetType( - name, dimensions, row[c.storage_class], isCalibration=(calibTableName is not None) + if self._caching_context.dataset_types is not None: + self._caching_context.dataset_types.clear() + + def _make_storage(self, record: _DatasetTypeRecord) -> ByDimensionsDatasetRecordStorage: + """Create storage instance for a dataset type record.""" + tags_spec = makeTagTableSpec(record.dataset_type, type(self._collections), self.getIdColumnType()) + tags_table_factory = _SpecTableFactory(self._db, record.tag_table_name, tags_spec) + calibs_table_factory = None + if record.calib_table_name is not None: + calibs_spec = makeCalibTableSpec( + record.dataset_type, + type(self._collections), + self._db.getTimespanRepresentation(), + self.getIdColumnType(), ) - tags = self._db.getExistingTable( - row[c.tag_association_table], - makeTagTableSpec(datasetType, type(self._collections), self.getIdColumnType()), - ) - if tags is None: - raise MissingDatabaseTableError( - f"Table {row[c.tag_association_table]} is missing from database schema." - ) - if calibTableName is not None: - calibs = self._db.getExistingTable( - row[c.calibration_association_table], - makeCalibTableSpec( - datasetType, - type(self._collections), - self._db.getTimespanRepresentation(), - self.getIdColumnType(), - ), - ) - if calibs is None: - raise MissingDatabaseTableError( - f"Table {row[c.calibration_association_table]} is missing from database schema." - ) - else: - calibs = None - storage = self._recordStorageType( - db=self._db, - datasetType=datasetType, - static=self._static, - summaries=self._summaries, - tags=tags, - calibs=calibs, - dataset_type_id=row["id"], - collections=self._collections, - use_astropy_ingest_date=self.ingest_date_dtype() is ddl.AstropyTimeNsecTai, - ) - byName[datasetType.name] = storage - byId[storage._dataset_type_id] = storage - dataset_types[row["id"]] = datasetType - self._byName = byName - self._byId = byId - self._summaries.refresh(dataset_types) + calibs_table_factory = _SpecTableFactory(self._db, record.calib_table_name, calibs_spec) + storage = self._recordStorageType( + db=self._db, + datasetType=record.dataset_type, + static=self._static, + summaries=self._summaries, + tags_table_factory=tags_table_factory, + calibs_table_factory=calibs_table_factory, + dataset_type_id=record.dataset_type_id, + collections=self._collections, + use_astropy_ingest_date=self.ingest_date_dtype() is ddl.AstropyTimeNsecTai, + ) + return storage def remove(self, name: str) -> None: # Docstring inherited from DatasetRecordStorageManager. @@ -281,31 +292,48 @@ def remove(self, name: str) -> None: def find(self, name: str) -> DatasetRecordStorage | None: # Docstring inherited from DatasetRecordStorageManager. - return self._byName.get(name) + if self._caching_context.dataset_types is not None: + _, storage = self._caching_context.dataset_types.get(name) + if storage is not None: + return storage + else: + # On the first cache miss populate the cache with complete list + # of dataset types (if it was not done yet). + if not self._caching_context.dataset_types.full: + self._fetch_dataset_types() + # Try again + _, storage = self._caching_context.dataset_types.get(name) + if self._caching_context.dataset_types.full: + # If not in cache then dataset type is not defined. + return storage + record = self._fetch_dataset_type_record(name) + if record is not None: + storage = self._make_storage(record) + if self._caching_context.dataset_types is not None: + self._caching_context.dataset_types.add(storage.datasetType, storage) + return storage + else: + return None - def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool]: + def register(self, datasetType: DatasetType) -> bool: # Docstring inherited from DatasetRecordStorageManager. if datasetType.isComponent(): raise ValueError( f"Component dataset types can not be stored in registry. Rejecting {datasetType.name}" ) - storage = self._byName.get(datasetType.name) - if storage is None: + record = self._fetch_dataset_type_record(datasetType.name) + if record is None: dimensionsKey = self._dimensions.saveDimensionGraph(datasetType.dimensions) tagTableName = makeTagTableName(datasetType, dimensionsKey) - calibTableName = ( - makeCalibTableName(datasetType, dimensionsKey) if datasetType.isCalibration() else None - ) - # The order is important here, we want to create tables first and - # only register them if this operation is successful. We cannot - # wrap it into a transaction because database class assumes that - # DDL is not transaction safe in general. - tags = self._db.ensureTableExists( + self._db.ensureTableExists( tagTableName, makeTagTableSpec(datasetType, type(self._collections), self.getIdColumnType()), ) + calibTableName = ( + makeCalibTableName(datasetType, dimensionsKey) if datasetType.isCalibration() else None + ) if calibTableName is not None: - calibs = self._db.ensureTableExists( + self._db.ensureTableExists( calibTableName, makeCalibTableSpec( datasetType, @@ -314,8 +342,6 @@ def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool self.getIdColumnType(), ), ) - else: - calibs = None row, inserted = self._db.sync( self._static.dataset_type, keys={"name": datasetType.name}, @@ -331,28 +357,25 @@ def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool }, returning=["id", "tag_association_table"], ) - assert row is not None - storage = self._recordStorageType( - db=self._db, - datasetType=datasetType, - static=self._static, - summaries=self._summaries, - tags=tags, - calibs=calibs, - dataset_type_id=row["id"], - collections=self._collections, - use_astropy_ingest_date=self.ingest_date_dtype() is ddl.AstropyTimeNsecTai, - ) - self._byName[datasetType.name] = storage - self._byId[storage._dataset_type_id] = storage + # Make sure that cache is updated + if self._caching_context.dataset_types is not None and row is not None: + record = _DatasetTypeRecord( + dataset_type=datasetType, + dataset_type_id=row["id"], + tag_table_name=tagTableName, + calib_table_name=calibTableName, + ) + storage = self._make_storage(record) + self._caching_context.dataset_types.add(datasetType, storage) else: - if datasetType != storage.datasetType: + if datasetType != record.dataset_type: raise ConflictingDefinitionError( f"Given dataset type {datasetType} is inconsistent " - f"with database definition {storage.datasetType}." + f"with database definition {record.dataset_type}." ) inserted = False - return storage, bool(inserted) + + return bool(inserted) def resolve_wildcard( self, @@ -406,15 +429,13 @@ def resolve_wildcard( raise TypeError( "Universal wildcard '...' is not permitted for dataset types in this context." ) - for storage in self._byName.values(): - result[storage.datasetType].add(None) + for datasetType in self._fetch_dataset_types(): + result[datasetType].add(None) if components: try: - result[storage.datasetType].update( - storage.datasetType.storageClass.allComponents().keys() - ) + result[datasetType].update(datasetType.storageClass.allComponents().keys()) if ( - storage.datasetType.storageClass.allComponents() + datasetType.storageClass.allComponents() and not already_warned and components_deprecated ): @@ -426,7 +447,7 @@ def resolve_wildcard( already_warned = True except KeyError as err: _LOG.warning( - f"Could not load storage class {err} for {storage.datasetType.name}; " + f"Could not load storage class {err} for {datasetType.name}; " "if it has components they will not be included in query results.", ) elif wildcard.patterns: @@ -438,29 +459,28 @@ def resolve_wildcard( FutureWarning, stacklevel=find_outside_stacklevel("lsst.daf.butler"), ) - for storage in self._byName.values(): - if any(p.fullmatch(storage.datasetType.name) for p in wildcard.patterns): - result[storage.datasetType].add(None) + dataset_types = self._fetch_dataset_types() + for datasetType in dataset_types: + if any(p.fullmatch(datasetType.name) for p in wildcard.patterns): + result[datasetType].add(None) if components is not False: - for storage in self._byName.values(): - if components is None and storage.datasetType in result: + for datasetType in dataset_types: + if components is None and datasetType in result: continue try: - components_for_parent = storage.datasetType.storageClass.allComponents().keys() + components_for_parent = datasetType.storageClass.allComponents().keys() except KeyError as err: _LOG.warning( - f"Could not load storage class {err} for {storage.datasetType.name}; " + f"Could not load storage class {err} for {datasetType.name}; " "if it has components they will not be included in query results." ) continue for component_name in components_for_parent: if any( - p.fullmatch( - DatasetType.nameWithComponent(storage.datasetType.name, component_name) - ) + p.fullmatch(DatasetType.nameWithComponent(datasetType.name, component_name)) for p in wildcard.patterns ): - result[storage.datasetType].add(component_name) + result[datasetType].add(component_name) if not already_warned and components_deprecated: warnings.warn( deprecation_message, @@ -476,29 +496,93 @@ def getDatasetRef(self, id: DatasetId) -> DatasetRef | None: sqlalchemy.sql.select( self._static.dataset.columns.dataset_type_id, self._static.dataset.columns[self._collections.getRunForeignKeyName()], + *self._static.dataset_type.columns, ) .select_from(self._static.dataset) + .join(self._static.dataset_type) .where(self._static.dataset.columns.id == id) ) with self._db.query(sql) as sql_result: row = sql_result.mappings().fetchone() if row is None: return None - recordsForType = self._byId.get(row[self._static.dataset.columns.dataset_type_id]) - if recordsForType is None: - self.refresh() - recordsForType = self._byId.get(row[self._static.dataset.columns.dataset_type_id]) - assert recordsForType is not None, "Should be guaranteed by foreign key constraints." + record = self._record_from_row(row) + storage: DatasetRecordStorage | None = None + if self._caching_context.dataset_types is not None: + _, storage = self._caching_context.dataset_types.get(record.dataset_type.name) + if storage is None: + storage = self._make_storage(record) + if self._caching_context.dataset_types is not None: + self._caching_context.dataset_types.add(storage.datasetType, storage) + assert isinstance(storage, ByDimensionsDatasetRecordStorage), "Not expected storage class" return DatasetRef( - recordsForType.datasetType, - dataId=recordsForType.getDataId(id=id), + storage.datasetType, + dataId=storage.getDataId(id=id), id=id, run=self._collections[row[self._collections.getRunForeignKeyName()]].name, ) + def _fetch_dataset_type_record(self, name: str) -> _DatasetTypeRecord | None: + """Retrieve all dataset types defined in database. + + Yields + ------ + dataset_types : `_DatasetTypeRecord` + Information from a single database record. + """ + c = self._static.dataset_type.columns + stmt = self._static.dataset_type.select().where(c.name == name) + with self._db.query(stmt) as sql_result: + row = sql_result.mappings().one_or_none() + if row is None: + return None + else: + return self._record_from_row(row) + + def _record_from_row(self, row: Mapping) -> _DatasetTypeRecord: + name = row["name"] + dimensions = self._dimensions.loadDimensionGraph(row["dimensions_key"]) + calibTableName = row["calibration_association_table"] + datasetType = DatasetType( + name, dimensions, row["storage_class"], isCalibration=(calibTableName is not None) + ) + return _DatasetTypeRecord( + dataset_type=datasetType, + dataset_type_id=row["id"], + tag_table_name=row["tag_association_table"], + calib_table_name=calibTableName, + ) + + def _dataset_type_from_row(self, row: Mapping) -> DatasetType: + return self._record_from_row(row).dataset_type + + def _fetch_dataset_types(self) -> list[DatasetType]: + """Fetch list of all defined dataset types.""" + if self._caching_context.dataset_types is not None: + if self._caching_context.dataset_types.full: + return [dataset_type for dataset_type, _ in self._caching_context.dataset_types.items()] + with self._db.query(self._static.dataset_type.select()) as sql_result: + sql_rows = sql_result.mappings().fetchall() + records = [self._record_from_row(row) for row in sql_rows] + # Cache everything and specify that cache is complete. + if self._caching_context.dataset_types is not None: + cache_data = [(record.dataset_type, self._make_storage(record)) for record in records] + self._caching_context.dataset_types.set(cache_data, full=True) + return [record.dataset_type for record in records] + def getCollectionSummary(self, collection: CollectionRecord) -> CollectionSummary: # Docstring inherited from DatasetRecordStorageManager. - return self._summaries.get(collection) + summaries = self._summaries.fetch_summaries([collection], None, self._dataset_type_from_row) + return summaries[collection.key] + + def fetch_summaries( + self, collections: Iterable[CollectionRecord], dataset_types: Iterable[DatasetType] | None = None + ) -> Mapping[Any, CollectionSummary]: + # Docstring inherited from DatasetRecordStorageManager. + dataset_type_names: Iterable[str] | None = None + if dataset_types is not None: + dataset_type_names = set(dataset_type.name for dataset_type in dataset_types) + return self._summaries.fetch_summaries(collections, dataset_type_names, self._dataset_type_from_row) _versions: list[VersionTuple] """Schema version for this class.""" diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py b/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py index 454702b5b8..b88b80a3c5 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py @@ -32,7 +32,7 @@ __all__ = ("ByDimensionsDatasetRecordStorage",) -from collections.abc import Iterable, Iterator, Sequence, Set +from collections.abc import Callable, Iterable, Iterator, Sequence, Set from datetime import datetime from typing import TYPE_CHECKING @@ -77,9 +77,9 @@ def __init__( collections: CollectionManager, static: StaticDatasetTablesTuple, summaries: CollectionSummaryManager, - tags: sqlalchemy.schema.Table, + tags_table_factory: Callable[[], sqlalchemy.schema.Table], use_astropy_ingest_date: bool, - calibs: sqlalchemy.schema.Table | None, + calibs_table_factory: Callable[[], sqlalchemy.schema.Table] | None, ): super().__init__(datasetType=datasetType) self._dataset_type_id = dataset_type_id @@ -87,10 +87,26 @@ def __init__( self._collections = collections self._static = static self._summaries = summaries - self._tags = tags - self._calibs = calibs + self._tags_table_factory = tags_table_factory + self._calibs_table_factory = calibs_table_factory self._runKeyColumn = collections.getRunForeignKeyName() self._use_astropy = use_astropy_ingest_date + self._tags_table: sqlalchemy.schema.Table | None = None + self._calibs_table: sqlalchemy.schema.Table | None = None + + @property + def _tags(self) -> sqlalchemy.schema.Table: + if self._tags_table is None: + self._tags_table = self._tags_table_factory() + return self._tags_table + + @property + def _calibs(self) -> sqlalchemy.schema.Table | None: + if self._calibs_table is None: + if self._calibs_table_factory is None: + return None + self._calibs_table = self._calibs_table_factory() + return self._calibs_table def delete(self, datasets: Iterable[DatasetRef]) -> None: # Docstring inherited from DatasetRecordStorage. diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py b/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py index 1356576376..d051b4b38d 100644 --- a/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py +++ b/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py @@ -31,7 +31,7 @@ __all__ = ("CollectionSummaryManager",) -from collections.abc import Iterable, Mapping +from collections.abc import Callable, Iterable, Mapping from typing import Any, Generic, TypeVar import sqlalchemy @@ -39,16 +39,17 @@ from ...._dataset_type import DatasetType from ...._named import NamedKeyDict, NamedKeyMapping from ....dimensions import GovernorDimension, addDimensionForeignKey +from ..._caching_context import CachingContext from ..._collection_summary import CollectionSummary from ..._collection_type import CollectionType from ...interfaces import ( - ChainedCollectionRecord, CollectionManager, CollectionRecord, Database, DimensionRecordStorageManager, StaticTablesContext, ) +from ...wildcards import CollectionWildcard _T = TypeVar("_T") @@ -133,6 +134,10 @@ class CollectionSummaryManager: Manager object for the dimensions in this `Registry`. tables : `CollectionSummaryTables` Struct containing the tables that hold collection summaries. + dataset_type_table : `sqlalchemy.schema.Table` + Table containing dataset type definitions. + caching_context : `CachingContext` + Object controlling caching of information returned by managers. """ def __init__( @@ -142,13 +147,16 @@ def __init__( collections: CollectionManager, dimensions: DimensionRecordStorageManager, tables: CollectionSummaryTables[sqlalchemy.schema.Table], + dataset_type_table: sqlalchemy.schema.Table, + caching_context: CachingContext, ): self._db = db self._collections = collections self._collectionKeyName = collections.getCollectionForeignKeyName() self._dimensions = dimensions self._tables = tables - self._cache: dict[Any, CollectionSummary] = {} + self._dataset_type_table = dataset_type_table + self._caching_context = caching_context @classmethod def initialize( @@ -158,6 +166,8 @@ def initialize( *, collections: CollectionManager, dimensions: DimensionRecordStorageManager, + dataset_type_table: sqlalchemy.schema.Table, + caching_context: CachingContext, ) -> CollectionSummaryManager: """Create all summary tables (or check that they have been created), returning an object to manage them. @@ -173,6 +183,10 @@ def initialize( Manager object for the collections in this `Registry`. dimensions : `DimensionRecordStorageManager` Manager object for the dimensions in this `Registry`. + dataset_type_table : `sqlalchemy.schema.Table` + Table containing dataset type definitions. + caching_context : `CachingContext` + Object controlling caching of information returned by managers. Returns ------- @@ -194,6 +208,8 @@ def initialize( collections=collections, dimensions=dimensions, tables=tables, + dataset_type_table=dataset_type_table, + caching_context=caching_context, ) def update( @@ -237,39 +253,67 @@ def update( self._tables.dimensions[dimension], *[{self._collectionKeyName: collection.key, dimension: v} for v in values], ) - # Update the in-memory cache, too. These changes will remain even if - # the database inserts above are rolled back by some later exception in - # the same transaction, but that's okay: we never promise that a - # CollectionSummary has _just_ the dataset types and governor dimension - # values that are actually present, only that it is guaranteed to - # contain any dataset types or governor dimension values that _may_ be - # present. - # That guarantee (and the possibility of rollbacks) means we can't get - # away with checking the cache before we try the database inserts, - # however; if someone had attempted to insert datasets of some dataset - # type previously, and that rolled back, and we're now trying to insert - # some more datasets of that same type, it would not be okay to skip - # the DB summary table insertions because we found entries in the - # in-memory cache. - self.get(collection).update(summary) - - def refresh(self, dataset_types: Mapping[int, DatasetType]) -> None: - """Load all collection summary information from the database. + + def fetch_summaries( + self, + collections: Iterable[CollectionRecord], + dataset_type_names: Iterable[str] | None, + dataset_type_factory: Callable[[sqlalchemy.engine.RowMapping], DatasetType], + ) -> Mapping[Any, CollectionSummary]: + """Fetch collection summaries given their names and dataset types. Parameters ---------- - dataset_types : `~collections.abc.Mapping` [`int`, `DatasetType`] - Mapping of an `int` dataset_type_id value to `DatasetType` - instance. Summaries are only loaded for dataset types that appear - in this mapping. + collections : `~collections.abc.Iterable` [`CollectionRecord`] + Collection records to query. + dataset_type_names : `~collections.abc.Iterable` [`str`] + Names of dataset types to include into returned summaries. If + `None` then all dataset types will be included. + dataset_type_factory : `Callable` + Method that takes a table row and make `DatasetType` instance out + of it. + + Returns + ------- + summaries : `~collections.abc.Mapping` [`Any`, `CollectionSummary`] + Collection summaries indexed by collection record key. This mapping + will also contain all nested non-chained collections of the chained + collections. """ + summaries: dict[Any, CollectionSummary] = {} + # Check what we have in cache first. + if self._caching_context.collection_summaries is not None: + summaries, missing_keys = self._caching_context.collection_summaries.find_summaries( + [record.key for record in collections] + ) + if not missing_keys: + return summaries + else: + collections = [record for record in collections if record.key in missing_keys] + + # Need to expand all chained collections first. + non_chains: list[CollectionRecord] = [] + chains: dict[CollectionRecord, list[CollectionRecord]] = {} + for collection in collections: + if collection.type is CollectionType.CHAINED: + children = self._collections.resolve_wildcard( + CollectionWildcard.from_names([collection.name]), + flatten_chains=True, + include_chains=False, + ) + non_chains += children + chains[collection] = children + else: + non_chains.append(collection) + # Set up the SQL query we'll use to fetch all of the summary # information at once. - columns = [ - self._tables.datasetType.columns[self._collectionKeyName].label(self._collectionKeyName), - self._tables.datasetType.columns.dataset_type_id.label("dataset_type_id"), - ] - fromClause: sqlalchemy.sql.expression.FromClause = self._tables.datasetType + coll_col = self._tables.datasetType.columns[self._collectionKeyName].label(self._collectionKeyName) + dataset_type_id_col = self._tables.datasetType.columns.dataset_type_id.label("dataset_type_id") + columns = [coll_col, dataset_type_id_col] + list(self._dataset_type_table.columns) + fromClause: sqlalchemy.sql.expression.FromClause = self._tables.datasetType.join( + self._dataset_type_table + ) for dimension, table in self._tables.dimensions.items(): columns.append(table.columns[dimension.name].label(dimension.name)) fromClause = fromClause.join( @@ -280,72 +324,54 @@ def refresh(self, dataset_types: Mapping[int, DatasetType]) -> None: ), isouter=True, ) + sql = sqlalchemy.sql.select(*columns).select_from(fromClause) + sql = sql.where(coll_col.in_([coll.key for coll in non_chains])) + # For caching we need to fetch complete summaries. + if self._caching_context.collection_summaries is None: + if dataset_type_names is not None: + sql = sql.where(self._dataset_type_table.columns["name"].in_(dataset_type_names)) + # Run the query and construct CollectionSummary objects from the result # rows. This will never include CHAINED collections or collections # with no datasets. - summaries: dict[Any, CollectionSummary] = {} with self._db.query(sql) as sql_result: sql_rows = sql_result.mappings().fetchall() + dataset_type_ids: dict[int, DatasetType] = {} for row in sql_rows: # Collection key should never be None/NULL; it's what we join on. # Extract that and then turn it into a collection name. collectionKey = row[self._collectionKeyName] # dataset_type_id should also never be None/NULL; it's in the first # table we joined. - if datasetType := dataset_types.get(row["dataset_type_id"]): - # See if we have a summary already for this collection; if not, - # make one. - summary = summaries.get(collectionKey) - if summary is None: - summary = CollectionSummary() - summaries[collectionKey] = summary - # Update the dimensions with the values in this row that - # aren't None/NULL (many will be in general, because these - # enter the query via LEFT OUTER JOIN). - summary.dataset_types.add(datasetType) - for dimension in self._tables.dimensions: - value = row[dimension.name] - if value is not None: - summary.governors.setdefault(dimension.name, set()).add(value) - self._cache = summaries - - def get(self, collection: CollectionRecord) -> CollectionSummary: - """Return a summary for the given collection. + dataset_type_id = row["dataset_type_id"] + if (dataset_type := dataset_type_ids.get(dataset_type_id)) is None: + dataset_type_ids[dataset_type_id] = dataset_type = dataset_type_factory(row) + # See if we have a summary already for this collection; if not, + # make one. + summary = summaries.get(collectionKey) + if summary is None: + summary = CollectionSummary() + summaries[collectionKey] = summary + # Update the dimensions with the values in this row that + # aren't None/NULL (many will be in general, because these + # enter the query via LEFT OUTER JOIN). + summary.dataset_types.add(dataset_type) + for dimension in self._tables.dimensions: + value = row[dimension.name] + if value is not None: + summary.governors.setdefault(dimension.name, set()).add(value) - Parameters - ---------- - collection : `CollectionRecord` - Record describing the collection for which a summary is to be - retrieved. + # Add empty summary for any missing collection. + for collection in non_chains: + if collection.key not in summaries: + summaries[collection.key] = CollectionSummary() - Returns - ------- - summary : `CollectionSummary` - Summary of the dataset types and governor dimension values in - this collection. - """ - summary = self._cache.get(collection.key) - if summary is None: - # When we load the summary information from the database, we don't - # create summaries for CHAINED collections; those are created here - # as needed, and *never* cached - we have no good way to update - # those summaries when some a new dataset is added to a child - # colletion. - if collection.type is CollectionType.CHAINED: - assert isinstance(collection, ChainedCollectionRecord) - child_summaries = [self.get(self._collections.find(child)) for child in collection.children] - if child_summaries: - summary = CollectionSummary.union(*child_summaries) - else: - summary = CollectionSummary() - else: - # Either this collection doesn't have any datasets yet, or the - # only datasets it has were created by some other process since - # the last call to refresh. We assume the former; the user is - # responsible for calling refresh if they want to read - # concurrently-written things. We do remember this in the - # cache. - summary = CollectionSummary() - self._cache[collection.key] = summary - return summary + # Merge children into their chains summaries. + for chain, children in chains.items(): + summaries[chain.key] = CollectionSummary.union(*(summaries[child.key] for child in children)) + + if self._caching_context.collection_summaries is not None: + self._caching_context.collection_summaries.update(summaries) + + return summaries diff --git a/python/lsst/daf/butler/registry/interfaces/_collections.py b/python/lsst/daf/butler/registry/interfaces/_collections.py index 2bc5bf30c6..837ede94e0 100644 --- a/python/lsst/daf/butler/registry/interfaces/_collections.py +++ b/python/lsst/daf/butler/registry/interfaces/_collections.py @@ -36,22 +36,24 @@ ] from abc import abstractmethod -from collections import defaultdict -from collections.abc import Iterator, Set -from typing import TYPE_CHECKING, Any +from collections.abc import Iterable, Set +from typing import TYPE_CHECKING, Any, Generic, TypeVar from ..._timespan import Timespan -from ...dimensions import DimensionUniverse from .._collection_type import CollectionType from ..wildcards import CollectionWildcard from ._versioning import VersionedExtension, VersionTuple if TYPE_CHECKING: + from .._caching_context import CachingContext from ._database import Database, StaticTablesContext from ._dimensions import DimensionRecordStorageManager -class CollectionRecord: +_Key = TypeVar("_Key") + + +class CollectionRecord(Generic[_Key]): """A struct used to represent a collection in internal `Registry` APIs. User-facing code should always just use a `str` to represent collections. @@ -76,7 +78,7 @@ class CollectionRecord: participate in some subclass equality definition. """ - def __init__(self, key: Any, name: str, type: CollectionType): + def __init__(self, key: _Key, name: str, type: CollectionType): self.key = key self.name = name self.type = type @@ -86,7 +88,7 @@ def __init__(self, key: Any, name: str, type: CollectionType): """Name of the collection (`str`). """ - key: Any + key: _Key """The primary/foreign key value for this collection. """ @@ -111,196 +113,85 @@ def __str__(self) -> str: return self.name -class RunRecord(CollectionRecord): +class RunRecord(CollectionRecord[_Key]): """A subclass of `CollectionRecord` that adds execution information and an interface for updating it. - """ - @abstractmethod - def update(self, host: str | None = None, timespan: Timespan | None = None) -> None: - """Update the database record for this run with new execution - information. - - Values not provided will set to ``NULL`` in the database, not ignored. + Parameters + ---------- + key: `object` + Unique collection key. + name : `str` + Name of the collection. + host : `str`, optional + Name of the host or system on which this run was produced. + timespan: `Timespan`, optional + Begin and end timestamps for the period over which the run was + produced. + """ - Parameters - ---------- - host : `str`, optional - Name of the host or system on which this run was produced. - Detailed form to be set by higher-level convention; from the - `Registry` perspective, this is an entirely opaque value. - timespan : `Timespan`, optional - Begin and end timestamps for the period over which the run was - produced. `None`/``NULL`` values are interpreted as infinite - bounds. - """ - raise NotImplementedError() + host: str | None + """Name of the host or system on which this run was produced (`str` or + `None`). + """ - @property - @abstractmethod - def host(self) -> str | None: - """Return the name of the host or system on which this run was - produced (`str` or `None`). - """ - raise NotImplementedError() + timespan: Timespan + """Begin and end timestamps for the period over which the run was produced. + None`/``NULL`` values are interpreted as infinite bounds. + """ - @property - @abstractmethod - def timespan(self) -> Timespan: - """Begin and end timestamps for the period over which the run was - produced. `None`/``NULL`` values are interpreted as infinite - bounds. - """ - raise NotImplementedError() + def __init__( + self, + key: _Key, + name: str, + *, + host: str | None = None, + timespan: Timespan | None = None, + ): + super().__init__(key=key, name=name, type=CollectionType.RUN) + self.host = host + if timespan is None: + timespan = Timespan(begin=None, end=None) + self.timespan = timespan def __repr__(self) -> str: return f"RunRecord(key={self.key!r}, name={self.name!r})" -class ChainedCollectionRecord(CollectionRecord): +class ChainedCollectionRecord(CollectionRecord[_Key]): """A subclass of `CollectionRecord` that adds the list of child collections in a ``CHAINED`` collection. Parameters ---------- - key - Unique collection ID, can be the same as ``name`` if ``name`` is used - for identification. Usually this is an integer or string, but can be - other database-specific type. + key: `object` + Unique collection key. name : `str` Name of the collection. + children: Iterable[str], + Ordered sequence of names of child collections. """ - def __init__(self, key: Any, name: str, universe: DimensionUniverse): - super().__init__(key=key, name=name, type=CollectionType.CHAINED) - self._children: tuple[str, ...] = () - - @property - def children(self) -> tuple[str, ...]: - """The ordered search path of child collections that define this chain - (`tuple` [ `str` ]). - """ - return self._children - - def update(self, manager: CollectionManager, children: tuple[str, ...], flatten: bool) -> None: - """Redefine this chain to search the given child collections. - - This method should be used by all external code to set children. It - delegates to `_update`, which is what should be overridden by - subclasses. - - Parameters - ---------- - manager : `CollectionManager` - The object that manages this records instance and all records - instances that may appear as its children. - children : `tuple` [ `str` ] - A collection search path that should be resolved to set the child - collections of this chain. - flatten : `bool` - If `True`, recursively flatten out any nested - `~CollectionType.CHAINED` collections in ``children`` first. - - Raises - ------ - ValueError - Raised when the child collections contain a cycle. - """ - children_as_wildcard = CollectionWildcard.from_names(children) - for record in manager.resolve_wildcard( - children_as_wildcard, - flatten_chains=True, - include_chains=True, - collection_types={CollectionType.CHAINED}, - ): - if record == self: - raise ValueError(f"Cycle in collection chaining when defining '{self.name}'.") - if flatten: - children = tuple( - record.name for record in manager.resolve_wildcard(children_as_wildcard, flatten_chains=True) - ) - # Delegate to derived classes to do the database updates. - self._update(manager, children) - # Update the reverse mapping (from child to parents) in the manager, - # by removing the old relationships and adding back in the new ones. - for old_child in self._children: - manager._parents_by_child[manager.find(old_child).key].discard(self.key) - for new_child in children: - manager._parents_by_child[manager.find(new_child).key].add(self.key) - # Actually set this instances sequence of children. - self._children = children - - def refresh(self, manager: CollectionManager) -> None: - """Load children from the database, using the given manager to resolve - collection primary key values into records. - - This method exists to ensure that all collections that may appear in a - chain are known to the manager before any particular chain tries to - retrieve their records from it. `ChainedCollectionRecord` subclasses - can rely on it being called sometime after their own ``__init__`` to - finish construction. - - Parameters - ---------- - manager : `CollectionManager` - The object that manages this records instance and all records - instances that may appear as its children. - """ - # Clear out the old reverse mapping (from child to parents). - for child in self._children: - manager._parents_by_child[manager.find(child).key].discard(self.key) - self._children = self._load(manager) - # Update the reverse mapping (from child to parents) in the manager. - for child in self._children: - manager._parents_by_child[manager.find(child).key].add(self.key) - - @abstractmethod - def _update(self, manager: CollectionManager, children: tuple[str, ...]) -> None: - """Protected implementation hook for `update`. - - This method should be implemented by subclasses to update the database - to reflect the children given. It should never be called by anything - other than `update`, which should be used by all external code. - - Parameters - ---------- - manager : `CollectionManager` - The object that manages this records instance and all records - instances that may appear as its children. - children : `tuple` [ `str` ] - A collection search path that should be resolved to set the child - collections of this chain. Guaranteed not to contain cycles. - """ - raise NotImplementedError() - - @abstractmethod - def _load(self, manager: CollectionManager) -> tuple[str, ...]: - """Protected implementation hook for `refresh`. - - This method should be implemented by subclasses to retrieve the chain's - child collections from the database and return them. It should never - be called by anything other than `refresh`, which should be used by all - external code. - - Parameters - ---------- - manager : `CollectionManager` - The object that manages this records instance and all records - instances that may appear as its children. + children: tuple[str, ...] + """The ordered search path of child collections that define this chain + (`tuple` [ `str` ]). + """ - Returns - ------- - children : `tuple` [ `str` ] - The ordered sequence of collection names that defines the chained - collection. Guaranteed not to contain cycles. - """ - raise NotImplementedError() + def __init__( + self, + key: Any, + name: str, + *, + children: Iterable[str], + ): + super().__init__(key=key, name=name, type=CollectionType.CHAINED) + self.children = tuple(children) def __repr__(self) -> str: return f"ChainedCollectionRecord(key={self.key!r}, name={self.name!r}, children={self.children!r})" -class CollectionManager(VersionedExtension): +class CollectionManager(Generic[_Key], VersionedExtension): """An interface for managing the collections (including runs) in a `Registry`. @@ -315,7 +206,6 @@ class CollectionManager(VersionedExtension): def __init__(self, *, registry_schema_version: VersionTuple | None = None) -> None: super().__init__(registry_schema_version=registry_schema_version) - self._parents_by_child: defaultdict[Any, set[Any]] = defaultdict(set) @classmethod @abstractmethod @@ -325,6 +215,7 @@ def initialize( context: StaticTablesContext, *, dimensions: DimensionRecordStorageManager, + caching_context: CachingContext, registry_schema_version: VersionTuple | None = None, ) -> CollectionManager: """Construct an instance of the manager. @@ -339,6 +230,8 @@ def initialize( implemented with this manager. dimensions : `DimensionRecordStorageManager` Manager object for the dimensions in this `Registry`. + caching_context : `CachingContext` + Object controlling caching of information returned by managers. registry_schema_version : `VersionTuple` or `None` Schema version of this extension as defined in registry. @@ -481,7 +374,7 @@ def refresh(self) -> None: @abstractmethod def register( self, name: str, type: CollectionType, doc: str | None = None - ) -> tuple[CollectionRecord, bool]: + ) -> tuple[CollectionRecord[_Key], bool]: """Ensure that a collection of the given name and type are present in the layer this manager is associated with. @@ -547,7 +440,7 @@ def remove(self, name: str) -> None: raise NotImplementedError() @abstractmethod - def find(self, name: str) -> CollectionRecord: + def find(self, name: str) -> CollectionRecord[_Key]: """Return the collection record associated with the given name. Parameters @@ -576,7 +469,7 @@ def find(self, name: str) -> CollectionRecord: raise NotImplementedError() @abstractmethod - def __getitem__(self, key: Any) -> CollectionRecord: + def __getitem__(self, key: Any) -> CollectionRecord[_Key]: """Return the collection record associated with the given primary/foreign key value. @@ -614,7 +507,7 @@ def resolve_wildcard( done: set[str] | None = None, flatten_chains: bool = True, include_chains: bool | None = None, - ) -> list[CollectionRecord]: + ) -> list[CollectionRecord[_Key]]: """Iterate over collection records that match a wildcard. Parameters @@ -632,10 +525,10 @@ def resolve_wildcard( If `True` (default) recursively yield the child collections of `~CollectionType.CHAINED` collections. include_chains : `bool`, optional - If `False`, return records for `~CollectionType.CHAINED` + If `True`, return records for `~CollectionType.CHAINED` collections themselves. The default is the opposite of - ``flattenChains``: either return records for CHAINED collections or - their children, but not both. + ``flatten_chains``: either return records for CHAINED collections + or their children, but not both. Returns ------- @@ -645,7 +538,7 @@ def resolve_wildcard( raise NotImplementedError() @abstractmethod - def getDocumentation(self, key: Any) -> str | None: + def getDocumentation(self, key: _Key) -> str | None: """Retrieve the documentation string for a collection. Parameters @@ -661,7 +554,7 @@ def getDocumentation(self, key: Any) -> str | None: raise NotImplementedError() @abstractmethod - def setDocumentation(self, key: Any, doc: str | None) -> None: + def setDocumentation(self, key: _Key, doc: str | None) -> None: """Set the documentation string for a collection. Parameters @@ -673,16 +566,37 @@ def setDocumentation(self, key: Any, doc: str | None) -> None: """ raise NotImplementedError() - def getParentChains(self, key: Any) -> Iterator[ChainedCollectionRecord]: - """Find all CHAINED collections that directly contain the given + @abstractmethod + def getParentChains(self, key: _Key) -> set[str]: + """Find all CHAINED collection names that directly contain the given collection. Parameters ---------- key Internal primary key value for the collection. + + Returns + ------- + names : `set` [`str`] + Parent collection names. """ - for parent_key in self._parents_by_child[key]: - result = self[parent_key] - assert isinstance(result, ChainedCollectionRecord) - yield result + raise NotImplementedError() + + @abstractmethod + def update_chain( + self, record: ChainedCollectionRecord[_Key], children: Iterable[str], flatten: bool = False + ) -> ChainedCollectionRecord[_Key]: + """Update chained collection composition. + + Parameters + ---------- + record : `ChainedCollectionRecord` + Chained collection record. + children : `~collections.abc.Iterable` [`str`] + Ordered names of children collections. + flatten : `bool`, optional + If `True`, recursively flatten out any nested + `~CollectionType.CHAINED` collections in ``children`` first. + """ + raise NotImplementedError() diff --git a/python/lsst/daf/butler/registry/interfaces/_database.py b/python/lsst/daf/butler/registry/interfaces/_database.py index 438bf55613..61bc9e2440 100644 --- a/python/lsst/daf/butler/registry/interfaces/_database.py +++ b/python/lsst/daf/butler/registry/interfaces/_database.py @@ -140,10 +140,6 @@ def addTable(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema.Table: relationships. """ name = self._db._mangleTableName(name) - if name in self._tableNames: - _checkExistingTableDefinition( - name, spec, self._inspector.get_columns(name, schema=self._db.namespace) - ) metadata = self._db._metadata assert metadata is not None, "Guaranteed by context manager that returns this object." table = self._db._convertTableSpec(name, spec, metadata) diff --git a/python/lsst/daf/butler/registry/interfaces/_datasets.py b/python/lsst/daf/butler/registry/interfaces/_datasets.py index 3424028804..ac459f3c8b 100644 --- a/python/lsst/daf/butler/registry/interfaces/_datasets.py +++ b/python/lsst/daf/butler/registry/interfaces/_datasets.py @@ -32,7 +32,7 @@ __all__ = ("DatasetRecordStorageManager", "DatasetRecordStorage") from abc import ABC, abstractmethod -from collections.abc import Iterable, Iterator, Set +from collections.abc import Iterable, Iterator, Mapping, Set from typing import TYPE_CHECKING, Any from lsst.daf.relation import Relation @@ -45,6 +45,7 @@ from ._versioning import VersionedExtension, VersionTuple if TYPE_CHECKING: + from .._caching_context import CachingContext from .._collection_summary import CollectionSummary from ..queries import SqlQueryContext from ._collections import CollectionManager, CollectionRecord, RunRecord @@ -329,6 +330,7 @@ def initialize( *, collections: CollectionManager, dimensions: DimensionRecordStorageManager, + caching_context: CachingContext, registry_schema_version: VersionTuple | None = None, ) -> DatasetRecordStorageManager: """Construct an instance of the manager. @@ -344,6 +346,8 @@ def initialize( Manager object for the collections in this `Registry`. dimensions : `DimensionRecordStorageManager` Manager object for the dimensions in this `Registry`. + caching_context : `CachingContext` + Object controlling caching of information returned by managers. registry_schema_version : `VersionTuple` or `None` Schema version of this extension as defined in registry. @@ -487,7 +491,7 @@ def find(self, name: str) -> DatasetRecordStorage | None: raise NotImplementedError() @abstractmethod - def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool]: + def register(self, datasetType: DatasetType) -> bool: """Ensure that this `Registry` can hold records for the given `DatasetType`, creating new tables as necessary. @@ -499,8 +503,6 @@ def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool Returns ------- - records : `DatasetRecordStorage` - The object representing the records for the given dataset type. inserted : `bool` `True` if the dataset type did not exist in the registry before. @@ -603,6 +605,29 @@ def getCollectionSummary(self, collection: CollectionRecord) -> CollectionSummar """ raise NotImplementedError() + @abstractmethod + def fetch_summaries( + self, collections: Iterable[CollectionRecord], dataset_types: Iterable[DatasetType] | None = None + ) -> Mapping[Any, CollectionSummary]: + """Fetch collection summaries given their names and dataset types. + + Parameters + ---------- + collections : `~collections.abc.Iterable` [`CollectionRecord`] + Collection records to query. + dataset_types : `~collections.abc.Iterable` [`DatasetType`] or `None` + Dataset types to include into returned summaries. If `None` then + all dataset types will be included. + + Returns + ------- + summaries : `~collections.abc.Mapping` [`Any`, `CollectionSummary`] + Collection summaries indexed by collection record key. This mapping + will also contain all nested non-chained collections of the chained + collections. + """ + raise NotImplementedError() + @abstractmethod def ingest_date_dtype(self) -> type: """Return type of the ``ingest_date`` column.""" diff --git a/python/lsst/daf/butler/registry/managers.py b/python/lsst/daf/butler/registry/managers.py index 1d80fcde51..48702c2fc9 100644 --- a/python/lsst/daf/butler/registry/managers.py +++ b/python/lsst/daf/butler/registry/managers.py @@ -45,6 +45,7 @@ from .._column_type_info import ColumnTypeInfo from .._config import Config from ..dimensions import DimensionConfig, DimensionUniverse +from ._caching_context import CachingContext from ._config import RegistryConfig from .interfaces import ( ButlerAttributeManager, @@ -353,6 +354,11 @@ class RegistryManagerInstances( and registry instances, including the dimension universe. """ + caching_context: CachingContext + """Object containing caches for for various information generated by + managers. + """ + @classmethod def initialize( cls, @@ -361,6 +367,7 @@ def initialize( *, types: RegistryManagerTypes, universe: DimensionUniverse, + caching_context: CachingContext | None = None, ) -> RegistryManagerInstances: """Construct manager instances from their types and an existing database connection. @@ -383,6 +390,8 @@ def initialize( instances : `RegistryManagerInstances` Struct containing manager instances. """ + if caching_context is None: + caching_context = CachingContext() dummy_table = ddl.TableSpec(fields=()) kwargs: dict[str, Any] = {} schema_versions = types.schema_versions @@ -396,6 +405,7 @@ def initialize( database, context, dimensions=kwargs["dimensions"], + caching_context=caching_context, registry_schema_version=schema_versions.get("collections"), ) datasets = types.datasets.initialize( @@ -404,6 +414,7 @@ def initialize( collections=kwargs["collections"], dimensions=kwargs["dimensions"], registry_schema_version=schema_versions.get("datasets"), + caching_context=caching_context, ) kwargs["datasets"] = datasets kwargs["opaque"] = types.opaque.initialize( @@ -440,6 +451,7 @@ def initialize( run_key_spec=types.collections.addRunForeignKey(dummy_table, primaryKey=False, nullable=False), ingest_date_dtype=datasets.ingest_date_dtype(), ) + kwargs["caching_context"] = caching_context return cls(**kwargs) def as_dict(self) -> Mapping[str, VersionedExtension]: @@ -453,7 +465,9 @@ def as_dict(self) -> Mapping[str, VersionedExtension]: manager instance. Only existing managers are returned. """ instances = { - f.name: getattr(self, f.name) for f in dataclasses.fields(self) if f.name != "column_types" + f.name: getattr(self, f.name) + for f in dataclasses.fields(self) + if f.name not in ("column_types", "caching_context") } return {key: value for key, value in instances.items() if value is not None} diff --git a/python/lsst/daf/butler/registry/queries/_sql_query_backend.py b/python/lsst/daf/butler/registry/queries/_sql_query_backend.py index db574dde44..fc5866e8ba 100644 --- a/python/lsst/daf/butler/registry/queries/_sql_query_backend.py +++ b/python/lsst/daf/butler/registry/queries/_sql_query_backend.py @@ -121,6 +121,7 @@ def filter_dataset_collections( result: dict[DatasetType, list[CollectionRecord]] = { dataset_type: [] for dataset_type in dataset_types } + summaries = self._managers.datasets.fetch_summaries(collections, result.keys()) for dataset_type, filtered_collections in result.items(): for collection_record in collections: if not dataset_type.isCalibration() and collection_record.type is CollectionType.CALIBRATION: @@ -130,7 +131,7 @@ def filter_dataset_collections( f"in CALIBRATION collection {collection_record.name!r}." ) else: - collection_summary = self._managers.datasets.getCollectionSummary(collection_record) + collection_summary = summaries[collection_record.key] if collection_summary.is_compatible_with( dataset_type, governor_constraints, diff --git a/python/lsst/daf/butler/registry/sql_registry.py b/python/lsst/daf/butler/registry/sql_registry.py index 733f820941..63c22a97d7 100644 --- a/python/lsst/daf/butler/registry/sql_registry.py +++ b/python/lsst/daf/butler/registry/sql_registry.py @@ -324,6 +324,13 @@ def refresh(self) -> None: with self._db.transaction(): self._managers.refresh() + @contextlib.contextmanager + def caching_context(self) -> Iterator[None]: + """Context manager that enables caching.""" + self._managers.caching_context.enable() + yield + self._managers.caching_context.disable() + @contextlib.contextmanager def transaction(self, *, savepoint: bool = False) -> Iterator[None]: """Return a context manager that represents a transaction.""" @@ -603,7 +610,7 @@ def setCollectionChain(self, parent: str, children: Any, *, flatten: bool = Fals assert isinstance(record, ChainedCollectionRecord) children = CollectionWildcard.from_expression(children).require_ordered() if children != record.children or flatten: - record.update(self._managers.collections, children, flatten=flatten) + self._managers.collections.update_chain(record, children, flatten=flatten) def getCollectionParentChains(self, collection: str) -> set[str]: """Return the CHAINED collections that directly contain the given one. @@ -618,12 +625,7 @@ def getCollectionParentChains(self, collection: str) -> set[str]: chains : `set` of `str` Set of `~CollectionType.CHAINED` collection names. """ - return { - record.name - for record in self._managers.collections.getParentChains( - self._managers.collections.find(collection).key - ) - } + return self._managers.collections.getParentChains(self._managers.collections.find(collection).key) def getCollectionDocumentation(self, collection: str) -> str | None: """Retrieve the documentation string for a collection. @@ -702,8 +704,7 @@ def registerDatasetType(self, datasetType: DatasetType) -> bool: This method cannot be called within transactions, as it needs to be able to perform its own transaction to be concurrent. """ - _, inserted = self._managers.datasets.register(datasetType) - return inserted + return self._managers.datasets.register(datasetType) def removeDatasetType(self, name: str | tuple[str, ...]) -> None: """Remove the named `DatasetType` from the registry. diff --git a/python/lsst/daf/butler/remote_butler/_remote_butler.py b/python/lsst/daf/butler/remote_butler/_remote_butler.py index 841735a28b..1930461ec0 100644 --- a/python/lsst/daf/butler/remote_butler/_remote_butler.py +++ b/python/lsst/daf/butler/remote_butler/_remote_butler.py @@ -158,6 +158,12 @@ def _simplify_dataId( # Assume we can treat it as a dict. return SerializedDataCoordinate(dataId=data_id) + def _caching_context(self) -> AbstractContextManager[None]: + # Docstring inherited. + # Not implemented for now, will have to think whether this needs to + # do something on client side and/or remote side. + raise NotImplementedError() + def transaction(self) -> AbstractContextManager[None]: """Will always raise NotImplementedError. Transactions are not supported by RemoteButler.