From 1f28d8e2854f6cc02f96885320b47d87daec122c Mon Sep 17 00:00:00 2001 From: dlpzx <71252798+dlpzx@users.noreply.github.com> Date: Mon, 4 Dec 2023 14:51:40 +0100 Subject: [PATCH 1/5] Remove unnecessary MANAGE_ORGANIZATIONS check (#887) ### Feature or Bugfix - Bugfix ### Detail - Remove unnecessary permission check in list environments in the organization ### Relates - #842 ### Security Please answer the questions below briefly where applicable, or write `N/A`. Based on [OWASP 10](https://owasp.org/Top10/en/). - Does this PR introduce or modify any input fields or queries - this includes fetching data from storage outside the application (e.g. a database, an S3 bucket)? - Is the input sanitized? - What precautions are you taking before deserializing the data you consume? - Is injection prevented by parametrizing queries? - Have you ensured no `eval` or similar functions are used? - Does this PR introduce any functionality or component that requires authorization? - How have you ensured it respects the existing AuthN/AuthZ mechanisms? - Are you logging failed auth attempts? - Are you using or adding any cryptographic features? - Do you use a standard proven implementations? - Are the used keys controlled by the customer? Where are they stored? - Are you introducing any new policies/roles/users? - Have you used the least-privilege principle? How? By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license. --- .../dataall/core/organizations/api/queries.py | 9 ----- .../core/organizations/api/resolvers.py | 13 ------- .../db/organization_repositories.py | 36 ----------------- tests/core/organizations/test_organization.py | 39 ------------------- 4 files changed, 97 deletions(-) diff --git a/backend/dataall/core/organizations/api/queries.py b/backend/dataall/core/organizations/api/queries.py index 32625e8d9..26e171879 100644 --- a/backend/dataall/core/organizations/api/queries.py +++ b/backend/dataall/core/organizations/api/queries.py @@ -23,15 +23,6 @@ test_scope='Organization', ) -listOrganizationInvitedGroups = gql.QueryField( - name='listOrganizationInvitedGroups', - type=gql.Ref('GroupSearchResult'), - args=[ - gql.Argument(name='organizationUri', type=gql.NonNullableType(gql.String)), - gql.Argument(name='filter', type=gql.Ref('GroupFilter')), - ], - resolver=list_organization_invited_groups, -) listOrganizationGroups = gql.QueryField( name='listOrganizationGroups', diff --git a/backend/dataall/core/organizations/api/resolvers.py b/backend/dataall/core/organizations/api/resolvers.py index 458a2ddb4..f866b7c72 100644 --- a/backend/dataall/core/organizations/api/resolvers.py +++ b/backend/dataall/core/organizations/api/resolvers.py @@ -107,19 +107,6 @@ def remove_group(context: Context, source, organizationUri=None, groupUri=None): return organization -def list_organization_invited_groups( - context: Context, source, organizationUri=None, filter=None -): - if filter is None: - filter = {} - with context.engine.scoped_session() as session: - return Organization.paginated_organization_invited_groups( - session=session, - uri=organizationUri, - data=filter, - ) - - def list_organization_groups( context: Context, source, organizationUri=None, filter=None ): diff --git a/backend/dataall/core/organizations/db/organization_repositories.py b/backend/dataall/core/organizations/db/organization_repositories.py index 9574ea24e..057496afa 100644 --- a/backend/dataall/core/organizations/db/organization_repositories.py +++ b/backend/dataall/core/organizations/db/organization_repositories.py @@ -151,7 +151,6 @@ def query_organization_environments(session, uri, filter) -> Query: return query @staticmethod - @has_tenant_permission(permissions.MANAGE_ORGANIZATIONS) @has_resource_permission(permissions.GET_ORGANIZATION) def paginated_organization_environments(session, uri, data=None) -> dict: return paginate( @@ -288,7 +287,6 @@ def query_organization_groups(session, uri, filter) -> Query: return query @staticmethod - @has_tenant_permission(permissions.MANAGE_ORGANIZATIONS) @has_resource_permission(permissions.GET_ORGANIZATION) def paginated_organization_groups(session, uri, data=None) -> dict: return paginate( @@ -297,40 +295,6 @@ def paginated_organization_groups(session, uri, data=None) -> dict: page_size=data.get('pageSize', 10), ).to_dict() - @staticmethod - def query_organization_invited_groups(session, organization, filter) -> Query: - query = ( - session.query(models.OrganizationGroup) - .join( - models.Organization, - models.OrganizationGroup.organizationUri == models.Organization.organizationUri, - ) - .filter( - and_( - models.Organization.organizationUri == organization.organizationUri, - models.OrganizationGroup.groupUri != models.Organization.SamlGroupName, - ) - ) - ) - if filter and filter.get('term'): - query = query.filter( - or_( - models.OrganizationGroup.groupUri.ilike('%' + filter.get('term') + '%'), - ) - ) - return query - - @staticmethod - @has_tenant_permission(permissions.MANAGE_ORGANIZATIONS) - @has_resource_permission(permissions.GET_ORGANIZATION) - def paginated_organization_invited_groups(session, uri, data=None) -> dict: - organization = Organization.get_organization_by_uri(session, uri) - return paginate( - query=Organization.query_organization_invited_groups(session, organization, data), - page=data.get('page', 1), - page_size=data.get('pageSize', 10), - ).to_dict() - @staticmethod def count_organization_invited_groups(session, uri, group) -> int: groups = ( diff --git a/tests/core/organizations/test_organization.py b/tests/core/organizations/test_organization.py index 40b304ff9..4f826070d 100644 --- a/tests/core/organizations/test_organization.py +++ b/tests/core/organizations/test_organization.py @@ -222,25 +222,6 @@ def test_group_invitation(db, client, org1, group2, user, group3, group, env): assert response.data.getOrganization.userRoleInOrganization == 'Invited' assert response.data.getOrganization.stats.groups == 1 - response = client.query( - """ - query listOrganizationInvitedGroups($organizationUri: String!, $filter:GroupFilter){ - listOrganizationInvitedGroups(organizationUri:$organizationUri, filter:$filter){ - count - nodes{ - groupUri - name - } - } - } - """, - username=user.username, - groups=[group.name, group2.name], - organizationUri=org1.organizationUri, - filter={}, - ) - - assert response.data.listOrganizationInvitedGroups.count == 1 response = client.query( """ @@ -303,26 +284,6 @@ def test_group_invitation(db, client, org1, group2, user, group3, group, env): print(response) assert response.data.removeGroupFromOrganization - response = client.query( - """ - query listOrganizationInvitedGroups($organizationUri: String!, $filter:GroupFilter){ - listOrganizationInvitedGroups(organizationUri:$organizationUri, filter:$filter){ - count - nodes{ - groupUri - name - } - } - } - """, - username=user.username, - groups=[group.name, group2.name], - organizationUri=org1.organizationUri, - filter={}, - ) - - assert response.data.listOrganizationInvitedGroups.count == 0 - response = client.query( """ query listOrganizationGroups($organizationUri: String!, $filter:GroupFilter){ From e7c87dfcdaf8cab0f4141c630ce6b4be74655dd5 Mon Sep 17 00:00:00 2001 From: nikpodsh <124577300+nikpodsh@users.noreply.github.com> Date: Wed, 6 Dec 2023 09:37:11 +0100 Subject: [PATCH 2/5] Add additional checks for dataset importing (#883) ### Feature or Bugfix Feature ### Detail Added additional checks for dataset importing (see #614 ) Implemented checks that fail if: 1) An importing dataset bucket is encrypted with AWS managed key, but KMS key provided 2) An importing dataset bucket is encrypted with KMS key, but KMS key not provided 3) An importing dataset bucket is encrypted with KMS key, but provided KMS key is wrong 4) An imported dataset bucket is encrypted with AWS managed key, but user is trying to create a share requests for different account than dataset's account From #614: > User can forget to grant the pivotRole access to modify the key policy -> All folder/bucket shares will fail when trying to update the key policy The check for this is not implemented. Rationale for this: There are two types of pivotRole: manual pivotRole (1) and cdk pivotRole (2). 1) Manual pivotRole has `kms:PutKeyPolicy` for '*' resources in the environment. No action needed 2) When dataset stack is deployed (for imported datasets), it triggers update of an environment stack. The environment stack updates policies for all imported keys and datasets. Eventually, it sets `kms:PutKeyPolicy` for the imported key that pivotRole can change the policy. Thus, the users can catch the error only if they try to create share before update of the stack is succeed ### Relates #614 ### Testing Added unit test for checks in this PR. Added example of testing data.all aws clients, by mocking boto3 client. Usually, aws clients are mocked in the integration testing and this test can be used as an example how to mock boto3 instead. ### Security Please answer the questions below briefly where applicable, or write `N/A`. Based on [OWASP 10](https://owasp.org/Top10/en/). - Are you introducing any new policies/roles/users? Yes - Adding `s3:GetEncryptionConfiguration` to the pivotRole to fetch the encryption configuration for env buckets. It's needed to implement additional checks. EncryptionConfiguration provides only information about encryption (encryption type and key) and doesn't allow pivotRole to read the content of the bucket. By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license. --------- Co-authored-by: dlpzx --- .../cdk/pivot_role_core_policies/s3.py | 3 +- .../modules/dataset_sharing/aws/kms_client.py | 34 +++--- .../aws/s3_dataset_bucket_policy_client.py | 86 ++++++++++++++ .../modules/datasets/aws/s3_dataset_client.py | 105 +++++------------ .../datasets/services/dataset_service.py | 64 +++++++---- .../datasets/tasks/bucket_policy_updater.py | 2 +- .../datasets_base/db/dataset_repositories.py | 78 ++++++------- deploy/pivot_role/pivotRole.yaml | 1 + tests/conftest.py | 25 ++++ tests/modules/datasets/test_dataset.py | 16 ++- .../test_import_dataset_check_unit.py | 108 ++++++++++++++++++ tests/modules/datasets/test_share.py | 3 +- 12 files changed, 365 insertions(+), 160 deletions(-) create mode 100644 backend/dataall/modules/datasets/aws/s3_dataset_bucket_policy_client.py create mode 100644 tests/modules/datasets/test_import_dataset_check_unit.py diff --git a/backend/dataall/core/environment/cdk/pivot_role_core_policies/s3.py b/backend/dataall/core/environment/cdk/pivot_role_core_policies/s3.py index 360fcc03b..096eae06a 100644 --- a/backend/dataall/core/environment/cdk/pivot_role_core_policies/s3.py +++ b/backend/dataall/core/environment/cdk/pivot_role_core_policies/s3.py @@ -17,7 +17,8 @@ def get_statements(self): actions=[ 's3:ListAllMyBuckets', 's3:GetBucketLocation', - 's3:PutBucketTagging' + 's3:PutBucketTagging', + 's3:GetEncryptionConfiguration' ], resources=['*'], ), diff --git a/backend/dataall/modules/dataset_sharing/aws/kms_client.py b/backend/dataall/modules/dataset_sharing/aws/kms_client.py index 5642a9013..33abdbbeb 100644 --- a/backend/dataall/modules/dataset_sharing/aws/kms_client.py +++ b/backend/dataall/modules/dataset_sharing/aws/kms_client.py @@ -1,6 +1,7 @@ import logging from dataall.base.aws.sts import SessionHelper +from botocore.exceptions import ClientError log = logging.getLogger(__name__) @@ -20,10 +21,10 @@ def put_key_policy(self, key_id: str, policy: str): PolicyName=self._DEFAULT_POLICY_NAME, Policy=policy, ) - except Exception as e: - log.error( - f'Failed to attach policy to KMS key {key_id} on {self._account_id} : {e} ' - ) + except ClientError as e: + if e.response['Error']['Code'] == 'AccessDenied': + raise Exception(f'Data.all Environment Pivot Role does not have kms:PutKeyPolicy Permission for key id {key_id}: {e}') + log.error(f'Failed to attach policy to KMS key {key_id} on {self._account_id}: {e} ') raise e def get_key_policy(self, key_id: str): @@ -32,10 +33,11 @@ def get_key_policy(self, key_id: str): KeyId=key_id, PolicyName=self._DEFAULT_POLICY_NAME, ) - except Exception as e: - log.error( - f'Failed to get kms key policy of key {key_id} : {e}' - ) + except ClientError as e: + if e.response['Error']['Code'] == 'AccessDenied': + raise Exception( + f'Data.all Environment Pivot Role does not have kms:GetKeyPolicy Permission for key id {key_id}: {e}') + log.error(f'Failed to get kms key policy of key {key_id}: {e}') return None else: return response['Policy'] @@ -45,10 +47,10 @@ def get_key_id(self, key_alias: str): response = self._client.describe_key( KeyId=key_alias, ) - except Exception as e: - log.error( - f'Failed to get kms key id of {key_alias} : {e}' - ) + except ClientError as e: + if e.response['Error']['Code'] == 'AccessDenied': + raise Exception(f'Data.all Environment Pivot Role does not have kms:DescribeKey Permission for key {key_alias}: {e}') + log.error(f'Failed to get kms key id of {key_alias}: {e}') return None else: return response['KeyMetadata']['KeyId'] @@ -62,10 +64,10 @@ def check_key_exists(self, key_alias: str): if key_alias in key_aliases: key_exist = True break - except Exception as e: - log.error( - f'Failed to list kms key aliases in account {self._account_id}: {e}' - ) + except ClientError as e: + if e.response['Error']['Code'] == 'AccessDenied': + raise Exception(f'Data.all Environment Pivot Role does not have kms:ListAliases Permission in account {self._account_id}: {e}') + log.error(f'Failed to list KMS key aliases in account {self._account_id}: {e}') return None else: return key_exist diff --git a/backend/dataall/modules/datasets/aws/s3_dataset_bucket_policy_client.py b/backend/dataall/modules/datasets/aws/s3_dataset_bucket_policy_client.py new file mode 100644 index 000000000..104302498 --- /dev/null +++ b/backend/dataall/modules/datasets/aws/s3_dataset_bucket_policy_client.py @@ -0,0 +1,86 @@ +import json +import logging + +from botocore.exceptions import ClientError + +from dataall.base.aws.sts import SessionHelper +from dataall.modules.datasets_base.db.dataset_models import Dataset + +log = logging.getLogger(__name__) + + +class S3DatasetBucketPolicyClient: + def __init__(self, dataset: Dataset): + session = SessionHelper.remote_session(accountid=dataset.AwsAccountId) + self._client = session.client('s3') + self._dataset = dataset + + def get_bucket_policy(self): + dataset = self._dataset + try: + policy = self._client.get_bucket_policy(Bucket=dataset.S3BucketName)['Policy'] + log.info(f'Current bucket policy---->:{policy}') + policy = json.loads(policy) + except ClientError as err: + if err.response['Error']['Code'] == 'NoSuchBucketPolicy': + log.info(f"No policy attached to '{dataset.S3BucketName}'") + + elif err.response['Error']['Code'] == 'NoSuchBucket': + log.error(f'Bucket deleted {dataset.S3BucketName}') + + elif err.response['Error']['Code'] == 'AccessDenied': + log.error( + f'Access denied in {dataset.AwsAccountId} ' + f'(s3:{err.operation_name}, ' + f"resource='{dataset.S3BucketName}')" + ) + else: + log.exception( + f"Failed to get '{dataset.S3BucketName}' policy in {dataset.AwsAccountId}" + ) + policy = { + 'Version': '2012-10-17', + 'Statement': [ + { + 'Sid': 'OwnerAccount', + 'Effect': 'Allow', + 'Action': ['s3:*'], + 'Resource': [ + f'arn:aws:s3:::{dataset.S3BucketName}', + f'arn:aws:s3:::{dataset.S3BucketName}/*', + ], + 'Principal': { + 'AWS': f'arn:aws:iam::{dataset.AwsAccountId}:root' + }, + } + ], + } + + return policy + + def put_bucket_policy(self, policy): + dataset = self._dataset + update_policy_report = { + 'datasetUri': dataset.datasetUri, + 'bucketName': dataset.S3BucketName, + 'accountId': dataset.AwsAccountId, + } + try: + policy_json = json.dumps(policy) if isinstance(policy, dict) else policy + log.info( + f"Putting new bucket policy on '{dataset.S3BucketName}' policy {policy_json}" + ) + response = self._client.put_bucket_policy( + Bucket=dataset.S3BucketName, Policy=policy_json + ) + log.info(f'Bucket Policy updated: {response}') + update_policy_report.update({'status': 'SUCCEEDED'}) + except ClientError as e: + log.error( + f'Failed to update bucket policy ' + f"on '{dataset.S3BucketName}' policy {policy} " + f'due to {e} ' + ) + update_policy_report.update({'status': 'FAILED'}) + + return update_policy_report diff --git a/backend/dataall/modules/datasets/aws/s3_dataset_client.py b/backend/dataall/modules/datasets/aws/s3_dataset_client.py index 0caf872ce..f68fea385 100644 --- a/backend/dataall/modules/datasets/aws/s3_dataset_client.py +++ b/backend/dataall/modules/datasets/aws/s3_dataset_client.py @@ -17,104 +17,51 @@ def __init__(self, dataset: Dataset): It first starts a session assuming the pivot role, then we define another session assuming the dataset role from the pivot role """ - pivot_role_session = SessionHelper.remote_session(accountid=dataset.AwsAccountId) - session = SessionHelper.get_session(base_session=pivot_role_session, role_arn=dataset.IAMDatasetAdminRoleArn) - self._client = session.client( + self._pivot_role_session = SessionHelper.remote_session(accountid=dataset.AwsAccountId) + self._client = self._pivot_role_session.client('s3') + self._dataset = dataset + + def _get_dataset_role_client(self): + session = SessionHelper.get_session(base_session=self.pivot_role_session, role_arn=self.dataset.IAMDatasetAdminRoleArn) + dataset_client = session.client( 's3', - region_name=dataset.region, + region_name=self._dataset.region, config=Config(signature_version='s3v4', s3={'addressing_style': 'virtual'}), ) - self._dataset = dataset + return dataset_client def get_file_upload_presigned_url(self, data): dataset = self._dataset + client = self._get_dataset_role_client() try: - self._client.get_bucket_acl( + client.get_bucket_acl( Bucket=dataset.S3BucketName, ExpectedBucketOwner=dataset.AwsAccountId ) - response = self._client.generate_presigned_post( + response = client.generate_presigned_post( Bucket=dataset.S3BucketName, Key=data.get('prefix', 'uploads') + '/' + data.get('fileName'), ExpiresIn=15 * 60, ) - return json.dumps(response) + except ClientError as e: raise e - -class S3DatasetBucketPolicyClient: - def __init__(self, dataset: Dataset): - session = SessionHelper.remote_session(accountid=dataset.AwsAccountId) - self._client = session.client('s3') - self._dataset = dataset - - def get_bucket_policy(self): + def get_bucket_encryption(self) -> (str, str): dataset = self._dataset try: - policy = self._client.get_bucket_policy(Bucket=dataset.S3BucketName)['Policy'] - log.info(f'Current bucket policy---->:{policy}') - policy = json.loads(policy) - except ClientError as err: - if err.response['Error']['Code'] == 'NoSuchBucketPolicy': - log.info(f"No policy attached to '{dataset.S3BucketName}'") - - elif err.response['Error']['Code'] == 'NoSuchBucket': - log.error(f'Bucket deleted {dataset.S3BucketName}') - - elif err.response['Error']['Code'] == 'AccessDenied': - log.error( - f'Access denied in {dataset.AwsAccountId} ' - f'(s3:{err.operation_name}, ' - f"resource='{dataset.S3BucketName}')" - ) - else: - log.exception( - f"Failed to get '{dataset.S3BucketName}' policy in {dataset.AwsAccountId}" - ) - policy = { - 'Version': '2012-10-17', - 'Statement': [ - { - 'Sid': 'OwnerAccount', - 'Effect': 'Allow', - 'Action': ['s3:*'], - 'Resource': [ - f'arn:aws:s3:::{dataset.S3BucketName}', - f'arn:aws:s3:::{dataset.S3BucketName}/*', - ], - 'Principal': { - 'AWS': f'arn:aws:iam::{dataset.AwsAccountId}:root' - }, - } - ], - } + response = self._client.get_bucket_encryption( + Bucket=dataset.S3BucketName, + ExpectedBucketOwner=dataset.AwsAccountId + ) + rule = response['ServerSideEncryptionConfiguration']['Rules'][0] + encryption = rule['ApplyServerSideEncryptionByDefault'] + s3_encryption = encryption['SSEAlgorithm'] + kms_id = encryption.get('KMSMasterKeyID').split("/")[-1] if encryption.get('KMSMasterKeyID') else None - return policy + return s3_encryption, kms_id - def put_bucket_policy(self, policy): - dataset = self._dataset - update_policy_report = { - 'datasetUri': dataset.datasetUri, - 'bucketName': dataset.S3BucketName, - 'accountId': dataset.AwsAccountId, - } - try: - policy_json = json.dumps(policy) if isinstance(policy, dict) else policy - log.info( - f"Putting new bucket policy on '{dataset.S3BucketName}' policy {policy_json}" - ) - response = self._client.put_bucket_policy( - Bucket=dataset.S3BucketName, Policy=policy_json - ) - log.info(f'Bucket Policy updated: {response}') - update_policy_report.update({'status': 'SUCCEEDED'}) except ClientError as e: - log.error( - f'Failed to update bucket policy ' - f"on '{dataset.S3BucketName}' policy {policy} " - f'due to {e} ' - ) - update_policy_report.update({'status': 'FAILED'}) - - return update_policy_report + if e.response['Error']['Code'] == 'AccessDenied': + raise Exception(f'Data.all Environment Pivot Role does not have s3:GetEncryptionConfiguration Permission for {dataset.S3BucketName} bucket: {e}') + raise Exception(f'Cannot fetch the bucket encryption configuration for {dataset.S3BucketName}: {e}') diff --git a/backend/dataall/modules/datasets/services/dataset_service.py b/backend/dataall/modules/datasets/services/dataset_service.py index fa6bcfe8a..b60ee1ffe 100644 --- a/backend/dataall/modules/datasets/services/dataset_service.py +++ b/backend/dataall/modules/datasets/services/dataset_service.py @@ -18,7 +18,6 @@ from dataall.modules.catalog.db.glossary_repositories import GlossaryRepository from dataall.modules.datasets.db.dataset_bucket_repositories import DatasetBucketRepository from dataall.modules.vote.db.vote_repositories import VoteRepository -from dataall.base.db.exceptions import AWSResourceNotFound, UnauthorizedOperation from dataall.modules.dataset_sharing.db.share_object_models import ShareObject from dataall.modules.dataset_sharing.db.share_object_repositories import ShareObjectRepository from dataall.modules.dataset_sharing.services.share_permissions import SHARE_OBJECT_APPROVER @@ -52,26 +51,42 @@ def check_dataset_account(session, environment): return True @staticmethod - def check_imported_resources(environment, data): - kms_alias = data.get('KmsKeyAlias') - if kms_alias not in [None, "Undefined", "", "SSE-S3"]: - key_exists = KmsClient(account_id=environment.AwsAccountId, region=environment.region).check_key_exists( + def check_imported_resources(dataset: Dataset): + kms_alias = dataset.KmsAlias + + s3_encryption, kms_id = S3DatasetClient(dataset).get_bucket_encryption() + if kms_alias not in [None, "Undefined", "", "SSE-S3"]: # user-defined KMS encryption + if s3_encryption == 'AES256': + raise exceptions.InvalidInput( + param_name='KmsAlias', + param_value=dataset.KmsAlias, + constraint=f'empty, Bucket {dataset.S3BucketName} is encrypted with AWS managed key (SSE-S3). KmsAlias {kms_alias} should NOT be provided as input parameter.' + ) + + key_exists = KmsClient(account_id=dataset.AwsAccountId, region=dataset.region).check_key_exists( key_alias=f"alias/{kms_alias}" ) if not key_exists: raise exceptions.AWSResourceNotFound( action=IMPORT_DATASET, - message=f'KMS key with alias={kms_alias} cannot be found - Please check if KMS Key Alias exists in account {environment.AwsAccountId}', + message=f'KMS key with alias={kms_alias} cannot be found - Please check if KMS Key Alias exists in account {dataset.AwsAccountId}', ) - key_id = KmsClient(account_id=environment.AwsAccountId, region=environment.region).get_key_id( + key_id = KmsClient(account_id=dataset.AwsAccountId, region=dataset.region).get_key_id( key_alias=f"alias/{kms_alias}" ) - if not key_id: - raise exceptions.AWSResourceNotFound( - action=IMPORT_DATASET, - message=f'Data.all Environment Pivot Role does not have kms:DescribeKey Permission to KMS key with alias={kms_alias}', + + if key_id != kms_id: + raise exceptions.InvalidInput( + param_name='KmsAlias', + param_value=dataset.KmsAlias, + constraint=f'the KMS Alias of the KMS key used to encrypt the Bucket {dataset.S3BucketName}. Provide the correct KMS Alias as input parameter.' ) + + else: # user-defined S3 encryption + if s3_encryption != 'AES256': + raise exceptions.RequiredParameter(param_name='KmsAlias') + return True @staticmethod @@ -83,21 +98,26 @@ def create_dataset(uri, admin_group, data: dict): with context.db_engine.scoped_session() as session: environment = EnvironmentService.get_environment_by_uri(session, uri) DatasetService.check_dataset_account(session=session, environment=environment) - if data.get('imported', False): - DatasetService.check_imported_resources(environment=environment, data=data) + dataset = DatasetRepository.build_dataset( + username=context.username, + env=environment, + data=data + ) + + if dataset.imported: + DatasetService.check_imported_resources(dataset) dataset = DatasetRepository.create_dataset( session=session, - username=context.username, - uri=uri, - data=data, + env=environment, + dataset=dataset, ) DatasetBucketRepository.create_dataset_bucket(session, dataset, data) ResourcePolicy.attach_resource_policy( session=session, - group=data['SamlAdminGroupName'], + group=dataset.SamlAdminGroupName, permissions=DATASET_ALL, resource_uri=dataset.datasetUri, resource_type=Dataset.__name__, @@ -195,8 +215,6 @@ def update_dataset(uri: str, data: dict): dataset = DatasetRepository.get_dataset_by_uri(session, uri) environment = EnvironmentService.get_environment_by_uri(session, dataset.environmentUri) DatasetService.check_dataset_account(session=session, environment=environment) - if data.get('imported', False): - DatasetService.check_imported_resources(environment=environment, data=data) username = get_context().username dataset: Dataset = DatasetRepository.get_dataset_by_uri(session, uri) @@ -207,6 +225,10 @@ def update_dataset(uri: str, data: dict): if data.get('KmsAlias') not in ["Undefined"]: dataset.KmsAlias = "SSE-S3" if data.get('KmsAlias') == "" else data.get('KmsAlias') dataset.importedKmsKey = False if data.get('KmsAlias') == "" else True + + if data.get('imported', False): + DatasetService.check_imported_resources(dataset) + if data.get('stewards') and data.get('stewards') != dataset.stewards: if data.get('stewards') != dataset.SamlAdminGroupName: DatasetService._transfer_stewardship_to_new_stewards( @@ -303,7 +325,7 @@ def start_crawler(uri: str, data: dict = None): crawler = DatasetCrawler(dataset).get_crawler() if not crawler: - raise AWSResourceNotFound( + raise exceptions.AWSResourceNotFound( action=CRAWL_DATASET, message=f'Crawler {dataset.GlueCrawlerName} can not be found', ) @@ -371,7 +393,7 @@ def delete_dataset(uri: str, delete_from_aws: bool = False): ) shares = ShareObjectRepository.list_dataset_shares_with_existing_shared_items(session, uri) if shares: - raise UnauthorizedOperation( + raise exceptions.UnauthorizedOperation( action=DELETE_DATASET, message=f'Dataset {dataset.name} is shared with other teams. ' 'Revoke all dataset shares before deletion.', diff --git a/backend/dataall/modules/datasets/tasks/bucket_policy_updater.py b/backend/dataall/modules/datasets/tasks/bucket_policy_updater.py index a2f995371..0de0ad66a 100644 --- a/backend/dataall/modules/datasets/tasks/bucket_policy_updater.py +++ b/backend/dataall/modules/datasets/tasks/bucket_policy_updater.py @@ -7,7 +7,7 @@ from dataall.base.db import get_engine from dataall.modules.dataset_sharing.db.share_object_repositories import ShareObjectRepository -from dataall.modules.datasets.aws.s3_dataset_client import S3DatasetBucketPolicyClient +from dataall.modules.datasets.aws.s3_dataset_bucket_policy_client import S3DatasetBucketPolicyClient from dataall.modules.datasets_base.db.dataset_models import Dataset root = logging.getLogger() diff --git a/backend/dataall/modules/datasets_base/db/dataset_repositories.py b/backend/dataall/modules/datasets_base/db/dataset_repositories.py index 29fbef4e1..9bbaf2489 100644 --- a/backend/dataall/modules/datasets_base/db/dataset_repositories.py +++ b/backend/dataall/modules/datasets_base/db/dataset_repositories.py @@ -3,6 +3,7 @@ from sqlalchemy import and_, or_ from sqlalchemy.orm import Query from dataall.core.activity.db.activity_models import Activity +from dataall.core.environment.db.environment_models import Environment from dataall.core.environment.services.environment_service import EnvironmentService from dataall.core.organizations.db.organization_repositories import Organization from dataall.base.db import paginate @@ -21,6 +22,39 @@ class DatasetRepository(EnvironmentResource): """DAO layer for Datasets""" + @classmethod + def build_dataset(cls, username: str, env: Environment, data: dict) -> Dataset: + """Builds a datasets based on the request data, but doesn't save it in the database""" + dataset = Dataset( + label=data.get('label'), + owner=username, + description=data.get('description', 'No description provided'), + tags=data.get('tags', []), + AwsAccountId=env.AwsAccountId, + SamlAdminGroupName=data['SamlAdminGroupName'], + region=env.region, + S3BucketName='undefined', + GlueDatabaseName='undefined', + IAMDatasetAdminRoleArn='undefined', + IAMDatasetAdminUserArn='undefined', + KmsAlias='undefined', + environmentUri=env.environmentUri, + organizationUri=env.organizationUri, + language=data.get('language', Language.English.value), + confidentiality=data.get( + 'confidentiality', ConfidentialityClassification.Unclassified.value + ), + topics=data.get('topics', []), + businessOwnerEmail=data.get('businessOwnerEmail'), + businessOwnerDelegationEmails=data.get('businessOwnerDelegationEmails', []), + stewards=data.get('stewards') + if data.get('stewards') + else data['SamlAdminGroupName'], + ) + cls._set_dataset_aws_resources(dataset, data, env) + cls._set_import_data(dataset, data) + return dataset + @staticmethod def get_dataset_by_uri(session, dataset_uri) -> Dataset: dataset: Dataset = session.query(Dataset).get(dataset_uri) @@ -41,55 +75,19 @@ def count_resources(session, environment, group_uri) -> int: ) @staticmethod - def create_dataset( - session, - username: str, - uri: str, - data: dict = None, - ) -> Dataset: - environment = EnvironmentService.get_environment_by_uri(session, uri) - + def create_dataset(session, env: Environment, dataset: Dataset): organization = Organization.get_organization_by_uri( - session, environment.organizationUri + session, env.organizationUri ) - dataset = Dataset( - label=data.get('label'), - owner=username, - description=data.get('description', 'No description provided'), - tags=data.get('tags', []), - AwsAccountId=environment.AwsAccountId, - SamlAdminGroupName=data['SamlAdminGroupName'], - region=environment.region, - S3BucketName='undefined', - GlueDatabaseName='undefined', - IAMDatasetAdminRoleArn='undefined', - IAMDatasetAdminUserArn='undefined', - KmsAlias='undefined', - environmentUri=environment.environmentUri, - organizationUri=environment.organizationUri, - language=data.get('language', Language.English.value), - confidentiality=data.get( - 'confidentiality', ConfidentialityClassification.Unclassified.value - ), - topics=data.get('topics', []), - businessOwnerEmail=data.get('businessOwnerEmail'), - businessOwnerDelegationEmails=data.get('businessOwnerDelegationEmails', []), - stewards=data.get('stewards') - if data.get('stewards') - else data['SamlAdminGroupName'], - ) session.add(dataset) session.commit() - DatasetRepository._set_dataset_aws_resources(dataset, data, environment) - DatasetRepository._set_import_data(dataset, data) - activity = Activity( action='dataset:create', label='dataset:create', - owner=username, - summary=f'{username} created dataset {dataset.name} in {environment.name} on organization {organization.name}', + owner=dataset.owner, + summary=f'{dataset.owner} created dataset {dataset.name} in {env.name} on organization {organization.name}', targetUri=dataset.datasetUri, targetType='dataset', ) diff --git a/deploy/pivot_role/pivotRole.yaml b/deploy/pivot_role/pivotRole.yaml index ced8b4af9..0d5a356b8 100644 --- a/deploy/pivot_role/pivotRole.yaml +++ b/deploy/pivot_role/pivotRole.yaml @@ -57,6 +57,7 @@ Resources: - 's3:ListAllMyBuckets' - 's3:GetBucketLocation' - 's3:PutBucketTagging' + - 's3:GetEncryptionConfiguration' Effect: Allow Resource: '*' - Sid: ManagedBuckets diff --git a/tests/conftest.py b/tests/conftest.py index c803c3350..4a8bb8b16 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import os from dataclasses import dataclass +from unittest.mock import MagicMock import pytest from dataall.base.db import get_engine, create_schema_and_tables, Engine @@ -176,3 +177,27 @@ def patch_check_env(module_mocker): module_mocker.patch( 'dataall.core.environment.api.resolvers.get_pivot_role_as_part_of_environment', return_value=False ) + + +@pytest.fixture(scope='function') +def mock_aws_client(module_mocker): + aws_client = MagicMock() + session_helper = MagicMock() + session = MagicMock() + + # there can be other mocker clients + module_mocker.patch( + 'dataall.modules.datasets.aws.s3_dataset_client.SessionHelper', + session_helper + ) + + module_mocker.patch( + 'dataall.modules.dataset_sharing.aws.kms_client.SessionHelper', + session_helper + ) + + session_helper.get_session.return_value = session + session_helper.remote_session.return_value = session + session.client.return_value = aws_client + + yield aws_client diff --git a/tests/modules/datasets/test_dataset.py b/tests/modules/datasets/test_dataset.py index 1ef7cf549..25009be87 100644 --- a/tests/modules/datasets/test_dataset.py +++ b/tests/modules/datasets/test_dataset.py @@ -10,6 +10,20 @@ from dataall.modules.datasets_base.db.dataset_models import DatasetStorageLocation, DatasetTable, Dataset from tests.core.stacks.test_stack import update_stack_query +mocked_key_id = 'some_key' + + +@pytest.fixture(scope='module', autouse=True) +def mock_s3_client(module_mocker): + s3_client = MagicMock() + module_mocker.patch( + 'dataall.modules.datasets.services.dataset_service.S3DatasetClient', + s3_client + ) + + s3_client().get_bucket_encryption.return_value = ('aws:kms', mocked_key_id) + yield s3_client + @pytest.fixture(scope='module') def dataset1( @@ -25,7 +39,7 @@ def dataset1( kms_client ) - kms_client().get_key_id.return_value = {"some_key"} + kms_client().get_key_id.return_value = mocked_key_id d = dataset(org=org_fixture, env=env_fixture, name='dataset1', owner=env_fixture.owner, group=group.name) print(d) diff --git a/tests/modules/datasets/test_import_dataset_check_unit.py b/tests/modules/datasets/test_import_dataset_check_unit.py new file mode 100644 index 000000000..9abee5bc0 --- /dev/null +++ b/tests/modules/datasets/test_import_dataset_check_unit.py @@ -0,0 +1,108 @@ +import json +from unittest.mock import MagicMock + +import pytest + +from dataall.base.db.exceptions import RequiredParameter, InvalidInput, UnauthorizedOperation, AWSResourceNotFound +from dataall.modules.datasets.services.dataset_service import DatasetService +from dataall.modules.datasets_base.db.dataset_models import Dataset + + +def test_s3_managed_bucket_import(mock_aws_client): + dataset = Dataset(KmsAlias=None) + + mock_encryption_bucket(mock_aws_client, 'AES256', None) + + assert DatasetService.check_imported_resources(dataset) + + +def test_s3_managed_bucket_but_bucket_encrypted_with_kms(mock_aws_client): + dataset = Dataset(KmsAlias=None) + + mock_encryption_bucket(mock_aws_client, 'aws:kms', 'any') + with pytest.raises(RequiredParameter): + DatasetService.check_imported_resources(dataset) + + +def test_s3_managed_bucket_but_alias_provided(mock_aws_client): + dataset = Dataset(KmsAlias='Key') + + mock_encryption_bucket(mock_aws_client, 'AES256', None) + with pytest.raises(InvalidInput): + DatasetService.check_imported_resources(dataset) + + +def test_kms_encrypted_bucket_but_key_not_exist(mock_aws_client): + alias = 'alias' + dataset = Dataset(KmsAlias=alias) + mock_encryption_bucket(mock_aws_client, 'aws:kms', 'any') + mock_existing_alias(mock_aws_client) + + with pytest.raises(AWSResourceNotFound): + DatasetService.check_imported_resources(dataset) + + +def test_kms_encrypted_bucket_but_key_is_wrong(mock_aws_client): + alias = 'key_alias' + kms_id = 'kms_id' + dataset = Dataset(KmsAlias=alias) + mock_encryption_bucket(mock_aws_client, 'aws:kms', 'wrong') + mock_existing_alias(mock_aws_client, f'alias/{alias}') + mock_key_id(mock_aws_client, kms_id) + + with pytest.raises(InvalidInput): + DatasetService.check_imported_resources(dataset) + + +def test_kms_encrypted_bucket_imported(mock_aws_client): + alias = 'key_alias' + kms_id = 'kms_id' + dataset = Dataset(KmsAlias=alias) + mock_encryption_bucket(mock_aws_client, 'aws:kms', kms_id) + mock_existing_alias(mock_aws_client, f'alias/{alias}') + mock_key_id(mock_aws_client, kms_id) + + assert DatasetService.check_imported_resources(dataset) + + +def mock_encryption_bucket(mock_aws_client, algorithm, kms_id=None): + response = f""" + {{ + "ServerSideEncryptionConfiguration": {{ + "Rules": [ + {{ + "ApplyServerSideEncryptionByDefault": {{ + "SSEAlgorithm": "{algorithm}", + "KMSMasterKeyID": "{kms_id}" + }}, + "BucketKeyEnabled": true + }} + ] + }} + }} + """ + mock_aws_client.get_bucket_encryption.return_value = json.loads(response) + + +def mock_existing_alias(mock_aws_client, existing_alias='unknown'): + paginator = MagicMock() + mock_aws_client.get_paginator.return_value = paginator + response = f""" + {{ + "Aliases": [ {{ + "AliasName": "{existing_alias}" + }} ] + }} + """ + paginator.paginate.return_value = [json.loads(response)] + + +def mock_key_id(mock_aws_client, key_id): + response = f""" + {{ + "KeyMetadata": {{ + "KeyId": "{key_id}" + }} + }} + """ + mock_aws_client.describe_key.return_value = json.loads(response) diff --git a/tests/modules/datasets/test_share.py b/tests/modules/datasets/test_share.py index 5ff64b965..a9677269e 100644 --- a/tests/modules/datasets/test_share.py +++ b/tests/modules/datasets/test_share.py @@ -1,5 +1,7 @@ import random import typing +from unittest.mock import MagicMock + import pytest from dataall.core.environment.db.environment_models import Environment, EnvironmentGroup @@ -429,7 +431,6 @@ def create_share_object(client, username, group, groupUri, environmentUri, datas print('Create share request response: ', response) return response - def get_share_object(client, user, group, shareUri, filter): q = """ query getShareObject($shareUri: String!, $filter: ShareableObjectFilter) { From 2e0fd3973cff95475b0ead097478b72d95a651b4 Mon Sep 17 00:00:00 2001 From: dlpzx <71252798+dlpzx@users.noreply.github.com> Date: Wed, 6 Dec 2023 09:38:54 +0100 Subject: [PATCH 3/5] Add SCP error handling in Quicksight identity region checks (#896) ### Feature or Bugfix - Feature - Bugfix ### Detail There is no API to obtain the Quicksight identity region used for an account, we obtain it form the error logs of the response of describe_groups. However, it does not take into account AccessDenied errors based on SCPs. A more detailed description of the issue can be found in #851 This PR: - handles AccessDenied errors based on SCPs and retries other Quicksight identity regions - fixes some methods for registering users that should be using the Quicksight client in the identity region. ### Relates - #851 ### Security Please answer the questions below briefly where applicable, or write `N/A`. Based on [OWASP 10](https://owasp.org/Top10/en/). --> `N/A` - Does this PR introduce or modify any input fields or queries - this includes fetching data from storage outside the application (e.g. a database, an S3 bucket)? - Is the input sanitized? - What precautions are you taking before deserializing the data you consume? - Is injection prevented by parametrizing queries? - Have you ensured no `eval` or similar functions are used? - Does this PR introduce any functionality or component that requires authorization? - How have you ensured it respects the existing AuthN/AuthZ mechanisms? - Are you logging failed auth attempts? - Are you using or adding any cryptographic features? - Do you use a standard proven implementations? - Are the used keys controlled by the customer? Where are they stored? - Are you introducing any new policies/roles/users? - Have you used the least-privilege principle? How? By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license. --- backend/dataall/base/aws/quicksight.py | 60 +++++++++++++------ .../aws/dashboard_quicksight_client.py | 14 ++--- .../cdk/pivot_role_dashboards_policy.py | 1 + .../share_managers/lf_share_manager.py | 4 +- .../services/share_object_service.py | 2 +- .../datasets/services/dataset_service.py | 6 +- deploy/pivot_role/pivotRole.yaml | 1 + 7 files changed, 60 insertions(+), 28 deletions(-) diff --git a/backend/dataall/base/aws/quicksight.py b/backend/dataall/base/aws/quicksight.py index e7c59dfb0..d03b467a3 100644 --- a/backend/dataall/base/aws/quicksight.py +++ b/backend/dataall/base/aws/quicksight.py @@ -10,6 +10,23 @@ class QuicksightClient: DEFAULT_GROUP_NAME = 'dataall' + QUICKSIGHT_IDENTITY_REGIONS = [ + {"name": 'US East (N. Virginia)', "code": 'us-east-1'}, + {"name": 'US East (Ohio)', "code": 'us-east-2'}, + {"name": 'US West (Oregon)', "code": 'us-west-2'}, + {"name": 'Europe (Frankfurt)', "code": 'eu-central-1'}, + {"name": 'Europe (Stockholm)', "code": 'eu-north-1'}, + {"name": 'Europe (Ireland)', "code": 'eu-west-1'}, + {"name": 'Europe (London)', "code": 'eu-west-2'}, + {"name": 'Europe (Paris)', "code": 'eu-west-3'}, + {"name": 'Asia Pacific (Singapore)', "code": 'ap-southeast-1'}, + {"name": 'Asia Pacific (Sydney)', "code": 'ap-southeast-2'}, + {"name": 'Asia Pacific (Tokyo)', "code": 'ap-northeast-1'}, + {"name": 'Asia Pacific (Seoul)', "code": 'ap-northeast-2'}, + {"name": 'South America (São Paulo)', "code": 'sa-east-1'}, + {"name": 'Canada (Central)', "code": 'ca-central-1'}, + {"name": 'Asia Pacific (Mumbai)', "code": 'ap-south-1'}, + ] def __init__(self): pass @@ -37,21 +54,29 @@ def get_identity_region(AwsAccountId): the region quicksight uses as identity region """ identity_region_rex = re.compile('Please use the (?P.*) endpoint.') - identity_region = 'us-east-1' - client = QuicksightClient.get_quicksight_client(AwsAccountId=AwsAccountId, region=identity_region) - try: - response = client.describe_group( - AwsAccountId=AwsAccountId, GroupName=QuicksightClient.DEFAULT_GROUP_NAME, Namespace='default' - ) - except client.exceptions.AccessDeniedException as e: - match = identity_region_rex.findall(str(e)) - if match: - identity_region = match[0] - else: - raise e - except client.exceptions.ResourceNotFoundException: - pass - return identity_region + scp = 'with an explicit deny in a service control policy' + index = 0 + while index < len(QuicksightClient.QUICKSIGHT_IDENTITY_REGIONS): + try: + identity_region = QuicksightClient.QUICKSIGHT_IDENTITY_REGIONS[index].get("code") + index += 1 + client = QuicksightClient.get_quicksight_client(AwsAccountId=AwsAccountId, region=identity_region) + response = client.describe_account_settings(AwsAccountId=AwsAccountId) + logger.info(f'Returning identity region = {identity_region} for account {AwsAccountId}') + return identity_region + except client.exceptions.AccessDeniedException as e: + if scp in str(e): + logger.info(f'Quicksight SCP found in {identity_region} for account {AwsAccountId}. Trying next region...') + else: + logger.info(f'Quicksight identity region is not {identity_region}, selecting correct region endpoint...') + match = identity_region_rex.findall(str(e)) + if match: + identity_region = match[0] + logger.info(f'Returning identity region = {identity_region} for account {AwsAccountId}') + return identity_region + else: + raise e + raise Exception(f'Quicksight subscription is inactive or the identity region has SCPs preventing access from data.all to account {AwsAccountId}') @staticmethod def get_quicksight_client_in_identity_region(AwsAccountId): @@ -99,10 +124,11 @@ def check_quicksight_enterprise_subscription(AwsAccountId, region=None): return False @staticmethod - def create_quicksight_group(AwsAccountId, GroupName=DEFAULT_GROUP_NAME): + def create_quicksight_group(AwsAccountId, region, GroupName=DEFAULT_GROUP_NAME): """Creates a Quicksight group called GroupName Args: AwsAccountId(str): aws account + region: aws region GroupName(str): name of the QS group Returns:dict @@ -113,7 +139,7 @@ def create_quicksight_group(AwsAccountId, GroupName=DEFAULT_GROUP_NAME): if not group: if GroupName == QuicksightClient.DEFAULT_GROUP_NAME: logger.info(f'Initializing data.all default group = {GroupName}') - QuicksightClient.check_quicksight_enterprise_subscription(AwsAccountId) + QuicksightClient.check_quicksight_enterprise_subscription(AwsAccountId, region) logger.info(f'Attempting to create Quicksight group `{GroupName}...') response = client.create_group( diff --git a/backend/dataall/modules/dashboards/aws/dashboard_quicksight_client.py b/backend/dataall/modules/dashboards/aws/dashboard_quicksight_client.py index 118da357e..fa3a8d578 100644 --- a/backend/dataall/modules/dashboards/aws/dashboard_quicksight_client.py +++ b/backend/dataall/modules/dashboards/aws/dashboard_quicksight_client.py @@ -17,18 +17,18 @@ class DashboardQuicksightClient: _DEFAULT_GROUP_NAME = QuicksightClient.DEFAULT_GROUP_NAME def __init__(self, username, aws_account_id, region='eu-west-1'): - session = SessionHelper.remote_session(accountid=aws_account_id) - self._client = session.client('quicksight', region_name=region) self._account_id = aws_account_id self._region = region self._username = username + self._client = QuicksightClient.get_quicksight_client(aws_account_id, region) def register_user_in_group(self, group_name, user_role='READER'): - QuicksightClient.create_quicksight_group(self._account_id, group_name) + identity_region_client = QuicksightClient.get_quicksight_client_in_identity_region(self._account_id) + QuicksightClient.create_quicksight_group(AwsAccountId=self._account_id, region=self._region, GroupName=group_name) user = self._describe_user() if user is not None: - self._client.update_user( + identity_region_client.update_user( UserName=self._username, AwsAccountId=self._account_id, Namespace='default', @@ -36,7 +36,7 @@ def register_user_in_group(self, group_name, user_role='READER'): Role=user_role, ) else: - self._client.register_user( + identity_region_client.register_user( UserName=self._username, Email=self._username, AwsAccountId=self._account_id, @@ -45,13 +45,13 @@ def register_user_in_group(self, group_name, user_role='READER'): UserRole=user_role, ) - response = self._client.list_user_groups( + response = identity_region_client.list_user_groups( UserName=self._username, AwsAccountId=self._account_id, Namespace='default' ) log.info(f'list_user_groups for {self._username}: {response})') if group_name not in [g['GroupName'] for g in response['GroupList']]: log.warning(f'Adding {self._username} to Quicksight group {group_name} on {self._account_id}') - self._client.create_group_membership( + identity_region_client.create_group_membership( MemberName=self._username, GroupName=group_name, AwsAccountId=self._account_id, diff --git a/backend/dataall/modules/dashboards/cdk/pivot_role_dashboards_policy.py b/backend/dataall/modules/dashboards/cdk/pivot_role_dashboards_policy.py index beb4f3fa3..c87279faf 100644 --- a/backend/dataall/modules/dashboards/cdk/pivot_role_dashboards_policy.py +++ b/backend/dataall/modules/dashboards/cdk/pivot_role_dashboards_policy.py @@ -32,6 +32,7 @@ def get_statements(self): 'quicksight:GetAuthCode', 'quicksight:CreateGroupMembership', 'quicksight:DescribeAccountSubscription', + 'quicksight:DescribeAccountSettings', ], resources=[ f'arn:aws:quicksight:*:{self.account}:group/default/*', diff --git a/backend/dataall/modules/dataset_sharing/services/share_managers/lf_share_manager.py b/backend/dataall/modules/dataset_sharing/services/share_managers/lf_share_manager.py index d1e92e43b..df0c8a09d 100644 --- a/backend/dataall/modules/dataset_sharing/services/share_managers/lf_share_manager.py +++ b/backend/dataall/modules/dataset_sharing/services/share_managers/lf_share_manager.py @@ -66,7 +66,9 @@ def get_share_principals(self) -> [str]: dashboard_enabled = EnvironmentService.get_boolean_env_param(self.session, self.target_environment, "dashboardsEnabled") if dashboard_enabled: - group = QuicksightClient.create_quicksight_group(AwsAccountId=self.target_environment.AwsAccountId) + group = QuicksightClient.create_quicksight_group( + AwsAccountId=self.target_environment.AwsAccountId, region=self.target_environment.region + ) if group and group.get('Group'): group_arn = group.get('Group').get('Arn') if group_arn: diff --git a/backend/dataall/modules/dataset_sharing/services/share_object_service.py b/backend/dataall/modules/dataset_sharing/services/share_object_service.py index d6dd11d23..40520f13d 100644 --- a/backend/dataall/modules/dataset_sharing/services/share_object_service.py +++ b/backend/dataall/modules/dataset_sharing/services/share_object_service.py @@ -185,7 +185,7 @@ def submit_share_object(cls, uri: str): if dashboard_enabled: share_table_items = ShareObjectRepository.find_all_share_items(session, uri, ShareableType.Table.value) if share_table_items: - QuicksightClient.check_quicksight_enterprise_subscription(AwsAccountId=env.AwsAccountId) + QuicksightClient.check_quicksight_enterprise_subscription(AwsAccountId=env.AwsAccountId, region=env.region) cls._run_transitions(session, share, states, ShareObjectActions.Submit) diff --git a/backend/dataall/modules/datasets/services/dataset_service.py b/backend/dataall/modules/datasets/services/dataset_service.py index b60ee1ffe..f0756f64f 100644 --- a/backend/dataall/modules/datasets/services/dataset_service.py +++ b/backend/dataall/modules/datasets/services/dataset_service.py @@ -44,9 +44,11 @@ def check_dataset_account(session, environment): dashboards_enabled = EnvironmentService.get_boolean_env_param(session, environment, "dashboardsEnabled") if dashboards_enabled: quicksight_subscription = QuicksightClient.check_quicksight_enterprise_subscription( - AwsAccountId=environment.AwsAccountId) + AwsAccountId=environment.AwsAccountId, region=environment.region) if quicksight_subscription: - group = QuicksightClient.create_quicksight_group(AwsAccountId=environment.AwsAccountId) + group = QuicksightClient.create_quicksight_group( + AwsAccountId=environment.AwsAccountId, region=environment.region + ) return True if group else False return True diff --git a/deploy/pivot_role/pivotRole.yaml b/deploy/pivot_role/pivotRole.yaml index 0d5a356b8..26435d897 100644 --- a/deploy/pivot_role/pivotRole.yaml +++ b/deploy/pivot_role/pivotRole.yaml @@ -399,6 +399,7 @@ Resources: - "quicksight:GetAuthCode" - "quicksight:CreateGroupMembership" - "quicksight:DescribeAccountSubscription" + - "quicksight:DescribeAccountSettings" Resource: - !Sub "arn:aws:quicksight:*:${AWS::AccountId}:group/default/*" - !Sub "arn:aws:quicksight:*:${AWS::AccountId}:user/default/*" From 5061ecb8f46121c9501d0d82edf18a21be711785 Mon Sep 17 00:00:00 2001 From: dlpzx <71252798+dlpzx@users.noreply.github.com> Date: Thu, 7 Dec 2023 11:19:57 +0100 Subject: [PATCH 4/5] Update CodeBuild images to Linux2 standard5.0 (node16 to node18) + Update Docker images to use AmazonLinux:2023 (node18 and Python3.9) (#889) ### Feature or Bugfix - Bugfix ### Detail The purpose of this PR is to upgrade any compute resource that uses node16 to node18. - CodeBuild images: [Amazon Linux 2 x86_64 standard:4.0 use node16 ](https://docs.aws.amazon.com/codebuild/latest/userguide/available-runtimes.html)which is already deprecated. In this PR we update the CodeBuild images to use Amazon Linux 2 x86_64 standard:5.0 instead - Docker images: In this PR we replace AmazonLinux2 images by [AmazonLinux2023](https://docs.aws.amazon.com/linux/al2023/ug/what-is-amazon-linux.html), the next generation of Amazon Linux from Amazon Web Services. In AmazonLinux2023 the default Python version installed is 3.9. For this reason we also upgrade the Python version in this PR. ### Relates #782 ### Security Please answer the questions below briefly where applicable, or write `N/A`. Based on [OWASP 10](https://owasp.org/Top10/en/). N/A - Does this PR introduce or modify any input fields or queries - this includes fetching data from storage outside the application (e.g. a database, an S3 bucket)? - Is the input sanitized? - What precautions are you taking before deserializing the data you consume? - Is injection prevented by parametrizing queries? - Have you ensured no `eval` or similar functions are used? - Does this PR introduce any functionality or component that requires authorization? - How have you ensured it respects the existing AuthN/AuthZ mechanisms? - Are you logging failed auth attempts? - Are you using or adding any cryptographic features? - Do you use a standard proven implementations? - Are the used keys controlled by the customer? Where are they stored? - Are you introducing any new policies/roles/users? - Have you used the least-privilege principle? How? By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license. --- .../cdk/datapipelines_pipeline.py | 4 +- backend/docker/dev/Dockerfile | 43 +++++++++------ backend/docker/prod/ecs/Dockerfile | 54 +++++++++++-------- backend/docker/prod/lambda/Dockerfile | 36 ++++++++----- deploy/stacks/container.py | 14 ++--- deploy/stacks/dbmigration.py | 2 +- deploy/stacks/pipeline.py | 38 ++++++------- docker-compose.yaml | 2 +- .../userguide/docker/prod/Dockerfile | 14 ++--- frontend/docker/prod/Dockerfile | 12 ++--- 10 files changed, 125 insertions(+), 94 deletions(-) diff --git a/backend/dataall/modules/datapipelines/cdk/datapipelines_pipeline.py b/backend/dataall/modules/datapipelines/cdk/datapipelines_pipeline.py index ea8d34f3e..f967458cd 100644 --- a/backend/dataall/modules/datapipelines/cdk/datapipelines_pipeline.py +++ b/backend/dataall/modules/datapipelines/cdk/datapipelines_pipeline.py @@ -264,7 +264,7 @@ def __init__(self, scope, id, target_uri: str = None, **kwargs): id=f'{pipeline.name}-build-{env.stage}', environment=codebuild.BuildEnvironment( privileged=True, - build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_3, + build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_5, environment_variables=PipelineStack.make_environment_variables( pipeline=pipeline, pipeline_environment=env, @@ -335,7 +335,7 @@ def __init__(self, scope, id, target_uri: str = None, **kwargs): id=f'{pipeline.name}-build-{env.stage}', environment=codebuild.BuildEnvironment( privileged=True, - build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_3, + build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_5, environment_variables=PipelineStack.make_environment_variables( pipeline=pipeline, pipeline_environment=env, diff --git a/backend/docker/dev/Dockerfile b/backend/docker/dev/Dockerfile index 2aba2a8a4..3fc7bc29d 100644 --- a/backend/docker/dev/Dockerfile +++ b/backend/docker/dev/Dockerfile @@ -1,24 +1,30 @@ -FROM public.ecr.aws/amazonlinux/amazonlinux:2 +FROM public.ecr.aws/amazonlinux/amazonlinux:2023 -ARG NODE_VERSION=16 +ARG NODE_VERSION=18 ARG NVM_VERSION=v0.37.2 -ARG PYTHON_VERSION=python3.8 +ARG PYTHON_VERSION=python3.9 -RUN yum clean all -RUN yum -y install shadow-utils wget -RUN yum -y install openssl-devel bzip2-devel libffi-devel postgresql-devel gcc unzip tar gzip -RUN amazon-linux-extras install $PYTHON_VERSION -RUN yum -y install python38-devel -RUN yum -y install git +# Clean cache +RUN dnf clean all -RUN /bin/bash -c "ln -s /usr/bin/${PYTHON_VERSION} /usr/bin/python3" +# Installing libraries +RUN dnf -y install -y \ + shadow-utils wget openssl-devel bzip2-devel libffi-devel \ + postgresql-devel gcc unzip tar gzip + +# Install Python +RUN dnf install $PYTHON_VERSION +RUN dnf -y install python3-pip python3-devel git RUN useradd -m app +## Add source WORKDIR /build +# Configuring path RUN touch ~/.bashrc +# Install AWS CLI RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" RUN unzip awscliv2.zip RUN ./aws/install @@ -27,9 +33,11 @@ COPY ./docker/dev/wait-for-it.sh /build/wait-for-it.sh RUN chmod +x /build/wait-for-it.sh RUN chown -R app:root /build/wait-for-it.sh +## Add source WORKDIR /dataall RUN touch ~/.bashrc +# Configuring Node and CDK RUN curl -o- https://raw.githubusercontent.com/creationix/nvm/$NVM_VERSION/install.sh | bash RUN /bin/bash -c ". ~/.nvm/nvm.sh && \ nvm install $NODE_VERSION && nvm use $NODE_VERSION && \ @@ -46,17 +54,20 @@ $PATH" >> ~/.bashrc && \ RUN /bin/bash -c ". ~/.nvm/nvm.sh && cdk --version" -COPY ./requirements.txt dh.requirements.txt +# App specific requirements +COPY ./requirements.txt requirements.txt COPY ./dataall/base/cdkproxy/requirements.txt cdk.requirements.txt -COPY ./dataall /dataall +# Install App requirements +RUN /bin/bash -c "${PYTHON_VERSION} -m pip install setuptools" +RUN /bin/bash -c "${PYTHON_VERSION} -m pip install -r requirements.txt" +RUN /bin/bash -c "${PYTHON_VERSION} -m pip install -r cdk.requirements.txt" + +# App code +COPY ./dataall /dataall ADD ./cdkproxymain.py /cdkproxymain.py ADD ./local_graphql_server.py /local_graphql_server.py -RUN /bin/bash -c "${PYTHON_VERSION} -m pip install -U pip " -RUN /bin/bash -c "${PYTHON_VERSION} -m pip install -r dh.requirements.txt" -RUN /bin/bash -c "${PYTHON_VERSION} -m pip install -r cdk.requirements.txt" - WORKDIR / ENTRYPOINT [ "/bin/bash", "-c", ". ~/.nvm/nvm.sh && uvicorn cdkproxymain:app --host 0.0.0.0 --port 8080" ] diff --git a/backend/docker/prod/ecs/Dockerfile b/backend/docker/prod/ecs/Dockerfile index aadf853ab..83af5d7bd 100644 --- a/backend/docker/prod/ecs/Dockerfile +++ b/backend/docker/prod/ecs/Dockerfile @@ -1,24 +1,28 @@ -FROM public.ecr.aws/amazonlinux/amazonlinux:2 +FROM public.ecr.aws/amazonlinux/amazonlinux:2023 -ARG NODE_VERSION=16 +ARG NODE_VERSION=18 ARG NVM_VERSION=v0.37.2 ARG DEEQU_VERSION=2.0.0-spark-3.1 -ARG PYTHON_VERSION=python3.8 +ARG PYTHON_VERSION=python3.9 + +# Clean cache +RUN dnf upgrade -y;\ + find /var/tmp -name "*.rpm" -print -delete ;\ + find /tmp -name "*.rpm" -print -delete ;\ + dnf autoremove -y; \ + dnf clean all; rm -rfv /var/cache/dnf # Installing libraries -RUN yum upgrade -y \ - && find /var/tmp -name "*.rpm" -print -delete \ - && find /tmp -name "*.rpm" -print -delete \ - && yum autoremove -y \ - && yum clean all \ - && rm -rfv /var/cache/yum \ - && yum install -y \ +RUN dnf -y install \ shadow-utils wget openssl-devel bzip2-devel libffi-devel \ - postgresql-devel gcc unzip tar gzip \ - && amazon-linux-extras install $PYTHON_VERSION \ - && yum install -y python38-devel git \ - && /bin/bash -c "ln -s /usr/bin/${PYTHON_VERSION} /usr/bin/python3" \ - && curl https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip -o /tmp/awscliv2.zip \ + postgresql-devel gcc unzip tar gzip + +# Install Python +RUN dnf install $PYTHON_VERSION +RUN dnf -y install python3-pip python3-devel git + +# Install AWS CLI +RUN curl https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip -o /tmp/awscliv2.zip \ && unzip -q /tmp/awscliv2.zip -d /opt \ && /opt/aws/install --update -i /usr/local/aws-cli -b /usr/local/bin \ && rm /tmp/awscliv2.zip \ @@ -33,8 +37,9 @@ RUN curl -o- https://raw.githubusercontent.com/creationix/nvm/$NVM_VERSION/insta && /bin/bash -c ". ~/.nvm/nvm.sh && \ nvm install $NODE_VERSION && nvm use $NODE_VERSION && \ npm install -g aws-cdk && \ - nvm alias default node && nvm cache clear" \ - && echo export PATH="\ + nvm alias default node && nvm cache clear" + +RUN echo export PATH="\ /root/.nvm/versions/node/${NODE_VERSION}/bin:\ $(${PYTHON_VERSION} -m site --user-base)/bin:\ $(python3 -m site --user-base)/bin:\ @@ -42,22 +47,25 @@ RUN curl -o- https://raw.githubusercontent.com/creationix/nvm/$NVM_VERSION/insta echo "nvm use ${NODE_VERSION} 1> /dev/null" >> ~/.bashrc \ && /bin/bash -c ". ~/.nvm/nvm.sh && cdk --version" -RUN $PYTHON_VERSION -m pip install -U pip - -# App specific -ADD backend/requirements.txt /dh.requirements.txt +# App specific requirements +ADD backend/requirements.txt /requirements.txt ADD backend/dataall/base/cdkproxy/requirements.txt /cdk.requirements.txt -RUN /bin/bash -c "pip3.8 install -r /dh.requirements.txt" \ - && /bin/bash -c "pip3.8 install -r /cdk.requirements.txt" +# Install App requirements +RUN /bin/bash -c "${PYTHON_VERSION} -m pip install setuptools" +RUN /bin/bash -c "${PYTHON_VERSION} -m pip install -r requirements.txt" +RUN /bin/bash -c "${PYTHON_VERSION} -m pip install -r cdk.requirements.txt" +# App code ADD backend/dataall /dataall VOLUME ["/dataall"] ADD backend/cdkproxymain.py /cdkproxymain.py +# App configuration file ENV config_location="/config.json" COPY config.json /config.json +# Glue profiling jobs jars RUN mkdir -p dataall/modules/datasets/cdk/assets/glueprofilingjob/jars/ ADD https://repo1.maven.org/maven2/com/amazon/deequ/deequ/$DEEQU_VERSION/deequ-$DEEQU_VERSION.jar /dataall/modules/datasets/cdk/assets/glueprofilingjob/jars/ diff --git a/backend/docker/prod/lambda/Dockerfile b/backend/docker/prod/lambda/Dockerfile index 74609e98c..4ba78a8a7 100644 --- a/backend/docker/prod/lambda/Dockerfile +++ b/backend/docker/prod/lambda/Dockerfile @@ -1,28 +1,38 @@ -FROM public.ecr.aws/amazonlinux/amazonlinux:2 +FROM public.ecr.aws/amazonlinux/amazonlinux:2023 ARG FUNCTION_DIR="/home/app/" -ARG PYTHON_VERSION=python3.8 +ARG PYTHON_VERSION=python3.9 -RUN yum upgrade -y;\ +# Clean cache +RUN dnf upgrade -y;\ find /var/tmp -name "*.rpm" -print -delete ;\ find /tmp -name "*.rpm" -print -delete ;\ - yum autoremove -y; \ - yum clean packages; yum clean headers; yum clean metadata; yum clean all; rm -rfv /var/cache/yum + dnf autoremove -y; \ + dnf clean all; rm -rfv /var/cache/dnf -RUN yum -y install shadow-utils wget -RUN yum -y install openssl-devel bzip2-devel libffi-devel postgresql-devel gcc unzip tar gzip -RUN amazon-linux-extras install $PYTHON_VERSION -RUN yum -y install python38-devel +# Install libraries +RUN dnf -y install \ + shadow-utils wget openssl-devel bzip2-devel libffi-devel \ + postgresql-devel gcc unzip tar gzip -## Add your source +# Install Python +RUN dnf install $PYTHON_VERSION +RUN dnf -y install python3-pip python3-devel + +## Add source WORKDIR ${FUNCTION_DIR} +# App specific requirements COPY backend/requirements.txt ./requirements.txt -RUN $PYTHON_VERSION -m pip install -U pip -RUN $PYTHON_VERSION -m pip install -r requirements.txt -t . +# Install App requirements +RUN /bin/bash -c "${PYTHON_VERSION} -m pip install setuptools" +RUN /bin/bash -c "${PYTHON_VERSION} -m pip install -r requirements.txt" + +# App code COPY backend/. ./ +# App configuration file ENV config_location="config.json" COPY config.json ./config.json @@ -30,5 +40,5 @@ COPY config.json ./config.json RUN $PYTHON_VERSION -m pip install awslambdaric --target ${FUNCTION_DIR} # Command can be overwritten by providing a different command in the template directly. -ENTRYPOINT [ "python3.8", "-m", "awslambdaric" ] +ENTRYPOINT [ "python3.9", "-m", "awslambdaric" ] CMD ["auth_handler.handler"] diff --git a/deploy/stacks/container.py b/deploy/stacks/container.py index 25d1775e3..1c0c6a85e 100644 --- a/deploy/stacks/container.py +++ b/deploy/stacks/container.py @@ -81,7 +81,7 @@ def __init__( container_definitions=[ecs.CfnTaskDefinition.ContainerDefinitionProperty( image=cdkproxy_image.image_name, name=cdkproxy_container_name, - command=['python3.8', '-m', 'dataall.core.stacks.tasks.cdkproxy'], + command=['python3.9', '-m', 'dataall.core.stacks.tasks.cdkproxy'], environment=[ ecs.CfnTaskDefinition.KeyValuePairProperty( name="AWS_REGION", @@ -156,7 +156,7 @@ def __init__( stacks_updater, stacks_updater_task_def = self.set_scheduled_task( cluster=cluster, - command=['python3.8', '-m', 'dataall.core.environment.tasks.env_stacks_updater'], + command=['python3.9', '-m', 'dataall.core.environment.tasks.env_stacks_updater'], container_id=f'container', ecr_repository=ecr_repository, environment=self._create_env('INFO'), @@ -213,7 +213,7 @@ def __init__( def add_catalog_indexer_task(self): catalog_indexer_task, catalog_indexer_task_def = self.set_scheduled_task( cluster=self.ecs_cluster, - command=['python3.8', '-m', 'dataall.modules.catalog.tasks.catalog_indexer_task'], + command=['python3.9', '-m', 'dataall.modules.catalog.tasks.catalog_indexer_task'], container_id=f'container', ecr_repository=self._ecr_repository, environment=self._create_env('INFO'), @@ -251,7 +251,7 @@ def add_share_management_task(self): repository=self._ecr_repository, tag=self._cdkproxy_image_tag ), environment=self._create_env('DEBUG'), - command=['python3.8', '-m', 'dataall.modules.dataset_sharing.tasks.share_manager_task'], + command=['python3.9', '-m', 'dataall.modules.dataset_sharing.tasks.share_manager_task'], logging=ecs.LogDriver.aws_logs( stream_prefix='task', log_group=self.create_log_group( @@ -281,7 +281,7 @@ def add_subscription_task(self): subscriptions_task, subscription_task_def = self.set_scheduled_task( cluster=self.ecs_cluster, command=[ - 'python3.8', + 'python3.9', '-m', 'dataall.modules.datasets.tasks.dataset_subscription_task', ], @@ -306,7 +306,7 @@ def add_subscription_task(self): def add_bucket_policy_updater_task(self): update_bucket_policies_task, update_bucket_task_def = self.set_scheduled_task( cluster=self.ecs_cluster, - command=['python3.8', '-m', 'dataall.modules.datasets.tasks.bucket_policy_updater'], + command=['python3.9', '-m', 'dataall.modules.datasets.tasks.bucket_policy_updater'], container_id=f'container', ecr_repository=self._ecr_repository, environment=self._create_env('DEBUG'), @@ -328,7 +328,7 @@ def add_bucket_policy_updater_task(self): def add_sync_dataset_table_task(self): sync_tables_task, sync_tables_task_def = self.set_scheduled_task( cluster=self.ecs_cluster, - command=['python3.8', '-m', 'dataall.modules.datasets.tasks.tables_syncer'], + command=['python3.9', '-m', 'dataall.modules.datasets.tasks.tables_syncer'], container_id=f'container', ecr_repository=self._ecr_repository, environment=self._create_env('INFO'), diff --git a/deploy/stacks/dbmigration.py b/deploy/stacks/dbmigration.py index d71320ebe..bb28c2e36 100644 --- a/deploy/stacks/dbmigration.py +++ b/deploy/stacks/dbmigration.py @@ -141,7 +141,7 @@ def __init__( id=f'DBMigrationCBProject{envname}', project_name=f'{resource_prefix}-{envname}-dbmigration', environment=codebuild.BuildEnvironment( - build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_3, + build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_5, ), role=self.build_project_role, build_spec=codebuild.BuildSpec.from_object( diff --git a/deploy/stacks/pipeline.py b/deploy/stacks/pipeline.py index 538216a4b..e961b666a 100644 --- a/deploy/stacks/pipeline.py +++ b/deploy/stacks/pipeline.py @@ -137,7 +137,7 @@ def __init__( 'Synth', input=source, build_environment=codebuild.BuildEnvironment( - build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_4, + build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_5, ), commands=[ f'aws codeartifact login --tool npm --repository {self.codeartifact.codeartifact_npm_repo_name} --domain {self.codeartifact.codeartifact_domain_name} --domain-owner {self.codeartifact.domain.attr_owner}', @@ -430,7 +430,7 @@ def set_quality_gate_stage(self): pipelines.CodeBuildStep( id='ValidateDBMigrations', build_environment=codebuild.BuildEnvironment( - build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_4, + build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_5, ), commands=[ f'aws codeartifact login --tool pip --repository {self.codeartifact.codeartifact_pip_repo_name} --domain {self.codeartifact.codeartifact_domain_name} --domain-owner {self.codeartifact.domain.attr_owner}', @@ -447,7 +447,7 @@ def set_quality_gate_stage(self): pipelines.CodeBuildStep( id='SecurityChecks', build_environment=codebuild.BuildEnvironment( - build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_4, + build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_5, ), commands=[ f'aws codeartifact login --tool pip --repository {self.codeartifact.codeartifact_pip_repo_name} --domain {self.codeartifact.codeartifact_domain_name} --domain-owner {self.codeartifact.domain.attr_owner}', @@ -462,7 +462,7 @@ def set_quality_gate_stage(self): pipelines.CodeBuildStep( id='Lint', build_environment=codebuild.BuildEnvironment( - build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_4, + build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_5, ), commands=[ f'aws codeartifact login --tool pip --repository {self.codeartifact.codeartifact_pip_repo_name} --domain {self.codeartifact.codeartifact_domain_name} --domain-owner {self.codeartifact.domain.attr_owner}', @@ -484,7 +484,7 @@ def set_quality_gate_stage(self): pipelines.CodeBuildStep( id='IntegrationTests', build_environment=codebuild.BuildEnvironment( - build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_4, + build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_5, ), partial_build_spec=codebuild.BuildSpec.from_object( dict( @@ -518,7 +518,7 @@ def set_quality_gate_stage(self): pipelines.CodeBuildStep( id='UploadCodeToS3', build_environment=codebuild.BuildEnvironment( - build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_4, + build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_5, ), commands=[ 'mkdir -p source_build', @@ -538,7 +538,7 @@ def set_quality_gate_stage(self): pipelines.CodeBuildStep( id='UploadCodeToS3', build_environment=codebuild.BuildEnvironment( - build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_4, + build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_5, ), commands=[ 'mkdir -p source_build', @@ -576,7 +576,7 @@ def set_ecr_stage( pipelines.CodeBuildStep( id='LambdaImage', build_environment=codebuild.BuildEnvironment( - build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_4, + build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_5, privileged=True, environment_variables={ 'REPOSITORY_URI': codebuild.BuildEnvironmentVariable( @@ -594,7 +594,7 @@ def set_ecr_stage( pipelines.CodeBuildStep( id='ECSImage', build_environment=codebuild.BuildEnvironment( - build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_4, + build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_5, privileged=True, environment_variables={ 'REPOSITORY_URI': codebuild.BuildEnvironmentVariable( @@ -660,7 +660,7 @@ def set_db_migration_stage( pipelines.CodeBuildStep( id='MigrateDB', build_environment=codebuild.BuildEnvironment( - build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_4, + build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_5, ), commands=[ 'mkdir ~/.aws/ && touch ~/.aws/config', @@ -690,7 +690,7 @@ def set_stacks_updater_stage( pipelines.CodeBuildStep( id='StacksUpdater', build_environment=codebuild.BuildEnvironment( - build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_4, + build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_5, ), commands=[ 'mkdir ~/.aws/ && touch ~/.aws/config', @@ -730,7 +730,7 @@ def set_cloudfront_stage(self, target_env): pipelines.CodeBuildStep( id='DeployFrontEnd', build_environment=codebuild.BuildEnvironment( - build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_4, + build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_5, compute_type=codebuild.ComputeType.LARGE, ), commands=[ @@ -752,6 +752,7 @@ def set_cloudfront_stage(self, target_env): 'pip install beautifulsoup4', 'python deploy/configs/frontend_config.py', 'export AWS_DEFAULT_REGION=us-east-1', + 'export AWS_REGION=us-east-1', f"export distributionId=$(aws ssm get-parameter --name /dataall/{target_env['envname']}/CloudfrontDistributionId --profile buildprofile --output text --query 'Parameter.Value')", f"export bucket=$(aws ssm get-parameter --name /dataall/{target_env['envname']}/CloudfrontDistributionBucket --profile buildprofile --output text --query 'Parameter.Value')", 'export NODE_OPTIONS="--max-old-space-size=6144"', @@ -781,7 +782,7 @@ def set_cloudfront_stage(self, target_env): pipelines.CodeBuildStep( id='UpdateDocumentation', build_environment=codebuild.BuildEnvironment( - build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_4, + build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_5, ), commands=[ f'aws codeartifact login --tool pip --repository {self.codeartifact.codeartifact_pip_repo_name} --domain {self.codeartifact.codeartifact_domain_name} --domain-owner {self.codeartifact.domain.attr_owner}', @@ -789,6 +790,7 @@ def set_cloudfront_stage(self, target_env): '. ./.env.assumed_role', 'aws sts get-caller-identity', 'export AWS_DEFAULT_REGION=us-east-1', + 'export AWS_REGION=us-east-1', f"export distributionId=$(aws ssm get-parameter --name /dataall/{target_env['envname']}/cloudfront/docs/user/CloudfrontDistributionId --output text --query 'Parameter.Value')", f"export bucket=$(aws ssm get-parameter --name /dataall/{target_env['envname']}/cloudfront/docs/user/CloudfrontDistributionBucket --output text --query 'Parameter.Value')", 'cd documentation/userguide', @@ -806,7 +808,7 @@ def cw_rum_config_action(self, target_env): return pipelines.CodeBuildStep( id='ConfigureRUM', build_environment=codebuild.BuildEnvironment( - build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_4, + build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_5, ), commands=[ f'export envname={target_env["envname"]}', @@ -832,7 +834,7 @@ def cognito_config_action(self, target_env): return pipelines.CodeBuildStep( id='ConfigureCognito', build_environment=codebuild.BuildEnvironment( - build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_4, + build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_5, ), commands=[ f'export envname={target_env["envname"]}', @@ -875,7 +877,7 @@ def set_albfront_stage(self, target_env, repository_name): pipelines.CodeBuildStep( id='FrontendImage', build_environment=codebuild.BuildEnvironment( - build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_4, + build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_5, compute_type=codebuild.ComputeType.LARGE, privileged=True, environment_variables={ @@ -915,7 +917,7 @@ def set_albfront_stage(self, target_env, repository_name): pipelines.CodeBuildStep( id='UserGuideImage', build_environment=codebuild.BuildEnvironment( - build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_4, + build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_5, compute_type=codebuild.ComputeType.LARGE, privileged=True, environment_variables={ @@ -961,7 +963,7 @@ def set_release_stage( pipelines.CodeBuildStep( id='GitRelease', build_environment=codebuild.BuildEnvironment( - build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_4, + build_image=codebuild.LinuxBuildImage.AMAZON_LINUX_2_5, ), partial_build_spec=codebuild.BuildSpec.from_object( dict( diff --git a/docker-compose.yaml b/docker-compose.yaml index e10f021ee..9495269ab 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -32,7 +32,7 @@ services: build: context: ./backend dockerfile: docker/dev/Dockerfile - entrypoint: /bin/bash -c "../build/wait-for-it.sh elasticsearch:9200 -t 30 && python3.8 local_graphql_server.py" + entrypoint: /bin/bash -c "../build/wait-for-it.sh elasticsearch:9200 -t 30 && python3.9 local_graphql_server.py" expose: - 5000 ports: diff --git a/documentation/userguide/docker/prod/Dockerfile b/documentation/userguide/docker/prod/Dockerfile index 1a11f64ae..f06ed9436 100644 --- a/documentation/userguide/docker/prod/Dockerfile +++ b/documentation/userguide/docker/prod/Dockerfile @@ -1,14 +1,14 @@ -FROM public.ecr.aws/amazonlinux/amazonlinux:2 +FROM public.ecr.aws/amazonlinux/amazonlinux:2023 -ARG NODE_VERSION=16 -ARG PYTHON_VERSION=3.8 +ARG NODE_VERSION=18 +ARG PYTHON_VERSION=3.9 ARG NGINX_VERSION=1.12 ARG ENVSUBST_VERSION=v1.1.0 -RUN yum -y install shadow-utils wget -RUN yum -y install openssl-devel bzip2-devel libffi-devel postgresql-devel gcc unzip tar gzip -RUN amazon-linux-extras install python$PYTHON_VERSION -RUN amazon-linux-extras install nginx$NGINX_VERSION +RUN dnf -y install shadow-utils wget +RUN dnf -y install openssl-devel bzip2-devel libffi-devel postgresql-devel gcc unzip tar gzip +RUN dnf install python$PYTHON_VERSION +RUN dnf install nginx$NGINX_VERSION RUN touch ~/.bashrc diff --git a/frontend/docker/prod/Dockerfile b/frontend/docker/prod/Dockerfile index 1a4e85ff4..8aa2683b5 100644 --- a/frontend/docker/prod/Dockerfile +++ b/frontend/docker/prod/Dockerfile @@ -1,15 +1,15 @@ -FROM public.ecr.aws/amazonlinux/amazonlinux:2 +FROM public.ecr.aws/amazonlinux/amazonlinux:2023 ARG REACT_APP_STAGE ARG DOMAIN -ARG NODE_VERSION=16 +ARG NODE_VERSION=18 ARG NGINX_VERSION=1.12 ARG NVM_VERSION=v0.37.0 -RUN yum update -y && \ - yum install -y tar gzip openssl && \ - yum clean all -y -RUN amazon-linux-extras install nginx$NGINX_VERSION +RUN dnf update -y && \ + dnf install -y tar gzip openssl && \ + dnf clean all -y +RUN dnf install nginx$NGINX_VERSION RUN touch ~/.bashrc From 94c93d9ca10f4a23c904fb7248df891dfe3e051a Mon Sep 17 00:00:00 2001 From: Noah Paige <69586985+noah-paige@users.noreply.github.com> Date: Fri, 8 Dec 2023 08:43:34 -0500 Subject: [PATCH 5/5] Byo vpc mlstudio (#894) ### Feature or Bugfix - Feature ### Detail - Enable SageMaker Studio Domain to be deployed in a already provisioned VPC ### Relates - https://github.com/awslabs/aws-dataall/issues/795 ### Security Please answer the questions below briefly where applicable, or write `N/A`. Based on [OWASP 10](https://owasp.org/Top10/en/). - Does this PR introduce or modify any input fields or queries - this includes fetching data from storage outside the application (e.g. a database, an S3 bucket)? - Is the input sanitized? - What precautions are you taking before deserializing the data you consume? - Is injection prevented by parametrizing queries? - Have you ensured no `eval` or similar functions are used? - Does this PR introduce any functionality or component that requires authorization? - How have you ensured it respects the existing AuthN/AuthZ mechanisms? - Are you logging failed auth attempts? - Are you using or adding any cryptographic features? - Do you use a standard proven implementations? - Are the used keys controlled by the customer? Where are they stored? - Are you introducing any new policies/roles/users? - Have you used the least-privilege principle? How? By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license. --- backend/dataall/base/aws/ec2_client.py | 65 +++++++ .../dataall/base/utils/naming_convention.py | 1 + .../core/environment/api/input_types.py | 15 +- .../dataall/core/environment/api/resolvers.py | 28 ++- .../services/environment_resource_manager.py | 15 +- .../services/environment_service.py | 24 +-- .../dashboards/db/dashboard_repositories.py | 2 +- backend/dataall/modules/mlstudio/__init__.py | 5 +- .../dataall/modules/mlstudio/api/queries.py | 10 + .../dataall/modules/mlstudio/api/resolvers.py | 9 +- backend/dataall/modules/mlstudio/api/types.py | 24 +++ .../modules/mlstudio/aws/ec2_client.py | 27 --- .../mlstudio/aws/sagemaker_studio_client.py | 20 +- .../mlstudio/cdk/mlstudio_extension.py | 160 ++++++++-------- .../modules/mlstudio/db/mlstudio_models.py | 15 +- .../mlstudio/db/mlstudio_repositories.py | 95 ++++++++-- .../mlstudio/services/mlstudio_service.py | 99 +++++++++- ...f5de322f_update_sagemaker_studio_domain.py | 178 ++++++++++++++++++ .../components/EnvironmentMLStudio.js | 156 +++++++++++++++ .../modules/Environments/components/index.js | 1 + .../views/EnvironmentCreateForm.js | 69 ++++++- .../Environments/views/EnvironmentEditForm.js | 76 +++++++- .../Environments/views/EnvironmentView.js | 10 + .../MLStudio/getEnvironmentMLStudioDomain.js | 23 +++ .../src/services/graphql/MLStudio/index.js | 1 + frontend/src/services/graphql/index.js | 1 + tests/core/conftest.py | 1 - tests/core/environments/test_environment.py | 29 +-- tests/core/vpc/test_vpc.py | 4 +- tests/modules/mlstudio/cdk/conftest.py | 20 +- .../cdk/test_sagemaker_studio_stack.py | 9 +- tests/modules/mlstudio/conftest.py | 122 +++++++++++- .../modules/mlstudio/test_sagemaker_studio.py | 146 +++++++++++++- 33 files changed, 1207 insertions(+), 253 deletions(-) create mode 100644 backend/dataall/base/aws/ec2_client.py delete mode 100644 backend/dataall/modules/mlstudio/aws/ec2_client.py create mode 100644 backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py create mode 100644 frontend/src/modules/Environments/components/EnvironmentMLStudio.js create mode 100644 frontend/src/services/graphql/MLStudio/getEnvironmentMLStudioDomain.js create mode 100644 frontend/src/services/graphql/MLStudio/index.js diff --git a/backend/dataall/base/aws/ec2_client.py b/backend/dataall/base/aws/ec2_client.py new file mode 100644 index 000000000..06bd62c7a --- /dev/null +++ b/backend/dataall/base/aws/ec2_client.py @@ -0,0 +1,65 @@ +import logging + +from dataall.base.aws.sts import SessionHelper +from botocore.exceptions import ClientError + +log = logging.getLogger(__name__) + + +class EC2: + + @staticmethod + def get_client(account_id: str, region: str, role=None): + session = SessionHelper.remote_session(accountid=account_id, role=role) + return session.client('ec2', region_name=region) + + @staticmethod + def check_default_vpc_exists(AwsAccountId: str, region: str, role=None): + log.info("Check that default VPC exists..") + client = EC2.get_client(account_id=AwsAccountId, region=region, role=role) + response = client.describe_vpcs( + Filters=[{'Name': 'isDefault', 'Values': ['true']}] + ) + vpcs = response['Vpcs'] + log.info(f"Default VPCs response: {vpcs}") + if vpcs: + vpc_id = vpcs[0]['VpcId'] + subnetIds = EC2._get_vpc_subnets(AwsAccountId=AwsAccountId, region=region, vpc_id=vpc_id, role=role) + if subnetIds: + return vpc_id, subnetIds + return False + + @staticmethod + def _get_vpc_subnets(AwsAccountId: str, region: str, vpc_id: str, role=None): + client = EC2.get_client(account_id=AwsAccountId, region=region, role=role) + response = client.describe_subnets( + Filters=[{'Name': 'vpc-id', 'Values': [vpc_id]}] + ) + return [subnet['SubnetId'] for subnet in response['Subnets']] + + @staticmethod + def check_vpc_exists(AwsAccountId, region, vpc_id, role=None, subnet_ids=[]): + try: + ec2 = EC2.get_client(account_id=AwsAccountId, region=region, role=role) + response = ec2.describe_vpcs(VpcIds=[vpc_id]) + except ClientError as e: + log.exception(f'VPC Id {vpc_id} Not Found: {e}') + raise Exception(f'VPCNotFound: {vpc_id}') + + try: + if subnet_ids: + response = ec2.describe_subnets( + Filters=[ + { + 'Name': 'vpc-id', + 'Values': [vpc_id] + }, + ], + SubnetIds=subnet_ids + ) + except ClientError as e: + log.exception(f'Subnet Id {subnet_ids} Not Found: {e}') + raise Exception(f'VPCSubnetsNotFound: {subnet_ids}') + + if not subnet_ids or len(response['Subnets']) != len(subnet_ids): + raise Exception(f'Not All Subnets: {subnet_ids} Are Within the Specified VPC Id {vpc_id}') diff --git a/backend/dataall/base/utils/naming_convention.py b/backend/dataall/base/utils/naming_convention.py index 3501fa71b..262964560 100644 --- a/backend/dataall/base/utils/naming_convention.py +++ b/backend/dataall/base/utils/naming_convention.py @@ -10,6 +10,7 @@ class NamingConventionPattern(Enum): GLUE = {'regex': '[^a-zA-Z0-9_]', 'separator': '_', 'max_length': 63} GLUE_ETL = {'regex': '[^a-zA-Z0-9-]', 'separator': '-', 'max_length': 52} NOTEBOOK = {'regex': '[^a-zA-Z0-9-]', 'separator': '-', 'max_length': 63} + MLSTUDIO_DOMAIN = {'regex': '[^a-zA-Z0-9-]', 'separator': '-', 'max_length': 63} DEFAULT = {'regex': '[^a-zA-Z0-9-_]', 'separator': '-', 'max_length': 63} OPENSEARCH = {'regex': '[^a-z0-9-]', 'separator': '-', 'max_length': 27} OPENSEARCH_SERVERLESS = {'regex': '[^a-z0-9-]', 'separator': '-', 'max_length': 31} diff --git a/backend/dataall/core/environment/api/input_types.py b/backend/dataall/core/environment/api/input_types.py index 9b618d0e5..27188f4ed 100644 --- a/backend/dataall/core/environment/api/input_types.py +++ b/backend/dataall/core/environment/api/input_types.py @@ -28,13 +28,11 @@ gql.Argument('description', gql.String), gql.Argument('AwsAccountId', gql.NonNullableType(gql.String)), gql.Argument('region', gql.NonNullableType(gql.String)), - gql.Argument('vpcId', gql.String), - gql.Argument('privateSubnetIds', gql.ArrayType(gql.String)), - gql.Argument('publicSubnetIds', gql.ArrayType(gql.String)), gql.Argument('EnvironmentDefaultIAMRoleArn', gql.String), gql.Argument('resourcePrefix', gql.String), - gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)) - + gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)), + gql.Argument('vpcId', gql.String), + gql.Argument('subnetIds', gql.ArrayType(gql.String)) ], ) @@ -45,11 +43,10 @@ gql.Argument('description', gql.String), gql.Argument('tags', gql.ArrayType(gql.String)), gql.Argument('SamlGroupName', gql.String), - gql.Argument('vpcId', gql.String), - gql.Argument('privateSubnetIds', gql.ArrayType(gql.String)), - gql.Argument('publicSubnetIds', gql.ArrayType(gql.String)), gql.Argument('resourcePrefix', gql.String), - gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)) + gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)), + gql.Argument('vpcId', gql.String), + gql.Argument('subnetIds', gql.ArrayType(gql.String)) ], ) diff --git a/backend/dataall/core/environment/api/resolvers.py b/backend/dataall/core/environment/api/resolvers.py index 7f3e4c765..06878cdfc 100644 --- a/backend/dataall/core/environment/api/resolvers.py +++ b/backend/dataall/core/environment/api/resolvers.py @@ -20,6 +20,7 @@ from dataall.core.stacks.aws.cloudformation import CloudFormation from dataall.core.stacks.db.stack_repositories import Stack from dataall.core.vpc.db.vpc_repositories import Vpc +from dataall.base.aws.ec2_client import EC2 from dataall.base.db import exceptions from dataall.core.permissions import permissions from dataall.base.feature_toggle_checker import is_feature_enabled @@ -43,7 +44,7 @@ def get_pivot_role_as_part_of_environment(context: Context, source, **kwargs): return True if ssm_param == "True" else False -def check_environment(context: Context, source, account_id, region): +def check_environment(context: Context, source, account_id, region, data): """ Checks necessary resources for environment deployment. - Check CDKToolkit exists in Account assuming cdk_look_up_role - Check Pivot Role exists in Account if pivot_role_as_part_of_environment is False @@ -71,11 +72,25 @@ def check_environment(context: Context, source, account_id, region): action='CHECK_PIVOT_ROLE', message='Pivot Role has not been created in the Environment AWS Account', ) + mlStudioEnabled = None + for parameter in data.get("parameters", []): + if parameter['key'] == 'mlStudiosEnabled': + mlStudioEnabled = parameter['value'] + + if mlStudioEnabled and data.get("vpcId", None) and data.get("subnetIds", []): + log.info("Check if ML Studio VPC Exists in the Account") + EC2.check_vpc_exists( + AwsAccountId=account_id, + region=region, + role=cdk_look_up_role_arn, + vpc_id=data.get("vpcId", None), + subnet_ids=data.get('subnetIds', []), + ) return cdk_role_name -def create_environment(context: Context, source, input=None): +def create_environment(context: Context, source, input={}): if input.get('SamlGroupName') and input.get('SamlGroupName') not in context.groups: raise exceptions.UnauthorizedOperation( action=permissions.LINK_ENVIRONMENT, @@ -85,8 +100,10 @@ def create_environment(context: Context, source, input=None): with context.engine.scoped_session() as session: cdk_role_name = check_environment(context, source, account_id=input.get('AwsAccountId'), - region=input.get('region') + region=input.get('region'), + data=input ) + input['cdk_role_name'] = cdk_role_name env = EnvironmentService.create_environment( session=session, @@ -119,7 +136,8 @@ def update_environment( environment = EnvironmentService.get_environment_by_uri(session, environmentUri) cdk_role_name = check_environment(context, source, account_id=environment.AwsAccountId, - region=environment.region + region=environment.region, + data=input ) previous_resource_prefix = environment.resourcePrefix @@ -130,7 +148,7 @@ def update_environment( data=input, ) - if EnvironmentResourceManager.deploy_updated_stack(session, previous_resource_prefix, environment): + if EnvironmentResourceManager.deploy_updated_stack(session, previous_resource_prefix, environment, data=input): stack_helper.deploy_stack(targetUri=environment.environmentUri) return environment diff --git a/backend/dataall/core/environment/services/environment_resource_manager.py b/backend/dataall/core/environment/services/environment_resource_manager.py index bc74f01bf..f5c2551fa 100644 --- a/backend/dataall/core/environment/services/environment_resource_manager.py +++ b/backend/dataall/core/environment/services/environment_resource_manager.py @@ -12,7 +12,11 @@ def delete_env(session, environment): pass @staticmethod - def update_env(session, environment): + def create_env(session, environment, **kwargs): + pass + + @staticmethod + def update_env(session, environment, **kwargs): return False @staticmethod @@ -39,10 +43,10 @@ def count_group_resources(cls, session, environment, group_uri) -> int: return counter @classmethod - def deploy_updated_stack(cls, session, prev_prefix, environment): + def deploy_updated_stack(cls, session, prev_prefix, environment, **kwargs): deploy_stack = prev_prefix != environment.resourcePrefix for resource in cls._resources: - deploy_stack |= resource.update_env(session, environment) + deploy_stack |= resource.update_env(session, environment, **kwargs) return deploy_stack @@ -51,6 +55,11 @@ def delete_env(cls, session, environment): for resource in cls._resources: resource.delete_env(session, environment) + @classmethod + def create_env(cls, session, environment, **kwargs): + for resource in cls._resources: + resource.create_env(session, environment, **kwargs) + @classmethod def count_consumption_role_resources(cls, session, role_uri): counter = 0 diff --git a/backend/dataall/core/environment/services/environment_service.py b/backend/dataall/core/environment/services/environment_service.py index ddea435c4..1b2dbec07 100644 --- a/backend/dataall/core/environment/services/environment_service.py +++ b/backend/dataall/core/environment/services/environment_service.py @@ -66,6 +66,7 @@ def create_environment(session, uri, data=None): session.commit() EnvironmentService._update_env_parameters(session, env, data) + EnvironmentResourceManager.create_env(session, env, data=data) env.EnvironmentDefaultBucketName = NamingConventionService( target_uri=env.environmentUri, @@ -98,29 +99,6 @@ def create_environment(session, uri, data=None): env.EnvironmentDefaultIAMRoleArn = data['EnvironmentDefaultIAMRoleArn'] env.EnvironmentDefaultIAMRoleImported = True - if data.get('vpcId'): - vpc = Vpc( - environmentUri=env.environmentUri, - region=env.region, - AwsAccountId=env.AwsAccountId, - VpcId=data.get('vpcId'), - privateSubnetIds=data.get('privateSubnetIds', []), - publicSubnetIds=data.get('publicSubnetIds', []), - SamlGroupName=data['SamlGroupName'], - owner=context.username, - label=f"{env.name}-{data.get('vpcId')}", - name=f"{env.name}-{data.get('vpcId')}", - default=True, - ) - session.add(vpc) - session.commit() - ResourcePolicy.attach_resource_policy( - session=session, - group=data['SamlGroupName'], - permissions=permissions.NETWORK_ALL, - resource_uri=vpc.vpcUri, - resource_type=Vpc.__name__, - ) env_group = EnvironmentGroup( environmentUri=env.environmentUri, groupUri=data['SamlGroupName'], diff --git a/backend/dataall/modules/dashboards/db/dashboard_repositories.py b/backend/dataall/modules/dashboards/db/dashboard_repositories.py index 91916f8ff..a8d9d6a2f 100644 --- a/backend/dataall/modules/dashboards/db/dashboard_repositories.py +++ b/backend/dataall/modules/dashboards/db/dashboard_repositories.py @@ -26,7 +26,7 @@ def count_resources(session, environment, group_uri) -> int: ) @staticmethod - def update_env(session, environment): + def update_env(session, environment, **kwargs): return EnvironmentService.get_boolean_env_param(session, environment, "dashboardsEnabled") @staticmethod diff --git a/backend/dataall/modules/mlstudio/__init__.py b/backend/dataall/modules/mlstudio/__init__.py index 2db9c0a1e..a6ca73917 100644 --- a/backend/dataall/modules/mlstudio/__init__.py +++ b/backend/dataall/modules/mlstudio/__init__.py @@ -3,7 +3,8 @@ from dataall.base.loader import ImportMode, ModuleInterface from dataall.core.stacks.db.target_type_repositories import TargetType -from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository +from dataall.modules.mlstudio.services.mlstudio_service import SagemakerStudioEnvironmentResource +from dataall.core.environment.services.environment_resource_manager import EnvironmentResourceManager log = logging.getLogger(__name__) @@ -20,6 +21,8 @@ def __init__(self): from dataall.modules.mlstudio.services.mlstudio_permissions import GET_SGMSTUDIO_USER, UPDATE_SGMSTUDIO_USER TargetType("mlstudio", GET_SGMSTUDIO_USER, UPDATE_SGMSTUDIO_USER) + EnvironmentResourceManager.register(SagemakerStudioEnvironmentResource()) + log.info("API of sagemaker mlstudio has been imported") diff --git a/backend/dataall/modules/mlstudio/api/queries.py b/backend/dataall/modules/mlstudio/api/queries.py index 457559def..ee014839f 100644 --- a/backend/dataall/modules/mlstudio/api/queries.py +++ b/backend/dataall/modules/mlstudio/api/queries.py @@ -4,6 +4,7 @@ get_sagemaker_studio_user, list_sagemaker_studio_users, get_sagemaker_studio_user_presigned_url, + get_environment_sagemaker_studio_domain ) getSagemakerStudioUser = gql.QueryField( @@ -34,3 +35,12 @@ type=gql.String, resolver=get_sagemaker_studio_user_presigned_url, ) + +getEnvironmentMLStudioDomain = gql.QueryField( + name='getEnvironmentMLStudioDomain', + args=[ + gql.Argument(name='environmentUri', type=gql.NonNullableType(gql.String)), + ], + type=gql.Ref('SagemakerStudioDomain'), + resolver=get_environment_sagemaker_studio_domain, +) diff --git a/backend/dataall/modules/mlstudio/api/resolvers.py b/backend/dataall/modules/mlstudio/api/resolvers.py index 63dc25ed7..48c9350fa 100644 --- a/backend/dataall/modules/mlstudio/api/resolvers.py +++ b/backend/dataall/modules/mlstudio/api/resolvers.py @@ -18,7 +18,7 @@ def required_uri(uri): raise exceptions.RequiredParameter('URI') @staticmethod - def validate_creation_request(data): + def validate_user_creation_request(data): required = RequestValidator._required if not data: raise exceptions.RequiredParameter('data') @@ -36,7 +36,7 @@ def _required(data: dict, name: str): def create_sagemaker_studio_user(context: Context, source, input: dict = None): """Creates a SageMaker Studio user. Deploys the SageMaker Studio user stack into AWS""" - RequestValidator.validate_creation_request(input) + RequestValidator.validate_user_creation_request(input) request = SagemakerStudioCreationRequest.from_dict(input) return SagemakerStudioService.create_sagemaker_studio_user( uri=input["environmentUri"], @@ -90,6 +90,11 @@ def delete_sagemaker_studio_user( ) +def get_environment_sagemaker_studio_domain(context, source, environmentUri: str = None): + RequestValidator.required_uri(environmentUri) + return SagemakerStudioService.get_environment_sagemaker_studio_domain(environment_uri=environmentUri) + + def resolve_user_role(context: Context, source: SagemakerStudioUser): """ Resolves the role of the current user in reference with the SageMaker Studio User diff --git a/backend/dataall/modules/mlstudio/api/types.py b/backend/dataall/modules/mlstudio/api/types.py index 21290711e..ca21df81d 100644 --- a/backend/dataall/modules/mlstudio/api/types.py +++ b/backend/dataall/modules/mlstudio/api/types.py @@ -79,3 +79,27 @@ gql.Field(name='nodes', type=gql.ArrayType(SagemakerStudioUser)), ], ) + +SagemakerStudioDomain = gql.ObjectType( + name='SagemakerStudioDomain', + fields=[ + gql.Field(name='sagemakerStudioUri', type=gql.ID), + gql.Field(name='environmentUri', type=gql.NonNullableType(gql.String)), + gql.Field(name='sagemakerStudioDomainName', type=gql.String), + gql.Field(name='DefaultDomainRoleName', type=gql.String), + gql.Field(name='label', type=gql.String), + gql.Field(name='name', type=gql.String), + gql.Field(name='vpcType', type=gql.String), + gql.Field(name='vpcId', type=gql.String), + gql.Field(name='subnetIds', type=gql.ArrayType(gql.String)), + gql.Field(name='owner', type=gql.String), + gql.Field(name='created', type=gql.String), + gql.Field(name='updated', type=gql.String), + gql.Field(name='deleted', type=gql.String), + gql.Field( + name='environment', + type=gql.Ref('Environment'), + resolver=resolve_environment, + ) + ], +) diff --git a/backend/dataall/modules/mlstudio/aws/ec2_client.py b/backend/dataall/modules/mlstudio/aws/ec2_client.py deleted file mode 100644 index 3dc484254..000000000 --- a/backend/dataall/modules/mlstudio/aws/ec2_client.py +++ /dev/null @@ -1,27 +0,0 @@ -import logging - -from dataall.base.aws.sts import SessionHelper - - -log = logging.getLogger(__name__) - - -class EC2: - - @staticmethod - def get_client(account_id: str, region: str, role=None): - session = SessionHelper.remote_session(accountid=account_id, role=role) - return session.client('ec2', region_name=region) - - @staticmethod - def check_default_vpc_exists(AwsAccountId: str, region: str, role=None): - log.info("Check that default VPC exists..") - client = EC2.get_client(account_id=AwsAccountId, region=region, role=role) - response = client.describe_vpcs( - Filters=[{'Name': 'isDefault', 'Values': ['true']}] - ) - vpcs = response['Vpcs'] - log.info(f"Default VPCs response: {vpcs}") - if vpcs: - return True - return False diff --git a/backend/dataall/modules/mlstudio/aws/sagemaker_studio_client.py b/backend/dataall/modules/mlstudio/aws/sagemaker_studio_client.py index 2a82806ea..2ee872b1c 100644 --- a/backend/dataall/modules/mlstudio/aws/sagemaker_studio_client.py +++ b/backend/dataall/modules/mlstudio/aws/sagemaker_studio_client.py @@ -12,28 +12,22 @@ def get_client(AwsAccountId, region): return session.client('sagemaker', region_name=region) -def get_sagemaker_studio_domain(AwsAccountId, region): +def get_sagemaker_studio_domain(AwsAccountId, region, domain_name): """ Sagemaker studio domain is limited to 5 per account/region RETURN: an existing domain or None if no domain is in the AWS account """ client = get_client(AwsAccountId=AwsAccountId, region=region) - existing_domain = dict() try: domain_id_paginator = client.get_paginator('list_domains') - domains = domain_id_paginator.paginate() - for _domain in domains: - print(_domain) - for _domain in _domain.get('Domains'): - # Get the domain name created by dataall - if 'dataall' in _domain: - return _domain - else: - existing_domain = _domain - return existing_domain + for page in domain_id_paginator.paginate(): + for domain in page.get('Domains', []): + if domain.get("DomainName") == domain_name: + return domain + return dict() except ClientError as e: print(e) - return 'NotFound' + return dict() class SagemakerStudioClient: diff --git a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py index fe9040ab9..49082ccfb 100644 --- a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py +++ b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py @@ -12,14 +12,12 @@ aws_ssm as ssm, RemovalPolicy, ) -from botocore.exceptions import ClientError +from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository -from dataall.base.aws.parameter_store import ParameterStoreManager from dataall.base.aws.sts import SessionHelper from dataall.core.environment.cdk.environment_stack import EnvironmentSetup, EnvironmentStackExtension from dataall.core.environment.services.environment_service import EnvironmentService -from dataall.modules.mlstudio.aws.ec2_client import EC2 -from dataall.modules.mlstudio.aws.sagemaker_studio_client import get_sagemaker_studio_domain +from dataall.base.aws.ec2_client import EC2 logger = logging.getLogger(__name__) @@ -31,75 +29,84 @@ def extent(setup: EnvironmentSetup): _environment = setup.environment() with setup.get_engine().scoped_session() as session: enabled = EnvironmentService.get_boolean_env_param(session, _environment, "mlStudiosEnabled") - if not enabled: + domain = SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, _environment.environmentUri) + if not enabled or not domain: return sagemaker_principals = [setup.default_role] + setup.group_roles logger.info(f'Creating SageMaker base resources for sagemaker_principals = {sagemaker_principals}..') - cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn( - accountid=_environment.AwsAccountId, region=_environment.region - ) - existing_default_vpc = EC2.check_default_vpc_exists( - AwsAccountId=_environment.AwsAccountId, region=_environment.region, role=cdk_look_up_role_arn - ) - if existing_default_vpc: - logger.info("Using default VPC for Sagemaker Studio domain") - # Use default VPC - initial configuration (to be migrated) - vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', is_default=True) - subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] - subnet_ids += [public_subnet.subnet_id for public_subnet in vpc.public_subnets] - subnet_ids += [isolated_subnet.subnet_id for isolated_subnet in vpc.isolated_subnets] + + if domain.vpcId and domain.subnetIds and domain.vpcType == 'imported': + logger.info(f'Using VPC {domain.vpcId} and subnets {domain.subnetIds} for SageMaker Studio domain') + vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', vpc_id=domain.vpcId) + subnet_ids = domain.subnetIds security_groups = [] else: - logger.info("Default VPC not found, Exception. Creating a VPC for SageMaker resources...") - # Create VPC with 3 Public Subnets and 3 Private subnets wit NAT Gateways - log_group = logs.LogGroup( - setup, - f'SageMakerStudio{_environment.name}', - log_group_name=f'/{_environment.resourcePrefix}/{_environment.name}/vpc/sagemakerstudio', - retention=logs.RetentionDays.ONE_MONTH, - removal_policy=RemovalPolicy.DESTROY, - ) - vpc_flow_role = iam.Role( - setup, 'FlowLog', - assumed_by=iam.ServicePrincipal('vpc-flow-logs.amazonaws.com') - ) - vpc = ec2.Vpc( - setup, - "SageMakerVPC", - max_azs=3, - cidr="10.10.0.0/16", - subnet_configuration=[ - ec2.SubnetConfiguration( - subnet_type=ec2.SubnetType.PUBLIC, - name="Public", - cidr_mask=24 - ), - ec2.SubnetConfiguration( - subnet_type=ec2.SubnetType.PRIVATE_WITH_NAT, - name="Private", - cidr_mask=24 - ), - ], - enable_dns_hostnames=True, - enable_dns_support=True, - ) - ec2.FlowLog( - setup, "StudioVPCFlowLog", - resource_type=ec2.FlowLogResourceType.from_vpc(vpc), - destination=ec2.FlowLogDestination.to_cloud_watch_logs(log_group, vpc_flow_role) + cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn( + accountid=_environment.AwsAccountId, region=_environment.region ) - # setup security group to be used for sagemaker studio domain - sagemaker_sg = ec2.SecurityGroup( - setup, - "SecurityGroup", - vpc=vpc, - description="Security Group for SageMaker Studio", + existing_default_vpc = EC2.check_default_vpc_exists( + AwsAccountId=_environment.AwsAccountId, region=_environment.region, role=cdk_look_up_role_arn ) + if existing_default_vpc: + logger.info("Using default VPC for Sagemaker Studio domain") + # Use default VPC - initial configuration (to be migrated) + vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', is_default=True) + subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] + subnet_ids += [public_subnet.subnet_id for public_subnet in vpc.public_subnets] + subnet_ids += [isolated_subnet.subnet_id for isolated_subnet in vpc.isolated_subnets] + security_groups = [] + else: + logger.info("Default VPC not found, Exception. Creating a VPC for SageMaker resources...") + # Create VPC with 3 Public Subnets and 3 Private subnets wit NAT Gateways + log_group = logs.LogGroup( + setup, + f'SageMakerStudio{_environment.name}', + log_group_name=f'/{_environment.resourcePrefix}/{_environment.name}/vpc/sagemakerstudio', + retention=logs.RetentionDays.ONE_MONTH, + removal_policy=RemovalPolicy.DESTROY, + ) + vpc_flow_role = iam.Role( + setup, 'FlowLog', + assumed_by=iam.ServicePrincipal('vpc-flow-logs.amazonaws.com') + ) + vpc = ec2.Vpc( + setup, + "SageMakerVPC", + max_azs=3, + cidr="10.10.0.0/16", + subnet_configuration=[ + ec2.SubnetConfiguration( + subnet_type=ec2.SubnetType.PUBLIC, + name="Public", + cidr_mask=24 + ), + ec2.SubnetConfiguration( + subnet_type=ec2.SubnetType.PRIVATE_WITH_NAT, + name="Private", + cidr_mask=24 + ), + ], + enable_dns_hostnames=True, + enable_dns_support=True, + ) + ec2.FlowLog( + setup, "StudioVPCFlowLog", + resource_type=ec2.FlowLogResourceType.from_vpc(vpc), + destination=ec2.FlowLogDestination.to_cloud_watch_logs(log_group, vpc_flow_role) + ) + # setup security group to be used for sagemaker studio domain + sagemaker_sg = ec2.SecurityGroup( + setup, + "SecurityGroup", + vpc=vpc, + description="Security Group for SageMaker Studio", + security_group_name=domain.sagemakerStudioDomainName, + ) - sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic()) - security_groups = [sagemaker_sg.security_group_id] - subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] + sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic()) + security_groups = [sagemaker_sg.security_group_id] + subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] vpc_id = vpc.vpc_id @@ -107,7 +114,7 @@ def extent(setup: EnvironmentSetup): setup, 'RoleForSagemakerStudioUsers', assumed_by=iam.ServicePrincipal('sagemaker.amazonaws.com'), - role_name='RoleSagemakerStudioUsers', + role_name=domain.DefaultDomainRoleName, managed_policies=[ iam.ManagedPolicy.from_managed_policy_arn( setup, @@ -123,7 +130,7 @@ def extent(setup: EnvironmentSetup): sagemaker_domain_key = kms.Key( setup, 'SagemakerDomainKmsKey', - alias='SagemakerStudioDomain', + alias=domain.sagemakerStudioDomainName, enable_key_rotation=True, admins=[ iam.ArnPrincipal(_environment.CDKRoleArn) @@ -175,7 +182,7 @@ def extent(setup: EnvironmentSetup): sagemaker_domain = sagemaker.CfnDomain( setup, 'SagemakerStudioDomain', - domain_name=f'SagemakerStudioDomain-{_environment.region}-{_environment.AwsAccountId}', + domain_name=domain.sagemakerStudioDomainName, auth_mode='IAM', default_user_settings=sagemaker.CfnDomain.UserSettingsProperty( execution_role=sagemaker_domain_role.role_arn, @@ -199,22 +206,3 @@ def extent(setup: EnvironmentSetup): parameter_name=f'/{_environment.resourcePrefix}/{_environment.environmentUri}/sagemaker/sagemakerstudio/domain_id', ) return sagemaker_domain - - @staticmethod - def check_existing_sagemaker_studio_domain(environment): - logger.info('Check if there is an existing sagemaker studio domain in the account') - try: - logger.info('check sagemaker studio domain created as part of data.all environment stack.') - cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn( - accountid=environment.AwsAccountId, region=environment.region - ) - dataall_created_domain = ParameterStoreManager.client( - AwsAccountId=environment.AwsAccountId, region=environment.region, role=cdk_look_up_role_arn - ).get_parameter(Name=f'/{environment.resourcePrefix}/{environment.environmentUri}/sagemaker/sagemakerstudio/domain_id') - return False - except ClientError as e: - logger.info(f'check sagemaker studio domain created outside of data.all. Parameter data.all not found: {e}') - existing_domain = get_sagemaker_studio_domain( - AwsAccountId=environment.AwsAccountId, region=environment.region, role=cdk_look_up_role_arn - ) - return existing_domain.get('DomainId', False) diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_models.py b/backend/dataall/modules/mlstudio/db/mlstudio_models.py index 032826588..a4c93a2fa 100644 --- a/backend/dataall/modules/mlstudio/db/mlstudio_models.py +++ b/backend/dataall/modules/mlstudio/db/mlstudio_models.py @@ -2,6 +2,7 @@ from sqlalchemy import Column, String, ForeignKey from sqlalchemy.orm import query_expression +from sqlalchemy.dialects.postgresql import ARRAY from dataall.base.db import Base from dataall.base.db import Resource, utils @@ -10,16 +11,20 @@ class SagemakerStudioDomain(Resource, Base): """Describes ORM model for sagemaker ML Studio domain""" __tablename__ = 'sagemaker_studio_domain' - environmentUri = Column(String, nullable=False) + environmentUri = Column(String, ForeignKey("environment.environmentUri")) sagemakerStudioUri = Column( String, primary_key=True, default=utils.uuid('sagemakerstudio') ) - sagemakerStudioDomainID = Column(String, nullable=False) - SagemakerStudioStatus = Column(String, nullable=False) + sagemakerStudioDomainID = Column(String, nullable=True) + SagemakerStudioStatus = Column(String, nullable=True) + sagemakerStudioDomainName = Column(String, nullable=False) AWSAccountId = Column(String, nullable=False) - RoleArn = Column(String, nullable=False) + DefaultDomainRoleName = Column(String, nullable=False) region = Column(String, default='eu-west-1') - userRoleForSagemakerStudio = query_expression() + SamlGroupName = Column(String, nullable=False) + vpcType = Column(String, nullable=True) + vpcId = Column(String, nullable=True) + subnetIds = Column(ARRAY(String), nullable=True) class SagemakerStudioUser(Resource, Base): diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py index 763ca6f92..21847b6ef 100644 --- a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py +++ b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py @@ -2,30 +2,34 @@ DAO layer that encapsulates the logic and interaction with the database for ML Studio Provides the API to retrieve / update / delete ml studio """ +from typing import Optional from sqlalchemy import or_ from sqlalchemy.sql import and_ from sqlalchemy.orm import Query +from dataall.base.utils import slugify from dataall.base.db import paginate -from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser -from dataall.core.environment.services.environment_resource_manager import EnvironmentResource +from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioDomain, SagemakerStudioUser +from dataall.base.utils.naming_convention import ( + NamingConventionService, + NamingConventionPattern, +) -class SageMakerStudioRepository(EnvironmentResource): +class SageMakerStudioRepository: """DAO layer for ML Studio""" _DEFAULT_PAGE = 1 _DEFAULT_PAGE_SIZE = 10 - def __init__(self, session): - self._session = session - - def save_sagemaker_studio_user(self, user): + @staticmethod + def save_sagemaker_studio_user(session, user): """Save SageMaker Studio user to the database""" - self._session.add(user) - self._session.commit() + session.add(user) + session.commit() - def _query_user_sagemaker_studio_users(self, username, groups, filter) -> Query: - query = self._session.query(SagemakerStudioUser).filter( + @staticmethod + def _query_user_sagemaker_studio_users(session, username, groups, filter) -> Query: + query = session.query(SagemakerStudioUser).filter( or_( SagemakerStudioUser.owner == username, SagemakerStudioUser.SamlAdminGroupName.in_(groups), @@ -44,21 +48,24 @@ def _query_user_sagemaker_studio_users(self, username, groups, filter) -> Query: ) return query - def paginated_sagemaker_studio_users(self, username, groups, filter=None) -> dict: + @staticmethod + def paginated_sagemaker_studio_users(session, username, groups, filter={}) -> dict: """Returns a page of sagemaker studio users for a data.all user""" return paginate( - query=self._query_user_sagemaker_studio_users(username, groups, filter), + query=SageMakerStudioRepository._query_user_sagemaker_studio_users(session, username, groups, filter), page=filter.get('page', SageMakerStudioRepository._DEFAULT_PAGE), page_size=filter.get('pageSize', SageMakerStudioRepository._DEFAULT_PAGE_SIZE), ).to_dict() - def find_sagemaker_studio_user(self, uri): + @staticmethod + def find_sagemaker_studio_user(session, uri): """Finds a sagemaker studio user. Returns None if it doesn't exist""" - return self._session.query(SagemakerStudioUser).get(uri) + return session.query(SagemakerStudioUser).get(uri) - def count_resources(self, environment, group_uri): + @staticmethod + def count_resources(session, environment, group_uri): return ( - self._session.query(SagemakerStudioUser) + session.query(SagemakerStudioUser) .filter( and_( SagemakerStudioUser.environmentUri == environment.environmentUri, @@ -67,3 +74,57 @@ def count_resources(self, environment, group_uri): ) .count() ) + + @staticmethod + def create_sagemaker_studio_domain(session, username, environment, data): + domain = SagemakerStudioDomain( + label=f"{data.get('label')}-domain", + owner=username, + description=data.get('description', 'No description provided'), + tags=data.get('tags', []), + SamlGroupName=environment.SamlGroupName, + environmentUri=environment.environmentUri, + AWSAccountId=environment.AwsAccountId, + region=environment.region, + SagemakerStudioStatus="PENDING", + DefaultDomainRoleName="DefaultMLStudioRole", + sagemakerStudioDomainName=slugify(data.get('label'), separator=''), + vpcType=data.get('vpcType'), + vpcId=data.get('vpcId'), + subnetIds=data.get('subnetIds', []) + ) + session.add(domain) + session.commit() + + domain.sagemakerStudioDomainName = NamingConventionService( + target_uri=domain.sagemakerStudioUri, + target_label=domain.label, + pattern=NamingConventionPattern.MLSTUDIO_DOMAIN, + resource_prefix=environment.resourcePrefix, + ).build_compliant_name() + + domain.DefaultDomainRoleName = NamingConventionService( + target_uri=domain.sagemakerStudioUri, + target_label=domain.label, + pattern=NamingConventionPattern.IAM, + resource_prefix=environment.resourcePrefix, + ).build_compliant_name() + + return domain + + @staticmethod + def get_sagemaker_studio_domain_by_env_uri(session, env_uri) -> Optional[SagemakerStudioDomain]: + domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter( + SagemakerStudioDomain.environmentUri == env_uri, + ).first() + if not domain: + return None + return domain + + @staticmethod + def delete_sagemaker_studio_domain_by_env_uri(session, env_uri) -> Optional[SagemakerStudioDomain]: + domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter( + SagemakerStudioDomain.environmentUri == env_uri, + ).first() + if domain: + session.delete(domain) diff --git a/backend/dataall/modules/mlstudio/services/mlstudio_service.py b/backend/dataall/modules/mlstudio/services/mlstudio_service.py index 06750b822..3738c118d 100644 --- a/backend/dataall/modules/mlstudio/services/mlstudio_service.py +++ b/backend/dataall/modules/mlstudio/services/mlstudio_service.py @@ -11,13 +11,18 @@ from dataall.core.environment.env_permission_checker import has_group_permission from dataall.core.environment.services.environment_service import EnvironmentService from dataall.core.permissions.db.resource_policy_repositories import ResourcePolicy +from dataall.core.permissions import permissions from dataall.core.permissions.permission_checker import has_resource_permission, has_tenant_permission from dataall.core.stacks.api import stack_helper from dataall.core.stacks.db.stack_repositories import Stack from dataall.base.db import exceptions from dataall.modules.mlstudio.aws.sagemaker_studio_client import sagemaker_studio_client, get_sagemaker_studio_domain from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository +from dataall.core.environment.services.environment_resource_manager import EnvironmentResource from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser +from dataall.base.aws.ec2_client import EC2 +from dataall.base.aws.sts import SessionHelper + from dataall.modules.mlstudio.services.mlstudio_permissions import ( MANAGE_SGMSTUDIO_USERS, CREATE_SGMSTUDIO_USER, @@ -54,6 +59,38 @@ def _session(): return get_context().db_engine.scoped_session() +class SagemakerStudioEnvironmentResource(EnvironmentResource): + @staticmethod + def count_resources(session, environment, group_uri) -> int: + return SageMakerStudioRepository.count_resources(session, environment, group_uri) + + @staticmethod + def create_env(session, environment, **kwargs): + enabled = EnvironmentService.get_boolean_env_param(session, environment, "mlStudiosEnabled") + if enabled: + SagemakerStudioService.create_sagemaker_studio_domain(session, environment, **kwargs) + + @staticmethod + def update_env(session, environment, **kwargs): + current_mlstudio_enabled = EnvironmentService.get_boolean_env_param(session, environment, "mlStudiosEnabled") + domain = SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, environment.environmentUri) + previous_mlstudio_enabled = True if domain else False + if (current_mlstudio_enabled != previous_mlstudio_enabled and previous_mlstudio_enabled): + SageMakerStudioRepository.delete_sagemaker_studio_domain_by_env_uri(session=session, env_uri=environment.environmentUri) + return True + elif (current_mlstudio_enabled != previous_mlstudio_enabled and not previous_mlstudio_enabled): + SagemakerStudioService.create_sagemaker_studio_domain(session, environment, **kwargs) + return True + elif current_mlstudio_enabled and domain and domain.vpcType == "unknown": + SagemakerStudioService.update_sagemaker_studio_domain(environment, domain, **kwargs) + return True + return False + + @staticmethod + def delete_env(session, environment): + SageMakerStudioRepository.delete_sagemaker_studio_domain_by_env_uri(session=session, env_uri=environment.environmentUri) + + class SagemakerStudioService: """ Encapsulate the logic of interactions with sagemaker ml studio. @@ -77,17 +114,19 @@ def create_sagemaker_studio_user(*, uri: str, admin_group: str, request: Sagemak action=CREATE_SGMSTUDIO_USER, message=f'ML Studio feature is disabled for the environment {env.label}', ) + + domain = SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, env_uri=env.environmentUri) response = get_sagemaker_studio_domain( AwsAccountId=env.AwsAccountId, - region=env.region + region=env.region, + domain_name=domain.sagemakerStudioDomainName ) existing_domain = response.get('DomainId', False) if not existing_domain: raise exceptions.AWSResourceNotAvailable( action='Sagemaker Studio domain', - message='Update the environment stack ' - 'or create a Sagemaker studio domain on your AWS account.', + message='Update the environment stack and enable ML Studio Environment Feature' ) sagemaker_studio_user = SagemakerStudioUser( @@ -104,7 +143,7 @@ def create_sagemaker_studio_user(*, uri: str, admin_group: str, request: Sagemak SamlAdminGroupName=admin_group, tags=request.tags, ) - SageMakerStudioRepository(session).save_sagemaker_studio_user(user=sagemaker_studio_user) + SageMakerStudioRepository.save_sagemaker_studio_user(session, sagemaker_studio_user) ResourcePolicy.attach_resource_policy( session=session, @@ -135,10 +174,58 @@ def create_sagemaker_studio_user(*, uri: str, admin_group: str, request: Sagemak return sagemaker_studio_user + @staticmethod + def update_sagemaker_studio_domain(environment, domain, data): + SagemakerStudioService._update_sagemaker_studio_domain_vpc(environment.AwsAccountId, environment.region, data) + domain.vpcType = data.get('vpcType') + if data.get('vpcId'): + domain.vpcId = data.get('vpcId') + if data.get('subnetIds'): + domain.subnetIds = data.get('subnetIds') + + @staticmethod + def _update_sagemaker_studio_domain_vpc(account_id, region, data={}): + cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn( + accountid=account_id, region=region + ) + if data.get("vpcId", None): + data["vpcType"] = "imported" + else: + response = EC2.check_default_vpc_exists( + AwsAccountId=account_id, + region=region, + role=cdk_look_up_role_arn, + ) + if response: + vpcId, subnetIds = response + data["vpcType"] = "default" + data["vpcId"] = vpcId + data["subnetIds"] = subnetIds + else: + data["vpcType"] = "created" + + @staticmethod + def create_sagemaker_studio_domain(session, environment, data: dict = {}): + SagemakerStudioService._update_sagemaker_studio_domain_vpc(environment.AwsAccountId, environment.region, data) + + domain = SageMakerStudioRepository.create_sagemaker_studio_domain( + session=session, + username=get_context().username, + environment=environment, + data=data, + ) + return domain + + @staticmethod + def get_environment_sagemaker_studio_domain(*, environment_uri: str): + with _session() as session: + return SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, env_uri=environment_uri) + @staticmethod def list_sagemaker_studio_users(*, filter: dict) -> dict: with _session() as session: - return SageMakerStudioRepository(session).paginated_sagemaker_studio_users( + return SageMakerStudioRepository.paginated_sagemaker_studio_users( + session=session, username=get_context().username, groups=get_context().groups, filter=filter, @@ -197,7 +284,7 @@ def delete_sagemaker_studio_user(*, uri: str, delete_from_aws: bool): @staticmethod def _get_sagemaker_studio_user(session, uri): - user = SageMakerStudioRepository(session).find_sagemaker_studio_user(uri=uri) + user = SageMakerStudioRepository.find_sagemaker_studio_user(session=session, uri=uri) if not user: raise exceptions.ObjectNotFound('SagemakerStudioUser', uri) return user diff --git a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py new file mode 100644 index 000000000..a3ac794f3 --- /dev/null +++ b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py @@ -0,0 +1,178 @@ +"""env_mlstudio_domain_table + +Revision ID: 71a5f5de322f +Revises: 8c79fb896983 +Create Date: 2023-11-29 09:44:04.160286 + +""" +import os +from sqlalchemy import orm, Column, String, Boolean, ForeignKey, and_ +from sqlalchemy.ext.declarative import declarative_base +import sqlalchemy as sa +from alembic import op + +from sqlalchemy.dialects import postgresql +from dataall.base.db import get_engine, has_table +from dataall.base.db import utils, Resource + +# revision identifiers, used by Alembic. +revision = '71a5f5de322f' +down_revision = '8c79fb896983' +branch_labels = None +depends_on = None + +Base = declarative_base() + + +class Environment(Resource, Base): + __tablename__ = "environment" + environmentUri = Column(String, primary_key=True) + AwsAccountId = Column(Boolean) + region = Column(Boolean) + SamlGroupName = Column(String) + + +class EnvironmentParameter(Base): + __tablename__ = 'environment_parameters' + environmentUri = Column(String, primary_key=True) + key = Column('paramKey', String, primary_key=True) + value = Column('paramValue', String, nullable=True) + + +class SagemakerStudioDomain(Resource, Base): + __tablename__ = 'sagemaker_studio_domain' + environmentUri = Column(String, ForeignKey("environment.environmentUri")) + sagemakerStudioUri = Column( + String, primary_key=True, default=utils.uuid('sagemakerstudio') + ) + sagemakerStudioDomainID = Column(String, nullable=True) + SagemakerStudioStatus = Column(String, nullable=True) + sagemakerStudioDomainName = Column(String, nullable=False) + AWSAccountId = Column(String, nullable=False) + DefaultDomainRoleName = Column(String, nullable=False) + region = Column(String, default='eu-west-1') + SamlGroupName = Column(String, nullable=False) + vpcType = Column(String, nullable=True) + + +def upgrade(): + """ + The script does the following migration: + 1) update of the sagemaker_studio_domain table to include SageMaker Studio Domain VPC Information + """ + try: + envname = os.getenv('envname', 'local') + engine = get_engine(envname=envname).engine + + bind = op.get_bind() + session = orm.Session(bind=bind) + + if has_table('sagemaker_studio_domain', engine): + print("Updating sagemaker_studio_domain table...") + op.alter_column( + 'sagemaker_studio_domain', + 'sagemakerStudioDomainID', + nullable=True, + existing_type=sa.String() + ) + op.alter_column( + 'sagemaker_studio_domain', + 'SagemakerStudioStatus', + nullable=True, + existing_type=sa.String() + ) + op.alter_column( + 'sagemaker_studio_domain', + 'RoleArn', + new_column_name='DefaultDomainRoleName', + nullable=False, + existing_type=sa.String() + ) + + op.add_column("sagemaker_studio_domain", Column("sagemakerStudioDomainName", sa.String(), nullable=False)) + op.add_column("sagemaker_studio_domain", Column("vpcType", sa.String(), nullable=True)) + op.add_column("sagemaker_studio_domain", Column("vpcId", sa.String(), nullable=True)) + op.add_column("sagemaker_studio_domain", Column("subnetIds", postgresql.ARRAY(sa.String()), nullable=True)) + op.add_column("sagemaker_studio_domain", Column("SamlGroupName", sa.String(), nullable=False)) + + op.create_foreign_key( + "fk_sagemaker_studio_domain_env_uri", + "sagemaker_studio_domain", "environment", + ["environmentUri"], ["environmentUri"], + ) + + print("Update sagemaker_studio_domain table done.") + print("Filling sagemaker_studio_domain table with environments with mlstudio enabled...") + + env_mlstudio_parameters: [EnvironmentParameter] = session.query(EnvironmentParameter).filter( + and_( + EnvironmentParameter.key == "mlStudiosEnabled", + EnvironmentParameter.value == "true" + ) + ).all() + for param in env_mlstudio_parameters: + env: Environment = session.query(Environment).filter( + Environment.environmentUri == param.environmentUri + ).first() + + domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter( + SagemakerStudioDomain.environmentUri == env.environmentUri + ).first() + if not domain: + domain = SagemakerStudioDomain( + label=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}", + owner=env.owner, + description='No description provided', + environmentUri=env.environmentUri, + AWSAccountId=env.AwsAccountId, + region=env.region, + DefaultDomainRoleName="RoleSagemakerStudioUsers", + sagemakerStudioDomainName=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}", + vpcType="unknown", + SamlGroupName=env.SamlGroupName + ) + session.add(domain) + session.flush() + session.commit() + print("Fill of sagemaker_studio_domain table is done") + + except Exception as exception: + print('Failed to upgrade due to:', exception) + raise exception + + +def downgrade(): + try: + envname = os.getenv('envname', 'local') + engine = get_engine(envname=envname).engine + + bind = op.get_bind() + session = orm.Session(bind=bind) + + if has_table('sagemaker_studio_domain', engine): + print("deleting sagemaker studio domain entries...") + session.query(SagemakerStudioDomain).delete() + + print("Updating of sagemaker_studio_domain table...") + op.alter_column( + 'sagemaker_studio_domain', + 'DefaultDomainRoleName', + new_column_name='RoleArn', + nullable=False, + existing_type=sa.String() + ) + + op.drop_column("sagemaker_studio_domain", "sagemakerStudioDomainName") + op.drop_column("sagemaker_studio_domain", "vpcType") + op.drop_column("sagemaker_studio_domain", "vpcId") + op.drop_column("sagemaker_studio_domain", "subnetIds") + op.drop_column("sagemaker_studio_domain", "SamlGroupName") + + op.drop_constraint("fk_sagemaker_studio_domain_env_uri", "sagemaker_studio_domain") + + session.commit() + print("Update of sagemaker_studio_domain table is done") + + except Exception as exception: + print('Failed to downgrade due to:', exception) + raise exception diff --git a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js new file mode 100644 index 000000000..44dac97b9 --- /dev/null +++ b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js @@ -0,0 +1,156 @@ +import { + Box, + Card, + CardHeader, + Divider, + Grid, + CardContent, + Typography, + CircularProgress, + Chip +} from '@mui/material'; + +import PropTypes from 'prop-types'; +import React, { useCallback, useEffect, useState } from 'react'; +import { RefreshTableMenu } from 'design'; +import { SET_ERROR, useDispatch } from 'globalErrors'; +import { getEnvironmentMLStudioDomain, useClient } from 'services'; + +export const EnvironmentMLStudio = ({ environment }) => { + const client = useClient(); + const dispatch = useDispatch(); + const [mlStudioDomain, setMLStudioDomain] = useState(null); + const [loading, setLoading] = useState(true); + + const fetchMLStudioDomain = useCallback(async () => { + try { + setLoading(true); + const response = await client.query( + getEnvironmentMLStudioDomain({ + environmentUri: environment.environmentUri + }) + ); + if (!response.errors) { + if (response.data.getEnvironmentMLStudioDomain) { + setMLStudioDomain(response.data.getEnvironmentMLStudioDomain); + } + } else { + dispatch({ type: SET_ERROR, error: response.errors[0].message }); + } + } catch (e) { + dispatch({ type: SET_ERROR, error: e.message }); + } finally { + setLoading(false); + } + }, [client, dispatch, environment.environmentUri]); + + useEffect(() => { + if (client) { + fetchMLStudioDomain().catch((e) => + dispatch({ type: SET_ERROR, error: e.message }) + ); + } + }, [client, fetchMLStudioDomain, dispatch]); + + if (loading) { + return ; + } + + return ( + + + } + title={ML Studio Domain Information} + /> + + + + + {mlStudioDomain === null ? ( + + + No ML Studio Domain - To Create a ML Studio Domain for this + Environment: {environment.label}, edit the Environment and enable + the ML Studio Environment Feature + + + ) : ( + + + + + SageMaker ML Studio Domain Name + + + {mlStudioDomain.sagemakerStudioDomainName} + + + + + SageMaker ML Studio Default Execution Role + + + arn:aws:iam::{environment.AwsAccountId}:role/ + {mlStudioDomain.DefaultDomainRoleName} + + + + + Domain VPC Type + + + {mlStudioDomain.vpcType} + + + {(mlStudioDomain.vpcType === 'imported' || + mlStudioDomain.vpcType === 'default') && ( + <> + + + Domain VPC Id + + + {mlStudioDomain.vpcId} + + + + + Domain Subnet Ids + + + {mlStudioDomain.subnetIds?.map((subnet) => ( + + ))} + + + + )} + + + )} + + + ); +}; + +EnvironmentMLStudio.propTypes = { + environment: PropTypes.object.isRequired +}; diff --git a/frontend/src/modules/Environments/components/index.js b/frontend/src/modules/Environments/components/index.js index afccd1235..7aecd51fa 100644 --- a/frontend/src/modules/Environments/components/index.js +++ b/frontend/src/modules/Environments/components/index.js @@ -12,3 +12,4 @@ export * from './EnvironmentTeamInviteEditForm'; export * from './EnvironmentTeamInviteForm'; export * from './EnvironmentTeams'; export * from './NetworkCreateModal'; +export * from './EnvironmentMLStudio'; diff --git a/frontend/src/modules/Environments/views/EnvironmentCreateForm.js b/frontend/src/modules/Environments/views/EnvironmentCreateForm.js index 2767e1fcc..a16cdfa7e 100644 --- a/frontend/src/modules/Environments/views/EnvironmentCreateForm.js +++ b/frontend/src/modules/Environments/views/EnvironmentCreateForm.js @@ -31,6 +31,13 @@ import { CopyToClipboard } from 'react-copy-to-clipboard/lib/Component'; import { Helmet } from 'react-helmet-async'; import { Link as RouterLink, useNavigate, useParams } from 'react-router-dom'; import * as Yup from 'yup'; +import { + createEnvironment, + getPivotRoleExternalId, + getPivotRoleName, + getPivotRolePresignedUrl, + getCDKExecPolicyPresignedUrl +} from '../services'; import { ArrowLeftIcon, ChevronRightIcon, @@ -44,13 +51,6 @@ import { useClient, useGroups } from 'services'; -import { - createEnvironment, - getPivotRoleExternalId, - getPivotRoleName, - getPivotRolePresignedUrl, - getCDKExecPolicyPresignedUrl -} from '../services'; import { AwsRegions, isAnyEnvironmentModuleEnabled, @@ -179,6 +179,8 @@ const EnvironmentCreateForm = (props) => { region: values.region, EnvironmentDefaultIAMRoleArn: values.EnvironmentDefaultIAMRoleArn, resourcePrefix: values.resourcePrefix, + vpcId: values.vpcId, + subnetIds: values.subnetIds, parameters: [ { key: 'notebooksEnabled', @@ -484,7 +486,9 @@ const EnvironmentCreateForm = (props) => { mlStudiosEnabled: isModuleEnabled(ModuleNames.MLSTUDIO), pipelinesEnabled: isModuleEnabled(ModuleNames.DATAPIPELINES), EnvironmentDefaultIAMRoleArn: '', - resourcePrefix: 'dataall' + resourcePrefix: 'dataall', + vpcId: '', + subnetIds: [] }} validationSchema={Yup.object().shape({ label: Yup.string() @@ -508,8 +512,14 @@ const EnvironmentCreateForm = (props) => { ).length >= 1 ), tags: Yup.array().nullable(), - privateSubnetIds: Yup.array().nullable(), - publicSubnetIds: Yup.array().nullable(), + subnetIds: Yup.array().when('vpcId', { + is: (value) => !!value, + then: Yup.array() + .min(1) + .required( + 'At least 1 Subnet Id required if VPC Id specified' + ) + }), vpcId: Yup.string().nullable(), EnvironmentDefaultIAMRoleArn: Yup.string().nullable(), resourcePrefix: Yup.string() @@ -862,6 +872,45 @@ const EnvironmentCreateForm = (props) => { + {values.mlStudiosEnabled && ( + + + + + + + + { + setFieldValue('subnetIds', [...chip]); + }} + /> + + + + )} {errors.submit && ( {errors.submit} diff --git a/frontend/src/modules/Environments/views/EnvironmentEditForm.js b/frontend/src/modules/Environments/views/EnvironmentEditForm.js index caa5d8441..382575920 100644 --- a/frontend/src/modules/Environments/views/EnvironmentEditForm.js +++ b/frontend/src/modules/Environments/views/EnvironmentEditForm.js @@ -30,7 +30,7 @@ import { useSettings } from 'design'; import { SET_ERROR, useDispatch } from 'globalErrors'; -import { useClient } from 'services'; +import { getEnvironmentMLStudioDomain, useClient } from 'services'; import { getEnvironment, updateEnvironment } from '../services'; import { isAnyEnvironmentModuleEnabled, @@ -47,6 +47,9 @@ const EnvironmentEditForm = (props) => { const { settings } = useSettings(); const [loading, setLoading] = useState(true); const [env, setEnv] = useState(''); + const [envMLStudioDomain, setEnvMLStudioDomain] = useState(''); + const [previousEnvMLStudioEnabled, setPreviousEnvMLStudioEnabled] = + useState(false); const fetchItem = useCallback(async () => { const response = await client.query( @@ -58,6 +61,20 @@ const EnvironmentEditForm = (props) => { environment.parameters.map((x) => [x.key, x.value]) ); setEnv(environment); + if (environment.parameters['mlStudiosEnabled'] === 'true') { + setPreviousEnvMLStudioEnabled(true); + const response2 = await client.query( + getEnvironmentMLStudioDomain({ environmentUri: params.uri }) + ); + if (!response2.errors && response2.data.getEnvironmentMLStudioDomain) { + setEnvMLStudioDomain(response2.data.getEnvironmentMLStudioDomain); + } else { + const error = response2.errors + ? response2.errors[0].message + : 'Environment ML Studio Domain not found'; + dispatch({ type: SET_ERROR, error }); + } + } } else { const error = response.errors ? response.errors[0].message @@ -66,11 +83,13 @@ const EnvironmentEditForm = (props) => { } setLoading(false); }, [client, dispatch, params.uri]); + useEffect(() => { if (client) { fetchItem().catch((e) => dispatch({ type: SET_ERROR, error: e.message })); } }, [client, fetchItem, dispatch]); + async function submit(values, setStatus, setSubmitting, setErrors) { try { const response = await client.mutate( @@ -81,6 +100,8 @@ const EnvironmentEditForm = (props) => { tags: values.tags, description: values.description, resourcePrefix: values.resourcePrefix, + vpcId: values.vpcId, + subnetIds: values.subnetIds, parameters: [ { key: 'notebooksEnabled', @@ -213,6 +234,8 @@ const EnvironmentEditForm = (props) => { label: env.label, description: env.description, tags: env.tags || [], + vpcId: envMLStudioDomain.vpcId || '', + subnetIds: envMLStudioDomain.subnetIds || [], notebooksEnabled: env.parameters['notebooksEnabled'] === 'true', mlStudiosEnabled: env.parameters['mlStudiosEnabled'] === 'true', pipelinesEnabled: env.parameters['pipelinesEnabled'] === 'true', @@ -226,6 +249,15 @@ const EnvironmentEditForm = (props) => { .required('*Environment name is required'), description: Yup.string().max(5000), tags: Yup.array().nullable(), + subnetIds: Yup.array().when('vpcId', { + is: (value) => !!value, + then: Yup.array() + .min(1) + .required( + 'At least 1 Subnet Id required if VPC Id specified' + ) + }), + vpcId: Yup.string().nullable(), resourcePrefix: Yup.string() .trim() .matches( @@ -383,6 +415,48 @@ const EnvironmentEditForm = (props) => { + {!previousEnvMLStudioEnabled && + values.mlStudiosEnabled && ( + + + + + + + + { + setFieldValue('subnetIds', [...chip]); + }} + /> + + + + )} {isAnyEnvironmentModuleEnabled() && ( diff --git a/frontend/src/modules/Environments/views/EnvironmentView.js b/frontend/src/modules/Environments/views/EnvironmentView.js index 0ba724320..792918c13 100644 --- a/frontend/src/modules/Environments/views/EnvironmentView.js +++ b/frontend/src/modules/Environments/views/EnvironmentView.js @@ -39,6 +39,7 @@ import { archiveEnvironment, getEnvironment } from '../services'; import { KeyValueTagList, Stack, StackStatus } from 'modules/Shared'; import { EnvironmentDatasets, + EnvironmentMLStudio, EnvironmentOverview, EnvironmentSubscriptions, EnvironmentTeams, @@ -59,6 +60,12 @@ const tabs = [ icon: , active: isModuleEnabled(ModuleNames.DATASETS) }, + { + label: 'ML Studio Domain', + value: 'mlstudio', + icon: , + active: isModuleEnabled(ModuleNames.MLSTUDIO) + }, { label: 'Networks', value: 'networks', icon: }, { label: 'Subscriptions', @@ -267,6 +274,9 @@ const EnvironmentView = () => { fetchItem={fetchItem} /> )} + {isAdmin && currentTab === 'mlstudio' && ( + + )} {isAdmin && currentTab === 'tags' && ( ({ + variables: { + environmentUri + }, + query: gql` + query getEnvironmentMLStudioDomain($environmentUri: String) { + getEnvironmentMLStudioDomain(environmentUri: $environmentUri) { + sagemakerStudioUri + environmentUri + label + sagemakerStudioDomainName + DefaultDomainRoleName + vpcType + vpcId + subnetIds + owner + created + } + } + ` +}); diff --git a/frontend/src/services/graphql/MLStudio/index.js b/frontend/src/services/graphql/MLStudio/index.js new file mode 100644 index 000000000..97d3de110 --- /dev/null +++ b/frontend/src/services/graphql/MLStudio/index.js @@ -0,0 +1 @@ +export * from './getEnvironmentMLStudioDomain'; diff --git a/frontend/src/services/graphql/index.js b/frontend/src/services/graphql/index.js index 8d0e00804..ce1c3fba2 100644 --- a/frontend/src/services/graphql/index.js +++ b/frontend/src/services/graphql/index.js @@ -8,6 +8,7 @@ export * from './Glossary'; export * from './Groups'; export * from './KeyValueTags'; export * from './Metric'; +export * from './MLStudio'; export * from './Notification'; export * from './Organization'; export * from './Principal'; diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 6d8a449e4..738ab4d06 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -44,7 +44,6 @@ def factory(org, envname, owner, group, account, region, desc='test', parameters 'tags': ['a', 'b', 'c'], 'region': f'{region}', 'SamlGroupName': f'{group}', - 'vpcId': 'vpc-123456', 'parameters': [{'key': k, 'value': v} for k, v in parameters.items()] }, ) diff --git a/tests/core/environments/test_environment.py b/tests/core/environments/test_environment.py index 31ba18e57..e806e07b2 100644 --- a/tests/core/environments/test_environment.py +++ b/tests/core/environments/test_environment.py @@ -221,26 +221,6 @@ def test_list_environments_no_filter(org_fixture, env_fixture, client, group): assert response.data.listEnvironments.count == 1 - response = client.query( - """ - query ListEnvironmentNetworks($environmentUri: String!,$filter:VpcFilter){ - listEnvironmentNetworks(environmentUri:$environmentUri,filter:$filter){ - count - nodes{ - VpcId - SamlGroupName - } - } - } - """, - environmentUri=env_fixture.environmentUri, - username='alice', - groups=[group.name], - ) - print(response) - - assert response.data.listEnvironmentNetworks.count == 1 - def test_list_environment_role_filter_as_creator(org_fixture, env_fixture, client, group): response = client.query( @@ -656,23 +636,16 @@ def test_create_environment(db, client, org_fixture, env_fixture, user, group): 'tags': ['a', 'b', 'c'], 'region': f'{env_fixture.region}', 'SamlGroupName': group.name, - 'vpcId': 'vpc-1234567', - 'privateSubnetIds': 'subnet-1', - 'publicSubnetIds': 'subnet-21', 'resourcePrefix': 'customer-prefix', }, ) body = response.data.createEnvironment - assert body.networks + assert len(body.networks) == 0 assert body.EnvironmentDefaultIAMRoleName == 'myOwnIamRole' assert body.EnvironmentDefaultIAMRoleImported assert body.resourcePrefix == 'customer-prefix' - for vpc in body.networks: - assert vpc.privateSubnetIds - assert vpc.publicSubnetIds - assert vpc.default with db.scoped_session() as session: env = EnvironmentService.get_environment_by_uri( diff --git a/tests/core/vpc/test_vpc.py b/tests/core/vpc/test_vpc.py index a55196d32..8f2391220 100644 --- a/tests/core/vpc/test_vpc.py +++ b/tests/core/vpc/test_vpc.py @@ -60,7 +60,7 @@ def test_list_networks(client, env_fixture, db, user, group, vpc): ) print(response) - assert response.data.listEnvironmentNetworks.count == 2 + assert response.data.listEnvironmentNetworks.count == 1 def test_list_networks_nopermissions(client, env_fixture, db, user, group2, vpc): @@ -119,4 +119,4 @@ def test_delete_network(client, env_fixture, db, user, group, module_mocker, vpc username='alice', groups=[group.name], ) - assert len(response.data.listEnvironmentNetworks['nodes']) == 1 + assert len(response.data.listEnvironmentNetworks['nodes']) == 0 diff --git a/tests/modules/mlstudio/cdk/conftest.py b/tests/modules/mlstudio/cdk/conftest.py index 4b3327838..2c6f1eddd 100644 --- a/tests/modules/mlstudio/cdk/conftest.py +++ b/tests/modules/mlstudio/cdk/conftest.py @@ -2,7 +2,7 @@ from dataall.core.environment.db.environment_models import Environment from dataall.core.organizations.db.organization_models import Organization -from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser +from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser, SagemakerStudioDomain @pytest.fixture(scope='module', autouse=True) @@ -23,3 +23,21 @@ def sgm_studio(db, env_fixture: Environment) -> SagemakerStudioUser: ) session.add(sm_user) yield sm_user + +@pytest.fixture(scope='module', autouse=True) +def sgm_studio_domain(db, env_fixture: Environment) -> SagemakerStudioDomain: + with db.scoped_session() as session: + sm_domain = SagemakerStudioDomain( + label='sagemaker-domain', + owner='me', + environmentUri=env_fixture.environmentUri, + AWSAccountId=env_fixture.AwsAccountId, + region=env_fixture.region, + SagemakerStudioStatus="PENDING", + DefaultDomainRoleName="DefaultMLStudioRole", + sagemakerStudioDomainName="DomainName", + vpcType="created", + SamlGroupName=env_fixture.SamlGroupName, + ) + session.add(sm_domain) + yield sm_domain diff --git a/tests/modules/mlstudio/cdk/test_sagemaker_studio_stack.py b/tests/modules/mlstudio/cdk/test_sagemaker_studio_stack.py index a2c1752e2..8e0cd6166 100644 --- a/tests/modules/mlstudio/cdk/test_sagemaker_studio_stack.py +++ b/tests/modules/mlstudio/cdk/test_sagemaker_studio_stack.py @@ -66,16 +66,19 @@ def patch_methods_sagemaker_studio(mocker, db, sgm_studio, env_fixture, org_fixt @pytest.fixture(scope='function', autouse=True) -def patch_methods_sagemaker_studio_extension(mocker): +def patch_methods_sagemaker_studio_extension(mocker, sgm_studio_domain): mocker.patch( 'dataall.base.aws.sts.SessionHelper.get_cdk_look_up_role_arn', return_value="arn:aws:iam::1111111111:role/cdk-hnb659fds-lookup-role-1111111111-eu-west-1", ) mocker.patch( - 'dataall.modules.mlstudio.aws.ec2_client.EC2.check_default_vpc_exists', + 'dataall.base.aws.ec2_client.EC2.check_default_vpc_exists', return_value=False, ) - + mocker.patch( + 'dataall.modules.mlstudio.db.mlstudio_repositories.SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri', + return_value=sgm_studio_domain, + ) def test_resources_sgmstudio_stack_created(sgm_studio): app = App() diff --git a/tests/modules/mlstudio/conftest.py b/tests/modules/mlstudio/conftest.py index 433048894..d1fffb2cf 100644 --- a/tests/modules/mlstudio/conftest.py +++ b/tests/modules/mlstudio/conftest.py @@ -1,6 +1,6 @@ import pytest -from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser +from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser, SagemakerStudioDomain @pytest.fixture(scope='module', autouse=True) @@ -16,8 +16,31 @@ def env_params(): yield {'mlStudiosEnabled': 'True'} +@pytest.fixture(scope='module', autouse=True) +def get_cdk_look_up_role_arn(module_mocker): + module_mocker.patch( + 'dataall.base.aws.sts.SessionHelper.get_cdk_look_up_role_arn', + return_value="arn:aws:iam::1111111111:role/cdk-hnb659fds-lookup-role-1111111111-eu-west-1", + ) + +@pytest.fixture(scope='module', autouse=True) +def check_default_vpc(module_mocker): + module_mocker.patch( + 'dataall.base.aws.ec2_client.EC2.check_default_vpc_exists', + return_value=False, + ) + + +@pytest.fixture(scope='module', autouse=True) +def check_vpc_exists(module_mocker): + module_mocker.patch( + 'dataall.base.aws.ec2_client.EC2.check_vpc_exists', + return_value=True, + ) + + @pytest.fixture(scope='module') -def sagemaker_studio_user(client, tenant, group, env_fixture) -> SagemakerStudioUser: +def sagemaker_studio_user(client, tenant, group, env_with_mlstudio) -> SagemakerStudioUser: response = client.query( """ mutation createSagemakerStudioUser($input:NewSagemakerStudioUserInput){ @@ -36,7 +59,7 @@ def sagemaker_studio_user(client, tenant, group, env_fixture) -> SagemakerStudio input={ 'label': 'testcreate', 'SamlAdminGroupName': group.name, - 'environmentUri': env_fixture.environmentUri, + 'environmentUri': env_with_mlstudio.environmentUri, }, username='alice', groups=[group.name], @@ -45,7 +68,7 @@ def sagemaker_studio_user(client, tenant, group, env_fixture) -> SagemakerStudio @pytest.fixture(scope='module') -def multiple_sagemaker_studio_users(client, db, env_fixture, group): +def multiple_sagemaker_studio_users(client, db, env_with_mlstudio, group): for i in range(0, 10): response = client.query( """ @@ -65,7 +88,7 @@ def multiple_sagemaker_studio_users(client, db, env_fixture, group): input={ 'label': f'test{i}', 'SamlAdminGroupName': group.name, - 'environmentUri': env_fixture.environmentUri, + 'environmentUri': env_with_mlstudio.environmentUri, }, username='alice', groups=[group.name], @@ -77,5 +100,92 @@ def multiple_sagemaker_studio_users(client, db, env_fixture, group): ) assert ( response.data.createSagemakerStudioUser.environmentUri - == env_fixture.environmentUri + == env_with_mlstudio.environmentUri ) + +@pytest.fixture(scope='module') +def env_with_mlstudio(client, org_fixture, user, group, parameters=None, vpcId='', subnetIds=[]): + if not parameters: + parameters = {'mlStudiosEnabled': 'True'} + response = client.query( + """mutation CreateEnv($input:NewEnvironmentInput){ + createEnvironment(input:$input){ + organization{ + organizationUri + } + environmentUri + label + AwsAccountId + SamlGroupName + region + name + owner + parameters { + key + value + } + } + }""", + username=f'{user.username}', + groups=['testadmins'], + input={ + 'label': f'dev', + 'description': '', + 'organizationUri': org_fixture.organizationUri, + 'AwsAccountId': '111111111111', + 'tags': [], + 'region': 'us-east-1', + 'SamlGroupName': 'testadmins', + 'parameters': [{'key': k, 'value': v} for k, v in parameters.items()], + 'vpcId': vpcId, + 'subnetIds': subnetIds + }, + ) + yield response.data.createEnvironment + + +@pytest.fixture(scope='module', autouse=True) +def org(client): + cache = {} + + def factory(orgname, owner, group): + key = orgname + owner + group + if cache.get(key): + print(f'returning item from cached key {key}') + return cache.get(key) + response = client.query( + """mutation CreateOrganization($input:NewOrganizationInput){ + createOrganization(input:$input){ + organizationUri + label + name + owner + SamlGroupName + } + }""", + username=f'{owner}', + groups=[group], + input={ + 'label': f'{orgname}', + 'description': f'test', + 'tags': ['a', 'b', 'c'], + 'SamlGroupName': f'{group}', + }, + ) + cache[key] = response.data.createOrganization + return cache[key] + + yield factory + + +@pytest.fixture(scope='module') +def org_fixture(org, user, group): + org1 = org('testorg', user.username, group.name) + yield org1 + + +@pytest.fixture(scope='module') +def env_mlstudio_fixture(env, org_fixture, user, group, tenant): + env1 = env_with_mlstudio(org_fixture, 'dev', 'alice', 'testadmins', '111111111111', 'eu-west-1') + yield env1 + diff --git a/tests/modules/mlstudio/test_sagemaker_studio.py b/tests/modules/mlstudio/test_sagemaker_studio.py index c55762522..3d90b405a 100644 --- a/tests/modules/mlstudio/test_sagemaker_studio.py +++ b/tests/modules/mlstudio/test_sagemaker_studio.py @@ -1,14 +1,43 @@ from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser -def test_create_sagemaker_studio_user(sagemaker_studio_user, group, env_fixture): +def test_create_sagemaker_studio_domain(db, client, org_fixture, env_with_mlstudio, user, group, vpcId="vpc-1234", subnetIds=["subnet"]): + response = client.query( + """ + query getEnvironmentMLStudioDomain($environmentUri: String) { + getEnvironmentMLStudioDomain(environmentUri: $environmentUri) { + sagemakerStudioUri + environmentUri + label + sagemakerStudioDomainName + DefaultDomainRoleName + vpcType + vpcId + subnetIds + owner + created + } + } + """, + environmentUri=env_with_mlstudio.environmentUri, + ) + + assert response.data.getEnvironmentMLStudioDomain.sagemakerStudioUri + assert response.data.getEnvironmentMLStudioDomain.label == f'{env_with_mlstudio.label}-domain' + assert response.data.getEnvironmentMLStudioDomain.vpcType == 'created' + assert len(response.data.getEnvironmentMLStudioDomain.vpcId) == 0 + assert len(response.data.getEnvironmentMLStudioDomain.subnetIds) == 0 + assert response.data.getEnvironmentMLStudioDomain.environmentUri == env_with_mlstudio.environmentUri + + +def test_create_sagemaker_studio_user(sagemaker_studio_user, group, env_with_mlstudio): """Testing that the conftest sagemaker studio user has been created correctly""" assert sagemaker_studio_user.label == 'testcreate' assert sagemaker_studio_user.SamlAdminGroupName == group.name - assert sagemaker_studio_user.environmentUri == env_fixture.environmentUri + assert sagemaker_studio_user.environmentUri == env_with_mlstudio.environmentUri -def test_list_sagemaker_studio_users(client, env_fixture, db, group, multiple_sagemaker_studio_users): +def test_list_sagemaker_studio_users(client, db, group, multiple_sagemaker_studio_users): response = client.query( """ query listSagemakerStudioUsers($filter:SagemakerStudioUserFilter!){ @@ -67,3 +96,114 @@ def test_delete_sagemaker_studio_user( sagemaker_studio_user.sagemakerStudioUserUri ) assert not n + +def update_env_query(): + query = """ + mutation UpdateEnv($environmentUri:String!,$input:ModifyEnvironmentInput){ + updateEnvironment(environmentUri:$environmentUri,input:$input){ + organization{ + organizationUri + } + label + AwsAccountId + region + SamlGroupName + owner + tags + resourcePrefix + parameters { + key + value + } + } + } + """ + return query + +def test_update_env_delete_domain(client, org_fixture, env_with_mlstudio, group, group2): + response = client.query( + update_env_query(), + username='alice', + environmentUri=env_with_mlstudio.environmentUri, + input={ + 'label': 'DEV', + 'tags': [], + 'parameters': [ + { + 'key': 'mlStudiosEnabled', + 'value': 'False' + } + ], + }, + groups=[group.name], + ) + + response = client.query( + """ + query getEnvironmentMLStudioDomain($environmentUri: String) { + getEnvironmentMLStudioDomain(environmentUri: $environmentUri) { + sagemakerStudioUri + environmentUri + label + sagemakerStudioDomainName + DefaultDomainRoleName + vpcType + vpcId + subnetIds + owner + created + } + } + """, + environmentUri=env_with_mlstudio.environmentUri, + ) + assert response.data.getEnvironmentMLStudioDomain is None + + +def test_update_env_create_domain_with_vpc(db, client, org_fixture, env_with_mlstudio, user, group): + response = client.query( + update_env_query(), + username='alice', + environmentUri=env_with_mlstudio.environmentUri, + input={ + 'label': 'dev', + 'tags': [], + 'vpcId': "vpc-12345", + 'subnetIds': ['subnet-12345', 'subnet-67890'], + 'parameters': [ + { + 'key': 'mlStudiosEnabled', + 'value': 'True' + } + ], + }, + groups=[group.name], + ) + + response = client.query( + """ + query getEnvironmentMLStudioDomain($environmentUri: String) { + getEnvironmentMLStudioDomain(environmentUri: $environmentUri) { + sagemakerStudioUri + environmentUri + label + sagemakerStudioDomainName + DefaultDomainRoleName + vpcType + vpcId + subnetIds + owner + created + } + } + """, + environmentUri=env_with_mlstudio.environmentUri, + ) + + assert response.data.getEnvironmentMLStudioDomain.sagemakerStudioUri + assert response.data.getEnvironmentMLStudioDomain.label == f'{env_with_mlstudio.label}-domain' + assert response.data.getEnvironmentMLStudioDomain.vpcType == 'imported' + assert response.data.getEnvironmentMLStudioDomain.vpcId == 'vpc-12345' + assert len(response.data.getEnvironmentMLStudioDomain.subnetIds) == 2 + assert response.data.getEnvironmentMLStudioDomain.environmentUri == env_with_mlstudio.environmentUri +