Skip to content

Commit

Permalink
Byo vpc mlstudio (data-dot-all#894)
Browse files Browse the repository at this point in the history
### Feature or Bugfix
<!-- please choose -->
- Feature

### Detail
- Enable SageMaker Studio Domain to be deployed in a already provisioned
VPC


### Relates
- data-dot-all#795

### Security
Please answer the questions below briefly where applicable, or write
`N/A`. Based on
[OWASP 10](https://owasp.org/Top10/en/).

- Does this PR introduce or modify any input fields or queries - this
includes
fetching data from storage outside the application (e.g. a database, an
S3 bucket)?
  - Is the input sanitized?
- What precautions are you taking before deserializing the data you
consume?
  - Is injection prevented by parametrizing queries?
  - Have you ensured no `eval` or similar functions are used?
- Does this PR introduce any functionality or component that requires
authorization?
- How have you ensured it respects the existing AuthN/AuthZ mechanisms?
  - Are you logging failed auth attempts?
- Are you using or adding any cryptographic features?
  - Do you use a standard proven implementations?
  - Are the used keys controlled by the customer? Where are they stored?
- Are you introducing any new policies/roles/users?
  - Have you used the least-privilege principle? How?


By submitting this pull request, I confirm that my contribution is made
under the terms of the Apache 2.0 license.
  • Loading branch information
noah-paige authored Dec 8, 2023
1 parent 5061ecb commit 94c93d9
Show file tree
Hide file tree
Showing 33 changed files with 1,207 additions and 253 deletions.
65 changes: 65 additions & 0 deletions backend/dataall/base/aws/ec2_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import logging

from dataall.base.aws.sts import SessionHelper
from botocore.exceptions import ClientError

log = logging.getLogger(__name__)


class EC2:

@staticmethod
def get_client(account_id: str, region: str, role=None):
session = SessionHelper.remote_session(accountid=account_id, role=role)
return session.client('ec2', region_name=region)

@staticmethod
def check_default_vpc_exists(AwsAccountId: str, region: str, role=None):
log.info("Check that default VPC exists..")
client = EC2.get_client(account_id=AwsAccountId, region=region, role=role)
response = client.describe_vpcs(
Filters=[{'Name': 'isDefault', 'Values': ['true']}]
)
vpcs = response['Vpcs']
log.info(f"Default VPCs response: {vpcs}")
if vpcs:
vpc_id = vpcs[0]['VpcId']
subnetIds = EC2._get_vpc_subnets(AwsAccountId=AwsAccountId, region=region, vpc_id=vpc_id, role=role)
if subnetIds:
return vpc_id, subnetIds
return False

@staticmethod
def _get_vpc_subnets(AwsAccountId: str, region: str, vpc_id: str, role=None):
client = EC2.get_client(account_id=AwsAccountId, region=region, role=role)
response = client.describe_subnets(
Filters=[{'Name': 'vpc-id', 'Values': [vpc_id]}]
)
return [subnet['SubnetId'] for subnet in response['Subnets']]

@staticmethod
def check_vpc_exists(AwsAccountId, region, vpc_id, role=None, subnet_ids=[]):
try:
ec2 = EC2.get_client(account_id=AwsAccountId, region=region, role=role)
response = ec2.describe_vpcs(VpcIds=[vpc_id])
except ClientError as e:
log.exception(f'VPC Id {vpc_id} Not Found: {e}')
raise Exception(f'VPCNotFound: {vpc_id}')

try:
if subnet_ids:
response = ec2.describe_subnets(
Filters=[
{
'Name': 'vpc-id',
'Values': [vpc_id]
},
],
SubnetIds=subnet_ids
)
except ClientError as e:
log.exception(f'Subnet Id {subnet_ids} Not Found: {e}')
raise Exception(f'VPCSubnetsNotFound: {subnet_ids}')

if not subnet_ids or len(response['Subnets']) != len(subnet_ids):
raise Exception(f'Not All Subnets: {subnet_ids} Are Within the Specified VPC Id {vpc_id}')
1 change: 1 addition & 0 deletions backend/dataall/base/utils/naming_convention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class NamingConventionPattern(Enum):
GLUE = {'regex': '[^a-zA-Z0-9_]', 'separator': '_', 'max_length': 63}
GLUE_ETL = {'regex': '[^a-zA-Z0-9-]', 'separator': '-', 'max_length': 52}
NOTEBOOK = {'regex': '[^a-zA-Z0-9-]', 'separator': '-', 'max_length': 63}
MLSTUDIO_DOMAIN = {'regex': '[^a-zA-Z0-9-]', 'separator': '-', 'max_length': 63}
DEFAULT = {'regex': '[^a-zA-Z0-9-_]', 'separator': '-', 'max_length': 63}
OPENSEARCH = {'regex': '[^a-z0-9-]', 'separator': '-', 'max_length': 27}
OPENSEARCH_SERVERLESS = {'regex': '[^a-z0-9-]', 'separator': '-', 'max_length': 31}
Expand Down
15 changes: 6 additions & 9 deletions backend/dataall/core/environment/api/input_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,11 @@
gql.Argument('description', gql.String),
gql.Argument('AwsAccountId', gql.NonNullableType(gql.String)),
gql.Argument('region', gql.NonNullableType(gql.String)),
gql.Argument('vpcId', gql.String),
gql.Argument('privateSubnetIds', gql.ArrayType(gql.String)),
gql.Argument('publicSubnetIds', gql.ArrayType(gql.String)),
gql.Argument('EnvironmentDefaultIAMRoleArn', gql.String),
gql.Argument('resourcePrefix', gql.String),
gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput))

gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)),
gql.Argument('vpcId', gql.String),
gql.Argument('subnetIds', gql.ArrayType(gql.String))
],
)

Expand All @@ -45,11 +43,10 @@
gql.Argument('description', gql.String),
gql.Argument('tags', gql.ArrayType(gql.String)),
gql.Argument('SamlGroupName', gql.String),
gql.Argument('vpcId', gql.String),
gql.Argument('privateSubnetIds', gql.ArrayType(gql.String)),
gql.Argument('publicSubnetIds', gql.ArrayType(gql.String)),
gql.Argument('resourcePrefix', gql.String),
gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput))
gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)),
gql.Argument('vpcId', gql.String),
gql.Argument('subnetIds', gql.ArrayType(gql.String))
],
)

Expand Down
28 changes: 23 additions & 5 deletions backend/dataall/core/environment/api/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from dataall.core.stacks.aws.cloudformation import CloudFormation
from dataall.core.stacks.db.stack_repositories import Stack
from dataall.core.vpc.db.vpc_repositories import Vpc
from dataall.base.aws.ec2_client import EC2
from dataall.base.db import exceptions
from dataall.core.permissions import permissions
from dataall.base.feature_toggle_checker import is_feature_enabled
Expand All @@ -43,7 +44,7 @@ def get_pivot_role_as_part_of_environment(context: Context, source, **kwargs):
return True if ssm_param == "True" else False


def check_environment(context: Context, source, account_id, region):
def check_environment(context: Context, source, account_id, region, data):
""" Checks necessary resources for environment deployment.
- Check CDKToolkit exists in Account assuming cdk_look_up_role
- Check Pivot Role exists in Account if pivot_role_as_part_of_environment is False
Expand Down Expand Up @@ -71,11 +72,25 @@ def check_environment(context: Context, source, account_id, region):
action='CHECK_PIVOT_ROLE',
message='Pivot Role has not been created in the Environment AWS Account',
)
mlStudioEnabled = None
for parameter in data.get("parameters", []):
if parameter['key'] == 'mlStudiosEnabled':
mlStudioEnabled = parameter['value']

if mlStudioEnabled and data.get("vpcId", None) and data.get("subnetIds", []):
log.info("Check if ML Studio VPC Exists in the Account")
EC2.check_vpc_exists(
AwsAccountId=account_id,
region=region,
role=cdk_look_up_role_arn,
vpc_id=data.get("vpcId", None),
subnet_ids=data.get('subnetIds', []),
)

return cdk_role_name


def create_environment(context: Context, source, input=None):
def create_environment(context: Context, source, input={}):
if input.get('SamlGroupName') and input.get('SamlGroupName') not in context.groups:
raise exceptions.UnauthorizedOperation(
action=permissions.LINK_ENVIRONMENT,
Expand All @@ -85,8 +100,10 @@ def create_environment(context: Context, source, input=None):
with context.engine.scoped_session() as session:
cdk_role_name = check_environment(context, source,
account_id=input.get('AwsAccountId'),
region=input.get('region')
region=input.get('region'),
data=input
)

input['cdk_role_name'] = cdk_role_name
env = EnvironmentService.create_environment(
session=session,
Expand Down Expand Up @@ -119,7 +136,8 @@ def update_environment(
environment = EnvironmentService.get_environment_by_uri(session, environmentUri)
cdk_role_name = check_environment(context, source,
account_id=environment.AwsAccountId,
region=environment.region
region=environment.region,
data=input
)

previous_resource_prefix = environment.resourcePrefix
Expand All @@ -130,7 +148,7 @@ def update_environment(
data=input,
)

if EnvironmentResourceManager.deploy_updated_stack(session, previous_resource_prefix, environment):
if EnvironmentResourceManager.deploy_updated_stack(session, previous_resource_prefix, environment, data=input):
stack_helper.deploy_stack(targetUri=environment.environmentUri)

return environment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ def delete_env(session, environment):
pass

@staticmethod
def update_env(session, environment):
def create_env(session, environment, **kwargs):
pass

@staticmethod
def update_env(session, environment, **kwargs):
return False

@staticmethod
Expand All @@ -39,10 +43,10 @@ def count_group_resources(cls, session, environment, group_uri) -> int:
return counter

@classmethod
def deploy_updated_stack(cls, session, prev_prefix, environment):
def deploy_updated_stack(cls, session, prev_prefix, environment, **kwargs):
deploy_stack = prev_prefix != environment.resourcePrefix
for resource in cls._resources:
deploy_stack |= resource.update_env(session, environment)
deploy_stack |= resource.update_env(session, environment, **kwargs)

return deploy_stack

Expand All @@ -51,6 +55,11 @@ def delete_env(cls, session, environment):
for resource in cls._resources:
resource.delete_env(session, environment)

@classmethod
def create_env(cls, session, environment, **kwargs):
for resource in cls._resources:
resource.create_env(session, environment, **kwargs)

@classmethod
def count_consumption_role_resources(cls, session, role_uri):
counter = 0
Expand Down
24 changes: 1 addition & 23 deletions backend/dataall/core/environment/services/environment_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def create_environment(session, uri, data=None):
session.commit()

EnvironmentService._update_env_parameters(session, env, data)
EnvironmentResourceManager.create_env(session, env, data=data)

env.EnvironmentDefaultBucketName = NamingConventionService(
target_uri=env.environmentUri,
Expand Down Expand Up @@ -98,29 +99,6 @@ def create_environment(session, uri, data=None):
env.EnvironmentDefaultIAMRoleArn = data['EnvironmentDefaultIAMRoleArn']
env.EnvironmentDefaultIAMRoleImported = True

if data.get('vpcId'):
vpc = Vpc(
environmentUri=env.environmentUri,
region=env.region,
AwsAccountId=env.AwsAccountId,
VpcId=data.get('vpcId'),
privateSubnetIds=data.get('privateSubnetIds', []),
publicSubnetIds=data.get('publicSubnetIds', []),
SamlGroupName=data['SamlGroupName'],
owner=context.username,
label=f"{env.name}-{data.get('vpcId')}",
name=f"{env.name}-{data.get('vpcId')}",
default=True,
)
session.add(vpc)
session.commit()
ResourcePolicy.attach_resource_policy(
session=session,
group=data['SamlGroupName'],
permissions=permissions.NETWORK_ALL,
resource_uri=vpc.vpcUri,
resource_type=Vpc.__name__,
)
env_group = EnvironmentGroup(
environmentUri=env.environmentUri,
groupUri=data['SamlGroupName'],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def count_resources(session, environment, group_uri) -> int:
)

@staticmethod
def update_env(session, environment):
def update_env(session, environment, **kwargs):
return EnvironmentService.get_boolean_env_param(session, environment, "dashboardsEnabled")

@staticmethod
Expand Down
5 changes: 4 additions & 1 deletion backend/dataall/modules/mlstudio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from dataall.base.loader import ImportMode, ModuleInterface
from dataall.core.stacks.db.target_type_repositories import TargetType
from dataall.modules.mlstudio.db.mlstudio_repositories import SageMakerStudioRepository
from dataall.modules.mlstudio.services.mlstudio_service import SagemakerStudioEnvironmentResource
from dataall.core.environment.services.environment_resource_manager import EnvironmentResourceManager

log = logging.getLogger(__name__)

Expand All @@ -20,6 +21,8 @@ def __init__(self):
from dataall.modules.mlstudio.services.mlstudio_permissions import GET_SGMSTUDIO_USER, UPDATE_SGMSTUDIO_USER
TargetType("mlstudio", GET_SGMSTUDIO_USER, UPDATE_SGMSTUDIO_USER)

EnvironmentResourceManager.register(SagemakerStudioEnvironmentResource())

log.info("API of sagemaker mlstudio has been imported")


Expand Down
10 changes: 10 additions & 0 deletions backend/dataall/modules/mlstudio/api/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
get_sagemaker_studio_user,
list_sagemaker_studio_users,
get_sagemaker_studio_user_presigned_url,
get_environment_sagemaker_studio_domain
)

getSagemakerStudioUser = gql.QueryField(
Expand Down Expand Up @@ -34,3 +35,12 @@
type=gql.String,
resolver=get_sagemaker_studio_user_presigned_url,
)

getEnvironmentMLStudioDomain = gql.QueryField(
name='getEnvironmentMLStudioDomain',
args=[
gql.Argument(name='environmentUri', type=gql.NonNullableType(gql.String)),
],
type=gql.Ref('SagemakerStudioDomain'),
resolver=get_environment_sagemaker_studio_domain,
)
9 changes: 7 additions & 2 deletions backend/dataall/modules/mlstudio/api/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def required_uri(uri):
raise exceptions.RequiredParameter('URI')

@staticmethod
def validate_creation_request(data):
def validate_user_creation_request(data):
required = RequestValidator._required
if not data:
raise exceptions.RequiredParameter('data')
Expand All @@ -36,7 +36,7 @@ def _required(data: dict, name: str):

def create_sagemaker_studio_user(context: Context, source, input: dict = None):
"""Creates a SageMaker Studio user. Deploys the SageMaker Studio user stack into AWS"""
RequestValidator.validate_creation_request(input)
RequestValidator.validate_user_creation_request(input)
request = SagemakerStudioCreationRequest.from_dict(input)
return SagemakerStudioService.create_sagemaker_studio_user(
uri=input["environmentUri"],
Expand Down Expand Up @@ -90,6 +90,11 @@ def delete_sagemaker_studio_user(
)


def get_environment_sagemaker_studio_domain(context, source, environmentUri: str = None):
RequestValidator.required_uri(environmentUri)
return SagemakerStudioService.get_environment_sagemaker_studio_domain(environment_uri=environmentUri)


def resolve_user_role(context: Context, source: SagemakerStudioUser):
"""
Resolves the role of the current user in reference with the SageMaker Studio User
Expand Down
24 changes: 24 additions & 0 deletions backend/dataall/modules/mlstudio/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,27 @@
gql.Field(name='nodes', type=gql.ArrayType(SagemakerStudioUser)),
],
)

SagemakerStudioDomain = gql.ObjectType(
name='SagemakerStudioDomain',
fields=[
gql.Field(name='sagemakerStudioUri', type=gql.ID),
gql.Field(name='environmentUri', type=gql.NonNullableType(gql.String)),
gql.Field(name='sagemakerStudioDomainName', type=gql.String),
gql.Field(name='DefaultDomainRoleName', type=gql.String),
gql.Field(name='label', type=gql.String),
gql.Field(name='name', type=gql.String),
gql.Field(name='vpcType', type=gql.String),
gql.Field(name='vpcId', type=gql.String),
gql.Field(name='subnetIds', type=gql.ArrayType(gql.String)),
gql.Field(name='owner', type=gql.String),
gql.Field(name='created', type=gql.String),
gql.Field(name='updated', type=gql.String),
gql.Field(name='deleted', type=gql.String),
gql.Field(
name='environment',
type=gql.Ref('Environment'),
resolver=resolve_environment,
)
],
)
27 changes: 0 additions & 27 deletions backend/dataall/modules/mlstudio/aws/ec2_client.py

This file was deleted.

Loading

0 comments on commit 94c93d9

Please sign in to comment.