Skip to content

Commit

Permalink
Update testing and hns key
Browse files Browse the repository at this point in the history
  • Loading branch information
pjbull committed Jul 28, 2024
1 parent cee1761 commit 96a58ad
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 72 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ make test-live-cloud

#### Azure live backend tests

For Azure, you can test both against Azure Blob Storage backends and Azure Data Lake Storage Gen2 backends. To run these tests, you need to set connection strings for both of the backends by setting the following environment variables (in your `.env` file for local development). If `AZURE_STORAGE_GEN2_CONNECTION_STRING` is not set, only the blob storage backend will be tested.
For Azure, you can test both against Azure Blob Storage backends and Azure Data Lake Storage Gen2 backends. To run these tests, you need to set connection strings for both of the backends by setting the following environment variables (in your `.env` file for local development). If `AZURE_STORAGE_GEN2_CONNECTION_STRING` is not set, only the blob storage backend will be tested. To set up a storage account with ADLS Gen2, go through the normal creation flow for a storage account in the Azure portal and select "Enable Hierarchical Namespace" in the "Advanced" tab of the settings when configuring the account.

```bash
AZURE_STORAGE_CONNECTION_STRING=your_connection_string
Expand Down
9 changes: 6 additions & 3 deletions cloudpathlib/azure/azblobclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,18 +139,21 @@ def __init__(
self.hns_cache: Dict[str, bool] = {}

def _check_hns(self, cloud_path: AzureBlobPath) -> bool:
if cloud_path.container not in self.hns_cache:
hns_key = self.service_client.account_name + "__" + cloud_path.container

if hns_key not in self.hns_cache:
hns_enabled: bool = self.service_client.get_account_information().get(
"is_hns_enabled", False
) # type: ignore
self.hns_cache[cloud_path.container] = hns_enabled
self.hns_cache[hns_key] = hns_enabled

return self.hns_cache[cloud_path.container]
return self.hns_cache[hns_key]

def _get_metadata(
self, cloud_path: AzureBlobPath
) -> Union["BlobProperties", "FileProperties", Dict[str, Any]]:
if self._check_hns(cloud_path):

# works on both files and directories
fsc = self.data_lake_client.get_file_system_client(cloud_path.container) # type: ignore

Expand Down
132 changes: 68 additions & 64 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,81 +113,88 @@ def create_test_dir_name(request) -> str:
return test_dir


def azure_rig_factory(conn_str_env_var):
adls_gen2 = conn_str_env_var == "AZURE_STORAGE_GEN2_CONNECTION_STRING"

@fixture()
def azure_rig(request, monkeypatch, assets_dir):
drive = os.getenv("LIVE_AZURE_CONTAINER", DEFAULT_CONTAINER_NAME)
test_dir = create_test_dir_name(request)
def _azure_fixture(conn_str_env_var, adls_gen2, request, monkeypatch, assets_dir):
drive = os.getenv("LIVE_AZURE_CONTAINER", DEFAULT_CONTAINER_NAME)
test_dir = create_test_dir_name(request)

live_server = os.getenv("USE_LIVE_CLOUD") == "1"
live_server = os.getenv("USE_LIVE_CLOUD") == "1"

if live_server:
# Set up test assets
blob_service_client = BlobServiceClient.from_connection_string(
os.getenv(conn_str_env_var)
)
data_lake_service_client = DataLakeServiceClient.from_connection_string(
os.getenv(conn_str_env_var)
if live_server:
# Set up test assets
blob_service_client = BlobServiceClient.from_connection_string(
os.getenv(conn_str_env_var)
)
data_lake_service_client = DataLakeServiceClient.from_connection_string(
os.getenv(conn_str_env_var)
)
test_files = [
f
for f in assets_dir.glob("**/*")
if f.is_file() and f.name not in UPLOAD_IGNORE_LIST
]
for test_file in test_files:
blob_client = blob_service_client.get_blob_client(
container=drive,
blob=str(f"{test_dir}/{PurePosixPath(test_file.relative_to(assets_dir))}"),
)
test_files = [
f
for f in assets_dir.glob("**/*")
if f.is_file() and f.name not in UPLOAD_IGNORE_LIST
]
for test_file in test_files:
blob_client = blob_service_client.get_blob_client(
container=drive,
blob=str(f"{test_dir}/{PurePosixPath(test_file.relative_to(assets_dir))}"),
)
blob_client.upload_blob(test_file.read_bytes(), overwrite=True)
else:
monkeypatch.setenv("AZURE_STORAGE_CONNECTION_STRING", "")
monkeypatch.setenv("AZURE_STORAGE_GEN2_CONNECTION_STRING", "")

# need shared client so both blob and adls APIs can point to same temp directory
shared_client = mocked_client_class_factory(test_dir, adls_gen2=adls_gen2)()
blob_client.upload_blob(test_file.read_bytes(), overwrite=True)
else:
monkeypatch.setenv("AZURE_STORAGE_CONNECTION_STRING", "")
monkeypatch.setenv("AZURE_STORAGE_GEN2_CONNECTION_STRING", "")

monkeypatch.setattr(
cloudpathlib.azure.azblobclient,
"BlobServiceClient",
shared_client,
)
# need shared client so both blob and adls APIs can point to same temp directory
shared_client = mocked_client_class_factory(test_dir, adls_gen2=adls_gen2)()

monkeypatch.setattr(
cloudpathlib.azure.azblobclient,
"DataLakeServiceClient",
mocked_adls_factory(test_dir, shared_client),
)
monkeypatch.setattr(
cloudpathlib.azure.azblobclient,
"BlobServiceClient",
shared_client,
)

rig = CloudProviderTestRig(
path_class=AzureBlobPath,
client_class=AzureBlobClient,
drive=drive,
test_dir=test_dir,
live_server=live_server,
monkeypatch.setattr(
cloudpathlib.azure.azblobclient,
"DataLakeServiceClient",
mocked_adls_factory(test_dir, shared_client),
)

rig.client_class().set_as_default_client() # set default client
rig = CloudProviderTestRig(
path_class=AzureBlobPath,
client_class=AzureBlobClient,
drive=drive,
test_dir=test_dir,
live_server=live_server,
required_client_kwargs=dict(connection_string=os.getenv(conn_str_env_var)), # switch on/off adls gen2
)

rig.client_class(connection_string=os.getenv(conn_str_env_var)).set_as_default_client() # set default client

# add flag for adls gen2 rig to skip some tests
rig.is_adls_gen2 = adls_gen2

yield rig
yield rig

rig.client_class._default_client = None # reset default client
rig.client_class._default_client = None # reset default client

if live_server:
if blob_service_client.get_account_information().get("is_hns_enabled", False):
_hns_rmtree(data_lake_service_client, drive, test_dir)
if live_server:
if blob_service_client.get_account_information().get("is_hns_enabled", False):
_hns_rmtree(data_lake_service_client, drive, test_dir)

else:
# Clean up test dir
container_client = blob_service_client.get_container_client(drive)
to_delete = container_client.list_blobs(name_starts_with=test_dir)
to_delete = sorted(to_delete, key=lambda b: len(b.name.split("/")), reverse=True)
else:
# Clean up test dir
container_client = blob_service_client.get_container_client(drive)
to_delete = container_client.list_blobs(name_starts_with=test_dir)
to_delete = sorted(to_delete, key=lambda b: len(b.name.split("/")), reverse=True)

container_client.delete_blobs(*to_delete)
container_client.delete_blobs(*to_delete)

return azure_rig

@fixture()
def azure_rig(request, monkeypatch, assets_dir):
yield from _azure_fixture("AZURE_STORAGE_CONNECTION_STRING", False, request, monkeypatch, assets_dir)

@fixture()
def azure_gen2_rig(request, monkeypatch, assets_dir):
yield from _azure_fixture("AZURE_STORAGE_GEN2_CONNECTION_STRING", True, request, monkeypatch, assets_dir)


@fixture()
Expand Down Expand Up @@ -435,9 +442,6 @@ def local_s3_rig(request, monkeypatch, assets_dir):
rig.client_class.reset_default_storage_dir() # reset local storage directory


azure_rig = azure_rig_factory("AZURE_STORAGE_CONNECTION_STRING")
azure_gen2_rig = azure_rig_factory("AZURE_STORAGE_GEN2_CONNECTION_STRING")

# create azure fixtures for both blob and gen2 storage
azure_rigs = fixture_union(
"azure_rigs",
Expand Down
4 changes: 2 additions & 2 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def test_default_client_instantiation(rig):
if not getattr(rig, "is_custom_s3", False):
if not getattr(rig, "is_custom_s3", False) and not (getattr(rig, "is_adls_gen2", False)):
# Skip resetting the default client for custom S3 endpoint, but keep the other tests,
# since they're still useful.
rig.client_class._default_client = None
Expand Down Expand Up @@ -43,7 +43,7 @@ def test_default_client_instantiation(rig):
def test_different_clients(rig):
p = rig.create_cloud_path("dir_0/file0_0.txt")

new_client = rig.client_class()
new_client = rig.client_class(**rig.required_client_kwargs)
p2 = new_client.CloudPath(f"{rig.cloud_prefix}{rig.drive}/{rig.test_dir}/dir_0/file0_0.txt")

assert p.client is not p2.client
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cloudpath_instantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_instantiation_errors(rig):
def test_idempotency(rig):
rig.client_class._default_client = None

client = rig.client_class()
client = rig.client_class(**rig.required_client_kwargs)
p = client.CloudPath(f"{rig.cloud_prefix}{rig.drive}/{rig.test_dir}/dir_0/file0_0.txt")

p2 = CloudPath(p)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cloudpath_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_joins(rig):
def test_with_segments(rig):
assert rig.create_cloud_path("a/b/c/d").with_segments(
"x", "y", "z"
) == rig.client_class().CloudPath(f"{rig.cloud_prefix}x/y/z")
) == rig.client_class(**rig.required_client_kwargs).CloudPath(f"{rig.cloud_prefix}x/y/z")


def test_is_junction(rig):
Expand Down

0 comments on commit 96a58ad

Please sign in to comment.