diff --git a/backend/dataall/core/environment/db/environment_repositories.py b/backend/dataall/core/environment/db/environment_repositories.py index 7aca43b40..d5836203f 100644 --- a/backend/dataall/core/environment/db/environment_repositories.py +++ b/backend/dataall/core/environment/db/environment_repositories.py @@ -310,3 +310,16 @@ def query_all_active_environments(session) -> List[Environment]: @staticmethod def query_environment_groups(session, uri): return session.query(EnvironmentGroup).filter(EnvironmentGroup.environmentUri == uri).all() + + @staticmethod + def get_environment_consumption_role_by_name(session, uri, IAMRoleName): + return ( + session.query(ConsumptionRole) + .filter( + and_( + ConsumptionRole.environmentUri == uri, + ConsumptionRole.IAMRoleName == IAMRoleName, + ) + ) + .first() + ) diff --git a/backend/dataall/core/environment/services/environment_service.py b/backend/dataall/core/environment/services/environment_service.py index f49302ae4..d61d65d29 100644 --- a/backend/dataall/core/environment/services/environment_service.py +++ b/backend/dataall/core/environment/services/environment_service.py @@ -1147,3 +1147,9 @@ def resolve_consumption_role_policies(uri, IAMRoleName): region=environment.region, resource_prefix=environment.resourcePrefix, ).get_all_policies() + + @staticmethod + @ResourcePolicyService.has_resource_permission(environment_permissions.GET_ENVIRONMENT) + def get_consumption_role_by_name(uri, IAMRoleName): + with get_context().db_engine.scoped_session() as session: + return EnvironmentRepository.get_environment_consumption_role_by_name(session, uri, IAMRoleName) diff --git a/backend/dataall/core/stacks/api/types.py b/backend/dataall/core/stacks/api/types.py index 4cd624de4..979de769e 100644 --- a/backend/dataall/core/stacks/api/types.py +++ b/backend/dataall/core/stacks/api/types.py @@ -21,6 +21,7 @@ gql.Field(name='region', type=gql.NonNullableType(gql.String)), gql.Field(name='status', type=gql.String), gql.Field(name='stackid', type=gql.String), + gql.Field(name='updated', type=gql.AWSDateTime), gql.Field(name='link', type=gql.String, resolver=resolve_link), gql.Field(name='outputs', type=gql.String, resolver=resolve_outputs), gql.Field(name='resources', type=gql.String, resolver=resolve_resources), diff --git a/tests_new/integration_tests/aws_clients/iam.py b/tests_new/integration_tests/aws_clients/iam.py index eb0e1b885..73050c748 100644 --- a/tests_new/integration_tests/aws_clients/iam.py +++ b/tests_new/integration_tests/aws_clients/iam.py @@ -16,19 +16,11 @@ def __init__(self, session=boto3.Session(), region=os.environ.get('AWS_REGION', self._region = region def get_role(self, role_name): - try: - role = self._client.get_role(RoleName=role_name) - return role - except Exception as e: - log.info(f'Error occurred: {e}') - return None + role = self._client.get_role(RoleName=role_name) + return role def delete_role(self, role_name): - try: - self._client.delete_role(RoleName=role_name) - except Exception as e: - log.error(e) - raise e + self._client.delete_role(RoleName=role_name) def create_role(self, account_id, role_name, test_role_name): policy_doc = { @@ -47,16 +39,12 @@ def create_role(self, account_id, role_name, test_role_name): } ], } - try: - role = self._client.create_role( - RoleName=role_name, - AssumeRolePolicyDocument=json.dumps(policy_doc), - Description='Role for Lambda function', - ) - return role - except Exception as e: - log.error(e) - raise e + role = self._client.create_role( + RoleName=role_name, + AssumeRolePolicyDocument=json.dumps(policy_doc), + Description='Role for Lambda function', + ) + return role def get_consumption_role(self, account_id, role_name, test_role_name): role = self.get_role(role_name) diff --git a/tests_new/integration_tests/aws_clients/utils.py b/tests_new/integration_tests/aws_clients/utils.py new file mode 100644 index 000000000..3a6d0dffd --- /dev/null +++ b/tests_new/integration_tests/aws_clients/utils.py @@ -0,0 +1,18 @@ +import json +import boto3 + +from tests_new.integration_tests.aws_clients.sts import StsClient + + +def get_group_session(credentials_str): + credentials = json.loads(credentials_str) + return boto3.Session( + aws_access_key_id=credentials['AccessKey'], + aws_secret_access_key=credentials['SessionKey'], + aws_session_token=credentials['sessionToken'], + ) + + +def get_role_session(session, role_arn, region): + sts_client = StsClient(session=session, region=region) + return sts_client.get_role_session(role_arn) diff --git a/tests_new/integration_tests/core/environment/global_conftest.py b/tests_new/integration_tests/core/environment/global_conftest.py index fcc1db311..79bb9c1b6 100644 --- a/tests_new/integration_tests/core/environment/global_conftest.py +++ b/tests_new/integration_tests/core/environment/global_conftest.py @@ -1,4 +1,6 @@ import logging +from datetime import datetime + import pytest import boto3 @@ -12,6 +14,7 @@ ) from integration_tests.core.organizations.queries import create_organization from integration_tests.core.stack.utils import check_stack_ready +from tests_new.integration_tests.core.environment.utils import update_env_stack log = logging.getLogger(__name__) @@ -178,3 +181,30 @@ def get_or_create_persistent_env(env_name, client, group, testdata): @pytest.fixture(scope='session') def persistent_env1(client1, group1, testdata): return get_or_create_persistent_env('persistent_env1', client1, group1, testdata) + + +@pytest.fixture(scope='session') +def updated_persistent_env1(client1, group1, persistent_env1): + update_env_stack(client1, persistent_env1) + return get_environment(client1, persistent_env1.environmentUri) + + +@pytest.fixture(scope='session') +def persistent_cross_acc_env_1(client5, group5, testdata): + return get_or_create_persistent_env('persistent_cross_acc_env_1', client5, group5, testdata) + + +@pytest.fixture(scope='session') +def updated_persistent_cross_acc_env_1(client5, group5, persistent_cross_acc_env_1): + update_env_stack(client5, persistent_cross_acc_env_1) + return get_environment(client5, persistent_cross_acc_env_1.environmentUri) + + +@pytest.fixture(scope='session') +def persistent_cross_acc_env_1_integration_role_arn(persistent_cross_acc_env_1): + return f'arn:aws:iam::{persistent_cross_acc_env_1.AwsAccountId}:role/dataall-integration-tests-role-{persistent_cross_acc_env_1.region}' + + +@pytest.fixture(scope='session') +def persistent_cross_acc_env_1_aws_client(persistent_cross_acc_env_1, persistent_cross_acc_env_1_integration_role_arn): + return get_environment_aws_session(persistent_cross_acc_env_1_integration_role_arn, persistent_cross_acc_env_1) diff --git a/tests_new/integration_tests/core/environment/queries.py b/tests_new/integration_tests/core/environment/queries.py index acdf2ca79..7a1e3eba2 100644 --- a/tests_new/integration_tests/core/environment/queries.py +++ b/tests_new/integration_tests/core/environment/queries.py @@ -33,6 +33,7 @@ accountid region stackid + updated link outputs resources @@ -244,6 +245,33 @@ def add_consumption_role(client, env_uri, group_uri, consumption_role_name, iam_ return response.data.addConsumptionRoleToEnvironment +def list_environment_consumption_roles(client, env_uri, filter): + query = { + 'operationName': 'listEnvironmentConsumptionRoles', + 'variables': {'environmentUri': env_uri, 'filter': filter}, + 'query': """ + query listEnvironmentConsumptionRoles($environmentUri: String!, $filter: ConsumptionRoleFilter) { + listEnvironmentConsumptionRoles(environmentUri: $environmentUri, filter: $filter) { + count + page + pages + hasNext + hasPrevious + nodes { + consumptionRoleUri + consumptionRoleName + environmentUri + groupUri + IAMRoleArn + } + } + } + """, + } + response = client.query(query=query) + return response.data.listEnvironmentConsumptionRoles + + def remove_consumption_role(client, env_uri, consumption_role_uri): query = { 'operationName': 'removeConsumptionRoleFromEnvironment', diff --git a/tests_new/integration_tests/core/environment/test_environment.py b/tests_new/integration_tests/core/environment/test_environment.py index 256fd3327..a3bf50497 100644 --- a/tests_new/integration_tests/core/environment/test_environment.py +++ b/tests_new/integration_tests/core/environment/test_environment.py @@ -12,10 +12,13 @@ remove_consumption_role, remove_group_from_env, ) -from integration_tests.core.stack.queries import update_stack -from integration_tests.core.stack.utils import check_stack_in_progress, check_stack_ready + from integration_tests.errors import GqlError +from integration_tests.core.environment.utils import update_env_stack +from integration_tests.core.stack.queries import get_stack + + log = logging.getLogger(__name__) @@ -51,15 +54,18 @@ def test_list_envs_invited(client2, session_env1, session_env2, session_id): def test_persistent_env_update(client1, persistent_env1): - # wait for stack to get to a final state before triggering an update - stack_uri = persistent_env1.stack.stackUri - env_uri = persistent_env1.environmentUri - check_stack_ready(client1, env_uri, stack_uri) - update_stack(client1, env_uri, 'environment') - # wait for stack to move to "in_progress" state - check_stack_in_progress(client1, env_uri, stack_uri) - stack = check_stack_ready(client1, env_uri, stack_uri) - assert_that(stack.status).is_equal_to('UPDATE_COMPLETE') + stack = get_stack( + client1, + persistent_env1.environmentUri, + persistent_env1.stack.stackUri, + persistent_env1.environmentUri, + target_type='environment', + ) + updated_before = datetime.fromisoformat(stack.updated) + stack = update_env_stack(client1, persistent_env1) + assert_that(stack).contains_entry(status='UPDATE_COMPLETE') + updated = datetime.fromisoformat(stack.updated) + assert_that(updated).is_greater_than_or_equal_to(updated_before) def test_invite_group_on_env_no_org(client1, session_env2, group4): diff --git a/tests_new/integration_tests/core/environment/utils.py b/tests_new/integration_tests/core/environment/utils.py index 99e55c959..e84ddf0ce 100644 --- a/tests_new/integration_tests/core/environment/utils.py +++ b/tests_new/integration_tests/core/environment/utils.py @@ -1,5 +1,6 @@ from integration_tests.core.environment.queries import update_environment from integration_tests.core.stack.utils import check_stack_ready, check_stack_in_progress +from integration_tests.core.stack.queries import update_stack def set_env_params(client, env, **new_params): @@ -34,3 +35,14 @@ def set_env_params(client, env, **new_params): ) check_stack_in_progress(client, env_uri, stack_uri) check_stack_ready(client, env_uri, stack_uri) + + +def update_env_stack(client, env): + stack_uri = env.stack.stackUri + env_uri = env.environmentUri + # wait for stack to get to a final state before triggering an update + check_stack_ready(client, env_uri, stack_uri) + update_stack(client, env_uri, 'environment') + # wait for stack to move to "in_progress" state + check_stack_in_progress(client, env_uri, stack_uri) + return check_stack_ready(client, env_uri, stack_uri) diff --git a/tests_new/integration_tests/core/stack/queries.py b/tests_new/integration_tests/core/stack/queries.py index ef4910a11..9757bf693 100644 --- a/tests_new/integration_tests/core/stack/queries.py +++ b/tests_new/integration_tests/core/stack/queries.py @@ -44,6 +44,7 @@ def get_stack(client, env_uri, stack_uri, target_uri, target_type): accountid region stackid + updated link outputs resources diff --git a/tests_new/integration_tests/modules/s3_datasets/aws_clients.py b/tests_new/integration_tests/modules/s3_datasets/aws_clients.py index a7b883d32..020d74d2d 100644 --- a/tests_new/integration_tests/modules/s3_datasets/aws_clients.py +++ b/tests_new/integration_tests/modules/s3_datasets/aws_clients.py @@ -104,7 +104,7 @@ def list_bucket_objects(self, bucket_name): return self._client.list_objects(Bucket=bucket_name) except ClientError as e: logging.error(f'Error listing objects in S3: {e}') - raise + raise e def list_accesspoint_folder_objects(self, access_point, folder_name): try: diff --git a/tests_new/integration_tests/modules/s3_datasets/global_conftest.py b/tests_new/integration_tests/modules/s3_datasets/global_conftest.py index 8de97095d..4153ad9b3 100644 --- a/tests_new/integration_tests/modules/s3_datasets/global_conftest.py +++ b/tests_new/integration_tests/modules/s3_datasets/global_conftest.py @@ -1,4 +1,6 @@ import logging +import time + import pytest import boto3 import json @@ -19,6 +21,7 @@ from tests_new.integration_tests.modules.datasets_base.queries import list_datasets from integration_tests.modules.s3_datasets.aws_clients import S3Client, KMSClient, GlueClient, LakeFormationClient +from integration_tests.core.stack.queries import update_stack log = logging.getLogger(__name__) @@ -398,7 +401,15 @@ def temp_s3_dataset1(client1, group1, org1, session_env1, session_id, testdata): def get_or_create_persistent_s3_dataset( - dataset_name, client, group, env, autoApprovalEnabled=False, bucket=None, kms_alias='', glue_database='' + dataset_name, + client, + group, + env, + autoApprovalEnabled=False, + bucket=None, + kms_alias='', + glue_database='', + withContent=False, ): dataset_name = dataset_name or 'persistent_s3_dataset1' s3_datasets = list_datasets(client, term=dataset_name).nodes @@ -431,6 +442,9 @@ def get_or_create_persistent_s3_dataset( tags=[dataset_name], autoApprovalEnabled=autoApprovalEnabled, ) + if withContent: + create_tables(client, s3_dataset) + create_folders(client, s3_dataset) if s3_dataset.stack.status in ['CREATE_COMPLETE', 'UPDATE_COMPLETE']: return s3_dataset @@ -441,7 +455,21 @@ def get_or_create_persistent_s3_dataset( @pytest.fixture(scope='session') def persistent_s3_dataset1(client1, group1, persistent_env1, testdata): - return get_or_create_persistent_s3_dataset('persistent_s3_dataset1', client1, group1, persistent_env1) + return get_or_create_persistent_s3_dataset( + 'persistent_s3_dataset1', client1, group1, persistent_env1, withContent=True + ) + + +@pytest.fixture(scope='session') +def updated_persistent_s3_dataset1(client1, persistent_s3_dataset1): + target_type = 'dataset' + stack_uri = persistent_s3_dataset1.stack.stackUri + env_uri = persistent_s3_dataset1.environment.environmentUri + dataset_uri = persistent_s3_dataset1.datasetUri + update_stack(client1, dataset_uri, target_type) + time.sleep(120) + check_stack_ready(client1, env_uri=env_uri, stack_uri=stack_uri, target_uri=dataset_uri, target_type=target_type) + return get_dataset(client1, dataset_uri) @pytest.fixture(scope='session') diff --git a/tests_new/integration_tests/modules/share_base/conftest.py b/tests_new/integration_tests/modules/share_base/conftest.py index a44ebe722..af0935f51 100644 --- a/tests_new/integration_tests/modules/share_base/conftest.py +++ b/tests_new/integration_tests/modules/share_base/conftest.py @@ -1,16 +1,24 @@ import pytest from tests_new.integration_tests.aws_clients.iam import IAMClient -from tests_new.integration_tests.core.environment.queries import add_consumption_role, remove_consumption_role +from tests_new.integration_tests.core.environment.queries import ( + add_consumption_role, + remove_consumption_role, + list_environment_consumption_roles, +) from tests_new.integration_tests.modules.share_base.queries import ( create_share_object, delete_share_object, get_share_object, revoke_share_items, + submit_share_object, + approve_share_object, + add_share_item, ) from tests_new.integration_tests.modules.share_base.utils import check_share_ready -test_cons_role_name = 'dataall-test-ShareTestConsumptionRole' +test_session_cons_role_name = 'dataall-test-ShareTestConsumptionRole' +test_persistent_cons_role_name = 'dataall-test-PersistentConsumptionRole' def revoke_all_possible(client, shareUri): @@ -34,24 +42,39 @@ def clean_up_share(client, shareUri): delete_share_object(client, shareUri) -@pytest.fixture(scope='session') -def consumption_role_1(client5, group5, session_cross_acc_env_1, session_cross_acc_env_1_aws_client): - iam_client = IAMClient(session=session_cross_acc_env_1_aws_client, region=session_cross_acc_env_1['region']) +def create_consumption_role(client, group, environment, environment_client, iam_role_name, cons_role_name): + iam_client = IAMClient(session=environment_client, region=environment['region']) role = iam_client.get_consumption_role( - session_cross_acc_env_1.AwsAccountId, - test_cons_role_name, - f'dataall-integration-tests-role-{session_cross_acc_env_1.region}', + environment.AwsAccountId, + iam_role_name, + f'dataall-integration-tests-role-{environment.region}', + ) + return add_consumption_role( + client, + environment.environmentUri, + group, + cons_role_name, + role['Role']['Arn'], ) - consumption_role = add_consumption_role( + + +# --------------SESSION PARAM FIXTURES---------------------------- + + +@pytest.fixture(scope='session') +def session_consumption_role_1(client5, group5, session_cross_acc_env_1, session_cross_acc_env_1_aws_client): + consumption_role = create_consumption_role( client5, - session_cross_acc_env_1.environmentUri, group5, - 'ShareTestConsumptionRole', - role['Role']['Arn'], + session_cross_acc_env_1, + session_cross_acc_env_1_aws_client, + test_session_cons_role_name, + 'SessionConsRole1', ) yield consumption_role remove_consumption_role(client5, session_cross_acc_env_1.environmentUri, consumption_role.consumptionRoleUri) - iam_client.delete_consumption_role(role['Role']['RoleName']) + iam_client = IAMClient(session=session_cross_acc_env_1_aws_client, region=session_cross_acc_env_1['region']) + iam_client.delete_consumption_role(consumption_role['Role']['RoleName']) @pytest.fixture(scope='session') @@ -116,14 +139,14 @@ def session_share_consrole_1( session_s3_dataset1_tables, session_s3_dataset1_folders, group5, - consumption_role_1, + session_consumption_role_1, ): share1cr = create_share_object( client=client5, dataset_or_item_params={'datasetUri': session_s3_dataset1.datasetUri}, environmentUri=session_cross_acc_env_1.environmentUri, groupUri=group5, - principalId=consumption_role_1.consumptionRoleUri, + principalId=session_consumption_role_1.consumptionRoleUri, principalType='ConsumptionRole', requestPurpose='test create share object', attachMissingPolicies=True, @@ -143,14 +166,14 @@ def session_share_consrole_2( session_imported_sse_s3_dataset1_tables, session_imported_sse_s3_dataset1_folders, group5, - consumption_role_1, + session_consumption_role_1, ): share2cr = create_share_object( client=client5, dataset_or_item_params={'datasetUri': session_imported_sse_s3_dataset1.datasetUri}, environmentUri=session_cross_acc_env_1.environmentUri, groupUri=group5, - principalId=consumption_role_1.consumptionRoleUri, + principalId=session_consumption_role_1.consumptionRoleUri, principalType='ConsumptionRole', requestPurpose='test create share object', attachMissingPolicies=True, @@ -163,11 +186,11 @@ def session_share_consrole_2( @pytest.fixture(params=['Group', 'ConsumptionRole']) -def principal1(request, group5, consumption_role_1): +def principal1(request, group5, session_consumption_role_1): if request.param == 'Group': yield group5, request.param else: - yield consumption_role_1.consumptionRoleUri, request.param + yield session_consumption_role_1.consumptionRoleUri, request.param @pytest.fixture(params=['Group', 'ConsumptionRole']) @@ -199,3 +222,101 @@ def share_params_all( yield session_share_1, session_s3_dataset1 else: yield session_share_consrole_1, session_s3_dataset1 + + +# --------------PERSISTENT FIXTURES---------------------------- + + +@pytest.fixture(scope='session') +def persistent_consumption_role_1(client5, group5, persistent_cross_acc_env_1, persistent_cross_acc_env_1_aws_client): + consumption_roles_result = list_environment_consumption_roles( + client5, + persistent_cross_acc_env_1.environmentUri, + {'term': 'PersistentConsRole1'}, + ) + + if consumption_roles_result.count == 0: + consumption_role = create_consumption_role( + client5, + group5, + persistent_cross_acc_env_1, + persistent_cross_acc_env_1_aws_client, + test_persistent_cons_role_name, + 'PersistentConsRole1', + ) + yield consumption_role + else: + yield consumption_roles_result.nodes[0] + + +@pytest.fixture(scope='session') +def persistent_group_share_1( + client5, + client1, + updated_persistent_env1, + updated_persistent_cross_acc_env_1, + updated_persistent_s3_dataset1, + group5, +): + share1 = create_share_object( + client=client5, + dataset_or_item_params={'datasetUri': updated_persistent_s3_dataset1.datasetUri}, + environmentUri=updated_persistent_cross_acc_env_1.environmentUri, + groupUri=group5, + principalId=group5, + principalType='Group', + requestPurpose='create persistent share object', + attachMissingPolicies=True, + permissions=['Read'], + ) + share1 = get_share_object(client5, share1.shareUri) + + if share1.status == 'Draft': + items = share1['items'].nodes + for item in items: + add_share_item(client5, share1.shareUri, item.itemUri, item.itemType) + submit_share_object(client5, share1.shareUri) + approve_share_object(client1, share1.shareUri) + check_share_ready(client5, share1.shareUri) + yield get_share_object(client5, share1.shareUri) + + +@pytest.fixture(scope='session') +def persistent_role_share_1( + client5, + client1, + updated_persistent_env1, + updated_persistent_cross_acc_env_1, + updated_persistent_s3_dataset1, + group5, + persistent_consumption_role_1, +): + share1 = create_share_object( + client=client5, + dataset_or_item_params={'datasetUri': updated_persistent_s3_dataset1.datasetUri}, + environmentUri=updated_persistent_cross_acc_env_1.environmentUri, + groupUri=group5, + principalId=persistent_consumption_role_1.consumptionRoleUri, + principalType='ConsumptionRole', + requestPurpose='create persistent share object', + attachMissingPolicies=True, + permissions=['Read'], + ) + share1 = get_share_object(client5, share1.shareUri) + + if share1.status == 'Draft': + items = share1['items'].nodes + for item in items: + add_share_item(client5, share1.shareUri, item.itemUri, item.itemType) + submit_share_object(client5, share1.shareUri) + approve_share_object(client1, share1.shareUri) + check_share_ready(client5, share1.shareUri) + yield get_share_object(client5, share1.shareUri) + + +@pytest.fixture(params=['Group', 'ConsumptionRole']) +def persistent_share_params_main(request, persistent_role_share_1, persistent_group_share_1): + if request.param == 'Group': + yield persistent_group_share_1 + else: + yield persistent_role_share_1 diff --git a/tests_new/integration_tests/modules/share_base/queries.py b/tests_new/integration_tests/modules/share_base/queries.py index 5e69543fc..cbde3a598 100644 --- a/tests_new/integration_tests/modules/share_base/queries.py +++ b/tests_new/integration_tests/modules/share_base/queries.py @@ -176,6 +176,21 @@ def add_share_item(client, shareUri: str, itemUri: str, itemType: str): return response.data.addSharedItem.shareItemUri +def remove_share_item(client, shareItemUri: str): + query = { + 'operationName': 'removeSharedItem', + 'variables': {'shareItemUri': shareItemUri}, + 'query': f""" + mutation removeSharedItem($shareItemUri: String!) {{ + removeSharedItem(shareItemUri: $shareItemUri) + }} + """, + } + + response = client.query(query=query) + return response.data.removeSharedItem + + def verify_share_items(client, shareUri: str, shareItemsUris: List[str]): query = { 'operationName': 'verifyItemsShareObject', diff --git a/tests_new/integration_tests/modules/share_base/shared_test_functions.py b/tests_new/integration_tests/modules/share_base/shared_test_functions.py new file mode 100644 index 000000000..25f8de1da --- /dev/null +++ b/tests_new/integration_tests/modules/share_base/shared_test_functions.py @@ -0,0 +1,203 @@ +from assertpy import assert_that +from botocore.exceptions import ClientError + +from tests_new.integration_tests.aws_clients.utils import get_group_session, get_role_session +from tests_new.integration_tests.core.environment.queries import get_environment_access_token +from tests_new.integration_tests.modules.share_base.queries import ( + get_share_object, + get_s3_consumption_data, + verify_share_items, + revoke_share_items, + add_share_item, + submit_share_object, + approve_share_object, + remove_share_item, +) +from tests_new.integration_tests.modules.share_base.utils import ( + check_share_items_verified, + check_share_ready, +) +from tests_new.integration_tests.aws_clients.athena import AthenaClient +from tests_new.integration_tests.modules.s3_datasets.aws_clients import S3Client +from tests_new.integration_tests.modules.s3_datasets.queries import get_folder + +ALL_S3_SHARABLE_TYPES_NAMES = [ + 'Table', + 'StorageLocation', + 'S3Bucket', +] + + +def add_all_items_to_share(client, shareUri): + updated_share = get_share_object(client, shareUri) + items = updated_share['items'].nodes + for item in items: + assert_that(add_share_item(client, shareUri, item.itemUri, item.itemType)).is_not_none() + updated_share = get_share_object(client, shareUri) + items = updated_share['items'].nodes + assert_that(items).extracting('status').contains_only('PendingApproval') + + +def delete_all_non_shared_items(client, shareUri): + updated_share = get_share_object(client, shareUri) + items = updated_share['items'].nodes + for item in items: + if item.status in [ + 'Revoke_Succeeded', + 'PendingApproval', + 'Share_Rejected', + 'Share_Failed', + ]: + assert_that(remove_share_item(client, item.shareItemUri)).is_true() + + +def check_submit_share_object(client, shareUri, dataset): + submit_share_object(client, shareUri) + updated_share = get_share_object(client, shareUri) + if dataset.autoApprovalEnabled: + assert_that(updated_share.status).is_equal_to('Approved') + else: + assert_that(updated_share.status).is_equal_to('Submitted') + + +def check_approve_share_object(client, shareUri): + approve_share_object(client, shareUri) + updated_share = get_share_object(client, shareUri, {'isShared': True}) + assert_that(updated_share.status).is_equal_to('Approved') + items = updated_share['items'].nodes + assert_that(items).extracting('status').contains_only('Share_Approved') + + +def check_share_succeeded(client, shareUri, check_contains_all_item_types=False): + check_share_ready(client, shareUri) + updated_share = get_share_object(client, shareUri, {'isShared': True}) + items = updated_share['items'].nodes + + assert_that(updated_share.status).is_equal_to('Processed') + for item in items: + assert_that(item.status).is_equal_to('Share_Succeeded') + assert_that(item.healthStatus).is_equal_to('Healthy') + if check_contains_all_item_types: + assert_that(items).extracting('itemType').contains(*ALL_S3_SHARABLE_TYPES_NAMES) + + +def check_verify_share_items(client, shareUri): + share = get_share_object(client, shareUri, {'isShared': True}) + items = share['items'].nodes + times = [item.lastVerificationTime for item in items] + verify_share_items(client, shareUri, [item.shareItemUri for item in items]) + check_share_items_verified(client, shareUri) + updated_share = get_share_object(client, shareUri, {'isShared': True}) + items = updated_share['items'].nodes + assert_that(items).extracting('status').contains_only('Share_Succeeded') + assert_that(items).extracting('healthStatus').contains_only('Healthy') + assert_that(items).extracting('lastVerificationTime').does_not_contain(*times) + + +def check_table_access( + athena_client, glue_db, table_name, workgroup, athena_workgroup_output_location, should_have_access +): + query = 'SELECT * FROM {}.{}'.format(glue_db, table_name) + if should_have_access: + state = athena_client.execute_query(query, workgroup, athena_workgroup_output_location) + assert_that(state).is_equal_to('SUCCEEDED') + else: + state = athena_client.execute_query(query, workgroup, athena_workgroup_output_location) + assert_that(state).is_equal_to('FAILED') + + +def check_bucket_access(client, s3_client, bucket_name, should_have_access): + if should_have_access: + assert_that(s3_client.bucket_exists(bucket_name)).is_true() + assert_that(s3_client.list_bucket_objects(bucket_name)).is_not_none() + else: + assert_that(s3_client.bucket_exists(bucket_name)).is_false() + assert_that(s3_client.list_bucket_objects).raises(ClientError).when_called_with(bucket_name).contains( + 'AccessDenied' + ) + + +def check_accesspoint_access(client, s3_client, access_point_arn, item_uri, should_have_access): + if should_have_access: + folder = get_folder(client, item_uri) + assert_that(s3_client.list_accesspoint_folder_objects(access_point_arn, folder.S3Prefix + '/')).is_not_none() + else: + assert_that(get_folder).raises(Exception).when_called_with(client, item_uri).contains( + 'is not authorized to perform: GET_DATASET_FOLDER' + ) + + +def check_share_items_access( + client, + group, + shareUri, + consumption_role, + env_client, +): + share = get_share_object(client, shareUri) + dataset = share.dataset + principal_type = share.principal.principalType + if principal_type == 'Group': + credentials_str = get_environment_access_token(client, share.environment.environmentUri, group) + session = get_group_session(credentials_str) + elif principal_type == 'ConsumptionRole': + session = get_role_session(env_client, consumption_role.IAMRoleArn, dataset.region) + else: + raise Exception('wrong principal type') + + s3_client = S3Client(session, dataset.region) + athena_client = AthenaClient(session, dataset.region) + + consumption_data = get_s3_consumption_data(client, shareUri) + items = share['items'].nodes + + glue_db = consumption_data.sharedGlueDatabase + access_point_arn = ( + f'arn:aws:s3:{dataset.region}:{dataset.AwsAccountId}:accesspoint/{consumption_data.s3AccessPointName}' + ) + if principal_type == 'Group': + workgroup = athena_client.get_env_work_group(share.environment.name) + athena_workgroup_output_location = None + else: + workgroup = 'primary' + athena_workgroup_output_location = ( + f's3://dataset-{dataset.datasetUri}-session-query-results/athenaqueries/primary/' + ) + + for item in items: + should_have_access = item.status == 'Share_Succeeded' + if item.itemType == 'Table': + check_table_access( + athena_client, glue_db, item.itemName, workgroup, athena_workgroup_output_location, should_have_access + ) + elif item.itemType == 'S3Bucket': + check_bucket_access(client, s3_client, item.itemName, should_have_access) + elif item.itemType == 'StorageLocation': + check_accesspoint_access(client, s3_client, access_point_arn, item.itemUri, should_have_access) + + +def revoke_and_check_all_shared_items(client, shareUri, check_contains_all_item_types=False): + share = get_share_object(client, shareUri, {'isShared': True}) + items = share['items'].nodes + + shareItemUris = [item.shareItemUri for item in items] + revoke_share_items(client, shareUri, shareItemUris) + + updated_share = get_share_object(client, shareUri, {'isShared': True}) + assert_that(updated_share.status).is_equal_to('Revoked') + items = updated_share['items'].nodes + + assert_that(items).extracting('status').contains_only('Revoke_Approved') + if check_contains_all_item_types: + assert_that(items).extracting('itemType').contains(*ALL_S3_SHARABLE_TYPES_NAMES) + + +def check_all_items_revoke_job_succeeded(client, shareUri, check_contains_all_item_types=False): + check_share_ready(client, shareUri) + updated_share = get_share_object(client, shareUri) + items = updated_share['items'].nodes + + assert_that(updated_share.status).is_equal_to('Processed') + assert_that(items).extracting('status').contains_only('Revoke_Succeeded') + if check_contains_all_item_types: + assert_that(items).extracting('itemType').contains(*ALL_S3_SHARABLE_TYPES_NAMES) diff --git a/tests_new/integration_tests/modules/share_base/test_new_crossacc_s3_share.py b/tests_new/integration_tests/modules/share_base/test_new_crossacc_s3_share.py index 035d2e254..fb0eb40e4 100644 --- a/tests_new/integration_tests/modules/share_base/test_new_crossacc_s3_share.py +++ b/tests_new/integration_tests/modules/share_base/test_new_crossacc_s3_share.py @@ -1,9 +1,6 @@ import pytest from assertpy import assert_that -from tests_new.integration_tests.aws_clients.athena import AthenaClient -from tests_new.integration_tests.modules.s3_datasets.aws_clients import S3Client -from tests_new.integration_tests.modules.s3_datasets.queries import get_folder from tests_new.integration_tests.modules.share_base.conftest import clean_up_share from tests_new.integration_tests.modules.share_base.queries import ( create_share_object, @@ -11,26 +8,24 @@ add_share_item, get_share_object, reject_share_object, - approve_share_object, - revoke_share_items, delete_share_object, - verify_share_items, update_share_request_reason, update_share_reject_reason, - get_s3_consumption_data, ) from tests_new.integration_tests.modules.share_base.utils import ( check_share_ready, - check_share_items_verified, - get_group_session, - get_role_session, ) -ALL_S3_SHARABLE_TYPES_NAMES = [ - 'Table', - 'StorageLocation', - 'S3Bucket', -] +from tests_new.integration_tests.modules.share_base.shared_test_functions import ( + check_share_items_access, + check_verify_share_items, + revoke_and_check_all_shared_items, + check_all_items_revoke_job_succeeded, + add_all_items_to_share, + check_submit_share_object, + check_approve_share_object, + check_share_succeeded, +) def test_create_and_delete_share_object(client5, session_cross_acc_env_1, session_s3_dataset1, principal1, group5): @@ -149,131 +144,51 @@ def test_change_share_purpose(client5, share_params_main): @pytest.mark.dependency(name='share_submitted') def test_submit_object(client5, share_params_all): share, dataset = share_params_all - updated_share = get_share_object(client5, share.shareUri) - items = updated_share['items'].nodes - for item in items: - add_share_item(client5, share.shareUri, item.itemUri, item.itemType) - - submit_share_object(client5, share.shareUri) - updated_share = get_share_object(client5, share.shareUri) - if dataset.autoApprovalEnabled: - assert_that(updated_share.status).is_equal_to('Approved') - else: - assert_that(updated_share.status).is_equal_to('Submitted') + add_all_items_to_share(client5, share.shareUri) + check_submit_share_object(client5, share.shareUri, dataset) @pytest.mark.dependency(name='share_approved', depends=['share_submitted']) def test_approve_share(client1, share_params_main): share, dataset = share_params_main - approve_share_object(client1, share.shareUri) - - updated_share = get_share_object(client1, share.shareUri, {'isShared': True}) - assert_that(updated_share.status).is_equal_to('Approved') - items = updated_share['items'].nodes - assert_that(items).extracting('status').contains_only('Share_Approved') + check_approve_share_object(client1, share.shareUri) @pytest.mark.dependency(name='share_succeeded', depends=['share_approved']) def test_share_succeeded(client1, share_params_main): share, dataset = share_params_main - check_share_ready(client1, share.shareUri) - updated_share = get_share_object(client1, share.shareUri, {'isShared': True}) - items = updated_share['items'].nodes - - assert_that(updated_share.status).is_equal_to('Processed') - for item in items: - assert_that(item.status).is_equal_to('Share_Succeeded') - assert_that(item.healthStatus).is_equal_to('Healthy') - assert_that(items).extracting('itemType').contains(*ALL_S3_SHARABLE_TYPES_NAMES) + check_share_succeeded(client1, share.shareUri, check_contains_all_item_types=True) @pytest.mark.dependency(name='share_verified', depends=['share_succeeded']) def test_verify_share_items(client1, share_params_main): share, dataset = share_params_main - updated_share = get_share_object(client1, share.shareUri, {'isShared': True}) - items = updated_share['items'].nodes - times = [item.lastVerificationTime for item in items] - verify_share_items(client1, share.shareUri, [item.shareItemUri for item in items]) - check_share_items_verified(client1, share.shareUri) - updated_share = get_share_object(client1, share.shareUri, {'isShared': True}) - items = updated_share['items'].nodes - assert_that(items).extracting('status').contains_only('Share_Succeeded') - assert_that(items).extracting('healthStatus').contains_only('Healthy') - assert_that(items).extracting('lastVerificationTime').does_not_contain(*times) + check_verify_share_items(client1, share.shareUri) @pytest.mark.dependency(depends=['share_verified']) -def test_check_item_access(client5, session_cross_acc_env_1_aws_client, share_params_main, group5, consumption_role_1): +def test_check_item_access( + client5, session_cross_acc_env_1_aws_client, share_params_main, group5, session_consumption_role_1 +): share, dataset = share_params_main - principal_type = share.principal.principalType - if principal_type == 'Group': - session = get_group_session(client5, share.environment.environmentUri, group5) - elif principal_type == 'ConsumptionRole': - session = get_role_session(session_cross_acc_env_1_aws_client, consumption_role_1.IAMRoleArn, dataset.region) - else: - raise Exception('wrong principal type') - - s3_client = S3Client(session, dataset.region) - athena_client = AthenaClient(session, dataset.region) - - consumption_data = get_s3_consumption_data(client5, share.shareUri) - updated_share = get_share_object(client5, share.shareUri, {'isShared': True}) - items = updated_share['items'].nodes - - glue_db = consumption_data.sharedGlueDatabase - access_point_arn = ( - f'arn:aws:s3:{dataset.region}:{dataset.AwsAccountId}:accesspoint/{consumption_data.s3AccessPointName}' + check_share_items_access( + client5, group5, share.shareUri, session_consumption_role_1, session_cross_acc_env_1_aws_client ) - if principal_type == 'Group': - workgroup = athena_client.get_env_work_group(updated_share.environment.name) - athena_workgroup_output_location = None - else: - workgroup = 'primary' - athena_workgroup_output_location = ( - f's3://dataset-{dataset.datasetUri}-session-query-results/athenaqueries/primary/' - ) - - for item in items: - if item.itemType == 'Table': - # nosemgrep-next-line:noexec - query = 'SELECT * FROM {}.{}'.format(glue_db, item.itemName) - state = athena_client.execute_query(query, workgroup, athena_workgroup_output_location) - assert_that(state).is_equal_to('SUCCEEDED') - elif item.itemType == 'S3Bucket': - assert_that(s3_client.bucket_exists(item.itemName)).is_not_none() - assert_that(s3_client.list_bucket_objects(item.itemName)).is_not_none() - elif item.itemType == 'StorageLocation': - folder = get_folder(client5, item.itemUri) - assert_that( - s3_client.list_accesspoint_folder_objects(access_point_arn, folder.S3Prefix + '/') - ).is_not_none() @pytest.mark.dependency(name='share_revoked', depends=['share_succeeded']) def test_revoke_share(client1, share_params_main): share, dataset = share_params_main check_share_ready(client1, share.shareUri) - updated_share = get_share_object(client1, share.shareUri, {'isShared': True}) - items = updated_share['items'].nodes - - shareItemUris = [item.shareItemUri for item in items] - revoke_share_items(client1, share.shareUri, shareItemUris) - - updated_share = get_share_object(client1, share.shareUri, {'isShared': True}) - assert_that(updated_share.status).is_equal_to('Revoked') - items = updated_share['items'].nodes - - assert_that(items).extracting('status').contains_only('Revoke_Approved') - assert_that(items).extracting('itemType').contains(*ALL_S3_SHARABLE_TYPES_NAMES) + revoke_and_check_all_shared_items(client1, share.shareUri, check_contains_all_item_types=True) @pytest.mark.dependency(name='share_revoke_succeeded', depends=['share_revoked']) -def test_revoke_succeeded(client1, share_params_main): +def test_revoke_succeeded( + client1, client5, session_cross_acc_env_1_aws_client, share_params_main, group5, session_consumption_role_1 +): share, dataset = share_params_main - check_share_ready(client1, share.shareUri) - updated_share = get_share_object(client1, share.shareUri, {'isShared': True}) - items = updated_share['items'].nodes - - assert_that(updated_share.status).is_equal_to('Processed') - assert_that(items).extracting('status').contains_only('Revoke_Succeeded') - assert_that(items).extracting('itemType').contains(*ALL_S3_SHARABLE_TYPES_NAMES) + check_all_items_revoke_job_succeeded(client1, share.shareUri, check_contains_all_item_types=True) + check_share_items_access( + client5, group5, share.shareUri, session_consumption_role_1, session_cross_acc_env_1_aws_client + ) diff --git a/tests_new/integration_tests/modules/share_base/test_persistent_crossacc_share.py b/tests_new/integration_tests/modules/share_base/test_persistent_crossacc_share.py new file mode 100644 index 000000000..737586f16 --- /dev/null +++ b/tests_new/integration_tests/modules/share_base/test_persistent_crossacc_share.py @@ -0,0 +1,99 @@ +from tests_new.integration_tests.modules.share_base.utils import check_share_ready +from tests_new.integration_tests.modules.share_base.shared_test_functions import ( + check_share_items_access, + check_verify_share_items, + revoke_and_check_all_shared_items, + check_all_items_revoke_job_succeeded, + add_all_items_to_share, + check_submit_share_object, + check_approve_share_object, + check_share_succeeded, + delete_all_non_shared_items, +) + +""" +1. Update persistent envs and datasets used for shares (made in fixtures) +2. Share verification test +3. Check item access test +4. Revoke share test +5. Check no access left +6. Add all items back to share +7. Share approved/processed successfully +8. Share verification test +9. Check item access test +""" + + +def test_verify_share_items(client5, persistent_share_params_main): + check_verify_share_items(client5, persistent_share_params_main.shareUri) + + +def test_check_share_items_access( + client5, group5, persistent_share_params_main, persistent_consumption_role_1, persistent_cross_acc_env_1_aws_client +): + check_share_items_access( + client5, + group5, + persistent_share_params_main.shareUri, + persistent_consumption_role_1, + persistent_cross_acc_env_1_aws_client, + ) + + +def test_revoke_share(client1, persistent_share_params_main): + check_share_ready(client1, persistent_share_params_main.shareUri) + revoke_and_check_all_shared_items( + client1, persistent_share_params_main.shareUri, check_contains_all_item_types=True + ) + + +def test_revoke_succeeded( + client1, + client5, + group5, + persistent_share_params_main, + persistent_consumption_role_1, + persistent_cross_acc_env_1_aws_client, +): + check_all_items_revoke_job_succeeded( + client1, persistent_share_params_main.shareUri, check_contains_all_item_types=True + ) + check_share_items_access( + client5, + group5, + persistent_share_params_main.shareUri, + persistent_consumption_role_1, + persistent_cross_acc_env_1_aws_client, + ) + + +def test_delete_all_nonshared_items(client5, persistent_share_params_main): + check_share_ready(client5, persistent_share_params_main.shareUri) + delete_all_non_shared_items(client5, persistent_share_params_main.shareUri) + + +def test_add_items_back_to_share(client5, persistent_share_params_main): + check_share_ready(client5, persistent_share_params_main.shareUri) + add_all_items_to_share(client5, persistent_share_params_main.shareUri) + + +def test_submit_share(client5, persistent_share_params_main, persistent_s3_dataset1): + check_submit_share_object(client5, persistent_share_params_main.shareUri, persistent_s3_dataset1) + + +def test_approve_share(client1, persistent_share_params_main): + check_approve_share_object(client1, persistent_share_params_main.shareUri) + + +def test_re_share_succeeded( + client5, persistent_share_params_main, persistent_consumption_role_1, persistent_cross_acc_env_1_aws_client +): + check_share_succeeded(client5, persistent_share_params_main.shareUri, check_contains_all_item_types=True) + check_verify_share_items(client5, persistent_share_params_main.shareUri) + check_share_items_access( + client5, + persistent_share_params_main.group, + persistent_share_params_main.shareUri, + persistent_consumption_role_1, + persistent_cross_acc_env_1_aws_client, + ) diff --git a/tests_new/integration_tests/modules/share_base/utils.py b/tests_new/integration_tests/modules/share_base/utils.py index e1eaee334..91887bedc 100644 --- a/tests_new/integration_tests/modules/share_base/utils.py +++ b/tests_new/integration_tests/modules/share_base/utils.py @@ -1,9 +1,3 @@ -import json - -import boto3 - -from tests_new.integration_tests.aws_clients.sts import StsClient -from tests_new.integration_tests.core.environment.queries import get_environment_access_token from tests_new.integration_tests.modules.share_base.queries import get_share_object from tests_new.integration_tests.utils import poller @@ -31,18 +25,3 @@ def check_share_ready(client, shareUri): @poller(check_success=lambda share: is_all_items_verified(share), timeout=600) def check_share_items_verified(client, shareUri): return get_share_object(client, shareUri) - - -def get_group_session(client, env_uri, group): - credentials = json.loads(get_environment_access_token(client, env_uri, group)) - - return boto3.Session( - aws_access_key_id=credentials['AccessKey'], - aws_secret_access_key=credentials['SessionKey'], - aws_session_token=credentials['sessionToken'], - ) - - -def get_role_session(session, role_arn, region): - sts_client = StsClient(session=session, region=region) - return sts_client.get_role_session(role_arn)