diff --git a/python/lsst/daf/butler/_query_all_datasets.py b/python/lsst/daf/butler/_query_all_datasets.py index 1432798dda..e8bed6c9e1 100644 --- a/python/lsst/daf/butler/_query_all_datasets.py +++ b/python/lsst/daf/butler/_query_all_datasets.py @@ -113,30 +113,33 @@ def query_all_datasets( raise InvalidQueryError("Can not use wildcards in collections when find_first=True") dataset_type_query = list(ensure_iterable(args.name)) - dataset_type_collections = _filter_collections_and_dataset_types( - butler, args.collections, dataset_type_query - ) - limit = args.limit - for dt, filtered_collections in sorted(dataset_type_collections.items()): - _LOG.debug("Querying dataset type %s", dt) - results = ( - query.datasets(dt, filtered_collections, find_first=args.find_first) - .where(args.data_id, args.where, args.kwargs, bind=args.bind) - .limit(limit) + with butler.registry.caching_context(): + dataset_type_collections = _filter_collections_and_dataset_types( + butler, args.collections, dataset_type_query ) - if args.with_dimension_records: - results = results.with_dimension_records() - - for page in results._iter_pages(): - if limit is not None: - # Track how much of the limit has been used up by each query. - limit -= len(page) - - yield DatasetsPage(dataset_type=dt, data=page) - if limit is not None and limit <= 0: - break + limit = args.limit + for dt, filtered_collections in sorted(dataset_type_collections.items()): + _LOG.debug("Querying dataset type %s", dt) + results = ( + query.datasets(dt, filtered_collections, find_first=args.find_first) + .where(args.data_id, args.where, args.kwargs, bind=args.bind) + .limit(limit) + ) + if args.with_dimension_records: + results = results.with_dimension_records() + + for page in results._iter_pages(): + if limit is not None: + # Track how much of the limit has been used up by each + # query. + limit -= len(page) + + yield DatasetsPage(dataset_type=dt, data=page) + + if limit is not None and limit <= 0: + break def _filter_collections_and_dataset_types( diff --git a/python/lsst/daf/butler/direct_butler/_direct_butler.py b/python/lsst/daf/butler/direct_butler/_direct_butler.py index 69972d8856..4095fd2af5 100644 --- a/python/lsst/daf/butler/direct_butler/_direct_butler.py +++ b/python/lsst/daf/butler/direct_butler/_direct_butler.py @@ -1426,18 +1426,19 @@ def removeRuns(self, names: Iterable[str], unstore: bool = True) -> None: names = list(names) refs: list[DatasetRef] = [] all_dataset_types = [dt.name for dt in self._registry.queryDatasetTypes(...)] - for name in names: - collectionType = self._registry.getCollectionType(name) - if collectionType is not CollectionType.RUN: - raise TypeError(f"The collection type of '{name}' is {collectionType.name}, not RUN.") - with self.query() as query: - # Work out the dataset types that are relevant. - collections_info = self.collections.query_info(name, include_summary=True) - filtered_dataset_types = self.collections._filter_dataset_types( - all_dataset_types, collections_info - ) - for dt in filtered_dataset_types: - refs.extend(query.datasets(dt, collections=name)) + with self._caching_context(): + for name in names: + collectionType = self._registry.getCollectionType(name) + if collectionType is not CollectionType.RUN: + raise TypeError(f"The collection type of '{name}' is {collectionType.name}, not RUN.") + with self.query() as query: + # Work out the dataset types that are relevant. + collections_info = self.collections.query_info(name, include_summary=True) + filtered_dataset_types = self.collections._filter_dataset_types( + all_dataset_types, collections_info + ) + for dt in filtered_dataset_types: + refs.extend(query.datasets(dt, collections=name)) with self._datastore.transaction(), self._registry.transaction(): if unstore: self._datastore.trash(refs) diff --git a/python/lsst/daf/butler/registry/sql_registry.py b/python/lsst/daf/butler/registry/sql_registry.py index edc35da3d2..d5dd11acbe 100644 --- a/python/lsst/daf/butler/registry/sql_registry.py +++ b/python/lsst/daf/butler/registry/sql_registry.py @@ -2332,17 +2332,16 @@ def _query_driver( default_data_id: DataCoordinate, ) -> Iterator[DirectQueryDriver]: """Set up a `QueryDriver` instance for query execution.""" - with self.caching_context(): - driver = DirectQueryDriver( - self._db, - self.dimensions, - self._managers, - self.dimension_record_cache, - default_collections=default_collections, - default_data_id=default_data_id, - ) - with driver: - yield driver + driver = DirectQueryDriver( + self._db, + self.dimensions, + self._managers, + self.dimension_record_cache, + default_collections=default_collections, + default_data_id=default_data_id, + ) + with driver: + yield driver def queryDatasetAssociations( self, diff --git a/python/lsst/daf/butler/script/exportCalibs.py b/python/lsst/daf/butler/script/exportCalibs.py index 07be56c19b..da6eaedf0e 100644 --- a/python/lsst/daf/butler/script/exportCalibs.py +++ b/python/lsst/daf/butler/script/exportCalibs.py @@ -73,16 +73,19 @@ def find_calibration_datasets( raise RuntimeError(f"Collection {collection.name} is not a CALIBRATION collection.") exportDatasets = [] - for calibType in datasetTypes: - with butler.query() as query: - results = query.datasets(calibType, collections=collection.name, find_first=False) - - try: - refs = list(results.with_dimension_records()) - except Exception as e: - e.add_note(f"Error from querying dataset type {calibType} and collection {collection.name}") - raise - exportDatasets.extend(refs) + with butler.registry.caching_context(): + for calibType in datasetTypes: + with butler.query() as query: + results = query.datasets(calibType, collections=collection.name, find_first=False) + + try: + refs = list(results.with_dimension_records()) + except Exception as e: + e.add_note( + f"Error from querying dataset type {calibType} and collection {collection.name}" + ) + raise + exportDatasets.extend(refs) return exportDatasets diff --git a/python/lsst/daf/butler/script/removeRuns.py b/python/lsst/daf/butler/script/removeRuns.py index cf16d1812e..9c29172823 100644 --- a/python/lsst/daf/butler/script/removeRuns.py +++ b/python/lsst/daf/butler/script/removeRuns.py @@ -86,29 +86,34 @@ def _getCollectionInfo( The dataset types and and how many will be removed. """ butler = Butler.from_config(repo) - try: - collections = butler.collections.query_info( - collection, CollectionType.RUN, include_chains=False, include_parents=True, include_summary=True - ) - except MissingCollectionError: - # Act as if no collections matched. - collections = [] - dataset_types = [dt.name for dt in butler.registry.queryDatasetTypes(...)] - dataset_types = list(butler.collections._filter_dataset_types(dataset_types, collections)) - - runs = [] - datasets: dict[str, int] = defaultdict(int) - for collection_info in collections: - assert collection_info.type == CollectionType.RUN and collection_info.parents is not None - runs.append(RemoveRun(collection_info.name, list(collection_info.parents))) - with butler.query() as query: - for dt in dataset_types: - results = query.datasets(dt, collections=collection_info.name) - count = results.count(exact=False) - if count: - datasets[dt] += count - - return runs, {k: datasets[k] for k in sorted(datasets.keys())} + with butler.registry.caching_context(): + try: + collections = butler.collections.query_info( + collection, + CollectionType.RUN, + include_chains=False, + include_parents=True, + include_summary=True, + ) + except MissingCollectionError: + # Act as if no collections matched. + collections = [] + dataset_types = [dt.name for dt in butler.registry.queryDatasetTypes(...)] + dataset_types = list(butler.collections._filter_dataset_types(dataset_types, collections)) + + runs = [] + datasets: dict[str, int] = defaultdict(int) + for collection_info in collections: + assert collection_info.type == CollectionType.RUN and collection_info.parents is not None + runs.append(RemoveRun(collection_info.name, list(collection_info.parents))) + with butler.query() as query: + for dt in dataset_types: + results = query.datasets(dt, collections=collection_info.name) + count = results.count(exact=False) + if count: + datasets[dt] += count + + return runs, {k: datasets[k] for k in sorted(datasets.keys())} def removeRuns(