diff --git a/tests_new/integration_tests/client.py b/tests_new/integration_tests/client.py index b288cff52..75fa0455e 100644 --- a/tests_new/integration_tests/client.py +++ b/tests_new/integration_tests/client.py @@ -29,9 +29,9 @@ def query(self, query: str): graphql_endpoint = os.path.join(os.environ['API_ENDPOINT'], 'graphql', 'api') headers = {'AccessKeyId': 'none', 'SecretKey': 'none', 'authorization': self.token} r = requests.post(graphql_endpoint, json=query, headers=headers) - r.raise_for_status() if errors := r.json().get('errors'): raise GqlError(errors) + r.raise_for_status() return DefaultMunch.fromDict(r.json()) diff --git a/tests_new/integration_tests/core/environment/utils.py b/tests_new/integration_tests/core/environment/utils.py new file mode 100644 index 000000000..99e55c959 --- /dev/null +++ b/tests_new/integration_tests/core/environment/utils.py @@ -0,0 +1,36 @@ +from integration_tests.core.environment.queries import update_environment +from integration_tests.core.stack.utils import check_stack_ready, check_stack_in_progress + + +def set_env_params(client, env, **new_params): + should_update = False + new_params_list = [] + for param in env.parameters: + new_param_value = new_params.get(param.key, param.value) + if new_param_value != param.value: + should_update = True + new_params_list.append({'key': param.key, 'value': new_param_value}) + if should_update: + env_uri = env.environmentUri + stack_uri = env.stack.stackUri + check_stack_ready(client, env_uri, stack_uri) + update_environment( + client, + env.environmentUri, + { + k: v + for k, v in env.items() + if k + in [ + 'description', + 'label', + 'resourcePrefix', + 'subnetIds', + 'tags', + 'vpcId', + ] + } + | {'parameters': new_params_list}, + ) + check_stack_in_progress(client, env_uri, stack_uri) + check_stack_ready(client, env_uri, stack_uri) diff --git a/tests_new/integration_tests/modules/mlstudio/conftest.py b/tests_new/integration_tests/modules/mlstudio/conftest.py new file mode 100644 index 000000000..58935f1fc --- /dev/null +++ b/tests_new/integration_tests/modules/mlstudio/conftest.py @@ -0,0 +1,22 @@ +import pytest + +from integration_tests.core.environment.utils import set_env_params +from integration_tests.core.stack.utils import check_stack_ready +from integration_tests.modules.mlstudio.mutations import create_smstudio_user, delete_smstudio_user +from integration_tests.modules.mlstudio.queries import get_smstudio_user + + +@pytest.fixture(scope='session') +def smstudio_user1(session_id, client1, persistent_env1): + set_env_params(client1, persistent_env1, mlStudiosEnabled='true') + env_uri = persistent_env1.environmentUri + smstudio = create_smstudio_user( + client1, + environmentUri=env_uri, + groupName=persistent_env1.SamlGroupName, + label=session_id, + ) + smstudio_uri = smstudio.sagemakerStudioUserUri + check_stack_ready(client1, env_uri, smstudio.stack.stackUri, smstudio_uri, 'mlstudio') + yield get_smstudio_user(client1, smstudio_uri) + delete_smstudio_user(client1, smstudio_uri) diff --git a/tests_new/integration_tests/modules/mlstudio/mutations.py b/tests_new/integration_tests/modules/mlstudio/mutations.py new file mode 100644 index 000000000..f752973f1 --- /dev/null +++ b/tests_new/integration_tests/modules/mlstudio/mutations.py @@ -0,0 +1,58 @@ +def create_smstudio_user( + client, environmentUri, groupName, label, description='integtestmlstudio', tags=[], topics=None +): + query = { + 'operationName': 'createSagemakerStudioUser', + 'variables': { + 'input': { + 'environmentUri': environmentUri, + 'SamlAdminGroupName': groupName, + 'label': label, + 'description': description, + 'tags': tags, + 'topics': topics, + } + }, + 'query': f""" + mutation createSagemakerStudioUser($input: NewSagemakerStudioUserInput!) {{ + createSagemakerStudioUser(input: $input) {{ + sagemakerStudioUserUri + name + label + created + description + tags + stack {{ + stack + status + stackUri + }} + }} + }} + """, + } + response = client.query(query=query) + return response.data.createSagemakerStudioUser + + +def delete_smstudio_user(client, uri, delete_from_aws=True): + query = { + 'operationName': 'deleteSagemakerStudioUser', + 'variables': { + 'sagemakerStudioUserUri': uri, + 'deleteFromAWS': delete_from_aws, + }, + 'query': f""" + mutation deleteSagemakerStudioUser( + $sagemakerStudioUserUri: String! + $deleteFromAWS: Boolean + ) {{ + deleteSagemakerStudioUser( + sagemakerStudioUserUri: $sagemakerStudioUserUri + deleteFromAWS: $deleteFromAWS + ) + }} + """, + } + response = client.query(query=query) + return response.data.deleteSagemakerStudioUser diff --git a/tests_new/integration_tests/modules/mlstudio/queries.py b/tests_new/integration_tests/modules/mlstudio/queries.py new file mode 100644 index 000000000..cec2c058a --- /dev/null +++ b/tests_new/integration_tests/modules/mlstudio/queries.py @@ -0,0 +1,152 @@ +def get_smstudio_user(client, uri): + query = { + 'operationName': 'getSagemakerStudioUser', + 'variables': { + 'sagemakerStudioUserUri': uri, + }, + 'query': f""" + query getSagemakerStudioUser($sagemakerStudioUserUri: String!) {{ + getSagemakerStudioUser(sagemakerStudioUserUri: $sagemakerStudioUserUri) {{ + sagemakerStudioUserUri + name + owner + description + label + created + tags + userRoleForSagemakerStudioUser + sagemakerStudioUserStatus + SamlAdminGroupName + sagemakerStudioUserApps {{ + DomainId + UserName + AppType + AppName + Status + }} + environment {{ + label + name + environmentUri + AwsAccountId + region + EnvironmentDefaultIAMRoleArn + }} + organization {{ + label + name + organizationUri + }} + stack {{ + stack + status + stackUri + targetUri + accountid + region + stackid + link + outputs + resources + }} + }} + }} + """, + } + response = client.query(query=query) + return response.data.getSagemakerStudioUser + + +def list_smstudio_users(client, term=None): + query = { + 'operationName': 'listSagemakerStudioUsers', + 'variables': { + 'filter': {'term': term}, + }, + 'query': f""" + query listSagemakerStudioUsers($filter: SagemakerStudioUserFilter) {{ + listSagemakerStudioUsers(filter: $filter) {{ + count + page + pages + hasNext + hasPrevious + nodes {{ + sagemakerStudioUserUri + name + owner + description + label + created + tags + sagemakerStudioUserStatus + userRoleForSagemakerStudioUser + environment {{ + label + name + environmentUri + AwsAccountId + region + SamlGroupName + }} + organization {{ + label + name + organizationUri + }} + stack {{ + stack + status + }} + }} + }} + }} + """, + } + response = client.query(query=query) + return response.data.listSagemakerStudioUsers + + +def get_smstudio_user_presigned_url(client, uri): + query = { + 'operationName': 'getSagemakerStudioUserPresignedUrl', + 'variables': { + 'sagemakerStudioUserUri': uri, + }, + 'query': f""" + query getSagemakerStudioUserPresignedUrl($sagemakerStudioUserUri: String!) {{ + getSagemakerStudioUserPresignedUrl( + sagemakerStudioUserUri: $sagemakerStudioUserUri + ) + }} + """, + } + response = client.query(query=query) + return response.data.getSagemakerStudioUserPresignedUrl + + +def get_environment_mlstudio_domain(client, uri): + query = { + 'operationName': 'getEnvironmentMLStudioDomain', + 'variables': { + 'environmentUri': uri, + }, + 'query': f""" + query getEnvironmentMLStudioDomain($environmentUri: String!) {{ + getEnvironmentMLStudioDomain(environmentUri: $environmentUri) {{ + sagemakerStudioUri + environmentUri + label + sagemakerStudioDomainName + DefaultDomainRoleName + vpcType + vpcId + subnetIds + owner + created + }} + }} + """, + } + response = client.query(query=query) + return response.data.getEnvironmentMLStudioDomain diff --git a/tests_new/integration_tests/modules/mlstudio/test_mlstudio.py b/tests_new/integration_tests/modules/mlstudio/test_mlstudio.py new file mode 100644 index 000000000..d1f78be45 --- /dev/null +++ b/tests_new/integration_tests/modules/mlstudio/test_mlstudio.py @@ -0,0 +1,33 @@ +from assertpy import assert_that + +from integration_tests.errors import GqlError +from integration_tests.modules.mlstudio.queries import ( + list_smstudio_users, + get_smstudio_user_presigned_url, + get_environment_mlstudio_domain, +) + + +def test_create_smstudio_user(smstudio_user1): + assert_that(smstudio_user1.stack.status).is_in('CREATE_COMPLETE', 'UPDATE_COMPLETE') + + +def test_list_smstudio_users(client1, client2, session_id, smstudio_user1): + assert_that(list_smstudio_users(client1, term=session_id).nodes).is_length(1) + assert_that(list_smstudio_users(client2, term=session_id).nodes).is_length(0) + + +def test_get_smstudio_user_presigned_url(client1, smstudio_user1): + assert_that(get_smstudio_user_presigned_url(client1, smstudio_user1.sagemakerStudioUserUri)).starts_with('https://') + + +def test_get_smstudio_user_presigned_url_unauthorized(client2, smstudio_user1): + assert_that(get_smstudio_user_presigned_url).raises(GqlError).when_called_with( + client2, smstudio_user1.sagemakerStudioUserUri + ).contains('UnauthorizedOperation', 'SGMSTUDIO_USER_URL') + + +def test_get_environment_mlstudio_domain(client1, smstudio_user1): + assert_that( + get_environment_mlstudio_domain(client1, smstudio_user1.environment.environmentUri).sagemakerStudioDomainName + ).starts_with('dataall')