Skip to content

Commit

Permalink
Mergin Files from Open source
Browse files Browse the repository at this point in the history
  • Loading branch information
TejasRGitHub authored and trajopadhye committed Dec 8, 2023
2 parents 6f3aee3 + 94c93d9 commit 692e5be
Show file tree
Hide file tree
Showing 64 changed files with 1,755 additions and 632 deletions.
65 changes: 65 additions & 0 deletions backend/dataall/base/aws/ec2_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import logging

from dataall.base.aws.sts import SessionHelper
from botocore.exceptions import ClientError

log = logging.getLogger(__name__)


class EC2:

@staticmethod
def get_client(account_id: str, region: str, role=None):
session = SessionHelper.remote_session(accountid=account_id, role=role)
return session.client('ec2', region_name=region)

@staticmethod
def check_default_vpc_exists(AwsAccountId: str, region: str, role=None):
log.info("Check that default VPC exists..")
client = EC2.get_client(account_id=AwsAccountId, region=region, role=role)
response = client.describe_vpcs(
Filters=[{'Name': 'isDefault', 'Values': ['true']}]
)
vpcs = response['Vpcs']
log.info(f"Default VPCs response: {vpcs}")
if vpcs:
vpc_id = vpcs[0]['VpcId']
subnetIds = EC2._get_vpc_subnets(AwsAccountId=AwsAccountId, region=region, vpc_id=vpc_id, role=role)
if subnetIds:
return vpc_id, subnetIds
return False

@staticmethod
def _get_vpc_subnets(AwsAccountId: str, region: str, vpc_id: str, role=None):
client = EC2.get_client(account_id=AwsAccountId, region=region, role=role)
response = client.describe_subnets(
Filters=[{'Name': 'vpc-id', 'Values': [vpc_id]}]
)
return [subnet['SubnetId'] for subnet in response['Subnets']]

@staticmethod
def check_vpc_exists(AwsAccountId, region, vpc_id, role=None, subnet_ids=[]):
try:
ec2 = EC2.get_client(account_id=AwsAccountId, region=region, role=role)
response = ec2.describe_vpcs(VpcIds=[vpc_id])
except ClientError as e:
log.exception(f'VPC Id {vpc_id} Not Found: {e}')
raise Exception(f'VPCNotFound: {vpc_id}')

try:
if subnet_ids:
response = ec2.describe_subnets(
Filters=[
{
'Name': 'vpc-id',
'Values': [vpc_id]
},
],
SubnetIds=subnet_ids
)
except ClientError as e:
log.exception(f'Subnet Id {subnet_ids} Not Found: {e}')
raise Exception(f'VPCSubnetsNotFound: {subnet_ids}')

if not subnet_ids or len(response['Subnets']) != len(subnet_ids):
raise Exception(f'Not All Subnets: {subnet_ids} Are Within the Specified VPC Id {vpc_id}')
60 changes: 43 additions & 17 deletions backend/dataall/base/aws/quicksight.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,23 @@
class QuicksightClient:

DEFAULT_GROUP_NAME = 'dataall'
QUICKSIGHT_IDENTITY_REGIONS = [
{"name": 'US East (N. Virginia)', "code": 'us-east-1'},
{"name": 'US East (Ohio)', "code": 'us-east-2'},
{"name": 'US West (Oregon)', "code": 'us-west-2'},
{"name": 'Europe (Frankfurt)', "code": 'eu-central-1'},
{"name": 'Europe (Stockholm)', "code": 'eu-north-1'},
{"name": 'Europe (Ireland)', "code": 'eu-west-1'},
{"name": 'Europe (London)', "code": 'eu-west-2'},
{"name": 'Europe (Paris)', "code": 'eu-west-3'},
{"name": 'Asia Pacific (Singapore)', "code": 'ap-southeast-1'},
{"name": 'Asia Pacific (Sydney)', "code": 'ap-southeast-2'},
{"name": 'Asia Pacific (Tokyo)', "code": 'ap-northeast-1'},
{"name": 'Asia Pacific (Seoul)', "code": 'ap-northeast-2'},
{"name": 'South America (São Paulo)', "code": 'sa-east-1'},
{"name": 'Canada (Central)', "code": 'ca-central-1'},
{"name": 'Asia Pacific (Mumbai)', "code": 'ap-south-1'},
]

def __init__(self):
pass
Expand Down Expand Up @@ -37,21 +54,29 @@ def get_identity_region(AwsAccountId):
the region quicksight uses as identity region
"""
identity_region_rex = re.compile('Please use the (?P<region>.*) endpoint.')
identity_region = 'us-east-1'
client = QuicksightClient.get_quicksight_client(AwsAccountId=AwsAccountId, region=identity_region)
try:
response = client.describe_group(
AwsAccountId=AwsAccountId, GroupName=QuicksightClient.DEFAULT_GROUP_NAME, Namespace='default'
)
except client.exceptions.AccessDeniedException as e:
match = identity_region_rex.findall(str(e))
if match:
identity_region = match[0]
else:
raise e
except client.exceptions.ResourceNotFoundException:
pass
return identity_region
scp = 'with an explicit deny in a service control policy'
index = 0
while index < len(QuicksightClient.QUICKSIGHT_IDENTITY_REGIONS):
try:
identity_region = QuicksightClient.QUICKSIGHT_IDENTITY_REGIONS[index].get("code")
index += 1
client = QuicksightClient.get_quicksight_client(AwsAccountId=AwsAccountId, region=identity_region)
response = client.describe_account_settings(AwsAccountId=AwsAccountId)
logger.info(f'Returning identity region = {identity_region} for account {AwsAccountId}')
return identity_region
except client.exceptions.AccessDeniedException as e:
if scp in str(e):
logger.info(f'Quicksight SCP found in {identity_region} for account {AwsAccountId}. Trying next region...')
else:
logger.info(f'Quicksight identity region is not {identity_region}, selecting correct region endpoint...')
match = identity_region_rex.findall(str(e))
if match:
identity_region = match[0]
logger.info(f'Returning identity region = {identity_region} for account {AwsAccountId}')
return identity_region
else:
raise e
raise Exception(f'Quicksight subscription is inactive or the identity region has SCPs preventing access from data.all to account {AwsAccountId}')

@staticmethod
def get_quicksight_client_in_identity_region(AwsAccountId):
Expand Down Expand Up @@ -99,10 +124,11 @@ def check_quicksight_enterprise_subscription(AwsAccountId, region=None):
return False

@staticmethod
def create_quicksight_group(AwsAccountId, GroupName=DEFAULT_GROUP_NAME):
def create_quicksight_group(AwsAccountId, region, GroupName=DEFAULT_GROUP_NAME):
"""Creates a Quicksight group called GroupName
Args:
AwsAccountId(str): aws account
region: aws region
GroupName(str): name of the QS group
Returns:dict
Expand All @@ -113,7 +139,7 @@ def create_quicksight_group(AwsAccountId, GroupName=DEFAULT_GROUP_NAME):
if not group:
if GroupName == QuicksightClient.DEFAULT_GROUP_NAME:
logger.info(f'Initializing data.all default group = {GroupName}')
QuicksightClient.check_quicksight_enterprise_subscription(AwsAccountId)
QuicksightClient.check_quicksight_enterprise_subscription(AwsAccountId, region)

logger.info(f'Attempting to create Quicksight group `{GroupName}...')
response = client.create_group(
Expand Down
1 change: 1 addition & 0 deletions backend/dataall/base/utils/naming_convention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class NamingConventionPattern(Enum):
GLUE = {'regex': '[^a-zA-Z0-9_]', 'separator': '_', 'max_length': 63}
GLUE_ETL = {'regex': '[^a-zA-Z0-9-]', 'separator': '-', 'max_length': 52}
NOTEBOOK = {'regex': '[^a-zA-Z0-9-]', 'separator': '-', 'max_length': 63}
MLSTUDIO_DOMAIN = {'regex': '[^a-zA-Z0-9-]', 'separator': '-', 'max_length': 63}
DEFAULT = {'regex': '[^a-zA-Z0-9-_]', 'separator': '-', 'max_length': 63}
OPENSEARCH = {'regex': '[^a-z0-9-]', 'separator': '-', 'max_length': 27}
OPENSEARCH_SERVERLESS = {'regex': '[^a-z0-9-]', 'separator': '-', 'max_length': 31}
Expand Down
15 changes: 6 additions & 9 deletions backend/dataall/core/environment/api/input_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,11 @@
gql.Argument('description', gql.String),
gql.Argument('AwsAccountId', gql.NonNullableType(gql.String)),
gql.Argument('region', gql.NonNullableType(gql.String)),
gql.Argument('vpcId', gql.String),
gql.Argument('privateSubnetIds', gql.ArrayType(gql.String)),
gql.Argument('publicSubnetIds', gql.ArrayType(gql.String)),
gql.Argument('EnvironmentDefaultIAMRoleArn', gql.String),
gql.Argument('resourcePrefix', gql.String),
gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput))

gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)),
gql.Argument('vpcId', gql.String),
gql.Argument('subnetIds', gql.ArrayType(gql.String))
],
)

Expand All @@ -45,11 +43,10 @@
gql.Argument('description', gql.String),
gql.Argument('tags', gql.ArrayType(gql.String)),
gql.Argument('SamlGroupName', gql.String),
gql.Argument('vpcId', gql.String),
gql.Argument('privateSubnetIds', gql.ArrayType(gql.String)),
gql.Argument('publicSubnetIds', gql.ArrayType(gql.String)),
gql.Argument('resourcePrefix', gql.String),
gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput))
gql.Argument('parameters', gql.ArrayType(ModifyEnvironmentParameterInput)),
gql.Argument('vpcId', gql.String),
gql.Argument('subnetIds', gql.ArrayType(gql.String))
],
)

Expand Down
28 changes: 23 additions & 5 deletions backend/dataall/core/environment/api/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from dataall.core.stacks.aws.cloudformation import CloudFormation
from dataall.core.stacks.db.stack_repositories import Stack
from dataall.core.vpc.db.vpc_repositories import Vpc
from dataall.base.aws.ec2_client import EC2
from dataall.base.db import exceptions
from dataall.core.permissions import permissions
from dataall.base.feature_toggle_checker import is_feature_enabled
Expand All @@ -43,7 +44,7 @@ def get_pivot_role_as_part_of_environment(context: Context, source, **kwargs):
return True if ssm_param == "True" else False


def check_environment(context: Context, source, account_id, region):
def check_environment(context: Context, source, account_id, region, data):
""" Checks necessary resources for environment deployment.
- Check CDKToolkit exists in Account assuming cdk_look_up_role
- Check Pivot Role exists in Account if pivot_role_as_part_of_environment is False
Expand Down Expand Up @@ -71,11 +72,25 @@ def check_environment(context: Context, source, account_id, region):
action='CHECK_PIVOT_ROLE',
message='Pivot Role has not been created in the Environment AWS Account',
)
mlStudioEnabled = None
for parameter in data.get("parameters", []):
if parameter['key'] == 'mlStudiosEnabled':
mlStudioEnabled = parameter['value']

if mlStudioEnabled and data.get("vpcId", None) and data.get("subnetIds", []):
log.info("Check if ML Studio VPC Exists in the Account")
EC2.check_vpc_exists(
AwsAccountId=account_id,
region=region,
role=cdk_look_up_role_arn,
vpc_id=data.get("vpcId", None),
subnet_ids=data.get('subnetIds', []),
)

return cdk_role_name


def create_environment(context: Context, source, input=None):
def create_environment(context: Context, source, input={}):
if input.get('SamlGroupName') and input.get('SamlGroupName') not in context.groups:
raise exceptions.UnauthorizedOperation(
action=permissions.LINK_ENVIRONMENT,
Expand All @@ -85,8 +100,10 @@ def create_environment(context: Context, source, input=None):
with context.engine.scoped_session() as session:
cdk_role_name = check_environment(context, source,
account_id=input.get('AwsAccountId'),
region=input.get('region')
region=input.get('region'),
data=input
)

input['cdk_role_name'] = cdk_role_name
env = EnvironmentService.create_environment(
session=session,
Expand Down Expand Up @@ -119,7 +136,8 @@ def update_environment(
environment = EnvironmentService.get_environment_by_uri(session, environmentUri)
cdk_role_name = check_environment(context, source,
account_id=environment.AwsAccountId,
region=environment.region
region=environment.region,
data=input
)

previous_resource_prefix = environment.resourcePrefix
Expand All @@ -130,7 +148,7 @@ def update_environment(
data=input,
)

if EnvironmentResourceManager.deploy_updated_stack(session, previous_resource_prefix, environment):
if EnvironmentResourceManager.deploy_updated_stack(session, previous_resource_prefix, environment, data=input):
stack_helper.deploy_stack(targetUri=environment.environmentUri)

return environment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def get_statements(self):
actions=[
's3:ListAllMyBuckets',
's3:GetBucketLocation',
's3:PutBucketTagging'
's3:PutBucketTagging',
's3:GetEncryptionConfiguration'
],
resources=['*'],
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ def delete_env(session, environment):
pass

@staticmethod
def update_env(session, environment):
def create_env(session, environment, **kwargs):
pass

@staticmethod
def update_env(session, environment, **kwargs):
return False

@staticmethod
Expand All @@ -39,10 +43,10 @@ def count_group_resources(cls, session, environment, group_uri) -> int:
return counter

@classmethod
def deploy_updated_stack(cls, session, prev_prefix, environment):
def deploy_updated_stack(cls, session, prev_prefix, environment, **kwargs):
deploy_stack = prev_prefix != environment.resourcePrefix
for resource in cls._resources:
deploy_stack |= resource.update_env(session, environment)
deploy_stack |= resource.update_env(session, environment, **kwargs)

return deploy_stack

Expand All @@ -51,6 +55,11 @@ def delete_env(cls, session, environment):
for resource in cls._resources:
resource.delete_env(session, environment)

@classmethod
def create_env(cls, session, environment, **kwargs):
for resource in cls._resources:
resource.create_env(session, environment, **kwargs)

@classmethod
def count_consumption_role_resources(cls, session, role_uri):
counter = 0
Expand Down
24 changes: 1 addition & 23 deletions backend/dataall/core/environment/services/environment_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def create_environment(session, uri, data=None):
session.commit()

EnvironmentService._update_env_parameters(session, env, data)
EnvironmentResourceManager.create_env(session, env, data=data)

env.EnvironmentDefaultBucketName = NamingConventionService(
target_uri=env.environmentUri,
Expand Down Expand Up @@ -98,29 +99,6 @@ def create_environment(session, uri, data=None):
env.EnvironmentDefaultIAMRoleArn = data['EnvironmentDefaultIAMRoleArn']
env.EnvironmentDefaultIAMRoleImported = True

if data.get('vpcId'):
vpc = Vpc(
environmentUri=env.environmentUri,
region=env.region,
AwsAccountId=env.AwsAccountId,
VpcId=data.get('vpcId'),
privateSubnetIds=data.get('privateSubnetIds', []),
publicSubnetIds=data.get('publicSubnetIds', []),
SamlGroupName=data['SamlGroupName'],
owner=context.username,
label=f"{env.name}-{data.get('vpcId')}",
name=f"{env.name}-{data.get('vpcId')}",
default=True,
)
session.add(vpc)
session.commit()
ResourcePolicy.attach_resource_policy(
session=session,
group=data['SamlGroupName'],
permissions=permissions.NETWORK_ALL,
resource_uri=vpc.vpcUri,
resource_type=Vpc.__name__,
)
env_group = EnvironmentGroup(
environmentUri=env.environmentUri,
groupUri=data['SamlGroupName'],
Expand Down
9 changes: 0 additions & 9 deletions backend/dataall/core/organizations/api/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,6 @@
test_scope='Organization',
)

listOrganizationInvitedGroups = gql.QueryField(
name='listOrganizationInvitedGroups',
type=gql.Ref('GroupSearchResult'),
args=[
gql.Argument(name='organizationUri', type=gql.NonNullableType(gql.String)),
gql.Argument(name='filter', type=gql.Ref('GroupFilter')),
],
resolver=list_organization_invited_groups,
)

listOrganizationGroups = gql.QueryField(
name='listOrganizationGroups',
Expand Down
13 changes: 0 additions & 13 deletions backend/dataall/core/organizations/api/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,6 @@ def remove_group(context: Context, source, organizationUri=None, groupUri=None):
return organization


def list_organization_invited_groups(
context: Context, source, organizationUri=None, filter=None
):
if filter is None:
filter = {}
with context.engine.scoped_session() as session:
return Organization.paginated_organization_invited_groups(
session=session,
uri=organizationUri,
data=filter,
)


def list_organization_groups(
context: Context, source, organizationUri=None, filter=None
):
Expand Down
Loading

0 comments on commit 692e5be

Please sign in to comment.