From 94c93d9ca10f4a23c904fb7248df891dfe3e051a Mon Sep 17 00:00:00 2001
From: Noah Paige <69586985+noah-paige@users.noreply.github.com>
Date: Fri, 8 Dec 2023 08:43:34 -0500
Subject: [PATCH] Byo vpc mlstudio (#894)
### Feature or Bugfix
- Feature
### Detail
- Enable SageMaker Studio Domain to be deployed in a already provisioned
VPC
### Relates
- https://github.com/awslabs/aws-dataall/issues/795
### Security
Please answer the questions below briefly where applicable, or write
`N/A`. Based on
[OWASP 10](https://owasp.org/Top10/en/).
- Does this PR introduce or modify any input fields or queries - this
includes
fetching data from storage outside the application (e.g. a database, an
S3 bucket)?
- Is the input sanitized?
- What precautions are you taking before deserializing the data you
consume?
- Is injection prevented by parametrizing queries?
- Have you ensured no `eval` or similar functions are used?
- Does this PR introduce any functionality or component that requires
authorization?
- How have you ensured it respects the existing AuthN/AuthZ mechanisms?
- Are you logging failed auth attempts?
- Are you using or adding any cryptographic features?
- Do you use a standard proven implementations?
- Are the used keys controlled by the customer? Where are they stored?
- Are you introducing any new policies/roles/users?
- Have you used the least-privilege principle? How?
By submitting this pull request, I confirm that my contribution is made
under the terms of the Apache 2.0 license.
---
backend/dataall/base/aws/ec2_client.py | 65 +++++++
.../dataall/base/utils/naming_convention.py | 1 +
.../core/environment/api/input_types.py | 15 +-
.../dataall/core/environment/api/resolvers.py | 28 ++-
.../services/environment_resource_manager.py | 15 +-
.../services/environment_service.py | 24 +--
.../dashboards/db/dashboard_repositories.py | 2 +-
backend/dataall/modules/mlstudio/__init__.py | 5 +-
.../dataall/modules/mlstudio/api/queries.py | 10 +
.../dataall/modules/mlstudio/api/resolvers.py | 9 +-
backend/dataall/modules/mlstudio/api/types.py | 24 +++
.../modules/mlstudio/aws/ec2_client.py | 27 ---
.../mlstudio/aws/sagemaker_studio_client.py | 20 +-
.../mlstudio/cdk/mlstudio_extension.py | 160 ++++++++--------
.../modules/mlstudio/db/mlstudio_models.py | 15 +-
.../mlstudio/db/mlstudio_repositories.py | 95 ++++++++--
.../mlstudio/services/mlstudio_service.py | 99 +++++++++-
...f5de322f_update_sagemaker_studio_domain.py | 178 ++++++++++++++++++
.../components/EnvironmentMLStudio.js | 156 +++++++++++++++
.../modules/Environments/components/index.js | 1 +
.../views/EnvironmentCreateForm.js | 69 ++++++-
.../Environments/views/EnvironmentEditForm.js | 76 +++++++-
.../Environments/views/EnvironmentView.js | 10 +
.../MLStudio/getEnvironmentMLStudioDomain.js | 23 +++
.../src/services/graphql/MLStudio/index.js | 1 +
frontend/src/services/graphql/index.js | 1 +
tests/core/conftest.py | 1 -
tests/core/environments/test_environment.py | 29 +--
tests/core/vpc/test_vpc.py | 4 +-
tests/modules/mlstudio/cdk/conftest.py | 20 +-
.../cdk/test_sagemaker_studio_stack.py | 9 +-
tests/modules/mlstudio/conftest.py | 122 +++++++++++-
.../modules/mlstudio/test_sagemaker_studio.py | 146 +++++++++++++-
33 files changed, 1207 insertions(+), 253 deletions(-)
create mode 100644 backend/dataall/base/aws/ec2_client.py
delete mode 100644 backend/dataall/modules/mlstudio/aws/ec2_client.py
create mode 100644 backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py
create mode 100644 frontend/src/modules/Environments/components/EnvironmentMLStudio.js
create mode 100644 frontend/src/services/graphql/MLStudio/getEnvironmentMLStudioDomain.js
create mode 100644 frontend/src/services/graphql/MLStudio/index.js
diff --git a/backend/dataall/base/aws/ec2_client.py b/backend/dataall/base/aws/ec2_client.py
new file mode 100644
index 000000000..06bd62c7a
--- /dev/null
+++ b/backend/dataall/base/aws/ec2_client.py
@@ -0,0 +1,65 @@
+import logging
+
+from dataall.base.aws.sts import SessionHelper
+from botocore.exceptions import ClientError
+
+log = logging.getLogger(__name__)
+
+
+class EC2:
+
+ @staticmethod
+ def get_client(account_id: str, region: str, role=None):
+ session = SessionHelper.remote_session(accountid=account_id, role=role)
+ return session.client('ec2', region_name=region)
+
+ @staticmethod
+ def check_default_vpc_exists(AwsAccountId: str, region: str, role=None):
+ log.info("Check that default VPC exists..")
+ client = EC2.get_client(account_id=AwsAccountId, region=region, role=role)
+ response = client.describe_vpcs(
+ Filters=[{'Name': 'isDefault', 'Values': ['true']}]
+ )
+ vpcs = response['Vpcs']
+ log.info(f"Default VPCs response: {vpcs}")
+ if vpcs:
+ vpc_id = vpcs[0]['VpcId']
+ subnetIds = EC2._get_vpc_subnets(AwsAccountId=AwsAccountId, region=region, vpc_id=vpc_id, role=role)
+ if subnetIds:
+ return vpc_id, subnetIds
+ return False
+
+ @staticmethod
+ def _get_vpc_subnets(AwsAccountId: str, region: str, vpc_id: str, role=None):
+ client = EC2.get_client(account_id=AwsAccountId, region=region, role=role)
+ response = client.describe_subnets(
+ Filters=[{'Name': 'vpc-id', 'Values': [vpc_id]}]
+ )
+ return [subnet['SubnetId'] for subnet in response['Subnets']]
+
+ @staticmethod
+ def check_vpc_exists(AwsAccountId, region, vpc_id, role=None, subnet_ids=[]):
+ try:
+ ec2 = EC2.get_client(account_id=AwsAccountId, region=region, role=role)
+ response = ec2.describe_vpcs(VpcIds=[vpc_id])
+ except ClientError as e:
+ log.exception(f'VPC Id {vpc_id} Not Found: {e}')
+ raise Exception(f'VPCNotFound: {vpc_id}')
+
+ try:
+ if subnet_ids:
+ response = ec2.describe_subnets(
+ Filters=[
+ {
+ 'Name': 'vpc-id',
+ 'Values': [vpc_id]
+ },
+ ],
+ SubnetIds=subnet_ids
+ )
+ except ClientError as e:
+ log.exception(f'Subnet Id {subnet_ids} Not Found: {e}')
+ raise Exception(f'VPCSubnetsNotFound: {subnet_ids}')
+
+ if not subnet_ids or len(response['Subnets']) != len(subnet_ids):
+ raise Exception(f'Not All Subnets: {subnet_ids} Are Within the Specified VPC Id {vpc_id}')
diff --git a/backend/dataall/base/utils/naming_convention.py b/backend/dataall/base/utils/naming_convention.py
index 3501fa71b..262964560 100644
--- a/backend/dataall/base/utils/naming_convention.py
+++ b/backend/dataall/base/utils/naming_convention.py
@@ -10,6 +10,7 @@ class NamingConventionPattern(Enum):
GLUE = {'regex': '[^a-zA-Z0-9_]', 'separator': '_', 'max_length': 63}
GLUE_ETL = {'regex': '[^a-zA-Z0-9-]', 'separator': '-', 'max_length': 52}
NOTEBOOK = {'regex': '[^a-zA-Z0-9-]', 'separator': '-', 'max_length': 63}
+ MLSTUDIO_DOMAIN = {'regex': '[^a-zA-Z0-9-]', 'separator': '-', 'max_length': 63}
DEFAULT = {'regex': '[^a-zA-Z0-9-_]', 'separator': '-', 'max_length': 63}
OPENSEARCH = {'regex': '[^a-z0-9-]', 'separator': '-', 'max_length': 27}
OPENSEARCH_SERVERLESS = {'regex': '[^a-z0-9-]', 'separator': '-', 'max_length': 31}
diff --git a/backend/dataall/core/environment/api/input_types.py b/backend/dataall/core/environment/api/input_types.py
index 9b618d0e5..27188f4ed 100644
--- a/backend/dataall/core/environment/api/input_types.py
+++ b/backend/dataall/core/environment/api/input_types.py
@@ -28,13 +28,11 @@
gql.Argument('description', gql.String),
gql.Argument('AwsAccountId', gql.NonNullableType(gql.String)),
gql.Argument('region', gql.NonNullableType(gql.String)),
- gql.Argument('vpcId', gql.String),
- gql.Argument('privateSubnetIds', gql.ArrayType(gql.String)),
- gql.Argument('publicSubnetIds', gql.ArrayType(gql.String)),
gql.Argument('EnvironmentDefaultIAMRoleArn', gql.String),
gql.Argument('resourcePrefix', gql.String),
- gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput))
-
+ gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)),
+ gql.Argument('vpcId', gql.String),
+ gql.Argument('subnetIds', gql.ArrayType(gql.String))
],
)
@@ -45,11 +43,10 @@
gql.Argument('description', gql.String),
gql.Argument('tags', gql.ArrayType(gql.String)),
gql.Argument('SamlGroupName', gql.String),
- gql.Argument('vpcId', gql.String),
- gql.Argument('privateSubnetIds', gql.ArrayType(gql.String)),
- gql.Argument('publicSubnetIds', gql.ArrayType(gql.String)),
gql.Argument('resourcePrefix', gql.String),
- gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput))
+ gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)),
+ gql.Argument('vpcId', gql.String),
+ gql.Argument('subnetIds', gql.ArrayType(gql.String))
],
)
diff --git a/backend/dataall/core/environment/api/resolvers.py b/backend/dataall/core/environment/api/resolvers.py
index 7f3e4c765..06878cdfc 100644
--- a/backend/dataall/core/environment/api/resolvers.py
+++ b/backend/dataall/core/environment/api/resolvers.py
@@ -20,6 +20,7 @@
from dataall.core.stacks.aws.cloudformation import CloudFormation
from dataall.core.stacks.db.stack_repositories import Stack
from dataall.core.vpc.db.vpc_repositories import Vpc
+from dataall.base.aws.ec2_client import EC2
from dataall.base.db import exceptions
from dataall.core.permissions import permissions
from dataall.base.feature_toggle_checker import is_feature_enabled
@@ -43,7 +44,7 @@ def get_pivot_role_as_part_of_environment(context: Context, source, **kwargs):
return True if ssm_param == "True" else False
-def check_environment(context: Context, source, account_id, region):
+def check_environment(context: Context, source, account_id, region, data):
""" Checks necessary resources for environment deployment.
- Check CDKToolkit exists in Account assuming cdk_look_up_role
- Check Pivot Role exists in Account if pivot_role_as_part_of_environment is False
@@ -71,11 +72,25 @@ def check_environment(context: Context, source, account_id, region):
action='CHECK_PIVOT_ROLE',
message='Pivot Role has not been created in the Environment AWS Account',
)
+ mlStudioEnabled = None
+ for parameter in data.get("parameters", []):
+ if parameter['key'] == 'mlStudiosEnabled':
+ mlStudioEnabled = parameter['value']
+
+ if mlStudioEnabled and data.get("vpcId", None) and data.get("subnetIds", []):
+ log.info("Check if ML Studio VPC Exists in the Account")
+ EC2.check_vpc_exists(
+ AwsAccountId=account_id,
+ region=region,
+ role=cdk_look_up_role_arn,
+ vpc_id=data.get("vpcId", None),
+ subnet_ids=data.get('subnetIds', []),
+ )
return cdk_role_name
-def create_environment(context: Context, source, input=None):
+def create_environment(context: Context, source, input={}):
if input.get('SamlGroupName') and input.get('SamlGroupName') not in context.groups:
raise exceptions.UnauthorizedOperation(
action=permissions.LINK_ENVIRONMENT,
@@ -85,8 +100,10 @@ def create_environment(context: Context, source, input=None):
with context.engine.scoped_session() as session:
cdk_role_name = check_environment(context, source,
account_id=input.get('AwsAccountId'),
- region=input.get('region')
+ region=input.get('region'),
+ data=input
)
+
input['cdk_role_name'] = cdk_role_name
env = EnvironmentService.create_environment(
session=session,
@@ -119,7 +136,8 @@ def update_environment(
environment = EnvironmentService.get_environment_by_uri(session, environmentUri)
cdk_role_name = check_environment(context, source,
account_id=environment.AwsAccountId,
- region=environment.region
+ region=environment.region,
+ data=input
)
previous_resource_prefix = environment.resourcePrefix
@@ -130,7 +148,7 @@ def update_environment(
data=input,
)
- if EnvironmentResourceManager.deploy_updated_stack(session, previous_resource_prefix, environment):
+ if EnvironmentResourceManager.deploy_updated_stack(session, previous_resource_prefix, environment, data=input):
stack_helper.deploy_stack(targetUri=environment.environmentUri)
return environment
diff --git a/backend/dataall/core/environment/services/environment_resource_manager.py b/backend/dataall/core/environment/services/environment_resource_manager.py
index bc74f01bf..f5c2551fa 100644
--- a/backend/dataall/core/environment/services/environment_resource_manager.py
+++ b/backend/dataall/core/environment/services/environment_resource_manager.py
@@ -12,7 +12,11 @@ def delete_env(session, environment):
pass
@staticmethod
- def update_env(session, environment):
+ def create_env(session, environment, **kwargs):
+ pass
+
+ @staticmethod
+ def update_env(session, environment, **kwargs):
return False
@staticmethod
@@ -39,10 +43,10 @@ def count_group_resources(cls, session, environment, group_uri) -> int:
return counter
@classmethod
- def deploy_updated_stack(cls, session, prev_prefix, environment):
+ def deploy_updated_stack(cls, session, prev_prefix, environment, **kwargs):
deploy_stack = prev_prefix != environment.resourcePrefix
for resource in cls._resources:
- deploy_stack |= resource.update_env(session, environment)
+ deploy_stack |= resource.update_env(session, environment, **kwargs)
return deploy_stack
@@ -51,6 +55,11 @@ def delete_env(cls, session, environment):
for resource in cls._resources:
resource.delete_env(session, environment)
+ @classmethod
+ def create_env(cls, session, environment, **kwargs):
+ for resource in cls._resources:
+ resource.create_env(session, environment, **kwargs)
+
@classmethod
def count_consumption_role_resources(cls, session, role_uri):
counter = 0
diff --git a/backend/dataall/core/environment/services/environment_service.py b/backend/dataall/core/environment/services/environment_service.py
index ddea435c4..1b2dbec07 100644
--- a/backend/dataall/core/environment/services/environment_service.py
+++ b/backend/dataall/core/environment/services/environment_service.py
@@ -66,6 +66,7 @@ def create_environment(session, uri, data=None):
session.commit()
EnvironmentService._update_env_parameters(session, env, data)
+ EnvironmentResourceManager.create_env(session, env, data=data)
env.EnvironmentDefaultBucketName = NamingConventionService(
target_uri=env.environmentUri,
@@ -98,29 +99,6 @@ def create_environment(session, uri, data=None):
env.EnvironmentDefaultIAMRoleArn = data['EnvironmentDefaultIAMRoleArn']
env.EnvironmentDefaultIAMRoleImported = True
- if data.get('vpcId'):
- vpc = Vpc(
- environmentUri=env.environmentUri,
- region=env.region,
- AwsAccountId=env.AwsAccountId,
- VpcId=data.get('vpcId'),
- privateSubnetIds=data.get('privateSubnetIds', []),
- publicSubnetIds=data.get('publicSubnetIds', []),
- SamlGroupName=data['SamlGroupName'],
- owner=context.username,
- label=f"{env.name}-{data.get('vpcId')}",
- name=f"{env.name}-{data.get('vpcId')}",
- default=True,
- )
- session.add(vpc)
- session.commit()
- ResourcePolicy.attach_resource_policy(
- session=session,
- group=data['SamlGroupName'],
- permissions=permissions.NETWORK_ALL,
- resource_uri=vpc.vpcUri,
- resource_type=Vpc.__name__,
- )
env_group = EnvironmentGroup(
environmentUri=env.environmentUri,
groupUri=data['SamlGroupName'],
diff --git a/backend/dataall/modules/dashboards/db/dashboard_repositories.py b/backend/dataall/modules/dashboards/db/dashboard_repositories.py
index 91916f8ff..a8d9d6a2f 100644
--- a/backend/dataall/modules/dashboards/db/dashboard_repositories.py
+++ b/backend/dataall/modules/dashboards/db/dashboard_repositories.py
@@ -26,7 +26,7 @@ def count_resources(session, environment, group_uri) -> int:
)
@staticmethod
- def update_env(session, environment):
+ def update_env(session, environment, **kwargs):
return EnvironmentService.get_boolean_env_param(session, environment, "dashboardsEnabled")
@staticmethod
diff --git a/backend/dataall/modules/mlstudio/__init__.py b/backend/dataall/modules/mlstudio/__init__.py
index 2db9c0a1e..a6ca73917 100644
--- a/backend/dataall/modules/mlstudio/__init__.py
+++ b/backend/dataall/modules/mlstudio/__init__.py
@@ -3,7 +3,8 @@
from dataall.base.loader import ImportMode, ModuleInterface
from dataall.core.stacks.db.target_type_repositories import TargetType
-from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository
+from dataall.modules.mlstudio.services.mlstudio_service import SagemakerStudioEnvironmentResource
+from dataall.core.environment.services.environment_resource_manager import EnvironmentResourceManager
log = logging.getLogger(__name__)
@@ -20,6 +21,8 @@ def __init__(self):
from dataall.modules.mlstudio.services.mlstudio_permissions import GET_SGMSTUDIO_USER, UPDATE_SGMSTUDIO_USER
TargetType("mlstudio", GET_SGMSTUDIO_USER, UPDATE_SGMSTUDIO_USER)
+ EnvironmentResourceManager.register(SagemakerStudioEnvironmentResource())
+
log.info("API of sagemaker mlstudio has been imported")
diff --git a/backend/dataall/modules/mlstudio/api/queries.py b/backend/dataall/modules/mlstudio/api/queries.py
index 457559def..ee014839f 100644
--- a/backend/dataall/modules/mlstudio/api/queries.py
+++ b/backend/dataall/modules/mlstudio/api/queries.py
@@ -4,6 +4,7 @@
get_sagemaker_studio_user,
list_sagemaker_studio_users,
get_sagemaker_studio_user_presigned_url,
+ get_environment_sagemaker_studio_domain
)
getSagemakerStudioUser = gql.QueryField(
@@ -34,3 +35,12 @@
type=gql.String,
resolver=get_sagemaker_studio_user_presigned_url,
)
+
+getEnvironmentMLStudioDomain = gql.QueryField(
+ name='getEnvironmentMLStudioDomain',
+ args=[
+ gql.Argument(name='environmentUri', type=gql.NonNullableType(gql.String)),
+ ],
+ type=gql.Ref('SagemakerStudioDomain'),
+ resolver=get_environment_sagemaker_studio_domain,
+)
diff --git a/backend/dataall/modules/mlstudio/api/resolvers.py b/backend/dataall/modules/mlstudio/api/resolvers.py
index 63dc25ed7..48c9350fa 100644
--- a/backend/dataall/modules/mlstudio/api/resolvers.py
+++ b/backend/dataall/modules/mlstudio/api/resolvers.py
@@ -18,7 +18,7 @@ def required_uri(uri):
raise exceptions.RequiredParameter('URI')
@staticmethod
- def validate_creation_request(data):
+ def validate_user_creation_request(data):
required = RequestValidator._required
if not data:
raise exceptions.RequiredParameter('data')
@@ -36,7 +36,7 @@ def _required(data: dict, name: str):
def create_sagemaker_studio_user(context: Context, source, input: dict = None):
"""Creates a SageMaker Studio user. Deploys the SageMaker Studio user stack into AWS"""
- RequestValidator.validate_creation_request(input)
+ RequestValidator.validate_user_creation_request(input)
request = SagemakerStudioCreationRequest.from_dict(input)
return SagemakerStudioService.create_sagemaker_studio_user(
uri=input["environmentUri"],
@@ -90,6 +90,11 @@ def delete_sagemaker_studio_user(
)
+def get_environment_sagemaker_studio_domain(context, source, environmentUri: str = None):
+ RequestValidator.required_uri(environmentUri)
+ return SagemakerStudioService.get_environment_sagemaker_studio_domain(environment_uri=environmentUri)
+
+
def resolve_user_role(context: Context, source: SagemakerStudioUser):
"""
Resolves the role of the current user in reference with the SageMaker Studio User
diff --git a/backend/dataall/modules/mlstudio/api/types.py b/backend/dataall/modules/mlstudio/api/types.py
index 21290711e..ca21df81d 100644
--- a/backend/dataall/modules/mlstudio/api/types.py
+++ b/backend/dataall/modules/mlstudio/api/types.py
@@ -79,3 +79,27 @@
gql.Field(name='nodes', type=gql.ArrayType(SagemakerStudioUser)),
],
)
+
+SagemakerStudioDomain = gql.ObjectType(
+ name='SagemakerStudioDomain',
+ fields=[
+ gql.Field(name='sagemakerStudioUri', type=gql.ID),
+ gql.Field(name='environmentUri', type=gql.NonNullableType(gql.String)),
+ gql.Field(name='sagemakerStudioDomainName', type=gql.String),
+ gql.Field(name='DefaultDomainRoleName', type=gql.String),
+ gql.Field(name='label', type=gql.String),
+ gql.Field(name='name', type=gql.String),
+ gql.Field(name='vpcType', type=gql.String),
+ gql.Field(name='vpcId', type=gql.String),
+ gql.Field(name='subnetIds', type=gql.ArrayType(gql.String)),
+ gql.Field(name='owner', type=gql.String),
+ gql.Field(name='created', type=gql.String),
+ gql.Field(name='updated', type=gql.String),
+ gql.Field(name='deleted', type=gql.String),
+ gql.Field(
+ name='environment',
+ type=gql.Ref('Environment'),
+ resolver=resolve_environment,
+ )
+ ],
+)
diff --git a/backend/dataall/modules/mlstudio/aws/ec2_client.py b/backend/dataall/modules/mlstudio/aws/ec2_client.py
deleted file mode 100644
index 3dc484254..000000000
--- a/backend/dataall/modules/mlstudio/aws/ec2_client.py
+++ /dev/null
@@ -1,27 +0,0 @@
-import logging
-
-from dataall.base.aws.sts import SessionHelper
-
-
-log = logging.getLogger(__name__)
-
-
-class EC2:
-
- @staticmethod
- def get_client(account_id: str, region: str, role=None):
- session = SessionHelper.remote_session(accountid=account_id, role=role)
- return session.client('ec2', region_name=region)
-
- @staticmethod
- def check_default_vpc_exists(AwsAccountId: str, region: str, role=None):
- log.info("Check that default VPC exists..")
- client = EC2.get_client(account_id=AwsAccountId, region=region, role=role)
- response = client.describe_vpcs(
- Filters=[{'Name': 'isDefault', 'Values': ['true']}]
- )
- vpcs = response['Vpcs']
- log.info(f"Default VPCs response: {vpcs}")
- if vpcs:
- return True
- return False
diff --git a/backend/dataall/modules/mlstudio/aws/sagemaker_studio_client.py b/backend/dataall/modules/mlstudio/aws/sagemaker_studio_client.py
index 2a82806ea..2ee872b1c 100644
--- a/backend/dataall/modules/mlstudio/aws/sagemaker_studio_client.py
+++ b/backend/dataall/modules/mlstudio/aws/sagemaker_studio_client.py
@@ -12,28 +12,22 @@ def get_client(AwsAccountId, region):
return session.client('sagemaker', region_name=region)
-def get_sagemaker_studio_domain(AwsAccountId, region):
+def get_sagemaker_studio_domain(AwsAccountId, region, domain_name):
"""
Sagemaker studio domain is limited to 5 per account/region
RETURN: an existing domain or None if no domain is in the AWS account
"""
client = get_client(AwsAccountId=AwsAccountId, region=region)
- existing_domain = dict()
try:
domain_id_paginator = client.get_paginator('list_domains')
- domains = domain_id_paginator.paginate()
- for _domain in domains:
- print(_domain)
- for _domain in _domain.get('Domains'):
- # Get the domain name created by dataall
- if 'dataall' in _domain:
- return _domain
- else:
- existing_domain = _domain
- return existing_domain
+ for page in domain_id_paginator.paginate():
+ for domain in page.get('Domains', []):
+ if domain.get("DomainName") == domain_name:
+ return domain
+ return dict()
except ClientError as e:
print(e)
- return 'NotFound'
+ return dict()
class SagemakerStudioClient:
diff --git a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py
index fe9040ab9..49082ccfb 100644
--- a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py
+++ b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py
@@ -12,14 +12,12 @@
aws_ssm as ssm,
RemovalPolicy,
)
-from botocore.exceptions import ClientError
+from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository
-from dataall.base.aws.parameter_store import ParameterStoreManager
from dataall.base.aws.sts import SessionHelper
from dataall.core.environment.cdk.environment_stack import EnvironmentSetup, EnvironmentStackExtension
from dataall.core.environment.services.environment_service import EnvironmentService
-from dataall.modules.mlstudio.aws.ec2_client import EC2
-from dataall.modules.mlstudio.aws.sagemaker_studio_client import get_sagemaker_studio_domain
+from dataall.base.aws.ec2_client import EC2
logger = logging.getLogger(__name__)
@@ -31,75 +29,84 @@ def extent(setup: EnvironmentSetup):
_environment = setup.environment()
with setup.get_engine().scoped_session() as session:
enabled = EnvironmentService.get_boolean_env_param(session, _environment, "mlStudiosEnabled")
- if not enabled:
+ domain = SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, _environment.environmentUri)
+ if not enabled or not domain:
return
sagemaker_principals = [setup.default_role] + setup.group_roles
logger.info(f'Creating SageMaker base resources for sagemaker_principals = {sagemaker_principals}..')
- cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn(
- accountid=_environment.AwsAccountId, region=_environment.region
- )
- existing_default_vpc = EC2.check_default_vpc_exists(
- AwsAccountId=_environment.AwsAccountId, region=_environment.region, role=cdk_look_up_role_arn
- )
- if existing_default_vpc:
- logger.info("Using default VPC for Sagemaker Studio domain")
- # Use default VPC - initial configuration (to be migrated)
- vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', is_default=True)
- subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets]
- subnet_ids += [public_subnet.subnet_id for public_subnet in vpc.public_subnets]
- subnet_ids += [isolated_subnet.subnet_id for isolated_subnet in vpc.isolated_subnets]
+
+ if domain.vpcId and domain.subnetIds and domain.vpcType == 'imported':
+ logger.info(f'Using VPC {domain.vpcId} and subnets {domain.subnetIds} for SageMaker Studio domain')
+ vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', vpc_id=domain.vpcId)
+ subnet_ids = domain.subnetIds
security_groups = []
else:
- logger.info("Default VPC not found, Exception. Creating a VPC for SageMaker resources...")
- # Create VPC with 3 Public Subnets and 3 Private subnets wit NAT Gateways
- log_group = logs.LogGroup(
- setup,
- f'SageMakerStudio{_environment.name}',
- log_group_name=f'/{_environment.resourcePrefix}/{_environment.name}/vpc/sagemakerstudio',
- retention=logs.RetentionDays.ONE_MONTH,
- removal_policy=RemovalPolicy.DESTROY,
- )
- vpc_flow_role = iam.Role(
- setup, 'FlowLog',
- assumed_by=iam.ServicePrincipal('vpc-flow-logs.amazonaws.com')
- )
- vpc = ec2.Vpc(
- setup,
- "SageMakerVPC",
- max_azs=3,
- cidr="10.10.0.0/16",
- subnet_configuration=[
- ec2.SubnetConfiguration(
- subnet_type=ec2.SubnetType.PUBLIC,
- name="Public",
- cidr_mask=24
- ),
- ec2.SubnetConfiguration(
- subnet_type=ec2.SubnetType.PRIVATE_WITH_NAT,
- name="Private",
- cidr_mask=24
- ),
- ],
- enable_dns_hostnames=True,
- enable_dns_support=True,
- )
- ec2.FlowLog(
- setup, "StudioVPCFlowLog",
- resource_type=ec2.FlowLogResourceType.from_vpc(vpc),
- destination=ec2.FlowLogDestination.to_cloud_watch_logs(log_group, vpc_flow_role)
+ cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn(
+ accountid=_environment.AwsAccountId, region=_environment.region
)
- # setup security group to be used for sagemaker studio domain
- sagemaker_sg = ec2.SecurityGroup(
- setup,
- "SecurityGroup",
- vpc=vpc,
- description="Security Group for SageMaker Studio",
+ existing_default_vpc = EC2.check_default_vpc_exists(
+ AwsAccountId=_environment.AwsAccountId, region=_environment.region, role=cdk_look_up_role_arn
)
+ if existing_default_vpc:
+ logger.info("Using default VPC for Sagemaker Studio domain")
+ # Use default VPC - initial configuration (to be migrated)
+ vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', is_default=True)
+ subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets]
+ subnet_ids += [public_subnet.subnet_id for public_subnet in vpc.public_subnets]
+ subnet_ids += [isolated_subnet.subnet_id for isolated_subnet in vpc.isolated_subnets]
+ security_groups = []
+ else:
+ logger.info("Default VPC not found, Exception. Creating a VPC for SageMaker resources...")
+ # Create VPC with 3 Public Subnets and 3 Private subnets wit NAT Gateways
+ log_group = logs.LogGroup(
+ setup,
+ f'SageMakerStudio{_environment.name}',
+ log_group_name=f'/{_environment.resourcePrefix}/{_environment.name}/vpc/sagemakerstudio',
+ retention=logs.RetentionDays.ONE_MONTH,
+ removal_policy=RemovalPolicy.DESTROY,
+ )
+ vpc_flow_role = iam.Role(
+ setup, 'FlowLog',
+ assumed_by=iam.ServicePrincipal('vpc-flow-logs.amazonaws.com')
+ )
+ vpc = ec2.Vpc(
+ setup,
+ "SageMakerVPC",
+ max_azs=3,
+ cidr="10.10.0.0/16",
+ subnet_configuration=[
+ ec2.SubnetConfiguration(
+ subnet_type=ec2.SubnetType.PUBLIC,
+ name="Public",
+ cidr_mask=24
+ ),
+ ec2.SubnetConfiguration(
+ subnet_type=ec2.SubnetType.PRIVATE_WITH_NAT,
+ name="Private",
+ cidr_mask=24
+ ),
+ ],
+ enable_dns_hostnames=True,
+ enable_dns_support=True,
+ )
+ ec2.FlowLog(
+ setup, "StudioVPCFlowLog",
+ resource_type=ec2.FlowLogResourceType.from_vpc(vpc),
+ destination=ec2.FlowLogDestination.to_cloud_watch_logs(log_group, vpc_flow_role)
+ )
+ # setup security group to be used for sagemaker studio domain
+ sagemaker_sg = ec2.SecurityGroup(
+ setup,
+ "SecurityGroup",
+ vpc=vpc,
+ description="Security Group for SageMaker Studio",
+ security_group_name=domain.sagemakerStudioDomainName,
+ )
- sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic())
- security_groups = [sagemaker_sg.security_group_id]
- subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets]
+ sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic())
+ security_groups = [sagemaker_sg.security_group_id]
+ subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets]
vpc_id = vpc.vpc_id
@@ -107,7 +114,7 @@ def extent(setup: EnvironmentSetup):
setup,
'RoleForSagemakerStudioUsers',
assumed_by=iam.ServicePrincipal('sagemaker.amazonaws.com'),
- role_name='RoleSagemakerStudioUsers',
+ role_name=domain.DefaultDomainRoleName,
managed_policies=[
iam.ManagedPolicy.from_managed_policy_arn(
setup,
@@ -123,7 +130,7 @@ def extent(setup: EnvironmentSetup):
sagemaker_domain_key = kms.Key(
setup,
'SagemakerDomainKmsKey',
- alias='SagemakerStudioDomain',
+ alias=domain.sagemakerStudioDomainName,
enable_key_rotation=True,
admins=[
iam.ArnPrincipal(_environment.CDKRoleArn)
@@ -175,7 +182,7 @@ def extent(setup: EnvironmentSetup):
sagemaker_domain = sagemaker.CfnDomain(
setup,
'SagemakerStudioDomain',
- domain_name=f'SagemakerStudioDomain-{_environment.region}-{_environment.AwsAccountId}',
+ domain_name=domain.sagemakerStudioDomainName,
auth_mode='IAM',
default_user_settings=sagemaker.CfnDomain.UserSettingsProperty(
execution_role=sagemaker_domain_role.role_arn,
@@ -199,22 +206,3 @@ def extent(setup: EnvironmentSetup):
parameter_name=f'/{_environment.resourcePrefix}/{_environment.environmentUri}/sagemaker/sagemakerstudio/domain_id',
)
return sagemaker_domain
-
- @staticmethod
- def check_existing_sagemaker_studio_domain(environment):
- logger.info('Check if there is an existing sagemaker studio domain in the account')
- try:
- logger.info('check sagemaker studio domain created as part of data.all environment stack.')
- cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn(
- accountid=environment.AwsAccountId, region=environment.region
- )
- dataall_created_domain = ParameterStoreManager.client(
- AwsAccountId=environment.AwsAccountId, region=environment.region, role=cdk_look_up_role_arn
- ).get_parameter(Name=f'/{environment.resourcePrefix}/{environment.environmentUri}/sagemaker/sagemakerstudio/domain_id')
- return False
- except ClientError as e:
- logger.info(f'check sagemaker studio domain created outside of data.all. Parameter data.all not found: {e}')
- existing_domain = get_sagemaker_studio_domain(
- AwsAccountId=environment.AwsAccountId, region=environment.region, role=cdk_look_up_role_arn
- )
- return existing_domain.get('DomainId', False)
diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_models.py b/backend/dataall/modules/mlstudio/db/mlstudio_models.py
index 032826588..a4c93a2fa 100644
--- a/backend/dataall/modules/mlstudio/db/mlstudio_models.py
+++ b/backend/dataall/modules/mlstudio/db/mlstudio_models.py
@@ -2,6 +2,7 @@
from sqlalchemy import Column, String, ForeignKey
from sqlalchemy.orm import query_expression
+from sqlalchemy.dialects.postgresql import ARRAY
from dataall.base.db import Base
from dataall.base.db import Resource, utils
@@ -10,16 +11,20 @@
class SagemakerStudioDomain(Resource, Base):
"""Describes ORM model for sagemaker ML Studio domain"""
__tablename__ = 'sagemaker_studio_domain'
- environmentUri = Column(String, nullable=False)
+ environmentUri = Column(String, ForeignKey("environment.environmentUri"))
sagemakerStudioUri = Column(
String, primary_key=True, default=utils.uuid('sagemakerstudio')
)
- sagemakerStudioDomainID = Column(String, nullable=False)
- SagemakerStudioStatus = Column(String, nullable=False)
+ sagemakerStudioDomainID = Column(String, nullable=True)
+ SagemakerStudioStatus = Column(String, nullable=True)
+ sagemakerStudioDomainName = Column(String, nullable=False)
AWSAccountId = Column(String, nullable=False)
- RoleArn = Column(String, nullable=False)
+ DefaultDomainRoleName = Column(String, nullable=False)
region = Column(String, default='eu-west-1')
- userRoleForSagemakerStudio = query_expression()
+ SamlGroupName = Column(String, nullable=False)
+ vpcType = Column(String, nullable=True)
+ vpcId = Column(String, nullable=True)
+ subnetIds = Column(ARRAY(String), nullable=True)
class SagemakerStudioUser(Resource, Base):
diff --git a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py
index 763ca6f92..21847b6ef 100644
--- a/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py
+++ b/backend/dataall/modules/mlstudio/db/mlstudio_repositories.py
@@ -2,30 +2,34 @@
DAO layer that encapsulates the logic and interaction with the database for ML Studio
Provides the API to retrieve / update / delete ml studio
"""
+from typing import Optional
from sqlalchemy import or_
from sqlalchemy.sql import and_
from sqlalchemy.orm import Query
+from dataall.base.utils import slugify
from dataall.base.db import paginate
-from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser
-from dataall.core.environment.services.environment_resource_manager import EnvironmentResource
+from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioDomain, SagemakerStudioUser
+from dataall.base.utils.naming_convention import (
+ NamingConventionService,
+ NamingConventionPattern,
+)
-class SageMakerStudioRepository(EnvironmentResource):
+class SageMakerStudioRepository:
"""DAO layer for ML Studio"""
_DEFAULT_PAGE = 1
_DEFAULT_PAGE_SIZE = 10
- def __init__(self, session):
- self._session = session
-
- def save_sagemaker_studio_user(self, user):
+ @staticmethod
+ def save_sagemaker_studio_user(session, user):
"""Save SageMaker Studio user to the database"""
- self._session.add(user)
- self._session.commit()
+ session.add(user)
+ session.commit()
- def _query_user_sagemaker_studio_users(self, username, groups, filter) -> Query:
- query = self._session.query(SagemakerStudioUser).filter(
+ @staticmethod
+ def _query_user_sagemaker_studio_users(session, username, groups, filter) -> Query:
+ query = session.query(SagemakerStudioUser).filter(
or_(
SagemakerStudioUser.owner == username,
SagemakerStudioUser.SamlAdminGroupName.in_(groups),
@@ -44,21 +48,24 @@ def _query_user_sagemaker_studio_users(self, username, groups, filter) -> Query:
)
return query
- def paginated_sagemaker_studio_users(self, username, groups, filter=None) -> dict:
+ @staticmethod
+ def paginated_sagemaker_studio_users(session, username, groups, filter={}) -> dict:
"""Returns a page of sagemaker studio users for a data.all user"""
return paginate(
- query=self._query_user_sagemaker_studio_users(username, groups, filter),
+ query=SageMakerStudioRepository._query_user_sagemaker_studio_users(session, username, groups, filter),
page=filter.get('page', SageMakerStudioRepository._DEFAULT_PAGE),
page_size=filter.get('pageSize', SageMakerStudioRepository._DEFAULT_PAGE_SIZE),
).to_dict()
- def find_sagemaker_studio_user(self, uri):
+ @staticmethod
+ def find_sagemaker_studio_user(session, uri):
"""Finds a sagemaker studio user. Returns None if it doesn't exist"""
- return self._session.query(SagemakerStudioUser).get(uri)
+ return session.query(SagemakerStudioUser).get(uri)
- def count_resources(self, environment, group_uri):
+ @staticmethod
+ def count_resources(session, environment, group_uri):
return (
- self._session.query(SagemakerStudioUser)
+ session.query(SagemakerStudioUser)
.filter(
and_(
SagemakerStudioUser.environmentUri == environment.environmentUri,
@@ -67,3 +74,57 @@ def count_resources(self, environment, group_uri):
)
.count()
)
+
+ @staticmethod
+ def create_sagemaker_studio_domain(session, username, environment, data):
+ domain = SagemakerStudioDomain(
+ label=f"{data.get('label')}-domain",
+ owner=username,
+ description=data.get('description', 'No description provided'),
+ tags=data.get('tags', []),
+ SamlGroupName=environment.SamlGroupName,
+ environmentUri=environment.environmentUri,
+ AWSAccountId=environment.AwsAccountId,
+ region=environment.region,
+ SagemakerStudioStatus="PENDING",
+ DefaultDomainRoleName="DefaultMLStudioRole",
+ sagemakerStudioDomainName=slugify(data.get('label'), separator=''),
+ vpcType=data.get('vpcType'),
+ vpcId=data.get('vpcId'),
+ subnetIds=data.get('subnetIds', [])
+ )
+ session.add(domain)
+ session.commit()
+
+ domain.sagemakerStudioDomainName = NamingConventionService(
+ target_uri=domain.sagemakerStudioUri,
+ target_label=domain.label,
+ pattern=NamingConventionPattern.MLSTUDIO_DOMAIN,
+ resource_prefix=environment.resourcePrefix,
+ ).build_compliant_name()
+
+ domain.DefaultDomainRoleName = NamingConventionService(
+ target_uri=domain.sagemakerStudioUri,
+ target_label=domain.label,
+ pattern=NamingConventionPattern.IAM,
+ resource_prefix=environment.resourcePrefix,
+ ).build_compliant_name()
+
+ return domain
+
+ @staticmethod
+ def get_sagemaker_studio_domain_by_env_uri(session, env_uri) -> Optional[SagemakerStudioDomain]:
+ domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter(
+ SagemakerStudioDomain.environmentUri == env_uri,
+ ).first()
+ if not domain:
+ return None
+ return domain
+
+ @staticmethod
+ def delete_sagemaker_studio_domain_by_env_uri(session, env_uri) -> Optional[SagemakerStudioDomain]:
+ domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter(
+ SagemakerStudioDomain.environmentUri == env_uri,
+ ).first()
+ if domain:
+ session.delete(domain)
diff --git a/backend/dataall/modules/mlstudio/services/mlstudio_service.py b/backend/dataall/modules/mlstudio/services/mlstudio_service.py
index 06750b822..3738c118d 100644
--- a/backend/dataall/modules/mlstudio/services/mlstudio_service.py
+++ b/backend/dataall/modules/mlstudio/services/mlstudio_service.py
@@ -11,13 +11,18 @@
from dataall.core.environment.env_permission_checker import has_group_permission
from dataall.core.environment.services.environment_service import EnvironmentService
from dataall.core.permissions.db.resource_policy_repositories import ResourcePolicy
+from dataall.core.permissions import permissions
from dataall.core.permissions.permission_checker import has_resource_permission, has_tenant_permission
from dataall.core.stacks.api import stack_helper
from dataall.core.stacks.db.stack_repositories import Stack
from dataall.base.db import exceptions
from dataall.modules.mlstudio.aws.sagemaker_studio_client import sagemaker_studio_client, get_sagemaker_studio_domain
from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository
+from dataall.core.environment.services.environment_resource_manager import EnvironmentResource
from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser
+from dataall.base.aws.ec2_client import EC2
+from dataall.base.aws.sts import SessionHelper
+
from dataall.modules.mlstudio.services.mlstudio_permissions import (
MANAGE_SGMSTUDIO_USERS,
CREATE_SGMSTUDIO_USER,
@@ -54,6 +59,38 @@ def _session():
return get_context().db_engine.scoped_session()
+class SagemakerStudioEnvironmentResource(EnvironmentResource):
+ @staticmethod
+ def count_resources(session, environment, group_uri) -> int:
+ return SageMakerStudioRepository.count_resources(session, environment, group_uri)
+
+ @staticmethod
+ def create_env(session, environment, **kwargs):
+ enabled = EnvironmentService.get_boolean_env_param(session, environment, "mlStudiosEnabled")
+ if enabled:
+ SagemakerStudioService.create_sagemaker_studio_domain(session, environment, **kwargs)
+
+ @staticmethod
+ def update_env(session, environment, **kwargs):
+ current_mlstudio_enabled = EnvironmentService.get_boolean_env_param(session, environment, "mlStudiosEnabled")
+ domain = SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, environment.environmentUri)
+ previous_mlstudio_enabled = True if domain else False
+ if (current_mlstudio_enabled != previous_mlstudio_enabled and previous_mlstudio_enabled):
+ SageMakerStudioRepository.delete_sagemaker_studio_domain_by_env_uri(session=session, env_uri=environment.environmentUri)
+ return True
+ elif (current_mlstudio_enabled != previous_mlstudio_enabled and not previous_mlstudio_enabled):
+ SagemakerStudioService.create_sagemaker_studio_domain(session, environment, **kwargs)
+ return True
+ elif current_mlstudio_enabled and domain and domain.vpcType == "unknown":
+ SagemakerStudioService.update_sagemaker_studio_domain(environment, domain, **kwargs)
+ return True
+ return False
+
+ @staticmethod
+ def delete_env(session, environment):
+ SageMakerStudioRepository.delete_sagemaker_studio_domain_by_env_uri(session=session, env_uri=environment.environmentUri)
+
+
class SagemakerStudioService:
"""
Encapsulate the logic of interactions with sagemaker ml studio.
@@ -77,17 +114,19 @@ def create_sagemaker_studio_user(*, uri: str, admin_group: str, request: Sagemak
action=CREATE_SGMSTUDIO_USER,
message=f'ML Studio feature is disabled for the environment {env.label}',
)
+
+ domain = SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, env_uri=env.environmentUri)
response = get_sagemaker_studio_domain(
AwsAccountId=env.AwsAccountId,
- region=env.region
+ region=env.region,
+ domain_name=domain.sagemakerStudioDomainName
)
existing_domain = response.get('DomainId', False)
if not existing_domain:
raise exceptions.AWSResourceNotAvailable(
action='Sagemaker Studio domain',
- message='Update the environment stack '
- 'or create a Sagemaker studio domain on your AWS account.',
+ message='Update the environment stack and enable ML Studio Environment Feature'
)
sagemaker_studio_user = SagemakerStudioUser(
@@ -104,7 +143,7 @@ def create_sagemaker_studio_user(*, uri: str, admin_group: str, request: Sagemak
SamlAdminGroupName=admin_group,
tags=request.tags,
)
- SageMakerStudioRepository(session).save_sagemaker_studio_user(user=sagemaker_studio_user)
+ SageMakerStudioRepository.save_sagemaker_studio_user(session, sagemaker_studio_user)
ResourcePolicy.attach_resource_policy(
session=session,
@@ -135,10 +174,58 @@ def create_sagemaker_studio_user(*, uri: str, admin_group: str, request: Sagemak
return sagemaker_studio_user
+ @staticmethod
+ def update_sagemaker_studio_domain(environment, domain, data):
+ SagemakerStudioService._update_sagemaker_studio_domain_vpc(environment.AwsAccountId, environment.region, data)
+ domain.vpcType = data.get('vpcType')
+ if data.get('vpcId'):
+ domain.vpcId = data.get('vpcId')
+ if data.get('subnetIds'):
+ domain.subnetIds = data.get('subnetIds')
+
+ @staticmethod
+ def _update_sagemaker_studio_domain_vpc(account_id, region, data={}):
+ cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn(
+ accountid=account_id, region=region
+ )
+ if data.get("vpcId", None):
+ data["vpcType"] = "imported"
+ else:
+ response = EC2.check_default_vpc_exists(
+ AwsAccountId=account_id,
+ region=region,
+ role=cdk_look_up_role_arn,
+ )
+ if response:
+ vpcId, subnetIds = response
+ data["vpcType"] = "default"
+ data["vpcId"] = vpcId
+ data["subnetIds"] = subnetIds
+ else:
+ data["vpcType"] = "created"
+
+ @staticmethod
+ def create_sagemaker_studio_domain(session, environment, data: dict = {}):
+ SagemakerStudioService._update_sagemaker_studio_domain_vpc(environment.AwsAccountId, environment.region, data)
+
+ domain = SageMakerStudioRepository.create_sagemaker_studio_domain(
+ session=session,
+ username=get_context().username,
+ environment=environment,
+ data=data,
+ )
+ return domain
+
+ @staticmethod
+ def get_environment_sagemaker_studio_domain(*, environment_uri: str):
+ with _session() as session:
+ return SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri(session, env_uri=environment_uri)
+
@staticmethod
def list_sagemaker_studio_users(*, filter: dict) -> dict:
with _session() as session:
- return SageMakerStudioRepository(session).paginated_sagemaker_studio_users(
+ return SageMakerStudioRepository.paginated_sagemaker_studio_users(
+ session=session,
username=get_context().username,
groups=get_context().groups,
filter=filter,
@@ -197,7 +284,7 @@ def delete_sagemaker_studio_user(*, uri: str, delete_from_aws: bool):
@staticmethod
def _get_sagemaker_studio_user(session, uri):
- user = SageMakerStudioRepository(session).find_sagemaker_studio_user(uri=uri)
+ user = SageMakerStudioRepository.find_sagemaker_studio_user(session=session, uri=uri)
if not user:
raise exceptions.ObjectNotFound('SagemakerStudioUser', uri)
return user
diff --git a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py
new file mode 100644
index 000000000..a3ac794f3
--- /dev/null
+++ b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py
@@ -0,0 +1,178 @@
+"""env_mlstudio_domain_table
+
+Revision ID: 71a5f5de322f
+Revises: 8c79fb896983
+Create Date: 2023-11-29 09:44:04.160286
+
+"""
+import os
+from sqlalchemy import orm, Column, String, Boolean, ForeignKey, and_
+from sqlalchemy.ext.declarative import declarative_base
+import sqlalchemy as sa
+from alembic import op
+
+from sqlalchemy.dialects import postgresql
+from dataall.base.db import get_engine, has_table
+from dataall.base.db import utils, Resource
+
+# revision identifiers, used by Alembic.
+revision = '71a5f5de322f'
+down_revision = '8c79fb896983'
+branch_labels = None
+depends_on = None
+
+Base = declarative_base()
+
+
+class Environment(Resource, Base):
+ __tablename__ = "environment"
+ environmentUri = Column(String, primary_key=True)
+ AwsAccountId = Column(Boolean)
+ region = Column(Boolean)
+ SamlGroupName = Column(String)
+
+
+class EnvironmentParameter(Base):
+ __tablename__ = 'environment_parameters'
+ environmentUri = Column(String, primary_key=True)
+ key = Column('paramKey', String, primary_key=True)
+ value = Column('paramValue', String, nullable=True)
+
+
+class SagemakerStudioDomain(Resource, Base):
+ __tablename__ = 'sagemaker_studio_domain'
+ environmentUri = Column(String, ForeignKey("environment.environmentUri"))
+ sagemakerStudioUri = Column(
+ String, primary_key=True, default=utils.uuid('sagemakerstudio')
+ )
+ sagemakerStudioDomainID = Column(String, nullable=True)
+ SagemakerStudioStatus = Column(String, nullable=True)
+ sagemakerStudioDomainName = Column(String, nullable=False)
+ AWSAccountId = Column(String, nullable=False)
+ DefaultDomainRoleName = Column(String, nullable=False)
+ region = Column(String, default='eu-west-1')
+ SamlGroupName = Column(String, nullable=False)
+ vpcType = Column(String, nullable=True)
+
+
+def upgrade():
+ """
+ The script does the following migration:
+ 1) update of the sagemaker_studio_domain table to include SageMaker Studio Domain VPC Information
+ """
+ try:
+ envname = os.getenv('envname', 'local')
+ engine = get_engine(envname=envname).engine
+
+ bind = op.get_bind()
+ session = orm.Session(bind=bind)
+
+ if has_table('sagemaker_studio_domain', engine):
+ print("Updating sagemaker_studio_domain table...")
+ op.alter_column(
+ 'sagemaker_studio_domain',
+ 'sagemakerStudioDomainID',
+ nullable=True,
+ existing_type=sa.String()
+ )
+ op.alter_column(
+ 'sagemaker_studio_domain',
+ 'SagemakerStudioStatus',
+ nullable=True,
+ existing_type=sa.String()
+ )
+ op.alter_column(
+ 'sagemaker_studio_domain',
+ 'RoleArn',
+ new_column_name='DefaultDomainRoleName',
+ nullable=False,
+ existing_type=sa.String()
+ )
+
+ op.add_column("sagemaker_studio_domain", Column("sagemakerStudioDomainName", sa.String(), nullable=False))
+ op.add_column("sagemaker_studio_domain", Column("vpcType", sa.String(), nullable=True))
+ op.add_column("sagemaker_studio_domain", Column("vpcId", sa.String(), nullable=True))
+ op.add_column("sagemaker_studio_domain", Column("subnetIds", postgresql.ARRAY(sa.String()), nullable=True))
+ op.add_column("sagemaker_studio_domain", Column("SamlGroupName", sa.String(), nullable=False))
+
+ op.create_foreign_key(
+ "fk_sagemaker_studio_domain_env_uri",
+ "sagemaker_studio_domain", "environment",
+ ["environmentUri"], ["environmentUri"],
+ )
+
+ print("Update sagemaker_studio_domain table done.")
+ print("Filling sagemaker_studio_domain table with environments with mlstudio enabled...")
+
+ env_mlstudio_parameters: [EnvironmentParameter] = session.query(EnvironmentParameter).filter(
+ and_(
+ EnvironmentParameter.key == "mlStudiosEnabled",
+ EnvironmentParameter.value == "true"
+ )
+ ).all()
+ for param in env_mlstudio_parameters:
+ env: Environment = session.query(Environment).filter(
+ Environment.environmentUri == param.environmentUri
+ ).first()
+
+ domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter(
+ SagemakerStudioDomain.environmentUri == env.environmentUri
+ ).first()
+ if not domain:
+ domain = SagemakerStudioDomain(
+ label=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}",
+ owner=env.owner,
+ description='No description provided',
+ environmentUri=env.environmentUri,
+ AWSAccountId=env.AwsAccountId,
+ region=env.region,
+ DefaultDomainRoleName="RoleSagemakerStudioUsers",
+ sagemakerStudioDomainName=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}",
+ vpcType="unknown",
+ SamlGroupName=env.SamlGroupName
+ )
+ session.add(domain)
+ session.flush()
+ session.commit()
+ print("Fill of sagemaker_studio_domain table is done")
+
+ except Exception as exception:
+ print('Failed to upgrade due to:', exception)
+ raise exception
+
+
+def downgrade():
+ try:
+ envname = os.getenv('envname', 'local')
+ engine = get_engine(envname=envname).engine
+
+ bind = op.get_bind()
+ session = orm.Session(bind=bind)
+
+ if has_table('sagemaker_studio_domain', engine):
+ print("deleting sagemaker studio domain entries...")
+ session.query(SagemakerStudioDomain).delete()
+
+ print("Updating of sagemaker_studio_domain table...")
+ op.alter_column(
+ 'sagemaker_studio_domain',
+ 'DefaultDomainRoleName',
+ new_column_name='RoleArn',
+ nullable=False,
+ existing_type=sa.String()
+ )
+
+ op.drop_column("sagemaker_studio_domain", "sagemakerStudioDomainName")
+ op.drop_column("sagemaker_studio_domain", "vpcType")
+ op.drop_column("sagemaker_studio_domain", "vpcId")
+ op.drop_column("sagemaker_studio_domain", "subnetIds")
+ op.drop_column("sagemaker_studio_domain", "SamlGroupName")
+
+ op.drop_constraint("fk_sagemaker_studio_domain_env_uri", "sagemaker_studio_domain")
+
+ session.commit()
+ print("Update of sagemaker_studio_domain table is done")
+
+ except Exception as exception:
+ print('Failed to downgrade due to:', exception)
+ raise exception
diff --git a/frontend/src/modules/Environments/components/EnvironmentMLStudio.js b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js
new file mode 100644
index 000000000..44dac97b9
--- /dev/null
+++ b/frontend/src/modules/Environments/components/EnvironmentMLStudio.js
@@ -0,0 +1,156 @@
+import {
+ Box,
+ Card,
+ CardHeader,
+ Divider,
+ Grid,
+ CardContent,
+ Typography,
+ CircularProgress,
+ Chip
+} from '@mui/material';
+
+import PropTypes from 'prop-types';
+import React, { useCallback, useEffect, useState } from 'react';
+import { RefreshTableMenu } from 'design';
+import { SET_ERROR, useDispatch } from 'globalErrors';
+import { getEnvironmentMLStudioDomain, useClient } from 'services';
+
+export const EnvironmentMLStudio = ({ environment }) => {
+ const client = useClient();
+ const dispatch = useDispatch();
+ const [mlStudioDomain, setMLStudioDomain] = useState(null);
+ const [loading, setLoading] = useState(true);
+
+ const fetchMLStudioDomain = useCallback(async () => {
+ try {
+ setLoading(true);
+ const response = await client.query(
+ getEnvironmentMLStudioDomain({
+ environmentUri: environment.environmentUri
+ })
+ );
+ if (!response.errors) {
+ if (response.data.getEnvironmentMLStudioDomain) {
+ setMLStudioDomain(response.data.getEnvironmentMLStudioDomain);
+ }
+ } else {
+ dispatch({ type: SET_ERROR, error: response.errors[0].message });
+ }
+ } catch (e) {
+ dispatch({ type: SET_ERROR, error: e.message });
+ } finally {
+ setLoading(false);
+ }
+ }, [client, dispatch, environment.environmentUri]);
+
+ useEffect(() => {
+ if (client) {
+ fetchMLStudioDomain().catch((e) =>
+ dispatch({ type: SET_ERROR, error: e.message })
+ );
+ }
+ }, [client, fetchMLStudioDomain, dispatch]);
+
+ if (loading) {
+ return ;
+ }
+
+ return (
+
+
+ }
+ title={ML Studio Domain Information}
+ />
+
+
+
+
+ {mlStudioDomain === null ? (
+
+
+ No ML Studio Domain - To Create a ML Studio Domain for this
+ Environment: {environment.label}, edit the Environment and enable
+ the ML Studio Environment Feature
+
+
+ ) : (
+
+
+
+
+ SageMaker ML Studio Domain Name
+
+
+ {mlStudioDomain.sagemakerStudioDomainName}
+
+
+
+
+ SageMaker ML Studio Default Execution Role
+
+
+ arn:aws:iam::{environment.AwsAccountId}:role/
+ {mlStudioDomain.DefaultDomainRoleName}
+
+
+
+
+ Domain VPC Type
+
+
+ {mlStudioDomain.vpcType}
+
+
+ {(mlStudioDomain.vpcType === 'imported' ||
+ mlStudioDomain.vpcType === 'default') && (
+ <>
+
+
+ Domain VPC Id
+
+
+ {mlStudioDomain.vpcId}
+
+
+
+
+ Domain Subnet Ids
+
+
+ {mlStudioDomain.subnetIds?.map((subnet) => (
+
+ ))}
+
+
+ >
+ )}
+
+
+ )}
+
+
+ );
+};
+
+EnvironmentMLStudio.propTypes = {
+ environment: PropTypes.object.isRequired
+};
diff --git a/frontend/src/modules/Environments/components/index.js b/frontend/src/modules/Environments/components/index.js
index afccd1235..7aecd51fa 100644
--- a/frontend/src/modules/Environments/components/index.js
+++ b/frontend/src/modules/Environments/components/index.js
@@ -12,3 +12,4 @@ export * from './EnvironmentTeamInviteEditForm';
export * from './EnvironmentTeamInviteForm';
export * from './EnvironmentTeams';
export * from './NetworkCreateModal';
+export * from './EnvironmentMLStudio';
diff --git a/frontend/src/modules/Environments/views/EnvironmentCreateForm.js b/frontend/src/modules/Environments/views/EnvironmentCreateForm.js
index 2767e1fcc..a16cdfa7e 100644
--- a/frontend/src/modules/Environments/views/EnvironmentCreateForm.js
+++ b/frontend/src/modules/Environments/views/EnvironmentCreateForm.js
@@ -31,6 +31,13 @@ import { CopyToClipboard } from 'react-copy-to-clipboard/lib/Component';
import { Helmet } from 'react-helmet-async';
import { Link as RouterLink, useNavigate, useParams } from 'react-router-dom';
import * as Yup from 'yup';
+import {
+ createEnvironment,
+ getPivotRoleExternalId,
+ getPivotRoleName,
+ getPivotRolePresignedUrl,
+ getCDKExecPolicyPresignedUrl
+} from '../services';
import {
ArrowLeftIcon,
ChevronRightIcon,
@@ -44,13 +51,6 @@ import {
useClient,
useGroups
} from 'services';
-import {
- createEnvironment,
- getPivotRoleExternalId,
- getPivotRoleName,
- getPivotRolePresignedUrl,
- getCDKExecPolicyPresignedUrl
-} from '../services';
import {
AwsRegions,
isAnyEnvironmentModuleEnabled,
@@ -179,6 +179,8 @@ const EnvironmentCreateForm = (props) => {
region: values.region,
EnvironmentDefaultIAMRoleArn: values.EnvironmentDefaultIAMRoleArn,
resourcePrefix: values.resourcePrefix,
+ vpcId: values.vpcId,
+ subnetIds: values.subnetIds,
parameters: [
{
key: 'notebooksEnabled',
@@ -484,7 +486,9 @@ const EnvironmentCreateForm = (props) => {
mlStudiosEnabled: isModuleEnabled(ModuleNames.MLSTUDIO),
pipelinesEnabled: isModuleEnabled(ModuleNames.DATAPIPELINES),
EnvironmentDefaultIAMRoleArn: '',
- resourcePrefix: 'dataall'
+ resourcePrefix: 'dataall',
+ vpcId: '',
+ subnetIds: []
}}
validationSchema={Yup.object().shape({
label: Yup.string()
@@ -508,8 +512,14 @@ const EnvironmentCreateForm = (props) => {
).length >= 1
),
tags: Yup.array().nullable(),
- privateSubnetIds: Yup.array().nullable(),
- publicSubnetIds: Yup.array().nullable(),
+ subnetIds: Yup.array().when('vpcId', {
+ is: (value) => !!value,
+ then: Yup.array()
+ .min(1)
+ .required(
+ 'At least 1 Subnet Id required if VPC Id specified'
+ )
+ }),
vpcId: Yup.string().nullable(),
EnvironmentDefaultIAMRoleArn: Yup.string().nullable(),
resourcePrefix: Yup.string()
@@ -862,6 +872,45 @@ const EnvironmentCreateForm = (props) => {
+ {values.mlStudiosEnabled && (
+
+
+
+
+
+
+
+ {
+ setFieldValue('subnetIds', [...chip]);
+ }}
+ />
+
+
+
+ )}
{errors.submit && (
{errors.submit}
diff --git a/frontend/src/modules/Environments/views/EnvironmentEditForm.js b/frontend/src/modules/Environments/views/EnvironmentEditForm.js
index caa5d8441..382575920 100644
--- a/frontend/src/modules/Environments/views/EnvironmentEditForm.js
+++ b/frontend/src/modules/Environments/views/EnvironmentEditForm.js
@@ -30,7 +30,7 @@ import {
useSettings
} from 'design';
import { SET_ERROR, useDispatch } from 'globalErrors';
-import { useClient } from 'services';
+import { getEnvironmentMLStudioDomain, useClient } from 'services';
import { getEnvironment, updateEnvironment } from '../services';
import {
isAnyEnvironmentModuleEnabled,
@@ -47,6 +47,9 @@ const EnvironmentEditForm = (props) => {
const { settings } = useSettings();
const [loading, setLoading] = useState(true);
const [env, setEnv] = useState('');
+ const [envMLStudioDomain, setEnvMLStudioDomain] = useState('');
+ const [previousEnvMLStudioEnabled, setPreviousEnvMLStudioEnabled] =
+ useState(false);
const fetchItem = useCallback(async () => {
const response = await client.query(
@@ -58,6 +61,20 @@ const EnvironmentEditForm = (props) => {
environment.parameters.map((x) => [x.key, x.value])
);
setEnv(environment);
+ if (environment.parameters['mlStudiosEnabled'] === 'true') {
+ setPreviousEnvMLStudioEnabled(true);
+ const response2 = await client.query(
+ getEnvironmentMLStudioDomain({ environmentUri: params.uri })
+ );
+ if (!response2.errors && response2.data.getEnvironmentMLStudioDomain) {
+ setEnvMLStudioDomain(response2.data.getEnvironmentMLStudioDomain);
+ } else {
+ const error = response2.errors
+ ? response2.errors[0].message
+ : 'Environment ML Studio Domain not found';
+ dispatch({ type: SET_ERROR, error });
+ }
+ }
} else {
const error = response.errors
? response.errors[0].message
@@ -66,11 +83,13 @@ const EnvironmentEditForm = (props) => {
}
setLoading(false);
}, [client, dispatch, params.uri]);
+
useEffect(() => {
if (client) {
fetchItem().catch((e) => dispatch({ type: SET_ERROR, error: e.message }));
}
}, [client, fetchItem, dispatch]);
+
async function submit(values, setStatus, setSubmitting, setErrors) {
try {
const response = await client.mutate(
@@ -81,6 +100,8 @@ const EnvironmentEditForm = (props) => {
tags: values.tags,
description: values.description,
resourcePrefix: values.resourcePrefix,
+ vpcId: values.vpcId,
+ subnetIds: values.subnetIds,
parameters: [
{
key: 'notebooksEnabled',
@@ -213,6 +234,8 @@ const EnvironmentEditForm = (props) => {
label: env.label,
description: env.description,
tags: env.tags || [],
+ vpcId: envMLStudioDomain.vpcId || '',
+ subnetIds: envMLStudioDomain.subnetIds || [],
notebooksEnabled: env.parameters['notebooksEnabled'] === 'true',
mlStudiosEnabled: env.parameters['mlStudiosEnabled'] === 'true',
pipelinesEnabled: env.parameters['pipelinesEnabled'] === 'true',
@@ -226,6 +249,15 @@ const EnvironmentEditForm = (props) => {
.required('*Environment name is required'),
description: Yup.string().max(5000),
tags: Yup.array().nullable(),
+ subnetIds: Yup.array().when('vpcId', {
+ is: (value) => !!value,
+ then: Yup.array()
+ .min(1)
+ .required(
+ 'At least 1 Subnet Id required if VPC Id specified'
+ )
+ }),
+ vpcId: Yup.string().nullable(),
resourcePrefix: Yup.string()
.trim()
.matches(
@@ -383,6 +415,48 @@ const EnvironmentEditForm = (props) => {
+ {!previousEnvMLStudioEnabled &&
+ values.mlStudiosEnabled && (
+
+
+
+
+
+
+
+ {
+ setFieldValue('subnetIds', [...chip]);
+ }}
+ />
+
+
+
+ )}
{isAnyEnvironmentModuleEnabled() && (
diff --git a/frontend/src/modules/Environments/views/EnvironmentView.js b/frontend/src/modules/Environments/views/EnvironmentView.js
index 0ba724320..792918c13 100644
--- a/frontend/src/modules/Environments/views/EnvironmentView.js
+++ b/frontend/src/modules/Environments/views/EnvironmentView.js
@@ -39,6 +39,7 @@ import { archiveEnvironment, getEnvironment } from '../services';
import { KeyValueTagList, Stack, StackStatus } from 'modules/Shared';
import {
EnvironmentDatasets,
+ EnvironmentMLStudio,
EnvironmentOverview,
EnvironmentSubscriptions,
EnvironmentTeams,
@@ -59,6 +60,12 @@ const tabs = [
icon: ,
active: isModuleEnabled(ModuleNames.DATASETS)
},
+ {
+ label: 'ML Studio Domain',
+ value: 'mlstudio',
+ icon: ,
+ active: isModuleEnabled(ModuleNames.MLSTUDIO)
+ },
{ label: 'Networks', value: 'networks', icon: },
{
label: 'Subscriptions',
@@ -267,6 +274,9 @@ const EnvironmentView = () => {
fetchItem={fetchItem}
/>
)}
+ {isAdmin && currentTab === 'mlstudio' && (
+
+ )}
{isAdmin && currentTab === 'tags' && (
({
+ variables: {
+ environmentUri
+ },
+ query: gql`
+ query getEnvironmentMLStudioDomain($environmentUri: String) {
+ getEnvironmentMLStudioDomain(environmentUri: $environmentUri) {
+ sagemakerStudioUri
+ environmentUri
+ label
+ sagemakerStudioDomainName
+ DefaultDomainRoleName
+ vpcType
+ vpcId
+ subnetIds
+ owner
+ created
+ }
+ }
+ `
+});
diff --git a/frontend/src/services/graphql/MLStudio/index.js b/frontend/src/services/graphql/MLStudio/index.js
new file mode 100644
index 000000000..97d3de110
--- /dev/null
+++ b/frontend/src/services/graphql/MLStudio/index.js
@@ -0,0 +1 @@
+export * from './getEnvironmentMLStudioDomain';
diff --git a/frontend/src/services/graphql/index.js b/frontend/src/services/graphql/index.js
index 8d0e00804..ce1c3fba2 100644
--- a/frontend/src/services/graphql/index.js
+++ b/frontend/src/services/graphql/index.js
@@ -8,6 +8,7 @@ export * from './Glossary';
export * from './Groups';
export * from './KeyValueTags';
export * from './Metric';
+export * from './MLStudio';
export * from './Notification';
export * from './Organization';
export * from './Principal';
diff --git a/tests/core/conftest.py b/tests/core/conftest.py
index 6d8a449e4..738ab4d06 100644
--- a/tests/core/conftest.py
+++ b/tests/core/conftest.py
@@ -44,7 +44,6 @@ def factory(org, envname, owner, group, account, region, desc='test', parameters
'tags': ['a', 'b', 'c'],
'region': f'{region}',
'SamlGroupName': f'{group}',
- 'vpcId': 'vpc-123456',
'parameters': [{'key': k, 'value': v} for k, v in parameters.items()]
},
)
diff --git a/tests/core/environments/test_environment.py b/tests/core/environments/test_environment.py
index 31ba18e57..e806e07b2 100644
--- a/tests/core/environments/test_environment.py
+++ b/tests/core/environments/test_environment.py
@@ -221,26 +221,6 @@ def test_list_environments_no_filter(org_fixture, env_fixture, client, group):
assert response.data.listEnvironments.count == 1
- response = client.query(
- """
- query ListEnvironmentNetworks($environmentUri: String!,$filter:VpcFilter){
- listEnvironmentNetworks(environmentUri:$environmentUri,filter:$filter){
- count
- nodes{
- VpcId
- SamlGroupName
- }
- }
- }
- """,
- environmentUri=env_fixture.environmentUri,
- username='alice',
- groups=[group.name],
- )
- print(response)
-
- assert response.data.listEnvironmentNetworks.count == 1
-
def test_list_environment_role_filter_as_creator(org_fixture, env_fixture, client, group):
response = client.query(
@@ -656,23 +636,16 @@ def test_create_environment(db, client, org_fixture, env_fixture, user, group):
'tags': ['a', 'b', 'c'],
'region': f'{env_fixture.region}',
'SamlGroupName': group.name,
- 'vpcId': 'vpc-1234567',
- 'privateSubnetIds': 'subnet-1',
- 'publicSubnetIds': 'subnet-21',
'resourcePrefix': 'customer-prefix',
},
)
body = response.data.createEnvironment
- assert body.networks
+ assert len(body.networks) == 0
assert body.EnvironmentDefaultIAMRoleName == 'myOwnIamRole'
assert body.EnvironmentDefaultIAMRoleImported
assert body.resourcePrefix == 'customer-prefix'
- for vpc in body.networks:
- assert vpc.privateSubnetIds
- assert vpc.publicSubnetIds
- assert vpc.default
with db.scoped_session() as session:
env = EnvironmentService.get_environment_by_uri(
diff --git a/tests/core/vpc/test_vpc.py b/tests/core/vpc/test_vpc.py
index a55196d32..8f2391220 100644
--- a/tests/core/vpc/test_vpc.py
+++ b/tests/core/vpc/test_vpc.py
@@ -60,7 +60,7 @@ def test_list_networks(client, env_fixture, db, user, group, vpc):
)
print(response)
- assert response.data.listEnvironmentNetworks.count == 2
+ assert response.data.listEnvironmentNetworks.count == 1
def test_list_networks_nopermissions(client, env_fixture, db, user, group2, vpc):
@@ -119,4 +119,4 @@ def test_delete_network(client, env_fixture, db, user, group, module_mocker, vpc
username='alice',
groups=[group.name],
)
- assert len(response.data.listEnvironmentNetworks['nodes']) == 1
+ assert len(response.data.listEnvironmentNetworks['nodes']) == 0
diff --git a/tests/modules/mlstudio/cdk/conftest.py b/tests/modules/mlstudio/cdk/conftest.py
index 4b3327838..2c6f1eddd 100644
--- a/tests/modules/mlstudio/cdk/conftest.py
+++ b/tests/modules/mlstudio/cdk/conftest.py
@@ -2,7 +2,7 @@
from dataall.core.environment.db.environment_models import Environment
from dataall.core.organizations.db.organization_models import Organization
-from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser
+from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser, SagemakerStudioDomain
@pytest.fixture(scope='module', autouse=True)
@@ -23,3 +23,21 @@ def sgm_studio(db, env_fixture: Environment) -> SagemakerStudioUser:
)
session.add(sm_user)
yield sm_user
+
+@pytest.fixture(scope='module', autouse=True)
+def sgm_studio_domain(db, env_fixture: Environment) -> SagemakerStudioDomain:
+ with db.scoped_session() as session:
+ sm_domain = SagemakerStudioDomain(
+ label='sagemaker-domain',
+ owner='me',
+ environmentUri=env_fixture.environmentUri,
+ AWSAccountId=env_fixture.AwsAccountId,
+ region=env_fixture.region,
+ SagemakerStudioStatus="PENDING",
+ DefaultDomainRoleName="DefaultMLStudioRole",
+ sagemakerStudioDomainName="DomainName",
+ vpcType="created",
+ SamlGroupName=env_fixture.SamlGroupName,
+ )
+ session.add(sm_domain)
+ yield sm_domain
diff --git a/tests/modules/mlstudio/cdk/test_sagemaker_studio_stack.py b/tests/modules/mlstudio/cdk/test_sagemaker_studio_stack.py
index a2c1752e2..8e0cd6166 100644
--- a/tests/modules/mlstudio/cdk/test_sagemaker_studio_stack.py
+++ b/tests/modules/mlstudio/cdk/test_sagemaker_studio_stack.py
@@ -66,16 +66,19 @@ def patch_methods_sagemaker_studio(mocker, db, sgm_studio, env_fixture, org_fixt
@pytest.fixture(scope='function', autouse=True)
-def patch_methods_sagemaker_studio_extension(mocker):
+def patch_methods_sagemaker_studio_extension(mocker, sgm_studio_domain):
mocker.patch(
'dataall.base.aws.sts.SessionHelper.get_cdk_look_up_role_arn',
return_value="arn:aws:iam::1111111111:role/cdk-hnb659fds-lookup-role-1111111111-eu-west-1",
)
mocker.patch(
- 'dataall.modules.mlstudio.aws.ec2_client.EC2.check_default_vpc_exists',
+ 'dataall.base.aws.ec2_client.EC2.check_default_vpc_exists',
return_value=False,
)
-
+ mocker.patch(
+ 'dataall.modules.mlstudio.db.mlstudio_repositories.SageMakerStudioRepository.get_sagemaker_studio_domain_by_env_uri',
+ return_value=sgm_studio_domain,
+ )
def test_resources_sgmstudio_stack_created(sgm_studio):
app = App()
diff --git a/tests/modules/mlstudio/conftest.py b/tests/modules/mlstudio/conftest.py
index 433048894..d1fffb2cf 100644
--- a/tests/modules/mlstudio/conftest.py
+++ b/tests/modules/mlstudio/conftest.py
@@ -1,6 +1,6 @@
import pytest
-from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser
+from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser, SagemakerStudioDomain
@pytest.fixture(scope='module', autouse=True)
@@ -16,8 +16,31 @@ def env_params():
yield {'mlStudiosEnabled': 'True'}
+@pytest.fixture(scope='module', autouse=True)
+def get_cdk_look_up_role_arn(module_mocker):
+ module_mocker.patch(
+ 'dataall.base.aws.sts.SessionHelper.get_cdk_look_up_role_arn',
+ return_value="arn:aws:iam::1111111111:role/cdk-hnb659fds-lookup-role-1111111111-eu-west-1",
+ )
+
+@pytest.fixture(scope='module', autouse=True)
+def check_default_vpc(module_mocker):
+ module_mocker.patch(
+ 'dataall.base.aws.ec2_client.EC2.check_default_vpc_exists',
+ return_value=False,
+ )
+
+
+@pytest.fixture(scope='module', autouse=True)
+def check_vpc_exists(module_mocker):
+ module_mocker.patch(
+ 'dataall.base.aws.ec2_client.EC2.check_vpc_exists',
+ return_value=True,
+ )
+
+
@pytest.fixture(scope='module')
-def sagemaker_studio_user(client, tenant, group, env_fixture) -> SagemakerStudioUser:
+def sagemaker_studio_user(client, tenant, group, env_with_mlstudio) -> SagemakerStudioUser:
response = client.query(
"""
mutation createSagemakerStudioUser($input:NewSagemakerStudioUserInput){
@@ -36,7 +59,7 @@ def sagemaker_studio_user(client, tenant, group, env_fixture) -> SagemakerStudio
input={
'label': 'testcreate',
'SamlAdminGroupName': group.name,
- 'environmentUri': env_fixture.environmentUri,
+ 'environmentUri': env_with_mlstudio.environmentUri,
},
username='alice',
groups=[group.name],
@@ -45,7 +68,7 @@ def sagemaker_studio_user(client, tenant, group, env_fixture) -> SagemakerStudio
@pytest.fixture(scope='module')
-def multiple_sagemaker_studio_users(client, db, env_fixture, group):
+def multiple_sagemaker_studio_users(client, db, env_with_mlstudio, group):
for i in range(0, 10):
response = client.query(
"""
@@ -65,7 +88,7 @@ def multiple_sagemaker_studio_users(client, db, env_fixture, group):
input={
'label': f'test{i}',
'SamlAdminGroupName': group.name,
- 'environmentUri': env_fixture.environmentUri,
+ 'environmentUri': env_with_mlstudio.environmentUri,
},
username='alice',
groups=[group.name],
@@ -77,5 +100,92 @@ def multiple_sagemaker_studio_users(client, db, env_fixture, group):
)
assert (
response.data.createSagemakerStudioUser.environmentUri
- == env_fixture.environmentUri
+ == env_with_mlstudio.environmentUri
)
+
+@pytest.fixture(scope='module')
+def env_with_mlstudio(client, org_fixture, user, group, parameters=None, vpcId='', subnetIds=[]):
+ if not parameters:
+ parameters = {'mlStudiosEnabled': 'True'}
+ response = client.query(
+ """mutation CreateEnv($input:NewEnvironmentInput){
+ createEnvironment(input:$input){
+ organization{
+ organizationUri
+ }
+ environmentUri
+ label
+ AwsAccountId
+ SamlGroupName
+ region
+ name
+ owner
+ parameters {
+ key
+ value
+ }
+ }
+ }""",
+ username=f'{user.username}',
+ groups=['testadmins'],
+ input={
+ 'label': f'dev',
+ 'description': '',
+ 'organizationUri': org_fixture.organizationUri,
+ 'AwsAccountId': '111111111111',
+ 'tags': [],
+ 'region': 'us-east-1',
+ 'SamlGroupName': 'testadmins',
+ 'parameters': [{'key': k, 'value': v} for k, v in parameters.items()],
+ 'vpcId': vpcId,
+ 'subnetIds': subnetIds
+ },
+ )
+ yield response.data.createEnvironment
+
+
+@pytest.fixture(scope='module', autouse=True)
+def org(client):
+ cache = {}
+
+ def factory(orgname, owner, group):
+ key = orgname + owner + group
+ if cache.get(key):
+ print(f'returning item from cached key {key}')
+ return cache.get(key)
+ response = client.query(
+ """mutation CreateOrganization($input:NewOrganizationInput){
+ createOrganization(input:$input){
+ organizationUri
+ label
+ name
+ owner
+ SamlGroupName
+ }
+ }""",
+ username=f'{owner}',
+ groups=[group],
+ input={
+ 'label': f'{orgname}',
+ 'description': f'test',
+ 'tags': ['a', 'b', 'c'],
+ 'SamlGroupName': f'{group}',
+ },
+ )
+ cache[key] = response.data.createOrganization
+ return cache[key]
+
+ yield factory
+
+
+@pytest.fixture(scope='module')
+def org_fixture(org, user, group):
+ org1 = org('testorg', user.username, group.name)
+ yield org1
+
+
+@pytest.fixture(scope='module')
+def env_mlstudio_fixture(env, org_fixture, user, group, tenant):
+ env1 = env_with_mlstudio(org_fixture, 'dev', 'alice', 'testadmins', '111111111111', 'eu-west-1')
+ yield env1
+
diff --git a/tests/modules/mlstudio/test_sagemaker_studio.py b/tests/modules/mlstudio/test_sagemaker_studio.py
index c55762522..3d90b405a 100644
--- a/tests/modules/mlstudio/test_sagemaker_studio.py
+++ b/tests/modules/mlstudio/test_sagemaker_studio.py
@@ -1,14 +1,43 @@
from dataall.modules.mlstudio.db.mlstudio_models import SagemakerStudioUser
-def test_create_sagemaker_studio_user(sagemaker_studio_user, group, env_fixture):
+def test_create_sagemaker_studio_domain(db, client, org_fixture, env_with_mlstudio, user, group, vpcId="vpc-1234", subnetIds=["subnet"]):
+ response = client.query(
+ """
+ query getEnvironmentMLStudioDomain($environmentUri: String) {
+ getEnvironmentMLStudioDomain(environmentUri: $environmentUri) {
+ sagemakerStudioUri
+ environmentUri
+ label
+ sagemakerStudioDomainName
+ DefaultDomainRoleName
+ vpcType
+ vpcId
+ subnetIds
+ owner
+ created
+ }
+ }
+ """,
+ environmentUri=env_with_mlstudio.environmentUri,
+ )
+
+ assert response.data.getEnvironmentMLStudioDomain.sagemakerStudioUri
+ assert response.data.getEnvironmentMLStudioDomain.label == f'{env_with_mlstudio.label}-domain'
+ assert response.data.getEnvironmentMLStudioDomain.vpcType == 'created'
+ assert len(response.data.getEnvironmentMLStudioDomain.vpcId) == 0
+ assert len(response.data.getEnvironmentMLStudioDomain.subnetIds) == 0
+ assert response.data.getEnvironmentMLStudioDomain.environmentUri == env_with_mlstudio.environmentUri
+
+
+def test_create_sagemaker_studio_user(sagemaker_studio_user, group, env_with_mlstudio):
"""Testing that the conftest sagemaker studio user has been created correctly"""
assert sagemaker_studio_user.label == 'testcreate'
assert sagemaker_studio_user.SamlAdminGroupName == group.name
- assert sagemaker_studio_user.environmentUri == env_fixture.environmentUri
+ assert sagemaker_studio_user.environmentUri == env_with_mlstudio.environmentUri
-def test_list_sagemaker_studio_users(client, env_fixture, db, group, multiple_sagemaker_studio_users):
+def test_list_sagemaker_studio_users(client, db, group, multiple_sagemaker_studio_users):
response = client.query(
"""
query listSagemakerStudioUsers($filter:SagemakerStudioUserFilter!){
@@ -67,3 +96,114 @@ def test_delete_sagemaker_studio_user(
sagemaker_studio_user.sagemakerStudioUserUri
)
assert not n
+
+def update_env_query():
+ query = """
+ mutation UpdateEnv($environmentUri:String!,$input:ModifyEnvironmentInput){
+ updateEnvironment(environmentUri:$environmentUri,input:$input){
+ organization{
+ organizationUri
+ }
+ label
+ AwsAccountId
+ region
+ SamlGroupName
+ owner
+ tags
+ resourcePrefix
+ parameters {
+ key
+ value
+ }
+ }
+ }
+ """
+ return query
+
+def test_update_env_delete_domain(client, org_fixture, env_with_mlstudio, group, group2):
+ response = client.query(
+ update_env_query(),
+ username='alice',
+ environmentUri=env_with_mlstudio.environmentUri,
+ input={
+ 'label': 'DEV',
+ 'tags': [],
+ 'parameters': [
+ {
+ 'key': 'mlStudiosEnabled',
+ 'value': 'False'
+ }
+ ],
+ },
+ groups=[group.name],
+ )
+
+ response = client.query(
+ """
+ query getEnvironmentMLStudioDomain($environmentUri: String) {
+ getEnvironmentMLStudioDomain(environmentUri: $environmentUri) {
+ sagemakerStudioUri
+ environmentUri
+ label
+ sagemakerStudioDomainName
+ DefaultDomainRoleName
+ vpcType
+ vpcId
+ subnetIds
+ owner
+ created
+ }
+ }
+ """,
+ environmentUri=env_with_mlstudio.environmentUri,
+ )
+ assert response.data.getEnvironmentMLStudioDomain is None
+
+
+def test_update_env_create_domain_with_vpc(db, client, org_fixture, env_with_mlstudio, user, group):
+ response = client.query(
+ update_env_query(),
+ username='alice',
+ environmentUri=env_with_mlstudio.environmentUri,
+ input={
+ 'label': 'dev',
+ 'tags': [],
+ 'vpcId': "vpc-12345",
+ 'subnetIds': ['subnet-12345', 'subnet-67890'],
+ 'parameters': [
+ {
+ 'key': 'mlStudiosEnabled',
+ 'value': 'True'
+ }
+ ],
+ },
+ groups=[group.name],
+ )
+
+ response = client.query(
+ """
+ query getEnvironmentMLStudioDomain($environmentUri: String) {
+ getEnvironmentMLStudioDomain(environmentUri: $environmentUri) {
+ sagemakerStudioUri
+ environmentUri
+ label
+ sagemakerStudioDomainName
+ DefaultDomainRoleName
+ vpcType
+ vpcId
+ subnetIds
+ owner
+ created
+ }
+ }
+ """,
+ environmentUri=env_with_mlstudio.environmentUri,
+ )
+
+ assert response.data.getEnvironmentMLStudioDomain.sagemakerStudioUri
+ assert response.data.getEnvironmentMLStudioDomain.label == f'{env_with_mlstudio.label}-domain'
+ assert response.data.getEnvironmentMLStudioDomain.vpcType == 'imported'
+ assert response.data.getEnvironmentMLStudioDomain.vpcId == 'vpc-12345'
+ assert len(response.data.getEnvironmentMLStudioDomain.subnetIds) == 2
+ assert response.data.getEnvironmentMLStudioDomain.environmentUri == env_with_mlstudio.environmentUri
+