From 5656879ad565e58117790c3ed82438c6cbd60876 Mon Sep 17 00:00:00 2001 From: Peter Bull Date: Wed, 28 Aug 2024 11:41:09 -0400 Subject: [PATCH] Initial ADLS gen2 support (#453) * minimal ADLS gen2 support * add rigs back * Make mocked tests work with adls * add rigs back; make explicit no dirs * Update testing and hns key * format * update mocked tests * windows agnostic * set gen2 var in CI * new adls fucntionality; better tests and instantiation * Code review comments * Tweak HISTORY.md * TEMP: debug test code * don't close non-existent file * Revert "TEMP: debug test code" This reverts commit bb36a52753277a7197bbda7204602a885b8d626a. --------- Co-authored-by: Jay Qi <2721979+jayqi@users.noreply.github.com> --- .env.example | 4 + .github/workflows/tests.yml | 1 + .gitignore | 2 +- CONTRIBUTING.md | 9 ++ HISTORY.md | 1 + cloudpathlib/azure/azblobclient.py | 223 +++++++++++++++++++++----- cloudpathlib/azure/azblobpath.py | 16 +- cloudpathlib/cloudpath.py | 2 +- docs/docs/authentication.md | 7 + pyproject.toml | 2 +- tests/conftest.py | 87 ++++++++-- tests/mock_clients/mock_adls_gen2.py | 110 +++++++++++++ tests/mock_clients/mock_azureblob.py | 144 +++++++++++------ tests/test_azure_specific.py | 126 ++++++++++++++- tests/test_client.py | 6 +- tests/test_cloudpath_file_io.py | 8 +- tests/test_cloudpath_instantiation.py | 2 +- tests/test_cloudpath_manipulation.py | 10 +- tests/test_local.py | 9 +- 19 files changed, 635 insertions(+), 134 deletions(-) create mode 100644 tests/mock_clients/mock_adls_gen2.py diff --git a/.env.example b/.env.example index be1da87d..7e6fbfaf 100644 --- a/.env.example +++ b/.env.example @@ -5,6 +5,10 @@ AWS_SECRET_ACCESS_KEY=your_secret_access_key AZURE_STORAGE_CONNECTION_STRING=DefaultEndpointsProtocol=https;AccountName=your_account_name;AccountKey=your_account_key;EndpointSuffix=core.windows.net +# if testing with ADLS Gen2 storage, set credentials for that account here +AZURE_STORAGE_GEN2_CONNECTION_STRING=DefaultEndpointsProtocol=https;AccountName=your_account_name;AccountKey=your_account_key;EndpointSuffix=core.windows.net + + GOOGLE_APPLICATION_CREDENTIALS=.gscreds.json # or GCP_PROJECT_ID=your_project_id diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9d85a275..edd56b28 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -102,6 +102,7 @@ jobs: env: LIVE_AZURE_CONTAINER: ${{ secrets.LIVE_AZURE_CONTAINER }} AZURE_STORAGE_CONNECTION_STRING: ${{ secrets.AZURE_STORAGE_CONNECTION_STRING }} + AZURE_STORAGE_GEN2_CONNECTION_STRING: ${{ secrets.AZURE_STORAGE_GEN2_CONNECTION_STRING }} LIVE_GS_BUCKET: ${{ secrets.LIVE_GS_BUCKET }} LIVE_S3_BUCKET: ${{ secrets.LIVE_S3_BUCKET }} AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} diff --git a/.gitignore b/.gitignore index d542b01d..59c8c813 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ docs/docs/changelog.md docs/docs/contributing.md # perf output -perf-results.csv +perf-*.csv ## GitHub Python .gitignore ## # https://github.com/github/gitignore/blob/master/Python.gitignore diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c309ea2f..54c98962 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -81,6 +81,15 @@ Finally, you may want to run your tests against live servers to ensure that the make test-live-cloud ``` +#### Azure live backend tests + +For Azure, you can test both against Azure Blob Storage backends and Azure Data Lake Storage Gen2 backends. To run these tests, you need to set connection strings for both of the backends by setting the following environment variables (in your `.env` file for local development). If `AZURE_STORAGE_GEN2_CONNECTION_STRING` is not set, only the blob storage backend will be tested. To set up a storage account with ADLS Gen2, go through the normal creation flow for a storage account in the Azure portal and select "Enable Hierarchical Namespace" in the "Advanced" tab of the settings when configuring the account. + +```bash +AZURE_STORAGE_CONNECTION_STRING=your_connection_string +AZURE_STORAGE_GEN2_CONNECTION_STRING=your_connection_string +``` + You can copy `.env.example` to `.env` and fill in the credentials and bucket/container names for the providers you want to test against. **Note that the live tests will create and delete files on the cloud provider.** You can also skip providers you do not have accounts for by commenting them out in the `rig` and `s3_like_rig` variables defined at the end of `tests/conftest.py`. diff --git a/HISTORY.md b/HISTORY.md index 0ec4abf1..563f64bb 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -15,6 +15,7 @@ - Changed `LocalClient` so that client instances using the default storage access the default local storage directory through the `get_default_storage_dir` rather than having an explicit reference to the path set at instantiation. This means that calling `get_default_storage_dir` will reset the local storage for all clients using the default local storage, whether the client has already been instantiated or is instantiated after resetting. This fixes unintuitive behavior where `reset_local_storage` did not reset local storage when using the default client. (Issue [#414](https://github.com/drivendataorg/cloudpathlib/issues/414)) - Added a new `local_storage_dir` property to `LocalClient`. This will return the current local storage directory used by that client instance. by reference through the `get_default_ rather than with an explicit. +- Added Azure Data Lake Storage Gen2 support (Issue [#161](https://github.com/drivendataorg/cloudpathlib/issues/161), PR [#450](https://github.com/drivendataorg/cloudpathlib/pull/450)), thanks to [@M0dEx](https://github.com/M0dEx) for PR [#447](https://github.com/drivendataorg/cloudpathlib/pull/447) and PR [#449](https://github.com/drivendataorg/cloudpathlib/pull/449) ## v0.18.1 (2024-02-26) diff --git a/cloudpathlib/azure/azblobclient.py b/cloudpathlib/azure/azblobclient.py index f161a02d..98189378 100644 --- a/cloudpathlib/azure/azblobclient.py +++ b/cloudpathlib/azure/azblobclient.py @@ -1,7 +1,7 @@ from datetime import datetime, timedelta import mimetypes import os -from pathlib import Path, PurePosixPath +from pathlib import Path from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union @@ -14,13 +14,17 @@ try: from azure.core.exceptions import ResourceNotFoundError + from azure.core.credentials import AzureNamedKeyCredential from azure.storage.blob import ( + BlobPrefix, BlobSasPermissions, BlobServiceClient, BlobProperties, ContentSettings, generate_blob_sas, ) + + from azure.storage.filedatalake import DataLakeServiceClient, FileProperties except ModuleNotFoundError: implementation_registry["azure"].dependencies_loaded = False @@ -39,6 +43,7 @@ def __init__( credential: Optional[Any] = None, connection_string: Optional[str] = None, blob_service_client: Optional["BlobServiceClient"] = None, + data_lake_client: Optional["DataLakeServiceClient"] = None, file_cache_mode: Optional[Union[str, FileCacheMode]] = None, local_cache_dir: Optional[Union[str, os.PathLike]] = None, content_type_method: Optional[Callable] = mimetypes.guess_type, @@ -50,12 +55,13 @@ def __init__( - Environment variable `""AZURE_STORAGE_CONNECTION_STRING"` containing connecting string with account credentials. See [Azure Storage SDK documentation]( https://docs.microsoft.com/en-us/azure/storage/blobs/storage-quickstart-blobs-python#copy-your-credentials-from-the-azure-portal). - - Account URL via `account_url`, authenticated either with an embedded SAS token, or with - credentials passed to `credentials`. - Connection string via `connection_string`, authenticated either with an embedded SAS token or with credentials passed to `credentials`. + - Account URL via `account_url`, authenticated either with an embedded SAS token, or with + credentials passed to `credentials`. - Instantiated and already authenticated [`BlobServiceClient`]( - https://docs.microsoft.com/en-us/python/api/azure-storage-blob/azure.storage.blob.blobserviceclient?view=azure-python). + https://docs.microsoft.com/en-us/python/api/azure-storage-blob/azure.storage.blob.blobserviceclient?view=azure-python) or + [`DataLakeServiceClient`](https://learn.microsoft.com/en-us/python/api/azure-storage-file-datalake/azure.storage.filedatalake.datalakeserviceclient). If multiple methods are used, priority order is reverse of list above (later in list takes priority). If no methods are used, a [`MissingCredentialsError`][cloudpathlib.exceptions.MissingCredentialsError] @@ -76,6 +82,10 @@ def __init__( https://docs.microsoft.com/en-us/azure/storage/blobs/storage-quickstart-blobs-python#copy-your-credentials-from-the-azure-portal). blob_service_client (Optional[BlobServiceClient]): Instantiated [`BlobServiceClient`]( https://docs.microsoft.com/en-us/python/api/azure-storage-blob/azure.storage.blob.blobserviceclient?view=azure-python). + data_lake_client (Optional[DataLakeServiceClient]): Instantiated [`DataLakeServiceClient`]( + https://learn.microsoft.com/en-us/python/api/azure-storage-file-datalake/azure.storage.filedatalake.datalakeserviceclient). + If None and `blob_service_client` is passed, we will create based on that. + Otherwise, will create based on passed credential, account_url, connection_string, or AZURE_STORAGE_CONNECTION_STRING env var file_cache_mode (Optional[Union[str, FileCacheMode]]): How often to clear the file cache; see [the caching docs](https://cloudpathlib.drivendata.org/stable/caching/) for more information about the options in cloudpathlib.eums.FileCacheMode. @@ -94,27 +104,101 @@ def __init__( if connection_string is None: connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING", None) + self.data_lake_client = None # only needs to end up being set if HNS is enabled + if blob_service_client is not None: self.service_client = blob_service_client + + # create from blob service client if not passed + if data_lake_client is None: + self.data_lake_client = DataLakeServiceClient( + account_url=self.service_client.url.replace(".blob.", ".dfs.", 1), + credential=AzureNamedKeyCredential( + blob_service_client.credential.account_name, + blob_service_client.credential.account_key, + ), + ) + else: + self.data_lake_client = data_lake_client + + elif data_lake_client is not None: + self.data_lake_client = data_lake_client + + if blob_service_client is None: + self.service_client = BlobServiceClient( + account_url=self.data_lake_client.url.replace(".dfs.", ".blob.", 1), + credential=AzureNamedKeyCredential( + data_lake_client.credential.account_name, + data_lake_client.credential.account_key, + ), + ) + elif connection_string is not None: self.service_client = BlobServiceClient.from_connection_string( conn_str=connection_string, credential=credential ) + self.data_lake_client = DataLakeServiceClient.from_connection_string( + conn_str=connection_string, credential=credential + ) elif account_url is not None: - self.service_client = BlobServiceClient(account_url=account_url, credential=credential) + if ".dfs." in account_url: + self.service_client = BlobServiceClient( + account_url=account_url.replace(".dfs.", ".blob."), credential=credential + ) + self.data_lake_client = DataLakeServiceClient( + account_url=account_url, credential=credential + ) + elif ".blob." in account_url: + self.service_client = BlobServiceClient( + account_url=account_url, credential=credential + ) + self.data_lake_client = DataLakeServiceClient( + account_url=account_url.replace(".blob.", ".dfs."), credential=credential + ) + else: + # assume default to blob; HNS not supported + self.service_client = BlobServiceClient( + account_url=account_url, credential=credential + ) + else: raise MissingCredentialsError( "AzureBlobClient does not support anonymous instantiation. " "Credentials are required; see docs for options." ) - def _get_metadata(self, cloud_path: AzureBlobPath) -> Union["BlobProperties", Dict[str, Any]]: - blob = self.service_client.get_blob_client( - container=cloud_path.container, blob=cloud_path.blob - ) - properties = blob.get_blob_properties() + self._hns_enabled = None + + def _check_hns(self) -> Optional[bool]: + if self._hns_enabled is None: + account_info = self.service_client.get_account_information() # type: ignore + self._hns_enabled = account_info.get("is_hns_enabled", False) # type: ignore + + return self._hns_enabled + + def _get_metadata( + self, cloud_path: AzureBlobPath + ) -> Union["BlobProperties", "FileProperties", Dict[str, Any]]: + if self._check_hns(): + + # works on both files and directories + fsc = self.data_lake_client.get_file_system_client(cloud_path.container) # type: ignore + + if fsc is not None: + properties = fsc.get_file_client(cloud_path.blob).get_file_properties() - properties["content_type"] = properties.content_settings.content_type + # no content settings on directory + properties["content_type"] = properties.get( + "content_settings", {"content_type": None} + ).get("content_type") + + else: + blob = self.service_client.get_blob_client( + container=cloud_path.container, blob=cloud_path.blob + ) + properties = blob.get_blob_properties() + + properties["content_type"] = properties.content_settings.content_type return properties @@ -155,8 +239,17 @@ def _is_file_or_dir(self, cloud_path: AzureBlobPath) -> Optional[str]: return "dir" try: - self._get_metadata(cloud_path) - return "file" + meta = self._get_metadata(cloud_path) + + # if hns, has is_directory property; else if not hns, _get_metadata will raise if not a file + return ( + "dir" + if meta.get("is_directory", False) + or meta.get("metadata", {}).get("hdi_isfolder", False) + else "file" + ) + + # thrown if not HNS and file does not exist _or_ is dir; check if is dir instead except ResourceNotFoundError: prefix = cloud_path.blob if prefix and not prefix.endswith("/"): @@ -181,17 +274,14 @@ def _exists(self, cloud_path: AzureBlobPath) -> bool: def _list_dir( self, cloud_path: AzureBlobPath, recursive: bool = False ) -> Iterable[Tuple[AzureBlobPath, bool]]: - # shortcut if listing all available containers if not cloud_path.container: - if recursive: - raise NotImplementedError( - "Cannot recursively list all containers and contents; you can get all the containers then recursively list each separately." - ) + for container in self.service_client.list_containers(): + yield self.CloudPath(f"az://{container.name}"), True - yield from ( - (self.CloudPath(f"az://{c.name}"), True) - for c in self.service_client.list_containers() - ) + if not recursive: + continue + + yield from self._list_dir(self.CloudPath(f"az://{container.name}"), recursive=True) return container_client = self.service_client.get_container_client(cloud_path.container) @@ -200,30 +290,29 @@ def _list_dir( if prefix and not prefix.endswith("/"): prefix += "/" - yielded_dirs = set() - - # NOTE: Not recursive may be slower than necessary since it just filters - # the recursive implementation - for o in container_client.list_blobs(name_starts_with=prefix): - # get directory from this path - for parent in PurePosixPath(o.name[len(prefix) :]).parents: - # if we haven't surfaced this directory already - if parent not in yielded_dirs and str(parent) != ".": - # skip if not recursive and this is beyond our depth - if not recursive and "/" in str(parent): - continue - - yield ( - self.CloudPath(f"az://{cloud_path.container}/{prefix}{parent}"), - True, # is a directory - ) - yielded_dirs.add(parent) + if self._check_hns(): + file_system_client = self.data_lake_client.get_file_system_client(cloud_path.container) # type: ignore + paths = file_system_client.get_paths(path=cloud_path.blob, recursive=recursive) - # skip file if not recursive and this is beyond our depth - if not recursive and "/" in o.name[len(prefix) :]: - continue + for path in paths: + yield self.CloudPath(f"az://{cloud_path.container}/{path.name}"), path.is_directory - yield (self.CloudPath(f"az://{cloud_path.container}/{o.name}"), False) # is a file + else: + if not recursive: + blobs = container_client.walk_blobs(name_starts_with=prefix) + else: + blobs = container_client.list_blobs(name_starts_with=prefix) + + for blob in blobs: + # walk_blobs returns folders with a trailing slash + blob_path = blob.name.rstrip("/") + blob_cloud_path = self.CloudPath(f"az://{cloud_path.container}/{blob_path}") + + yield blob_cloud_path, ( + isinstance(blob, BlobPrefix) + if not recursive + else False # no folders from list_blobs in non-hns storage accounts + ) def _move_file( self, src: AzureBlobPath, dst: AzureBlobPath, remove_src: bool = True @@ -238,6 +327,16 @@ def _move_file( metadata=dict(last_modified=str(datetime.utcnow().timestamp())) ) + # we can use rename API when the same account on adls gen2 + elif remove_src and (src.client is dst.client) and self._check_hns(): + fsc = self.data_lake_client.get_file_system_client(src.container) # type: ignore + + if src.is_dir(): + fsc.get_directory_client(src.blob).rename_directory(f"{dst.container}/{dst.blob}") + else: + dst.parent.mkdir(parents=True, exist_ok=True) + fsc.get_file_client(src.blob).rename_file(f"{dst.container}/{dst.blob}") + else: target = self.service_client.get_blob_client(container=dst.container, blob=dst.blob) @@ -250,9 +349,34 @@ def _move_file( return dst + def _mkdir( + self, cloud_path: AzureBlobPath, parents: bool = False, exist_ok: bool = False + ) -> None: + if self._check_hns(): + file_system_client = self.data_lake_client.get_file_system_client(cloud_path.container) # type: ignore + directory_client = file_system_client.get_directory_client(cloud_path.blob) + + if not exist_ok and directory_client.exists(): + raise FileExistsError(f"Directory already exists: {cloud_path}") + + if not parents: + if not self._exists(cloud_path.parent): + raise FileNotFoundError( + f"Parent directory does not exist ({cloud_path.parent}). To create parent directories, use `parents=True`." + ) + + directory_client.create_directory() + else: + # consistent with other mkdir no-op behavior on other backends if not supported + pass + def _remove(self, cloud_path: AzureBlobPath, missing_ok: bool = True) -> None: file_or_dir = self._is_file_or_dir(cloud_path) if file_or_dir == "dir": + if self._check_hns(): + _hns_rmtree(self.data_lake_client, cloud_path.container, cloud_path.blob) + return + blobs = [ b.blob for b, is_dir in self._list_dir(cloud_path, recursive=True) if not is_dir ] @@ -313,4 +437,15 @@ def _generate_presigned_url( return url +def _hns_rmtree(data_lake_client, container, directory): + """Stateless implementation so can be used in test suite cleanup as well. + + If hierarchical namespace is enabled, delete the directory and all its contents. + (The non-HNS version is implemented in `_remove`, but will leave empty folders in HNS). + """ + file_system_client = data_lake_client.get_file_system_client(container) + directory_client = file_system_client.get_directory_client(directory) + directory_client.delete_directory() + + AzureBlobClient.AzureBlobPath = AzureBlobClient.CloudPath # type: ignore diff --git a/cloudpathlib/azure/azblobpath.py b/cloudpathlib/azure/azblobpath.py index e6777ab7..265cfd81 100644 --- a/cloudpathlib/azure/azblobpath.py +++ b/cloudpathlib/azure/azblobpath.py @@ -3,6 +3,8 @@ from tempfile import TemporaryDirectory from typing import TYPE_CHECKING +from cloudpathlib.exceptions import CloudPathIsADirectoryError + try: from azure.core.exceptions import ResourceNotFoundError except ImportError: @@ -44,8 +46,7 @@ def is_file(self) -> bool: return self.client._is_file_or_dir(self) == "file" def mkdir(self, parents=False, exist_ok=False): - # not possible to make empty directory on blob storage - pass + self.client._mkdir(self, parents=parents, exist_ok=exist_ok) def touch(self, exist_ok: bool = True): if self.exists(): @@ -84,6 +85,17 @@ def stat(self): ) ) + def replace(self, target: "AzureBlobPath") -> "AzureBlobPath": + try: + return super().replace(target) + + # we can rename directories on ADLS Gen2 + except CloudPathIsADirectoryError: + if self.client._check_hns(): + return self.client._move_file(self, target) + else: + raise + @property def container(self) -> str: return self._no_prefix.split("/", 1)[0] diff --git a/cloudpathlib/cloudpath.py b/cloudpathlib/cloudpath.py index 5cd92708..d7bf391b 100644 --- a/cloudpathlib/cloudpath.py +++ b/cloudpathlib/cloudpath.py @@ -251,7 +251,7 @@ def client(self): def __del__(self) -> None: # make sure that file handle to local path is closed - if self._handle is not None: + if self._handle is not None and self._local.exists(): self._handle.close() # ensure file removed from cache when cloudpath object deleted diff --git a/docs/docs/authentication.md b/docs/docs/authentication.md index 76c7b1b3..36018532 100644 --- a/docs/docs/authentication.md +++ b/docs/docs/authentication.md @@ -211,6 +211,13 @@ client.set_as_default_client() cp3 = CloudPath("s3://cloudpathlib-test-bucket/") ``` +## Accessing Azure DataLake Storage Gen2 (ADLS Gen2) storage with hierarchical namespace enabled + +Some Azure storage accounts are configured with "hierarchical namespace" enabled. This means that the storage account is backed by the Azure DataLake Storage Gen2 product rather than Azure Blob Storage. For many operations, the two are the same and one can use the Azure Blob Storage API. However, for some operations, a developer will need to use the Azure DataLake Storage API. The `AzureBlobClient` class implemented in cloudpathlib is designed to detect if hierarchical namespace is enabled and use the Azure DataLake Storage API in the places where it is necessary or it provides a performance improvement. Usually, a user of cloudpathlib will not need to know if hierarchical namespace is enabled and the storage account is backed by Azure DataLake Storage Gen2 or Azure Blob Storage. + +If needed, the Azure SDK provided `DataLakeServiceClient` object can be accessed via the `AzureBlobClient.data_lake_client`. The Azure SDK provided `BlobServiceClient` object can be accessed via `AzureBlobClient.service_client`. + + ## Pickling `CloudPath` objects You can pickle and unpickle `CloudPath` objects normally, for example: diff --git a/pyproject.toml b/pyproject.toml index 63974887..c7f6dcdb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ dependencies = [ ] [project.optional-dependencies] -azure = ["azure-storage-blob>=12"] +azure = ["azure-storage-blob>=12", "azure-storage-file-datalake>=12"] gs = ["google-cloud-storage"] s3 = ["boto3>=1.34.0"] all = ["cloudpathlib[azure]", "cloudpathlib[gs]", "cloudpathlib[s3]"] diff --git a/tests/conftest.py b/tests/conftest.py index bf680ece..301ffe87 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,13 @@ import os from pathlib import Path, PurePosixPath import shutil +from tempfile import TemporaryDirectory from typing import Dict, Optional from azure.storage.blob import BlobServiceClient +from azure.storage.filedatalake import ( + DataLakeServiceClient, +) import boto3 import botocore from dotenv import find_dotenv, load_dotenv @@ -26,8 +30,10 @@ LocalS3Path, ) import cloudpathlib.azure.azblobclient +from cloudpathlib.azure.azblobclient import _hns_rmtree import cloudpathlib.s3.s3client -from .mock_clients.mock_azureblob import mocked_client_class_factory, DEFAULT_CONTAINER_NAME +from .mock_clients.mock_azureblob import MockBlobServiceClient, DEFAULT_CONTAINER_NAME +from .mock_clients.mock_adls_gen2 import MockedDataLakeServiceClient from .mock_clients.mock_gs import ( mocked_client_class_factory as mocked_gsclient_class_factory, DEFAULT_GS_BUCKET_NAME, @@ -109,17 +115,20 @@ def create_test_dir_name(request) -> str: return test_dir -@fixture() -def azure_rig(request, monkeypatch, assets_dir): +def _azure_fixture(conn_str_env_var, adls_gen2, request, monkeypatch, assets_dir): drive = os.getenv("LIVE_AZURE_CONTAINER", DEFAULT_CONTAINER_NAME) test_dir = create_test_dir_name(request) live_server = os.getenv("USE_LIVE_CLOUD") == "1" + connection_kwargs = dict() + tmpdir = TemporaryDirectory() + if live_server: # Set up test assets - blob_service_client = BlobServiceClient.from_connection_string( - os.getenv("AZURE_STORAGE_CONNECTION_STRING") + blob_service_client = BlobServiceClient.from_connection_string(os.getenv(conn_str_env_var)) + data_lake_service_client = DataLakeServiceClient.from_connection_string( + os.getenv(conn_str_env_var) ) test_files = [ f for f in assets_dir.glob("**/*") if f.is_file() and f.name not in UPLOAD_IGNORE_LIST @@ -130,13 +139,25 @@ def azure_rig(request, monkeypatch, assets_dir): blob=str(f"{test_dir}/{PurePosixPath(test_file.relative_to(assets_dir))}"), ) blob_client.upload_blob(test_file.read_bytes(), overwrite=True) + + connection_kwargs["connection_string"] = os.getenv(conn_str_env_var) else: - monkeypatch.setenv("AZURE_STORAGE_CONNECTION_STRING", "") - # Mock cloud SDK + # pass key mocked params to clients via connection string + monkeypatch.setenv( + "AZURE_STORAGE_CONNECTION_STRING", f"{Path(tmpdir.name) / test_dir};{adls_gen2}" + ) + monkeypatch.setenv("AZURE_STORAGE_GEN2_CONNECTION_STRING", "") + monkeypatch.setattr( cloudpathlib.azure.azblobclient, "BlobServiceClient", - mocked_client_class_factory(test_dir), + MockBlobServiceClient, + ) + + monkeypatch.setattr( + cloudpathlib.azure.azblobclient, + "DataLakeServiceClient", + MockedDataLakeServiceClient, ) rig = CloudProviderTestRig( @@ -145,19 +166,47 @@ def azure_rig(request, monkeypatch, assets_dir): drive=drive, test_dir=test_dir, live_server=live_server, + required_client_kwargs=connection_kwargs, ) - rig.client_class().set_as_default_client() # set default client + rig.client_class(**connection_kwargs).set_as_default_client() # set default client + + # add flag for adls gen2 rig to skip some tests + rig.is_adls_gen2 = adls_gen2 + rig.connection_string = os.getenv(conn_str_env_var) # used for client instantiation tests yield rig rig.client_class._default_client = None # reset default client if live_server: - # Clean up test dir - container_client = blob_service_client.get_container_client(drive) - to_delete = container_client.list_blobs(name_starts_with=test_dir) - container_client.delete_blobs(*to_delete) + if blob_service_client.get_account_information().get("is_hns_enabled", False): + _hns_rmtree(data_lake_service_client, drive, test_dir) + + else: + # Clean up test dir + container_client = blob_service_client.get_container_client(drive) + to_delete = container_client.list_blobs(name_starts_with=test_dir) + to_delete = sorted(to_delete, key=lambda b: len(b.name.split("/")), reverse=True) + + container_client.delete_blobs(*to_delete) + + else: + tmpdir.cleanup() + + +@fixture() +def azure_rig(request, monkeypatch, assets_dir): + yield from _azure_fixture( + "AZURE_STORAGE_CONNECTION_STRING", False, request, monkeypatch, assets_dir + ) + + +@fixture() +def azure_gen2_rig(request, monkeypatch, assets_dir): + yield from _azure_fixture( + "AZURE_STORAGE_GEN2_CONNECTION_STRING", True, request, monkeypatch, assets_dir + ) @fixture() @@ -420,10 +469,20 @@ def local_s3_rig(request, monkeypatch, assets_dir): rig.client_class.reset_default_storage_dir() # reset local storage directory +# create azure fixtures for both blob and gen2 storage +azure_rigs = fixture_union( + "azure_rigs", + [ + azure_rig, # azure_rig0 + azure_gen2_rig, # azure_rig1 + ], +) + rig = fixture_union( "rig", [ - azure_rig, + azure_rig, # azure_rig0 + azure_gen2_rig, # azure_rig1 gs_rig, s3_rig, custom_s3_rig, diff --git a/tests/mock_clients/mock_adls_gen2.py b/tests/mock_clients/mock_adls_gen2.py new file mode 100644 index 00000000..aefdb735 --- /dev/null +++ b/tests/mock_clients/mock_adls_gen2.py @@ -0,0 +1,110 @@ +from datetime import datetime +from pathlib import Path, PurePosixPath +from shutil import rmtree +from azure.core.exceptions import ResourceNotFoundError +from azure.storage.filedatalake import FileProperties + +from tests.mock_clients.mock_azureblob import _JsonCache, DEFAULT_CONTAINER_NAME + + +class MockedDataLakeServiceClient: + def __init__(self, test_dir, adls): + # root is parent of the test specific directort + self.root = test_dir.parent + self.test_dir = test_dir + self.adls = adls + self.metadata_cache = _JsonCache(self.root / ".metadata") + + @classmethod + def from_connection_string(cls, conn_str, credential): + # configured in conftest.py + test_dir, adls = conn_str.split(";") + adls = adls == "True" + test_dir = Path(test_dir) + return cls(test_dir, adls) + + def get_file_system_client(self, file_system): + return MockedFileSystemClient(self.root, self.metadata_cache) + + +class MockedFileSystemClient: + def __init__(self, root, metadata_cache): + self.root = root + self.metadata_cache = metadata_cache + + def get_file_client(self, key): + return MockedFileClient(key, self.root, self.metadata_cache) + + def get_directory_client(self, key): + return MockedDirClient(key, self.root) + + def get_paths(self, path, recursive=False): + yield from ( + MockedFileClient( + PurePosixPath(f.relative_to(self.root)), self.root, self.metadata_cache + ).get_file_properties() + for f in (self.root / path).glob("**/*" if recursive else "*") + ) + + +class MockedFileClient: + def __init__(self, key, root, metadata_cache) -> None: + self.key = key + self.root = root + self.metadata_cache = metadata_cache + + def get_file_properties(self): + path = self.root / self.key + + if path.exists() and path.is_dir(): + fp = FileProperties( + **{ + "name": self.key, + "size": 0, + "ETag": "etag", + "Last-Modified": datetime.fromtimestamp(path.stat().st_mtime), + "metadata": {"hdi_isfolder": True}, + } + ) + fp["is_directory"] = True # not part of object def, but still in API responses... + return fp + + elif path.exists(): + fp = FileProperties( + **{ + "name": self.key, + "size": path.stat().st_size, + "ETag": "etag", + "Last-Modified": datetime.fromtimestamp(path.stat().st_mtime), + "metadata": {"hdi_isfolder": False}, + "Content-Type": self.metadata_cache.get(self.root / self.key, None), + } + ) + + fp["is_directory"] = False + return fp + else: + raise ResourceNotFoundError + + def rename_file(self, new_name): + new_path = self.root / new_name[len(DEFAULT_CONTAINER_NAME + "/") :] + (self.root / self.key).rename(new_path) + + +class MockedDirClient: + def __init__(self, key, root) -> None: + self.key = key + self.root = root + + def delete_directory(self): + rmtree(self.root / self.key) + + def exists(self): + return (self.root / self.key).exists() + + def create_directory(self): + (self.root / self.key).mkdir(parents=True, exist_ok=True) + + def rename_directory(self, new_name): + new_path = self.root / new_name[len(DEFAULT_CONTAINER_NAME + "/") :] + (self.root / self.key).rename(new_path) diff --git a/tests/mock_clients/mock_azureblob.py b/tests/mock_clients/mock_azureblob.py index b07aeb0a..f99e0d4a 100644 --- a/tests/mock_clients/mock_azureblob.py +++ b/tests/mock_clients/mock_azureblob.py @@ -1,8 +1,8 @@ from collections import namedtuple from datetime import datetime +import json from pathlib import Path, PurePosixPath import shutil -from tempfile import TemporaryDirectory from azure.storage.blob import BlobProperties @@ -17,51 +17,86 @@ DEFAULT_CONTAINER_NAME = "container" -def mocked_client_class_factory(test_dir: str): - class MockBlobServiceClient: - def __init__(self, *args, **kwargs): - # copy test assets for reference in tests without affecting assets - self.tmp = TemporaryDirectory() - self.tmp_path = Path(self.tmp.name) / "test_case_copy" - shutil.copytree(TEST_ASSETS, self.tmp_path / test_dir) - - self.metadata_cache = {} - - @classmethod - def from_connection_string(cls, *args, **kwargs): - return cls() - - @property - def account_name(self) -> str: - """Returns well-known account name used by Azurite - See: https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azurite?tabs=visual-studio%2Cblob-storage#well-known-storage-account-and-key - """ - return "devstoreaccount1" - - @property - def credential(self): - """Returns well-known account key used by Azurite - See: https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azurite?tabs=visual-studio%2Cblob-storage#well-known-storage-account-and-key - """ - return SharedKeyCredentialPolicy( - self.account_name, - "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", - ) +class _JsonCache: + """Used to mock file metadata store on cloud storage; saves/writes to disk so + different clients can access the same metadata store. + """ + + def __init__(self, path: Path): + self.path = path + + # initialize to empty + with self.path.open("w") as f: + json.dump({}, f) + + def __getitem__(self, key): + with self.path.open("r") as f: + return json.load(f)[str(key)] + + def __setitem__(self, key, value): + with self.path.open("r") as f: + data = json.load(f) + + with self.path.open("w") as f: + data[str(key)] = value + json.dump(data, f) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + +class MockBlobServiceClient: + def __init__(self, test_dir, adls): + # copy test assets for reference in tests without affecting assets + shutil.copytree(TEST_ASSETS, test_dir, dirs_exist_ok=True) - def __del__(self): - self.tmp.cleanup() + # root is parent of the test specific directory + self.root = test_dir.parent + self.test_dir = test_dir - def get_blob_client(self, container, blob): - return MockBlobClient(self.tmp_path, blob, service_client=self) + self.metadata_cache = _JsonCache(self.root / ".metadata") + self.adls_gen2 = adls - def get_container_client(self, container): - return MockContainerClient(self.tmp_path, container_name=container) + @classmethod + def from_connection_string(cls, conn_str, credential): + # configured in conftest.py + test_dir, adls = conn_str.split(";") + adls = adls == "True" + test_dir = Path(test_dir) + return cls(test_dir, adls) - def list_containers(self): - Container = namedtuple("Container", "name") - return [Container(name=DEFAULT_CONTAINER_NAME)] + @property + def account_name(self) -> str: + """Returns well-known account name used by Azurite + See: https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azurite?tabs=visual-studio%2Cblob-storage#well-known-storage-account-and-key + """ + return "devstoreaccount1" + + @property + def credential(self): + """Returns well-known account key used by Azurite + See: https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azurite?tabs=visual-studio%2Cblob-storage#well-known-storage-account-and-key + """ + return SharedKeyCredentialPolicy( + self.account_name, + "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==", + ) + + def get_blob_client(self, container, blob): + return MockBlobClient(self.root, blob, service_client=self) - return MockBlobServiceClient + def get_container_client(self, container): + return MockContainerClient(self.root, container_name=container) + + def list_containers(self): + Container = namedtuple("Container", "name") + return [Container(name=DEFAULT_CONTAINER_NAME)] + + def get_account_information(self): + return {"is_hns_enabled": self.adls_gen2} class MockBlobClient: @@ -86,6 +121,7 @@ def get_blob_properties(self): "content_type": self.service_client.metadata_cache.get( self.root / self.key, None ), + "metadata": dict(), } ) else: @@ -148,24 +184,30 @@ def exists(self): def list_blobs(self, name_starts_with=None): return mock_item_paged(self.root, name_starts_with) + def walk_blobs(self, name_starts_with=None): + return mock_item_paged(self.root, name_starts_with, recursive=False) + def delete_blobs(self, *blobs): for blob in blobs: (self.root / blob).unlink() delete_empty_parents_up_to_root(path=self.root / blob, root=self.root) -def mock_item_paged(root, name_starts_with=None): +def mock_item_paged(root, name_starts_with=None, recursive=True): items = [] - if not name_starts_with: - name_starts_with = "" - for f in root.glob("**/*"): - if ( - (not f.name.startswith(".")) - and f.is_file() - and (root / name_starts_with) in [f, *f.parents] - ): - items.append((PurePosixPath(f), f)) + if recursive: + items = [ + (PurePosixPath(f), f) + for f in root.glob("**/*") + if ( + (not f.name.startswith(".")) + and f.is_file() + and (root / name_starts_with) in [f, *f.parents] + ) + ] + else: + items = [(PurePosixPath(f), f) for f in (root / name_starts_with).iterdir()] for mocked, local in items: # BlobProperties diff --git a/tests/test_azure_specific.py b/tests/test_azure_specific.py index f229ce61..474525e2 100644 --- a/tests/test_azure_specific.py +++ b/tests/test_azure_specific.py @@ -1,11 +1,21 @@ import os -from azure.storage.blob import StorageStreamDownloader +from azure.core.credentials import AzureNamedKeyCredential +from azure.storage.blob import ( + BlobServiceClient, + StorageStreamDownloader, +) + +from azure.storage.filedatalake import DataLakeServiceClient import pytest from urllib.parse import urlparse, parse_qs from cloudpathlib import AzureBlobClient, AzureBlobPath -from cloudpathlib.exceptions import MissingCredentialsError +from cloudpathlib.exceptions import ( + CloudPathIsADirectoryError, + DirectoryNotEmptyError, + MissingCredentialsError, +) from cloudpathlib.local import LocalAzureBlobClient, LocalAzureBlobPath from .mock_clients.mock_azureblob import MockStorageStreamDownloader @@ -32,8 +42,8 @@ def test_azureblobpath_nocreds(client_class, monkeypatch): client_class() -def test_as_url(azure_rig): - p: AzureBlobPath = azure_rig.create_cloud_path("dir_0/file0_0.txt") +def test_as_url(azure_rigs): + p: AzureBlobPath = azure_rigs.create_cloud_path("dir_0/file0_0.txt") public_url = str(p.as_url()) public_parts = urlparse(public_url) @@ -50,8 +60,8 @@ def test_as_url(azure_rig): assert "sig" in query_params -def test_partial_download(azure_rig, monkeypatch): - p: AzureBlobPath = azure_rig.create_cloud_path("dir_0/file0_0.txt") +def test_partial_download(azure_rigs, monkeypatch): + p: AzureBlobPath = azure_rigs.create_cloud_path("dir_0/file0_0.txt") # no partial after successful download p.read_text() # downloads @@ -69,7 +79,7 @@ def _patched(self, buffer): buffer.write(b"partial") raise Exception("boom") - if azure_rig.live_server: + if azure_rigs.live_server: m.setattr(StorageStreamDownloader, "readinto", _patched) else: m.setattr(MockStorageStreamDownloader, "readinto", _patched) @@ -79,3 +89,105 @@ def _patched(self, buffer): assert not p._local.exists() assert not p.client._partial_filename(p._local).exists() + + +def test_client_instantiation(azure_rigs, monkeypatch): + # don't use creds from env vars for these tests + monkeypatch.delenv("AZURE_STORAGE_CONNECTION_STRING") + + if not azure_rigs.live_server: + return + + bsc = BlobServiceClient.from_connection_string(azure_rigs.connection_string) + dlsc = DataLakeServiceClient.from_connection_string(azure_rigs.connection_string) + + def _check_access(az_client, gen2=False): + """Check API access by listing.""" + assert len(list(az_client.service_client.list_containers())) > 0 + + if gen2: + assert len(list(az_client.data_lake_client.list_file_systems())) > 0 + + # test just BlobServiceClient passed + cl = azure_rigs.client_class(blob_service_client=bsc) + _check_access(cl, gen2=azure_rigs.is_adls_gen2) + + cl = azure_rigs.client_class(data_lake_client=dlsc) + _check_access(cl, gen2=azure_rigs.is_adls_gen2) + + cl = azure_rigs.client_class(blob_service_client=bsc, data_lake_client=dlsc) + _check_access(cl, gen2=azure_rigs.is_adls_gen2) + + cl = azure_rigs.client_class( + account_url=bsc.url, + credential=AzureNamedKeyCredential( + bsc.credential.account_name, bsc.credential.account_key + ), + ) + _check_access(cl, gen2=azure_rigs.is_adls_gen2) + + cl = azure_rigs.client_class( + account_url=dlsc.url, + credential=AzureNamedKeyCredential( + bsc.credential.account_name, bsc.credential.account_key + ), + ) + _check_access(cl, gen2=azure_rigs.is_adls_gen2) + + +def test_adls_gen2_mkdir(azure_gen2_rig): + """Since directories can be created on gen2, we should test mkdir, rmdir, rmtree, and unlink + all work as expected. + """ + p = azure_gen2_rig.create_cloud_path("new_dir") + + # mkdir + p.mkdir() + assert p.exists() and p.is_dir() + # rmdir does not throw + p.rmdir() + + # mkdir + p.mkdir() + p.mkdir(exist_ok=True) # ensure not raises + + with pytest.raises(FileExistsError): + p.mkdir(exist_ok=False) + + # touch file + (p / "file.txt").write_text("content") + # rmdir throws - not empty + with pytest.raises(DirectoryNotEmptyError): + p.rmdir() + + # rmtree works + p.rmtree() + assert not p.exists() + + # mkdir + p2 = p / "nested" + + with pytest.raises(FileNotFoundError): + p2.mkdir() + + p2.mkdir(parents=True) + assert p2.exists() + + with pytest.raises(CloudPathIsADirectoryError): + p2.unlink() + + +def test_adls_gen2_rename(azure_gen2_rig): + # rename file + p = azure_gen2_rig.create_cloud_path("file.txt") + p.write_text("content") + p2 = p.rename(azure_gen2_rig.create_cloud_path("file2.txt")) + assert not p.exists() + assert p2.exists() + + # rename dir + p = azure_gen2_rig.create_cloud_path("dir") + p.mkdir() + p2 = p.rename(azure_gen2_rig.create_cloud_path("dir2")) + assert not p.exists() + assert p2.exists() diff --git a/tests/test_client.py b/tests/test_client.py index 00b6f270..a665a5a6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -9,7 +9,7 @@ def test_default_client_instantiation(rig): - if not getattr(rig, "is_custom_s3", False): + if not getattr(rig, "is_custom_s3", False) and not (getattr(rig, "is_adls_gen2", False)): # Skip resetting the default client for custom S3 endpoint, but keep the other tests, # since they're still useful. rig.client_class._default_client = None @@ -43,7 +43,7 @@ def test_default_client_instantiation(rig): def test_different_clients(rig): p = rig.create_cloud_path("dir_0/file0_0.txt") - new_client = rig.client_class() + new_client = rig.client_class(**rig.required_client_kwargs) p2 = new_client.CloudPath(f"{rig.cloud_prefix}{rig.drive}/{rig.test_dir}/dir_0/file0_0.txt") assert p.client is not p2.client @@ -102,7 +102,7 @@ def my_content_type(path): mimes.append((".potato", "application/potato")) # see if testing custom s3 endpoint, make sure to pass the url to the constructor - kwargs = {} + kwargs = rig.required_client_kwargs.copy() custom_endpoint = os.getenv("CUSTOM_S3_ENDPOINT", "https://s3.us-west-1.drivendatabws.com") if ( rig.client_class is S3Client diff --git a/tests/test_cloudpath_file_io.py b/tests/test_cloudpath_file_io.py index 34c9c913..7dc5b149 100644 --- a/tests/test_cloudpath_file_io.py +++ b/tests/test_cloudpath_file_io.py @@ -40,8 +40,9 @@ def test_file_discovery(rig): with pytest.raises(CloudPathIsADirectoryError): p3.unlink() - with pytest.raises(CloudPathIsADirectoryError): - p3.rename(rig.create_cloud_path("dir_2/")) + if not getattr(rig, "is_adls_gen2", False): + with pytest.raises(CloudPathIsADirectoryError): + p3.rename(rig.create_cloud_path("dir_2/")) with pytest.raises(DirectoryNotEmptyError): p3.rmdir() @@ -360,7 +361,8 @@ def test_file_read_writes(rig, tmp_path): assert datetime.fromtimestamp(p.stat().st_mtime) > before_touch # no-op - p.mkdir() + if not getattr(rig, "is_adls_gen2", False): + p.mkdir() assert p.etag is not None diff --git a/tests/test_cloudpath_instantiation.py b/tests/test_cloudpath_instantiation.py index 64951495..de139593 100644 --- a/tests/test_cloudpath_instantiation.py +++ b/tests/test_cloudpath_instantiation.py @@ -77,7 +77,7 @@ def test_instantiation_errors(rig): def test_idempotency(rig): rig.client_class._default_client = None - client = rig.client_class() + client = rig.client_class(**rig.required_client_kwargs) p = client.CloudPath(f"{rig.cloud_prefix}{rig.drive}/{rig.test_dir}/dir_0/file0_0.txt") p2 = CloudPath(p) diff --git a/tests/test_cloudpath_manipulation.py b/tests/test_cloudpath_manipulation.py index a6aad166..aaf4098c 100644 --- a/tests/test_cloudpath_manipulation.py +++ b/tests/test_cloudpath_manipulation.py @@ -43,7 +43,7 @@ def test_no_op_actions(rig): assert path.is_absolute() -def test_relative_to(rig, azure_rig, gs_rig): +def test_relative_to(rig, azure_rigs, gs_rig): assert rig.create_cloud_path("bucket/path/to/file.txt").relative_to( rig.create_cloud_path("bucket/path") ) == PurePosixPath("to/file.txt") @@ -59,7 +59,7 @@ def test_relative_to(rig, azure_rig, gs_rig): with pytest.raises(ValueError): assert rig.create_cloud_path("a/b/c/d.file").relative_to(PurePosixPath("/a/b/c")) - other_rig = azure_rig if rig.cloud_prefix != azure_rig.cloud_prefix else gs_rig + other_rig = azure_rigs if rig.cloud_prefix != azure_rigs.cloud_prefix else gs_rig path = CloudPath(f"{rig.cloud_prefix}bucket/path/to/file.txt") other_cloud_path = CloudPath(f"{other_rig.cloud_prefix}bucket/path") with pytest.raises(ValueError): @@ -118,9 +118,9 @@ def test_joins(rig): def test_with_segments(rig): - assert rig.create_cloud_path("a/b/c/d").with_segments( - "x", "y", "z" - ) == rig.client_class().CloudPath(f"{rig.cloud_prefix}x/y/z") + assert rig.create_cloud_path("a/b/c/d").with_segments("x", "y", "z") == rig.client_class( + **rig.required_client_kwargs + ).CloudPath(f"{rig.cloud_prefix}x/y/z") def test_is_junction(rig): diff --git a/tests/test_local.py b/tests/test_local.py index a983fdd3..15f1b6f9 100644 --- a/tests/test_local.py +++ b/tests/test_local.py @@ -38,7 +38,14 @@ def test_interface(cloud_class, local_class): assert type(cloud_attr) is type(local_attr) if callable(cloud_attr): - assert signature(cloud_attr).parameters == signature(local_attr).parameters + # does not check type annotations, which can vary semantically, but are the same (e.g., Self != AzureBlobPath) + assert all( + a.name == b.name + for a, b in zip( + signature(cloud_attr).parameters.values(), + signature(local_attr).parameters.values(), + ) + ) @pytest.mark.parametrize("client_class", [LocalAzureBlobClient, LocalGSClient, LocalS3Client])