Skip to content

Commit

Permalink
Add RemoteButler.find_dataset and corresponding server implementation
Browse files Browse the repository at this point in the history
This is not the final interface for client/server but is there
to have something in place.
  • Loading branch information
timj committed Oct 27, 2023
1 parent c856cef commit 2a50477
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 6 deletions.
67 changes: 63 additions & 4 deletions python/lsst/daf/butler/remote_butler/_remote_butler.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,20 @@
from .._butler_config import ButlerConfig
from .._config import Config
from .._dataset_existence import DatasetExistence
from .._dataset_ref import DatasetIdGenEnum, DatasetRef
from .._dataset_ref import DatasetIdGenEnum, DatasetRef, SerializedDatasetRef
from .._dataset_type import DatasetType, SerializedDatasetType
from .._deferredDatasetHandle import DeferredDatasetHandle
from .._file_dataset import FileDataset
from .._limited_butler import LimitedButler
from .._storage_class import StorageClass
from .._timespan import Timespan
from ..datastore import DatasetRefURIs
from ..dimensions import DataId, DimensionConfig, DimensionUniverse
from ..registry import CollectionArgType, Registry, RegistryDefaults
from ..dimensions import DataCoordinate, DataId, DimensionConfig, DimensionUniverse, SerializedDataCoordinate
from ..registry import CollectionArgType, NoDefaultCollectionError, Registry, RegistryDefaults
from ..registry.wildcards import CollectionWildcard
from ..transfers import RepoExportContext
from ._config import RemoteButlerConfigModel
from .server import FindDatasetModel


class RemoteButler(Butler):
Expand Down Expand Up @@ -101,6 +103,39 @@ def dimensions(self) -> DimensionUniverse:
self._dimensions = DimensionUniverse(config)
return self._dimensions

def _simplify_dataId(
self, dataId: DataId | None, **kwargs: dict[str, int | str]
) -> SerializedDataCoordinate | None:
"""Take a generic Data ID and convert it to a serializable form.
Parameters
----------
dataId : `dict`, `None`, `DataCoordinate`
The data ID to serialize.
**kwargs : `dict`
Additional values that should be included if this is not
a `DataCoordinate`.
Returns
-------
data_id : `SerializedDataCoordinate` or `None`
A serializable form.
"""
if dataId is None and not kwargs:
return None
if isinstance(dataId, DataCoordinate):
return dataId.to_simple()

if dataId is None:
data_id = kwargs
elif kwargs:
# Change variable because DataId is immutable and mypy complains.
data_id = dict(dataId)
data_id.update(kwargs)

# Assume we can treat it as a dict.
return SerializedDataCoordinate(dataId=data_id)

def getDatasetType(self, name: str) -> DatasetType:
# Docstring inherited.
raise NotImplementedError()
Expand Down Expand Up @@ -196,7 +231,31 @@ def find_dataset(
datastore_records: bool = False,
**kwargs: Any,
) -> DatasetRef | None:
raise NotImplementedError()
if collections is None:
if not self.collections:
raise NoDefaultCollectionError(
"No collections provided to find_dataset, and no defaults from butler construction."
)
collections = self.collections
# Temporary hack. Assume strings for collections. In future
# want to construct CollectionWildcard and filter it through collection
# cache to generate list of collection names.
wildcards = CollectionWildcard.from_expression(collections)

if isinstance(datasetType, DatasetType):
datasetType = datasetType.name

query = FindDatasetModel(
dataId=self._simplify_dataId(dataId, **kwargs), collections=wildcards.strings
)

path = f"find_dataset/{datasetType}"
response = self._client.post(
self._get_url(path), json=query.model_dump(mode="json", exclude_unset=True)
)
response.raise_for_status()

return DatasetRef.from_simple(SerializedDatasetRef(**response.json()), universe=self.dimensions)

def retrieveArtifacts(
self,
Expand Down
1 change: 1 addition & 0 deletions python/lsst/daf/butler/remote_butler/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@

from ._factory import *
from ._server import *
from ._server_models import *
53 changes: 52 additions & 1 deletion python/lsst/daf/butler/remote_butler/server/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,16 @@

from fastapi import Depends, FastAPI
from fastapi.middleware.gzip import GZipMiddleware
from lsst.daf.butler import Butler, SerializedDatasetType
from lsst.daf.butler import (
Butler,
DataCoordinate,
SerializedDataCoordinate,
SerializedDatasetRef,
SerializedDatasetType,
)

from ._factory import Factory
from ._server_models import FindDatasetModel

BUTLER_ROOT = "ci_hsc_gen3/DATA"

Expand All @@ -56,6 +63,26 @@ def factory_dependency() -> Factory:
return Factory(butler=_make_global_butler())


def unpack_dataId(butler: Butler, data_id: SerializedDataCoordinate | None) -> DataCoordinate | None:
"""Convert the serialized dataId back to full DataCoordinate.
Parameters
----------
butler : `lsst.daf.butler.Butler`
The butler to use for registry and universe.
data_id : `SerializedDataCoordinate` or `None`
The serialized form.
Returns
-------
dataId : `DataCoordinate` or `None`
The DataId usable by registry.
"""
if data_id is None:
return None
return DataCoordinate.from_simple(data_id, registry=butler.registry)


@app.get("/butler/v1/universe", response_model=dict[str, Any])
def get_dimension_universe(factory: Factory = Depends(factory_dependency)) -> dict[str, Any]:
"""Allow remote client to get dimensions definition."""
Expand All @@ -78,3 +105,27 @@ def get_dataset_type(
butler = factory.create_butler()
datasetType = butler.get_dataset_type(dataset_type_name)
return datasetType.to_simple()


# Not yet supported: TimeSpan is not yet a pydantic model.
# collections parameter assumes client-side has resolved regexes.
@app.post(
"/butler/v1/find_dataset/{dataset_type}",
summary="Retrieve this dataset definition from collection, dataset type, and dataId",
response_model=SerializedDatasetRef,
response_model_exclude_unset=True,
response_model_exclude_defaults=True,
response_model_exclude_none=True,
)
def find_dataset(
dataset_type: str,
query: FindDatasetModel,
factory: Factory = Depends(factory_dependency),
) -> SerializedDatasetRef | None:
collection_query = query.collections if query.collections else None

butler = factory.create_butler()
ref = butler.find_dataset(
dataset_type, dataId=unpack_dataId(butler, query.dataId), collections=collection_query
)
return ref.to_simple() if ref else None
11 changes: 11 additions & 0 deletions python/lsst/daf/butler/remote_butler/server/_server_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,14 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

"""Models used for client/server communication."""

__all__ = ["FindDatasetModel"]

from lsst.daf.butler import SerializedDataCoordinate

from ..._compat import _BaseModelCompat


class FindDatasetModel(_BaseModelCompat):
dataId: SerializedDataCoordinate
collections: list[str]
13 changes: 12 additions & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import os.path
import unittest
import uuid

try:
# Failing to import any of these should disable the tests.
Expand All @@ -37,7 +38,8 @@
TestClient = None
app = None

from lsst.daf.butler import Butler
from lsst.daf.butler import Butler, DatasetRef
from lsst.daf.butler.tests import DatastoreMock
from lsst.daf.butler.tests.utils import MetricTestRepo, makeTestTempDir, removeTestTempDir

TESTDIR = os.path.abspath(os.path.dirname(__file__))
Expand Down Expand Up @@ -68,6 +70,9 @@ def setUpClass(cls):
# Override the server's Butler initialization to point at our test repo
server_butler = Butler.from_config(cls.root, writeable=True)

# Not yet testing butler.get()
DatastoreMock.apply(server_butler)

def create_factory_dependency():
return Factory(butler=server_butler)

Expand All @@ -79,6 +84,7 @@ def create_factory_dependency():

# Populate the test server.
server_butler.import_(filename=os.path.join(TESTDIR, "data", "registry", "base.yaml"))
server_butler.import_(filename=os.path.join(TESTDIR, "data", "registry", "datasets-uuid.yaml"))

@classmethod
def tearDownClass(cls):
Expand All @@ -98,6 +104,11 @@ def test_get_dataset_type(self):
bias_type = self.butler.get_dataset_type("bias")
self.assertEqual(bias_type.name, "bias")

def test_find_dataset(self):
ref = self.butler.find_dataset("bias", collections="imported_g", detector=1, instrument="Cam1")
self.assertIsInstance(ref, DatasetRef)
self.assertEqual(ref.id, uuid.UUID("e15ab039-bc8b-4135-87c5-90902a7c0b22"))


if __name__ == "__main__":
unittest.main()

0 comments on commit 2a50477

Please sign in to comment.