Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix artifact not being cleaned up #194

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,36 @@ def wait_for_trial_component(sagemaker_client, training_job_name=None, trial_com
except sagemaker_client.exceptions.ResourceNotFound:
logging.info("Trial component %s not created yet.", trial_component_name)
time.sleep(5)


def delete_artifact(sagemaker_client, artifact_arn, disassociate: bool = False):
"""Delete the artifact object.

Args:
disassociate (bool): When set to true, disassociate incoming and outgoing association.
"""
if disassociate:
_disassociate(sagemaker_client, source_arn=artifact_arn)
_disassociate(sagemaker_client, destination_arn=artifact_arn)
sagemaker_client.delete_artifact(ArtifactArn=artifact_arn)


def _disassociate(sagemaker_client, source_arn=None, destination_arn=None):
"""Remove the association.

Remove incoming association when source_arn is provided, remove outgoing association when
destination_arn is provided.
"""
params = {
"SourceArn": source_arn,
"DestinationArn": destination_arn,
}
not_none_params = {k: v for k, v in params.items() if v is not None}

# list_associations() returns a maximum of 10 associations by default. Test case would not exceed 10.
association_summaries = sagemaker_client.list_associations(**not_none_params)
for association_summary in association_summaries["AssociationSummaries"]:
sagemaker_client.delete_association(
SourceArn=association_summary["SourceArn"],
DestinationArn=association_summary["DestinationArn"],
)
36 changes: 35 additions & 1 deletion tests/integ/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from tests.helpers import name, wait_for_trial_component
from smexperiments import tracker, trial_component, _utils
from tests.helpers import retry
from tests.helpers import retry, delete_artifact


def test_load_trial_component(trial_component_obj, sagemaker_boto_client):
Expand Down Expand Up @@ -96,6 +96,7 @@ def test_log_artifact(trial_component_obj, bucket, tempdir, sagemaker_boto_clien
sagemaker_boto_client=sagemaker_boto_client,
) as tracker_obj:
tracker_obj.log_artifact(file_path, name=artifact_name)
artifacts = tracker_obj._lineage_artifact_tracker.artifacts

loaded = trial_component.TrialComponent.load(
trial_component_name=trial_component_obj.trial_component_name,
Expand All @@ -104,6 +105,13 @@ def test_log_artifact(trial_component_obj, bucket, tempdir, sagemaker_boto_clien
assert "text/plain" == loaded.output_artifacts[artifact_name].media_type
assert prefix in loaded.output_artifacts[artifact_name].value

try:
delete_artifact(
sagemaker_boto_client, artifacts[0].artifact_arn, disassociate=True
)
except:
logging.info("Artifacts are not deleted properly.")


def test_log_artifacts(trial_component_obj, bucket, tempdir, sagemaker_boto_client):
prefix = name()
Expand All @@ -122,6 +130,7 @@ def test_log_artifacts(trial_component_obj, bucket, tempdir, sagemaker_boto_clie
sagemaker_boto_client=sagemaker_boto_client,
) as tracker_obj:
tracker_obj.log_artifacts(tempdir)
artifacts = tracker_obj._lineage_artifact_tracker.artifacts
loaded = trial_component.TrialComponent.load(
trial_component_name=trial_component_obj.trial_component_name,
sagemaker_boto_client=sagemaker_boto_client,
Expand All @@ -131,6 +140,14 @@ def test_log_artifacts(trial_component_obj, bucket, tempdir, sagemaker_boto_clie
assert "text/plain" == loaded.output_artifacts["bar"].media_type
assert prefix in loaded.output_artifacts["bar"].value

try:
for artifact in artifacts:
delete_artifact(
sagemaker_boto_client, artifact.artifact_arn, disassociate=True
)
except:
logging.info("Artifacts are not deleted properly.")


def test_create_default_bucket(boto3_session):
bucket_name_prefix = _utils.name("sm-test")
Expand Down Expand Up @@ -158,6 +175,7 @@ def test_create_lineage_artifacts(trial_component_obj, bucket, tempdir, sagemake
sagemaker_boto_client=sagemaker_boto_client,
) as tracker_obj:
tracker_obj.log_output_artifact(file_path, name=artifact_name)
artifacts = tracker_obj._lineage_artifact_tracker.artifacts

response = sagemaker_boto_client.list_associations(SourceArn=trial_component_obj.trial_component_arn)
associations = response["AssociationSummaries"]
Expand All @@ -171,6 +189,14 @@ def validate():

retry(validate, num_attempts=4)

try:
for artifact in artifacts:
delete_artifact(
sagemaker_boto_client, artifact.artifact_arn, disassociate=True
)
except:
logging.info("Artifacts are not deleted properly")


def test_log_table_artifact(trial_component_obj, bucket, sagemaker_boto_client):
prefix = name()
Expand All @@ -185,6 +211,7 @@ def test_log_table_artifact(trial_component_obj, bucket, sagemaker_boto_client):
sagemaker_boto_client=sagemaker_boto_client,
) as tracker_obj:
tracker_obj.log_table(title=artifact_name, values=values)
artifacts = tracker_obj._lineage_artifact_tracker.artifacts

response = sagemaker_boto_client.list_associations(SourceArn=trial_component_obj.trial_component_arn)
associations = response["AssociationSummaries"]
Expand All @@ -197,3 +224,10 @@ def validate():
assert summary["DestinationName"] == artifact_name

retry(validate, num_attempts=4)

try:
delete_artifact(
sagemaker_boto_client, artifacts[0].artifact_arn, disassociate=True
)
except:
logging.info("Artifacts are not deleted properly.")
Loading