diff --git a/backend/dataall/modules/s3_datasets/db/dataset_bucket_repositories.py b/backend/dataall/modules/s3_datasets/db/dataset_bucket_repositories.py index d93abfbe5..681aec0f4 100644 --- a/backend/dataall/modules/s3_datasets/db/dataset_bucket_repositories.py +++ b/backend/dataall/modules/s3_datasets/db/dataset_bucket_repositories.py @@ -31,3 +31,7 @@ def delete_dataset_buckets(session, dataset_uri) -> bool: buckets = session.query(DatasetBucket).filter(DatasetBucket.datasetUri == dataset_uri).all() for bucket in buckets: session.delete(bucket) + + @staticmethod + def get_dataset_bucket_by_name(session, bucket_name) -> DatasetBucket: + return session.query(DatasetBucket).filter(DatasetBucket.S3BucketName == bucket_name).first() diff --git a/backend/dataall/modules/s3_datasets/services/dataset_service.py b/backend/dataall/modules/s3_datasets/services/dataset_service.py index 2f82dc666..27febeb3a 100644 --- a/backend/dataall/modules/s3_datasets/services/dataset_service.py +++ b/backend/dataall/modules/s3_datasets/services/dataset_service.py @@ -106,6 +106,12 @@ def check_dataset_account(session, environment): @staticmethod def check_imported_resources(dataset: S3Dataset): + with get_context().db_engine.scoped_session() as session: + if DatasetBucketRepository.get_dataset_bucket_by_name(session, dataset.S3BucketName): + raise exceptions.ResourceAlreadyExists( + action=IMPORT_DATASET, + message=f'Dataset with bucket {dataset.S3BucketName} already exists', + ) if dataset.importedGlueDatabase: if len(dataset.GlueDatabaseName) > NamingConventionPattern.GLUE.value.get('max_length'): raise exceptions.InvalidInput( diff --git a/tests/modules/s3_datasets/conftest.py b/tests/modules/s3_datasets/conftest.py index 2ab1e6a06..402b3186d 100644 --- a/tests/modules/s3_datasets/conftest.py +++ b/tests/modules/s3_datasets/conftest.py @@ -3,6 +3,7 @@ import pytest +from dataall.base.context import set_context, RequestContext, dispose_context from dataall.core.environment.db.environment_models import Environment, EnvironmentGroup from dataall.core.organizations.db.organization_models import Organization from dataall.core.permissions.services.resource_policy_service import ResourcePolicyService @@ -437,3 +438,9 @@ def random_tag(): def random_tags(): return [random_tag() for i in range(1, random.choice([2, 3, 4, 5]))] + + +@pytest.fixture(scope='function') +def api_context_1(db, user, group): + yield set_context(RequestContext(db_engine=db, username=user.username, groups=[group.name], user_id=user.username)) + dispose_context() diff --git a/tests/modules/s3_datasets/test_import_dataset_check_unit.py b/tests/modules/s3_datasets/test_import_dataset_check_unit.py index e8021c528..1eaa466ac 100644 --- a/tests/modules/s3_datasets/test_import_dataset_check_unit.py +++ b/tests/modules/s3_datasets/test_import_dataset_check_unit.py @@ -8,7 +8,7 @@ from dataall.modules.s3_datasets.db.dataset_models import S3Dataset -def test_s3_managed_bucket_import(mock_aws_client): +def test_s3_managed_bucket_import(mock_aws_client, api_context_1): dataset = S3Dataset(KmsAlias=None) mock_encryption_bucket(mock_aws_client, 'AES256', None) @@ -16,7 +16,7 @@ def test_s3_managed_bucket_import(mock_aws_client): assert DatasetService.check_imported_resources(dataset) -def test_s3_managed_bucket_but_bucket_encrypted_with_kms(mock_aws_client): +def test_s3_managed_bucket_but_bucket_encrypted_with_kms(mock_aws_client, api_context_1): dataset = S3Dataset(KmsAlias=None) mock_encryption_bucket(mock_aws_client, 'aws:kms', 'any') @@ -24,7 +24,7 @@ def test_s3_managed_bucket_but_bucket_encrypted_with_kms(mock_aws_client): DatasetService.check_imported_resources(dataset) -def test_s3_managed_bucket_but_alias_provided(mock_aws_client): +def test_s3_managed_bucket_but_alias_provided(mock_aws_client, api_context_1): dataset = S3Dataset(KmsAlias='Key') mock_encryption_bucket(mock_aws_client, 'AES256', None) @@ -32,7 +32,7 @@ def test_s3_managed_bucket_but_alias_provided(mock_aws_client): DatasetService.check_imported_resources(dataset) -def test_kms_encrypted_bucket_but_key_not_exist(mock_aws_client): +def test_kms_encrypted_bucket_but_key_not_exist(mock_aws_client, api_context_1): alias = 'alias' dataset = S3Dataset(KmsAlias=alias) mock_encryption_bucket(mock_aws_client, 'aws:kms', 'any') @@ -42,7 +42,7 @@ def test_kms_encrypted_bucket_but_key_not_exist(mock_aws_client): DatasetService.check_imported_resources(dataset) -def test_kms_encrypted_bucket_but_key_is_wrong(mock_aws_client): +def test_kms_encrypted_bucket_but_key_is_wrong(mock_aws_client, api_context_1): alias = 'key_alias' kms_id = 'kms_id' dataset = S3Dataset(KmsAlias=alias) @@ -54,7 +54,7 @@ def test_kms_encrypted_bucket_but_key_is_wrong(mock_aws_client): DatasetService.check_imported_resources(dataset) -def test_kms_encrypted_bucket_imported(mock_aws_client): +def test_kms_encrypted_bucket_imported(mock_aws_client, api_context_1): alias = 'key_alias' kms_id = 'kms_id' dataset = S3Dataset(KmsAlias=alias)