Skip to content

Commit

Permalink
add mlstudio integ tests (data-dot-all#1535)
Browse files Browse the repository at this point in the history
### Feature or Bugfix
Feature

### Detail
Adding integration tests for ML Studio

PENDING TESTS PASSING IN DEV AWS ENV

### Relates
related to data-dot-all#1220 and resolves data-dot-all#1534

### Security
Please answer the questions below briefly where applicable, or write
`N/A`. Based on
[OWASP 10](https://owasp.org/Top10/en/).

- Does this PR introduce or modify any input fields or queries - this
includes
fetching data from storage outside the application (e.g. a database, an
S3 bucket)?
  - Is the input sanitized?
- What precautions are you taking before deserializing the data you
consume?
  - Is injection prevented by parametrizing queries?
  - Have you ensured no `eval` or similar functions are used?
- Does this PR introduce any functionality or component that requires
authorization?
- How have you ensured it respects the existing AuthN/AuthZ mechanisms?
  - Are you logging failed auth attempts?
- Are you using or adding any cryptographic features?
  - Do you use a standard proven implementations?
  - Are the used keys controlled by the customer? Where are they stored?
- Are you introducing any new policies/roles/users?
  - Have you used the least-privilege principle? How?


By submitting this pull request, I confirm that my contribution is made
under the terms of the Apache 2.0 license.
  • Loading branch information
petrkalos authored Sep 13, 2024
1 parent 405019d commit 3d4d648
Show file tree
Hide file tree
Showing 6 changed files with 302 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tests_new/integration_tests/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def query(self, query: str):
graphql_endpoint = os.path.join(os.environ['API_ENDPOINT'], 'graphql', 'api')
headers = {'AccessKeyId': 'none', 'SecretKey': 'none', 'authorization': self.token}
r = requests.post(graphql_endpoint, json=query, headers=headers)
r.raise_for_status()
if errors := r.json().get('errors'):
raise GqlError(errors)
r.raise_for_status()

return DefaultMunch.fromDict(r.json())

Expand Down
36 changes: 36 additions & 0 deletions tests_new/integration_tests/core/environment/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from integration_tests.core.environment.queries import update_environment
from integration_tests.core.stack.utils import check_stack_ready, check_stack_in_progress


def set_env_params(client, env, **new_params):
should_update = False
new_params_list = []
for param in env.parameters:
new_param_value = new_params.get(param.key, param.value)
if new_param_value != param.value:
should_update = True
new_params_list.append({'key': param.key, 'value': new_param_value})
if should_update:
env_uri = env.environmentUri
stack_uri = env.stack.stackUri
check_stack_ready(client, env_uri, stack_uri)
update_environment(
client,
env.environmentUri,
{
k: v
for k, v in env.items()
if k
in [
'description',
'label',
'resourcePrefix',
'subnetIds',
'tags',
'vpcId',
]
}
| {'parameters': new_params_list},
)
check_stack_in_progress(client, env_uri, stack_uri)
check_stack_ready(client, env_uri, stack_uri)
22 changes: 22 additions & 0 deletions tests_new/integration_tests/modules/mlstudio/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest

from integration_tests.core.environment.utils import set_env_params
from integration_tests.core.stack.utils import check_stack_ready
from integration_tests.modules.mlstudio.mutations import create_smstudio_user, delete_smstudio_user
from integration_tests.modules.mlstudio.queries import get_smstudio_user


@pytest.fixture(scope='session')
def smstudio_user1(session_id, client1, persistent_env1):
set_env_params(client1, persistent_env1, mlStudiosEnabled='true')
env_uri = persistent_env1.environmentUri
smstudio = create_smstudio_user(
client1,
environmentUri=env_uri,
groupName=persistent_env1.SamlGroupName,
label=session_id,
)
smstudio_uri = smstudio.sagemakerStudioUserUri
check_stack_ready(client1, env_uri, smstudio.stack.stackUri, smstudio_uri, 'mlstudio')
yield get_smstudio_user(client1, smstudio_uri)
delete_smstudio_user(client1, smstudio_uri)
58 changes: 58 additions & 0 deletions tests_new/integration_tests/modules/mlstudio/mutations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
def create_smstudio_user(
client, environmentUri, groupName, label, description='integtestmlstudio', tags=[], topics=None
):
query = {
'operationName': 'createSagemakerStudioUser',
'variables': {
'input': {
'environmentUri': environmentUri,
'SamlAdminGroupName': groupName,
'label': label,
'description': description,
'tags': tags,
'topics': topics,
}
},
'query': f"""
mutation createSagemakerStudioUser($input: NewSagemakerStudioUserInput!) {{
createSagemakerStudioUser(input: $input) {{
sagemakerStudioUserUri
name
label
created
description
tags
stack {{
stack
status
stackUri
}}
}}
}}
""",
}
response = client.query(query=query)
return response.data.createSagemakerStudioUser


def delete_smstudio_user(client, uri, delete_from_aws=True):
query = {
'operationName': 'deleteSagemakerStudioUser',
'variables': {
'sagemakerStudioUserUri': uri,
'deleteFromAWS': delete_from_aws,
},
'query': f"""
mutation deleteSagemakerStudioUser(
$sagemakerStudioUserUri: String!
$deleteFromAWS: Boolean
) {{
deleteSagemakerStudioUser(
sagemakerStudioUserUri: $sagemakerStudioUserUri
deleteFromAWS: $deleteFromAWS
)
}}
""",
}
response = client.query(query=query)
return response.data.deleteSagemakerStudioUser
152 changes: 152 additions & 0 deletions tests_new/integration_tests/modules/mlstudio/queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
def get_smstudio_user(client, uri):
query = {
'operationName': 'getSagemakerStudioUser',
'variables': {
'sagemakerStudioUserUri': uri,
},
'query': f"""
query getSagemakerStudioUser($sagemakerStudioUserUri: String!) {{
getSagemakerStudioUser(sagemakerStudioUserUri: $sagemakerStudioUserUri) {{
sagemakerStudioUserUri
name
owner
description
label
created
tags
userRoleForSagemakerStudioUser
sagemakerStudioUserStatus
SamlAdminGroupName
sagemakerStudioUserApps {{
DomainId
UserName
AppType
AppName
Status
}}
environment {{
label
name
environmentUri
AwsAccountId
region
EnvironmentDefaultIAMRoleArn
}}
organization {{
label
name
organizationUri
}}
stack {{
stack
status
stackUri
targetUri
accountid
region
stackid
link
outputs
resources
}}
}}
}}
""",
}
response = client.query(query=query)
return response.data.getSagemakerStudioUser


def list_smstudio_users(client, term=None):
query = {
'operationName': 'listSagemakerStudioUsers',
'variables': {
'filter': {'term': term},
},
'query': f"""
query listSagemakerStudioUsers($filter: SagemakerStudioUserFilter) {{
listSagemakerStudioUsers(filter: $filter) {{
count
page
pages
hasNext
hasPrevious
nodes {{
sagemakerStudioUserUri
name
owner
description
label
created
tags
sagemakerStudioUserStatus
userRoleForSagemakerStudioUser
environment {{
label
name
environmentUri
AwsAccountId
region
SamlGroupName
}}
organization {{
label
name
organizationUri
}}
stack {{
stack
status
}}
}}
}}
}}
""",
}
response = client.query(query=query)
return response.data.listSagemakerStudioUsers


def get_smstudio_user_presigned_url(client, uri):
query = {
'operationName': 'getSagemakerStudioUserPresignedUrl',
'variables': {
'sagemakerStudioUserUri': uri,
},
'query': f"""
query getSagemakerStudioUserPresignedUrl($sagemakerStudioUserUri: String!) {{
getSagemakerStudioUserPresignedUrl(
sagemakerStudioUserUri: $sagemakerStudioUserUri
)
}}
""",
}
response = client.query(query=query)
return response.data.getSagemakerStudioUserPresignedUrl


def get_environment_mlstudio_domain(client, uri):
query = {
'operationName': 'getEnvironmentMLStudioDomain',
'variables': {
'environmentUri': uri,
},
'query': f"""
query getEnvironmentMLStudioDomain($environmentUri: String!) {{
getEnvironmentMLStudioDomain(environmentUri: $environmentUri) {{
sagemakerStudioUri
environmentUri
label
sagemakerStudioDomainName
DefaultDomainRoleName
vpcType
vpcId
subnetIds
owner
created
}}
}}
""",
}
response = client.query(query=query)
return response.data.getEnvironmentMLStudioDomain
33 changes: 33 additions & 0 deletions tests_new/integration_tests/modules/mlstudio/test_mlstudio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from assertpy import assert_that

from integration_tests.errors import GqlError
from integration_tests.modules.mlstudio.queries import (
list_smstudio_users,
get_smstudio_user_presigned_url,
get_environment_mlstudio_domain,
)


def test_create_smstudio_user(smstudio_user1):
assert_that(smstudio_user1.stack.status).is_in('CREATE_COMPLETE', 'UPDATE_COMPLETE')


def test_list_smstudio_users(client1, client2, session_id, smstudio_user1):
assert_that(list_smstudio_users(client1, term=session_id).nodes).is_length(1)
assert_that(list_smstudio_users(client2, term=session_id).nodes).is_length(0)


def test_get_smstudio_user_presigned_url(client1, smstudio_user1):
assert_that(get_smstudio_user_presigned_url(client1, smstudio_user1.sagemakerStudioUserUri)).starts_with('https://')


def test_get_smstudio_user_presigned_url_unauthorized(client2, smstudio_user1):
assert_that(get_smstudio_user_presigned_url).raises(GqlError).when_called_with(
client2, smstudio_user1.sagemakerStudioUserUri
).contains('UnauthorizedOperation', 'SGMSTUDIO_USER_URL')


def test_get_environment_mlstudio_domain(client1, smstudio_user1):
assert_that(
get_environment_mlstudio_domain(client1, smstudio_user1.environment.environmentUri).sagemakerStudioDomainName
).starts_with('dataall')

0 comments on commit 3d4d648

Please sign in to comment.