Skip to content

Commit

Permalink
Make mocked tests work with adls
Browse files Browse the repository at this point in the history
  • Loading branch information
pjbull committed Jul 28, 2024
1 parent 7389b3b commit a503d5e
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 39 deletions.
7 changes: 7 additions & 0 deletions docs/docs/authentication.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,13 @@ client.set_as_default_client()
cp3 = CloudPath("s3://cloudpathlib-test-bucket/")
```

## Accessing Azure DataLake Storage Gen2 (ADLS Gen2) storage with hierarchical namespace enabled

Some Azure storage accounts are configured with "hierarchical namespace" enabled. This means that the storage account is backed by the Azure DataLake Storage Gen2 product rather than Azure Blob Storage. For many operations, the two are the same and one can use the Azure Blob Storage API. However, for some operations, a developer will need to use the Azure DataLake Storage API. The `AzureBlobClient` class implemented in cloudpathlib is designed to detect if hierarchical namespace is enabled and use the Azure DataLake Storage API in the places where it is necessary or it provides a performance improvement. Usually, a user of cloudpathlib will not need to know if hierarchical namespace is enabled and the storage account is backed by Azure DataLake Storage Gen2 or Azure Blob Storage.

If needed, the Azure SDK provided `DataLakeServiceClient` object can be accessed via the `AzureBlobClient.data_lake_client`. The Azure SDK provided `BlobServiceClient` object can be accessed via `AzureBlobClient.blob_client`.


## Pickling `CloudPath` objects

You can pickle and unpickle `CloudPath` objects normally, for example:
Expand Down
55 changes: 35 additions & 20 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from cloudpathlib.azure.azblobclient import _hns_rmtree
import cloudpathlib.s3.s3client
from .mock_clients.mock_azureblob import mocked_client_class_factory, DEFAULT_CONTAINER_NAME
from .mock_clients.mock_adls_gen2 import mocked_adls_factory
from .mock_clients.mock_gs import (
mocked_client_class_factory as mocked_gsclient_class_factory,
DEFAULT_GS_BUCKET_NAME,
Expand Down Expand Up @@ -112,7 +113,9 @@ def create_test_dir_name(request) -> str:
return test_dir


def azure_rig_factory(conn_str_env_var="AZURE_STORAGE_CONNECTION_STRING"):
def azure_rig_factory(conn_str_env_var):
adls_gen2 = conn_str_env_var == "AZURE_STORAGE_GEN2_CONNECTION_STRING"

@fixture()
def azure_rig(request, monkeypatch, assets_dir):
drive = os.getenv("LIVE_AZURE_CONTAINER", DEFAULT_CONTAINER_NAME)
Expand Down Expand Up @@ -141,11 +144,21 @@ def azure_rig(request, monkeypatch, assets_dir):
blob_client.upload_blob(test_file.read_bytes(), overwrite=True)
else:
monkeypatch.setenv("AZURE_STORAGE_CONNECTION_STRING", "")
# Mock cloud SDK
monkeypatch.setenv("AZURE_STORAGE_GEN2_CONNECTION_STRING", "")

# need shared client so both blob and adls APIs can point to same temp directory
shared_client = mocked_client_class_factory(test_dir, adls_gen2=adls_gen2)()

monkeypatch.setattr(
cloudpathlib.azure.azblobclient,
"BlobServiceClient",
mocked_client_class_factory(test_dir),
shared_client,
)

monkeypatch.setattr(
cloudpathlib.azure.azblobclient,
"DataLakeServiceClient",
mocked_adls_factory(test_dir, shared_client),
)

rig = CloudProviderTestRig(
Expand Down Expand Up @@ -423,26 +436,28 @@ def local_s3_rig(request, monkeypatch, assets_dir):


azure_rig = azure_rig_factory("AZURE_STORAGE_CONNECTION_STRING")
azure_gen2_rig = azure_rig_factory("AZURE_STORAGE_GEN2_CONNECTION_STRING")

# create azure fixtures for both blob and gen2 storage depending on which live services are configured in
# the environment variables
azure_fixtures = [azure_rig]

# explicitly test gen2 if configured
if os.getenv("AZURE_STORAGE_GEN2_CONNECTION_STRING"):
azure_gen2_rig = azure_rig_factory("AZURE_STORAGE_GEN2_CONNECTION_STRING")
azure_fixtures.append(azure_gen2_rig)
# create azure fixtures for both blob and gen2 storage
azure_rigs = fixture_union(
"azure_rigs",
[
azure_rig, # azure_rig0
azure_gen2_rig, # azure_rig1
],
)

rig = fixture_union(
"rig",
azure_fixtures
+ [
gs_rig,
s3_rig,
custom_s3_rig,
local_azure_rig,
local_s3_rig,
local_gs_rig,
[
azure_rig, # azure_rig0
azure_gen2_rig, # azure_rig1
# gs_rig,
# s3_rig,
# custom_s3_rig,
# local_azure_rig,
# local_s3_rig,
# local_gs_rig,
],
)

Expand All @@ -451,6 +466,6 @@ def local_s3_rig(request, monkeypatch, assets_dir):
"s3_like_rig",
[
s3_rig,
custom_s3_rig,
# custom_s3_rig,
],
)
52 changes: 52 additions & 0 deletions tests/mock_clients/mock_adls_gen2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from azure.storage.filedatalake import FileProperties

from .mock_azureblob import mocked_client_class_factory


def mocked_adls_factory(test_dir, blob_service_client):
"""Just wrap and use `MockBlobClient` where needed to mock ADLS Gen2"""

class MockedDataLakeServiceClient:
def __init__(self, blob_service_client):
self.blob_service_client = blob_service_client

@classmethod
def from_connection_string(cls, *args, **kwargs):
return cls(mocked_client_class_factory(test_dir, adls_gen2=True)())

def get_file_system_client(self, file_system):
return MockedFileSystemClient(self.blob_service_client)

return MockedDataLakeServiceClient


class MockedFileSystemClient:
def __init__(self, blob_service_client):
self.blob_service_client = blob_service_client

def get_file_client(self, key):
return MockedFileClient(key, self.blob_service_client)


class MockedFileClient:
def __init__(self, key, blob_service_client) -> None:
self.key = key
self.blob_service_client = blob_service_client

def get_file_properties(self):
path = self.blob_service_client.tmp_path / self.key

if path.exists() and path.is_dir():
return FileProperties(
**{
"name": self.path.name,
"size": 0,
"etag": "etag",
"last_modified": self.path.stat().st_mtime,
"metadata": {"hdi_isfolder": True},
}
)

# fallback to blob properties for files
else:
return self.blob_service_client.get_blob_client("", self.key).get_blob_properties()
37 changes: 25 additions & 12 deletions tests/mock_clients/mock_azureblob.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@
DEFAULT_CONTAINER_NAME = "container"


def mocked_client_class_factory(test_dir: str):
def mocked_client_class_factory(test_dir: str, adls_gen2: bool = False, tmp_dir: Path = None):
"""If tmp_dir is not None, use that one so that it can be shared with a MockedDataLakeServiceClient."""

class MockBlobServiceClient:
def __init__(self, *args, **kwargs):
# copy test assets for reference in tests without affecting assets
self.tmp = TemporaryDirectory()
self.tmp = TemporaryDirectory() if not tmp_dir else tmp_dir
self.tmp_path = Path(self.tmp.name) / "test_case_copy"
shutil.copytree(TEST_ASSETS, self.tmp_path / test_dir)

self.metadata_cache = {}
self.adls_gen2 = adls_gen2

@classmethod
def from_connection_string(cls, *args, **kwargs):
Expand Down Expand Up @@ -61,6 +64,9 @@ def list_containers(self):
Container = namedtuple("Container", "name")
return [Container(name=DEFAULT_CONTAINER_NAME)]

def get_account_information(self):
return {"is_hns_enabled": self.adls_gen2}

return MockBlobServiceClient


Expand All @@ -86,6 +92,7 @@ def get_blob_properties(self):
"content_type": self.service_client.metadata_cache.get(
self.root / self.key, None
),
"metadata": dict(),
}
)
else:
Expand Down Expand Up @@ -148,24 +155,30 @@ def exists(self):
def list_blobs(self, name_starts_with=None):
return mock_item_paged(self.root, name_starts_with)

def walk_blobs(self, name_starts_with=None):
return mock_item_paged(self.root, name_starts_with, recursive=False)

def delete_blobs(self, *blobs):
for blob in blobs:
(self.root / blob).unlink()
delete_empty_parents_up_to_root(path=self.root / blob, root=self.root)


def mock_item_paged(root, name_starts_with=None):
def mock_item_paged(root, name_starts_with=None, recursive=True):
items = []

if not name_starts_with:
name_starts_with = ""
for f in root.glob("**/*"):
if (
(not f.name.startswith("."))
and f.is_file()
and (root / name_starts_with) in [f, *f.parents]
):
items.append((PurePosixPath(f), f))
if recursive:
items = [
(PurePosixPath(f), f)
for f in root.glob("**/*")
if (
(not f.name.startswith("."))
and f.is_file()
and (root / name_starts_with) in [f, *f.parents]
)
]
else:
items = [(PurePosixPath(f), f) for f in (root / name_starts_with).iterdir()]

for mocked, local in items:
# BlobProperties
Expand Down
10 changes: 5 additions & 5 deletions tests/test_azure_specific.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def test_azureblobpath_nocreds(client_class, monkeypatch):
client_class()


def test_as_url(azure_rig):
p: AzureBlobPath = azure_rig.create_cloud_path("dir_0/file0_0.txt")
def test_as_url(azure_rigs):
p: AzureBlobPath = azure_rigs.create_cloud_path("dir_0/file0_0.txt")

public_url = str(p.as_url())
public_parts = urlparse(public_url)
Expand All @@ -50,8 +50,8 @@ def test_as_url(azure_rig):
assert "sig" in query_params


def test_partial_download(azure_rig, monkeypatch):
p: AzureBlobPath = azure_rig.create_cloud_path("dir_0/file0_0.txt")
def test_partial_download(azure_rigs, monkeypatch):
p: AzureBlobPath = azure_rigs.create_cloud_path("dir_0/file0_0.txt")

# no partial after successful download
p.read_text() # downloads
Expand All @@ -69,7 +69,7 @@ def _patched(self, buffer):
buffer.write(b"partial")
raise Exception("boom")

if azure_rig.live_server:
if azure_rigs.live_server:
m.setattr(StorageStreamDownloader, "readinto", _patched)
else:
m.setattr(MockStorageStreamDownloader, "readinto", _patched)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cloudpath_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_no_op_actions(rig):
assert path.is_absolute()


def test_relative_to(rig, azure_rig, gs_rig):
def test_relative_to(rig, azure_rigs, gs_rig):
assert rig.create_cloud_path("bucket/path/to/file.txt").relative_to(
rig.create_cloud_path("bucket/path")
) == PurePosixPath("to/file.txt")
Expand All @@ -59,7 +59,7 @@ def test_relative_to(rig, azure_rig, gs_rig):

with pytest.raises(ValueError):
assert rig.create_cloud_path("a/b/c/d.file").relative_to(PurePosixPath("/a/b/c"))
other_rig = azure_rig if rig.cloud_prefix != azure_rig.cloud_prefix else gs_rig
other_rig = azure_rigs if rig.cloud_prefix != azure_rigs.cloud_prefix else gs_rig
path = CloudPath(f"{rig.cloud_prefix}bucket/path/to/file.txt")
other_cloud_path = CloudPath(f"{other_rig.cloud_prefix}bucket/path")
with pytest.raises(ValueError):
Expand Down

0 comments on commit a503d5e

Please sign in to comment.