From 67c8d6aa25c9ff78da5e7341b784583583b2f6a7 Mon Sep 17 00:00:00 2001 From: Ruidi Peng Date: Tue, 14 Nov 2023 23:17:21 +0000 Subject: [PATCH] fix artifact not being cleaned up --- tests/helpers.py | 33 +++++++++++++++++++++++++++++++++ tests/integ/test_tracker.py | 36 +++++++++++++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/tests/helpers.py b/tests/helpers.py index b7b2f6a..3621b74 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -163,3 +163,36 @@ def wait_for_trial_component(sagemaker_client, training_job_name=None, trial_com except sagemaker_client.exceptions.ResourceNotFound: logging.info("Trial component %s not created yet.", trial_component_name) time.sleep(5) + + +def delete_artifact(sagemaker_client, artifact_arn, disassociate: bool = False): + """Delete the artifact object. + + Args: + disassociate (bool): When set to true, disassociate incoming and outgoing association. + """ + if disassociate: + _disassociate(sagemaker_client, source_arn=artifact_arn) + _disassociate(sagemaker_client, destination_arn=artifact_arn) + sagemaker_client.delete_artifact(ArtifactArn=artifact_arn) + + +def _disassociate(sagemaker_client, source_arn=None, destination_arn=None): + """Remove the association. + + Remove incoming association when source_arn is provided, remove outgoing association when + destination_arn is provided. + """ + params = { + "SourceArn": source_arn, + "DestinationArn": destination_arn, + } + not_none_params = {k: v for k, v in params.items() if v is not None} + + # list_associations() returns a maximum of 10 associations by default. Test case would not exceed 10. + association_summaries = sagemaker_client.list_associations(**not_none_params) + for association_summary in association_summaries["AssociationSummaries"]: + sagemaker_client.delete_association( + SourceArn=association_summary["SourceArn"], + DestinationArn=association_summary["DestinationArn"], + ) diff --git a/tests/integ/test_tracker.py b/tests/integ/test_tracker.py index 421eafd..a0a7b6d 100644 --- a/tests/integ/test_tracker.py +++ b/tests/integ/test_tracker.py @@ -17,7 +17,7 @@ from tests.helpers import name, wait_for_trial_component from smexperiments import tracker, trial_component, _utils -from tests.helpers import retry +from tests.helpers import retry, delete_artifact def test_load_trial_component(trial_component_obj, sagemaker_boto_client): @@ -96,6 +96,7 @@ def test_log_artifact(trial_component_obj, bucket, tempdir, sagemaker_boto_clien sagemaker_boto_client=sagemaker_boto_client, ) as tracker_obj: tracker_obj.log_artifact(file_path, name=artifact_name) + artifacts = tracker_obj._lineage_artifact_tracker.artifacts loaded = trial_component.TrialComponent.load( trial_component_name=trial_component_obj.trial_component_name, @@ -104,6 +105,13 @@ def test_log_artifact(trial_component_obj, bucket, tempdir, sagemaker_boto_clien assert "text/plain" == loaded.output_artifacts[artifact_name].media_type assert prefix in loaded.output_artifacts[artifact_name].value + try: + delete_artifact( + sagemaker_boto_client, artifacts[0].artifact_arn, disassociate=True + ) + except: + logging.info("Artifacts are not deleted properly.") + def test_log_artifacts(trial_component_obj, bucket, tempdir, sagemaker_boto_client): prefix = name() @@ -122,6 +130,7 @@ def test_log_artifacts(trial_component_obj, bucket, tempdir, sagemaker_boto_clie sagemaker_boto_client=sagemaker_boto_client, ) as tracker_obj: tracker_obj.log_artifacts(tempdir) + artifacts = tracker_obj._lineage_artifact_tracker.artifacts loaded = trial_component.TrialComponent.load( trial_component_name=trial_component_obj.trial_component_name, sagemaker_boto_client=sagemaker_boto_client, @@ -131,6 +140,14 @@ def test_log_artifacts(trial_component_obj, bucket, tempdir, sagemaker_boto_clie assert "text/plain" == loaded.output_artifacts["bar"].media_type assert prefix in loaded.output_artifacts["bar"].value + try: + for artifact in artifacts: + delete_artifact( + sagemaker_boto_client, artifact.artifact_arn, disassociate=True + ) + except: + logging.info("Artifacts are not deleted properly.") + def test_create_default_bucket(boto3_session): bucket_name_prefix = _utils.name("sm-test") @@ -158,6 +175,7 @@ def test_create_lineage_artifacts(trial_component_obj, bucket, tempdir, sagemake sagemaker_boto_client=sagemaker_boto_client, ) as tracker_obj: tracker_obj.log_output_artifact(file_path, name=artifact_name) + artifacts = tracker_obj._lineage_artifact_tracker.artifacts response = sagemaker_boto_client.list_associations(SourceArn=trial_component_obj.trial_component_arn) associations = response["AssociationSummaries"] @@ -171,6 +189,14 @@ def validate(): retry(validate, num_attempts=4) + try: + for artifact in artifacts: + delete_artifact( + sagemaker_boto_client, artifact.artifact_arn, disassociate=True + ) + except: + logging.info("Artifacts are not deleted properly") + def test_log_table_artifact(trial_component_obj, bucket, sagemaker_boto_client): prefix = name() @@ -185,6 +211,7 @@ def test_log_table_artifact(trial_component_obj, bucket, sagemaker_boto_client): sagemaker_boto_client=sagemaker_boto_client, ) as tracker_obj: tracker_obj.log_table(title=artifact_name, values=values) + artifacts = tracker_obj._lineage_artifact_tracker.artifacts response = sagemaker_boto_client.list_associations(SourceArn=trial_component_obj.trial_component_arn) associations = response["AssociationSummaries"] @@ -197,3 +224,10 @@ def validate(): assert summary["DestinationName"] == artifact_name retry(validate, num_attempts=4) + + try: + delete_artifact( + sagemaker_boto_client, artifacts[0].artifact_arn, disassociate=True + ) + except: + logging.info("Artifacts are not deleted properly.")