diff --git a/backend/dataall/base/aws/ec2_client.py b/backend/dataall/base/aws/ec2_client.py new file mode 100644 index 000000000..06bd62c7a --- /dev/null +++ b/backend/dataall/base/aws/ec2_client.py @@ -0,0 +1,65 @@ +import logging + +from dataall.base.aws.sts import SessionHelper +from botocore.exceptions import ClientError + +log = logging.getLogger(__name__) + + +class EC2: + + @staticmethod + def get_client(account_id: str, region: str, role=None): + session = SessionHelper.remote_session(accountid=account_id, role=role) + return session.client('ec2', region_name=region) + + @staticmethod + def check_default_vpc_exists(AwsAccountId: str, region: str, role=None): + log.info("Check that default VPC exists..") + client = EC2.get_client(account_id=AwsAccountId, region=region, role=role) + response = client.describe_vpcs( + Filters=[{'Name': 'isDefault', 'Values': ['true']}] + ) + vpcs = response['Vpcs'] + log.info(f"Default VPCs response: {vpcs}") + if vpcs: + vpc_id = vpcs[0]['VpcId'] + subnetIds = EC2._get_vpc_subnets(AwsAccountId=AwsAccountId, region=region, vpc_id=vpc_id, role=role) + if subnetIds: + return vpc_id, subnetIds + return False + + @staticmethod + def _get_vpc_subnets(AwsAccountId: str, region: str, vpc_id: str, role=None): + client = EC2.get_client(account_id=AwsAccountId, region=region, role=role) + response = client.describe_subnets( + Filters=[{'Name': 'vpc-id', 'Values': [vpc_id]}] + ) + return [subnet['SubnetId'] for subnet in response['Subnets']] + + @staticmethod + def check_vpc_exists(AwsAccountId, region, vpc_id, role=None, subnet_ids=[]): + try: + ec2 = EC2.get_client(account_id=AwsAccountId, region=region, role=role) + response = ec2.describe_vpcs(VpcIds=[vpc_id]) + except ClientError as e: + log.exception(f'VPC Id {vpc_id} Not Found: {e}') + raise Exception(f'VPCNotFound: {vpc_id}') + + try: + if subnet_ids: + response = ec2.describe_subnets( + Filters=[ + { + 'Name': 'vpc-id', + 'Values': [vpc_id] + }, + ], + SubnetIds=subnet_ids + ) + except ClientError as e: + log.exception(f'Subnet Id {subnet_ids} Not Found: {e}') + raise Exception(f'VPCSubnetsNotFound: {subnet_ids}') + + if not subnet_ids or len(response['Subnets']) != len(subnet_ids): + raise Exception(f'Not All Subnets: {subnet_ids} Are Within the Specified VPC Id {vpc_id}') diff --git a/backend/dataall/base/utils/naming_convention.py b/backend/dataall/base/utils/naming_convention.py index 3501fa71b..262964560 100644 --- a/backend/dataall/base/utils/naming_convention.py +++ b/backend/dataall/base/utils/naming_convention.py @@ -10,6 +10,7 @@ class NamingConventionPattern(Enum): GLUE = {'regex': '[^a-zA-Z0-9_]', 'separator': '_', 'max_length': 63} GLUE_ETL = {'regex': '[^a-zA-Z0-9-]', 'separator': '-', 'max_length': 52} NOTEBOOK = {'regex': '[^a-zA-Z0-9-]', 'separator': '-', 'max_length': 63} + MLSTUDIO_DOMAIN = {'regex': '[^a-zA-Z0-9-]', 'separator': '-', 'max_length': 63} DEFAULT = {'regex': '[^a-zA-Z0-9-_]', 'separator': '-', 'max_length': 63} OPENSEARCH = {'regex': '[^a-z0-9-]', 'separator': '-', 'max_length': 27} OPENSEARCH_SERVERLESS = {'regex': '[^a-z0-9-]', 'separator': '-', 'max_length': 31} diff --git a/backend/dataall/core/environment/api/input_types.py b/backend/dataall/core/environment/api/input_types.py index 9b618d0e5..27188f4ed 100644 --- a/backend/dataall/core/environment/api/input_types.py +++ b/backend/dataall/core/environment/api/input_types.py @@ -28,13 +28,11 @@ gql.Argument('description', gql.String), gql.Argument('AwsAccountId', gql.NonNullableType(gql.String)), gql.Argument('region', gql.NonNullableType(gql.String)), - gql.Argument('vpcId', gql.String), - gql.Argument('privateSubnetIds', gql.ArrayType(gql.String)), - gql.Argument('publicSubnetIds', gql.ArrayType(gql.String)), gql.Argument('EnvironmentDefaultIAMRoleArn', gql.String), gql.Argument('resourcePrefix', gql.String), - gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)) - + gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)), + gql.Argument('vpcId', gql.String), + gql.Argument('subnetIds', gql.ArrayType(gql.String)) ], ) @@ -45,11 +43,10 @@ gql.Argument('description', gql.String), gql.Argument('tags', gql.ArrayType(gql.String)), gql.Argument('SamlGroupName', gql.String), - gql.Argument('vpcId', gql.String), - gql.Argument('privateSubnetIds', gql.ArrayType(gql.String)), - gql.Argument('publicSubnetIds', gql.ArrayType(gql.String)), gql.Argument('resourcePrefix', gql.String), - gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)) + gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)), + gql.Argument('vpcId', gql.String), + gql.Argument('subnetIds', gql.ArrayType(gql.String)) ], ) diff --git a/backend/dataall/core/environment/api/resolvers.py b/backend/dataall/core/environment/api/resolvers.py index 7f3e4c765..06878cdfc 100644 --- a/backend/dataall/core/environment/api/resolvers.py +++ b/backend/dataall/core/environment/api/resolvers.py @@ -20,6 +20,7 @@ from dataall.core.stacks.aws.cloudformation import CloudFormation from dataall.core.stacks.db.stack_repositories import Stack from dataall.core.vpc.db.vpc_repositories import Vpc +from dataall.base.aws.ec2_client import EC2 from dataall.base.db import exceptions from dataall.core.permissions import permissions from dataall.base.feature_toggle_checker import is_feature_enabled @@ -43,7 +44,7 @@ def get_pivot_role_as_part_of_environment(context: Context, source, **kwargs): return True if ssm_param == "True" else False -def check_environment(context: Context, source, account_id, region): +def check_environment(context: Context, source, account_id, region, data): """ Checks necessary resources for environment deployment. - Check CDKToolkit exists in Account assuming cdk_look_up_role - Check Pivot Role exists in Account if pivot_role_as_part_of_environment is False @@ -71,11 +72,25 @@ def check_environment(context: Context, source, account_id, region): action='CHECK_PIVOT_ROLE', message='Pivot Role has not been created in the Environment AWS Account', ) + mlStudioEnabled = None + for parameter in data.get("parameters", []): + if parameter['key'] == 'mlStudiosEnabled': + mlStudioEnabled = parameter['value'] + + if mlStudioEnabled and data.get("vpcId", None) and data.get("subnetIds", []): + log.info("Check if ML Studio VPC Exists in the Account") + EC2.check_vpc_exists( + AwsAccountId=account_id, + region=region, + role=cdk_look_up_role_arn, + vpc_id=data.get("vpcId", None), + subnet_ids=data.get('subnetIds', []), + ) return cdk_role_name -def create_environment(context: Context, source, input=None): +def create_environment(context: Context, source, input={}): if input.get('SamlGroupName') and input.get('SamlGroupName') not in context.groups: raise exceptions.UnauthorizedOperation( action=permissions.LINK_ENVIRONMENT, @@ -85,8 +100,10 @@ def create_environment(context: Context, source, input=None): with context.engine.scoped_session() as session: cdk_role_name = check_environment(context, source, account_id=input.get('AwsAccountId'), - region=input.get('region') + region=input.get('region'), + data=input ) + input['cdk_role_name'] = cdk_role_name env = EnvironmentService.create_environment( session=session, @@ -119,7 +136,8 @@ def update_environment( environment = EnvironmentService.get_environment_by_uri(session, environmentUri) cdk_role_name = check_environment(context, source, account_id=environment.AwsAccountId, - region=environment.region + region=environment.region, + data=input ) previous_resource_prefix = environment.resourcePrefix @@ -130,7 +148,7 @@ def update_environment( data=input, ) - if EnvironmentResourceManager.deploy_updated_stack(session, previous_resource_prefix, environment): + if EnvironmentResourceManager.deploy_updated_stack(session, previous_resource_prefix, environment, data=input): stack_helper.deploy_stack(targetUri=environment.environmentUri) return environment diff --git a/backend/dataall/core/environment/services/environment_resource_manager.py b/backend/dataall/core/environment/services/environment_resource_manager.py index bc74f01bf..f5c2551fa 100644 --- a/backend/dataall/core/environment/services/environment_resource_manager.py +++ b/backend/dataall/core/environment/services/environment_resource_manager.py @@ -12,7 +12,11 @@ def delete_env(session, environment): pass @staticmethod - def update_env(session, environment): + def create_env(session, environment, **kwargs): + pass + + @staticmethod + def update_env(session, environment, **kwargs): return False @staticmethod @@ -39,10 +43,10 @@ def count_group_resources(cls, session, environment, group_uri) -> int: return counter @classmethod - def deploy_updated_stack(cls, session, prev_prefix, environment): + def deploy_updated_stack(cls, session, prev_prefix, environment, **kwargs): deploy_stack = prev_prefix != environment.resourcePrefix for resource in cls._resources: - deploy_stack |= resource.update_env(session, environment) + deploy_stack |= resource.update_env(session, environment, **kwargs) return deploy_stack @@ -51,6 +55,11 @@ def delete_env(cls, session, environment): for resource in cls._resources: resource.delete_env(session, environment) + @classmethod + def create_env(cls, session, environment, **kwargs): + for resource in cls._resources: + resource.create_env(session, environment, **kwargs) + @classmethod def count_consumption_role_resources(cls, session, role_uri): counter = 0 diff --git a/backend/dataall/core/environment/services/environment_service.py b/backend/dataall/core/environment/services/environment_service.py index ddea435c4..1b2dbec07 100644 --- a/backend/dataall/core/environment/services/environment_service.py +++ b/backend/dataall/core/environment/services/environment_service.py @@ -66,6 +66,7 @@ def create_environment(session, uri, data=None): session.commit() EnvironmentService._update_env_parameters(session, env, data) + EnvironmentResourceManager.create_env(session, env, data=data) env.EnvironmentDefaultBucketName = NamingConventionService( target_uri=env.environmentUri, @@ -98,29 +99,6 @@ def create_environment(session, uri, data=None): env.EnvironmentDefaultIAMRoleArn = data['EnvironmentDefaultIAMRoleArn'] env.EnvironmentDefaultIAMRoleImported = True - if data.get('vpcId'): - vpc = Vpc( - environmentUri=env.environmentUri, - region=env.region, - AwsAccountId=env.AwsAccountId, - VpcId=data.get('vpcId'), - privateSubnetIds=data.get('privateSubnetIds', []), - publicSubnetIds=data.get('publicSubnetIds', []), - SamlGroupName=data['SamlGroupName'], - owner=context.username, - label=f"{env.name}-{data.get('vpcId')}", - name=f"{env.name}-{data.get('vpcId')}", - default=True, - ) - session.add(vpc) - session.commit() - ResourcePolicy.attach_resource_policy( - session=session, - group=data['SamlGroupName'], - permissions=permissions.NETWORK_ALL, - resource_uri=vpc.vpcUri, - resource_type=Vpc.__name__, - ) env_group = EnvironmentGroup( environmentUri=env.environmentUri, groupUri=data['SamlGroupName'], diff --git a/backend/dataall/modules/dashboards/db/dashboard_repositories.py b/backend/dataall/modules/dashboards/db/dashboard_repositories.py index 91916f8ff..a8d9d6a2f 100644 --- a/backend/dataall/modules/dashboards/db/dashboard_repositories.py +++ b/backend/dataall/modules/dashboards/db/dashboard_repositories.py @@ -26,7 +26,7 @@ def count_resources(session, environment, group_uri) -> int: ) @staticmethod - def update_env(session, environment): + def update_env(session, environment, **kwargs): return EnvironmentService.get_boolean_env_param(session, environment, "dashboardsEnabled") @staticmethod diff --git a/backend/dataall/modules/mlstudio/__init__.py b/backend/dataall/modules/mlstudio/__init__.py index 2db9c0a1e..a6ca73917 100644 --- a/backend/dataall/modules/mlstudio/__init__.py +++ b/backend/dataall/modules/mlstudio/__init__.py @@ -3,7 +3,8 @@ from dataall.base.loader import ImportMode, ModuleInterface from dataall.core.stacks.db.target_type_repositories import TargetType -from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository +from dataall.modules.mlstudio.services.mlstudio_service import SagemakerStudioEnvironmentResource +from dataall.core.environment.services.environment_resource_manager import EnvironmentResourceManager log = logging.getLogger(__name__) @@ -20,6 +21,8 @@ def __init__(self): from dataall.modules.mlstudio.services.mlstudio_permissions import GET_SGMSTUDIO_USER, UPDATE_SGMSTUDIO_USER TargetType("mlstudio", GET_SGMSTUDIO_USER, UPDATE_SGMSTUDIO_USER) + EnvironmentResourceManager.register(SagemakerStudioEnvironmentResource()) + log.info("API of sagemaker mlstudio has been imported") diff --git a/backend/dataall/modules/mlstudio/api/queries.py b/backend/dataall/modules/mlstudio/api/queries.py index 457559def..ee014839f 100644 --- a/backend/dataall/modules/mlstudio/api/queries.py +++ b/backend/dataall/modules/mlstudio/api/queries.py @@ -4,6 +4,7 @@ get_sagemaker_studio_user, list_sagemaker_studio_users, get_sagemaker_studio_user_presigned_url, + get_environment_sagemaker_studio_domain ) getSagemakerStudioUser = gql.QueryField( @@ -34,3 +35,12 @@ type=gql.String, resolver=get_sagemaker_studio_user_presigned_url, ) + +getEnvironmentMLStudioDomain = gql.QueryField( + name='getEnvironmentMLStudioDomain', + args=[ + gql.Argument(name='environmentUri', type=gql.NonNullableType(gql.String)), + ], + type=gql.Ref('SagemakerStudioDomain'), + resolver=get_environment_sagemaker_studio_domain, +) diff --git a/backend/dataall/modules/mlstudio/api/resolvers.py b/backend/dataall/modules/mlstudio/api/resolvers.py index 63dc25ed7..48c9350fa 100644 --- a/backend/dataall/modules/mlstudio/api/resolvers.py +++ b/backend/dataall/modules/mlstudio/api/resolvers.py @@ -18,7 +18,7 @@ def required_uri(uri): raise exceptions.RequiredParameter('URI') @staticmethod - def validate_creation_request(data): + def validate_user_creation_request(data): required = RequestValidator._required if not data: raise exceptions.RequiredParameter('data') @@ -36,7 +36,7 @@ def _required(data: dict, name: str): def create_sagemaker_studio_user(context: Context, source, input: dict = None): """Creates a SageMaker Studio user. Deploys the SageMaker Studio user stack into AWS""" - RequestValidator.validate_creation_request(input) + RequestValidator.validate_user_creation_request(input) request = SagemakerStudioCreationRequest.from_dict(input) return SagemakerStudioService.create_sagemaker_studio_user( uri=input["environmentUri"], @@ -90,6 +90,11 @@ def delete_sagemaker_studio_user( ) +def get_environment_sagemaker_studio_domain(context, source, environmentUri: str = None): + RequestValidator.required_uri(environmentUri) + return SagemakerStudioService.get_environment_sagemaker_studio_domain(environment_uri=environmentUri) + + def resolve_user_role(context: Context, source: SagemakerStudioUser): """ Resolves the role of the current user in reference with the SageMaker Studio User diff --git a/backend/dataall/modules/mlstudio/api/types.py b/backend/dataall/modules/mlstudio/api/types.py index 21290711e..ca21df81d 100644 --- a/backend/dataall/modules/mlstudio/api/types.py +++ b/backend/dataall/modules/mlstudio/api/types.py @@ -79,3 +79,27 @@ gql.Field(name='nodes', type=gql.ArrayType(SagemakerStudioUser)), ], ) + +SagemakerStudioDomain = gql.ObjectType( + name='SagemakerStudioDomain', + fields=[ + gql.Field(name='sagemakerStudioUri', type=gql.ID), + gql.Field(name='environmentUri', type=gql.NonNullableType(gql.String)), + gql.Field(name='sagemakerStudioDomainName', type=gql.String), + gql.Field(name='DefaultDomainRoleName', type=gql.String), + gql.Field(name='label', type=gql.String), + gql.Field(name='name', type=gql.String), + gql.Field(name='vpcType', type=gql.String), + gql.Field(name='vpcId', type=gql.String), + gql.Field(name='subnetIds', type=gql.ArrayType(gql.String)), + gql.Field(name='owner', type=gql.String), + gql.Field(name='created', type=gql.String), + gql.Field(name='updated', type=gql.String), + gql.Field(name='deleted', type=gql.String), + gql.Field( + name='environment', + type=gql.Ref('Environment'), + resolver=resolve_environment, + ) + ], +) diff --git a/backend/dataall/modules/mlstudio/aws/ec2_client.py b/backend/dataall/modules/mlstudio/aws/ec2_client.py deleted file mode 100644 index 3dc484254..000000000 --- a/backend/dataall/modules/mlstudio/aws/ec2_client.py +++ /dev/null @@ -1,27 +0,0 @@ -import logging - -from dataall.base.aws.sts import SessionHelper - - -log = logging.getLogger(__name__) - - -class EC2: - - @staticmethod - def get_client(account_id: str, region: str, role=None): - session = SessionHelper.remote_session(accountid=account_id, role=role) - return session.client('ec2', region_name=region) - - @staticmethod - def check_default_vpc_exists(AwsAccountId: str, region: str, role=None): - log.info("Check that default VPC exists..") - client = EC2.get_client(account_id=AwsAccountId, region=region, role=role) - response = client.describe_vpcs( - Filters=[{'Name': 'isDefault', 'Values': ['true']}] - ) - vpcs = response['Vpcs'] - log.info(f"Default VPCs response: {vpcs}") - if vpcs: - return True - return False diff --git a/backend/dataall/modules/mlstudio/aws/sagemaker_studio_client.py b/backend/dataall/modules/mlstudio/aws/sagemaker_studio_client.py index 2a82806ea..2ee872b1c 100644 --- a/backend/dataall/modules/mlstudio/aws/sagemaker_studio_client.py +++ b/backend/dataall/modules/mlstudio/aws/sagemaker_studio_client.py @@ -12,28 +12,22 @@ def get_client(AwsAccountId, region): return session.client('sagemaker', region_name=region) -def get_sagemaker_studio_domain(AwsAccountId, region): +def get_sagemaker_studio_domain(AwsAccountId, region, domain_name): """ Sagemaker studio domain is limited to 5 per account/region RETURN: an existing domain or None if no domain is in the AWS account """ client = get_client(AwsAccountId=AwsAccountId, region=region) - existing_domain = dict() try: domain_id_paginator = client.get_paginator('list_domains') - domains = domain_id_paginator.paginate() - for _domain in domains: - print(_domain) - for _domain in _domain.get('Domains'): - # Get the domain name created by dataall - if 'dataall' in _domain: - return _domain - else: - existing_domain = _domain - return existing_domain + for page in domain_id_paginator.paginate(): + for domain in page.get('Domains', []): + if domain.get("DomainName") == domain_name: + return domain + return dict() except ClientError as e: print(e) - return 'NotFound' + return dict() class SagemakerStudioClient: diff --git a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py index fe9040ab9..49082ccfb 100644 --- a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py +++ b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py @@ -12,14 +12,12 @@ aws_ssm as ssm, RemovalPolicy, ) -from botocore.exceptions import ClientError +from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository -from dataall.base.aws.parameter_store import ParameterStoreManager from dataall.base.aws.sts import SessionHelper from dataall.core.environment.cdk.environment_stack import EnvironmentSetup, EnvironmentStackExtension from dataall.core.environment.services.environment_service import EnvironmentService -from dataall.modules.mlstudio.aws.ec2_client import EC2 -from dataall.modules.mlstudio.aws.sagemaker_studio_client import get_sagemaker_studio_domain +from dataall.base.aws.ec2_client import EC2 logger = logging.getLogger(__name__) @@ -31,75 +29,84 @@ def extent(setup: EnvironmentSetup): _environment = setup.environment() with setup.get_engine().scoped_session() as session: enabled = EnvironmentService.get_boolean_env_param(session, _environment, "mlStudiosEnabled") - if not enabled: + domain = SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, _environment.environmentUri) + if not enabled or not domain: return sagemaker_principals = [setup.default_role] + setup.group_roles logger.info(f'Creating SageMaker base resources for sagemaker_principals = {sagemaker_principals}..') - cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn( - accountid=_environment.AwsAccountId, region=_environment.region - ) - existing_default_vpc = EC2.check_default_vpc_exists( - AwsAccountId=_environment.AwsAccountId, region=_environment.region, role=cdk_look_up_role_arn - ) - if existing_default_vpc: - logger.info("Using default VPC for Sagemaker Studio domain") - # Use default VPC - initial configuration (to be migrated) - vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', is_default=True) - subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] - subnet_ids += [public_subnet.subnet_id for public_subnet in vpc.public_subnets] - subnet_ids += [isolated_subnet.subnet_id for isolated_subnet in vpc.isolated_subnets] + + if domain.vpcId and domain.subnetIds and domain.vpcType == 'imported': + logger.info(f'Using VPC {domain.vpcId} and subnets {domain.subnetIds} for SageMaker Studio domain') + vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', vpc_id=domain.vpcId) + subnet_ids = domain.subnetIds security_groups = [] else: - logger.info("Default VPC not found, Exception. Creating a VPC for SageMaker resources...") - # Create VPC with 3 Public Subnets and 3 Private subnets wit NAT Gateways - log_group = logs.LogGroup( - setup, - f'SageMakerStudio{_environment.name}', - log_group_name=f'/{_environment.resourcePrefix}/{_environment.name}/vpc/sagemakerstudio', - retention=logs.RetentionDays.ONE_MONTH, - removal_policy=RemovalPolicy.DESTROY, - ) - vpc_flow_role = iam.Role( - setup, 'FlowLog', - assumed_by=iam.ServicePrincipal('vpc-flow-logs.amazonaws.com') - ) - vpc = ec2.Vpc( - setup, - "SageMakerVPC", - max_azs=3, - cidr="10.10.0.0/16", - subnet_configuration=[ - ec2.SubnetConfiguration( - subnet_type=ec2.SubnetType.PUBLIC, - name="Public", - cidr_mask=24 - ), - ec2.SubnetConfiguration( - subnet_type=ec2.SubnetType.PRIVATE_WITH_NAT, - name="Private", - cidr_mask=24 - ), - ], - enable_dns_hostnames=True, - enable_dns_support=True, - ) - ec2.FlowLog( - setup, "StudioVPCFlowLog", - resource_type=ec2.FlowLogResourceType.from_vpc(vpc), - destination=ec2.FlowLogDestination.to_cloud_watch_logs(log_group, vpc_flow_role) + cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn( + accountid=_environment.AwsAccountId, region=_environment.region ) - # setup security group to be used for sagemaker studio domain - sagemaker_sg = ec2.SecurityGroup( - setup, - "SecurityGroup", - vpc=vpc, - description="Security Group for SageMaker Studio", + existing_default_vpc = EC2.check_default_vpc_exists( + AwsAccountId=_environment.AwsAccountId, region=_environment.region, role=cdk_look_up_role_arn ) + if existing_default_vpc: + logger.info("Using default VPC for Sagemaker Studio domain") + # Use default VPC - initial configuration (to be migrated) + vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', is_default=True) + subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] + subnet_ids += [public_subnet.subnet_id for public_subnet in vpc.public_subnets] + subnet_ids += [isolated_subnet.subnet_id for isolated_subnet in vpc.isolated_subnets] + security_groups = [] + else: + logger.info("Default VPC not found, Exception. Creating a VPC for SageMaker resources...") + # Create VPC with 3 Public Subnets and 3 Private subnets wit NAT Gateways + log_group = logs.LogGroup( + setup, + f'SageMakerStudio{_environment.name}', + log_group_name=f'/{_environment.resourcePrefix}/{_environment.name}/vpc/sagemakerstudio', + retention=logs.RetentionDays.ONE_MONTH, + removal_policy=RemovalPolicy.DESTROY, + ) + vpc_flow_role = iam.Role( + setup, 'FlowLog', + assumed_by=iam.ServicePrincipal('vpc-flow-logs.amazonaws.com') + ) + vpc = ec2.Vpc( + setup, + "SageMakerVPC", + max_azs=3, + cidr="10.10.0.0/16", + subnet_configuration=[ + ec2.SubnetConfiguration( + subnet_type=ec2.SubnetType.PUBLIC, + name="Public", + cidr_mask=24 + ), + ec2.SubnetConfiguration( + subnet_type=ec2.SubnetType.PRIVATE_WITH_NAT, + name="Private", + cidr_mask=24 + ), + ], + enable_dns_hostnames=True, + enable_dns_support=True, + ) + ec2.FlowLog( + setup, "StudioVPCFlowLog", + resource_type=ec2.FlowLogResourceType.from_vpc(vpc), + destination=ec2.FlowLogDestination.to_cloud_watch_logs(log_group, vpc_flow_role) + ) + # setup security group to be used for sagemaker studio domain + sagemaker_sg = ec2.SecurityGroup( + setup, + "SecurityGroup", + vpc=vpc, + description="Security Group for SageMaker Studio", + security_group_name=domain.sagemakerStudioDomainName, + ) - sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic()) - security_groups = [sagemaker_sg.security_group_id] - subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] + sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic()) + security_groups = [sagemaker_sg.security_group_id] + subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] vpc_id = vpc.vpc_id @@ -107,7 +114,7 @@ def extent(setup: EnvironmentSetup): setup, 'RoleForSagemakerStudioUsers', assumed_by=iam.ServicePrincipal('sagemaker.amazonaws.com'), - role_name='RoleSagemakerStudioUsers', + role_name=domain.DefaultDomainRoleName, managed_policies=[ iam.ManagedPolicy.from_managed_policy_arn( setup, @@ -123,7 +130,7 @@ def extent(setup: EnvironmentSetup): sagemaker_domain_key = kms.Key( setup, 'SagemakerDomainKmsKey', - alias='SagemakerStudioDomain', + alias=domain.sagemakerStudioDomainName, enable_key_rotation=True, admins=[ iam.ArnPrincipal(_environment.CDKRoleArn) @@ -175,7 +182,7 @@ def extent(setup: EnvironmentSetup): sagemaker_domain = sagemaker.CfnDomain( setup, 'SagemakerStudioDomain', - domain_name=f'SagemakerStudioDomain-{_environment.region}-{_environment.AwsAccountId}', + domain_name=domain.sagemakerStudioDomainName, auth_mode='IAM', default_user_settings=sagemaker.CfnDomain.UserSettingsProperty( execution_role=sagemaker_domain_role.role_arn, @@ -199,22 +206,3 @@ def extent(setup: EnvironmentSetup): parameter_name=f'/{_environment.resourcePrefix}/{_environment.environmentUri}/sagemaker/sagemakerstudio/domain_id', ) return sagemaker_domain - - @staticmethod - def check_existing_sagemaker_studio_domain(environment): - logger.info('Check if there is an existing sagemaker studio domain in the account') - try: - logger.info('check sagemaker studio domain created as part of data.all environment stack.') - cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn( - accountid=environment.AwsAccountId, region=environment.region - ) - dataall_created_domain = ParameterStoreManager.client( - AwsAccountId=environment.AwsAccountId, region=environment.region, role=cdk_look_up_role_arn - ).get_parameter(Name=f'/{environment.resourcePrefix}/{environment.environmentUri}/sagemaker/sagemakerstudio/domain_id') - return False - except ClientError as e: - logger.info(f'check sagemaker studio domain created outside of data.all. Parameter data.all not found: {e}') - existing_domain = get_sagemaker_studio_domain( - AwsAccountId=environment.AwsAccountId, region=environment.region, role=cdk_look_up_role_arn - ) - return existing_domain.get('DomainId', False) diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_models.py b/backend/dataall/modules/mlstudio/db/mlstudio_models.py index 032826588..a4c93a2fa 100644 --- a/backend/dataall/modules/mlstudio/db/mlstudio_models.py +++ b/backend/dataall/modules/mlstudio/db/mlstudio_models.py @@ -2,6 +2,7 @@ from sqlalchemy import Column, String, ForeignKey from sqlalchemy.orm import query_expression +from sqlalchemy.dialects.postgresql import ARRAY from dataall.base.db import Base from dataall.base.db import Resource, utils @@ -10,16 +11,20 @@ class SagemakerStudioDomain(Resource, Base): """Describes ORM model for sagemaker ML Studio domain""" __tablename__ = 'sagemaker_studio_domain' - environmentUri = Column(String, nullable=False) + environmentUri = Column(String, ForeignKey("environment.environmentUri")) sagemakerStudioUri = Column( String, primary_key=True, default=utils.uuid('sagemakerstudio') ) - sagemakerStudioDomainID = Column(String, nullable=False) - SagemakerStudioStatus = Column(String, nullable=False) + sagemakerStudioDomainID = Column(String, nullable=True) + SagemakerStudioStatus = Column(String, nullable=True) + sagemakerStudioDomainName = Column(String, nullable=False) AWSAccountId = Column(String, nullable=False) - RoleArn = Column(String, nullable=False) + DefaultDomainRoleName = Column(String, nullable=False) region = Column(String, default='eu-west-1') - userRoleForSagemakerStudio = query_expression() + SamlGroupName = Column(String, nullable=False) + vpcType = Column(String, nullable=True) + vpcId = Column(String, nullable=True) + subnetIds = Column(ARRAY(String), nullable=True) class SagemakerStudioUser(Resource, Base): diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py index 763ca6f92..21847b6ef 100644 --- a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py +++ b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py @@ -2,30 +2,34 @@ DAO layer that encapsulates the logic and interaction with the database for ML Studio Provides the API to retrieve / update / delete ml studio """ +from typing import Optional from sqlalchemy import or_ from sqlalchemy.sql import and_ from sqlalchemy.orm import Query +from dataall.base.utils import slugify from dataall.base.db import paginate -from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser -from dataall.core.environment.services.environment_resource_manager import EnvironmentResource +from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioDomain, SagemakerStudioUser +from dataall.base.utils.naming_convention import ( + NamingConventionService, + NamingConventionPattern, +) -class SageMakerStudioRepository(EnvironmentResource): +class SageMakerStudioRepository: """DAO layer for ML Studio""" _DEFAULT_PAGE = 1 _DEFAULT_PAGE_SIZE = 10 - def __init__(self, session): - self._session = session - - def save_sagemaker_studio_user(self, user): + @staticmethod + def save_sagemaker_studio_user(session, user): """Save SageMaker Studio user to the database""" - self._session.add(user) - self._session.commit() + session.add(user) + session.commit() - def _query_user_sagemaker_studio_users(self, username, groups, filter) -> Query: - query = self._session.query(SagemakerStudioUser).filter( + @staticmethod + def _query_user_sagemaker_studio_users(session, username, groups, filter) -> Query: + query = session.query(SagemakerStudioUser).filter( or_( SagemakerStudioUser.owner == username, SagemakerStudioUser.SamlAdminGroupName.in_(groups), @@ -44,21 +48,24 @@ def _query_user_sagemaker_studio_users(self, username, groups, filter) -> Query: ) return query - def paginated_sagemaker_studio_users(self, username, groups, filter=None) -> dict: + @staticmethod + def paginated_sagemaker_studio_users(session, username, groups, filter={}) -> dict: """Returns a page of sagemaker studio users for a data.all user""" return paginate( - query=self._query_user_sagemaker_studio_users(username, groups, filter), + query=SageMakerStudioRepository._query_user_sagemaker_studio_users(session, username, groups, filter), page=filter.get('page', SageMakerStudioRepository._DEFAULT_PAGE), page_size=filter.get('pageSize', SageMakerStudioRepository._DEFAULT_PAGE_SIZE), ).to_dict() - def find_sagemaker_studio_user(self, uri): + @staticmethod + def find_sagemaker_studio_user(session, uri): """Finds a sagemaker studio user. Returns None if it doesn't exist""" - return self._session.query(SagemakerStudioUser).get(uri) + return session.query(SagemakerStudioUser).get(uri) - def count_resources(self, environment, group_uri): + @staticmethod + def count_resources(session, environment, group_uri): return ( - self._session.query(SagemakerStudioUser) + session.query(SagemakerStudioUser) .filter( and_( SagemakerStudioUser.environmentUri == environment.environmentUri, @@ -67,3 +74,57 @@ def count_resources(self, environment, group_uri): ) .count() ) + + @staticmethod + def create_sagemaker_studio_domain(session, username, environment, data): + domain = SagemakerStudioDomain( + label=f"{data.get('label')}-domain", + owner=username, + description=data.get('description', 'No description provided'), + tags=data.get('tags', []), + SamlGroupName=environment.SamlGroupName, + environmentUri=environment.environmentUri, + AWSAccountId=environment.AwsAccountId, + region=environment.region, + SagemakerStudioStatus="PENDING", + DefaultDomainRoleName="DefaultMLStudioRole", + sagemakerStudioDomainName=slugify(data.get('label'), separator=''), + vpcType=data.get('vpcType'), + vpcId=data.get('vpcId'), + subnetIds=data.get('subnetIds', []) + ) + session.add(domain) + session.commit() + + domain.sagemakerStudioDomainName = NamingConventionService( + target_uri=domain.sagemakerStudioUri, + target_label=domain.label, + pattern=NamingConventionPattern.MLSTUDIO_DOMAIN, + resource_prefix=environment.resourcePrefix, + ).build_compliant_name() + + domain.DefaultDomainRoleName = NamingConventionService( + target_uri=domain.sagemakerStudioUri, + target_label=domain.label, + pattern=NamingConventionPattern.IAM, + resource_prefix=environment.resourcePrefix, + ).build_compliant_name() + + return domain + + @staticmethod + def get_sagemaker_studio_domain_by_env_uri(session, env_uri) -> Optional[SagemakerStudioDomain]: + domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter( + SagemakerStudioDomain.environmentUri == env_uri, + ).first() + if not domain: + return None + return domain + + @staticmethod + def delete_sagemaker_studio_domain_by_env_uri(session, env_uri) -> Optional[SagemakerStudioDomain]: + domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter( + SagemakerStudioDomain.environmentUri == env_uri, + ).first() + if domain: + session.delete(domain) diff --git a/backend/dataall/modules/mlstudio/services/mlstudio_service.py b/backend/dataall/modules/mlstudio/services/mlstudio_service.py index 06750b822..3738c118d 100644 --- a/backend/dataall/modules/mlstudio/services/mlstudio_service.py +++ b/backend/dataall/modules/mlstudio/services/mlstudio_service.py @@ -11,13 +11,18 @@ from dataall.core.environment.env_permission_checker import has_group_permission from dataall.core.environment.services.environment_service import EnvironmentService from dataall.core.permissions.db.resource_policy_repositories import ResourcePolicy +from dataall.core.permissions import permissions from dataall.core.permissions.permission_checker import has_resource_permission, has_tenant_permission from dataall.core.stacks.api import stack_helper from dataall.core.stacks.db.stack_repositories import Stack from dataall.base.db import exceptions from dataall.modules.mlstudio.aws.sagemaker_studio_client import sagemaker_studio_client, get_sagemaker_studio_domain from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository +from dataall.core.environment.services.environment_resource_manager import EnvironmentResource from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser +from dataall.base.aws.ec2_client import EC2 +from dataall.base.aws.sts import SessionHelper + from dataall.modules.mlstudio.services.mlstudio_permissions import ( MANAGE_SGMSTUDIO_USERS, CREATE_SGMSTUDIO_USER, @@ -54,6 +59,38 @@ def _session(): return get_context().db_engine.scoped_session() +class SagemakerStudioEnvironmentResource(EnvironmentResource): + @staticmethod + def count_resources(session, environment, group_uri) -> int: + return SageMakerStudioRepository.count_resources(session, environment, group_uri) + + @staticmethod + def create_env(session, environment, **kwargs): + enabled = EnvironmentService.get_boolean_env_param(session, environment, "mlStudiosEnabled") + if enabled: + SagemakerStudioService.create_sagemaker_studio_domain(session, environment, **kwargs) + + @staticmethod + def update_env(session, environment, **kwargs): + current_mlstudio_enabled = EnvironmentService.get_boolean_env_param(session, environment, "mlStudiosEnabled") + domain = SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, environment.environmentUri) + previous_mlstudio_enabled = True if domain else False + if (current_mlstudio_enabled != previous_mlstudio_enabled and previous_mlstudio_enabled): + SageMakerStudioRepository.delete_sagemaker_studio_domain_by_env_uri(session=session, env_uri=environment.environmentUri) + return True + elif (current_mlstudio_enabled != previous_mlstudio_enabled and not previous_mlstudio_enabled): + SagemakerStudioService.create_sagemaker_studio_domain(session, environment, **kwargs) + return True + elif current_mlstudio_enabled and domain and domain.vpcType == "unknown": + SagemakerStudioService.update_sagemaker_studio_domain(environment, domain, **kwargs) + return True + return False + + @staticmethod + def delete_env(session, environment): + SageMakerStudioRepository.delete_sagemaker_studio_domain_by_env_uri(session=session, env_uri=environment.environmentUri) + + class SagemakerStudioService: """ Encapsulate the logic of interactions with sagemaker ml studio. @@ -77,17 +114,19 @@ def create_sagemaker_studio_user(*, uri: str, admin_group: str, request: Sagemak action=CREATE_SGMSTUDIO_USER, message=f'ML Studio feature is disabled for the environment {env.label}', ) + + domain = SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, env_uri=env.environmentUri) response = get_sagemaker_studio_domain( AwsAccountId=env.AwsAccountId, - region=env.region + region=env.region, + domain_name=domain.sagemakerStudioDomainName ) existing_domain = response.get('DomainId', False) if not existing_domain: raise exceptions.AWSResourceNotAvailable( action='Sagemaker Studio domain', - message='Update the environment stack ' - 'or create a Sagemaker studio domain on your AWS account.', + message='Update the environment stack and enable ML Studio Environment Feature' ) sagemaker_studio_user = SagemakerStudioUser( @@ -104,7 +143,7 @@ def create_sagemaker_studio_user(*, uri: str, admin_group: str, request: Sagemak SamlAdminGroupName=admin_group, tags=request.tags, ) - SageMakerStudioRepository(session).save_sagemaker_studio_user(user=sagemaker_studio_user) + SageMakerStudioRepository.save_sagemaker_studio_user(session, sagemaker_studio_user) ResourcePolicy.attach_resource_policy( session=session, @@ -135,10 +174,58 @@ def create_sagemaker_studio_user(*, uri: str, admin_group: str, request: Sagemak return sagemaker_studio_user + @staticmethod + def update_sagemaker_studio_domain(environment, domain, data): + SagemakerStudioService._update_sagemaker_studio_domain_vpc(environment.AwsAccountId, environment.region, data) + domain.vpcType = data.get('vpcType') + if data.get('vpcId'): + domain.vpcId = data.get('vpcId') + if data.get('subnetIds'): + domain.subnetIds = data.get('subnetIds') + + @staticmethod + def _update_sagemaker_studio_domain_vpc(account_id, region, data={}): + cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn( + accountid=account_id, region=region + ) + if data.get("vpcId", None): + data["vpcType"] = "imported" + else: + response = EC2.check_default_vpc_exists( + AwsAccountId=account_id, + region=region, + role=cdk_look_up_role_arn, + ) + if response: + vpcId, subnetIds = response + data["vpcType"] = "default" + data["vpcId"] = vpcId + data["subnetIds"] = subnetIds + else: + data["vpcType"] = "created" + + @staticmethod + def create_sagemaker_studio_domain(session, environment, data: dict = {}): + SagemakerStudioService._update_sagemaker_studio_domain_vpc(environment.AwsAccountId, environment.region, data) + + domain = SageMakerStudioRepository.create_sagemaker_studio_domain( + session=session, + username=get_context().username, + environment=environment, + data=data, + ) + return domain + + @staticmethod + def get_environment_sagemaker_studio_domain(*, environment_uri: str): + with _session() as session: + return SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, env_uri=environment_uri) + @staticmethod def list_sagemaker_studio_users(*, filter: dict) -> dict: with _session() as session: - return SageMakerStudioRepository(session).paginated_sagemaker_studio_users( + return SageMakerStudioRepository.paginated_sagemaker_studio_users( + session=session, username=get_context().username, groups=get_context().groups, filter=filter, @@ -197,7 +284,7 @@ def delete_sagemaker_studio_user(*, uri: str, delete_from_aws: bool): @staticmethod def _get_sagemaker_studio_user(session, uri): - user = SageMakerStudioRepository(session).find_sagemaker_studio_user(uri=uri) + user = SageMakerStudioRepository.find_sagemaker_studio_user(session=session, uri=uri) if not user: raise exceptions.ObjectNotFound('SagemakerStudioUser', uri) return user diff --git a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py new file mode 100644 index 000000000..a3ac794f3 --- /dev/null +++ b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py @@ -0,0 +1,178 @@ +"""env_mlstudio_domain_table + +Revision ID: 71a5f5de322f +Revises: 8c79fb896983 +Create Date: 2023-11-29 09:44:04.160286 + +""" +import os +from sqlalchemy import orm, Column, String, Boolean, ForeignKey, and_ +from sqlalchemy.ext.declarative import declarative_base +import sqlalchemy as sa +from alembic import op + +from sqlalchemy.dialects import postgresql +from dataall.base.db import get_engine, has_table +from dataall.base.db import utils, Resource + +# revision identifiers, used by Alembic. +revision = '71a5f5de322f' +down_revision = '8c79fb896983' +branch_labels = None +depends_on = None + +Base = declarative_base() + + +class Environment(Resource, Base): + __tablename__ = "environment" + environmentUri = Column(String, primary_key=True) + AwsAccountId = Column(Boolean) + region = Column(Boolean) + SamlGroupName = Column(String) + + +class EnvironmentParameter(Base): + __tablename__ = 'environment_parameters' + environmentUri = Column(String, primary_key=True) + key = Column('paramKey', String, primary_key=True) + value = Column('paramValue', String, nullable=True) + + +class SagemakerStudioDomain(Resource, Base): + __tablename__ = 'sagemaker_studio_domain' + environmentUri = Column(String, ForeignKey("environment.environmentUri")) + sagemakerStudioUri = Column( + String, primary_key=True, default=utils.uuid('sagemakerstudio') + ) + sagemakerStudioDomainID = Column(String, nullable=True) + SagemakerStudioStatus = Column(String, nullable=True) + sagemakerStudioDomainName = Column(String, nullable=False) + AWSAccountId = Column(String, nullable=False) + DefaultDomainRoleName = Column(String, nullable=False) + region = Column(String, default='eu-west-1') + SamlGroupName = Column(String, nullable=False) + vpcType = Column(String, nullable=True) + + +def upgrade(): + """ + The script does the following migration: + 1) update of the sagemaker_studio_domain table to include SageMaker Studio Domain VPC Information + """ + try: + envname = os.getenv('envname', 'local') + engine = get_engine(envname=envname).engine + + bind = op.get_bind() + session = orm.Session(bind=bind) + + if has_table('sagemaker_studio_domain', engine): + print("Updating sagemaker_studio_domain table...") + op.alter_column( + 'sagemaker_studio_domain', + 'sagemakerStudioDomainID', + nullable=True, + existing_type=sa.String() + ) + op.alter_column( + 'sagemaker_studio_domain', + 'SagemakerStudioStatus', + nullable=True, + existing_type=sa.String() + ) + op.alter_column( + 'sagemaker_studio_domain', + 'RoleArn', + new_column_name='DefaultDomainRoleName', + nullable=False, + existing_type=sa.String() + ) + + op.add_column("sagemaker_studio_domain", Column("sagemakerStudioDomainName", sa.String(), nullable=False)) + op.add_column("sagemaker_studio_domain", Column("vpcType", sa.String(), nullable=True)) + op.add_column("sagemaker_studio_domain", Column("vpcId", sa.String(), nullable=True)) + op.add_column("sagemaker_studio_domain", Column("subnetIds", postgresql.ARRAY(sa.String()), nullable=True)) + op.add_column("sagemaker_studio_domain", Column("SamlGroupName", sa.String(), nullable=False)) + + op.create_foreign_key( + "fk_sagemaker_studio_domain_env_uri", + "sagemaker_studio_domain", "environment", + ["environmentUri"], ["environmentUri"], + ) + + print("Update sagemaker_studio_domain table done.") + print("Filling sagemaker_studio_domain table with environments with mlstudio enabled...") + + env_mlstudio_parameters: [EnvironmentParameter] = session.query(EnvironmentParameter).filter( + and_( + EnvironmentParameter.key == "mlStudiosEnabled", + EnvironmentParameter.value == "true" + ) + ).all() + for param in env_mlstudio_parameters: + env: Environment = session.query(Environment).filter( + Environment.environmentUri == param.environmentUri + ).first() + + domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter( + SagemakerStudioDomain.environmentUri == env.environmentUri + ).first() + if not domain: + domain = SagemakerStudioDomain( + label=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}", + owner=env.owner, + description='No description provided', + environmentUri=env.environmentUri, + AWSAccountId=env.AwsAccountId, + region=env.region, + DefaultDomainRoleName="RoleSagemakerStudioUsers", + sagemakerStudioDomainName=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}", + vpcType="unknown", + SamlGroupName=env.SamlGroupName + ) + session.add(domain) + session.flush() + session.commit() + print("Fill of sagemaker_studio_domain table is done") + + except Exception as exception: + print('Failed to upgrade due to:', exception) + raise exception + + +def downgrade(): + try: + envname = os.getenv('envname', 'local') + engine = get_engine(envname=envname).engine + + bind = op.get_bind() + session = orm.Session(bind=bind) + + if has_table('sagemaker_studio_domain', engine): + print("deleting sagemaker studio domain entries...") + session.query(SagemakerStudioDomain).delete() + + print("Updating of sagemaker_studio_domain table...") + op.alter_column( + 'sagemaker_studio_domain', + 'DefaultDomainRoleName', + new_column_name='RoleArn', + nullable=False, + existing_type=sa.String() + ) + + op.drop_column("sagemaker_studio_domain", "sagemakerStudioDomainName") + op.drop_column("sagemaker_studio_domain", "vpcType") + op.drop_column("sagemaker_studio_domain", "vpcId") + op.drop_column("sagemaker_studio_domain", "subnetIds") + op.drop_column("sagemaker_studio_domain", "SamlGroupName") + + op.drop_constraint("fk_sagemaker_studio_domain_env_uri", "sagemaker_studio_domain") + + session.commit() + print("Update of sagemaker_studio_domain table is done") + + except Exception as exception: + print('Failed to downgrade due to:', exception) + raise exception diff --git a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js new file mode 100644 index 000000000..44dac97b9 --- /dev/null +++ b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js @@ -0,0 +1,156 @@ +import { + Box, + Card, + CardHeader, + Divider, + Grid, + CardContent, + Typography, + CircularProgress, + Chip +} from '@mui/material'; + +import PropTypes from 'prop-types'; +import React, { useCallback, useEffect, useState } from 'react'; +import { RefreshTableMenu } from 'design'; +import { SET_ERROR, useDispatch } from 'globalErrors'; +import { getEnvironmentMLStudioDomain, useClient } from 'services'; + +export const EnvironmentMLStudio = ({ environment }) => { + const client = useClient(); + const dispatch = useDispatch(); + const [mlStudioDomain, setMLStudioDomain] = useState(null); + const [loading, setLoading] = useState(true); + + const fetchMLStudioDomain = useCallback(async () => { + try { + setLoading(true); + const response = await client.query( + getEnvironmentMLStudioDomain({ + environmentUri: environment.environmentUri + }) + ); + if (!response.errors) { + if (response.data.getEnvironmentMLStudioDomain) { + setMLStudioDomain(response.data.getEnvironmentMLStudioDomain); + } + } else { + dispatch({ type: SET_ERROR, error: response.errors[0].message }); + } + } catch (e) { + dispatch({ type: SET_ERROR, error: e.message }); + } finally { + setLoading(false); + } + }, [client, dispatch, environment.environmentUri]); + + useEffect(() => { + if (client) { + fetchMLStudioDomain().catch((e) => + dispatch({ type: SET_ERROR, error: e.message }) + ); + } + }, [client, fetchMLStudioDomain, dispatch]); + + if (loading) { + return ; + } + + return ( + + + } + title={ML Studio Domain Information} + /> + + + + + {mlStudioDomain === null ? ( + + + No ML Studio Domain - To Create a ML Studio Domain for this + Environment: {environment.label}, edit the Environment and enable + the ML Studio Environment Feature + + + ) : ( + + + + + SageMaker ML Studio Domain Name + + + {mlStudioDomain.sagemakerStudioDomainName} + + + + + SageMaker ML Studio Default Execution Role + + + arn:aws:iam::{environment.AwsAccountId}:role/ + {mlStudioDomain.DefaultDomainRoleName} + + + + + Domain VPC Type + + + {mlStudioDomain.vpcType} + + + {(mlStudioDomain.vpcType === 'imported' || + mlStudioDomain.vpcType === 'default') && ( + <> + + + Domain VPC Id + + + {mlStudioDomain.vpcId} + + + + + Domain Subnet Ids + + + {mlStudioDomain.subnetIds?.map((subnet) => ( + + ))} + + + + )} + + + )} + + + ); +}; + +EnvironmentMLStudio.propTypes = { + environment: PropTypes.object.isRequired +}; diff --git a/frontend/src/modules/Environments/components/index.js b/frontend/src/modules/Environments/components/index.js index afccd1235..7aecd51fa 100644 --- a/frontend/src/modules/Environments/components/index.js +++ b/frontend/src/modules/Environments/components/index.js @@ -12,3 +12,4 @@ export * from './EnvironmentTeamInviteEditForm'; export * from './EnvironmentTeamInviteForm'; export * from './EnvironmentTeams'; export * from './NetworkCreateModal'; +export * from './EnvironmentMLStudio'; diff --git a/frontend/src/modules/Environments/views/EnvironmentCreateForm.js b/frontend/src/modules/Environments/views/EnvironmentCreateForm.js index 2767e1fcc..a16cdfa7e 100644 --- a/frontend/src/modules/Environments/views/EnvironmentCreateForm.js +++ b/frontend/src/modules/Environments/views/EnvironmentCreateForm.js @@ -31,6 +31,13 @@ import { CopyToClipboard } from 'react-copy-to-clipboard/lib/Component'; import { Helmet } from 'react-helmet-async'; import { Link as RouterLink, useNavigate, useParams } from 'react-router-dom'; import * as Yup from 'yup'; +import { + createEnvironment, + getPivotRoleExternalId, + getPivotRoleName, + getPivotRolePresignedUrl, + getCDKExecPolicyPresignedUrl +} from '../services'; import { ArrowLeftIcon, ChevronRightIcon, @@ -44,13 +51,6 @@ import { useClient, useGroups } from 'services'; -import { - createEnvironment, - getPivotRoleExternalId, - getPivotRoleName, - getPivotRolePresignedUrl, - getCDKExecPolicyPresignedUrl -} from '../services'; import { AwsRegions, isAnyEnvironmentModuleEnabled, @@ -179,6 +179,8 @@ const EnvironmentCreateForm = (props) => { region: values.region, EnvironmentDefaultIAMRoleArn: values.EnvironmentDefaultIAMRoleArn, resourcePrefix: values.resourcePrefix, + vpcId: values.vpcId, + subnetIds: values.subnetIds, parameters: [ { key: 'notebooksEnabled', @@ -484,7 +486,9 @@ const EnvironmentCreateForm = (props) => { mlStudiosEnabled: isModuleEnabled(ModuleNames.MLSTUDIO), pipelinesEnabled: isModuleEnabled(ModuleNames.DATAPIPELINES), EnvironmentDefaultIAMRoleArn: '', - resourcePrefix: 'dataall' + resourcePrefix: 'dataall', + vpcId: '', + subnetIds: [] }} validationSchema={Yup.object().shape({ label: Yup.string() @@ -508,8 +512,14 @@ const EnvironmentCreateForm = (props) => { ).length >= 1 ), tags: Yup.array().nullable(), - privateSubnetIds: Yup.array().nullable(), - publicSubnetIds: Yup.array().nullable(), + subnetIds: Yup.array().when('vpcId', { + is: (value) => !!value, + then: Yup.array() + .min(1) + .required( + 'At least 1 Subnet Id required if VPC Id specified' + ) + }), vpcId: Yup.string().nullable(), EnvironmentDefaultIAMRoleArn: Yup.string().nullable(), resourcePrefix: Yup.string() @@ -862,6 +872,45 @@ const EnvironmentCreateForm = (props) => { + {values.mlStudiosEnabled && ( + + + + + + + + { + setFieldValue('subnetIds', [...chip]); + }} + /> + + + + )} {errors.submit && ( {errors.submit} diff --git a/frontend/src/modules/Environments/views/EnvironmentEditForm.js b/frontend/src/modules/Environments/views/EnvironmentEditForm.js index caa5d8441..382575920 100644 --- a/frontend/src/modules/Environments/views/EnvironmentEditForm.js +++ b/frontend/src/modules/Environments/views/EnvironmentEditForm.js @@ -30,7 +30,7 @@ import { useSettings } from 'design'; import { SET_ERROR, useDispatch } from 'globalErrors'; -import { useClient } from 'services'; +import { getEnvironmentMLStudioDomain, useClient } from 'services'; import { getEnvironment, updateEnvironment } from '../services'; import { isAnyEnvironmentModuleEnabled, @@ -47,6 +47,9 @@ const EnvironmentEditForm = (props) => { const { settings } = useSettings(); const [loading, setLoading] = useState(true); const [env, setEnv] = useState(''); + const [envMLStudioDomain, setEnvMLStudioDomain] = useState(''); + const [previousEnvMLStudioEnabled, setPreviousEnvMLStudioEnabled] = + useState(false); const fetchItem = useCallback(async () => { const response = await client.query( @@ -58,6 +61,20 @@ const EnvironmentEditForm = (props) => { environment.parameters.map((x) => [x.key, x.value]) ); setEnv(environment); + if (environment.parameters['mlStudiosEnabled'] === 'true') { + setPreviousEnvMLStudioEnabled(true); + const response2 = await client.query( + getEnvironmentMLStudioDomain({ environmentUri: params.uri }) + ); + if (!response2.errors && response2.data.getEnvironmentMLStudioDomain) { + setEnvMLStudioDomain(response2.data.getEnvironmentMLStudioDomain); + } else { + const error = response2.errors + ? response2.errors[0].message + : 'Environment ML Studio Domain not found'; + dispatch({ type: SET_ERROR, error }); + } + } } else { const error = response.errors ? response.errors[0].message @@ -66,11 +83,13 @@ const EnvironmentEditForm = (props) => { } setLoading(false); }, [client, dispatch, params.uri]); + useEffect(() => { if (client) { fetchItem().catch((e) => dispatch({ type: SET_ERROR, error: e.message })); } }, [client, fetchItem, dispatch]); + async function submit(values, setStatus, setSubmitting, setErrors) { try { const response = await client.mutate( @@ -81,6 +100,8 @@ const EnvironmentEditForm = (props) => { tags: values.tags, description: values.description, resourcePrefix: values.resourcePrefix, + vpcId: values.vpcId, + subnetIds: values.subnetIds, parameters: [ { key: 'notebooksEnabled', @@ -213,6 +234,8 @@ const EnvironmentEditForm = (props) => { label: env.label, description: env.description, tags: env.tags || [], + vpcId: envMLStudioDomain.vpcId || '', + subnetIds: envMLStudioDomain.subnetIds || [], notebooksEnabled: env.parameters['notebooksEnabled'] === 'true', mlStudiosEnabled: env.parameters['mlStudiosEnabled'] === 'true', pipelinesEnabled: env.parameters['pipelinesEnabled'] === 'true', @@ -226,6 +249,15 @@ const EnvironmentEditForm = (props) => { .required('*Environment name is required'), description: Yup.string().max(5000), tags: Yup.array().nullable(), + subnetIds: Yup.array().when('vpcId', { + is: (value) => !!value, + then: Yup.array() + .min(1) + .required( + 'At least 1 Subnet Id required if VPC Id specified' + ) + }), + vpcId: Yup.string().nullable(), resourcePrefix: Yup.string() .trim() .matches( @@ -383,6 +415,48 @@ const EnvironmentEditForm = (props) => { + {!previousEnvMLStudioEnabled && + values.mlStudiosEnabled && ( + + + + + + + + { + setFieldValue('subnetIds', [...chip]); + }} + /> + + + + )} {isAnyEnvironmentModuleEnabled() && ( diff --git a/frontend/src/modules/Environments/views/EnvironmentView.js b/frontend/src/modules/Environments/views/EnvironmentView.js index 0ba724320..792918c13 100644 --- a/frontend/src/modules/Environments/views/EnvironmentView.js +++ b/frontend/src/modules/Environments/views/EnvironmentView.js @@ -39,6 +39,7 @@ import { archiveEnvironment, getEnvironment } from '../services'; import { KeyValueTagList, Stack, StackStatus } from 'modules/Shared'; import { EnvironmentDatasets, + EnvironmentMLStudio, EnvironmentOverview, EnvironmentSubscriptions, EnvironmentTeams, @@ -59,6 +60,12 @@ const tabs = [ icon: , active: isModuleEnabled(ModuleNames.DATASETS) }, + { + label: 'ML Studio Domain', + value: 'mlstudio', + icon: , + active: isModuleEnabled(ModuleNames.MLSTUDIO) + }, { label: 'Networks', value: 'networks', icon: }, { label: 'Subscriptions', @@ -267,6 +274,9 @@ const EnvironmentView = () => { fetchItem={fetchItem} /> )} + {isAdmin && currentTab === 'mlstudio' && ( + + )} {isAdmin && currentTab === 'tags' && ( ({ + variables: { + environmentUri + }, + query: gql` + query getEnvironmentMLStudioDomain($environmentUri: String) { + getEnvironmentMLStudioDomain(environmentUri: $environmentUri) { + sagemakerStudioUri + environmentUri + label + sagemakerStudioDomainName + DefaultDomainRoleName + vpcType + vpcId + subnetIds + owner + created + } + } + ` +}); diff --git a/frontend/src/services/graphql/MLStudio/index.js b/frontend/src/services/graphql/MLStudio/index.js new file mode 100644 index 000000000..97d3de110 --- /dev/null +++ b/frontend/src/services/graphql/MLStudio/index.js @@ -0,0 +1 @@ +export * from './getEnvironmentMLStudioDomain'; diff --git a/frontend/src/services/graphql/index.js b/frontend/src/services/graphql/index.js index 8d0e00804..ce1c3fba2 100644 --- a/frontend/src/services/graphql/index.js +++ b/frontend/src/services/graphql/index.js @@ -8,6 +8,7 @@ export * from './Glossary'; export * from './Groups'; export * from './KeyValueTags'; export * from './Metric'; +export * from './MLStudio'; export * from './Notification'; export * from './Organization'; export * from './Principal'; diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 6d8a449e4..738ab4d06 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -44,7 +44,6 @@ def factory(org, envname, owner, group, account, region, desc='test', parameters 'tags': ['a', 'b', 'c'], 'region': f'{region}', 'SamlGroupName': f'{group}', - 'vpcId': 'vpc-123456', 'parameters': [{'key': k, 'value': v} for k, v in parameters.items()] }, ) diff --git a/tests/core/environments/test_environment.py b/tests/core/environments/test_environment.py index 31ba18e57..e806e07b2 100644 --- a/tests/core/environments/test_environment.py +++ b/tests/core/environments/test_environment.py @@ -221,26 +221,6 @@ def test_list_environments_no_filter(org_fixture, env_fixture, client, group): assert response.data.listEnvironments.count == 1 - response = client.query( - """ - query ListEnvironmentNetworks($environmentUri: String!,$filter:VpcFilter){ - listEnvironmentNetworks(environmentUri:$environmentUri,filter:$filter){ - count - nodes{ - VpcId - SamlGroupName - } - } - } - """, - environmentUri=env_fixture.environmentUri, - username='alice', - groups=[group.name], - ) - print(response) - - assert response.data.listEnvironmentNetworks.count == 1 - def test_list_environment_role_filter_as_creator(org_fixture, env_fixture, client, group): response = client.query( @@ -656,23 +636,16 @@ def test_create_environment(db, client, org_fixture, env_fixture, user, group): 'tags': ['a', 'b', 'c'], 'region': f'{env_fixture.region}', 'SamlGroupName': group.name, - 'vpcId': 'vpc-1234567', - 'privateSubnetIds': 'subnet-1', - 'publicSubnetIds': 'subnet-21', 'resourcePrefix': 'customer-prefix', }, ) body = response.data.createEnvironment - assert body.networks + assert len(body.networks) == 0 assert body.EnvironmentDefaultIAMRoleName == 'myOwnIamRole' assert body.EnvironmentDefaultIAMRoleImported assert body.resourcePrefix == 'customer-prefix' - for vpc in body.networks: - assert vpc.privateSubnetIds - assert vpc.publicSubnetIds - assert vpc.default with db.scoped_session() as session: env = EnvironmentService.get_environment_by_uri( diff --git a/tests/core/vpc/test_vpc.py b/tests/core/vpc/test_vpc.py index a55196d32..8f2391220 100644 --- a/tests/core/vpc/test_vpc.py +++ b/tests/core/vpc/test_vpc.py @@ -60,7 +60,7 @@ def test_list_networks(client, env_fixture, db, user, group, vpc): ) print(response) - assert response.data.listEnvironmentNetworks.count == 2 + assert response.data.listEnvironmentNetworks.count == 1 def test_list_networks_nopermissions(client, env_fixture, db, user, group2, vpc): @@ -119,4 +119,4 @@ def test_delete_network(client, env_fixture, db, user, group, module_mocker, vpc username='alice', groups=[group.name], ) - assert len(response.data.listEnvironmentNetworks['nodes']) == 1 + assert len(response.data.listEnvironmentNetworks['nodes']) == 0 diff --git a/tests/modules/mlstudio/cdk/conftest.py b/tests/modules/mlstudio/cdk/conftest.py index 4b3327838..2c6f1eddd 100644 --- a/tests/modules/mlstudio/cdk/conftest.py +++ b/tests/modules/mlstudio/cdk/conftest.py @@ -2,7 +2,7 @@ from dataall.core.environment.db.environment_models import Environment from dataall.core.organizations.db.organization_models import Organization -from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser +from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser, SagemakerStudioDomain @pytest.fixture(scope='module', autouse=True) @@ -23,3 +23,21 @@ def sgm_studio(db, env_fixture: Environment) -> SagemakerStudioUser: ) session.add(sm_user) yield sm_user + +@pytest.fixture(scope='module', autouse=True) +def sgm_studio_domain(db, env_fixture: Environment) -> SagemakerStudioDomain: + with db.scoped_session() as session: + sm_domain = SagemakerStudioDomain( + label='sagemaker-domain', + owner='me', + environmentUri=env_fixture.environmentUri, + AWSAccountId=env_fixture.AwsAccountId, + region=env_fixture.region, + SagemakerStudioStatus="PENDING", + DefaultDomainRoleName="DefaultMLStudioRole", + sagemakerStudioDomainName="DomainName", + vpcType="created", + SamlGroupName=env_fixture.SamlGroupName, + ) + session.add(sm_domain) + yield sm_domain diff --git a/tests/modules/mlstudio/cdk/test_sagemaker_studio_stack.py b/tests/modules/mlstudio/cdk/test_sagemaker_studio_stack.py index a2c1752e2..8e0cd6166 100644 --- a/tests/modules/mlstudio/cdk/test_sagemaker_studio_stack.py +++ b/tests/modules/mlstudio/cdk/test_sagemaker_studio_stack.py @@ -66,16 +66,19 @@ def patch_methods_sagemaker_studio(mocker, db, sgm_studio, env_fixture, org_fixt @pytest.fixture(scope='function', autouse=True) -def patch_methods_sagemaker_studio_extension(mocker): +def patch_methods_sagemaker_studio_extension(mocker, sgm_studio_domain): mocker.patch( 'dataall.base.aws.sts.SessionHelper.get_cdk_look_up_role_arn', return_value="arn:aws:iam::1111111111:role/cdk-hnb659fds-lookup-role-1111111111-eu-west-1", ) mocker.patch( - 'dataall.modules.mlstudio.aws.ec2_client.EC2.check_default_vpc_exists', + 'dataall.base.aws.ec2_client.EC2.check_default_vpc_exists', return_value=False, ) - + mocker.patch( + 'dataall.modules.mlstudio.db.mlstudio_repositories.SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri', + return_value=sgm_studio_domain, + ) def test_resources_sgmstudio_stack_created(sgm_studio): app = App() diff --git a/tests/modules/mlstudio/conftest.py b/tests/modules/mlstudio/conftest.py index 433048894..d1fffb2cf 100644 --- a/tests/modules/mlstudio/conftest.py +++ b/tests/modules/mlstudio/conftest.py @@ -1,6 +1,6 @@ import pytest -from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser +from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser, SagemakerStudioDomain @pytest.fixture(scope='module', autouse=True) @@ -16,8 +16,31 @@ def env_params(): yield {'mlStudiosEnabled': 'True'} +@pytest.fixture(scope='module', autouse=True) +def get_cdk_look_up_role_arn(module_mocker): + module_mocker.patch( + 'dataall.base.aws.sts.SessionHelper.get_cdk_look_up_role_arn', + return_value="arn:aws:iam::1111111111:role/cdk-hnb659fds-lookup-role-1111111111-eu-west-1", + ) + +@pytest.fixture(scope='module', autouse=True) +def check_default_vpc(module_mocker): + module_mocker.patch( + 'dataall.base.aws.ec2_client.EC2.check_default_vpc_exists', + return_value=False, + ) + + +@pytest.fixture(scope='module', autouse=True) +def check_vpc_exists(module_mocker): + module_mocker.patch( + 'dataall.base.aws.ec2_client.EC2.check_vpc_exists', + return_value=True, + ) + + @pytest.fixture(scope='module') -def sagemaker_studio_user(client, tenant, group, env_fixture) -> SagemakerStudioUser: +def sagemaker_studio_user(client, tenant, group, env_with_mlstudio) -> SagemakerStudioUser: response = client.query( """ mutation createSagemakerStudioUser($input:NewSagemakerStudioUserInput){ @@ -36,7 +59,7 @@ def sagemaker_studio_user(client, tenant, group, env_fixture) -> SagemakerStudio input={ 'label': 'testcreate', 'SamlAdminGroupName': group.name, - 'environmentUri': env_fixture.environmentUri, + 'environmentUri': env_with_mlstudio.environmentUri, }, username='alice', groups=[group.name], @@ -45,7 +68,7 @@ def sagemaker_studio_user(client, tenant, group, env_fixture) -> SagemakerStudio @pytest.fixture(scope='module') -def multiple_sagemaker_studio_users(client, db, env_fixture, group): +def multiple_sagemaker_studio_users(client, db, env_with_mlstudio, group): for i in range(0, 10): response = client.query( """ @@ -65,7 +88,7 @@ def multiple_sagemaker_studio_users(client, db, env_fixture, group): input={ 'label': f'test{i}', 'SamlAdminGroupName': group.name, - 'environmentUri': env_fixture.environmentUri, + 'environmentUri': env_with_mlstudio.environmentUri, }, username='alice', groups=[group.name], @@ -77,5 +100,92 @@ def multiple_sagemaker_studio_users(client, db, env_fixture, group): ) assert ( response.data.createSagemakerStudioUser.environmentUri - == env_fixture.environmentUri + == env_with_mlstudio.environmentUri ) + +@pytest.fixture(scope='module') +def env_with_mlstudio(client, org_fixture, user, group, parameters=None, vpcId='', subnetIds=[]): + if not parameters: + parameters = {'mlStudiosEnabled': 'True'} + response = client.query( + """mutation CreateEnv($input:NewEnvironmentInput){ + createEnvironment(input:$input){ + organization{ + organizationUri + } + environmentUri + label + AwsAccountId + SamlGroupName + region + name + owner + parameters { + key + value + } + } + }""", + username=f'{user.username}', + groups=['testadmins'], + input={ + 'label': f'dev', + 'description': '', + 'organizationUri': org_fixture.organizationUri, + 'AwsAccountId': '111111111111', + 'tags': [], + 'region': 'us-east-1', + 'SamlGroupName': 'testadmins', + 'parameters': [{'key': k, 'value': v} for k, v in parameters.items()], + 'vpcId': vpcId, + 'subnetIds': subnetIds + }, + ) + yield response.data.createEnvironment + + +@pytest.fixture(scope='module', autouse=True) +def org(client): + cache = {} + + def factory(orgname, owner, group): + key = orgname + owner + group + if cache.get(key): + print(f'returning item from cached key {key}') + return cache.get(key) + response = client.query( + """mutation CreateOrganization($input:NewOrganizationInput){ + createOrganization(input:$input){ + organizationUri + label + name + owner + SamlGroupName + } + }""", + username=f'{owner}', + groups=[group], + input={ + 'label': f'{orgname}', + 'description': f'test', + 'tags': ['a', 'b', 'c'], + 'SamlGroupName': f'{group}', + }, + ) + cache[key] = response.data.createOrganization + return cache[key] + + yield factory + + +@pytest.fixture(scope='module') +def org_fixture(org, user, group): + org1 = org('testorg', user.username, group.name) + yield org1 + + +@pytest.fixture(scope='module') +def env_mlstudio_fixture(env, org_fixture, user, group, tenant): + env1 = env_with_mlstudio(org_fixture, 'dev', 'alice', 'testadmins', '111111111111', 'eu-west-1') + yield env1 + diff --git a/tests/modules/mlstudio/test_sagemaker_studio.py b/tests/modules/mlstudio/test_sagemaker_studio.py index c55762522..3d90b405a 100644 --- a/tests/modules/mlstudio/test_sagemaker_studio.py +++ b/tests/modules/mlstudio/test_sagemaker_studio.py @@ -1,14 +1,43 @@ from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser -def test_create_sagemaker_studio_user(sagemaker_studio_user, group, env_fixture): +def test_create_sagemaker_studio_domain(db, client, org_fixture, env_with_mlstudio, user, group, vpcId="vpc-1234", subnetIds=["subnet"]): + response = client.query( + """ + query getEnvironmentMLStudioDomain($environmentUri: String) { + getEnvironmentMLStudioDomain(environmentUri: $environmentUri) { + sagemakerStudioUri + environmentUri + label + sagemakerStudioDomainName + DefaultDomainRoleName + vpcType + vpcId + subnetIds + owner + created + } + } + """, + environmentUri=env_with_mlstudio.environmentUri, + ) + + assert response.data.getEnvironmentMLStudioDomain.sagemakerStudioUri + assert response.data.getEnvironmentMLStudioDomain.label == f'{env_with_mlstudio.label}-domain' + assert response.data.getEnvironmentMLStudioDomain.vpcType == 'created' + assert len(response.data.getEnvironmentMLStudioDomain.vpcId) == 0 + assert len(response.data.getEnvironmentMLStudioDomain.subnetIds) == 0 + assert response.data.getEnvironmentMLStudioDomain.environmentUri == env_with_mlstudio.environmentUri + + +def test_create_sagemaker_studio_user(sagemaker_studio_user, group, env_with_mlstudio): """Testing that the conftest sagemaker studio user has been created correctly""" assert sagemaker_studio_user.label == 'testcreate' assert sagemaker_studio_user.SamlAdminGroupName == group.name - assert sagemaker_studio_user.environmentUri == env_fixture.environmentUri + assert sagemaker_studio_user.environmentUri == env_with_mlstudio.environmentUri -def test_list_sagemaker_studio_users(client, env_fixture, db, group, multiple_sagemaker_studio_users): +def test_list_sagemaker_studio_users(client, db, group, multiple_sagemaker_studio_users): response = client.query( """ query listSagemakerStudioUsers($filter:SagemakerStudioUserFilter!){ @@ -67,3 +96,114 @@ def test_delete_sagemaker_studio_user( sagemaker_studio_user.sagemakerStudioUserUri ) assert not n + +def update_env_query(): + query = """ + mutation UpdateEnv($environmentUri:String!,$input:ModifyEnvironmentInput){ + updateEnvironment(environmentUri:$environmentUri,input:$input){ + organization{ + organizationUri + } + label + AwsAccountId + region + SamlGroupName + owner + tags + resourcePrefix + parameters { + key + value + } + } + } + """ + return query + +def test_update_env_delete_domain(client, org_fixture, env_with_mlstudio, group, group2): + response = client.query( + update_env_query(), + username='alice', + environmentUri=env_with_mlstudio.environmentUri, + input={ + 'label': 'DEV', + 'tags': [], + 'parameters': [ + { + 'key': 'mlStudiosEnabled', + 'value': 'False' + } + ], + }, + groups=[group.name], + ) + + response = client.query( + """ + query getEnvironmentMLStudioDomain($environmentUri: String) { + getEnvironmentMLStudioDomain(environmentUri: $environmentUri) { + sagemakerStudioUri + environmentUri + label + sagemakerStudioDomainName + DefaultDomainRoleName + vpcType + vpcId + subnetIds + owner + created + } + } + """, + environmentUri=env_with_mlstudio.environmentUri, + ) + assert response.data.getEnvironmentMLStudioDomain is None + + +def test_update_env_create_domain_with_vpc(db, client, org_fixture, env_with_mlstudio, user, group): + response = client.query( + update_env_query(), + username='alice', + environmentUri=env_with_mlstudio.environmentUri, + input={ + 'label': 'dev', + 'tags': [], + 'vpcId': "vpc-12345", + 'subnetIds': ['subnet-12345', 'subnet-67890'], + 'parameters': [ + { + 'key': 'mlStudiosEnabled', + 'value': 'True' + } + ], + }, + groups=[group.name], + ) + + response = client.query( + """ + query getEnvironmentMLStudioDomain($environmentUri: String) { + getEnvironmentMLStudioDomain(environmentUri: $environmentUri) { + sagemakerStudioUri + environmentUri + label + sagemakerStudioDomainName + DefaultDomainRoleName + vpcType + vpcId + subnetIds + owner + created + } + } + """, + environmentUri=env_with_mlstudio.environmentUri, + ) + + assert response.data.getEnvironmentMLStudioDomain.sagemakerStudioUri + assert response.data.getEnvironmentMLStudioDomain.label == f'{env_with_mlstudio.label}-domain' + assert response.data.getEnvironmentMLStudioDomain.vpcType == 'imported' + assert response.data.getEnvironmentMLStudioDomain.vpcId == 'vpc-12345' + assert len(response.data.getEnvironmentMLStudioDomain.subnetIds) == 2 + assert response.data.getEnvironmentMLStudioDomain.environmentUri == env_with_mlstudio.environmentUri +