diff --git a/python/lsst/daf/butler/_butler.py b/python/lsst/daf/butler/_butler.py
index f83a4ad347..937ff8e6c9 100644
--- a/python/lsst/daf/butler/_butler.py
+++ b/python/lsst/daf/butler/_butler.py
@@ -482,6 +482,11 @@ def get_known_repos(cls) -> set[str]:
"""
return ButlerRepoIndex.get_known_repos()
+ @abstractmethod
+ def _caching_context(self) -> AbstractContextManager[None]:
+ """Context manager that enables caching."""
+ raise NotImplementedError()
+
@abstractmethod
def transaction(self) -> AbstractContextManager[None]:
"""Context manager supporting `Butler` transactions.
diff --git a/python/lsst/daf/butler/_named.py b/python/lsst/daf/butler/_named.py
index 55850645bd..e3ff851b72 100644
--- a/python/lsst/daf/butler/_named.py
+++ b/python/lsst/daf/butler/_named.py
@@ -266,7 +266,7 @@ def freeze(self) -> NamedKeyMapping[K, V]:
to a new variable (and considering any previous references
invalidated) should allow for more accurate static type checking.
"""
- if not isinstance(self._dict, MappingProxyType):
+ if not isinstance(self._dict, MappingProxyType): # type: ignore[unreachable]
self._dict = MappingProxyType(self._dict) # type: ignore
return self
@@ -578,7 +578,7 @@ def freeze(self) -> NamedValueAbstractSet[K]:
to a new variable (and considering any previous references
invalidated) should allow for more accurate static type checking.
"""
- if not isinstance(self._mapping, MappingProxyType):
+ if not isinstance(self._mapping, MappingProxyType): # type: ignore[unreachable]
self._mapping = MappingProxyType(self._mapping) # type: ignore
return self
diff --git a/python/lsst/daf/butler/_registry_shim.py b/python/lsst/daf/butler/_registry_shim.py
index 67f50a16e1..4d2653abe0 100644
--- a/python/lsst/daf/butler/_registry_shim.py
+++ b/python/lsst/daf/butler/_registry_shim.py
@@ -102,6 +102,10 @@ def refresh(self) -> None:
# Docstring inherited from a base class.
self._registry.refresh()
+ def caching_context(self) -> contextlib.AbstractContextManager[None]:
+ # Docstring inherited from a base class.
+ return self._butler._caching_context()
+
@contextlib.contextmanager
def transaction(self, *, savepoint: bool = False) -> Iterator[None]:
# Docstring inherited from a base class.
diff --git a/python/lsst/daf/butler/direct_butler.py b/python/lsst/daf/butler/direct_butler.py
index 6b70ecb1e1..80bcf5d1ae 100644
--- a/python/lsst/daf/butler/direct_butler.py
+++ b/python/lsst/daf/butler/direct_butler.py
@@ -299,6 +299,10 @@ def isWriteable(self) -> bool:
# Docstring inherited.
return self._registry.isWriteable()
+ def _caching_context(self) -> contextlib.AbstractContextManager[None]:
+ """Context manager that enables caching."""
+ return self._registry.caching_context()
+
@contextlib.contextmanager
def transaction(self) -> Iterator[None]:
"""Context manager supporting `Butler` transactions.
diff --git a/python/lsst/daf/butler/persistence_context.py b/python/lsst/daf/butler/persistence_context.py
index b366564d45..8830fd9550 100644
--- a/python/lsst/daf/butler/persistence_context.py
+++ b/python/lsst/daf/butler/persistence_context.py
@@ -33,7 +33,7 @@
import uuid
from collections.abc import Callable, Hashable
from contextvars import Context, ContextVar, Token, copy_context
-from typing import TYPE_CHECKING, ParamSpec, TypeVar, cast
+from typing import TYPE_CHECKING, ParamSpec, TypeVar
if TYPE_CHECKING:
from ._dataset_ref import DatasetRef
@@ -198,4 +198,4 @@ def run(self, function: Callable[_Q, _T], *args: _Q.args, **kwargs: _Q.kwargs) -
# cast the result as we know this is exactly what the return type will
# be.
result = self._ctx.run(self._functionRunner, function, *args, **kwargs) # type: ignore
- return cast(_T, result)
+ return result
diff --git a/python/lsst/daf/butler/registry/_caching_context.py b/python/lsst/daf/butler/registry/_caching_context.py
new file mode 100644
index 0000000000..2674a54f69
--- /dev/null
+++ b/python/lsst/daf/butler/registry/_caching_context.py
@@ -0,0 +1,79 @@
+# This file is part of daf_butler.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This software is dual licensed under the GNU General Public License and also
+# under a 3-clause BSD license. Recipients may choose which of these licenses
+# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
+# respectively. If you choose the GPL option then the following text applies
+# (but note that there is still no warranty even if you opt for BSD instead):
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+__all__ = ["CachingContext"]
+
+from typing import TYPE_CHECKING
+
+from ._collection_record_cache import CollectionRecordCache
+from ._collection_summary_cache import CollectionSummaryCache
+from ._dataset_type_cache import DatasetTypeCache
+
+if TYPE_CHECKING:
+ from .interfaces import DatasetRecordStorage
+
+
+class CachingContext:
+ """Collection of caches for various types of records retrieved from
+ database.
+
+ Notes
+ -----
+ Caching is usually disabled for most of the record types, but it can be
+ explicitly and temporarily enabled in some context (e.g. quantum graph
+ building) using Registry method. This class is a collection of cache
+ instances which will be `None` when caching is disabled. Instance of this
+ class is passed to the relevant managers that can use it to query or
+ populate caches when caching is enabled.
+
+ Dataset type cache is always enabled for now, this avoids the need for
+ explicitly enabling caching in pipetask executors.
+ """
+
+ collection_records: CollectionRecordCache | None = None
+ """Cache for collection records (`CollectionRecordCache`)."""
+
+ collection_summaries: CollectionSummaryCache | None = None
+ """Cache for collection summary records (`CollectionSummaryCache`)."""
+
+ dataset_types: DatasetTypeCache[DatasetRecordStorage]
+ """Cache for dataset types, never disabled (`DatasetTypeCache`)."""
+
+ def __init__(self) -> None:
+ self.dataset_types = DatasetTypeCache()
+
+ def enable(self) -> None:
+ """Enable caches, initializes all caches."""
+ self.collection_records = CollectionRecordCache()
+ self.collection_summaries = CollectionSummaryCache()
+
+ def disable(self) -> None:
+ """Disable caches, sets all caches to `None`."""
+ self.collection_records = None
+ self.collection_summaries = None
diff --git a/python/lsst/daf/butler/registry/_collection_record_cache.py b/python/lsst/daf/butler/registry/_collection_record_cache.py
new file mode 100644
index 0000000000..da00bb6a35
--- /dev/null
+++ b/python/lsst/daf/butler/registry/_collection_record_cache.py
@@ -0,0 +1,165 @@
+# This file is part of daf_butler.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This software is dual licensed under the GNU General Public License and also
+# under a 3-clause BSD license. Recipients may choose which of these licenses
+# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
+# respectively. If you choose the GPL option then the following text applies
+# (but note that there is still no warranty even if you opt for BSD instead):
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+__all__ = ("CollectionRecordCache",)
+
+from collections.abc import Iterable, Iterator
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+ from .interfaces import CollectionRecord
+
+
+class CollectionRecordCache:
+ """Cache for collection records.
+
+ Notes
+ -----
+ This class stores collection records and can retrieve them using either
+ collection name or collection key. One complication is that key type can be
+ either collection name or a distinct integer value. To optimize storage
+ when the key is the same as collection name, this class only stores key to
+ record mapping when key is of a non-string type.
+
+ In come contexts (e.g. ``resolve_wildcard``) a full list of collections is
+ needed. To signify that cache content can be used in such contexts, cache
+ defines special ``full`` flag that needs to be set by client.
+ """
+
+ def __init__(self) -> None:
+ self._by_name: dict[str, CollectionRecord] = {}
+ # This dict is only used for records whose key type is not str.
+ self._by_key: dict[Any, CollectionRecord] = {}
+ self._full = False
+
+ @property
+ def full(self) -> bool:
+ """`True` if cache holds all known collection records (`bool`)."""
+ return self._full
+
+ def add(self, record: CollectionRecord) -> None:
+ """Add one record to the cache.
+
+ Parameters
+ ----------
+ record : `CollectionRecord`
+ Collection record, replaces any existing record with the same name
+ or key.
+ """
+ # In case we replace same record name with different key, find the
+ # existing record and drop its key.
+ if (old_record := self._by_name.get(record.name)) is not None:
+ self._by_key.pop(old_record.key)
+ if (old_record := self._by_key.get(record.key)) is not None:
+ self._by_name.pop(old_record.name)
+ self._by_name[record.name] = record
+ if not isinstance(record.key, str):
+ self._by_key[record.key] = record
+
+ def set(self, records: Iterable[CollectionRecord], *, full: bool = False) -> None:
+ """Replace cache contents with the new set of records.
+
+ Parameters
+ ----------
+ records : `~collections.abc.Iterable` [`CollectionRecord`]
+ Collection records.
+ full : `bool`
+ If `True` then ``records`` contain all known collection records.
+ """
+ self.clear()
+ for record in records:
+ self._by_name[record.name] = record
+ if not isinstance(record.key, str):
+ self._by_key[record.key] = record
+ self._full = full
+
+ def clear(self) -> None:
+ """Remove all records from the cache."""
+ self._by_name = {}
+ self._by_key = {}
+ self._full = False
+
+ def discard(self, record: CollectionRecord) -> None:
+ """Remove single record from the cache.
+
+ Parameters
+ ----------
+ record : `CollectionRecord`
+ Collection record to remove.
+ """
+ self._by_name.pop(record.name, None)
+ if not isinstance(record.key, str):
+ self._by_key.pop(record.key, None)
+
+ def get_by_name(self, name: str) -> CollectionRecord | None:
+ """Return collection record given its name.
+
+ Parameters
+ ----------
+ name : `str`
+ Collection name.
+
+ Returns
+ -------
+ record : `CollectionRecord` or `None`
+ Collection record, `None` is returned if the name is not in the
+ cache.
+ """
+ return self._by_name.get(name)
+
+ def get_by_key(self, key: Any) -> CollectionRecord | None:
+ """Return collection record given its key.
+
+ Parameters
+ ----------
+ key : `Any`
+ Collection key.
+
+ Returns
+ -------
+ record : `CollectionRecord` or `None`
+ Collection record, `None` is returned if the key is not in the
+ cache.
+ """
+ if isinstance(key, str):
+ return self._by_name.get(key)
+ return self._by_key.get(key)
+
+ def records(self) -> Iterator[CollectionRecord]:
+ """Return iterator for the set of records in the cache, can only be
+ used if `full` is true.
+
+ Raises
+ ------
+ RuntimeError
+ Raised if ``self.full`` is `False`.
+ """
+ if not self._full:
+ raise RuntimeError("cannot call records() if cache is not full")
+ return iter(self._by_name.values())
diff --git a/python/lsst/daf/butler/registry/_collection_summary_cache.py b/python/lsst/daf/butler/registry/_collection_summary_cache.py
new file mode 100644
index 0000000000..ed5b2f2fa2
--- /dev/null
+++ b/python/lsst/daf/butler/registry/_collection_summary_cache.py
@@ -0,0 +1,86 @@
+# This file is part of daf_butler.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This software is dual licensed under the GNU General Public License and also
+# under a 3-clause BSD license. Recipients may choose which of these licenses
+# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
+# respectively. If you choose the GPL option then the following text applies
+# (but note that there is still no warranty even if you opt for BSD instead):
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+__all__ = ("CollectionSummaryCache",)
+
+from collections.abc import Iterable, Mapping
+from typing import Any
+
+from ._collection_summary import CollectionSummary
+
+
+class CollectionSummaryCache:
+ """Cache for collection summaries.
+
+ Notes
+ -----
+ This class stores `CollectionSummary` records indexed by collection keys.
+ For cache to be usable the records that are given to `update` method have
+ to include all dataset types, i.e. the query that produces records should
+ not be constrained by dataset type.
+ """
+
+ def __init__(self) -> None:
+ self._cache: dict[Any, CollectionSummary] = {}
+
+ def update(self, summaries: Mapping[Any, CollectionSummary]) -> None:
+ """Add records to the cache.
+
+ Parameters
+ ----------
+ summaries : `~collections.abc.Mapping` [`Any`, `CollectionSummary`]
+ Summary records indexed by collection key, records must include all
+ dataset types.
+ """
+ self._cache.update(summaries)
+
+ def find_summaries(self, keys: Iterable[Any]) -> tuple[dict[Any, CollectionSummary], set[Any]]:
+ """Return summary records given a set of keys.
+
+ Parameters
+ ----------
+ keys : `~collections.abc.Iterable` [`Any`]
+ Sequence of collection keys.
+
+ Returns
+ -------
+ summaries : `dict` [`Any`, `CollectionSummary`]
+ Dictionary of summaries indexed by collection keys, includes
+ records found in the cache.
+ missing_keys : `set` [`Any`]
+ Collection keys that are not present in the cache.
+ """
+ found = {}
+ not_found = set()
+ for key in keys:
+ if (summary := self._cache.get(key)) is not None:
+ found[key] = summary
+ else:
+ not_found.add(key)
+ return found, not_found
diff --git a/python/lsst/daf/butler/registry/_dataset_type_cache.py b/python/lsst/daf/butler/registry/_dataset_type_cache.py
new file mode 100644
index 0000000000..3f1665dfa3
--- /dev/null
+++ b/python/lsst/daf/butler/registry/_dataset_type_cache.py
@@ -0,0 +1,162 @@
+# This file is part of daf_butler.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This software is dual licensed under the GNU General Public License and also
+# under a 3-clause BSD license. Recipients may choose which of these licenses
+# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
+# respectively. If you choose the GPL option then the following text applies
+# (but note that there is still no warranty even if you opt for BSD instead):
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+__all__ = ("DatasetTypeCache",)
+
+from collections.abc import Iterable, Iterator
+from typing import Generic, TypeVar
+
+from .._dataset_type import DatasetType
+
+_T = TypeVar("_T")
+
+
+class DatasetTypeCache(Generic[_T]):
+ """Cache for dataset types.
+
+ Notes
+ -----
+ This class caches mapping of dataset type name to a corresponding
+ `DatasetType` instance. Registry manager also needs to cache corresponding
+ "storage" instance, so this class allows storing additional opaque object
+ along with the dataset type.
+
+ In come contexts (e.g. ``resolve_wildcard``) a full list of dataset types
+ is needed. To signify that cache content can be used in such contexts,
+ cache defines special ``full`` flag that needs to be set by client.
+ """
+
+ def __init__(self) -> None:
+ self._cache: dict[str, tuple[DatasetType, _T | None]] = {}
+ self._full = False
+
+ @property
+ def full(self) -> bool:
+ """`True` if cache holds all known dataset types (`bool`)."""
+ return self._full
+
+ def add(self, dataset_type: DatasetType, extra: _T | None = None) -> None:
+ """Add one record to the cache.
+
+ Parameters
+ ----------
+ dataset_type : `DatasetType`
+ Dataset type, replaces any existing dataset type with the same
+ name.
+ extra : `Any`, optional
+ Additional opaque object stored with this dataset type.
+ """
+ self._cache[dataset_type.name] = (dataset_type, extra)
+
+ def set(self, data: Iterable[DatasetType | tuple[DatasetType, _T | None]], *, full: bool = False) -> None:
+ """Replace cache contents with the new set of dataset types.
+
+ Parameters
+ ----------
+ data : `~collections.abc.Iterable`
+ Sequence of `DatasetType` instances or tuples of `DatasetType` and
+ an extra opaque object.
+ full : `bool`
+ If `True` then ``data`` contains all known dataset types.
+ """
+ self.clear()
+ for item in data:
+ if isinstance(item, DatasetType):
+ item = (item, None)
+ self._cache[item[0].name] = item
+ self._full = full
+
+ def clear(self) -> None:
+ """Remove everything from the cache."""
+ self._cache = {}
+ self._full = False
+
+ def discard(self, name: str) -> None:
+ """Remove named dataset type from the cache.
+
+ Parameters
+ ----------
+ name : `str`
+ Name of the dataset type to remove.
+ """
+ self._cache.pop(name, None)
+
+ def get(self, name: str) -> tuple[DatasetType | None, _T | None]:
+ """Return cached info given dataset type name.
+
+ Parameters
+ ----------
+ name : `str`
+ Dataset type name.
+
+ Returns
+ -------
+ dataset_type : `DatasetType` or `None`
+ Cached dataset type, `None` is returned if the name is not in the
+ cache.
+ extra : `Any` or `None`
+ Cached opaque data, `None` is returned if the name is not in the
+ cache or no extra info was stored for this dataset type.
+ """
+ item = self._cache.get(name)
+ if item is None:
+ return (None, None)
+ return item
+
+ def get_dataset_type(self, name: str) -> DatasetType | None:
+ """Return dataset type given its name.
+
+ Parameters
+ ----------
+ name : `str`
+ Dataset type name.
+
+ Returns
+ -------
+ dataset_type : `DatasetType` or `None`
+ Cached dataset type, `None` is returned if the name is not in the
+ cache.
+ """
+ item = self._cache.get(name)
+ if item is None:
+ return None
+ return item[0]
+
+ def items(self) -> Iterator[tuple[DatasetType, _T | None]]:
+ """Return iterator for the set of items in the cache, can only be
+ used if `full` is true.
+
+ Raises
+ ------
+ RuntimeError
+ Raised if ``self.full`` is `False`.
+ """
+ if not self._full:
+ raise RuntimeError("cannot call items() if cache is not full")
+ return iter(self._cache.values())
diff --git a/python/lsst/daf/butler/registry/_registry.py b/python/lsst/daf/butler/registry/_registry.py
index 2f0cb3231d..21e651314c 100644
--- a/python/lsst/daf/butler/registry/_registry.py
+++ b/python/lsst/daf/butler/registry/_registry.py
@@ -118,6 +118,11 @@ def refresh(self) -> None:
"""
raise NotImplementedError()
+ @abstractmethod
+ def caching_context(self) -> contextlib.AbstractContextManager[None]:
+ """Context manager that enables caching."""
+ raise NotImplementedError()
+
@contextlib.contextmanager
@abstractmethod
def transaction(self, *, savepoint: bool = False) -> Iterator[None]:
diff --git a/python/lsst/daf/butler/registry/collections/_base.py b/python/lsst/daf/butler/registry/collections/_base.py
index f611cd630c..f39ee607e9 100644
--- a/python/lsst/daf/butler/registry/collections/_base.py
+++ b/python/lsst/daf/butler/registry/collections/_base.py
@@ -34,18 +34,18 @@
from abc import abstractmethod
from collections import namedtuple
from collections.abc import Iterable, Iterator, Set
-from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
+from typing import TYPE_CHECKING, Any, TypeVar, cast
import sqlalchemy
-from ..._timespan import Timespan, TimespanDatabaseRepresentation
-from ...dimensions import DimensionUniverse
+from ..._timespan import TimespanDatabaseRepresentation
from .._collection_type import CollectionType
from .._exceptions import MissingCollectionError
from ..interfaces import ChainedCollectionRecord, CollectionManager, CollectionRecord, RunRecord, VersionTuple
from ..wildcards import CollectionWildcard
if TYPE_CHECKING:
+ from .._caching_context import CachingContext
from ..interfaces import Database, DimensionRecordStorageManager
@@ -157,150 +157,10 @@ def makeCollectionChainTableSpec(collectionIdName: str, collectionIdType: type)
)
-class DefaultRunRecord(RunRecord):
- """Default `RunRecord` implementation.
-
- This method assumes the same run table definition as produced by
- `makeRunTableSpec` method. The only non-fixed name in the schema
- is the PK column name, this needs to be passed in a constructor.
-
- Parameters
- ----------
- db : `Database`
- Registry database.
- key
- Unique collection ID, can be the same as ``name`` if ``name`` is used
- for identification. Usually this is an integer or string, but can be
- other database-specific type.
- name : `str`
- Run collection name.
- table : `sqlalchemy.schema.Table`
- Table for run records.
- idColumnName : `str`
- Name of the identifying column in run table.
- host : `str`, optional
- Name of the host where run was produced.
- timespan : `Timespan`, optional
- Timespan for this run.
- """
-
- def __init__(
- self,
- db: Database,
- key: Any,
- name: str,
- *,
- table: sqlalchemy.schema.Table,
- idColumnName: str,
- host: str | None = None,
- timespan: Timespan | None = None,
- ):
- super().__init__(key=key, name=name, type=CollectionType.RUN)
- self._db = db
- self._table = table
- self._host = host
- if timespan is None:
- timespan = Timespan(begin=None, end=None)
- self._timespan = timespan
- self._idName = idColumnName
-
- def update(self, host: str | None = None, timespan: Timespan | None = None) -> None:
- # Docstring inherited from RunRecord.
- if timespan is None:
- timespan = Timespan(begin=None, end=None)
- row = {
- self._idName: self.key,
- "host": host,
- }
- self._db.getTimespanRepresentation().update(timespan, result=row)
- count = self._db.update(self._table, {self._idName: self.key}, row)
- if count != 1:
- raise RuntimeError(f"Run update affected {count} records; expected exactly one.")
- self._host = host
- self._timespan = timespan
-
- @property
- def host(self) -> str | None:
- # Docstring inherited from RunRecord.
- return self._host
-
- @property
- def timespan(self) -> Timespan:
- # Docstring inherited from RunRecord.
- return self._timespan
-
-
-class DefaultChainedCollectionRecord(ChainedCollectionRecord):
- """Default `ChainedCollectionRecord` implementation.
-
- This method assumes the same chain table definition as produced by
- `makeCollectionChainTableSpec` method. All column names in the table are
- fixed and hard-coded in the methods.
-
- Parameters
- ----------
- db : `Database`
- Registry database.
- key
- Unique collection ID, can be the same as ``name`` if ``name`` is used
- for identification. Usually this is an integer or string, but can be
- other database-specific type.
- name : `str`
- Collection name.
- table : `sqlalchemy.schema.Table`
- Table for chain relationship records.
- universe : `DimensionUniverse`
- Object managing all known dimensions.
- """
-
- def __init__(
- self,
- db: Database,
- key: Any,
- name: str,
- *,
- table: sqlalchemy.schema.Table,
- universe: DimensionUniverse,
- ):
- super().__init__(key=key, name=name, universe=universe)
- self._db = db
- self._table = table
- self._universe = universe
-
- def _update(self, manager: CollectionManager, children: tuple[str, ...]) -> None:
- # Docstring inherited from ChainedCollectionRecord.
- rows = []
- position = itertools.count()
- for child in manager.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False):
- rows.append(
- {
- "parent": self.key,
- "child": child.key,
- "position": next(position),
- }
- )
- with self._db.transaction():
- self._db.delete(self._table, ["parent"], {"parent": self.key})
- self._db.insert(self._table, *rows)
-
- def _load(self, manager: CollectionManager) -> tuple[str, ...]:
- # Docstring inherited from ChainedCollectionRecord.
- sql = (
- sqlalchemy.sql.select(
- self._table.columns.child,
- )
- .select_from(self._table)
- .where(self._table.columns.parent == self.key)
- .order_by(self._table.columns.position)
- )
- with self._db.query(sql) as sql_result:
- return tuple(manager[row[self._table.columns.child]].name for row in sql_result.mappings())
-
-
K = TypeVar("K")
-class DefaultCollectionManager(Generic[K], CollectionManager):
+class DefaultCollectionManager(CollectionManager[K]):
"""Default `CollectionManager` implementation.
This implementation uses record classes defined in this module and is
@@ -331,72 +191,34 @@ def __init__(
collectionIdName: str,
*,
dimensions: DimensionRecordStorageManager,
+ caching_context: CachingContext,
registry_schema_version: VersionTuple | None = None,
):
super().__init__(registry_schema_version=registry_schema_version)
self._db = db
self._tables = tables
self._collectionIdName = collectionIdName
- self._records: dict[K, CollectionRecord] = {} # indexed by record ID
self._dimensions = dimensions
+ self._caching_context = caching_context
def refresh(self) -> None:
# Docstring inherited from CollectionManager.
- sql = sqlalchemy.sql.select(
- *(list(self._tables.collection.columns) + list(self._tables.run.columns))
- ).select_from(self._tables.collection.join(self._tables.run, isouter=True))
- # Put found records into a temporary instead of updating self._records
- # in place, for exception safety.
- records = []
- chains = []
- TimespanReprClass = self._db.getTimespanRepresentation()
- with self._db.query(sql) as sql_result:
- sql_rows = sql_result.mappings().fetchall()
- for row in sql_rows:
- collection_id = row[self._tables.collection.columns[self._collectionIdName]]
- name = row[self._tables.collection.columns.name]
- type = CollectionType(row["type"])
- record: CollectionRecord
- if type is CollectionType.RUN:
- record = DefaultRunRecord(
- key=collection_id,
- name=name,
- db=self._db,
- table=self._tables.run,
- idColumnName=self._collectionIdName,
- host=row[self._tables.run.columns.host],
- timespan=TimespanReprClass.extract(row),
- )
- elif type is CollectionType.CHAINED:
- record = DefaultChainedCollectionRecord(
- db=self._db,
- key=collection_id,
- table=self._tables.collection_chain,
- name=name,
- universe=self._dimensions.universe,
- )
- chains.append(record)
- else:
- record = CollectionRecord(key=collection_id, name=name, type=type)
- records.append(record)
- self._setRecordCache(records)
- for chain in chains:
- try:
- chain.refresh(self)
- except MissingCollectionError:
- # This indicates a race condition in which some other client
- # created a new collection and added it as a child of this
- # (pre-existing) chain between the time we fetched all
- # collections and the time we queried for parent-child
- # relationships.
- # Because that's some other unrelated client, we shouldn't care
- # about that parent collection anyway, so we just drop it on
- # the floor (a manual refresh can be used to get it back).
- self._removeCachedRecord(chain)
+ if self._caching_context.collection_records is not None:
+ self._caching_context.collection_records.clear()
+
+ def _fetch_all(self) -> list[CollectionRecord[K]]:
+ """Retrieve all records into cache if not done so yet."""
+ if self._caching_context.collection_records is not None:
+ if self._caching_context.collection_records.full:
+ return list(self._caching_context.collection_records.records())
+ records = self._fetch_by_key(None)
+ if self._caching_context.collection_records is not None:
+ self._caching_context.collection_records.set(records, full=True)
+ return records
def register(
self, name: str, type: CollectionType, doc: str | None = None
- ) -> tuple[CollectionRecord, bool]:
+ ) -> tuple[CollectionRecord[K], bool]:
# Docstring inherited from CollectionManager.
registered = False
record = self._getByName(name)
@@ -411,7 +233,7 @@ def register(
assert isinstance(inserted_or_updated, bool)
registered = inserted_or_updated
assert row is not None
- collection_id = row[self._collectionIdName]
+ collection_id = cast(K, row[self._collectionIdName])
if type is CollectionType.RUN:
TimespanReprClass = self._db.getTimespanRepresentation()
row, _ = self._db.sync(
@@ -420,25 +242,20 @@ def register(
returning=("host",) + TimespanReprClass.getFieldNames(),
)
assert row is not None
- record = DefaultRunRecord(
- db=self._db,
+ record = RunRecord[K](
key=collection_id,
name=name,
- table=self._tables.run,
- idColumnName=self._collectionIdName,
host=row["host"],
timespan=TimespanReprClass.extract(row),
)
elif type is CollectionType.CHAINED:
- record = DefaultChainedCollectionRecord(
- db=self._db,
+ record = ChainedCollectionRecord[K](
key=collection_id,
name=name,
- table=self._tables.collection_chain,
- universe=self._dimensions.universe,
+ children=[],
)
else:
- record = CollectionRecord(key=collection_id, name=name, type=type)
+ record = CollectionRecord[K](key=collection_id, name=name, type=type)
self._addCachedRecord(record)
return record, registered
@@ -453,19 +270,48 @@ def remove(self, name: str) -> None:
)
self._removeCachedRecord(record)
- def find(self, name: str) -> CollectionRecord:
+ def find(self, name: str) -> CollectionRecord[K]:
# Docstring inherited from CollectionManager.
result = self._getByName(name)
if result is None:
raise MissingCollectionError(f"No collection with name '{name}' found.")
return result
- def __getitem__(self, key: Any) -> CollectionRecord:
+ def _find_many(self, names: Iterable[str]) -> list[CollectionRecord[K]]:
+ """Return multiple records given their names."""
+ names = list(names)
+ # To protect against potential races in cache updates.
+ records: dict[str, CollectionRecord | None] = {}
+ if self._caching_context.collection_records is not None:
+ for name in names:
+ records[name] = self._caching_context.collection_records.get_by_name(name)
+ fetch_names = [name for name, record in records.items() if record is None]
+ else:
+ fetch_names = list(names)
+ records = {name: None for name in fetch_names}
+ if fetch_names:
+ for record in self._fetch_by_name(fetch_names):
+ records[record.name] = record
+ self._addCachedRecord(record)
+ missing_names = [name for name, record in records.items() if record is None]
+ if len(missing_names) == 1:
+ raise MissingCollectionError(f"No collection with name '{missing_names[0]}' found.")
+ elif len(missing_names) > 1:
+ raise MissingCollectionError(f"No collections with names '{' '.join(missing_names)}' found.")
+ return [cast(CollectionRecord[K], records[name]) for name in names]
+
+ def __getitem__(self, key: Any) -> CollectionRecord[K]:
# Docstring inherited from CollectionManager.
- try:
- return self._records[key]
- except KeyError as err:
- raise MissingCollectionError(f"Collection with key '{key}' not found.") from err
+ if self._caching_context.collection_records is not None:
+ if (record := self._caching_context.collection_records.get_by_key(key)) is not None:
+ return record
+ if records := self._fetch_by_key([key]):
+ record = records[0]
+ if self._caching_context.collection_records is not None:
+ self._caching_context.collection_records.add(record)
+ return record
+ else:
+ raise MissingCollectionError(f"Collection with key '{key}' not found.")
def resolve_wildcard(
self,
@@ -475,13 +321,13 @@ def resolve_wildcard(
done: set[str] | None = None,
flatten_chains: bool = True,
include_chains: bool | None = None,
- ) -> list[CollectionRecord]:
+ ) -> list[CollectionRecord[K]]:
# Docstring inherited
if done is None:
done = set()
include_chains = include_chains if include_chains is not None else not flatten_chains
- def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[CollectionRecord]:
+ def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[CollectionRecord[K]]:
if record.name in done:
return
if record.type in collection_types:
@@ -490,28 +336,29 @@ def resolve_nested(record: CollectionRecord, done: set[str]) -> Iterator[Collect
yield record
if flatten_chains and record.type is CollectionType.CHAINED:
done.add(record.name)
- for name in cast(ChainedCollectionRecord, record).children:
+ for child in self._find_many(cast(ChainedCollectionRecord[K], record).children):
# flake8 can't tell that we only delete this closure when
# we're totally done with it.
- yield from resolve_nested(self.find(name), done) # noqa: F821
+ yield from resolve_nested(child, done) # noqa: F821
- result: list[CollectionRecord] = []
+ result: list[CollectionRecord[K]] = []
if wildcard.patterns is ...:
- for record in self._records.values():
+ for record in self._fetch_all():
result.extend(resolve_nested(record, done))
del resolve_nested
return result
- for name in wildcard.strings:
- result.extend(resolve_nested(self.find(name), done))
+ if wildcard.strings:
+ for record in self._find_many(wildcard.strings):
+ result.extend(resolve_nested(record, done))
if wildcard.patterns:
- for record in self._records.values():
+ for record in self._fetch_all():
if any(p.fullmatch(record.name) for p in wildcard.patterns):
result.extend(resolve_nested(record, done))
del resolve_nested
return result
- def getDocumentation(self, key: Any) -> str | None:
+ def getDocumentation(self, key: K) -> str | None:
# Docstring inherited from CollectionManager.
sql = (
sqlalchemy.sql.select(self._tables.collection.columns.doc)
@@ -521,27 +368,76 @@ def getDocumentation(self, key: Any) -> str | None:
with self._db.query(sql) as sql_result:
return sql_result.scalar()
- def setDocumentation(self, key: Any, doc: str | None) -> None:
+ def setDocumentation(self, key: K, doc: str | None) -> None:
# Docstring inherited from CollectionManager.
self._db.update(self._tables.collection, {self._collectionIdName: "key"}, {"key": key, "doc": doc})
- def _setRecordCache(self, records: Iterable[CollectionRecord]) -> None:
- """Set internal record cache to contain given records,
- old cached records will be removed.
- """
- self._records = {}
- for record in records:
- self._records[record.key] = record
-
- def _addCachedRecord(self, record: CollectionRecord) -> None:
+ def _addCachedRecord(self, record: CollectionRecord[K]) -> None:
"""Add single record to cache."""
- self._records[record.key] = record
+ if self._caching_context.collection_records is not None:
+ self._caching_context.collection_records.add(record)
- def _removeCachedRecord(self, record: CollectionRecord) -> None:
+ def _removeCachedRecord(self, record: CollectionRecord[K]) -> None:
"""Remove single record from cache."""
- del self._records[record.key]
+ if self._caching_context.collection_records is not None:
+ self._caching_context.collection_records.discard(record)
- @abstractmethod
- def _getByName(self, name: str) -> CollectionRecord | None:
+ def _getByName(self, name: str) -> CollectionRecord[K] | None:
"""Find collection record given collection name."""
+ if self._caching_context.collection_records is not None:
+ if (record := self._caching_context.collection_records.get_by_name(name)) is not None:
+ return record
+ records = self._fetch_by_name([name])
+ for record in records:
+ self._addCachedRecord(record)
+ return records[0] if records else None
+
+ @abstractmethod
+ def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[K]]:
+ """Fetch collection record from database given its name."""
+ raise NotImplementedError()
+
+ @abstractmethod
+ def _fetch_by_key(self, collection_ids: Iterable[K] | None) -> list[CollectionRecord[K]]:
+ """Fetch collection record from database given its key, or fetch all
+ collctions if argument is None.
+ """
raise NotImplementedError()
+
+ def update_chain(
+ self, chain: ChainedCollectionRecord[K], children: Iterable[str], flatten: bool = False
+ ) -> ChainedCollectionRecord[K]:
+ # Docstring inherited from CollectionManager.
+ children_as_wildcard = CollectionWildcard.from_names(children)
+ for record in self.resolve_wildcard(
+ children_as_wildcard,
+ flatten_chains=True,
+ include_chains=True,
+ collection_types={CollectionType.CHAINED},
+ ):
+ if record == chain:
+ raise ValueError(f"Cycle in collection chaining when defining '{chain.name}'.")
+ if flatten:
+ children = tuple(
+ record.name for record in self.resolve_wildcard(children_as_wildcard, flatten_chains=True)
+ )
+
+ rows = []
+ position = itertools.count()
+ names = []
+ for child in self.resolve_wildcard(CollectionWildcard.from_names(children), flatten_chains=False):
+ rows.append(
+ {
+ "parent": chain.key,
+ "child": child.key,
+ "position": next(position),
+ }
+ )
+ names.append(child.name)
+ with self._db.transaction():
+ self._db.delete(self._tables.collection_chain, ["parent"], {"parent": chain.key})
+ self._db.insert(self._tables.collection_chain, *rows)
+
+ record = ChainedCollectionRecord[K](chain.key, chain.name, children=tuple(names))
+ self._addCachedRecord(record)
+ return record
diff --git a/python/lsst/daf/butler/registry/collections/nameKey.py b/python/lsst/daf/butler/registry/collections/nameKey.py
index e5e635e61c..9fc2e32271 100644
--- a/python/lsst/daf/butler/registry/collections/nameKey.py
+++ b/python/lsst/daf/butler/registry/collections/nameKey.py
@@ -26,16 +26,17 @@
# along with this program. If not, see .
from __future__ import annotations
-from ... import ddl
-
__all__ = ["NameKeyCollectionManager"]
+from collections.abc import Iterable, Mapping
from typing import TYPE_CHECKING, Any
import sqlalchemy
+from ... import ddl
from ..._timespan import TimespanDatabaseRepresentation
-from ..interfaces import VersionTuple
+from .._collection_type import CollectionType
+from ..interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord, VersionTuple
from ._base import (
CollectionTablesTuple,
DefaultCollectionManager,
@@ -44,7 +45,8 @@
)
if TYPE_CHECKING:
- from ..interfaces import CollectionRecord, Database, DimensionRecordStorageManager, StaticTablesContext
+ from .._caching_context import CachingContext
+ from ..interfaces import Database, DimensionRecordStorageManager, StaticTablesContext
_KEY_FIELD_SPEC = ddl.FieldSpec("name", dtype=sqlalchemy.String, length=64, primaryKey=True)
@@ -68,7 +70,7 @@ def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) ->
)
-class NameKeyCollectionManager(DefaultCollectionManager):
+class NameKeyCollectionManager(DefaultCollectionManager[str]):
"""A `CollectionManager` implementation that uses collection names for
primary/foreign keys and aggressively loads all collection/run records in
the database into memory.
@@ -85,6 +87,7 @@ def initialize(
context: StaticTablesContext,
*,
dimensions: DimensionRecordStorageManager,
+ caching_context: CachingContext,
registry_schema_version: VersionTuple | None = None,
) -> NameKeyCollectionManager:
# Docstring inherited from CollectionManager.
@@ -93,6 +96,7 @@ def initialize(
tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore
collectionIdName="name",
dimensions=dimensions,
+ caching_context=caching_context,
registry_schema_version=registry_schema_version,
)
@@ -152,9 +156,117 @@ def addRunForeignKey(
)
return copy
- def _getByName(self, name: str) -> CollectionRecord | None:
- # Docstring inherited from DefaultCollectionManager.
- return self._records.get(name)
+ def getParentChains(self, key: str) -> set[str]:
+ # Docstring inherited from CollectionManager.
+ table = self._tables.collection_chain
+ sql = (
+ sqlalchemy.sql.select(table.columns["parent"])
+ .select_from(table)
+ .where(table.columns["child"] == key)
+ )
+ with self._db.query(sql) as sql_result:
+ parent_names = set(sql_result.scalars().all())
+ return parent_names
+
+ def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[str]]:
+ # Docstring inherited from base class.
+ return self._fetch_by_key(names)
+
+ def _fetch_by_key(self, collection_ids: Iterable[str] | None) -> list[CollectionRecord[str]]:
+ # Docstring inherited from base class.
+ sql = sqlalchemy.sql.select(*self._tables.collection.columns, *self._tables.run.columns).select_from(
+ self._tables.collection.join(self._tables.run, isouter=True)
+ )
+
+ chain_sql = sqlalchemy.sql.select(
+ self._tables.collection_chain.columns["parent"],
+ self._tables.collection_chain.columns["position"],
+ self._tables.collection_chain.columns["child"],
+ )
+
+ records: list[CollectionRecord[str]] = []
+ # We want to keep transactions as short as possible. When we fetch
+ # everything we want to quickly fetch things into memory and finish
+ # transaction. When we fetch just few records we need to process result
+ # of the first query before we can run the second one.
+ if collection_ids is not None:
+ sql = sql.where(self._tables.collection.columns[self._collectionIdName].in_(collection_ids))
+ with self._db.transaction():
+ with self._db.query(sql) as sql_result:
+ sql_rows = sql_result.mappings().fetchall()
+
+ records, chained_ids = self._rows_to_records(sql_rows)
+
+ if chained_ids:
+ # Retrieve chained collection compositions
+ chain_sql = chain_sql.where(
+ self._tables.collection_chain.columns["parent"].in_(chained_ids)
+ )
+ with self._db.query(chain_sql) as sql_result:
+ chain_rows = sql_result.mappings().fetchall()
+
+ records += self._rows_to_chains(chain_rows, chained_ids)
+
+ else:
+ with self._db.transaction():
+ with self._db.query(sql) as sql_result:
+ sql_rows = sql_result.mappings().fetchall()
+ with self._db.query(chain_sql) as sql_result:
+ chain_rows = sql_result.mappings().fetchall()
+
+ records, chained_ids = self._rows_to_records(sql_rows)
+ records += self._rows_to_chains(chain_rows, chained_ids)
+
+ return records
+
+ def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[str]], list[str]]:
+ """Convert rows returned from collection query to a list of records
+ and a list chained collection names.
+ """
+ records: list[CollectionRecord[str]] = []
+ TimespanReprClass = self._db.getTimespanRepresentation()
+ chained_ids: list[str] = []
+ for row in rows:
+ name = row[self._tables.collection.columns.name]
+ type = CollectionType(row["type"])
+ record: CollectionRecord[str]
+ if type is CollectionType.RUN:
+ record = RunRecord[str](
+ key=name,
+ name=name,
+ host=row[self._tables.run.columns.host],
+ timespan=TimespanReprClass.extract(row),
+ )
+ records.append(record)
+ elif type is CollectionType.CHAINED:
+ # Need to delay chained collection construction until to
+ # fetch their children names.
+ chained_ids.append(name)
+ else:
+ record = CollectionRecord[str](key=name, name=name, type=type)
+ records.append(record)
+
+ return records, chained_ids
+
+ def _rows_to_chains(self, rows: Iterable[Mapping], chained_ids: list[str]) -> list[CollectionRecord[str]]:
+ """Convert rows returned from collection chain query to a list of
+ records.
+ """
+ chains_defs: dict[str, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids}
+ for row in rows:
+ chains_defs[row["parent"]].append((row["position"], row["child"]))
+
+ records: list[CollectionRecord[str]] = []
+ for name, children in chains_defs.items():
+ children_names = [child for _, child in sorted(children)]
+ record = ChainedCollectionRecord[str](
+ key=name,
+ name=name,
+ children=children_names,
+ )
+ records.append(record)
+
+ return records
@classmethod
def currentVersions(cls) -> list[VersionTuple]:
diff --git a/python/lsst/daf/butler/registry/collections/synthIntKey.py b/python/lsst/daf/butler/registry/collections/synthIntKey.py
index 8e49140c8d..3b7dbb1de4 100644
--- a/python/lsst/daf/butler/registry/collections/synthIntKey.py
+++ b/python/lsst/daf/butler/registry/collections/synthIntKey.py
@@ -30,13 +30,14 @@
__all__ = ["SynthIntKeyCollectionManager"]
-from collections.abc import Iterable
+from collections.abc import Iterable, Mapping
from typing import TYPE_CHECKING, Any
import sqlalchemy
from ..._timespan import TimespanDatabaseRepresentation
-from ..interfaces import CollectionRecord, VersionTuple
+from .._collection_type import CollectionType
+from ..interfaces import ChainedCollectionRecord, CollectionRecord, RunRecord, VersionTuple
from ._base import (
CollectionTablesTuple,
DefaultCollectionManager,
@@ -45,6 +46,7 @@
)
if TYPE_CHECKING:
+ from .._caching_context import CachingContext
from ..interfaces import Database, DimensionRecordStorageManager, StaticTablesContext
@@ -73,43 +75,11 @@ def _makeTableSpecs(TimespanReprClass: type[TimespanDatabaseRepresentation]) ->
)
-class SynthIntKeyCollectionManager(DefaultCollectionManager):
+class SynthIntKeyCollectionManager(DefaultCollectionManager[int]):
"""A `CollectionManager` implementation that uses synthetic primary key
(auto-incremented integer) for collections table.
-
- Most of the logic, including caching policy, is implemented in the base
- class, this class only adds customizations specific to this particular
- table schema.
-
- Parameters
- ----------
- db : `Database`
- Interface to the underlying database engine and namespace.
- tables : `NamedTuple`
- Named tuple of SQLAlchemy table objects.
- collectionIdName : `str`
- Name of the column in collections table that identifies it (PK).
- dimensions : `DimensionRecordStorageManager`
- Manager object for the dimensions in this `Registry`.
"""
- def __init__(
- self,
- db: Database,
- tables: CollectionTablesTuple,
- collectionIdName: str,
- dimensions: DimensionRecordStorageManager,
- registry_schema_version: VersionTuple | None = None,
- ):
- super().__init__(
- db=db,
- tables=tables,
- collectionIdName=collectionIdName,
- dimensions=dimensions,
- registry_schema_version=registry_schema_version,
- )
- self._nameCache: dict[str, CollectionRecord] = {} # indexed by collection name
-
@classmethod
def initialize(
cls,
@@ -117,6 +87,7 @@ def initialize(
context: StaticTablesContext,
*,
dimensions: DimensionRecordStorageManager,
+ caching_context: CachingContext,
registry_schema_version: VersionTuple | None = None,
) -> SynthIntKeyCollectionManager:
# Docstring inherited from CollectionManager.
@@ -125,6 +96,7 @@ def initialize(
tables=context.addTableTuple(_makeTableSpecs(db.getTimespanRepresentation())), # type: ignore
collectionIdName="collection_id",
dimensions=dimensions,
+ caching_context=caching_context,
registry_schema_version=registry_schema_version,
)
@@ -184,29 +156,134 @@ def addRunForeignKey(
)
return copy
- def _setRecordCache(self, records: Iterable[CollectionRecord]) -> None:
- """Set internal record cache to contain given records,
- old cached records will be removed.
+ def getParentChains(self, key: int) -> set[str]:
+ # Docstring inherited from CollectionManager.
+ chain = self._tables.collection_chain
+ collection = self._tables.collection
+ sql = (
+ sqlalchemy.sql.select(collection.columns["name"])
+ .select_from(collection)
+ .join(chain, onclause=collection.columns[self._collectionIdName] == chain.columns["parent"])
+ .where(chain.columns["child"] == key)
+ )
+ with self._db.query(sql) as sql_result:
+ parent_names = set(sql_result.scalars().all())
+ return parent_names
+
+ def _fetch_by_name(self, names: Iterable[str]) -> list[CollectionRecord[int]]:
+ # Docstring inherited from base class.
+ return self._fetch("name", names)
+
+ def _fetch_by_key(self, collection_ids: Iterable[int] | None) -> list[CollectionRecord[int]]:
+ # Docstring inherited from base class.
+ return self._fetch(self._collectionIdName, collection_ids)
+
+ def _fetch(
+ self, column_name: str, collections: Iterable[int | str] | None
+ ) -> list[CollectionRecord[int]]:
+ collection_chain = self._tables.collection_chain
+ collection = self._tables.collection
+ sql = sqlalchemy.sql.select(*collection.columns, *self._tables.run.columns).select_from(
+ collection.join(self._tables.run, isouter=True)
+ )
+
+ chain_sql = (
+ sqlalchemy.sql.select(
+ collection_chain.columns["parent"],
+ collection_chain.columns["position"],
+ collection.columns["name"].label("child_name"),
+ )
+ .select_from(collection_chain)
+ .join(
+ collection,
+ onclause=collection_chain.columns["child"] == collection.columns[self._collectionIdName],
+ )
+ )
+
+ records: list[CollectionRecord[int]] = []
+ # We want to keep transactions as short as possible. When we fetch
+ # everything we want to quickly fetch things into memory and finish
+ # transaction. When we fetch just few records we need to process first
+ # query before wi can run second one,
+ if collections is not None:
+ sql = sql.where(collection.columns[column_name].in_(collections))
+ with self._db.transaction():
+ with self._db.query(sql) as sql_result:
+ sql_rows = sql_result.mappings().fetchall()
+
+ records, chained_ids = self._rows_to_records(sql_rows)
+
+ if chained_ids:
+ chain_sql = chain_sql.where(collection_chain.columns["parent"].in_(list(chained_ids)))
+
+ with self._db.query(chain_sql) as sql_result:
+ chain_rows = sql_result.mappings().fetchall()
+
+ records += self._rows_to_chains(chain_rows, chained_ids)
+
+ else:
+ with self._db.transaction():
+ with self._db.query(sql) as sql_result:
+ sql_rows = sql_result.mappings().fetchall()
+ with self._db.query(chain_sql) as sql_result:
+ chain_rows = sql_result.mappings().fetchall()
+
+ records, chained_ids = self._rows_to_records(sql_rows)
+ records += self._rows_to_chains(chain_rows, chained_ids)
+
+ return records
+
+ def _rows_to_records(self, rows: Iterable[Mapping]) -> tuple[list[CollectionRecord[int]], dict[int, str]]:
+ """Convert rows returned from collection query to a list of records
+ and a dict chained collection names.
+ """
+ records: list[CollectionRecord[int]] = []
+ chained_ids: dict[int, str] = {}
+ TimespanReprClass = self._db.getTimespanRepresentation()
+ for row in rows:
+ key: int = row[self._collectionIdName]
+ name: str = row[self._tables.collection.columns.name]
+ type = CollectionType(row["type"])
+ record: CollectionRecord[int]
+ if type is CollectionType.RUN:
+ record = RunRecord[int](
+ key=key,
+ name=name,
+ host=row[self._tables.run.columns.host],
+ timespan=TimespanReprClass.extract(row),
+ )
+ records.append(record)
+ elif type is CollectionType.CHAINED:
+ # Need to delay chained collection construction until to
+ # fetch their children names.
+ chained_ids[key] = name
+ else:
+ record = CollectionRecord[int](key=key, name=name, type=type)
+ records.append(record)
+ return records, chained_ids
+
+ def _rows_to_chains(
+ self, rows: Iterable[Mapping], chained_ids: dict[int, str]
+ ) -> list[CollectionRecord[int]]:
+ """Convert rows returned from collection chain query to a list of
+ records.
"""
- self._records = {}
- self._nameCache = {}
- for record in records:
- self._records[record.key] = record
- self._nameCache[record.name] = record
-
- def _addCachedRecord(self, record: CollectionRecord) -> None:
- """Add single record to cache."""
- self._records[record.key] = record
- self._nameCache[record.name] = record
-
- def _removeCachedRecord(self, record: CollectionRecord) -> None:
- """Remove single record from cache."""
- del self._records[record.key]
- del self._nameCache[record.name]
-
- def _getByName(self, name: str) -> CollectionRecord | None:
- # Docstring inherited from DefaultCollectionManager.
- return self._nameCache.get(name)
+ chains_defs: dict[int, list[tuple[int, str]]] = {chain_id: [] for chain_id in chained_ids}
+ for row in rows:
+ chains_defs[row["parent"]].append((row["position"], row["child_name"]))
+
+ records: list[CollectionRecord[int]] = []
+ for key, children in chains_defs.items():
+ name = chained_ids[key]
+ children_names = [child for _, child in sorted(children)]
+ record = ChainedCollectionRecord[int](
+ key=key,
+ name=name,
+ children=children_names,
+ )
+ records.append(record)
+
+ return records
@classmethod
def currentVersions(cls) -> list[VersionTuple]:
diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py b/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py
index d52d337a2b..692d8585a3 100644
--- a/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py
+++ b/python/lsst/daf/butler/registry/datasets/byDimensions/_manager.py
@@ -4,9 +4,11 @@
__all__ = ("ByDimensionsDatasetRecordStorageManagerUUID",)
+import dataclasses
import logging
import warnings
from collections import defaultdict
+from collections.abc import Iterable, Mapping
from typing import TYPE_CHECKING, Any
import sqlalchemy
@@ -30,6 +32,7 @@
)
if TYPE_CHECKING:
+ from ..._caching_context import CachingContext
from ...interfaces import (
CollectionManager,
CollectionRecord,
@@ -54,6 +57,34 @@ class MissingDatabaseTableError(RuntimeError):
"""Exception raised when a table is not found in a database."""
+@dataclasses.dataclass
+class _DatasetTypeRecord:
+ """Contents of a single dataset type record."""
+
+ dataset_type: DatasetType
+ dataset_type_id: int
+ tag_table_name: str
+ calib_table_name: str | None
+
+
+class _SpecTableFactory:
+ """Factory for `sqlalchemy.schema.Table` instances that builds table
+ instances using provided `ddl.TableSpec` definition and verifies that
+ table exists in the database.
+ """
+
+ def __init__(self, db: Database, name: str, spec: ddl.TableSpec):
+ self._db = db
+ self._name = name
+ self._spec = spec
+
+ def __call__(self) -> sqlalchemy.schema.Table:
+ table = self._db.getExistingTable(self._name, self._spec)
+ if table is None:
+ raise MissingDatabaseTableError(f"Table {self._name} is missing from database schema.")
+ return table
+
+
class ByDimensionsDatasetRecordStorageManagerBase(DatasetRecordStorageManager):
"""A manager class for datasets that uses one dataset-collection table for
each group of dataset types that share the same dimensions.
@@ -90,6 +121,8 @@ class ByDimensionsDatasetRecordStorageManagerBase(DatasetRecordStorageManager):
tables used by this class.
summaries : `CollectionSummaryManager`
Structure containing tables that summarize the contents of collections.
+ caching_context : `CachingContext`
+ Object controlling caching of information returned by managers.
"""
def __init__(
@@ -100,6 +133,7 @@ def __init__(
dimensions: DimensionRecordStorageManager,
static: StaticDatasetTablesTuple,
summaries: CollectionSummaryManager,
+ caching_context: CachingContext,
registry_schema_version: VersionTuple | None = None,
):
super().__init__(registry_schema_version=registry_schema_version)
@@ -108,8 +142,7 @@ def __init__(
self._dimensions = dimensions
self._static = static
self._summaries = summaries
- self._byName: dict[str, ByDimensionsDatasetRecordStorage] = {}
- self._byId: dict[int, ByDimensionsDatasetRecordStorage] = {}
+ self._caching_context = caching_context
@classmethod
def initialize(
@@ -119,6 +152,7 @@ def initialize(
*,
collections: CollectionManager,
dimensions: DimensionRecordStorageManager,
+ caching_context: CachingContext,
registry_schema_version: VersionTuple | None = None,
) -> DatasetRecordStorageManager:
# Docstring inherited from DatasetRecordStorageManager.
@@ -131,6 +165,8 @@ def initialize(
context,
collections=collections,
dimensions=dimensions,
+ dataset_type_table=static.dataset_type,
+ caching_context=caching_context,
)
return cls(
db=db,
@@ -138,6 +174,7 @@ def initialize(
dimensions=dimensions,
static=static,
summaries=summaries,
+ caching_context=caching_context,
registry_schema_version=registry_schema_version,
)
@@ -205,60 +242,34 @@ def addDatasetForeignKey(
def refresh(self) -> None:
# Docstring inherited from DatasetRecordStorageManager.
- byName: dict[str, ByDimensionsDatasetRecordStorage] = {}
- byId: dict[int, ByDimensionsDatasetRecordStorage] = {}
- dataset_types: dict[int, DatasetType] = {}
- c = self._static.dataset_type.columns
- with self._db.query(self._static.dataset_type.select()) as sql_result:
- sql_rows = sql_result.mappings().fetchall()
- for row in sql_rows:
- name = row[c.name]
- dimensions = self._dimensions.loadDimensionGraph(row[c.dimensions_key])
- calibTableName = row[c.calibration_association_table]
- datasetType = DatasetType(
- name, dimensions, row[c.storage_class], isCalibration=(calibTableName is not None)
+ if self._caching_context.dataset_types is not None:
+ self._caching_context.dataset_types.clear()
+
+ def _make_storage(self, record: _DatasetTypeRecord) -> ByDimensionsDatasetRecordStorage:
+ """Create storage instance for a dataset type record."""
+ tags_spec = makeTagTableSpec(record.dataset_type, type(self._collections), self.getIdColumnType())
+ tags_table_factory = _SpecTableFactory(self._db, record.tag_table_name, tags_spec)
+ calibs_table_factory = None
+ if record.calib_table_name is not None:
+ calibs_spec = makeCalibTableSpec(
+ record.dataset_type,
+ type(self._collections),
+ self._db.getTimespanRepresentation(),
+ self.getIdColumnType(),
)
- tags = self._db.getExistingTable(
- row[c.tag_association_table],
- makeTagTableSpec(datasetType, type(self._collections), self.getIdColumnType()),
- )
- if tags is None:
- raise MissingDatabaseTableError(
- f"Table {row[c.tag_association_table]} is missing from database schema."
- )
- if calibTableName is not None:
- calibs = self._db.getExistingTable(
- row[c.calibration_association_table],
- makeCalibTableSpec(
- datasetType,
- type(self._collections),
- self._db.getTimespanRepresentation(),
- self.getIdColumnType(),
- ),
- )
- if calibs is None:
- raise MissingDatabaseTableError(
- f"Table {row[c.calibration_association_table]} is missing from database schema."
- )
- else:
- calibs = None
- storage = self._recordStorageType(
- db=self._db,
- datasetType=datasetType,
- static=self._static,
- summaries=self._summaries,
- tags=tags,
- calibs=calibs,
- dataset_type_id=row["id"],
- collections=self._collections,
- use_astropy_ingest_date=self.ingest_date_dtype() is ddl.AstropyTimeNsecTai,
- )
- byName[datasetType.name] = storage
- byId[storage._dataset_type_id] = storage
- dataset_types[row["id"]] = datasetType
- self._byName = byName
- self._byId = byId
- self._summaries.refresh(dataset_types)
+ calibs_table_factory = _SpecTableFactory(self._db, record.calib_table_name, calibs_spec)
+ storage = self._recordStorageType(
+ db=self._db,
+ datasetType=record.dataset_type,
+ static=self._static,
+ summaries=self._summaries,
+ tags_table_factory=tags_table_factory,
+ calibs_table_factory=calibs_table_factory,
+ dataset_type_id=record.dataset_type_id,
+ collections=self._collections,
+ use_astropy_ingest_date=self.ingest_date_dtype() is ddl.AstropyTimeNsecTai,
+ )
+ return storage
def remove(self, name: str) -> None:
# Docstring inherited from DatasetRecordStorageManager.
@@ -281,31 +292,48 @@ def remove(self, name: str) -> None:
def find(self, name: str) -> DatasetRecordStorage | None:
# Docstring inherited from DatasetRecordStorageManager.
- return self._byName.get(name)
+ if self._caching_context.dataset_types is not None:
+ _, storage = self._caching_context.dataset_types.get(name)
+ if storage is not None:
+ return storage
+ else:
+ # On the first cache miss populate the cache with complete list
+ # of dataset types (if it was not done yet).
+ if not self._caching_context.dataset_types.full:
+ self._fetch_dataset_types()
+ # Try again
+ _, storage = self._caching_context.dataset_types.get(name)
+ if self._caching_context.dataset_types.full:
+ # If not in cache then dataset type is not defined.
+ return storage
+ record = self._fetch_dataset_type_record(name)
+ if record is not None:
+ storage = self._make_storage(record)
+ if self._caching_context.dataset_types is not None:
+ self._caching_context.dataset_types.add(storage.datasetType, storage)
+ return storage
+ else:
+ return None
- def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool]:
+ def register(self, datasetType: DatasetType) -> bool:
# Docstring inherited from DatasetRecordStorageManager.
if datasetType.isComponent():
raise ValueError(
f"Component dataset types can not be stored in registry. Rejecting {datasetType.name}"
)
- storage = self._byName.get(datasetType.name)
- if storage is None:
+ record = self._fetch_dataset_type_record(datasetType.name)
+ if record is None:
dimensionsKey = self._dimensions.saveDimensionGraph(datasetType.dimensions)
tagTableName = makeTagTableName(datasetType, dimensionsKey)
- calibTableName = (
- makeCalibTableName(datasetType, dimensionsKey) if datasetType.isCalibration() else None
- )
- # The order is important here, we want to create tables first and
- # only register them if this operation is successful. We cannot
- # wrap it into a transaction because database class assumes that
- # DDL is not transaction safe in general.
- tags = self._db.ensureTableExists(
+ self._db.ensureTableExists(
tagTableName,
makeTagTableSpec(datasetType, type(self._collections), self.getIdColumnType()),
)
+ calibTableName = (
+ makeCalibTableName(datasetType, dimensionsKey) if datasetType.isCalibration() else None
+ )
if calibTableName is not None:
- calibs = self._db.ensureTableExists(
+ self._db.ensureTableExists(
calibTableName,
makeCalibTableSpec(
datasetType,
@@ -314,8 +342,6 @@ def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool
self.getIdColumnType(),
),
)
- else:
- calibs = None
row, inserted = self._db.sync(
self._static.dataset_type,
keys={"name": datasetType.name},
@@ -331,28 +357,25 @@ def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool
},
returning=["id", "tag_association_table"],
)
- assert row is not None
- storage = self._recordStorageType(
- db=self._db,
- datasetType=datasetType,
- static=self._static,
- summaries=self._summaries,
- tags=tags,
- calibs=calibs,
- dataset_type_id=row["id"],
- collections=self._collections,
- use_astropy_ingest_date=self.ingest_date_dtype() is ddl.AstropyTimeNsecTai,
- )
- self._byName[datasetType.name] = storage
- self._byId[storage._dataset_type_id] = storage
+ # Make sure that cache is updated
+ if self._caching_context.dataset_types is not None and row is not None:
+ record = _DatasetTypeRecord(
+ dataset_type=datasetType,
+ dataset_type_id=row["id"],
+ tag_table_name=tagTableName,
+ calib_table_name=calibTableName,
+ )
+ storage = self._make_storage(record)
+ self._caching_context.dataset_types.add(datasetType, storage)
else:
- if datasetType != storage.datasetType:
+ if datasetType != record.dataset_type:
raise ConflictingDefinitionError(
f"Given dataset type {datasetType} is inconsistent "
- f"with database definition {storage.datasetType}."
+ f"with database definition {record.dataset_type}."
)
inserted = False
- return storage, bool(inserted)
+
+ return bool(inserted)
def resolve_wildcard(
self,
@@ -406,15 +429,13 @@ def resolve_wildcard(
raise TypeError(
"Universal wildcard '...' is not permitted for dataset types in this context."
)
- for storage in self._byName.values():
- result[storage.datasetType].add(None)
+ for datasetType in self._fetch_dataset_types():
+ result[datasetType].add(None)
if components:
try:
- result[storage.datasetType].update(
- storage.datasetType.storageClass.allComponents().keys()
- )
+ result[datasetType].update(datasetType.storageClass.allComponents().keys())
if (
- storage.datasetType.storageClass.allComponents()
+ datasetType.storageClass.allComponents()
and not already_warned
and components_deprecated
):
@@ -426,7 +447,7 @@ def resolve_wildcard(
already_warned = True
except KeyError as err:
_LOG.warning(
- f"Could not load storage class {err} for {storage.datasetType.name}; "
+ f"Could not load storage class {err} for {datasetType.name}; "
"if it has components they will not be included in query results.",
)
elif wildcard.patterns:
@@ -438,29 +459,28 @@ def resolve_wildcard(
FutureWarning,
stacklevel=find_outside_stacklevel("lsst.daf.butler"),
)
- for storage in self._byName.values():
- if any(p.fullmatch(storage.datasetType.name) for p in wildcard.patterns):
- result[storage.datasetType].add(None)
+ dataset_types = self._fetch_dataset_types()
+ for datasetType in dataset_types:
+ if any(p.fullmatch(datasetType.name) for p in wildcard.patterns):
+ result[datasetType].add(None)
if components is not False:
- for storage in self._byName.values():
- if components is None and storage.datasetType in result:
+ for datasetType in dataset_types:
+ if components is None and datasetType in result:
continue
try:
- components_for_parent = storage.datasetType.storageClass.allComponents().keys()
+ components_for_parent = datasetType.storageClass.allComponents().keys()
except KeyError as err:
_LOG.warning(
- f"Could not load storage class {err} for {storage.datasetType.name}; "
+ f"Could not load storage class {err} for {datasetType.name}; "
"if it has components they will not be included in query results."
)
continue
for component_name in components_for_parent:
if any(
- p.fullmatch(
- DatasetType.nameWithComponent(storage.datasetType.name, component_name)
- )
+ p.fullmatch(DatasetType.nameWithComponent(datasetType.name, component_name))
for p in wildcard.patterns
):
- result[storage.datasetType].add(component_name)
+ result[datasetType].add(component_name)
if not already_warned and components_deprecated:
warnings.warn(
deprecation_message,
@@ -476,29 +496,93 @@ def getDatasetRef(self, id: DatasetId) -> DatasetRef | None:
sqlalchemy.sql.select(
self._static.dataset.columns.dataset_type_id,
self._static.dataset.columns[self._collections.getRunForeignKeyName()],
+ *self._static.dataset_type.columns,
)
.select_from(self._static.dataset)
+ .join(self._static.dataset_type)
.where(self._static.dataset.columns.id == id)
)
with self._db.query(sql) as sql_result:
row = sql_result.mappings().fetchone()
if row is None:
return None
- recordsForType = self._byId.get(row[self._static.dataset.columns.dataset_type_id])
- if recordsForType is None:
- self.refresh()
- recordsForType = self._byId.get(row[self._static.dataset.columns.dataset_type_id])
- assert recordsForType is not None, "Should be guaranteed by foreign key constraints."
+ record = self._record_from_row(row)
+ storage: DatasetRecordStorage | None = None
+ if self._caching_context.dataset_types is not None:
+ _, storage = self._caching_context.dataset_types.get(record.dataset_type.name)
+ if storage is None:
+ storage = self._make_storage(record)
+ if self._caching_context.dataset_types is not None:
+ self._caching_context.dataset_types.add(storage.datasetType, storage)
+ assert isinstance(storage, ByDimensionsDatasetRecordStorage), "Not expected storage class"
return DatasetRef(
- recordsForType.datasetType,
- dataId=recordsForType.getDataId(id=id),
+ storage.datasetType,
+ dataId=storage.getDataId(id=id),
id=id,
run=self._collections[row[self._collections.getRunForeignKeyName()]].name,
)
+ def _fetch_dataset_type_record(self, name: str) -> _DatasetTypeRecord | None:
+ """Retrieve all dataset types defined in database.
+
+ Yields
+ ------
+ dataset_types : `_DatasetTypeRecord`
+ Information from a single database record.
+ """
+ c = self._static.dataset_type.columns
+ stmt = self._static.dataset_type.select().where(c.name == name)
+ with self._db.query(stmt) as sql_result:
+ row = sql_result.mappings().one_or_none()
+ if row is None:
+ return None
+ else:
+ return self._record_from_row(row)
+
+ def _record_from_row(self, row: Mapping) -> _DatasetTypeRecord:
+ name = row["name"]
+ dimensions = self._dimensions.loadDimensionGraph(row["dimensions_key"])
+ calibTableName = row["calibration_association_table"]
+ datasetType = DatasetType(
+ name, dimensions, row["storage_class"], isCalibration=(calibTableName is not None)
+ )
+ return _DatasetTypeRecord(
+ dataset_type=datasetType,
+ dataset_type_id=row["id"],
+ tag_table_name=row["tag_association_table"],
+ calib_table_name=calibTableName,
+ )
+
+ def _dataset_type_from_row(self, row: Mapping) -> DatasetType:
+ return self._record_from_row(row).dataset_type
+
+ def _fetch_dataset_types(self) -> list[DatasetType]:
+ """Fetch list of all defined dataset types."""
+ if self._caching_context.dataset_types is not None:
+ if self._caching_context.dataset_types.full:
+ return [dataset_type for dataset_type, _ in self._caching_context.dataset_types.items()]
+ with self._db.query(self._static.dataset_type.select()) as sql_result:
+ sql_rows = sql_result.mappings().fetchall()
+ records = [self._record_from_row(row) for row in sql_rows]
+ # Cache everything and specify that cache is complete.
+ if self._caching_context.dataset_types is not None:
+ cache_data = [(record.dataset_type, self._make_storage(record)) for record in records]
+ self._caching_context.dataset_types.set(cache_data, full=True)
+ return [record.dataset_type for record in records]
+
def getCollectionSummary(self, collection: CollectionRecord) -> CollectionSummary:
# Docstring inherited from DatasetRecordStorageManager.
- return self._summaries.get(collection)
+ summaries = self._summaries.fetch_summaries([collection], None, self._dataset_type_from_row)
+ return summaries[collection.key]
+
+ def fetch_summaries(
+ self, collections: Iterable[CollectionRecord], dataset_types: Iterable[DatasetType] | None = None
+ ) -> Mapping[Any, CollectionSummary]:
+ # Docstring inherited from DatasetRecordStorageManager.
+ dataset_type_names: Iterable[str] | None = None
+ if dataset_types is not None:
+ dataset_type_names = set(dataset_type.name for dataset_type in dataset_types)
+ return self._summaries.fetch_summaries(collections, dataset_type_names, self._dataset_type_from_row)
_versions: list[VersionTuple]
"""Schema version for this class."""
diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py b/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py
index 454702b5b8..b88b80a3c5 100644
--- a/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py
+++ b/python/lsst/daf/butler/registry/datasets/byDimensions/_storage.py
@@ -32,7 +32,7 @@
__all__ = ("ByDimensionsDatasetRecordStorage",)
-from collections.abc import Iterable, Iterator, Sequence, Set
+from collections.abc import Callable, Iterable, Iterator, Sequence, Set
from datetime import datetime
from typing import TYPE_CHECKING
@@ -77,9 +77,9 @@ def __init__(
collections: CollectionManager,
static: StaticDatasetTablesTuple,
summaries: CollectionSummaryManager,
- tags: sqlalchemy.schema.Table,
+ tags_table_factory: Callable[[], sqlalchemy.schema.Table],
use_astropy_ingest_date: bool,
- calibs: sqlalchemy.schema.Table | None,
+ calibs_table_factory: Callable[[], sqlalchemy.schema.Table] | None,
):
super().__init__(datasetType=datasetType)
self._dataset_type_id = dataset_type_id
@@ -87,10 +87,26 @@ def __init__(
self._collections = collections
self._static = static
self._summaries = summaries
- self._tags = tags
- self._calibs = calibs
+ self._tags_table_factory = tags_table_factory
+ self._calibs_table_factory = calibs_table_factory
self._runKeyColumn = collections.getRunForeignKeyName()
self._use_astropy = use_astropy_ingest_date
+ self._tags_table: sqlalchemy.schema.Table | None = None
+ self._calibs_table: sqlalchemy.schema.Table | None = None
+
+ @property
+ def _tags(self) -> sqlalchemy.schema.Table:
+ if self._tags_table is None:
+ self._tags_table = self._tags_table_factory()
+ return self._tags_table
+
+ @property
+ def _calibs(self) -> sqlalchemy.schema.Table | None:
+ if self._calibs_table is None:
+ if self._calibs_table_factory is None:
+ return None
+ self._calibs_table = self._calibs_table_factory()
+ return self._calibs_table
def delete(self, datasets: Iterable[DatasetRef]) -> None:
# Docstring inherited from DatasetRecordStorage.
diff --git a/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py b/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py
index 1356576376..d051b4b38d 100644
--- a/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py
+++ b/python/lsst/daf/butler/registry/datasets/byDimensions/summaries.py
@@ -31,7 +31,7 @@
__all__ = ("CollectionSummaryManager",)
-from collections.abc import Iterable, Mapping
+from collections.abc import Callable, Iterable, Mapping
from typing import Any, Generic, TypeVar
import sqlalchemy
@@ -39,16 +39,17 @@
from ...._dataset_type import DatasetType
from ...._named import NamedKeyDict, NamedKeyMapping
from ....dimensions import GovernorDimension, addDimensionForeignKey
+from ..._caching_context import CachingContext
from ..._collection_summary import CollectionSummary
from ..._collection_type import CollectionType
from ...interfaces import (
- ChainedCollectionRecord,
CollectionManager,
CollectionRecord,
Database,
DimensionRecordStorageManager,
StaticTablesContext,
)
+from ...wildcards import CollectionWildcard
_T = TypeVar("_T")
@@ -133,6 +134,10 @@ class CollectionSummaryManager:
Manager object for the dimensions in this `Registry`.
tables : `CollectionSummaryTables`
Struct containing the tables that hold collection summaries.
+ dataset_type_table : `sqlalchemy.schema.Table`
+ Table containing dataset type definitions.
+ caching_context : `CachingContext`
+ Object controlling caching of information returned by managers.
"""
def __init__(
@@ -142,13 +147,16 @@ def __init__(
collections: CollectionManager,
dimensions: DimensionRecordStorageManager,
tables: CollectionSummaryTables[sqlalchemy.schema.Table],
+ dataset_type_table: sqlalchemy.schema.Table,
+ caching_context: CachingContext,
):
self._db = db
self._collections = collections
self._collectionKeyName = collections.getCollectionForeignKeyName()
self._dimensions = dimensions
self._tables = tables
- self._cache: dict[Any, CollectionSummary] = {}
+ self._dataset_type_table = dataset_type_table
+ self._caching_context = caching_context
@classmethod
def initialize(
@@ -158,6 +166,8 @@ def initialize(
*,
collections: CollectionManager,
dimensions: DimensionRecordStorageManager,
+ dataset_type_table: sqlalchemy.schema.Table,
+ caching_context: CachingContext,
) -> CollectionSummaryManager:
"""Create all summary tables (or check that they have been created),
returning an object to manage them.
@@ -173,6 +183,10 @@ def initialize(
Manager object for the collections in this `Registry`.
dimensions : `DimensionRecordStorageManager`
Manager object for the dimensions in this `Registry`.
+ dataset_type_table : `sqlalchemy.schema.Table`
+ Table containing dataset type definitions.
+ caching_context : `CachingContext`
+ Object controlling caching of information returned by managers.
Returns
-------
@@ -194,6 +208,8 @@ def initialize(
collections=collections,
dimensions=dimensions,
tables=tables,
+ dataset_type_table=dataset_type_table,
+ caching_context=caching_context,
)
def update(
@@ -237,39 +253,67 @@ def update(
self._tables.dimensions[dimension],
*[{self._collectionKeyName: collection.key, dimension: v} for v in values],
)
- # Update the in-memory cache, too. These changes will remain even if
- # the database inserts above are rolled back by some later exception in
- # the same transaction, but that's okay: we never promise that a
- # CollectionSummary has _just_ the dataset types and governor dimension
- # values that are actually present, only that it is guaranteed to
- # contain any dataset types or governor dimension values that _may_ be
- # present.
- # That guarantee (and the possibility of rollbacks) means we can't get
- # away with checking the cache before we try the database inserts,
- # however; if someone had attempted to insert datasets of some dataset
- # type previously, and that rolled back, and we're now trying to insert
- # some more datasets of that same type, it would not be okay to skip
- # the DB summary table insertions because we found entries in the
- # in-memory cache.
- self.get(collection).update(summary)
-
- def refresh(self, dataset_types: Mapping[int, DatasetType]) -> None:
- """Load all collection summary information from the database.
+
+ def fetch_summaries(
+ self,
+ collections: Iterable[CollectionRecord],
+ dataset_type_names: Iterable[str] | None,
+ dataset_type_factory: Callable[[sqlalchemy.engine.RowMapping], DatasetType],
+ ) -> Mapping[Any, CollectionSummary]:
+ """Fetch collection summaries given their names and dataset types.
Parameters
----------
- dataset_types : `~collections.abc.Mapping` [`int`, `DatasetType`]
- Mapping of an `int` dataset_type_id value to `DatasetType`
- instance. Summaries are only loaded for dataset types that appear
- in this mapping.
+ collections : `~collections.abc.Iterable` [`CollectionRecord`]
+ Collection records to query.
+ dataset_type_names : `~collections.abc.Iterable` [`str`]
+ Names of dataset types to include into returned summaries. If
+ `None` then all dataset types will be included.
+ dataset_type_factory : `Callable`
+ Method that takes a table row and make `DatasetType` instance out
+ of it.
+
+ Returns
+ -------
+ summaries : `~collections.abc.Mapping` [`Any`, `CollectionSummary`]
+ Collection summaries indexed by collection record key. This mapping
+ will also contain all nested non-chained collections of the chained
+ collections.
"""
+ summaries: dict[Any, CollectionSummary] = {}
+ # Check what we have in cache first.
+ if self._caching_context.collection_summaries is not None:
+ summaries, missing_keys = self._caching_context.collection_summaries.find_summaries(
+ [record.key for record in collections]
+ )
+ if not missing_keys:
+ return summaries
+ else:
+ collections = [record for record in collections if record.key in missing_keys]
+
+ # Need to expand all chained collections first.
+ non_chains: list[CollectionRecord] = []
+ chains: dict[CollectionRecord, list[CollectionRecord]] = {}
+ for collection in collections:
+ if collection.type is CollectionType.CHAINED:
+ children = self._collections.resolve_wildcard(
+ CollectionWildcard.from_names([collection.name]),
+ flatten_chains=True,
+ include_chains=False,
+ )
+ non_chains += children
+ chains[collection] = children
+ else:
+ non_chains.append(collection)
+
# Set up the SQL query we'll use to fetch all of the summary
# information at once.
- columns = [
- self._tables.datasetType.columns[self._collectionKeyName].label(self._collectionKeyName),
- self._tables.datasetType.columns.dataset_type_id.label("dataset_type_id"),
- ]
- fromClause: sqlalchemy.sql.expression.FromClause = self._tables.datasetType
+ coll_col = self._tables.datasetType.columns[self._collectionKeyName].label(self._collectionKeyName)
+ dataset_type_id_col = self._tables.datasetType.columns.dataset_type_id.label("dataset_type_id")
+ columns = [coll_col, dataset_type_id_col] + list(self._dataset_type_table.columns)
+ fromClause: sqlalchemy.sql.expression.FromClause = self._tables.datasetType.join(
+ self._dataset_type_table
+ )
for dimension, table in self._tables.dimensions.items():
columns.append(table.columns[dimension.name].label(dimension.name))
fromClause = fromClause.join(
@@ -280,72 +324,54 @@ def refresh(self, dataset_types: Mapping[int, DatasetType]) -> None:
),
isouter=True,
)
+
sql = sqlalchemy.sql.select(*columns).select_from(fromClause)
+ sql = sql.where(coll_col.in_([coll.key for coll in non_chains]))
+ # For caching we need to fetch complete summaries.
+ if self._caching_context.collection_summaries is None:
+ if dataset_type_names is not None:
+ sql = sql.where(self._dataset_type_table.columns["name"].in_(dataset_type_names))
+
# Run the query and construct CollectionSummary objects from the result
# rows. This will never include CHAINED collections or collections
# with no datasets.
- summaries: dict[Any, CollectionSummary] = {}
with self._db.query(sql) as sql_result:
sql_rows = sql_result.mappings().fetchall()
+ dataset_type_ids: dict[int, DatasetType] = {}
for row in sql_rows:
# Collection key should never be None/NULL; it's what we join on.
# Extract that and then turn it into a collection name.
collectionKey = row[self._collectionKeyName]
# dataset_type_id should also never be None/NULL; it's in the first
# table we joined.
- if datasetType := dataset_types.get(row["dataset_type_id"]):
- # See if we have a summary already for this collection; if not,
- # make one.
- summary = summaries.get(collectionKey)
- if summary is None:
- summary = CollectionSummary()
- summaries[collectionKey] = summary
- # Update the dimensions with the values in this row that
- # aren't None/NULL (many will be in general, because these
- # enter the query via LEFT OUTER JOIN).
- summary.dataset_types.add(datasetType)
- for dimension in self._tables.dimensions:
- value = row[dimension.name]
- if value is not None:
- summary.governors.setdefault(dimension.name, set()).add(value)
- self._cache = summaries
-
- def get(self, collection: CollectionRecord) -> CollectionSummary:
- """Return a summary for the given collection.
+ dataset_type_id = row["dataset_type_id"]
+ if (dataset_type := dataset_type_ids.get(dataset_type_id)) is None:
+ dataset_type_ids[dataset_type_id] = dataset_type = dataset_type_factory(row)
+ # See if we have a summary already for this collection; if not,
+ # make one.
+ summary = summaries.get(collectionKey)
+ if summary is None:
+ summary = CollectionSummary()
+ summaries[collectionKey] = summary
+ # Update the dimensions with the values in this row that
+ # aren't None/NULL (many will be in general, because these
+ # enter the query via LEFT OUTER JOIN).
+ summary.dataset_types.add(dataset_type)
+ for dimension in self._tables.dimensions:
+ value = row[dimension.name]
+ if value is not None:
+ summary.governors.setdefault(dimension.name, set()).add(value)
- Parameters
- ----------
- collection : `CollectionRecord`
- Record describing the collection for which a summary is to be
- retrieved.
+ # Add empty summary for any missing collection.
+ for collection in non_chains:
+ if collection.key not in summaries:
+ summaries[collection.key] = CollectionSummary()
- Returns
- -------
- summary : `CollectionSummary`
- Summary of the dataset types and governor dimension values in
- this collection.
- """
- summary = self._cache.get(collection.key)
- if summary is None:
- # When we load the summary information from the database, we don't
- # create summaries for CHAINED collections; those are created here
- # as needed, and *never* cached - we have no good way to update
- # those summaries when some a new dataset is added to a child
- # colletion.
- if collection.type is CollectionType.CHAINED:
- assert isinstance(collection, ChainedCollectionRecord)
- child_summaries = [self.get(self._collections.find(child)) for child in collection.children]
- if child_summaries:
- summary = CollectionSummary.union(*child_summaries)
- else:
- summary = CollectionSummary()
- else:
- # Either this collection doesn't have any datasets yet, or the
- # only datasets it has were created by some other process since
- # the last call to refresh. We assume the former; the user is
- # responsible for calling refresh if they want to read
- # concurrently-written things. We do remember this in the
- # cache.
- summary = CollectionSummary()
- self._cache[collection.key] = summary
- return summary
+ # Merge children into their chains summaries.
+ for chain, children in chains.items():
+ summaries[chain.key] = CollectionSummary.union(*(summaries[child.key] for child in children))
+
+ if self._caching_context.collection_summaries is not None:
+ self._caching_context.collection_summaries.update(summaries)
+
+ return summaries
diff --git a/python/lsst/daf/butler/registry/interfaces/_collections.py b/python/lsst/daf/butler/registry/interfaces/_collections.py
index 2bc5bf30c6..837ede94e0 100644
--- a/python/lsst/daf/butler/registry/interfaces/_collections.py
+++ b/python/lsst/daf/butler/registry/interfaces/_collections.py
@@ -36,22 +36,24 @@
]
from abc import abstractmethod
-from collections import defaultdict
-from collections.abc import Iterator, Set
-from typing import TYPE_CHECKING, Any
+from collections.abc import Iterable, Set
+from typing import TYPE_CHECKING, Any, Generic, TypeVar
from ..._timespan import Timespan
-from ...dimensions import DimensionUniverse
from .._collection_type import CollectionType
from ..wildcards import CollectionWildcard
from ._versioning import VersionedExtension, VersionTuple
if TYPE_CHECKING:
+ from .._caching_context import CachingContext
from ._database import Database, StaticTablesContext
from ._dimensions import DimensionRecordStorageManager
-class CollectionRecord:
+_Key = TypeVar("_Key")
+
+
+class CollectionRecord(Generic[_Key]):
"""A struct used to represent a collection in internal `Registry` APIs.
User-facing code should always just use a `str` to represent collections.
@@ -76,7 +78,7 @@ class CollectionRecord:
participate in some subclass equality definition.
"""
- def __init__(self, key: Any, name: str, type: CollectionType):
+ def __init__(self, key: _Key, name: str, type: CollectionType):
self.key = key
self.name = name
self.type = type
@@ -86,7 +88,7 @@ def __init__(self, key: Any, name: str, type: CollectionType):
"""Name of the collection (`str`).
"""
- key: Any
+ key: _Key
"""The primary/foreign key value for this collection.
"""
@@ -111,196 +113,85 @@ def __str__(self) -> str:
return self.name
-class RunRecord(CollectionRecord):
+class RunRecord(CollectionRecord[_Key]):
"""A subclass of `CollectionRecord` that adds execution information and
an interface for updating it.
- """
- @abstractmethod
- def update(self, host: str | None = None, timespan: Timespan | None = None) -> None:
- """Update the database record for this run with new execution
- information.
-
- Values not provided will set to ``NULL`` in the database, not ignored.
+ Parameters
+ ----------
+ key: `object`
+ Unique collection key.
+ name : `str`
+ Name of the collection.
+ host : `str`, optional
+ Name of the host or system on which this run was produced.
+ timespan: `Timespan`, optional
+ Begin and end timestamps for the period over which the run was
+ produced.
+ """
- Parameters
- ----------
- host : `str`, optional
- Name of the host or system on which this run was produced.
- Detailed form to be set by higher-level convention; from the
- `Registry` perspective, this is an entirely opaque value.
- timespan : `Timespan`, optional
- Begin and end timestamps for the period over which the run was
- produced. `None`/``NULL`` values are interpreted as infinite
- bounds.
- """
- raise NotImplementedError()
+ host: str | None
+ """Name of the host or system on which this run was produced (`str` or
+ `None`).
+ """
- @property
- @abstractmethod
- def host(self) -> str | None:
- """Return the name of the host or system on which this run was
- produced (`str` or `None`).
- """
- raise NotImplementedError()
+ timespan: Timespan
+ """Begin and end timestamps for the period over which the run was produced.
+ None`/``NULL`` values are interpreted as infinite bounds.
+ """
- @property
- @abstractmethod
- def timespan(self) -> Timespan:
- """Begin and end timestamps for the period over which the run was
- produced. `None`/``NULL`` values are interpreted as infinite
- bounds.
- """
- raise NotImplementedError()
+ def __init__(
+ self,
+ key: _Key,
+ name: str,
+ *,
+ host: str | None = None,
+ timespan: Timespan | None = None,
+ ):
+ super().__init__(key=key, name=name, type=CollectionType.RUN)
+ self.host = host
+ if timespan is None:
+ timespan = Timespan(begin=None, end=None)
+ self.timespan = timespan
def __repr__(self) -> str:
return f"RunRecord(key={self.key!r}, name={self.name!r})"
-class ChainedCollectionRecord(CollectionRecord):
+class ChainedCollectionRecord(CollectionRecord[_Key]):
"""A subclass of `CollectionRecord` that adds the list of child collections
in a ``CHAINED`` collection.
Parameters
----------
- key
- Unique collection ID, can be the same as ``name`` if ``name`` is used
- for identification. Usually this is an integer or string, but can be
- other database-specific type.
+ key: `object`
+ Unique collection key.
name : `str`
Name of the collection.
+ children: Iterable[str],
+ Ordered sequence of names of child collections.
"""
- def __init__(self, key: Any, name: str, universe: DimensionUniverse):
- super().__init__(key=key, name=name, type=CollectionType.CHAINED)
- self._children: tuple[str, ...] = ()
-
- @property
- def children(self) -> tuple[str, ...]:
- """The ordered search path of child collections that define this chain
- (`tuple` [ `str` ]).
- """
- return self._children
-
- def update(self, manager: CollectionManager, children: tuple[str, ...], flatten: bool) -> None:
- """Redefine this chain to search the given child collections.
-
- This method should be used by all external code to set children. It
- delegates to `_update`, which is what should be overridden by
- subclasses.
-
- Parameters
- ----------
- manager : `CollectionManager`
- The object that manages this records instance and all records
- instances that may appear as its children.
- children : `tuple` [ `str` ]
- A collection search path that should be resolved to set the child
- collections of this chain.
- flatten : `bool`
- If `True`, recursively flatten out any nested
- `~CollectionType.CHAINED` collections in ``children`` first.
-
- Raises
- ------
- ValueError
- Raised when the child collections contain a cycle.
- """
- children_as_wildcard = CollectionWildcard.from_names(children)
- for record in manager.resolve_wildcard(
- children_as_wildcard,
- flatten_chains=True,
- include_chains=True,
- collection_types={CollectionType.CHAINED},
- ):
- if record == self:
- raise ValueError(f"Cycle in collection chaining when defining '{self.name}'.")
- if flatten:
- children = tuple(
- record.name for record in manager.resolve_wildcard(children_as_wildcard, flatten_chains=True)
- )
- # Delegate to derived classes to do the database updates.
- self._update(manager, children)
- # Update the reverse mapping (from child to parents) in the manager,
- # by removing the old relationships and adding back in the new ones.
- for old_child in self._children:
- manager._parents_by_child[manager.find(old_child).key].discard(self.key)
- for new_child in children:
- manager._parents_by_child[manager.find(new_child).key].add(self.key)
- # Actually set this instances sequence of children.
- self._children = children
-
- def refresh(self, manager: CollectionManager) -> None:
- """Load children from the database, using the given manager to resolve
- collection primary key values into records.
-
- This method exists to ensure that all collections that may appear in a
- chain are known to the manager before any particular chain tries to
- retrieve their records from it. `ChainedCollectionRecord` subclasses
- can rely on it being called sometime after their own ``__init__`` to
- finish construction.
-
- Parameters
- ----------
- manager : `CollectionManager`
- The object that manages this records instance and all records
- instances that may appear as its children.
- """
- # Clear out the old reverse mapping (from child to parents).
- for child in self._children:
- manager._parents_by_child[manager.find(child).key].discard(self.key)
- self._children = self._load(manager)
- # Update the reverse mapping (from child to parents) in the manager.
- for child in self._children:
- manager._parents_by_child[manager.find(child).key].add(self.key)
-
- @abstractmethod
- def _update(self, manager: CollectionManager, children: tuple[str, ...]) -> None:
- """Protected implementation hook for `update`.
-
- This method should be implemented by subclasses to update the database
- to reflect the children given. It should never be called by anything
- other than `update`, which should be used by all external code.
-
- Parameters
- ----------
- manager : `CollectionManager`
- The object that manages this records instance and all records
- instances that may appear as its children.
- children : `tuple` [ `str` ]
- A collection search path that should be resolved to set the child
- collections of this chain. Guaranteed not to contain cycles.
- """
- raise NotImplementedError()
-
- @abstractmethod
- def _load(self, manager: CollectionManager) -> tuple[str, ...]:
- """Protected implementation hook for `refresh`.
-
- This method should be implemented by subclasses to retrieve the chain's
- child collections from the database and return them. It should never
- be called by anything other than `refresh`, which should be used by all
- external code.
-
- Parameters
- ----------
- manager : `CollectionManager`
- The object that manages this records instance and all records
- instances that may appear as its children.
+ children: tuple[str, ...]
+ """The ordered search path of child collections that define this chain
+ (`tuple` [ `str` ]).
+ """
- Returns
- -------
- children : `tuple` [ `str` ]
- The ordered sequence of collection names that defines the chained
- collection. Guaranteed not to contain cycles.
- """
- raise NotImplementedError()
+ def __init__(
+ self,
+ key: Any,
+ name: str,
+ *,
+ children: Iterable[str],
+ ):
+ super().__init__(key=key, name=name, type=CollectionType.CHAINED)
+ self.children = tuple(children)
def __repr__(self) -> str:
return f"ChainedCollectionRecord(key={self.key!r}, name={self.name!r}, children={self.children!r})"
-class CollectionManager(VersionedExtension):
+class CollectionManager(Generic[_Key], VersionedExtension):
"""An interface for managing the collections (including runs) in a
`Registry`.
@@ -315,7 +206,6 @@ class CollectionManager(VersionedExtension):
def __init__(self, *, registry_schema_version: VersionTuple | None = None) -> None:
super().__init__(registry_schema_version=registry_schema_version)
- self._parents_by_child: defaultdict[Any, set[Any]] = defaultdict(set)
@classmethod
@abstractmethod
@@ -325,6 +215,7 @@ def initialize(
context: StaticTablesContext,
*,
dimensions: DimensionRecordStorageManager,
+ caching_context: CachingContext,
registry_schema_version: VersionTuple | None = None,
) -> CollectionManager:
"""Construct an instance of the manager.
@@ -339,6 +230,8 @@ def initialize(
implemented with this manager.
dimensions : `DimensionRecordStorageManager`
Manager object for the dimensions in this `Registry`.
+ caching_context : `CachingContext`
+ Object controlling caching of information returned by managers.
registry_schema_version : `VersionTuple` or `None`
Schema version of this extension as defined in registry.
@@ -481,7 +374,7 @@ def refresh(self) -> None:
@abstractmethod
def register(
self, name: str, type: CollectionType, doc: str | None = None
- ) -> tuple[CollectionRecord, bool]:
+ ) -> tuple[CollectionRecord[_Key], bool]:
"""Ensure that a collection of the given name and type are present
in the layer this manager is associated with.
@@ -547,7 +440,7 @@ def remove(self, name: str) -> None:
raise NotImplementedError()
@abstractmethod
- def find(self, name: str) -> CollectionRecord:
+ def find(self, name: str) -> CollectionRecord[_Key]:
"""Return the collection record associated with the given name.
Parameters
@@ -576,7 +469,7 @@ def find(self, name: str) -> CollectionRecord:
raise NotImplementedError()
@abstractmethod
- def __getitem__(self, key: Any) -> CollectionRecord:
+ def __getitem__(self, key: Any) -> CollectionRecord[_Key]:
"""Return the collection record associated with the given
primary/foreign key value.
@@ -614,7 +507,7 @@ def resolve_wildcard(
done: set[str] | None = None,
flatten_chains: bool = True,
include_chains: bool | None = None,
- ) -> list[CollectionRecord]:
+ ) -> list[CollectionRecord[_Key]]:
"""Iterate over collection records that match a wildcard.
Parameters
@@ -632,10 +525,10 @@ def resolve_wildcard(
If `True` (default) recursively yield the child collections of
`~CollectionType.CHAINED` collections.
include_chains : `bool`, optional
- If `False`, return records for `~CollectionType.CHAINED`
+ If `True`, return records for `~CollectionType.CHAINED`
collections themselves. The default is the opposite of
- ``flattenChains``: either return records for CHAINED collections or
- their children, but not both.
+ ``flatten_chains``: either return records for CHAINED collections
+ or their children, but not both.
Returns
-------
@@ -645,7 +538,7 @@ def resolve_wildcard(
raise NotImplementedError()
@abstractmethod
- def getDocumentation(self, key: Any) -> str | None:
+ def getDocumentation(self, key: _Key) -> str | None:
"""Retrieve the documentation string for a collection.
Parameters
@@ -661,7 +554,7 @@ def getDocumentation(self, key: Any) -> str | None:
raise NotImplementedError()
@abstractmethod
- def setDocumentation(self, key: Any, doc: str | None) -> None:
+ def setDocumentation(self, key: _Key, doc: str | None) -> None:
"""Set the documentation string for a collection.
Parameters
@@ -673,16 +566,37 @@ def setDocumentation(self, key: Any, doc: str | None) -> None:
"""
raise NotImplementedError()
- def getParentChains(self, key: Any) -> Iterator[ChainedCollectionRecord]:
- """Find all CHAINED collections that directly contain the given
+ @abstractmethod
+ def getParentChains(self, key: _Key) -> set[str]:
+ """Find all CHAINED collection names that directly contain the given
collection.
Parameters
----------
key
Internal primary key value for the collection.
+
+ Returns
+ -------
+ names : `set` [`str`]
+ Parent collection names.
"""
- for parent_key in self._parents_by_child[key]:
- result = self[parent_key]
- assert isinstance(result, ChainedCollectionRecord)
- yield result
+ raise NotImplementedError()
+
+ @abstractmethod
+ def update_chain(
+ self, record: ChainedCollectionRecord[_Key], children: Iterable[str], flatten: bool = False
+ ) -> ChainedCollectionRecord[_Key]:
+ """Update chained collection composition.
+
+ Parameters
+ ----------
+ record : `ChainedCollectionRecord`
+ Chained collection record.
+ children : `~collections.abc.Iterable` [`str`]
+ Ordered names of children collections.
+ flatten : `bool`, optional
+ If `True`, recursively flatten out any nested
+ `~CollectionType.CHAINED` collections in ``children`` first.
+ """
+ raise NotImplementedError()
diff --git a/python/lsst/daf/butler/registry/interfaces/_database.py b/python/lsst/daf/butler/registry/interfaces/_database.py
index 438bf55613..61bc9e2440 100644
--- a/python/lsst/daf/butler/registry/interfaces/_database.py
+++ b/python/lsst/daf/butler/registry/interfaces/_database.py
@@ -140,10 +140,6 @@ def addTable(self, name: str, spec: ddl.TableSpec) -> sqlalchemy.schema.Table:
relationships.
"""
name = self._db._mangleTableName(name)
- if name in self._tableNames:
- _checkExistingTableDefinition(
- name, spec, self._inspector.get_columns(name, schema=self._db.namespace)
- )
metadata = self._db._metadata
assert metadata is not None, "Guaranteed by context manager that returns this object."
table = self._db._convertTableSpec(name, spec, metadata)
diff --git a/python/lsst/daf/butler/registry/interfaces/_datasets.py b/python/lsst/daf/butler/registry/interfaces/_datasets.py
index 3424028804..ac459f3c8b 100644
--- a/python/lsst/daf/butler/registry/interfaces/_datasets.py
+++ b/python/lsst/daf/butler/registry/interfaces/_datasets.py
@@ -32,7 +32,7 @@
__all__ = ("DatasetRecordStorageManager", "DatasetRecordStorage")
from abc import ABC, abstractmethod
-from collections.abc import Iterable, Iterator, Set
+from collections.abc import Iterable, Iterator, Mapping, Set
from typing import TYPE_CHECKING, Any
from lsst.daf.relation import Relation
@@ -45,6 +45,7 @@
from ._versioning import VersionedExtension, VersionTuple
if TYPE_CHECKING:
+ from .._caching_context import CachingContext
from .._collection_summary import CollectionSummary
from ..queries import SqlQueryContext
from ._collections import CollectionManager, CollectionRecord, RunRecord
@@ -329,6 +330,7 @@ def initialize(
*,
collections: CollectionManager,
dimensions: DimensionRecordStorageManager,
+ caching_context: CachingContext,
registry_schema_version: VersionTuple | None = None,
) -> DatasetRecordStorageManager:
"""Construct an instance of the manager.
@@ -344,6 +346,8 @@ def initialize(
Manager object for the collections in this `Registry`.
dimensions : `DimensionRecordStorageManager`
Manager object for the dimensions in this `Registry`.
+ caching_context : `CachingContext`
+ Object controlling caching of information returned by managers.
registry_schema_version : `VersionTuple` or `None`
Schema version of this extension as defined in registry.
@@ -487,7 +491,7 @@ def find(self, name: str) -> DatasetRecordStorage | None:
raise NotImplementedError()
@abstractmethod
- def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool]:
+ def register(self, datasetType: DatasetType) -> bool:
"""Ensure that this `Registry` can hold records for the given
`DatasetType`, creating new tables as necessary.
@@ -499,8 +503,6 @@ def register(self, datasetType: DatasetType) -> tuple[DatasetRecordStorage, bool
Returns
-------
- records : `DatasetRecordStorage`
- The object representing the records for the given dataset type.
inserted : `bool`
`True` if the dataset type did not exist in the registry before.
@@ -603,6 +605,29 @@ def getCollectionSummary(self, collection: CollectionRecord) -> CollectionSummar
"""
raise NotImplementedError()
+ @abstractmethod
+ def fetch_summaries(
+ self, collections: Iterable[CollectionRecord], dataset_types: Iterable[DatasetType] | None = None
+ ) -> Mapping[Any, CollectionSummary]:
+ """Fetch collection summaries given their names and dataset types.
+
+ Parameters
+ ----------
+ collections : `~collections.abc.Iterable` [`CollectionRecord`]
+ Collection records to query.
+ dataset_types : `~collections.abc.Iterable` [`DatasetType`] or `None`
+ Dataset types to include into returned summaries. If `None` then
+ all dataset types will be included.
+
+ Returns
+ -------
+ summaries : `~collections.abc.Mapping` [`Any`, `CollectionSummary`]
+ Collection summaries indexed by collection record key. This mapping
+ will also contain all nested non-chained collections of the chained
+ collections.
+ """
+ raise NotImplementedError()
+
@abstractmethod
def ingest_date_dtype(self) -> type:
"""Return type of the ``ingest_date`` column."""
diff --git a/python/lsst/daf/butler/registry/managers.py b/python/lsst/daf/butler/registry/managers.py
index 1d80fcde51..48702c2fc9 100644
--- a/python/lsst/daf/butler/registry/managers.py
+++ b/python/lsst/daf/butler/registry/managers.py
@@ -45,6 +45,7 @@
from .._column_type_info import ColumnTypeInfo
from .._config import Config
from ..dimensions import DimensionConfig, DimensionUniverse
+from ._caching_context import CachingContext
from ._config import RegistryConfig
from .interfaces import (
ButlerAttributeManager,
@@ -353,6 +354,11 @@ class RegistryManagerInstances(
and registry instances, including the dimension universe.
"""
+ caching_context: CachingContext
+ """Object containing caches for for various information generated by
+ managers.
+ """
+
@classmethod
def initialize(
cls,
@@ -361,6 +367,7 @@ def initialize(
*,
types: RegistryManagerTypes,
universe: DimensionUniverse,
+ caching_context: CachingContext | None = None,
) -> RegistryManagerInstances:
"""Construct manager instances from their types and an existing
database connection.
@@ -383,6 +390,8 @@ def initialize(
instances : `RegistryManagerInstances`
Struct containing manager instances.
"""
+ if caching_context is None:
+ caching_context = CachingContext()
dummy_table = ddl.TableSpec(fields=())
kwargs: dict[str, Any] = {}
schema_versions = types.schema_versions
@@ -396,6 +405,7 @@ def initialize(
database,
context,
dimensions=kwargs["dimensions"],
+ caching_context=caching_context,
registry_schema_version=schema_versions.get("collections"),
)
datasets = types.datasets.initialize(
@@ -404,6 +414,7 @@ def initialize(
collections=kwargs["collections"],
dimensions=kwargs["dimensions"],
registry_schema_version=schema_versions.get("datasets"),
+ caching_context=caching_context,
)
kwargs["datasets"] = datasets
kwargs["opaque"] = types.opaque.initialize(
@@ -440,6 +451,7 @@ def initialize(
run_key_spec=types.collections.addRunForeignKey(dummy_table, primaryKey=False, nullable=False),
ingest_date_dtype=datasets.ingest_date_dtype(),
)
+ kwargs["caching_context"] = caching_context
return cls(**kwargs)
def as_dict(self) -> Mapping[str, VersionedExtension]:
@@ -453,7 +465,9 @@ def as_dict(self) -> Mapping[str, VersionedExtension]:
manager instance. Only existing managers are returned.
"""
instances = {
- f.name: getattr(self, f.name) for f in dataclasses.fields(self) if f.name != "column_types"
+ f.name: getattr(self, f.name)
+ for f in dataclasses.fields(self)
+ if f.name not in ("column_types", "caching_context")
}
return {key: value for key, value in instances.items() if value is not None}
diff --git a/python/lsst/daf/butler/registry/queries/_sql_query_backend.py b/python/lsst/daf/butler/registry/queries/_sql_query_backend.py
index db574dde44..fc5866e8ba 100644
--- a/python/lsst/daf/butler/registry/queries/_sql_query_backend.py
+++ b/python/lsst/daf/butler/registry/queries/_sql_query_backend.py
@@ -121,6 +121,7 @@ def filter_dataset_collections(
result: dict[DatasetType, list[CollectionRecord]] = {
dataset_type: [] for dataset_type in dataset_types
}
+ summaries = self._managers.datasets.fetch_summaries(collections, result.keys())
for dataset_type, filtered_collections in result.items():
for collection_record in collections:
if not dataset_type.isCalibration() and collection_record.type is CollectionType.CALIBRATION:
@@ -130,7 +131,7 @@ def filter_dataset_collections(
f"in CALIBRATION collection {collection_record.name!r}."
)
else:
- collection_summary = self._managers.datasets.getCollectionSummary(collection_record)
+ collection_summary = summaries[collection_record.key]
if collection_summary.is_compatible_with(
dataset_type,
governor_constraints,
diff --git a/python/lsst/daf/butler/registry/sql_registry.py b/python/lsst/daf/butler/registry/sql_registry.py
index 733f820941..63c22a97d7 100644
--- a/python/lsst/daf/butler/registry/sql_registry.py
+++ b/python/lsst/daf/butler/registry/sql_registry.py
@@ -324,6 +324,13 @@ def refresh(self) -> None:
with self._db.transaction():
self._managers.refresh()
+ @contextlib.contextmanager
+ def caching_context(self) -> Iterator[None]:
+ """Context manager that enables caching."""
+ self._managers.caching_context.enable()
+ yield
+ self._managers.caching_context.disable()
+
@contextlib.contextmanager
def transaction(self, *, savepoint: bool = False) -> Iterator[None]:
"""Return a context manager that represents a transaction."""
@@ -603,7 +610,7 @@ def setCollectionChain(self, parent: str, children: Any, *, flatten: bool = Fals
assert isinstance(record, ChainedCollectionRecord)
children = CollectionWildcard.from_expression(children).require_ordered()
if children != record.children or flatten:
- record.update(self._managers.collections, children, flatten=flatten)
+ self._managers.collections.update_chain(record, children, flatten=flatten)
def getCollectionParentChains(self, collection: str) -> set[str]:
"""Return the CHAINED collections that directly contain the given one.
@@ -618,12 +625,7 @@ def getCollectionParentChains(self, collection: str) -> set[str]:
chains : `set` of `str`
Set of `~CollectionType.CHAINED` collection names.
"""
- return {
- record.name
- for record in self._managers.collections.getParentChains(
- self._managers.collections.find(collection).key
- )
- }
+ return self._managers.collections.getParentChains(self._managers.collections.find(collection).key)
def getCollectionDocumentation(self, collection: str) -> str | None:
"""Retrieve the documentation string for a collection.
@@ -702,8 +704,7 @@ def registerDatasetType(self, datasetType: DatasetType) -> bool:
This method cannot be called within transactions, as it needs to be
able to perform its own transaction to be concurrent.
"""
- _, inserted = self._managers.datasets.register(datasetType)
- return inserted
+ return self._managers.datasets.register(datasetType)
def removeDatasetType(self, name: str | tuple[str, ...]) -> None:
"""Remove the named `DatasetType` from the registry.
diff --git a/python/lsst/daf/butler/remote_butler/_remote_butler.py b/python/lsst/daf/butler/remote_butler/_remote_butler.py
index 841735a28b..1930461ec0 100644
--- a/python/lsst/daf/butler/remote_butler/_remote_butler.py
+++ b/python/lsst/daf/butler/remote_butler/_remote_butler.py
@@ -158,6 +158,12 @@ def _simplify_dataId(
# Assume we can treat it as a dict.
return SerializedDataCoordinate(dataId=data_id)
+ def _caching_context(self) -> AbstractContextManager[None]:
+ # Docstring inherited.
+ # Not implemented for now, will have to think whether this needs to
+ # do something on client side and/or remote side.
+ raise NotImplementedError()
+
def transaction(self) -> AbstractContextManager[None]:
"""Will always raise NotImplementedError.
Transactions are not supported by RemoteButler.