From a8d44d4758c9c9167e9b83de10606eee926d7e89 Mon Sep 17 00:00:00 2001 From: Tim Jenness Date: Thu, 7 Nov 2024 17:28:27 -0700 Subject: [PATCH] Add support for using a LabeledButlerFactory to retrieve the butler --- python/lsst/daf/butler/_butler.py | 23 +++++++++++++------ .../daf/butler/_labeled_butler_factory.py | 15 +++++++++++- tests/test_simpleButler.py | 7 ++++++ 3 files changed, 37 insertions(+), 8 deletions(-) diff --git a/python/lsst/daf/butler/_butler.py b/python/lsst/daf/butler/_butler.py index 8b12776e8b..e6cecf32d7 100644 --- a/python/lsst/daf/butler/_butler.py +++ b/python/lsst/daf/butler/_butler.py @@ -63,6 +63,7 @@ from ._dataset_type import DatasetType from ._deferredDatasetHandle import DeferredDatasetHandle from ._file_dataset import FileDataset + from ._labeled_butler_factory import LabeledButlerFactoryProtocol from ._storage_class import StorageClass from ._timespan import Timespan from .datastore import DatasetRefURIs @@ -583,27 +584,35 @@ def parse_dataset_uri(cls, uri: str) -> tuple[str, DatasetId]: return label, dataset_id @classmethod - def get_dataset_from_uri(cls, uri: str) -> DatasetRef | None: + def get_dataset_from_uri( + cls, uri: str, factory: LabeledButlerFactoryProtocol | None = None + ) -> DatasetRef | None: """Get the dataset associated with the given dataset URI. Parameters ---------- uri : `str` The URI associated with a dataset. + factory : `LabeledButlerFactoryProtocol` or `None`, optional + Bound factory function that will be given the butler label + and receive a `Butler`. Returns ------- ref : `DatasetRef` or `None` The dataset associated with that URI, or `None` if the UUID is valid but the dataset is not known to this butler. - - Notes - ----- - It might be possible to pass in an optional ``LabeledButlerFactory`` - but how would a caller know the right access token to supply? """ label, dataset_id = cls.parse_dataset_uri(uri) - butler = cls.from_config(label) + butler: Butler | None = None + if factory is not None: + # If the label is not recognized, it might be a path. + try: + butler = factory(label) + except KeyError: + pass + if butler is None: + butler = cls.from_config(label) return butler.get_dataset(dataset_id) @abstractmethod diff --git a/python/lsst/daf/butler/_labeled_butler_factory.py b/python/lsst/daf/butler/_labeled_butler_factory.py index 40f887ed0a..3941722527 100644 --- a/python/lsst/daf/butler/_labeled_butler_factory.py +++ b/python/lsst/daf/butler/_labeled_butler_factory.py @@ -25,9 +25,10 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -__all__ = ("LabeledButlerFactory",) +__all__ = ("LabeledButlerFactory", "LabeledButlerFactoryProtocol") from collections.abc import Callable, Mapping +from typing import Protocol from lsst.resources import ResourcePathExpression @@ -42,6 +43,12 @@ instance.""" +class LabeledButlerFactoryProtocol(Protocol): + """Callable to retrieve a butler from a label.""" + + def __call__(self, label: str) -> Butler: ... + + class LabeledButlerFactory: """Factory for efficiently instantiating Butler instances from the repository index file. This is intended for use from long-lived services @@ -83,6 +90,12 @@ def __init__(self, repositories: Mapping[str, str] | None = None) -> None: # This may be overridden by unit tests. self._preload_direct_butler_cache = True + def bind(self, access_token: str | None) -> LabeledButlerFactoryProtocol: + def create(label: str) -> Butler: + return self.create_butler(label=label, access_token=access_token) + + return create + def create_butler(self, *, label: str, access_token: str | None) -> Butler: """Create a Butler instance. diff --git a/tests/test_simpleButler.py b/tests/test_simpleButler.py index 0a44c39da5..ef8275e166 100644 --- a/tests/test_simpleButler.py +++ b/tests/test_simpleButler.py @@ -48,6 +48,7 @@ DatasetId, DatasetRef, DatasetType, + LabeledButlerFactory, StorageClass, Timespan, ) @@ -907,6 +908,9 @@ def test_dataset_uris(self): index_file.write(f"{label}: {config_dir}\n") index_file.flush() with mock_env({"DAF_BUTLER_REPOSITORY_INDEX": index_file.name}): + butler_factory = LabeledButlerFactory() + factory = butler_factory.bind(access_token=None) + for dataset_uri in ( f"ivo://rubin/{config_dir}/{ref.id}", f"ivo://rubin/{config_dir}/butler.yaml/{ref.id}", @@ -916,6 +920,9 @@ def test_dataset_uris(self): ref2 = Butler.get_dataset_from_uri(dataset_uri) self.assertEqual(ref, ref2) + ref2 = Butler.get_dataset_from_uri(dataset_uri, factory=factory) + self.assertEqual(ref, ref2) + # Non existent dataset. missing_id = str(ref.id).replace("2", "3") no_ref = Butler.get_dataset_from_uri(f"butler://{label}/{missing_id}")