From 93a846670f57f444b590551b2d67a3c6a95302aa Mon Sep 17 00:00:00 2001 From: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Date: Wed, 14 Dec 2022 09:09:46 -0800 Subject: [PATCH] feature: Add SageMaker Experiment (#3536) * feature: Add experiment plus Run class (#691) * feature: Add Experiment helper classes (#646) * feature: Add Experiment helper classes feature: Add helper class _RunEnvironment * change: Change sleep retry to backoff retry for get TC * minor fixes in backoff retry Co-authored-by: Dewen Qi * feature: Add helper classes and methods for Run class (#660) * feature: Add helper classes and methods for Run class * Add Parent class to address comment * fix docstyle check * Add arg docstrings in _helper Co-authored-by: Dewen Qi * feature: Add Experiment Run class (#651) Co-authored-by: Dewen Qi * change: Add integ tests for Run (#673) Co-authored-by: Dewen Qi * Update run log metric to use MetricsManager (#678) * Update run.log_metric to use _MetricsManager * fix several metrics issues * Add doc strings to metrics.py Co-authored-by: Dana Benson Co-authored-by: Dana Benson <31262102+danabens@users.noreply.github.com> Co-authored-by: Dewen Qi Co-authored-by: Dewen Qi Co-authored-by: Dana Benson Co-authored-by: Dana Benson <31262102+danabens@users.noreply.github.com> * change: Simplify exp plus integ test configuration (#694) Co-authored-by: Dewen Qi * feature: add RunName to expeirment_config (#696) * change: Update Run init and add Run load and _RunContext (#707) * change: Update Run init and add Run load Add exp name and run group name to load and address comments * Address nit comments Co-authored-by: Dewen Qi * fix: Fix run name uniqueness issue (#730) Co-authored-by: Dewen Qi * change: Update integ tests for Exp Plus M1 changes (#741) Co-authored-by: Dewen Qi * add metrics client to session object (#745) Co-authored-by: Dewen Qi Co-authored-by: Dana Benson Co-authored-by: Dana Benson <31262102+danabens@users.noreply.github.com> Co-authored-by: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> * change: Add integ test for using Run in Transform Job (#749) Co-authored-by: Dewen Qi * Add async metrics sink (#739) Co-authored-by: Dewen Qi Co-authored-by: Dana Benson Co-authored-by: Dana Benson <31262102+danabens@users.noreply.github.com> Co-authored-by: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> * use metrics client provided by session (#754) * fix flaky metrics test (#753) * change: Change Run.init and Run.load to constructor and module method respectively (#752) Co-authored-by: Dewen Qi * feature: Add latest metric service model (#757) Co-authored-by: Dewen Qi Co-authored-by: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> * fix: lowercase run name (#767) * Change: Minimize use of lower case tc name (#769) * change: Clean up test resources to remove model files (#756) * change: Clean up test resources to remove model files * fix: Change experiment enums to upper case * change: Upgrade boto3 and update test to validate mixed case name * fix: Update as per latest botocore release and backend change Co-authored-by: Dewen Qi * lowercase trial component name (#776) * change: Expose sagemaker experiment doc strings * fix: Fix exp name mixed case in issue Co-authored-by: Dewen Qi Co-authored-by: Dana Benson Co-authored-by: Dana Benson <31262102+danabens@users.noreply.github.com> Co-authored-by: Yifei Zhu <66866419+yzhu0@users.noreply.github.com> --- .gitignore | 5 +- doc/experiments/index.rst | 10 + doc/experiments/sagemaker.experiments.rst | 20 + doc/index.rst | 10 + requirements/extras/test_requirements.txt | 1 + setup.py | 2 +- src/sagemaker/amazon/amazon_estimator.py | 7 +- src/sagemaker/apiutils/_base_types.py | 6 +- src/sagemaker/apiutils/_boto_functions.py | 4 +- src/sagemaker/dataset_definition/inputs.py | 6 +- src/sagemaker/estimator.py | 16 +- src/sagemaker/experiments/__init__.py | 20 + src/sagemaker/experiments/_api_types.py | 251 +++++ src/sagemaker/experiments/_environment.py | 132 +++ src/sagemaker/experiments/_helper.py | 266 +++++ src/sagemaker/experiments/_metrics.py | 413 ++++++++ src/sagemaker/experiments/_run_context.py | 58 ++ src/sagemaker/experiments/_utils.py | 218 ++++ src/sagemaker/experiments/experiment.py | 237 +++++ src/sagemaker/experiments/run.py | 882 ++++++++++++++++ src/sagemaker/experiments/trial.py | 289 ++++++ src/sagemaker/experiments/trial_component.py | 341 +++++++ src/sagemaker/lineage/_utils.py | 17 - src/sagemaker/lineage/artifact.py | 3 +- src/sagemaker/processing.py | 9 +- src/sagemaker/session.py | 23 +- src/sagemaker/transformer.py | 7 +- src/sagemaker/utilities/search_expression.py | 133 +++ src/sagemaker/utils.py | 66 ++ tests/data/experiment/inference.py | 85 ++ .../process_job_script_for_run_clz.py | 37 + .../train_job_script_for_run_clz.py | 71 ++ .../transform_job_materials/data.csv | 1 + .../transform_job_materials/xgb_model.tar.gz | Bin 0 -> 35946 bytes tests/integ/sagemaker/experiments/__init__.py | 0 tests/integ/sagemaker/experiments/conftest.py | 177 ++++ tests/integ/sagemaker/experiments/helpers.py | 42 + .../sagemaker/experiments/test_experiment.py | 56 ++ .../sagemaker/experiments/test_metrics.py | 39 + tests/integ/sagemaker/experiments/test_run.py | 662 ++++++++++++ .../integ/sagemaker/experiments/test_trial.py | 75 ++ .../experiments/test_trial_component.py | 144 +++ tests/integ/sagemaker/lineage/conftest.py | 5 +- tests/integ/sagemaker/lineage/helpers.py | 14 - .../integ/sagemaker/lineage/test_artifact.py | 4 +- tests/integ/sagemaker/utilities/__init__.py | 0 .../utilities/test_search_expression.py | 67 ++ tests/integ/test_marketplace.py | 4 +- tests/integ/test_multidatamodel.py | 21 +- tests/integ/utils.py | 20 + tests/unit/conftest.py | 66 ++ tests/unit/sagemaker/experiments/__init__.py | 0 tests/unit/sagemaker/experiments/conftest.py | 86 ++ tests/unit/sagemaker/experiments/helpers.py | 44 + .../sagemaker/experiments/test_environment.py | 107 ++ .../sagemaker/experiments/test_experiment.py | 306 ++++++ .../unit/sagemaker/experiments/test_helper.py | 195 ++++ .../sagemaker/experiments/test_metrics.py | 178 ++++ tests/unit/sagemaker/experiments/test_run.py | 941 ++++++++++++++++++ .../sagemaker/experiments/test_run_context.py | 191 ++++ .../unit/sagemaker/experiments/test_trial.py | 276 +++++ .../experiments/test_trial_component.py | 384 +++++++ .../unit/sagemaker/experiments/test_utils.py | 36 + .../sagemaker/huggingface/test_estimator.py | 1 + .../sagemaker/tensorflow/test_estimator.py | 1 + .../test_huggingface_pytorch_compiler.py | 1 + .../test_huggingface_tensorflow_compiler.py | 1 + .../test_tensorflow_compiler.py | 1 + .../utilities/test_search_expression.py | 80 ++ .../workflow/test_clarify_check_step.py | 44 - .../unit/sagemaker/workflow/test_entities.py | 43 - .../workflow/test_quality_check_step.py | 46 - tests/unit/sagemaker/workflow/test_steps.py | 47 +- tests/unit/test_amazon_estimator.py | 13 +- tests/unit/test_estimator.py | 9 +- tests/unit/test_mxnet.py | 1 + tests/unit/test_pytorch.py | 1 + tests/unit/test_rl.py | 1 + tests/unit/test_session.py | 15 + tests/unit/test_sklearn.py | 1 + tests/unit/test_utils.py | 64 +- tests/unit/test_xgboost.py | 1 + 82 files changed, 7894 insertions(+), 263 deletions(-) create mode 100644 doc/experiments/index.rst create mode 100644 doc/experiments/sagemaker.experiments.rst create mode 100644 src/sagemaker/experiments/__init__.py create mode 100644 src/sagemaker/experiments/_api_types.py create mode 100644 src/sagemaker/experiments/_environment.py create mode 100644 src/sagemaker/experiments/_helper.py create mode 100644 src/sagemaker/experiments/_metrics.py create mode 100644 src/sagemaker/experiments/_run_context.py create mode 100644 src/sagemaker/experiments/_utils.py create mode 100644 src/sagemaker/experiments/experiment.py create mode 100644 src/sagemaker/experiments/run.py create mode 100644 src/sagemaker/experiments/trial.py create mode 100644 src/sagemaker/experiments/trial_component.py create mode 100644 src/sagemaker/utilities/search_expression.py create mode 100644 tests/data/experiment/inference.py create mode 100644 tests/data/experiment/process_job_script_for_run_clz.py create mode 100644 tests/data/experiment/train_job_script_for_run_clz.py create mode 100644 tests/data/experiment/transform_job_materials/data.csv create mode 100644 tests/data/experiment/transform_job_materials/xgb_model.tar.gz create mode 100644 tests/integ/sagemaker/experiments/__init__.py create mode 100644 tests/integ/sagemaker/experiments/conftest.py create mode 100644 tests/integ/sagemaker/experiments/helpers.py create mode 100644 tests/integ/sagemaker/experiments/test_experiment.py create mode 100644 tests/integ/sagemaker/experiments/test_metrics.py create mode 100644 tests/integ/sagemaker/experiments/test_run.py create mode 100644 tests/integ/sagemaker/experiments/test_trial.py create mode 100644 tests/integ/sagemaker/experiments/test_trial_component.py create mode 100644 tests/integ/sagemaker/utilities/__init__.py create mode 100644 tests/integ/sagemaker/utilities/test_search_expression.py create mode 100644 tests/unit/conftest.py create mode 100644 tests/unit/sagemaker/experiments/__init__.py create mode 100644 tests/unit/sagemaker/experiments/conftest.py create mode 100644 tests/unit/sagemaker/experiments/helpers.py create mode 100644 tests/unit/sagemaker/experiments/test_environment.py create mode 100644 tests/unit/sagemaker/experiments/test_experiment.py create mode 100644 tests/unit/sagemaker/experiments/test_helper.py create mode 100644 tests/unit/sagemaker/experiments/test_metrics.py create mode 100644 tests/unit/sagemaker/experiments/test_run.py create mode 100644 tests/unit/sagemaker/experiments/test_run_context.py create mode 100644 tests/unit/sagemaker/experiments/test_trial.py create mode 100644 tests/unit/sagemaker/experiments/test_trial_component.py create mode 100644 tests/unit/sagemaker/experiments/test_utils.py create mode 100644 tests/unit/sagemaker/utilities/test_search_expression.py diff --git a/.gitignore b/.gitignore index 9829ed9781..cae8f890ea 100644 --- a/.gitignore +++ b/.gitignore @@ -30,5 +30,6 @@ env/ .vscode/ **/tmp .python-version -**/_repack_model.py -**/_repack_script_launcher.sh \ No newline at end of file +**/_repack_script_launcher.sh +tests/data/**/_repack_model.py +tests/data/experiment/sagemaker-dev-1.0.tar.gz diff --git a/doc/experiments/index.rst b/doc/experiments/index.rst new file mode 100644 index 0000000000..8c12f30edc --- /dev/null +++ b/doc/experiments/index.rst @@ -0,0 +1,10 @@ +############################ +Amazon SageMaker Experiments +############################ + +The SageMaker Python SDK supports to track and organize your machine learning workflow across SageMaker with jobs, such as Processing, Training and Transform, or locally. + +.. toctree:: + :maxdepth: 2 + + sagemaker.experiments diff --git a/doc/experiments/sagemaker.experiments.rst b/doc/experiments/sagemaker.experiments.rst new file mode 100644 index 0000000000..f0776ec43b --- /dev/null +++ b/doc/experiments/sagemaker.experiments.rst @@ -0,0 +1,20 @@ +Experiments +============ + +Run +------------- + +.. autoclass:: sagemaker.experiments.Run + :members: + +.. automethod:: sagemaker.experiments.load_run + +.. automethod:: sagemaker.experiments.list_runs + +.. autoclass:: sagemaker.experiments.SortByType + :members: + :undoc-members: + +.. autoclass:: sagemaker.experiments.SortOrderType + :members: + :undoc-members: diff --git a/doc/index.rst b/doc/index.rst index 2d4ebe32c1..69038056b0 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -60,6 +60,16 @@ Orchestrate your SageMaker training and inference workflows with Airflow and Kub workflows/index +**************************** +Amazon SageMaker Experiments +**************************** +You can use Amazon SageMaker Experiments to track machine learning experiments. + +.. toctree:: + :maxdepth: 2 + + experiments/index + ************************* Amazon SageMaker Debugger ************************* diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index fe93fd4d0e..494b6dca11 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -20,3 +20,4 @@ requests==2.27.1 sagemaker-experiments==0.1.35 Jinja2==3.0.3 pandas>=1.3.5,<1.5 +scikit-learn==1.0.2 diff --git a/setup.py b/setup.py index 4327045760..e2adb6b433 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ def read_requirements(filename): # Declare minimal set for installation required_packages = [ "attrs>=20.3.0,<23", - "boto3>=1.26.20,<2.0", + "boto3>=1.26.28,<2.0", "google-pasta", "numpy>=1.9.0,<2.0", "protobuf>=3.1,<4.0", diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index b156f2e65f..1abea5e48c 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -27,7 +27,7 @@ from sagemaker.deprecations import renamed_warning from sagemaker.estimator import EstimatorBase, _TrainingJob from sagemaker.inputs import FileSystemInput, TrainingInput -from sagemaker.utils import sagemaker_timestamp +from sagemaker.utils import sagemaker_timestamp, check_and_get_run_experiment_config from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline from sagemaker.workflow import is_pipeline_variable @@ -242,8 +242,8 @@ def fit( generates a default job name, based on the training image name and current timestamp. experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -255,6 +255,7 @@ def fit( """ self._prepare_for_training(records, job_name=job_name, mini_batch_size=mini_batch_size) + experiment_config = check_and_get_run_experiment_config(experiment_config) self.latest_training_job = _TrainingJob.start_new( self, records, experiment_config=experiment_config ) diff --git a/src/sagemaker/apiutils/_base_types.py b/src/sagemaker/apiutils/_base_types.py index e920797b18..9a7359e12b 100644 --- a/src/sagemaker/apiutils/_base_types.py +++ b/src/sagemaker/apiutils/_base_types.py @@ -173,8 +173,10 @@ def _search( search_items = search_method_response.get("Results", []) next_token = search_method_response.get(boto_next_token_name) for item in search_items: - if cls.__name__ in item: - yield search_item_factory(item[cls.__name__]) + # _TrialComponent class in experiments module is not public currently + class_name = cls.__name__.lstrip("_") + if class_name in item: + yield search_item_factory(item[class_name]) if not next_token: break except StopIteration: diff --git a/src/sagemaker/apiutils/_boto_functions.py b/src/sagemaker/apiutils/_boto_functions.py index 1e29f2ebea..a227d30ca8 100644 --- a/src/sagemaker/apiutils/_boto_functions.py +++ b/src/sagemaker/apiutils/_boto_functions.py @@ -68,7 +68,9 @@ def from_boto(boto_dict, boto_name_to_member_name, member_name_to_type): api_type, is_collection = member_name_to_type[member_name] if is_collection: if isinstance(boto_value, dict): - member_value = api_type.from_boto(boto_value) + member_value = { + key: api_type.from_boto(value) for key, value in boto_value.items() + } else: member_value = [api_type.from_boto(item) for item in boto_value] else: diff --git a/src/sagemaker/dataset_definition/inputs.py b/src/sagemaker/dataset_definition/inputs.py index 90a272c4d7..468be22ac3 100644 --- a/src/sagemaker/dataset_definition/inputs.py +++ b/src/sagemaker/dataset_definition/inputs.py @@ -124,8 +124,10 @@ class DatasetDefinition(ApiObject): """DatasetDefinition input.""" _custom_boto_types = { - "redshift_dataset_definition": (RedshiftDatasetDefinition, True), - "athena_dataset_definition": (AthenaDatasetDefinition, True), + # RedshiftDatasetDefinition and AthenaDatasetDefinition are not collection + # Instead they are singleton objects. Thus, set the is_collection flag to False. + "redshift_dataset_definition": (RedshiftDatasetDefinition, False), + "athena_dataset_definition": (AthenaDatasetDefinition, False), } def __init__( diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 6f729267de..e3b06950aa 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -79,6 +79,7 @@ get_config_value, name_from_base, to_string, + check_and_get_run_experiment_config, ) from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable @@ -1103,8 +1104,8 @@ def fit( job_name (str): Training job name. If not specified, the estimator generates a default job name based on the training image name and current timestamp. experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'.. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -1122,6 +1123,7 @@ def fit( """ self._prepare_for_training(job_name=job_name) + experiment_config = check_and_get_run_experiment_config(experiment_config) self.latest_training_job = _TrainingJob.start_new(self, inputs, experiment_config) self.jobs.append(self.latest_training_job) if wait: @@ -2023,8 +2025,8 @@ def start_new(cls, estimator, inputs, experiment_config): inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`. experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -2033,6 +2035,7 @@ def start_new(cls, estimator, inputs, experiment_config): * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + * `RunName` is used to record an experiment run. Returns: sagemaker.estimator._TrainingJob: Constructed object that captures all information about the started training job. @@ -2053,8 +2056,8 @@ def _get_train_args(cls, estimator, inputs, experiment_config): inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`. experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -2063,6 +2066,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config): * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + * `RunName` is used to record an experiment run. Returns: Dict: dict for `sagemaker.session.Session.train` method diff --git a/src/sagemaker/experiments/__init__.py b/src/sagemaker/experiments/__init__.py new file mode 100644 index 0000000000..b87656b1ab --- /dev/null +++ b/src/sagemaker/experiments/__init__.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker Experiment Module""" +from __future__ import absolute_import + +from sagemaker.experiments.run import Run # noqa: F401 +from sagemaker.experiments.run import load_run # noqa: F401 +from sagemaker.experiments.run import list_runs # noqa: F401 +from sagemaker.experiments.run import SortOrderType # noqa: F401 +from sagemaker.experiments.run import SortByType # noqa: F401 diff --git a/src/sagemaker/experiments/_api_types.py b/src/sagemaker/experiments/_api_types.py new file mode 100644 index 0000000000..78f82565aa --- /dev/null +++ b/src/sagemaker/experiments/_api_types.py @@ -0,0 +1,251 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains API objects for SageMaker experiments.""" +from __future__ import absolute_import + +import enum +import numbers + +from sagemaker.apiutils import _base_types + + +class TrialComponentMetricSummary(_base_types.ApiObject): + """Summary model of a trial component. + + Attributes: + metric_name (str): The name of the metric. + source_arn (str): The ARN of the source. + time_stamp (datetime): Metric last updated value. + max (float): The max value of the metric. + min (float): The min value of the metric. + last (float): The last value of the metric. + count (float): The number of samples used to generate the metric. + avg (float): The average value of the metric. + std_dev (float): The standard deviation of the metric. + """ + + metric_name = None + source_arn = None + time_stamp = None + max = None + min = None + last = None + count = None + avg = None + std_dev = None + + def __init__(self, metric_name=None, source_arn=None, **kwargs): + super(TrialComponentMetricSummary, self).__init__( + metric_name=metric_name, source_arn=source_arn, **kwargs + ) + + +class TrialComponentParameters(_base_types.ApiObject): + """A dictionary of TrialComponentParameterValues""" + + @classmethod + def from_boto(cls, boto_dict, **kwargs): + """Converts a boto dict to a dictionary of TrialComponentParameterValues + + Args: + boto_dict (dict): boto response dictionary. + **kwargs: Arbitrary keyword arguments. + + Returns: + dict: Dictionary of parameter values. + """ + return_map = {} + for key, value in boto_dict.items(): + return_map[key] = value.get("NumberValue", value.get("StringValue", None)) + return return_map + + @classmethod + def to_boto(cls, parameters): + """Converts TrialComponentParameters to dict. + + Args: + parameters (TrialComponentParameters): Dictionary to convert. + + Returns: + dict: Dictionary of trial component parameters in boto format. + """ + boto_map = {} + for key, value in parameters.items(): + if isinstance(value, numbers.Number): + boto_map[key] = {"NumberValue": value} + else: + boto_map[key] = {"StringValue": str(value)} + return boto_map + + +class TrialComponentArtifact(_base_types.ApiObject): + """Trial component artifact. + + Attributes: + value (str): The artifact value. + media_type (str): The media type. + """ + + value = None + media_type = None + + def __init__(self, value=None, media_type=None, **kwargs): + super(TrialComponentArtifact, self).__init__(value=value, media_type=media_type, **kwargs) + + +class _TrialComponentStatusType(enum.Enum): + """The type of trial component status""" + + InProgress = "InProgress" + Completed = "Completed" + Failed = "Failed" + + +class TrialComponentStatus(_base_types.ApiObject): + """Status of the trial component. + + Attributes: + primary_status (str): The status of a trial component. + message (str): Status message. + """ + + primary_status = None + message = None + + def __init__(self, primary_status=None, message=None, **kwargs): + super(TrialComponentStatus, self).__init__( + primary_status=primary_status, message=message, **kwargs + ) + + +class TrialComponentSummary(_base_types.ApiObject): + """Summary model of a trial component. + + Attributes: + trial_component_name (str): Name of trial component. + trial_component_arn (str): ARN of the trial component. + display_name (str): Friendly display name in UI. + source_arn (str): ARN of the trial component source. + status (str): Status. + start_time (datetime): Start time. + end_time (datetime): End time. + creation_time (datetime): Creation time. + created_by (str): Created by. + last_modified_time (datetime): Date last modified. + last_modified_by (datetime): User last modified. + """ + + _custom_boto_types = { + "status": (TrialComponentStatus, False), + } + trial_component_name = None + trial_component_arn = None + display_name = None + source_arn = None + status = None + start_time = None + end_time = None + creation_time = None + created_by = None + last_modified_time = None + last_modified_by = None + + +class TrialComponentSource(_base_types.ApiObject): + """Trial Component Source + + Attributes: + source_arn (str): The ARN of the source. + """ + + source_arn = None + + def __init__(self, source_arn=None, **kwargs): + super(TrialComponentSource, self).__init__(source_arn=source_arn, **kwargs) + + +class Parent(_base_types.ApiObject): + """The trial/experiment/run that a trial component is associated with. + + Attributes: + trial_name (str): Name of the trial. + experiment_name (str): Name of the experiment. + run_name (str): Name of the run. + """ + + trial_name = None + experiment_name = None + run_name = None + + +class TrialComponentSearchResult(_base_types.ApiObject): + """Summary model of an Trial Component search result. + + Attributes: + trial_component_arn (str): ARN of the trial component. + trial_component_name (str): Name of the trial component. + display_name (str): Display name of the trial component for UI display. + source (dict): The source of the trial component. + status (dict): The status of the trial component. + start_time (datetime): Start time. + end_time (datetime): End time. + creation_time (datetime): Creation time. + created_by (str): Created by. + last_modified_time (datetime): Date last modified. + last_modified_by (datetime): User last modified. + parameters (dict): The hyperparameters of the component. + input_artifacts (dict): The input artifacts of the component. + output_artifacts (dict): The output artifacts of the component. + metrics (list): The metrics for the component. + source_detail (dict): The source of the trial component. + tags (list): The list of tags that are associated with the trial component. + parents (list[Parent]): The parent of trial component. + """ + + _custom_boto_types = { + "parents": (Parent, True), # parents is a collection (list) of Parent objects + } + trial_component_arn = None + trial_component_name = None + display_name = None + source = None + status = None + start_time = None + end_time = None + creation_time = None + created_by = None + last_modified_time = None + last_modified_by = None + parameters = None + input_artifacts = None + output_artifacts = None + metrics = None + source_detail = None + tags = None + parents = None + + +class TrialSummary(_base_types.ApiObject): + """Summary model of a trial. + + Attributes: + trial_arn (str): The ARN of the trial. + trial_name (str): The name of the trial. + creation_time (datetime): When the trial was created. + last_modified_time (datetime): When the trial was last modified. + """ + + trial_arn = None + trial_name = None + creation_time = None + last_modified_time = None diff --git a/src/sagemaker/experiments/_environment.py b/src/sagemaker/experiments/_environment.py new file mode 100644 index 0000000000..441661ae5a --- /dev/null +++ b/src/sagemaker/experiments/_environment.py @@ -0,0 +1,132 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the _RunEnvironment class.""" +from __future__ import absolute_import + +import enum +import json +import logging +import os + +from sagemaker.experiments import trial_component +from sagemaker.utils import retry_with_backoff + +TRAINING_JOB_ARN_ENV = "TRAINING_JOB_ARN" +PROCESSING_JOB_CONFIG_PATH = "/opt/ml/config/processingjobconfig.json" +TRANSFORM_JOB_ENV_BATCH_VAR = "SAGEMAKER_BATCH" +MAX_RETRY_ATTEMPTS = 7 + +logger = logging.getLogger(__name__) + + +class _EnvironmentType(enum.Enum): + """SageMaker jobs which data can be pulled from the environment.""" + + SageMakerTrainingJob = 1 + SageMakerProcessingJob = 2 + SageMakerTransformJob = 3 + + +class _RunEnvironment(object): + """Retrieves job specific data from the environment.""" + + def __init__(self, environment_type, source_arn): + """Init for _RunEnvironment. + + Args: + environment_type (_EnvironmentType): The environment type. + source_arn (str): The ARN of the current job. + """ + self.environment_type = environment_type + self.source_arn = source_arn + + @classmethod + def load( + cls, + training_job_arn_env=TRAINING_JOB_ARN_ENV, + processing_job_config_path=PROCESSING_JOB_CONFIG_PATH, + transform_job_batch_var=TRANSFORM_JOB_ENV_BATCH_VAR, + ): + """Loads source arn of current job from environment. + + Args: + training_job_arn_env (str): The environment key for training job ARN + (default: `TRAINING_JOB_ARN`). + processing_job_config_path (str): The processing job config path + (default: `/opt/ml/config/processingjobconfig.json`). + transform_job_batch_var (str): The environment variable indicating if + it is a transform job (default: `SAGEMAKER_BATCH`). + + Returns: + _RunEnvironment: Job data loaded from the environment. None if config does not exist. + """ + if training_job_arn_env in os.environ: + environment_type = _EnvironmentType.SageMakerTrainingJob + source_arn = os.environ.get(training_job_arn_env) + return _RunEnvironment(environment_type, source_arn) + if os.path.exists(processing_job_config_path): + environment_type = _EnvironmentType.SageMakerProcessingJob + source_arn = json.loads(open(processing_job_config_path).read())["ProcessingJobArn"] + return _RunEnvironment(environment_type, source_arn) + if transform_job_batch_var in os.environ and os.environ[transform_job_batch_var] == "true": + environment_type = _EnvironmentType.SageMakerTransformJob + # TODO: need to figure out how to get source_arn from job env + # with Transform team's help. + source_arn = "" + return _RunEnvironment(environment_type, source_arn) + + return None + + def get_trial_component(self, sagemaker_session): + """Retrieves the trial component from the job in the environment. + + Args: + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + _TrialComponent: The trial component created from the job. None if not found. + """ + # TODO: Remove this condition check once we have a way to retrieve source ARN + # from transform job env + if self.environment_type == _EnvironmentType.SageMakerTransformJob: + logger.error( + "Currently getting the job trial component from the transform job environment " + "is not supported. Returning None." + ) + return None + + def _get_trial_component(): + summaries = list( + trial_component._TrialComponent.list( + source_arn=self.source_arn.lower(), sagemaker_session=sagemaker_session + ) + ) + if summaries: + summary = summaries[0] + return trial_component._TrialComponent.load( + trial_component_name=summary.trial_component_name, + sagemaker_session=sagemaker_session, + ) + return None + + job_tc = None + try: + job_tc = retry_with_backoff(_get_trial_component, MAX_RETRY_ATTEMPTS) + except Exception as ex: # pylint: disable=broad-except + logger.error( + "Failed to get trail component in the current environment due to %s", str(ex) + ) + return job_tc diff --git a/src/sagemaker/experiments/_helper.py b/src/sagemaker/experiments/_helper.py new file mode 100644 index 0000000000..0c689b1125 --- /dev/null +++ b/src/sagemaker/experiments/_helper.py @@ -0,0 +1,266 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the helper classes for SageMaker Experiment.""" +from __future__ import absolute_import + +import json +import logging +import os + +import botocore + +from sagemaker.experiments._utils import is_already_exist_error + +logger = logging.getLogger(__name__) + + +_DEFAULT_ARTIFACT_PREFIX = "trial-component-artifacts" +_DEFAULT_ARTIFACT_TYPE = "Tracker" + + +class _ArtifactUploader(object): + """Artifact uploader""" + + def __init__( + self, + trial_component_name, + sagemaker_session, + artifact_bucket=None, + artifact_prefix=_DEFAULT_ARTIFACT_PREFIX, + ): + """Initialize a `_ArtifactUploader` instance. + + Args: + trial_component_name (str): The name of the trial component, + which is used to generate the S3 path to upload the artifact to. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. + artifact_bucket (str): The S3 bucket to upload the artifact to. + If not specified, the default bucket defined in `sagemaker_session` + will be used. + artifact_prefix (str): The S3 key prefix used to generate the S3 path + to upload the artifact to (default: "trial-component-artifacts"). + """ + self.sagemaker_session = sagemaker_session + self.trial_component_name = trial_component_name + self.artifact_bucket = artifact_bucket + self.artifact_prefix = artifact_prefix + self._s3_client = self.sagemaker_session.boto_session.client("s3") + + def upload_artifact(self, file_path): + """Upload an artifact file to S3. + + Args: + file_path (str): the file path of the artifact + + Returns: + (str, str): The s3 URI of the uploaded file and the etag of the file. + + Raises: + ValueError: If file does not exist. + """ + file_path = os.path.expanduser(file_path) + if not os.path.isfile(file_path): + raise ValueError( + "{} does not exist or is not a file. Please supply a file path.".format(file_path) + ) + if not self.artifact_bucket: + self.artifact_bucket = self.sagemaker_session.default_bucket() + artifact_name = os.path.basename(file_path) + artifact_s3_key = "{}/{}/{}".format( + self.artifact_prefix, self.trial_component_name, artifact_name + ) + self._s3_client.upload_file(file_path, self.artifact_bucket, artifact_s3_key) + etag = self._try_get_etag(artifact_s3_key) + return "s3://{}/{}".format(self.artifact_bucket, artifact_s3_key), etag + + def upload_object_artifact(self, artifact_name, artifact_object, file_extension=None): + """Upload an artifact object to S3. + + Args: + artifact_name (str): the name of the artifact. + artifact_object (obj): the object of the artifact + file_extension (str): Optional file extension. + + Returns: + str: The s3 URI of the uploaded file and the version of the file. + """ + if not self.artifact_bucket: + self.artifact_bucket = self.sagemaker_session.default_bucket() + if file_extension: + artifact_name = ( + artifact_name + ("" if file_extension.startswith(".") else ".") + file_extension + ) + artifact_s3_key = "{}/{}/{}".format( + self.artifact_prefix, self.trial_component_name, artifact_name + ) + self._s3_client.put_object( + Body=json.dumps(artifact_object), Bucket=self.artifact_bucket, Key=artifact_s3_key + ) + etag = self._try_get_etag(artifact_s3_key) + return "s3://{}/{}".format(self.artifact_bucket, artifact_s3_key), etag + + def _try_get_etag(self, key): + """Get ETag of given key and return None if not allowed + + Args: + key (str): The S3 object key. + + Returns: + str: The S3 object ETag if it allows, otherwise return None. + """ + try: + response = self._s3_client.head_object(Bucket=self.artifact_bucket, Key=key) + return response["ETag"] + except botocore.exceptions.ClientError as error: + # requires read permissions + logger.warning("Failed to get ETag of %s due to %s", key, error) + return None + + +class _LineageArtifactManager(object): + """A helper class to manage Lineage Artifacts""" + + def __init__( + self, + name, + source_uri, + etag, + source_arn=None, + dest_arn=None, + artifact_type=_DEFAULT_ARTIFACT_TYPE, + ): + """Initialize a `_LineageArtifactManager` instance. + + Args: + name (str): The name of the Lineage artifact to be created. + source_uri (str): The source URI used to create the Lineage artifact. + etag (str): The S3 Etag used to create the Lineage artifact. + source_arn (str): The source ARN of a trail component to associate + this Lineage artifact with (default: None). + dest_arn (str): The destination ARN of a trial component to associate + this Lineage artifact with (default: None). + artifact_type (str): The type of the Lineage artifact (default: "Tracker"). + """ + self.name = name + self.source_uri = source_uri + self.etag = etag + self.source_arn = source_arn + self.dest_arn = dest_arn + self.artifact_arn = None + self.artifact_type = artifact_type + + def create_artifact(self, sagemaker_session): + """Create the artifact by calling `CreateArtifact` API + + Args: + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. + """ + source_ids = [] + if self.etag: + source_ids.append({"SourceIdType": "S3ETag", "Value": self.etag}) + + try: + response = sagemaker_session.sagemaker_client.create_artifact( + ArtifactName=self.name, + ArtifactType=self.artifact_type, + Source={"SourceUri": self.source_uri, "SourceTypes": source_ids}, + ) + self.artifact_arn = response["ArtifactArn"] + except botocore.exceptions.ClientError as err: + err_info = err.response["Error"] + if not is_already_exist_error(err_info): + raise + logger.warning( + "Skip creating the artifact since it already exists: %s", err_info["Message"] + ) + + def add_association(self, sagemaker_session): + """Associate the artifact with a source/destination ARN (e.g. trial component arn) + + Args: + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. + """ + source_arn = self.source_arn if self.source_arn else self.artifact_arn + dest_arn = self.dest_arn if self.dest_arn else self.artifact_arn + # if the trial component (job) is the source then it produced the artifact, + # otherwise the artifact contributed to the trial component (job) + association_edge_type = "Produced" if self.source_arn else "ContributedTo" + try: + sagemaker_session.sagemaker_client.add_association( + SourceArn=source_arn, DestinationArn=dest_arn, AssociationType=association_edge_type + ) + except botocore.exceptions.ClientError as err: + err_info = err.response["Error"] + if not is_already_exist_error(err_info): + raise + logger.warning( + "Skip associating since the association already exists: %s", err_info["Message"] + ) + + +class _LineageArtifactTracker(object): + """Lineage Artifact Tracker""" + + def __init__(self, trial_component_arn, sagemaker_session): + """Initialize a `_LineageArtifactTracker` instance. + + Args: + trial_component_arn (str): The ARN of the trial component to be + associated with the input/output artifacts. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. + """ + self.trial_component_arn = trial_component_arn + self.sagemaker_session = sagemaker_session + self.artifacts = [] + + def add_input_artifact(self, name, source_uri, etag, artifact_type): + """Add a Lineage input artifact locally + + Args: + name (str): The name of the Lineage input artifact to be added. + source_uri (str): The source URI used to create the Lineage input artifact. + etag (str): The S3 Etag used to create the Lineage input artifact. + artifact_type (str): The type of the Lineage input artifact. + """ + artifact = _LineageArtifactManager( + name, source_uri, etag, dest_arn=self.trial_component_arn, artifact_type=artifact_type + ) + self.artifacts.append(artifact) + + def add_output_artifact(self, name, source_uri, etag, artifact_type): + """Add a Lineage output artifact locally + + Args: + name (str): The name of the Lineage output artifact to be added. + source_uri (str): The source URI used to create the Lineage output artifact. + etag (str): The S3 Etag used to create the Lineage output artifact. + artifact_type (str): The type of the Lineage output artifact. + """ + artifact = _LineageArtifactManager( + name, source_uri, etag, source_arn=self.trial_component_arn, artifact_type=artifact_type + ) + self.artifacts.append(artifact) + + def save(self): + """Persist any artifact data saved locally""" + for artifact in self.artifacts: + artifact.create_artifact(self.sagemaker_session) + artifact.add_association(self.sagemaker_session) diff --git a/src/sagemaker/experiments/_metrics.py b/src/sagemaker/experiments/_metrics.py new file mode 100644 index 0000000000..f80c43f337 --- /dev/null +++ b/src/sagemaker/experiments/_metrics.py @@ -0,0 +1,413 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes to manage metrics for Sagemaker Experiment""" +from __future__ import absolute_import + +import datetime +import json +import logging +import os +import time +import threading +import queue + +import dateutil.tz + +from sagemaker.session import Session + +METRICS_DIR = os.environ.get("SAGEMAKER_METRICS_DIRECTORY", ".") +METRIC_TS_LOWER_BOUND_TO_NOW = 1209600 # on seconds +METRIC_TS_UPPER_BOUND_FROM_NOW = 7200 # on seconds + +BATCH_SIZE = 10 + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +# TODO: remove this _SageMakerFileMetricsWriter class +# when _MetricsManager is fully ready +class _SageMakerFileMetricsWriter(object): + """Write metric data to file.""" + + def __init__(self, metrics_file_path=None): + """Construct a `_SageMakerFileMetricsWriter` object""" + self._metrics_file_path = metrics_file_path + self._file = None + self._closed = False + + def log_metric(self, metric_name, value, timestamp=None, step=None): + """Write a metric to file. + + Args: + metric_name (str): The name of the metric. + value (float): The value of the metric. + timestamp (datetime.datetime): Timestamp of the metric. + If not specified, the current UTC time will be used. + step (int): Iteration number of the metric (default: None). + + Raises: + SageMakerMetricsWriterException: If the metrics file is closed. + AttributeError: If file has been initialized and the writer hasn't been closed. + """ + raw_metric_data = _RawMetricData( + metric_name=metric_name, value=value, timestamp=timestamp, step=step + ) + try: + logger.debug("Writing metric: %s", raw_metric_data) + self._file.write(json.dumps(raw_metric_data.to_record())) + self._file.write("\n") + except AttributeError as attr_err: + if self._closed: + raise SageMakerMetricsWriterException("log_metric called on a closed writer") + if not self._file: + self._file = open(self._get_metrics_file_path(), "a", buffering=1) + self._file.write(json.dumps(raw_metric_data.to_record())) + self._file.write("\n") + else: + raise attr_err + + def close(self): + """Closes the metric file.""" + if not self._closed and self._file: + self._file.close() + self._file = None # invalidate reference, causing subsequent log_metric to fail. + self._closed = True + + def __enter__(self): + """Return self""" + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + """Execute self.close()""" + self.close() + + def __del__(self): + """Execute self.close()""" + self.close() + + def _get_metrics_file_path(self): + """Get file path to store metrics""" + pid_filename = "{}.json".format(str(os.getpid())) + metrics_file_path = self._metrics_file_path or os.path.join(METRICS_DIR, pid_filename) + logger.debug("metrics_file_path = %s", metrics_file_path) + return metrics_file_path + + +class SageMakerMetricsWriterException(Exception): + """SageMakerMetricsWriterException""" + + def __init__(self, message, errors=None): + """Construct a `SageMakerMetricsWriterException` instance""" + super().__init__(message) + if errors: + self.errors = errors + + +class _RawMetricData(object): + """A Raw Metric Data Object""" + + MetricName = None + Value = None + Timestamp = None + Step = None + + def __init__(self, metric_name, value, timestamp=None, step=None): + """Construct a `_RawMetricData` instance. + + Args: + metric_name (str): The name of the metric. + value (float): The value of the metric. + timestamp (datetime.datetime or float or str): Timestamp of the metric. + If not specified, the current UTC time will be used. + step (int): Iteration number of the metric (default: None). + """ + if timestamp is None: + timestamp = time.time() + elif isinstance(timestamp, datetime.datetime): + # If the input is a datetime then convert it to UTC time. + # Assume a naive datetime is in local timezone + if not timestamp.tzinfo: + timestamp = timestamp.replace(tzinfo=dateutil.tz.tzlocal()) + timestamp = (timestamp - timestamp.utcoffset()).replace(tzinfo=datetime.timezone.utc) + timestamp = timestamp.timestamp() + else: + timestamp = float(timestamp) + + if timestamp < (time.time() - METRIC_TS_LOWER_BOUND_TO_NOW) or timestamp > ( + time.time() + METRIC_TS_UPPER_BOUND_FROM_NOW + ): + raise ValueError( + "Supplied timestamp %f is invalid." + " Timestamps must be between two weeks before and two hours from now." % timestamp + ) + value = float(value) + + self.MetricName = metric_name + self.Value = float(value) + self.Timestamp = timestamp + if step is not None: + if not isinstance(step, int): + raise ValueError("step must be int.") + self.Step = step + + def to_record(self): + """Convert the `_RawMetricData` object to dict""" + return self.__dict__ + + def to_raw_metric_data(self): + """Converts the metric data to a BatchPutMetrics RawMetricData item""" + # Convert timestamp from float to timestamp str. + # Otherwise will get ParamValidationError + raw_metric_data = { + "MetricName": self.MetricName, + "Value": self.Value, + "Timestamp": str(int(self.Timestamp)), + } + if self.Step is not None: + raw_metric_data["Step"] = int(self.Step) + return raw_metric_data + + def __str__(self): + """String representation of the `_RawMetricData` object.""" + return repr(self) + + def __repr__(self): + """Return a string representation of this _RawMetricData` object.""" + return "{}({})".format( + type(self).__name__, + ",".join(["{}={}".format(k, repr(v)) for k, v in vars(self).items()]), + ) + + +class _MetricsManager(object): + """Collects metrics and sends them directly to SageMaker Metrics data plane APIs.""" + + def __init__(self, trial_component_name: str, sagemaker_session: Session, sink=None) -> None: + """Initialize a `_MetricsManager` instance + + Args: + trial_component_name (str): The Name of the Trial Component to log metrics to + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + sink (object): The metrics sink to use. + """ + if sink is None: + self.sink = _SyncMetricsSink( + trial_component_name, sagemaker_session.sagemaker_metrics_client + ) + else: + self.sink = sink + + def log_metric(self, metric_name, value, timestamp=None, step=None): + """Sends a metric to metrics service.""" + + metric_data = _RawMetricData(metric_name, value, timestamp, step) + self.sink.log_metric(metric_data) + + def __enter__(self): + """Return self""" + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + """Execute self.close()""" + self.sink.close() + + def close(self): + """Close the metrics object.""" + self.sink.close() + + +class _SyncMetricsSink(object): + """Collects metrics and sends them directly to metrics service.""" + + def __init__(self, trial_component_name, metrics_client) -> None: + """Initialize a `_SyncMetricsSink` instance + + Args: + trial_component_name (str): The Name of the Trial Component to log metrics. + metrics_client (boto3.client): boto client for metrics service + """ + self._trial_component_name = trial_component_name + self._metrics_client = metrics_client + self._buffer = [] + + def log_metric(self, metric_data): + """Sends a metric to metrics service.""" + + # this is a simplistic solution which calls BatchPutMetrics + # on the same thread as the client code + self._buffer.append(metric_data) + self._drain() + + def _drain(self, close=False): + """Pops off all metrics in the buffer and starts sending them to metrics service.""" + + if not self._buffer: + return + + if len(self._buffer) < BATCH_SIZE and not close: + return + + # pop all the available metrics + available_metrics, self._buffer = self._buffer, [] + + self._send_metrics(available_metrics) + + def _send_metrics(self, metrics): + """Calls BatchPutMetrics directly on the metrics service.""" + while metrics: + batch, metrics = ( + metrics[:BATCH_SIZE], + metrics[BATCH_SIZE:], + ) + request = self._construct_batch_put_metrics_request(batch) + response = self._metrics_client.batch_put_metrics(**request) + errors = response["Errors"] if "Errors" in response else None + if errors: + message = errors[0]["Message"] + raise Exception(f'{len(errors)} errors with message "{message}"') + + def _construct_batch_put_metrics_request(self, batch): + """Creates dictionary object used as request to metrics service.""" + return { + "TrialComponentName": self._trial_component_name.lower(), + "MetricData": list(map(lambda x: x.to_raw_metric_data(), batch)), + } + + def close(self): + """Drains any remaining metrics.""" + self._drain(close=True) + + +class _MetricQueue(object): + """A thread safe queue for sending metrics to SageMaker. + + Args: + trial_component_name (str): the ARN of the resource + metric_name (str): the name of the metric + metrics_client (boto_client): the boto client for SageMaker Metrics service + """ + + _CONSUMER_SLEEP_SECONDS = 5 + + def __init__(self, trial_component_name, metric_name, metrics_client): + # infinite queue size + self._queue = queue.Queue() + self._buffer = [] + self._thread = threading.Thread(target=self._run) + self._started = False + self._finished = False + self._trial_component_name = trial_component_name + self._metrics_client = metrics_client + self._metric_name = metric_name + self._logged_metrics = 0 + + def log_metric(self, metric_data): + """Adds a metric data point to the queue""" + self._buffer.append(metric_data) + + if len(self._buffer) < BATCH_SIZE: + return + + self._enqueue_all() + + if not self._started: + self._thread.start() + self._started = True + + def _run(self): + """Starts the metric thread which sends metrics to SageMaker in batches""" + + while not self._queue.empty() or not self._finished: + if self._queue.empty(): + time.sleep(self._CONSUMER_SLEEP_SECONDS) + else: + batch = self._queue.get() + self._send_metrics(batch) + + def _send_metrics(self, metrics_batch): + """Calls BatchPutMetrics directly on the metrics service.""" + request = self._construct_batch_put_metrics_request(metrics_batch) + self._logged_metrics += len(metrics_batch) + self._metrics_client.batch_put_metrics(**request) + + def _construct_batch_put_metrics_request(self, batch): + """Creates dictionary object used as request to metrics service.""" + + return { + "TrialComponentName": self._trial_component_name, + "MetricData": list(map(lambda x: x.to_raw_metric_data(), batch)), + } + + def _enqueue_all(self): + """Enqueue all buffered metrics to be sent to SageMaker""" + + available_metrics, self._buffer = self._buffer, [] + if available_metrics: + self._queue.put(available_metrics) + + def close(self): + """Flushes any buffered metrics""" + + self._enqueue_all() + self._finished = True + + def is_active(self): + """Is the thread active (still draining metrics to SageMaker)""" + + return self._thread.is_alive() + + +class _AsyncMetricsSink(object): + """Collects metrics and sends them directly to metrics service.""" + + _COMPLETE_SLEEP_SECONDS = 1.0 + + def __init__(self, trial_component_name, metrics_client) -> None: + """Initialize a `_AsyncMetricsSink` instance + + Args: + trial_component_name (str): The Name of the Trial Component to log metrics to. + metrics_client (boto3.client): boto client for metrics service + """ + self._trial_component_name = trial_component_name + self._metrics_client = metrics_client + self._buffer = [] + self._is_draining = False + self._metric_queues = {} + + def log_metric(self, metric_data): + """Sends a metric to metrics service.""" + + if metric_data.MetricName in self._metric_queues: + self._metric_queues[metric_data.MetricName].log_metric(metric_data) + else: + cur_metric_queue = _MetricQueue( + self._trial_component_name, metric_data.MetricName, self._metrics_client + ) + self._metric_queues[metric_data.MetricName] = cur_metric_queue + cur_metric_queue.log_metric(metric_data) + + def close(self): + """Closes the metric file.""" + logging.debug("Closing") + for q in self._metric_queues.values(): + q.close() + + # TODO should probably use join + while any(map(lambda x: x.is_active(), self._metric_queues.values())): + time.sleep(self._COMPLETE_SLEEP_SECONDS) + logging.debug("Closed") diff --git a/src/sagemaker/experiments/_run_context.py b/src/sagemaker/experiments/_run_context.py new file mode 100644 index 0000000000..9a7dada5f4 --- /dev/null +++ b/src/sagemaker/experiments/_run_context.py @@ -0,0 +1,58 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the SageMaker Experiment _RunContext class.""" +from __future__ import absolute_import + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from sagemaker.experiments import Run + + +class _RunContext: + """A static context variable to keep track of the current Run object""" + + _context_run = None + + @classmethod + def add_run_object(cls, run: "Run"): + """Keep track of the current executing Run object + + by adding it to a class static variable. + + Args: + run (Run): The current Run object to be tracked. + """ + cls._context_run = run + + @classmethod + def drop_current_run(cls) -> "Run": + """Drop the Run object tracked in the global static variable + + as its execution finishes (its "with" block ends). + + Return: + Run: the dropped Run object. + """ + current_run = cls._context_run + cls._context_run = None + return current_run + + @classmethod + def get_current_run(cls) -> "Run": + """Return the current Run object without dropping it. + + Return: + Run: the current Run object to be returned. + """ + return cls._context_run diff --git a/src/sagemaker/experiments/_utils.py b/src/sagemaker/experiments/_utils.py new file mode 100644 index 0000000000..5ef5d99dad --- /dev/null +++ b/src/sagemaker/experiments/_utils.py @@ -0,0 +1,218 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the SageMaker Experiment utility methods.""" +from __future__ import absolute_import + +import logging +import os + +import mimetypes +import urllib +from functools import wraps +from typing import Optional + +from sagemaker import Session +from sagemaker.apiutils import _utils +from sagemaker.experiments._environment import _RunEnvironment, _EnvironmentType +from sagemaker.experiments.trial_component import _TrialComponent +from sagemaker.utilities.search_expression import Filter, Operator, SearchExpression +from sagemaker.utils import retry_with_backoff + + +def resolve_artifact_name(file_path): + """Resolve artifact name from given file path. + + If not specified, will auto create one. + + Args: + file_path (str): Path to the file. + + Returns: + str: The resolved artifact name. + """ + _, filename = os.path.split(file_path) + if filename: + return filename + + return _utils.name("artifact") + + +def guess_media_type(file_path): + """Infer the media type of a file based on its file name. + + Args: + file_path (str): Path to the file. + + Returns: + str: The guessed media type. + """ + file_url = urllib.parse.urljoin("file:", urllib.request.pathname2url(file_path)) + guessed_media_type, _ = mimetypes.guess_type(file_url, strict=False) + return guessed_media_type + + +def verify_length_of_true_and_predicted(true_labels, predicted_attrs, predicted_attrs_name): + """Verify if lengths match between lists of true labels and predicted attributes. + + Args: + true_labels (list or array): The list of the true labels. + predicted_attrs (list or array): The list of the predicted labels/probabilities/scores. + predicted_attrs_name (str): The name of the predicted attributes. + + Raises: + ValueError: If lengths mismatch between true labels and predicted attributes. + """ + if len(true_labels) != len(predicted_attrs): + raise ValueError( + "Lengths mismatch between true labels and {}: " + "({} vs {}).".format(predicted_attrs_name, len(true_labels), len(predicted_attrs)) + ) + + +def validate_invoked_inside_run_context(func): + """A Decorator to force the decorated method called under Run context.""" + + @wraps(func) + def wrapper(*args, **kwargs): + self_instance = args[0] + if not self_instance._inside_load_context and not self_instance._inside_init_context: + raise RuntimeError("This method should be called inside context of 'with' statement.") + return func(*args, **kwargs) + + return wrapper + + +def is_already_exist_error(error): + """Check if the error indicates resource already exists + + Args: + error (dict): The "Error" field in the response of the + `botocore.exceptions.ClientError` + """ + return error["Code"] == "ValidationException" and "already exists" in error["Message"] + + +def get_tc_and_exp_config_from_job_env( + environment: _RunEnvironment, + sagemaker_session: Session, +) -> dict: + """Retrieve an experiment config from the job environment. + + Args: + environment (_RunEnvironment): The run environment object with job specific data. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + """ + job_name = environment.source_arn.split("/")[-1] + if environment.environment_type == _EnvironmentType.SageMakerTrainingJob: + job_response = retry_with_backoff( + callable_func=lambda: sagemaker_session.describe_training_job(job_name), + num_attempts=4, + ) + elif environment.environment_type == _EnvironmentType.SageMakerProcessingJob: + job_response = retry_with_backoff( + callable_func=lambda: sagemaker_session.describe_processing_job(job_name), + num_attempts=4, + ) + else: # environment.environment_type == _EnvironmentType.SageMakerTransformJob + raise RuntimeError( + "Failed to load the Run as loading experiment config " + "from transform job environment is not currently supported. " + "As a workaround, please explicitly pass in " + "the experiment_name and run_name in load_run." + ) + + job_exp_config = job_response.get("ExperimentConfig", dict()) + from sagemaker.experiments.run import RUN_NAME + + if job_exp_config.get(RUN_NAME, None): + return job_exp_config + raise RuntimeError( + "Not able to fetch RunName in ExperimentConfig of the sagemaker job. " + "Please make sure the ExperimentConfig is correctly set." + ) + + +def verify_load_input_names( + run_name: Optional[str] = None, + experiment_name: Optional[str] = None, +): + """Verify the run_name and the experiment_name inputs in load_run. + + Args: + run_name (str): The run_name supplied by the user (default: None). + experiment_name (str): The experiment_name supplied by the user + (default: None). + + Raises: + ValueError: If run_name is supplied while experiment_name is not. + """ + if not run_name and experiment_name: + logging.warning( + "No run_name is supplied. Ignoring the provided experiment_name " + "since it only takes effect along with run_name. " + "Will load the Run object from the job environment or current Run context." + ) + if run_name and not experiment_name: + raise ValueError( + "Invalid input: experiment_name is missing when run_name is supplied. " + "Please supply a valid experiment_name when the run_name is not None." + ) + + +def is_run_trial_component(trial_component_name: str, sagemaker_session: Session) -> bool: + """Check if a trial component is generated by `sagemaker.experiments.Run` + + Args: + trial_component_name (str): The name of the trial component. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + bool: Indicate whether the trial component is created by + `sagemaker.experiments.Run` or not. + """ + search_filter = Filter( + name="TrialComponentName", + operator=Operator.EQUALS, + value=trial_component_name, + ) + search_expression = SearchExpression(filters=[search_filter]) + + def search(): + return list( + _TrialComponent.search( + search_expression=search_expression, + max_results=1, # TrialComponentName is unique in an account + sagemaker_session=sagemaker_session, + ) + )[0] + + try: + tc_search_res = retry_with_backoff(search, 4) + from sagemaker.experiments.run import RUN_TC_TAG + + if not tc_search_res.tags or RUN_TC_TAG not in tc_search_res.tags: + return False + return True + except Exception as ex: # pylint: disable=broad-except + logging.warning( + "Failed to inspect the type of the trial component (%s), due to (%s)", + trial_component_name, + str(ex), + ) + return False diff --git a/src/sagemaker/experiments/experiment.py b/src/sagemaker/experiments/experiment.py new file mode 100644 index 0000000000..8f59ff36b3 --- /dev/null +++ b/src/sagemaker/experiments/experiment.py @@ -0,0 +1,237 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the SageMaker Experiment class.""" +from __future__ import absolute_import + +import time + +from sagemaker.apiutils import _base_types +from sagemaker.experiments.trial import _Trial +from sagemaker.experiments.trial_component import _TrialComponent + + +class _Experiment(_base_types.Record): + """An Amazon SageMaker experiment, which is a collection of related trials. + + New experiments are created by calling `experiments.experiment._Experiment.create`. + Existing experiments can be reloaded by calling `experiments.experiment._Experiment.load`. + + Attributes: + experiment_name (str): The name of the experiment. The name must be unique + within an account. + display_name (str): Name of the experiment that will appear in UI, + such as SageMaker Studio. + description (str): A description of the experiment. + tags (List[Dict[str, str]]): A list of tags to associate with the experiment. + """ + + experiment_name = None + display_name = None + description = None + tags = None + + _boto_create_method = "create_experiment" + _boto_load_method = "describe_experiment" + _boto_update_method = "update_experiment" + _boto_delete_method = "delete_experiment" + + _boto_update_members = ["experiment_name", "description", "display_name"] + _boto_delete_members = ["experiment_name"] + + _MAX_DELETE_ALL_ATTEMPTS = 3 + + def save(self): + """Save the state of this Experiment to SageMaker. + + Returns: + dict: Update experiment API response. + """ + return self._invoke_api(self._boto_update_method, self._boto_update_members) + + def delete(self): + """Delete this Experiment from SageMaker. + + Deleting an Experiment does not delete associated Trials and their Trial Components. + It requires that each Trial in the Experiment is first deleted. + + Returns: + dict: Delete experiment API response. + """ + return self._invoke_api(self._boto_delete_method, self._boto_delete_members) + + @classmethod + def load(cls, experiment_name, sagemaker_session=None): + """Load an existing experiment and return an `_Experiment` object representing it. + + Args: + experiment_name: (str): Name of the experiment + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.experiment._Experiment: A SageMaker `_Experiment` object + """ + return cls._construct( + cls._boto_load_method, + experiment_name=experiment_name, + sagemaker_session=sagemaker_session, + ) + + @classmethod + def create( + cls, + experiment_name, + display_name=None, + description=None, + tags=None, + sagemaker_session=None, + ): + """Create a new experiment in SageMaker and return an `_Experiment` object. + + Args: + experiment_name: (str): Name of the experiment. Must be unique. Required. + display_name: (str): Name of the experiment that will appear in UI, + such as SageMaker Studio (default: None). + description: (str): Description of the experiment (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + tags (List[Dict[str, str]]): A list of tags to associate with the experiment + (default: None). + + Returns: + experiments.experiment._Experiment: A SageMaker `_Experiment` object + """ + return cls._construct( + cls._boto_create_method, + experiment_name=experiment_name, + display_name=display_name, + description=description, + tags=tags, + sagemaker_session=sagemaker_session, + ) + + @classmethod + def _load_or_create( + cls, + experiment_name, + display_name=None, + description=None, + tags=None, + sagemaker_session=None, + ): + """Load an experiment by name and create a new one if it does not exist. + + Args: + experiment_name: (str): Name of the experiment. Must be unique. Required. + display_name: (str): Name of the experiment that will appear in UI, + such as SageMaker Studio (default: None). This is used only when the + given `experiment_name` does not exist and a new experiment has to be created. + description: (str): Description of the experiment (default: None). + This is used only when the given `experiment_name` does not exist and + a new experiment has to be created. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + tags (List[Dict[str, str]]): A list of tags to associate with the experiment + (default: None). This is used only when the given `experiment_name` does not + exist and a new experiment has to be created. + + Returns: + experiments.experiment._Experiment: A SageMaker `_Experiment` object + """ + sagemaker_client = sagemaker_session.sagemaker_client + try: + experiment = _Experiment.load(experiment_name, sagemaker_session) + except sagemaker_client.exceptions.ResourceNotFound: + experiment = _Experiment.create( + experiment_name=experiment_name, + display_name=display_name, + description=description, + tags=tags, + sagemaker_session=sagemaker_session, + ) + return experiment + + def list_trials(self, created_before=None, created_after=None, sort_by=None, sort_order=None): + """List trials in this experiment matching the specified criteria. + + Args: + created_before (datetime.datetime): Return trials created before this instant + (default: None). + created_after (datetime.datetime): Return trials created after this instant + (default: None). + sort_by (str): Which property to sort results by. One of 'Name', 'CreationTime' + (default: None). + sort_order (str): One of 'Ascending', or 'Descending' (default: None). + + Returns: + collections.Iterator[experiments._api_types.TrialSummary] : + An iterator over trials matching the criteria. + """ + return _Trial.list( + experiment_name=self.experiment_name, + created_before=created_before, + created_after=created_after, + sort_by=sort_by, + sort_order=sort_order, + sagemaker_session=self.sagemaker_session, + ) + + def _delete_all(self, action): + """Force to delete the experiment and associated trials, trial components. + + Args: + action (str): The string '--force' is required to pass in to confirm recursively + delete the experiments, and all its trials and trial components. + """ + if action != "--force": + raise ValueError( + "Must confirm with string '--force' in order to delete the experiment and " + "associated trials, trial components." + ) + + delete_attempt_count = 0 + last_exception = None + while True: + if delete_attempt_count == self._MAX_DELETE_ALL_ATTEMPTS: + raise Exception("Failed to delete, please try again.") from last_exception + try: + for trial_summary in self.list_trials(): + trial = _Trial.load( + sagemaker_session=self.sagemaker_session, + trial_name=trial_summary.trial_name, + ) + for ( + trial_component_summary + ) in trial.list_trial_components(): # pylint: disable=no-member + tc = _TrialComponent.load( + sagemaker_session=self.sagemaker_session, + trial_component_name=trial_component_summary.trial_component_name, + ) + tc.delete(force_disassociate=True) + # to prevent throttling + time.sleep(1.2) + trial.delete() # pylint: disable=no-member + # to prevent throttling + time.sleep(1.2) + self.delete() + break + except Exception as ex: # pylint: disable=broad-except + last_exception = ex + finally: + delete_attempt_count = delete_attempt_count + 1 diff --git a/src/sagemaker/experiments/run.py b/src/sagemaker/experiments/run.py new file mode 100644 index 0000000000..1492b6bafa --- /dev/null +++ b/src/sagemaker/experiments/run.py @@ -0,0 +1,882 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the SageMaker Experiment Run class.""" +from __future__ import absolute_import + +import datetime +import logging +from enum import Enum +from math import isnan, isinf +from numbers import Number +from typing import Optional, List, Dict, TYPE_CHECKING, Union + +import dateutil +from numpy import array + +from sagemaker.apiutils import _utils +from sagemaker.experiments import _api_types +from sagemaker.experiments._api_types import TrialComponentArtifact, _TrialComponentStatusType +from sagemaker.experiments._helper import ( + _ArtifactUploader, + _LineageArtifactTracker, +) +from sagemaker.experiments._environment import _RunEnvironment +from sagemaker.experiments._run_context import _RunContext +from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments._metrics import _MetricsManager +from sagemaker.experiments.trial import _Trial +from sagemaker.experiments.trial_component import _TrialComponent + +from sagemaker.utils import ( + get_module, + unique_name_from_base, +) + +from sagemaker.experiments._utils import ( + guess_media_type, + resolve_artifact_name, + verify_length_of_true_and_predicted, + validate_invoked_inside_run_context, + get_tc_and_exp_config_from_job_env, + verify_load_input_names, + is_run_trial_component, +) + +if TYPE_CHECKING: + from sagemaker import Session + +logger = logging.getLogger(__name__) + +RUN_NAME_BASE = "Sagemaker-Run".lower() +TRIAL_NAME_TEMPLATE = "Default-Run-Group-{}" +MAX_RUN_TC_ARTIFACTS_LEN = 30 +MAX_NAME_LEN_IN_BACKEND = 120 +EXPERIMENT_NAME = "ExperimentName" +TRIAL_NAME = "TrialName" +RUN_NAME = "RunName" +DELIMITER = "-" +RUN_TC_TAG_KEY = "sagemaker:trial-component-source" +RUN_TC_TAG_VALUE = "run" +RUN_TC_TAG = {"Key": RUN_TC_TAG_KEY, "Value": RUN_TC_TAG_VALUE} + + +class SortByType(Enum): + """The type of property by which to sort the `list_runs` results.""" + + CREATION_TIME = "CreationTime" + NAME = "Name" + + +class SortOrderType(Enum): + """The type of order to sort the list or search results.""" + + ASCENDING = "Ascending" + DESCENDING = "Descending" + + +class Run(object): + """A collection of parameters, metrics, and artifacts to create a ML model.""" + + def __init__( + self, + experiment_name: str, + run_name: Optional[str] = None, + experiment_display_name: Optional[str] = None, + run_display_name: Optional[str] = None, + tags: Optional[List[Dict[str, str]]] = None, + sagemaker_session: Optional["Session"] = None, + ): + """Construct a `Run` instance. + + SageMaker Experiments automatically tracks the inputs, parameters, configurations, + and results of your iterations as runs. + You can assign, group, and organize these runs into experiments. + You can also create, compare, and evaluate runs. + + The code sample below shows how to initialize a run, log parameters to the Run object + and invoke a training job under the context of this Run object, which automatically + passes the run's ``experiment_config`` (including the experiment name, run name etc.) + to the training job. + + Note: + All log methods (e.g. ``log_parameter``, ``log_metric``, etc.) have to be called within + the run context (i.e. the ``with`` statement). Otherwise, a ``RuntimeError`` is thrown. + + .. code:: python + + with Run(experiment_name="my-exp", run_name="my-run", ...) as run: + run.log_parameter(...) + ... + estimator.fit(job_name="my-job") # Create a training job + + In order to reuse an existing run to log extra data, ``load_run`` is recommended. + The code snippet below displays how to load the run initialized above + in a custom training job script, where no ``run_name`` or ``experiment_name`` + is presented as they are automatically retrieved from the experiment config + in the job environment. + + Note: + Instead of the ``Run`` constructor, the ``load_run`` is recommended to use + in a job script to load the existing run created before the job launch. + Otherwise, a new run may be created each time you launch a job. + + .. code:: python + + with load_run() as run: + run.log_metric(...) + ... + + Args: + experiment_name (str): The name of the experiment. The name must be unique + within an account. + run_name (str): The name of the run. If it is not specified, one is auto generated. + experiment_display_name (str): Name of the experiment that will appear in UI, + such as SageMaker Studio. (default: None). This display name is used in + a create experiment call. If an experiment with the specified name already exists, + this display name won't take effect. + run_display_name (str): The display name of the run used in UI (default: None). + This display name is used in a create run call. If a run with the + specified name already exists, this display name won't take effect. + tags (List[Dict[str, str]]): A list of tags to be used for all create calls, + e.g. to create an experiment, a run group, etc. (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + """ + # TODO: we should revert the lower casting once backend fix reaches prod + self.experiment_name = experiment_name.lower() + sagemaker_session = sagemaker_session or _utils.default_session() + self.run_name = run_name or unique_name_from_base(RUN_NAME_BASE) + + # avoid confusion due to mis-match in casing between run name and TC name + self.run_name = self.run_name.lower() + + trial_component_name = Run._generate_trial_component_name( + run_name=self.run_name, experiment_name=self.experiment_name + ) + self.run_group_name = Run._generate_trial_name(self.experiment_name) + + self._experiment = _Experiment._load_or_create( + experiment_name=self.experiment_name, + display_name=experiment_display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + + self._trial = _Trial._load_or_create( + experiment_name=self.experiment_name, + trial_name=self.run_group_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + + self._trial_component, is_existed = _TrialComponent._load_or_create( + trial_component_name=trial_component_name, + display_name=run_display_name, + tags=Run._append_run_tc_label_to_tags(tags), + sagemaker_session=sagemaker_session, + ) + if is_existed: + logger.info( + "The run (%s) under experiment (%s) already exists. Loading it. " + "Note: sagemaker.experiments.load_run is recommended to use when " + "the desired run already exists.", + self.run_name, + self.experiment_name, + ) + self._trial.add_trial_component(self._trial_component) + + self._artifact_uploader = _ArtifactUploader( + trial_component_name=self._trial_component.trial_component_name, + sagemaker_session=sagemaker_session, + ) + self._lineage_artifact_tracker = _LineageArtifactTracker( + trial_component_arn=self._trial_component.trial_component_arn, + sagemaker_session=sagemaker_session, + ) + self._metrics_manager = _MetricsManager( + trial_component_name=self._trial_component.trial_component_name, + sagemaker_session=sagemaker_session, + ) + self._inside_init_context = False + self._inside_load_context = False + self._in_load = False + + @property + def experiment_config(self) -> dict: + """Get experiment config from run attributes.""" + return { + EXPERIMENT_NAME: self.experiment_name, + TRIAL_NAME: self.run_group_name, + RUN_NAME: self._trial_component.trial_component_name, + } + + @validate_invoked_inside_run_context + def log_parameter(self, name: str, value: Union[str, int, float]): + """Record a single parameter value for this run. + + Overwrites any previous value recorded for the specified parameter name. + + Args: + name (str): The name of the parameter. + value (str or int or float): The value of the parameter. + """ + if self._is_input_valid("parameter", name, value): + self._trial_component.parameters[name] = value + + @validate_invoked_inside_run_context + def log_parameters(self, parameters: Dict[str, Union[str, int, float]]): + """Record a collection of parameter values for this run. + + Args: + parameters (dict[str, str or int or float]): The parameters to record. + """ + filtered_parameters = { + key: value + for (key, value) in parameters.items() + if self._is_input_valid("parameter", key, value) + } + self._trial_component.parameters.update(filtered_parameters) + + @validate_invoked_inside_run_context + def log_metric( + self, + name: str, + value: float, + timestamp: Optional[datetime.datetime] = None, + step: Optional[int] = None, + ): + """Record a custom scalar metric value for this run. + + Note: + This method is for manual custom metrics, for automatic metrics see the + ``enable_sagemaker_metrics`` parameter on the ``estimator`` class. + + Args: + name (str): The name of the metric. + value (float): The value of the metric. + timestamp (datetime.datetime): The timestamp of the metric. + If not specified, the current UTC time will be used. + step (int): The integer iteration number of the metric value (default: None). + """ + if self._is_input_valid("metric", name, value): + self._metrics_manager.log_metric( + metric_name=name, value=value, timestamp=timestamp, step=step + ) + + @validate_invoked_inside_run_context + def log_precision_recall( + self, + y_true: Union[list, array], + predicted_probabilities: Union[list, array], + positive_label: Optional[Union[str, int]] = None, + title: Optional[str] = None, + is_output: bool = True, + no_skill: Optional[int] = None, + ): + """Create and log a precision recall graph artifact for Studio UI to render. + + The artifact is stored in S3 and represented as a lineage artifact + with an association with the run. + + You can view the artifact in the UI. + If your job is created by a pipeline execution you can view the artifact + by selecting the corresponding step in the pipelines UI. + See also `SageMaker Pipelines `_ + + This method requires sklearn library. + + Args: + y_true (list or array): True labels. If labels are not binary + then positive_label should be given. + predicted_probabilities (list or array): Estimated/predicted probabilities. + positive_label (str or int): Label of the positive class (default: None). + title (str): Title of the graph (default: None). + is_output (bool): Determines direction of association to the + run. Defaults to True (output artifact). + If set to False then represented as input association. + no_skill (int): The precision threshold under which the classifier cannot discriminate + between the classes and would predict a random class or a constant class in + all cases (default: None). + """ + + verify_length_of_true_and_predicted( + true_labels=y_true, + predicted_attrs=predicted_probabilities, + predicted_attrs_name="predicted probabilities", + ) + + get_module("sklearn") + from sklearn.metrics import precision_recall_curve, average_precision_score + + kwargs = {} + if positive_label is not None: + kwargs["pos_label"] = positive_label + + precision, recall, _ = precision_recall_curve(y_true, predicted_probabilities, **kwargs) + + kwargs["average"] = "micro" + ap = average_precision_score(y_true, predicted_probabilities, **kwargs) + + data = { + "type": "PrecisionRecallCurve", + "version": 0, + "title": title, + "precision": precision.tolist(), + "recall": recall.tolist(), + "averagePrecisionScore": ap, + "noSkill": no_skill, + } + self._log_graph_artifact( + artifact_name=title, data=data, graph_type="PrecisionRecallCurve", is_output=is_output + ) + + @validate_invoked_inside_run_context + def log_roc_curve( + self, + y_true: Union[list, array], + y_score: Union[list, array], + title: Optional[str] = None, + is_output: bool = True, + ): + """Create and log a receiver operating characteristic (ROC curve) artifact. + + The artifact is stored in S3 and represented as a lineage artifact + with an association with the run. + + You can view the artifact in the UI. + If your job is created by a pipeline execution you can view the artifact + by selecting the corresponding step in the pipelines UI. + See also `SageMaker Pipelines `_ + + This method requires sklearn library. + + Args: + y_true (list or array): True labels. If labels are not binary + then positive_label should be given. + y_score (list or array): Estimated/predicted probabilities. + title (str): Title of the graph (default: None). + is_output (bool): Determines direction of association to the + run. Defaults to True (output artifact). + If set to False then represented as input association. + """ + verify_length_of_true_and_predicted( + true_labels=y_true, predicted_attrs=y_score, predicted_attrs_name="predicted scores" + ) + + get_module("sklearn") + from sklearn.metrics import roc_curve, auc + + fpr, tpr, _ = roc_curve(y_true, y_score) + + auc = auc(fpr, tpr) + + data = { + "type": "ROCCurve", + "version": 0, + "title": title, + "falsePositiveRate": fpr.tolist(), + "truePositiveRate": tpr.tolist(), + "areaUnderCurve": auc, + } + self._log_graph_artifact( + artifact_name=title, data=data, graph_type="ROCCurve", is_output=is_output + ) + + @validate_invoked_inside_run_context + def log_confusion_matrix( + self, + y_true: Union[list, array], + y_pred: Union[list, array], + title: Optional[str] = None, + is_output: bool = True, + ): + """Create and log a confusion matrix artifact. + + The artifact is stored in S3 and represented as a lineage artifact + with an association with the run. + + You can view the artifact in the UI. + If your job is created by a pipeline execution you can view the + artifact by selecting the corresponding step in the pipelines UI. + See also `SageMaker Pipelines `_ + This method requires sklearn library. + + Args: + y_true (list or array): True labels. If labels are not binary + then positive_label should be given. + y_pred (list or array): Predicted labels. + title (str): Title of the graph (default: None). + is_output (bool): Determines direction of association to the + run. Defaults to True (output artifact). + If set to False then represented as input association. + """ + verify_length_of_true_and_predicted( + true_labels=y_true, predicted_attrs=y_pred, predicted_attrs_name="predicted labels" + ) + + get_module("sklearn") + from sklearn.metrics import confusion_matrix + + matrix = confusion_matrix(y_true, y_pred) + + data = { + "type": "ConfusionMatrix", + "version": 0, + "title": title, + "confusionMatrix": matrix.tolist(), + } + self._log_graph_artifact( + artifact_name=title, data=data, graph_type="ConfusionMatrix", is_output=is_output + ) + + @validate_invoked_inside_run_context + def log_artifact( + self, name: str, value: str, media_type: Optional[str] = None, is_output: bool = True + ): + """Record a single artifact for this run. + + Overwrites any previous value recorded for the specified name. + + Args: + name (str): The name of the artifact. + value (str): The value. + media_type (str): The MediaType (MIME type) of the value (default: None). + is_output (bool): Determines direction of association to the + run. Defaults to True (output artifact). + If set to False then represented as input association. + """ + self._verify_trial_component_artifacts_length(is_output=is_output) + if is_output: + self._trial_component.output_artifacts[name] = TrialComponentArtifact( + value, media_type=media_type + ) + else: + self._trial_component.input_artifacts[name] = TrialComponentArtifact( + value, media_type=media_type + ) + + @validate_invoked_inside_run_context + def log_file( + self, + file_path: str, + name: Optional[str] = None, + media_type: Optional[str] = None, + is_output: bool = True, + ): + """Upload a file to s3 and store it as an input/output artifact in this run. + + Args: + file_path (str): The path of the local file to upload. + name (str): The name of the artifact (default: None). + media_type (str): The MediaType (MIME type) of the file. + If not specified, this library will attempt to infer the media type + from the file extension of ``file_path``. + is_output (bool): Determines direction of association to the + run. Defaults to True (output artifact). + If set to False then represented as input association. + """ + self._verify_trial_component_artifacts_length(is_output) + media_type = media_type or guess_media_type(file_path) + name = name or resolve_artifact_name(file_path) + s3_uri, _ = self._artifact_uploader.upload_artifact(file_path) + if is_output: + self._trial_component.output_artifacts[name] = TrialComponentArtifact( + value=s3_uri, media_type=media_type + ) + else: + self._trial_component.input_artifacts[name] = TrialComponentArtifact( + value=s3_uri, media_type=media_type + ) + + def close(self): + """Persist any data saved locally.""" + try: + # Update the trial component with additions from the Run object + self._trial_component.save() + # Create Lineage entities for the artifacts + self._lineage_artifact_tracker.save() + finally: + if self._metrics_manager: + self._metrics_manager.close() + + @staticmethod + def _generate_trial_name(base_name) -> str: + """Generate the reserved trial name based on experiment name + + Args: + base_name (str): The ``experiment_name`` of this ``Run`` object. + """ + available_length = MAX_NAME_LEN_IN_BACKEND - len(TRIAL_NAME_TEMPLATE) + return TRIAL_NAME_TEMPLATE.format(base_name[:available_length]) + + @staticmethod + def _is_input_valid(input_type, field_name, field_value) -> bool: + """Check if the input is valid or not + + Args: + input_type (str): The type of the input, one of ``parameter``, ``metric``. + field_name (str): The name of the field to be checked. + field_value (str or int or float): The value of the field to be checked. + """ + if isinstance(field_value, Number) and (isnan(field_value) or isinf(field_value)): + logger.warning( + "Failed to log %s %s. Received invalid value: %s.", + input_type, + field_name, + field_value, + ) + return False + return True + + def _log_graph_artifact(self, data, graph_type, is_output, artifact_name=None): + """Log an artifact. + + Logs an artifact by uploading data to S3, creating an artifact, and associating that + artifact with the run trial component. + + Args: + data (dict): Artifacts data that will be saved to S3. + graph_type (str): The type of the artifact. + is_output (bool): Determines direction of association to the + trial component. Defaults to True (output artifact). + If set to False then represented as input association. + artifact_name (str): Name of the artifact (default: None). + """ + # generate an artifact name + if not artifact_name: + unique_name_from_base(graph_type) + + # create a json file in S3 + s3_uri, etag = self._artifact_uploader.upload_object_artifact( + artifact_name, data, file_extension="json" + ) + + # create an artifact and association for the table + if is_output: + self._lineage_artifact_tracker.add_output_artifact( + name=artifact_name, source_uri=s3_uri, etag=etag, artifact_type=graph_type + ) + else: + self._lineage_artifact_tracker.add_input_artifact( + name=artifact_name, source_uri=s3_uri, etag=etag, artifact_type=graph_type + ) + + def _verify_trial_component_artifacts_length(self, is_output): + """Verify the length of trial component artifacts + + Args: + is_output (bool): Determines direction of association to the + trial component. + + Raises: + ValueError: If the length of trial component artifacts exceeds the limit. + """ + err_msg_template = "Cannot add more than {} {}_artifacts under run" + if is_output: + if len(self._trial_component.output_artifacts) >= MAX_RUN_TC_ARTIFACTS_LEN: + raise ValueError(err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "output")) + else: + if len(self._trial_component.input_artifacts) >= MAX_RUN_TC_ARTIFACTS_LEN: + raise ValueError(err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "input")) + + @staticmethod + def _generate_trial_component_name(run_name: str, experiment_name: str) -> str: + """Generate the TrialComponentName based on run_name and experiment_name + + Args: + run_name (str): The run_name supplied by the user. + experiment_name (str): The experiment_name supplied by the user, + which is prepended to the run_name to generate the TrialComponentName. + + Returns: + str: The TrialComponentName used to create a trial component + which is unique in an account. + + Raises: + ValueError: If either the run_name or the experiment_name exceeds + the length limit. + """ + buffer = 1 # leave length buffers for delimiters + max_len = int(MAX_NAME_LEN_IN_BACKEND / 2) - buffer + err_msg_template = "The {} (length: {}) must have length less than or equal to {}" + if len(run_name) > max_len: + raise ValueError(err_msg_template.format("run_name", len(run_name), max_len)) + if len(experiment_name) > max_len: + raise ValueError( + err_msg_template.format("experiment_name", len(experiment_name), max_len) + ) + trial_component_name = "{}{}{}".format(experiment_name, DELIMITER, run_name) + # due to mixed-case concerns on the backend + trial_component_name = trial_component_name.lower() + return trial_component_name + + @staticmethod + def _extract_run_name_from_tc_name(trial_component_name: str, experiment_name: str) -> str: + """Extract the user supplied run name from a trial component name. + + Args: + trial_component_name (str): The name of a run trial component. + experiment_name (str): The experiment_name supplied by the user, + which was prepended to the run_name to generate the trial_component_name. + + Returns: + str: The name of the Run object supplied by a user. + """ + return trial_component_name.replace("{}{}".format(experiment_name, DELIMITER), "", 1) + + @staticmethod + def _append_run_tc_label_to_tags(tags: Optional[List[Dict[str, str]]] = None) -> list: + """Append the run trial component label to tags used to create a trial component. + + Args: + tags (List[Dict[str, str]]): The tags supplied by users to initialize a Run object. + + Returns: + list: The updated tags with the appended run trial component label. + """ + if not tags: + tags = [] + tags.append(RUN_TC_TAG) + return tags + + def __enter__(self): + """Updates the start time of the run. + + Returns: + object: self. + """ + nested_with_err_msg_template = ( + "It is not allowed to use nested 'with' statements on the {}." + ) + if self._in_load: + if self._inside_load_context: + raise RuntimeError(nested_with_err_msg_template.format("load_run")) + self._inside_load_context = True + else: + if _RunContext.get_current_run(): + raise RuntimeError(nested_with_err_msg_template.format("Run")) + self._inside_init_context = True + _RunContext.add_run_object(self) + + if not self._trial_component.start_time: + start_time = datetime.datetime.now(dateutil.tz.tzlocal()) + self._trial_component.start_time = start_time + self._trial_component.status = _api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, + message="Within a run context", + ) + # Save the start_time and status changes to backend + self._trial_component.save() + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + """Updates the end time of the run. + + Args: + exc_type (str): The exception type. + exc_value (str): The exception value. + exc_traceback (str): The stack trace of the exception. + """ + if self._in_load: + self._inside_load_context = False + self._in_load = False + else: + self._inside_init_context = False + _RunContext.drop_current_run() + + end_time = datetime.datetime.now(dateutil.tz.tzlocal()) + self._trial_component.end_time = end_time + if exc_value: + self._trial_component.status = _api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.Failed.value, message=str(exc_value) + ) + else: + self._trial_component.status = _api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.Completed.value + ) + + self.close() + + +def load_run( + run_name: Optional[str] = None, + experiment_name: Optional[str] = None, + sagemaker_session: Optional["Session"] = None, +) -> Run: + """Load an existing run. + + In order to reuse an existing run to log extra data, ``load_run`` is recommended. + It can be used in several ways: + + 1. Use ``load_run`` by explicitly passing in ``run_name`` and ``experiment_name``. + + If ``run_name`` and ``experiment_name`` are passed in, they are honored over + the default experiment config in the job environment or the run context + (i.e. within the ``with`` block). + + Note: + Both ``run_name`` and ``experiment_name`` should be supplied to make this usage work. + Otherwise, you may get a ``ValueError``. + + .. code:: python + + with load_run(experiment_name="my-exp", run_name="my-run") as run: + run.log_metric(...) + ... + + 2. Use the ``load_run`` in a job script without supplying ``run_name`` and ``experiment_name``. + + In this case, the default experiment config (specified when creating the job) is fetched + from the job environment to load the run. + + .. code:: python + + # In a job script + with load_run() as run: + run.log_metric(...) + ... + + 3. Use the ``load_run`` in a notebook within a run context (i.e. the ``with`` block) + but without supplying ``run_name`` and ``experiment_name``. + + Every time we call ``with Run(...) as run1:``, the initialized ``run1`` is tracked + in the run context. Then when we call ``load_run()`` under this with statement, the ``run1`` + in the context is loaded by default. + + .. code:: python + + # In a notebook + with Run(experiment_name="my-exp", run_name="my-run", ...) as run1: + run1.log_parameter(...) + + with load_run() as run2: # run2 is the same object as run1 + run2.log_metric(...) + ... + + Args: + run_name (str): The name of the run to be loaded (default: None). + If it is None, the ``RunName`` in the ``ExperimentConfig`` of the job will be + fetched to load the run. + experiment_name (str): The name of the Experiment that the to be loaded run + is associated with (default: None). + Note: the experiment_name must be supplied along with a valid run_name. + Otherwise, it will be ignored. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + Run: The loaded Run object. + """ + sagemaker_session = sagemaker_session or _utils.default_session() + environment = _RunEnvironment.load() + + verify_load_input_names(run_name=run_name, experiment_name=experiment_name) + + if run_name or environment: + if run_name: + logger.warning( + "run_name is explicitly supplied in load_run, " + "which will be prioritized to load the Run object. " + "In other words, the run name in the experiment config, fetched from the " + "job environment or the current run context, will be ignored." + ) + else: + exp_config = get_tc_and_exp_config_from_job_env( + environment=environment, sagemaker_session=sagemaker_session + ) + run_name = Run._extract_run_name_from_tc_name( + trial_component_name=exp_config[RUN_NAME], + experiment_name=exp_config[EXPERIMENT_NAME], + ) + experiment_name = exp_config[EXPERIMENT_NAME] + + run_instance = Run( + experiment_name=experiment_name, + run_name=run_name, + sagemaker_session=sagemaker_session, + ) + elif _RunContext.get_current_run(): + run_instance = _RunContext.get_current_run() + else: + raise RuntimeError( + "Failed to load a Run object. " + "Please make sure a Run object has been initialized already." + ) + + run_instance._in_load = True + return run_instance + + +def list_runs( + experiment_name: str, + created_before: Optional[datetime.datetime] = None, + created_after: Optional[datetime.datetime] = None, + sagemaker_session: Optional["Session"] = None, + max_results: Optional[int] = None, + next_token: Optional[str] = None, + sort_by: SortByType = SortByType.CREATION_TIME, + sort_order: SortOrderType = SortOrderType.DESCENDING, +) -> list: + """Return a list of ``Run`` objects matching the given criteria. + + Args: + experiment_name (str): Only Run objects related to the specified experiment + are returned. + created_before (datetime.datetime): Return Run objects created before this instant + (default: None). + created_after (datetime.datetime): Return Run objects created after this instant + (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + max_results (int): Maximum number of Run objects to retrieve (default: None). + next_token (str): Token for next page of results (default: None). + sort_by (SortByType): The property to sort results by. One of NAME, CREATION_TIME + (default: CREATION_TIME). + sort_order (SortOrderType): One of ASCENDING, or DESCENDING (default: DESCENDING). + + Returns: + list: A list of ``Run`` objects. + """ + tc_summaries = _TrialComponent.list( + experiment_name=experiment_name, + created_before=created_before, + created_after=created_after, + sort_by=sort_by.value, + sort_order=sort_order.value, + sagemaker_session=sagemaker_session, + max_results=max_results, + next_token=next_token, + ) + run_list = [] + for tc_summary in tc_summaries: + if not is_run_trial_component( + trial_component_name=tc_summary.trial_component_name, + sagemaker_session=sagemaker_session, + ): + continue + run_instance = Run( + experiment_name=experiment_name, + run_name=Run._extract_run_name_from_tc_name( + trial_component_name=tc_summary.trial_component_name, + experiment_name=experiment_name, + ), + sagemaker_session=sagemaker_session, + ) + run_list.append(run_instance) + return run_list diff --git a/src/sagemaker/experiments/trial.py b/src/sagemaker/experiments/trial.py new file mode 100644 index 0000000000..146b24f18b --- /dev/null +++ b/src/sagemaker/experiments/trial.py @@ -0,0 +1,289 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the Trial class.""" +from __future__ import absolute_import + +from sagemaker.apiutils import _base_types +from sagemaker.experiments import _api_types +from sagemaker.experiments.trial_component import _TrialComponent + + +class _Trial(_base_types.Record): + """An execution of a data-science workflow with an experiment. + + Consists of a list of trial component objects, which document individual + activities within the workflow. + + Attributes: + trial_name (str): The name of the trial. + experiment_name (str): The name of the trial's experiment. + display_name (str): The name of the trial that will appear in UI, + such as SageMaker Studio. + tags (List[Dict[str, str]]): A list of tags to associate with the trial. + """ + + trial_name = None + experiment_name = None + display_name = None + tags = None + + _boto_create_method = "create_trial" + _boto_load_method = "describe_trial" + _boto_delete_method = "delete_trial" + _boto_update_method = "update_trial" + + _boto_update_members = ["trial_name", "display_name"] + _boto_delete_members = ["trial_name"] + + @classmethod + def _boto_ignore(cls): + """Response fields to ignore by default.""" + return super(_Trial, cls)._boto_ignore() + ["CreatedBy"] + + def save(self): + """Save the state of this Trial to SageMaker. + + Returns: + dict: Update trial response. + """ + return self._invoke_api(self._boto_update_method, self._boto_update_members) + + def delete(self): + """Delete this Trial from SageMaker. + + Does not delete associated Trial Components. + + Returns: + dict: Delete trial response. + """ + return self._invoke_api(self._boto_delete_method, self._boto_delete_members) + + @classmethod + def load(cls, trial_name, sagemaker_session=None): + """Load an existing trial and return a `_Trial` object. + + Args: + trial_name: (str): Name of the Trial. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial._Trial: A SageMaker `_Trial` object + """ + return super(_Trial, cls)._construct( + cls._boto_load_method, + trial_name=trial_name, + sagemaker_session=sagemaker_session, + ) + + @classmethod + def create( + cls, experiment_name, trial_name, display_name=None, tags=None, sagemaker_session=None + ): + """Create a new trial and return a `_Trial` object. + + Args: + experiment_name: (str): Name of the experiment to create this trial in. + trial_name: (str): Name of the Trial. + display_name (str): Name of the trial that will appear in UI, + such as SageMaker Studio (default: None). + tags (List[dict]): A list of tags to associate with the trial (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial._Trial: A SageMaker `_Trial` object + """ + trial = super(_Trial, cls)._construct( + cls._boto_create_method, + trial_name=trial_name, + experiment_name=experiment_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + return trial + + @classmethod + def list( + cls, + experiment_name=None, + trial_component_name=None, + created_before=None, + created_after=None, + sort_by=None, + sort_order=None, + sagemaker_session=None, + ): + """List all trials matching the specified criteria. + + Args: + experiment_name (str): Name of the experiment. If specified, only trials in + the experiment will be returned (default: None). + trial_component_name (str): Name of the trial component. If specified, only + trials with this trial component name will be returned (default: None). + created_before (datetime.datetime): Return trials created before this instant + (default: None). + created_after (datetime.datetime): Return trials created after this instant + (default: None). + sort_by (str): Which property to sort results by. One of 'Name', 'CreationTime' + (default: None). + sort_order (str): One of 'Ascending', or 'Descending' (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + Returns: + collections.Iterator[experiments._api_types.TrialSummary]: An iterator over trials + matching the specified criteria. + """ + return super(_Trial, cls)._list( + "list_trials", + _api_types.TrialSummary.from_boto, + "TrialSummaries", + experiment_name=experiment_name, + trial_component_name=trial_component_name, + created_before=created_before, + created_after=created_after, + sort_by=sort_by, + sort_order=sort_order, + sagemaker_session=sagemaker_session, + ) + + def add_trial_component(self, trial_component): + """Add the specified trial component to this trial. + + A trial component may belong to many trials and a trial may have many trial components. + + Args: + trial_component (str or _TrialComponent): The trial component to add. + Can be one of a _TrialComponent instance, or a string containing + the name of the trial component to add. + """ + if isinstance(trial_component, _TrialComponent): + trial_component_name = trial_component.trial_component_name + elif isinstance(trial_component, str): + trial_component_name = trial_component + else: + raise TypeError( + "Unsupported type of trail component {}. " + "It has to be one type of _TrialComponent or str".format(trial_component) + ) + self.sagemaker_session.sagemaker_client.associate_trial_component( + TrialName=self.trial_name, TrialComponentName=trial_component_name + ) + + def remove_trial_component(self, trial_component): + """Remove the specified trial component from this trial. + + Args: + trial_component (str or _TrialComponent): The trial component to add. + Can be one of a _TrialComponent instance, or a string containing + the name of the trial component to add. + """ + if isinstance(trial_component, _TrialComponent): + trial_component_name = trial_component.trial_component_name + elif isinstance(trial_component, str): + trial_component_name = trial_component + else: + raise TypeError( + "Unsupported type of trail component {}. " + "It has to be one type of _TrialComponent or str".format(trial_component) + ) + self.sagemaker_session.sagemaker_client.disassociate_trial_component( + TrialName=self.trial_name, TrialComponentName=trial_component_name + ) + + def list_trial_components( + self, + created_before=None, + created_after=None, + sort_by=None, + sort_order=None, + max_results=None, + next_token=None, + ): + """List trial components in this trial matching the specified criteria. + + Args: + created_before (datetime.datetime): Return trials created before this instant + (default: None). + created_after (datetime.datetime): Return trials created after this instant + (default: None). + sort_by (str): Which property to sort results by. One of 'Name', + 'CreationTime' (default: None). + sort_order (str): One of 'Ascending', or 'Descending' (default: None). + max_results (int): maximum number of trial components to retrieve (default: None). + next_token (str): token for next page of results (default: None). + + Returns: + collections.Iterator[experiments._api_types.TrialComponentSummary] : An iterator over + trials matching the criteria. + """ + return _TrialComponent.list( + trial_name=self.trial_name, + created_before=created_before, + created_after=created_after, + sort_by=sort_by, + sort_order=sort_order, + max_results=max_results, + next_token=next_token, + sagemaker_session=self.sagemaker_session, + ) + + @classmethod + def _load_or_create( + cls, experiment_name, trial_name, display_name=None, tags=None, sagemaker_session=None + ): + """Load a trial by name and create a new one if it does not exist. + + Args: + experiment_name: (str): Name of the experiment to create this trial in. + trial_name: (str): Name of the Trial. + display_name (str): Name of the trial that will appear in UI, + such as SageMaker Studio (default: None). This is used only when the given + `trial_name` does not exist and a new trial has to be created. + tags (List[dict]): A list of tags to associate with the trial (default: None). + This is used only when the given `trial_name` does not exist and + a new trial has to be created. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial._Trial: A SageMaker `_Trial` object + """ + sagemaker_client = sagemaker_session.sagemaker_client + try: + trial = _Trial.load(trial_name, sagemaker_session) + if trial.experiment_name != experiment_name: # pylint: disable=no-member + raise ValueError( + "The given experiment_name {} ".format(experiment_name) + + "does not match that in the loaded trial {}".format( + trial.experiment_name # pylint: disable=no-member + ) + ) + except sagemaker_client.exceptions.ResourceNotFound: + trial = _Trial.create( + experiment_name=experiment_name, + trial_name=trial_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + return trial diff --git a/src/sagemaker/experiments/trial_component.py b/src/sagemaker/experiments/trial_component.py new file mode 100644 index 0000000000..e5701b2119 --- /dev/null +++ b/src/sagemaker/experiments/trial_component.py @@ -0,0 +1,341 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the TrialComponent class.""" +from __future__ import absolute_import + +import time + +from sagemaker.apiutils import _base_types +from sagemaker.experiments import _api_types +from sagemaker.experiments._api_types import TrialComponentSearchResult + + +class _TrialComponent(_base_types.Record): + """This class represents a SageMaker trial component object. + + A trial component is a stage in a trial. + Trial components are created automatically within the SageMaker runtime and + may not be created directly. To automatically associate trial components with + a trial and experiment, supply an experiment config when creating a job. + For example: https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTrainingJob.html + + Attributes: + trial_component_name (str): The name of the trial component. Generated by SageMaker + from the name of the source job with a suffix specific to the type of source job. + trial_component_arn (str): The ARN of the trial component. + display_name (str): The name of the trial component that will appear in UI, + such as SageMaker Studio. + source (TrialComponentSource): A TrialComponentSource object with a source_arn attribute. + status (str): Status of the source job. + start_time (datetime): When the source job started. + end_time (datetime): When the source job ended. + creation_time (datetime): When the source job was created. + created_by (obj): Contextual info on which account created the trial component. + last_modified_time (datetime): When the trial component was last modified. + last_modified_by (obj): Contextual info on which account last modified the trial component. + parameters (dict): Dictionary of parameters to the source job. + input_artifacts (dict): Dictionary of input artifacts. + output_artifacts (dict): Dictionary of output artifacts. + metrics (obj): Aggregated metrics for the job. + parameters_to_remove (list): The hyperparameters to remove from the component. + input_artifacts_to_remove (list): The input artifacts to remove from the component. + output_artifacts_to_remove (list): The output artifacts to remove from the component. + tags (List[Dict[str, str]]): A list of tags to associate with the trial component. + """ + + trial_component_name = None + trial_component_arn = None + display_name = None + source = None + status = None + start_time = None + end_time = None + creation_time = None + created_by = None + last_modified_time = None + last_modified_by = None + parameters = None + input_artifacts = None + output_artifacts = None + metrics = None + parameters_to_remove = None + input_artifacts_to_remove = None + output_artifacts_to_remove = None + tags = None + + _boto_load_method = "describe_trial_component" + _boto_create_method = "create_trial_component" + _boto_update_method = "update_trial_component" + _boto_delete_method = "delete_trial_component" + + _custom_boto_types = { + "source": (_api_types.TrialComponentSource, False), + "status": (_api_types.TrialComponentStatus, False), + "parameters": (_api_types.TrialComponentParameters, False), + "input_artifacts": (_api_types.TrialComponentArtifact, True), + "output_artifacts": (_api_types.TrialComponentArtifact, True), + "metrics": (_api_types.TrialComponentMetricSummary, True), + } + + _boto_update_members = [ + "trial_component_name", + "display_name", + "status", + "start_time", + "end_time", + "parameters", + "input_artifacts", + "output_artifacts", + "parameters_to_remove", + "input_artifacts_to_remove", + "output_artifacts_to_remove", + ] + _boto_delete_members = ["trial_component_name"] + + def __init__(self, sagemaker_session=None, **kwargs): + """Init for _TrialComponent""" + super().__init__(sagemaker_session, **kwargs) + self.parameters = self.parameters or {} + self.input_artifacts = self.input_artifacts or {} + self.output_artifacts = self.output_artifacts or {} + + @classmethod + def _boto_ignore(cls): + """Response fields to ignore by default.""" + return super(_TrialComponent, cls)._boto_ignore() + ["CreatedBy"] + + def save(self): + """Save the state of this TrialComponent to SageMaker.""" + return self._invoke_api(self._boto_update_method, self._boto_update_members) + + def delete(self, force_disassociate=False): + """Delete this TrialComponent from SageMaker. + + Args: + force_disassociate (boolean): Indicates whether to force disassociate the + trial component with the trials before deletion (default: False). + If set to true, force disassociate the trial component with associated trials + first, then delete the trial component. + If it's not set or set to false, it will delete the trial component directory + without disassociation. + + Returns: + dict: Delete trial component response. + """ + if force_disassociate: + next_token = None + + while True: + if next_token: + list_trials_response = self.sagemaker_session.sagemaker_client.list_trials( + TrialComponentName=self.trial_component_name, NextToken=next_token + ) + else: + list_trials_response = self.sagemaker_session.sagemaker_client.list_trials( + TrialComponentName=self.trial_component_name + ) + + # Disassociate the trials and trial components + for per_trial in list_trials_response["TrialSummaries"]: + # to prevent DisassociateTrialComponent throttling + time.sleep(1.2) + self.sagemaker_session.sagemaker_client.disassociate_trial_component( + TrialName=per_trial["TrialName"], + TrialComponentName=self.trial_component_name, + ) + + if "NextToken" in list_trials_response: + next_token = list_trials_response["NextToken"] + else: + break + + return self._invoke_api(self._boto_delete_method, self._boto_delete_members) + + @classmethod + def load(cls, trial_component_name, sagemaker_session=None): + """Load an existing trial component and return an `_TrialComponent` object representing it. + + Args: + trial_component_name (str): Name of the trial component + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object + """ + trial_component = cls._construct( + cls._boto_load_method, + trial_component_name=trial_component_name, + sagemaker_session=sagemaker_session, + ) + return trial_component + + @classmethod + def create(cls, trial_component_name, display_name=None, tags=None, sagemaker_session=None): + """Create a trial component and return a `_TrialComponent` object representing it. + + Args: + trial_component_name (str): The name of the trial component. + display_name (str): Display name of the trial component used by Studio (default: None). + tags (List[Dict[str, str]]): Tags to add to the trial component (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object. + """ + return super(_TrialComponent, cls)._construct( + cls._boto_create_method, + trial_component_name=trial_component_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + + @classmethod + def list( + cls, + source_arn=None, + created_before=None, + created_after=None, + sort_by=None, + sort_order=None, + sagemaker_session=None, + trial_name=None, + experiment_name=None, + max_results=None, + next_token=None, + ): + """Return a list of trial component summaries. + + Args: + source_arn (str): A SageMaker Training or Processing Job ARN (default: None). + created_before (datetime.datetime): Return trial components created before this instant + (default: None). + created_after (datetime.datetime): Return trial components created after this instant + (default: None). + sort_by (str): Which property to sort results by. One of 'Name', 'CreationTime' + (default: None). + sort_order (str): One of 'Ascending', or 'Descending' (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + trial_name (str): If provided only trial components related to the trial are returned + (default: None). + experiment_name (str): If provided only trial components related to the experiment are + returned (default: None). + max_results (int): maximum number of trial components to retrieve (default: None). + next_token (str): token for next page of results (default: None). + Returns: + collections.Iterator[experiments._api_types.TrialComponentSummary]: An iterator + over `TrialComponentSummary` objects. + """ + return super(_TrialComponent, cls)._list( + "list_trial_components", + _api_types.TrialComponentSummary.from_boto, + "TrialComponentSummaries", + source_arn=source_arn, + created_before=created_before, + created_after=created_after, + sort_by=sort_by, + sort_order=sort_order, + sagemaker_session=sagemaker_session, + trial_name=trial_name, + experiment_name=experiment_name, + max_results=max_results, + next_token=next_token, + ) + + @classmethod + def search( + cls, + search_expression=None, + sort_by=None, + sort_order=None, + max_results=None, + sagemaker_session=None, + ): + """Search Experiment Trail Component. + + Returns SearchResults in the account matching the search criteria. + + Args: + search_expression: (SearchExpression): A Boolean conditional statement (default: None). + Resource objects must satisfy this condition to be included in search results. + You must provide at least one subexpression, filter, or nested filter. + sort_by (str): The name of the resource property used to sort the SearchResults + (default: None). + sort_order (str): How SearchResults are ordered. Valid values are Ascending or + Descending (default: None). + max_results (int): The maximum number of results to return in a SearchResponse + (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + collections.Iterator[SearchResult] : An iterator over search results matching the + search criteria. + """ + return super(_TrialComponent, cls)._search( + search_resource="ExperimentTrialComponent", + search_item_factory=TrialComponentSearchResult.from_boto, + search_expression=None if search_expression is None else search_expression.to_boto(), + sort_by=sort_by, + sort_order=sort_order, + max_results=max_results, + sagemaker_session=sagemaker_session, + ) + + @classmethod + def _load_or_create( + cls, trial_component_name, display_name=None, tags=None, sagemaker_session=None + ): + """Load a trial component by name and create a new one if it does not exist. + + Args: + trial_component_name (str): The name of the trial component. + display_name (str): Display name of the trial component used by Studio (default: None). + This is used only when the given `trial_component_name` does not + exist and a new trial component has to be created. + tags (List[Dict[str, str]]): Tags to add to the trial component (default: None). + This is used only when the given `trial_component_name` does not + exist and a new trial component has to be created. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object. + bool: A boolean variable indicating whether the trail component already exists + """ + sagemaker_client = sagemaker_session.sagemaker_client + is_existed = False + try: + run_tc = _TrialComponent.load(trial_component_name, sagemaker_session) + is_existed = True + except sagemaker_client.exceptions.ResourceNotFound: + run_tc = _TrialComponent.create( + trial_component_name=trial_component_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + return run_tc, is_existed diff --git a/src/sagemaker/lineage/_utils.py b/src/sagemaker/lineage/_utils.py index 28732b0174..7c833a468e 100644 --- a/src/sagemaker/lineage/_utils.py +++ b/src/sagemaker/lineage/_utils.py @@ -12,7 +12,6 @@ # language governing permissions and limitations under the License. """SageMaker lineage utility methods.""" from __future__ import absolute_import -from importlib import import_module from sagemaker.lineage import association @@ -38,22 +37,6 @@ def _disassociate(source_arn=None, destination_arn=None, sagemaker_session=None) curr_association.delete() -def get_module(module_name): - """Import a module. - - Args: - module_name (str): name of the module to import. - - Returns: - [obj]: The imported module. - Raises exceptions when the module name is not found - """ - try: - return import_module(module_name) - except ImportError: - raise Exception("Cannot import module {}, please try again.".format(module_name)) - - def get_resource_name_from_arn(arn): """Extract the resource name from an ARN string. diff --git a/src/sagemaker/lineage/artifact.py b/src/sagemaker/lineage/artifact.py index 3921562beb..718344095a 100644 --- a/src/sagemaker/lineage/artifact.py +++ b/src/sagemaker/lineage/artifact.py @@ -29,8 +29,9 @@ LineageEntityEnum, LineageQueryDirectionEnum, ) -from sagemaker.lineage._utils import get_module, _disassociate, get_resource_name_from_arn +from sagemaker.lineage._utils import _disassociate, get_resource_name_from_arn from sagemaker.lineage.association import Association +from sagemaker.utils import get_module LOGGER = logging.getLogger("sagemaker") diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 01d4361197..af52da6288 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -33,7 +33,12 @@ from sagemaker.job import _Job from sagemaker.local import LocalSession from sagemaker.network import NetworkConfig -from sagemaker.utils import base_name_from_image, get_config_value, name_from_base +from sagemaker.utils import ( + base_name_from_image, + get_config_value, + name_from_base, + check_and_get_run_experiment_config, +) from sagemaker.session import Session from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.functions import Join @@ -203,6 +208,7 @@ def run( outputs=outputs, ) + experiment_config = check_and_get_run_experiment_config(experiment_config) self.latest_job = ProcessingJob.start_new( processor=self, inputs=normalized_inputs, @@ -605,6 +611,7 @@ def run( kms_key=kms_key, ) + experiment_config = check_and_get_run_experiment_config(experiment_config) self.latest_job = ProcessingJob.start_new( processor=self, inputs=normalized_inputs, diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 72df570496..ce6a3b99cd 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -89,6 +89,7 @@ def __init__( sagemaker_featurestore_runtime_client=None, default_bucket=None, settings=SessionSettings(), + sagemaker_metrics_client=None, ): """Initialize a SageMaker ``Session``. @@ -116,6 +117,10 @@ def __init__( Example: "sagemaker-my-custom-bucket". settings (sagemaker.session_settings.SessionSettings): Optional. Set of optional parameters to apply to the session. + sagemaker_metrics_client (boto3.SageMakerMetrics.Client): + Client which makes SageMaker Metrics related calls to Amazon SageMaker + (default: None). If not provided, one will be created using + this instance's ``boto_session``. """ self._default_bucket = None self._default_bucket_name_override = default_bucket @@ -130,6 +135,7 @@ def __init__( sagemaker_client=sagemaker_client, sagemaker_runtime_client=sagemaker_runtime_client, sagemaker_featurestore_runtime_client=sagemaker_featurestore_runtime_client, + sagemaker_metrics_client=sagemaker_metrics_client, ) def _initialize( @@ -138,6 +144,7 @@ def _initialize( sagemaker_client, sagemaker_runtime_client, sagemaker_featurestore_runtime_client, + sagemaker_metrics_client, ): """Initialize this SageMaker Session. @@ -172,6 +179,12 @@ def _initialize( "sagemaker-featurestore-runtime" ) + if sagemaker_metrics_client: + self.sagemaker_metrics_client = sagemaker_metrics_client + else: + self.sagemaker_metrics_client = self.boto_session.client("sagemaker-metrics") + prepend_user_agent(self.sagemaker_metrics_client) + self.local_mode = False @property @@ -548,8 +561,8 @@ def train( # noqa: C901 checkpoints will be provided under `/opt/ml/checkpoints/`. (default: ``None``). experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -558,6 +571,7 @@ def train( # noqa: C901 * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + * `RunName` is used to record an experiment run. enable_sagemaker_metrics (bool): enable SageMaker Metrics Time Series. For more information see: https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries @@ -703,8 +717,8 @@ def _get_train_request( # noqa: C901 checkpoints will be provided under `/opt/ml/checkpoints/`. (default: ``None``). experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -713,6 +727,7 @@ def _get_train_request( # noqa: C901 * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + * `RunName` is used to record an experiment run. enable_sagemaker_metrics (bool): enable SageMaker Metrics Time Series. For more information see: https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index 97278abdd0..40ed143ebc 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -27,7 +27,11 @@ from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.execution_variables import ExecutionVariables -from sagemaker.utils import base_name_from_image, name_from_base +from sagemaker.utils import ( + base_name_from_image, + name_from_base, + check_and_get_run_experiment_config, +) class Transformer(object): @@ -251,6 +255,7 @@ def transform( ) self._reset_output_path = True + experiment_config = check_and_get_run_experiment_config(experiment_config) self.latest_transform_job = _TransformJob.start_new( self, data, diff --git a/src/sagemaker/utilities/search_expression.py b/src/sagemaker/utilities/search_expression.py new file mode 100644 index 0000000000..5b2aaf3226 --- /dev/null +++ b/src/sagemaker/utilities/search_expression.py @@ -0,0 +1,133 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Simplify Search Expression by provide a simplified DSL""" +from __future__ import absolute_import + +from enum import Enum, unique + +from sagemaker.apiutils._base_types import ApiObject + + +# TODO: we should update the lineage to use search expressions +# defined here in a separate change +@unique +class Operator(Enum): + """Search operators""" + + EQUALS = "Equals" + NOT_EQUALS = "NotEquals" + GREATER_THAN = "GreaterThan" + GREATER_THAN_OR_EQUAL = "GreaterThanOrEqualTo" + LESS_THAN = "LessThan" + LESS_THAN_OR_EQUAL = "LessThanOrEqualTo" + CONTAINS = "Contains" + EXISTS = "Exists" + NOT_EXISTS = "NotExists" + + +@unique +class BooleanOperator(Enum): + """Boolean search operation enum""" + + AND = "And" + OR = "Or" + + +class SearchObject(ApiObject): + """Search Object""" + + def to_boto(self): + """Convert a search object to boto""" + return ApiObject.to_boto(self) + + +class Filter(SearchObject): + """A Python class represent a Search Filter object.""" + + name = None + operator = None + value = None + + def __init__(self, name, operator=None, value=None, **kwargs): + """Construct a Filter object + + Args: + name (str): filter field name + operator (Operator): one of Operator enum + value (str): value of the field + """ + super().__init__(**kwargs) + self.name = name + self.operator = None if operator is None else operator.value + self.value = value + + +class NestedFilter(SearchObject): + """A Python class represent a Nested Filter object.""" + + nested_property_name = None + filters = None + + def __init__(self, property_name, filters, **kwargs): + """Construct a Nested Filter object + + Args: + property_name (str): nested property name + filters (List[Filter]): list of Filter objects + """ + super().__init__(**kwargs) + self.nested_property_name = property_name + self.filters = list(map(lambda x: x.to_boto(), filters)) + + +class SearchExpression(SearchObject): + """A Python class representation of a Search Expression object. + + A sample search expression defined in here: + https://boto3.amazonaws.com/v1/documentation/api/1.12.8/reference/services/sagemaker.html#SageMaker.Client.search + """ + + filters = None + nested_filters = None + operator = None + sub_expressions = None + + def __init__( + self, + filters=None, + nested_filters=None, + sub_expressions=None, + boolean_operator=BooleanOperator.AND, + **kwargs + ): + """Construct a Search Expression object + + Args: + filters (List[Filter]): list of Filter objects + nested_filters (List[NestedFilter]): list of Nested Filters objects + sub_expressions (List[SearchExpression]): list of Search Expression objects + boolean_operator (BooleanOperator): one of the boolean operator enums + """ + super().__init__(**kwargs) + if filters is None and nested_filters is None and sub_expressions is None: + raise ValueError( + "You must specify at least one subexpression, filter, or nested filter" + ) + self.filters = None if filters is None else list(map(lambda x: x.to_boto(), filters)) + self.nested_filters = ( + None if nested_filters is None else list(map(lambda x: x.to_boto(), nested_filters)) + ) + self.sub_expressions = ( + None if sub_expressions is None else list(map(lambda x: x.to_boto(), sub_expressions)) + ) + self.operator = boolean_operator.value diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index e668b2a8ed..9d28e3bf4e 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -29,6 +29,7 @@ from datetime import datetime from typing import Optional +from importlib import import_module import botocore from six.moves.urllib import parse @@ -590,6 +591,27 @@ def retries( ) +def retry_with_backoff(callable_func, num_attempts=8): + """Retry with backoff until maximum attempts are reached + + Args: + callable_func (callable): The callable function to retry. + num_attempts (int): The maximum number of attempts to retry. + """ + if num_attempts < 1: + raise ValueError( + "The num_attempts must be >= 1, but the given value is {}.".format(num_attempts) + ) + for i in range(num_attempts): + try: + return callable_func() + except Exception as ex: # pylint: disable=broad-except + if i == num_attempts - 1: + raise ex + logger.error("Retrying in attempt %s, due to %s", (i + 1), str(ex)) + time.sleep(2**i) + + def _botocore_resolver(): """Get the DNS suffix for the given region. @@ -874,3 +896,47 @@ def _start_waiting(waiting_time: int): print(progress, end="\r") time.sleep(interval) print(len(progress) * " ", end="\r") + + +def get_module(module_name): + """Import a module. + + Args: + module_name (str): name of the module to import. + + Returns: + object: The imported module. + + Raises: + Exception: when the module name is not found + """ + try: + return import_module(module_name) + except ImportError: + raise Exception("Cannot import module {}, please try again.".format(module_name)) + + +def check_and_get_run_experiment_config(experiment_config: Optional[dict] = None) -> dict: + """Check user input experiment_config or get it from the current Run object if exists. + + Args: + experiment_config (dict): The experiment_config supplied by the user. + + Returns: + dict: Return the user supplied experiment_config if it is not None. + Otherwise fetch the experiment_config from the current Run object if exists. + """ + from sagemaker.experiments._run_context import _RunContext + + run_obj = _RunContext.get_current_run() + if experiment_config: + if run_obj: + logger.warning( + "The function is invoked within an Experiment Run context " + "but another experiment_config (%s) was supplied, so " + "ignoring the experiment_config fetched from the Run object.", + experiment_config, + ) + return experiment_config + + return run_obj.experiment_config if run_obj else None diff --git a/tests/data/experiment/inference.py b/tests/data/experiment/inference.py new file mode 100644 index 0000000000..cdb9a7b8c6 --- /dev/null +++ b/tests/data/experiment/inference.py @@ -0,0 +1,85 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +import logging +import os +import pickle as pkl + +import boto3 +import numpy as np +import sagemaker_xgboost_container.encoder as xgb_encoders + +sdk_name = "sagemaker-dev-1.0.tar.gz" +code_dir = "/opt/ml/code" + +sdk_file = f"{code_dir}/{sdk_name}" +os.system(f"pip install {sdk_file}") + +from sagemaker.session import Session +from sagemaker.experiments import load_run + +boto_session = boto3.Session(region_name=os.environ["AWS_REGION"]) +sagemaker_session = Session(boto_session=boto_session) + + +def model_fn(model_dir): + """ + Deserialize and return fitted model. + """ + with load_run( + experiment_name=os.environ["EXPERIMENT_NAME"], + run_name=os.environ["RUN_NAME"], + sagemaker_session=sagemaker_session, + ) as run: + logging.info(f"Run name: {run.run_name}") + logging.info(f"Experiment name: {run.experiment_name}") + logging.info(f"Trial component name: {run._trial_component.trial_component_name}") + run.log_parameters({"p3": 3.0, "p4": 4.0}) + run.log_metric("test-job-load-log-metric", 0.1) + + model_file = "xgboost-model" + booster = pkl.load(open(os.path.join(model_dir, model_file), "rb")) + return booster + + +def input_fn(request_body, request_content_type): + """ + The SageMaker XGBoost model server receives the request data body and the content type, + and invokes the `input_fn`. + Return a DMatrix (an object that can be passed to predict_fn). + """ + if request_content_type == "text/libsvm": + return xgb_encoders.libsvm_to_dmatrix(request_body) + else: + raise ValueError("Content type {} is not supported.".format(request_content_type)) + + +def predict_fn(input_data, model): + """ + SageMaker XGBoost model server invokes `predict_fn` on the return value of `input_fn`. + Return a two-dimensional NumPy array where the first columns are predictions + and the remaining columns are the feature contributions (SHAP values) for that prediction. + """ + prediction = model.predict(input_data) + feature_contribs = model.predict(input_data, pred_contribs=True, validate_features=False) + output = np.hstack((prediction[:, np.newaxis], feature_contribs)) + return output + + +def output_fn(predictions, content_type): + """ + After invoking predict_fn, the model server invokes `output_fn`. + """ + if content_type == "text/csv" or content_type == "application/json": + return ",".join(str(x) for x in predictions[0]) + else: + raise ValueError("Content type {} is not supported.".format(content_type)) diff --git a/tests/data/experiment/process_job_script_for_run_clz.py b/tests/data/experiment/process_job_script_for_run_clz.py new file mode 100644 index 0000000000..32fd0ab4f6 --- /dev/null +++ b/tests/data/experiment/process_job_script_for_run_clz.py @@ -0,0 +1,37 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This script file runs on SageMaker processing job""" +from __future__ import absolute_import + +import logging +import os +import boto3 + +sdk_file = "sagemaker-dev-1.0.tar.gz" +os.system(f"pip install {sdk_file}") + + +from sagemaker import Session +from sagemaker.experiments import load_run + + +boto_session = boto3.Session(region_name=os.environ["AWS_REGION"]) +sagemaker_session = Session(boto_session=boto_session) + + +with load_run(sagemaker_session=sagemaker_session) as run: + logging.info(f"Run name: {run.run_name}") + logging.info(f"Experiment name: {run.experiment_name}") + logging.info(f"Trial component name: {run._trial_component.trial_component_name}") + run.log_parameters({"p3": 3.0, "p4": 4.0}) + run.log_metric("test-job-load-log-metric", 0.1) diff --git a/tests/data/experiment/train_job_script_for_run_clz.py b/tests/data/experiment/train_job_script_for_run_clz.py new file mode 100644 index 0000000000..34c86e0993 --- /dev/null +++ b/tests/data/experiment/train_job_script_for_run_clz.py @@ -0,0 +1,71 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This script file runs on SageMaker training job""" +from __future__ import absolute_import + +import logging +import time +import os +import boto3 + +sdk_file = "sagemaker-dev-1.0.tar.gz" +os.system(f"pip install {sdk_file}") + +from sagemaker import Session +from sagemaker.experiments import load_run, Run + +boto_session = boto3.Session(region_name=os.environ["AWS_REGION"]) +sagemaker_session = Session(boto_session=boto_session) + +if os.environ["RUN_OPERATION"] == "init": + logging.info("Initializing a Run") + with Run( + experiment_name=os.environ["EXPERIMENT_NAME"], + run_name=os.environ["RUN_NAME"], + sagemaker_session=sagemaker_session, + ) as run: + logging.info(f"Run name: {run.run_name}") + logging.info(f"Experiment name: {run.experiment_name}") + logging.info(f"Trial component name: {run._trial_component.trial_component_name}") + run.log_parameter("p1", 1.0) + run.log_parameter("p2", 2) + + for i in range(2): + run.log_metric("A", i) + for i in range(2): + run.log_metric("B", i) + for i in range(2): + run.log_metric("C", i) + for i in range(2): + time.sleep(0.003) + run.log_metric("D", i) + for i in range(2): + time.sleep(0.003) + run.log_metric("E", i) + time.sleep(15) + +else: + logging.info("Loading a Run") + logging.info("Invoking load_run with name arguments") + with load_run( + experiment_name=os.environ["EXPERIMENT_NAME"], + run_name=os.environ["RUN_NAME"], + sagemaker_session=sagemaker_session, + ) as run: + run.log_parameters({"p3": 3.0, "p4": 4}) + run.log_metric("test-job-load-log-metric", 0.1) + + if os.environ.get("CALL_RUN_LOAD_WITH_NO_NAME_ARGS", None) == "True": + logging.info("Invoking load_run without name arguments") + with load_run(sagemaker_session=sagemaker_session) as run: + run.log_parameters({"p5": 5.0, "p6": 6}) diff --git a/tests/data/experiment/transform_job_materials/data.csv b/tests/data/experiment/transform_job_materials/data.csv new file mode 100644 index 0000000000..9f1b6c0bb0 --- /dev/null +++ b/tests/data/experiment/transform_job_materials/data.csv @@ -0,0 +1 @@ +-99 1:3 2:0.37 3:0.29 4:0.095 5:0.249 6:0.1045 7:0.058 8:0.067 \ No newline at end of file diff --git a/tests/data/experiment/transform_job_materials/xgb_model.tar.gz b/tests/data/experiment/transform_job_materials/xgb_model.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..3969bede9e315f8f51d27f3df2de623e670459c6 GIT binary patch literal 35946 zcmV(%K;pk2iwFP!00000|Lncnj+{r6B-pdF*o%IOhOTA+W+6Pz(U-H}<=!vQXb_ZC zlGsC$`bny%dm0VwKGeM1Uap(Fd1O{G>s&nigAq$)Rvx}nei32rZf^E3zyA3C{l`y- z-{1dy`Sx$V%zsHz>b3q|^8c>>TSgtZ{+-lZavP&c|GOl))bTfem%h;PT>0UA{(t}SAO8I>|J#51 zzyA+?+i$Pm{rvXwFaPnUAOC#w_S2hpAOH5pfBkg%`oo9U|N6Io`QQJ`|M>s@!{7Yd z5C7-;cfY*(^@qRzw;$f>OYbf};Nh>A`ryq^ul{)b;q~u$;_}_=AKzZSy8M?v^!eW} z-+g-h_SHXqeE;s%NB#Zv+c*FH^`|%Q-~I6SKgbU+e)#3}o42p@wKwnnSzr43)vtg2 z`RdKj`eM=-_b&(GZ-2c0_43unH~;;T?|6p~eB8hM)B7LsOT{mre*19w`e*+3K~L}f z@2~V$_!J-gczOBr$-`HAw6EUim5HOn3wBKU5kI+0ef6Uq_rL4We0cry8$G#KKl%l= zOY!5U50{s(K7Dxo=H2mny!-Qa`{nK1%eSxa2A=vKA1;4>^V6sI`q3X=|M}^?J@x*l z%a0#lefrlw@UEXPe|i1q+fVv}ZC4NAxA!>a%YS+O6Q2F^4;&mm|LceM|LxUJ@8ACZ z`oCU&c=5-FKi^#*U)Kx&mmA)rC-P2D46oX;<6``A`O}+U-slzCoB!qY+mDwgyZg7R z+rks^`1;Ae|KroEpD*8iyu`cX+Fg6k>$iXW_L^7xr`JFIcKJ#_#fzk$`uO_yKlFb3 z@n7Fw{`2MAGta-%sqM_uCzqNR^3~68K3x99@7Hc&eT}#4FPE=B{rTbYm7dw3f4q73 zE05>$lO5Sb=CCig5nuoQ`)l4_o5}X3=l<@a zo%AQ%KlwkIm;~N}tZk@mD_I}Y_e%hbl^E$CtltGDkz##sCP{k!*{-oLxNcAo#~KV#%w{_^4S<8RpOFFfB5m%qL`{m$|G zl?!5DAuhJK0OL{?6M_ zpIJYgZ0*{J*~{v2W%~KeN4t;XR>zs(`YHP0=6|+RJg%qjfBYv5hJVJ+^eg!LkAL~| zwXV9CA3o^m>%RTNdT{^n!-wDXmj1^pANjXf8$bIfzfiqBjtBhlr`LLHf4u*Ayqn|) zyI%jSu)zl8Kgf&nqWq&i^?%{gT?10DKk32z_~$3wxc_zo!uUNtN2qb_E5PN&^SIp2 zkc~nmyDwk;^Ch81@)cM={N~*+dJp|Wdhx~o{G;y?7n=8pUn4xRE@~~kw)Qf;7#BI$ zEO;mvX{J?Wy_6U2qKz(N=+pBerQGZCca&LH?fMS6$WwHL;BmQFQpd=6d)KY#M!IcV zLlE0Sm7HvM>!Ryxg}WO=!;@kgYpu5|!ye0;qwCw{MZW0vA2wVsMC$DLXw!q%P0wok zE3K!+L(O^uZQ0Yhh0;g0r&k-ZRXe1i8=tB20I(B%xwbt01an?lKWN91q^wf>u~ZT1 zetcDr!%uu%w7zPsJaN5Zz50RFR9aeg#!G{>wQT1}TF=dow6>(%#y+J;#(}Jh(pwsS zATo3k67e18tnXd+mO2XGk@P4gzO8@63is*q*Y_XZ*cX1~_g7fRK7aP(<*x`7&b*-a zFQ1K3|Mv119b5nK>BFCwFP`$3gW8jar+VGB?GTTdCWPr?lO%O_u`K3v8u z(q4=g>xJm|QC>)UA>)Ot7g_Jo@*=kv2YoQ`xBukD2StB$2C_Gl@uIBbd+z$t!?F49 zCyf+;weMtJjDef!r5?dWnQg3M534?xde+DJLdWN9I>53%PHokdvpy~tC9N!T>SG-} zWvwpyP;WD&3v?i15T{&5NzL!;oM)46x7^a3a*DrY%{04ya_VRt*h{i|vHCsizsCRh zVw5~|u;@qQxpq-TAi(SV@EjCh5H~xVC^+M`= z&AVMEj!~yhf5Rq9{W@RhrNS!lqUxWff4cq|`e*8&h3`DskJwMPpK3o7Kez?-C)-cq z?SP-IKf3D=n08NZrmVZ^x@m0FH`Q%B$+l4OLEJ&}crTO`5A;Cmi(+pm^F>V$x|p7N6vGp0hm^`_J|}F+uCh$U{n1 zSAjO{hVeB;mzTL@`Ylu>4X!mg&Z5PM3J*#y#$2h{0#f3AI1qo7b)Cz(6hHo{d!NIn z?IvZ$sv)^Ff40x0W7MD4!IE_`^AVD4g!nP8lx$7Yjdwn&Th-NJ+LJ}su$hbRu~vPr zZ$lSW9Xevq>vyzznRa*k7yCxP*Pnczuv<6V-(7ELz1Qrf?s~&dA_TF&tG$-&owuV2 z--oA3o*%$q2a2SAyZrd^&1)MZ7x)%2x#_BJ|6cL{+!3FofjP6eWupLR;^ zI?(rLxcdK4h(kwxPma49Zs{h;VM`lWXn@9w&dOyt^2uFN?G6^_?0Ww? z)>|Xl2+?nIj?q`QPwJC3*;{bYFtmNVPif6Y(avm5u%zBhn_{<^ zY!nffG>4%_9qrTbu0E|BsVUpHTe_~MbJ|-JGq36MgMv!^v;1it+KFIp>6$$JR@BvY z20C9lqw*}JFTDWbOL88H;r;1s7}SBReoZ!^{Ix*B$K$#4pS=5t`{O`)3=UKfm~X~- zjt2R-#e;Fnntcrc>SCyCUZ0p6^-f;05OrZ@Q?xb*XOyh>{on*y^#;m5(djKabQQ9x zYxOzA=dGr{NY198g{8DMj|O#&W7#ACxb z01M_dKzQzdOz_fU8~V9ygCkzn*&^epPd+g*D01%oZbbM6kYBMM!+%73)qVzk6sVQ- zNX(~zuiT919QpBaj|cOXJLfFg({IsGP1i_5L6TK&1vgn5T5aPa;@~Sqg^MHY=;fFmwNUW#~! zO~ri#hU~f|Az5LFc8;~`U=*a|X!$|3+3Os{b4BCyT8Rl;ozpLGHV8QhLbE~6oADfw z+O%7=*uB*4@6T~t<8%Gw@2`dB&f+)(rYjKJYSZFW2V&iy1db?%0DMMrEaXA#;Dl9d znlsE+9W;qQSsZzF|GsSnot`!swgK^=w|D?sZJt@aL>K!dIXuhNHYT#cTDA}AU=SXq z9YVHv)#Kr|>~PxN0}?lO4wyrSb#@=d3fn4Acxbqn{Yb!|Sl#oCqHAVpK5~~v zYu(?Hs-8x5v^KJ?nVmCF(Xi65ebxc*$Wd2PjTF;#sI1_x`L^{*4+M1^eU#zPYEY_y zq8;i|Xl$q3kwnzFyMA{d?`+!ceB$pG@OHGte77(bgOXCkLp~ceQy4V$)h9=gf=F{c0^cnLPhM+8#)O)Y3$^f|=*2C1*bY&l{ zZ%o!Okea6gbvQ_^J|~x%keG8gx%B=mf$@+{FFM8a{ksmXnilgqb>(yw*X2AsenWu) z#Rk5r5&q26GvXRzr>cY3X4G1Bxz8RUX^1zUk`jC>K)f|A*zITF2LLZcitrRJK0HN; z`|SiV;x^kj@M?b0A)@C8YktK2c0N(K0=>=L95*UXip-s^#eoEUk&)wT6@Wc5c7ynH zuAKFrV~#OWjM>p8GJy8Nj7*RhK_xQgV6fY2Z|#}-&66)8FQW&NgTFjpMNddcl0tm^_Uz~`(DIn`QhNU5h)kP z%}fs1qTt)S6RPIBbByI7P*vv<(K)5~l%ExdRtHo%JbD+;e(OW2hTM>buKyk88Hr>kvMK5fd{tir35d#bb_`D#Z@H+Fb z^edU3>eX}5tKB-W{=($sSijcXJ^IzVR;M_t@pn}7yC0!mxT=#^Bwlr@8I=*ktOjCr zDpr{O9b(u=fZMwY*Y~VmKQqK7lr48d*=soVS}prpGIQftEPeV@9*Qo`0$RlS8AWEV z<)CV3!MHB2Yn9=70F5hqv%+yla%IROaYVqX@OsReC8bG3n-Kmv8qhfbF|kd)t&Z6| zhF1i-vSm9ll4}vBBPuJ-nSfMbtQ5I|UKQ#_b|-7#=rzG(PIKY_r$?VHInFjrtcUFt zdE}cES@*)t>U5|}_~Z<)QD?PF@I-w3Ma$oYaBnPv*T`#jPmbOT=UxN3*Gl2n@}nE% z;%3xM)yLh8=M$wwyA3PDGN5{`GLBO6E_Mg%UHfBmIn3-;?{AGg4YUIH} z{LF12VPDqrv_;L4zSLCHh9tDj0N)bBJa*o#tp%Q}?&NQPp>)cnx(o@{kV=RwtR zwIYs$(SrKDb4|iIUT3KEv0NR?n3wTpblT^MktHiLSv+CF^zUBe4Di z+OosST^C&83%v|oXZ$oo?^zxL_)>E&Li#x|aQ7jVl!>O1NA}f62L7!C#5Vz7jHX2Btmk5n8d2%_Puv?oe-+?-6L@Y;I7-E9@(!k1 z79S3Kk@_Ask(azH;Rru$iniIgdZKq+u|W*1Wj|K~;7@X7KD-QM!6y5Je z(d{898p-c92NSzGOL!d~2y|4^9CgKM%do31k@_ewZ&4Jvn8h4Au{lLjmp>g-gY=U; zFvU1GrYWLc1m;T@xJLL|zz=oc*>vXcXLV9wu6813mk4HYxYHLyogT{cHB_gSo*Iu9 z*)4~r$Tr2fb?`=g=L01m2dqU;QP(Fn0Tg5J_W9W(Br;)&!lb;%eM!rV=*g(;VzA;zdou<3Z8op*W#A=VZIBn1+e zCEa6>JWfj*X?tp>YUt7JSP~JRcwHNFiiW^723T*E(Kbr$;yC~XL>NhoTxwqAW0chw zxrbe63&~=~GjtN4{`RU-kIulpmg#DsAMkiWgY(q9kgBu)Y~8ay`pJdq>&bca?;s{A zLC9%WGVmUs=&g*q`7|5im=u%c9NGYDl$8ymk6OTwFpT;z~=7@Rv#f&gSSLs$ANjHBsUrFSXCsJg8AV=!I z8j+YFT$f2OA8eTDy0SJwSV!R)y^xLT2im=UU>VwDkQ8x`9?q@|q!j1aEJ%ysY3w3K zjr9Fhata3W3U-LDoOxzm87~&whh!uxvLAdZX~@(uKAKaZ%rTlNqecOvp)H>~Ekjb` zC0(-sF5!qYiGfr@xv&vTDO3PV@ega{ z*lN8q#F{$H`wp@CjRDiJ_xQh!pFy4CDdE-kS;cr@kcO|NSlvkMUyXQeeH34tz<0H# zk>s(PWuqo3s~-1w5n_0JC57Qr_d1D3Q`KmJ-FeDbT=(zAi92Xe@gjzA#uv%Nt`m<8 zt%SljZA5VcVH$JtzQ!7nCXp880o?I*gbvO=v=!bDtMV;8l)>$ISQ+f=)VQPKC%hxX z>{XXJcEDF`s}j_rJ3_p9XNcFv!$Ld_Ad16%d>u*0d}c6_$Kq6FNFh7g7>WZbFMt9U z=>#$8eG6*#KyE3fV6fRdv;tws!%4Il=-e{B(4lc9$TtjhhQb>U^OjD!nX%0>W{T&8 z^mSo)rBg=A!l^>RN{hI(fXiKY$eHbi2vLxv@gvA-wdK_ieMc$&FzfT#!w&QiuIfRg zPPewSSzI`W&q83|^wfEXqEl)clc&@bQAHg^;goN&>crN$IOUFZE3f(tG*D$fu9Z;W z@aOd%V`YD=5p)_n8oli3iuBM;W4F@CInTB6Eerhg?ubsK>2CcT_^7T|jUjNyewzo0 z3h4B5lZ^yrmrmxQXb3C*rGd#&JFn!_6@kU1sWX`@|1veImdbl*AZ}wRZ6zc2vOXot zA8LH5_XI=9ly$lGtkJ>8Ln;^jm9&8#zLZ^5^SFxso zqcy!TM{o2Z#UD3?c!kNutTW+asa;2mv^<#`tVlpRML^-M*w7~7y{9>n7yNa2mbTlk zd&V4Vi%qib7Mp|=e~%iSk$TD+Wm+e4Lov&vL$n>OgmYkmTqgl>KC2#T>OrouCS?mv z0q%OKAUb}v=N}FL>UA2{wZIP}=J{j4n<`pmv08N! zuoZ+{r^P6zibTjZHswp|N!?3FP9Hmdn`|Ti@+-Spv_!C1*cnfig7w8oUURVUmltc8!$T# zr&O>0Lo++DU2HiAc71eNnHfevykOS&-gm z2k=mIQI=ncG-mdUm0ncDZHvXiGIch|UX#?=xOjZ3aHJOStv76H(mk1_z*NBrHg#j% zH#Un5gWlslG_RhI#^!E-NQ@J`vGq|bw7-^^4oVOOOj*WcjNO{|xK@u}S{Li0s9t09 z#Y((YRh&eTJtkXT(?X8aSZ=nMB&UdB{c(E!tosSJVN`O$^@)*DwOwaMoP0 z(1~T2>XkBX+CxTg@9<$i%wVCtU`&YLEt;&yS{wJHN9Kng^fjlv0pq@mrj6oJrFYNT zPWi02!j*h-YW>cTu0C@}hnA9@MI8FitSEwQ-5RzmC;DR|Ry0oTmY`8Qm~9I+Cx`xd zD6sF2A9Nm_4(m}0$B2xAQdR)uAq3POl81>Mycq@r=vDIZs-ERv5(PmAxlJ(D=ehS1(j@^%{LQ$b{2d zoOLE&Gh$(lEmtrIHCSv+ch!o%eTR;iZW!_Dx|;pnRtR0JD3n(CCF%)?GagY`&9(^H zK9jJY45^u%dEN9R8>@91e7EJNmh_q+-Sbxyf;~E_dMK&xtHooCuliGoc4Zd}PNJ;S z5bMiCQ^V>(!}SxIh6CbVY&xnm^dW_~>q?q@kw6B5vD7FYWQcXi_w3QuKqFUSGD$$r zWeeRh3X^TSD!eFzEqLge_eHHa8!xlVSD+KHe5*Rd6?Fv14!O-X*7PY|A(oB}*{{t~ zz;7q3FDumZK9-JNOXh^yh@O@KGJJ79Ad>IdKL{qVYg2)|U340b0PL?PHo74ydSDgo zIe2@~gmM}Wod!c+GVbYPw`_!djp>5y@MHIXpffL|Y*1q5({&9pys@gR9L12m-6(#Q zEgl1iyaMa5U0DLN!<&~;NZWnV($O**_+L@Aq@h0d6=J@fnHRKqj`=SFfU3Y;X@T?@4;-{qR^K4O$()pnF#PYVH z8jUq-v{GxhP+NRRQ4V@34=!L)7K0hl5M<1#7v`B;V#ow7k;G3o4Uu-6jRrQA(T}dg zb1Yu>JkoGlW3$Xkk2pt>ZteWKruDx@B&f51_oy|bLV1}7f?(OzzM^x19r!KnGG4H% zJ0Nkbo$~>SGXkfLUmwlwmO@v$^_A}}kU1B;kS&Ocr?Ccew>U+-d77-jnWts+JcZSr ztkAqwX1w$p$YX#}P5r_e$^nT;1eYC5utBCom?G@L$sRWpHhMctuz9C9JGoce3_~w9 zc}vgaV-Hh$V@x4V6?8b8PXi2+3w>pyAWS_H3B}89{entuGIq!&ekjGOagAKV=U^7L z2WbGIRi9x4D4jYohKCRZ%0Ii_A%iN0x*S5 zvS(te+|t{{)q%3{Yf=Mtfv#fn z4=cJ~Lgf?DBO}!h6N-2Qpm1Oc`|YOA`sJ!cLmNi+`0^Qj>bKd!9slNW#$)+U~jN4*~ zgWRn>^ttHdV|#UOi5zrY%&s^zF&s)}g++sUW0WC$&S)L?EKVP?dtGLz;fWU{${6hB z!Mwq8eMdReCq!r}=;-sZy?!(=PVfoh3#{4;CIV6d5Lqv_iYtb&q<}Zs=w20*aMl@u zQn?*=ZyMn*zHMG0IFIU){u$iv1ZWy|*u}~y<@tkxhXex8R3JGY3fuww4|2DZ(dT-V zkL}aB1Lk8@3Or5L_f{BM*J=_6%)yG}1HZ1_ZlJwrAp7N+f=*UYF6~h~0dYUYIZi3g zWb5E6%;4C8{aBC0rSmBRH@sP< zw9Sty7hCn0t(Fc+`7H6ZW8CI#l!MwEn~FtFRbh9Gb)M^NxVcoxOAEzI-wx)jJ(JWO z&-0UPZP5Fr%WK;Aex!g=NK=S^9pj55H3%t>bfz$Rn4S47KHCiHYZ(2 zyzyj#3YklE{wXMP&XS;m0)=~fe_i5zq#F1@#o30{!J755M3~=2SE_((XgKc#xmn?B zpuv`!s}tcnC_PO_t-6kmTuFR`Twvmu#zH;2hc0NXcd;VOk_wRrk5t#hLT*zA{gv2I z7LYuf+R+_IsM??!MX6kUOY4~(HeD#9;O?jD*f)dy?V{mp4Tl5X7*&pJpU{5j^}}S}^M*kq)-1o)yJL z6+{8Ie{1ahSSbFa9-HrgL3JZepO4XrI*ZfKCj8qovqC;=y4S&K0-v5#fzYk57?7Dp zL$A%df(}5}BeoXao-N_8f&Ui!%*kj=Hd)sK5qYBGVR39{Z^{6f-j|R9&2SfZLLqcz zDoxF{lDu>p1<<*Fu(l!X7(2A0Q%v#&1C-FQNYT=5P6iKp9I1Lj2C&Pj2tY(Sb5h+j zLmYgu>d!+sb^NDdMHQgnQ8)`YgU~N4bBz>+Jj}Q z9#I)+R}=VX>j*G!>ZpOpeN{@krzMy{H6%_CxJhRM|Jp;2;K4+D|2Wmgsu!P^9^wd= zt_367*~gY;9c#)1=hx3O#>Vq zpaJ|%5}(@|%|j0TwsM}d1E?Uzlpdefx0tAoM>%7K`he9B%ZR=ly@raA7+_&rvqbR) zA^=|dfO7?-oQ|qS?opMPQ_nJ*e=3@b?M68hk5r34LOCGhvYJIH)#8vrjs8JOQ#LZx zqr!%7LRmT+ueiV#NDQ9!&rUL8_-EOtQON9x z296q?`1Q1JpKY0Z$yXKWJ0Ble`6|bByyx&KxV6P)ohsyMEH<-9)EXiTDrN7*HQf6y zAQX1VEaw{MQ0?*>!SPX!nt{oSd1d+;^kdtDM7fYttFpF1)2gnOCDvTAXtFfg95G}j zz;`rppla5bv|cF2d9rVyO3)bDtfA1ClR~YlEW&2p8I+Md-(2-w;w(sAcPEb_FJluj zWLdJ3hZh;RxpA_UENp_q(IHvLqmxzE3bli-z&zwcAX_dTy^78SQDne`&iuW?j~aE0 z-rzApw${{3QR0kCNhC-~Sxn&^6t65qc9=vV&_73y42w?BmBM8@Z4D2iCKi<`OoRx8 z5mJ%+@- zX7|jHK=(n8wSdA}P027=iNIfHcZyjH+KcEu;LGAe?qOjOZG^iAU|4JP6oSidX#{|Kp*b#(xf)R`+{z`a9BgR zEAJS?giJbFr6rkSpj9`}2NF8A$QU%FSUs(&ux$8}QHw%BH-FP;D3ti5&RRT@N+4Q! zU3u!W)LN z6W78ABNUZE0dv_k3{>G0MZNUSh8CXpXPm?zA(y**EcR}4S;1|M*LHJ$jW!X9>c`k! zD&1O@d04stIT$}c3<{OSD4hUxFc<&pVbMY$9BhD_k`f<0ZulsSI}Ji|!FKh52Vd z{h-IT;ek0Oxl@WV>WFTB%w^kHLzhGI@2*=l2ULCHwPn~hF~6iG#ib{LXT|Y zGuyw}iY49N_2iCfMEbm^YvaV}8ij#-B z?52fEr?BDDX5~oz86M4y2t4vJiV<(NkGFeK#n%Dbq*0xiGOx#YC9~h%*-f@W*bb^6 zg0;lkVb;AZ@ze6n7AhJ0tA32e=RNMpcJX9o;D<7gbQ*#-mZ$r4@ohn9Y{(%v@n6u@ zkVz+0RGGF+g^r=xiGdZvd!zP*nin~X2q;&NZl&JDix9e6R~_mxXXyIRSR-jbAxFrDdiFv$F&m|eVaJi^$6H{b!DS@D%olGtSk$MXSTwrj z-n~`)fUBnKNcMbFg+p7&mgEL)H0rZ}8&xusKova^j0Dgf(uN zoqJs~S?A%E3We14=(h+Ql$x!rXsnOryo;e}D6+CY;)yVA<<&%pv+N@Zek} zrag&0b)T>X`Dk6%jev@}`p_HLd9Nt=w(bFlj?6I^a}`i3@|2Dp|C( z=S5MLQOBg<_z-w+_*$b#YDxB3q+*}RqX@9Fk=Th)N5?UQ*zF?SD|T7B68g~7&zIuYfrmWgrzX5Pu7ahT7;xyNX7HjaAE|Hs>4~wgP;Q% zK3yu-%Q?DUQyTh!A&oKGTw_hgt0cx(dRzb`tPG)(GtPBRQ|vjo+FA`4KafURVJmod z)pmj|g${aw$azC8sBEz^-CB)4lLBS+oy=Wnw}a>|+6tjMof0*0qi6t>O#AT5ZTyeH z=m=f%y#y}l!*p$Z%li3Eg%JorlS8-g9a)a{#%>Rx=orRzW81pdat0;I5??L$PJhP# zJSm|1rmg8`J3CnG>D|>>ybRJkrh(T3n%1YwJBKZNoZLHfG{H7$EERTw+5Ii8IcE_}iS+RM;#F0^0CFQY-7lAG=tZOfi7#YTQu zUXF79vS;+D677YYMitOg(%eaz$$4&z%(luX%{gw?jQ{j_>p^U`3FAtM>hO}jX3V$A z>aj1q>AK74F$XEAubg@E>{&W>kM+|fvB)b8!XA~gkIp0LvUZf8$EuwBKy->hSy69l zk#>Os4*L~!a# zTJ!qMEP3@&kA|nZlfn3AEcA|WWkB1#DKHuba)7QI`YaWFEws215EX-^u=LW^6k9Gj z8PmYKLyaIPceo~*AsgoGNMbDw$L4pS@cf9v%=5%MUm@y=h~`%F>BFr(AP%R_7H6J6 zW)VwYy2eO`8&dbYRf11E)gCKSnb97OzO3!_2TA76PVj^dGYKKeWc+WZifGu9Xgv+? zMaUj&NI$X@yJXE&{q4yZE4Fr6uM}}RIICcXgh;8~<6etF^w8XpJt?vZ0w$yf*+f25 z_CyL_0#P{72usZCqV_gF2NyRxvwp(Z%m zm1H7Bl|i05wZWDnO{=H9jcsdvrGFZ73UzzJus(CZ{xnig;aZqiX;E3X){o#r%(}Zt zJeRn|karjmM6h;rdue5O0iq9Im0jj_!s(Ws>57|E;o4guoe74?1HJggECU z#cYCe-J)B^tQAV&_>@U9Ggi9L#<4>zecs#;Qi#(NaeF1x%8d;h{aG-Wrtz~Ljncnq zhe+`b5wyZtiK93gHvjAC+$wTfa&@5HcDpg*cBU8;T+XdB;*+C+=kCzmaV+O#$JBIX z$$M)#>J#9gfb}IOl?QUYd?A<{cIPeWOHtn9wGmorJ3`)!s$)|~W|ubAAH7e-+qW$a zVSB=J7#ltaPQ%r__Iw)JhocdEc!L6|*fGU)B=GAi0fPHTGe$Xll6FNE>F^YCL8qTB~T&Es~0tBAsn1SiJ=|E6pAjMr7_%wZ0|}Dp(ANmcTPaz zCW{L?v$%&0B7{W@Isv&8J}IMMZyr03ryF8$nCy-(8eu1)}+k0x@S<-B*HjLr+ z+h87nhUq&t3VaW9EVNX6&PTV>M7%#-!B3m(8syE8uQh&R9E2WTZK_R%c(&XlC-o+b~Xz!sS+Z+J~g{4>vSjY;$nls^_}0>m$WzZ$jic*9Tyh z-MwtGTt|JRd{lK6qHtRQ zsjU8kt*w}~kp2(L@zI0jV_+Ga3=Fst$(Yz^*|iEZK>B^T^DnyOOW2j4u&S)UFn4A< z=MZ%5Fyzfnw}MFGs$?)6Yz3T9r1nZiOOr7LdZ(t19X}u)+&de-5@^|saoUPyUmr?P zt>E|31FGxWWv`Qng6GlKWz|Y>57IQ-$#tRcvJ*^-?en-%H>uOO19T)Ev+R{9Bq()z z4u71Yg!j1_`8`4MSId*1m22&m`OefAU(H&!#haqdLSuv;P%+ik?v2w_nlVKF(sgy& zm*7TK-HjC+lg;^v#oYp&I9+`5jy$?6ybkKTTJv?=4GeVaNX`Xvi>n&cr^XEZ?po+B zLF?GMZ-SApD%J(DF6dQ$R7R4clgQH>wI3w3+7MuCBswBxJ2UYuBIuPu;;Ta8FS5eZ zL{dp}LC-*navQ2)uJ(z*b+s726s(8sNdUZrrk{4nxjs8>6_9_yI(ck@?j03rn`|*p2$fI4J6w)r zTU+AQKNN0U0?C9)YxooE9F{UCXH>9?3#jxVt-2?A7puZKu^a5V#abk&!_S|lekeG; zlA^vhEWR~GeK#P+Ef&dXV3->)@im)o2xRB=yoJfOgcka zb9e1-no9J>!cj6v4y9jx*+k}!`)GZ;@CA_G6ntXyokTbBxJX=7(?;ccR+#2(JlNq! zE6gLK^0Rs{>ERP7kq1WM*SD^NY>KZE!zg7nj#n3q7R_@rqsW5IdO;Wpb_Z9o@1riB z^fde6$@l5{gF3745=Gk1L{!zqs1zpGAm)+VoUt->t906;(Z*m+06dd1@^ST=E&LqN z6x8R9#s=B$cF#SJ6HOw2wleoznTftP(EQ@4_AHvEV_f6E-P6svIU4mOZ9{6Abc{YN zhfM0jQ5=YZ4%v%v5Dp%9GPg3lYZQ92Mtm~%w=~}0*v0S!Wi`DIAC(&SMxs{~u1?{I zO7I>d<72qa5gjjewC5T_XF`yZ>NiOttFgkZ{Xx&ki6{pYBqT4QJ`g|YeT`MljR=Mj z_#~tc1k~7=9Cg9XCA?=6G(ZYX&u437q=`voQWovdOhS;;q@R^D!=WOaA~s09l|}ib z0+Ddz#(fuv@@5!$p1p8B_j=Df?jtA?ff4WFKYf~T_0{!r&HQ||#kAbPZa61knOEzD zb*{9?H(1qFYJI91Yk3kAl~+^=$Vuo^mQPhMcc0!`sxhO-BZ;%12a6u$>#~|obCd4- z+9r7g`evbEr3OphYwjp>&Q<3M#stW6**eCYD=9hkb`_afC8j~x@BXOo8B(1$smz>W zw|mG+exAnU4e*(;e7V|%S?*vJoJVJsZp0+arrl!~6u8_sdfL({fax0~ZE`Xn9Dv`J z{t1nP=}M7{CAq0#e2$kE%m;`KH+SJAD|v83-4)ZC)U8a3HldJKmvMsBz@n+%%X5Gk zkb<0mTzZ$YuLbz8)Jlsd?i<j!a6A2BBn{extu2S<}w$r0m(SzF>k2gyX8|iZml;hCc{$0GK5jvcleU8fSDVs z_+JD`okLE^AxiK5ulI8=w&xAUF+Kz3u$zq};L7|=^#0K_=cV%@?A^1Tyd4=wn=x{^ z(*76!XXn>avplpQ9auRan{BgGPwV#L3X*qAkl$TNFAu7uTRAsLL<|`O?h-;r1wJD| z;MVA)GvkE*%`DO7mQbFeETUq7JVFH3U_c@1H7PnUlx6quM8!5m&c5fSD8}<#ps6S3 z67|sBi`5eIG6x%^CoCd3To^`YMxhNT6aav}{Gg_>fX&C4#Wz~{va*BIgM^V&^%(2!(v|3&XS>XrFBRNyilZjl=t?Ax2& zu&Ey)48N_uKA@4*LP8xfW%3>5{!S^X&no%xfG#(p$aDou-joFV5)$Pw*c>~vwkzRl z3xw|%L#1gu1f_tTp`l&L{ydny@ACB80re>~Sh23Y6MIt-LL5P5zc^}%)DxY-Ji?15Nea>*~Bd~-HN|m^xI#?az9?TuFuDZ+O3a@hJ8)1+<7uU=J zd5^j=Jy@`t9xSlFW%>nFm&~4&ghCyg?3XU{P>GFIh8On`vSixVG3B%J$l{ zs(0hx-&daS2E;q5QV#X6VmFUN%aT2JI9Va_sJOZzQ0lZZ?L?%Wi0j9Scn&c^H>ykX zD7zuiTUFX_GYbPBsqQJXdllHhz|nKCaj6FQj8KI39)O0nbV5oSjh=ADh@Lp?1}7f)MRxe(9Q<_ zm*kR~7td-T^SQbPu|@$!T2beLVc830@J2~S>%@1U>ul-a`IVveUp(C|GEKU}zTX8w zaLyvSbhD_8YI9zpN?OJk-k*A~lK#junqP;4_2eJd6}y)pyG#_)a#R5nh36VqCM(9R zE*PHCL?0In2BIM>ng%SxY=Sr};OG+823avOBkx1^W@9$ULaZ?D`b3sgSRDZ&Z?SQ; zqi_@FSV~vGO+}7wY$VrBnn}U;4oE625C{~&i}%>BNDDW(gOdi;lk)qx2qLpc0hOo#VV69OGVlHdtKIAzylf|02(s= zgF;0O!_cQqg|MBDK8ja547q#>7n!F(cv;=7xW0|vqUoe;USYW|dwcve3lST18J>hs2};zc~(mHu1X$ z=2H(RDb?FCFz&ZU*1~Vhv({ZRMn*Mk?dy@A35_G-#<`{KU~xwOgv`Eywrwr5$)-lS zLMOb+oT6+2y%~_TH3$o^Id}!6t-FezIiqiE?B-`Zy12}*(4nrPRrl!L9a;^FtkDDp zeIhIuOqa;kzY{oUI=Th3Z!E~>AByFi<_Klj34V3tqL5tmxf@DYIeXT%K|aA$bnU}s zRE=&~q}DghUa6qhN1>-2q674EtXOy;u|k--bkWX{wbgwz=F)U90l%_83`rRsKRfZQ zh`<&&E8&{3>y2$#r$PzUhR3vjGlj960p4b9Sl<)H^K43YcuIG8de(#tfVJ#5V{Qhd zqz4YLf>*{VEwBv3Vn*|(`4jmz_>Q={j||3DTQVpJd^E&nedW61MdNfen zy5Zph0U^O?JAwpSA_w~OIFi_kO+k|8BY6y2(WW099^LR=NUpY3gznTMkx9T zpYLr!>SBKadWAcS38RKK*9_MajiD<0?IDWikTH)jr98q69RE)}8$8yA`GEp7C54nLifv)n=U-E<9uNS8u5%ro{Tu z#4xrwcHXY-Z6YCD)~04dkFoNQhm5|pFLiTiF)gUG>=+SlXP8)(Xm3}H6PduFbEY+) z7teLK!8hq-SjQ&QO`pE87A3u*48xo&{8rPTt!)0tB+vTeP1IDV#$CE!(=k ztg$+CeO)Gely^EPR?sv(i-mbRD!+7;-W;~n@E+#)edRX$>{E`K`^9DUkCZJ&mz9?^8svc$>os0lzPyZDp~yW;2%e{CbBNve#^p3Ug?+to6!E5|ZHZLIN~dglu*eFHC3M z@Vd_P`^J^r&)S^7N*#Cia2)T$uUp5x?gwY>jTXl5oYRm=;m{FaELeV9&vR^}gt7Bk zZEK6}N^wSgLjuu_koKG~o?fj{>@?(o2qvSrEE^MEadd#6*@)ib05m|$zn#KTzyK=G zHa$2RF(`=8C9+Mf)(Im){3EMZD^3^VPwPrnGK-~nzU_b-&azAO zdgo8H2Rew<)weSBwRH7rsWU0oJt^qR43b^Bk)~L4wMNH-QzT3C1^<9Jyg;r!Xo^Xv zifXt5#4O2W?Mn=l!b(qupWM0uq6!mQ-mRQzt`JpPR}}K3R;LT~190Xy(-KPF+_9|v zVY$T5jDoFL>&5a$W+BcjvgLo)1(h$YN7#Fc1F(u05RRV|qGjvX(74s5Lv1RhrB3l; z@7N-olN3geLW!AKA>Ihu%d>Q;9&4jxPnsFK=cV^MMgD7gUD=SO&HQ&$9iQJ(y6Uap zBekXXqcnClKm!{-V1WZ6dpjeg3nxZTPAw#?XWGdk9m!>PLkQdqsZfRvjh44G4-B~q zXB#sSicYdr&}m5cI#wt}Dsb-}Gak(H1Uyym^nw*)jTRB9IaBT570*Vr>MB0&E#le) z$EF(0O=P=z=9%h}500np6~-!lny5kL$(-21*dt?fZ!fBSN$5z83?7K#;0YdcC>kJZ zW~zYtf`-2mJVsDzBzkKK9OpxsQ}@vNAY3l1cJir>Q_@!Twv-CFQNzI~c;&ItwjN#X zF0E|CDwP$`JJ(z)8?M2LOcC}Zd?YRr9O;Dk9@&Fp1#WnR>pWaS!MWz16Wwdj+as0y zaWt7`nuMc@-DqE3VUb`LPIovg$;yCYKm>7vefT^S`|dg}MT&N!sY_FztX5;L&z4h~ zkTtSEDcd)i6_l-r$IMpvuL9aXS6J}9;r@pmK|~4nvx>V#AnB}0oZ~FFifg(xT}tnj zHGO(KVg4&%iF<&cfM(E*x)|HdA_!Qh&@8%u_uV~MX!^O*eCa9$!ME3}6P!q>6N)l< zp|0rwA}MTzgwD8&b5tbnGRxEe-d=B6&@l3%%c`Jj+tXl1r&73r*0P2(+6#ot*b9`G zhTGK_KeVS9c^W)%jmno7;zjpaCbPp6WO~mIn@P)+d@dm4T8+-yirqGwr1rVo7Xqj* zRa$66z4ep;@Rk1m0$vPHXQp2FV!$xKNYHOAJ<^xqmH@{c8T?*%W4fm2s?hXaDedQp ziFTR2L452OTC}650vDjUsz)B zQ*mkUR!nZ2>obF*?Q2&o!l5hNCbG`so>9?{{?)o)FOjO!0dlPQ6=oXSg~e7W%6v_4 z$raIG-ZtJju2&=94CGR%uzAvaEukki)??j($wt}gYy}_oa7`FfG89XYH}TXbBFPNq zPJ-7r$ZJYnG$`+e-PM0h&P?cIIfuu7sAdrae3B4!h=^wylUZ1frEMiQOR)1F8kIB zRRP$wv0QRINEBYld>N$MN-Y4=fVl%{o2x7PKs_X6yVVXNtw)cZ(8niVhEA=8vC)*5 zPuRgBI&h>@$Q2TWTQvEyTb{$|;ei#IqN7RLrVX^&Nvc&SuQN_i{t+^G=!a+og3FY& zhHsrqdmq2l@HgYMz3A6@OIKX$5HzVRf7I#NqI!aS#PE6<3 zmL#Uaq`~J%6w1}NF~YTj)Rm$|(9#l;OySt4(2WxXBwM>_(mA_oWNHi`Ifl?TDy$kE z?aGlNz8VrB6+_nIFPK3T>8#|U5$EKI@pdqWVo3+T3sDxuVO4in^1J!?TbPH4Wd)bgz2)~CA0>> z1DK&Ar3*i*puU@h^n=f)dP|$DuO*vw6tB~!F&I;~n#uhT9CdpH_v2YOASpodEkOMU z&L74KIm?UXU6AjT7q}ZFpU^`#m|Wwe1w@@3wS-K3jT>&X4Q_KiMY^YY2y^pRG*~XvDIo`=jIp-@hFQ!~%BGnhrPsEJ zkn!g_iw6jBC#&5G_|4he#s4dmfo%im#OS8pLh3>YvmDH7tmOdhsVxbrjmGGi-aJsK zQCPdEHc%7VH)GSYq#4sZG+3+sas~}L(yNxyFNOfKQB&JiBsh=c=8LcI6mV9x$Gc)g?1}Y8|RVrIwAAYY}!qd?b$~J}6QKz>|;wLAKdL|@t>7>4}qAY#l>4uDDu+I|_JFy=t{DnTf{rI*ppuaEiLU<9`p6NxAXLlS`VMzK z2d$!(9_$efon)H=u2yQ=@J0t{Gf@G=ml{a91m}Te;qIzOqYa8$qC|xbEn*0h{bf{T z)cU@|+cwu2!lW#;jaQ_%7xJ~4twJm7Dl~lm<^~##kN|EqyCke$iLLmg&$qVuWD2{A zXNv16ZYanoyvNW)kpUufmB-(}L8C^9qETb;XAcT{ZxTl96=Eevx5BapCFT0e&@#6Q zEQE6?^)^9w&6y<+i5lv`KTw>2XW==Np1f`Z^Mmf`2c1oG)Gpy{m_xik7u}M+N6n`P zvSjb1u6+IUbs)D}t*@?W3tU5PczhxYwR6u0d98qaMYk67iQ285x9mD4_9@lc82yKkCqg@k#@4SGE_dXi-oP(hL9ta0%QwWkpo=z zXG2s#RzOe0qA*1B8uK($Do)lA4W!9oJ=#&_#@&>*!QugOe66e%rOsFtI$UL<8VWh) z8e#-`2Pry@){fm+G!z{IU`5o=h{#5*p)^xz?E%PcXI2y3SpU}DA6xv-DW?;8?} zHL|Rwq&#+>up5!yR|yQUnmtibpa&FkUMoW?E>Z)%YY1OMG68*9Y&)?}O?S*Wm_kr0 zQ4nrMRP*!G<*!jI6^@W--4Z{+mTGrht+Opv?ltRb`Ps6M&2VaG#?pb|Fy=|JA@l&s z^jFP9p2M!iiPGJnBaDRXS7211_m%@gh57ypH9?~xs_abNTt`hS_?99gNpRT;dDTY@ zn&~D*d-4zoD7yv9Xf1Avb>wW9MJ2CEvHk$JmK{8Sx^2Hsdo2gRt_Nm5U5eD?8L;ld zdaQI_wu{KVw2264F`S7wnqJeOE-EQK)v z$`otse(9h#A$~osmSn4&bOCeQr{4utXu{|(D6~=5X(&^e^?`+m4ZN^kCKoOP+>#ih%uQI?o0BWq-<+cgLzI%SQu*6DyYPR z3k#C#JwZ`BM_;tAB`oh(xUc0Ba8lA>w?Ay$XeSK+_foDKCV?5GtJq_7qaLdXfB!1{ zebRcTy6<>!wUYihV0$zq8%NupnFhPmVUe6DRoqs9o@0X;!H%NyS(rd;yBIvn)>E%@ za0blB-M}KFvV5KvJpd52v*L~l`EIU17<;%*c7U+Vt|52Q`^V=nwB8IUTkvl}ZjsE~ z7V{_68>r)qJ`V8KDbSSBtV4{W^PI-T(*_F@_?IJ(i4ZXuCSm4v1KA*@{IS ze2M@i4c2`^P@+g@)E!ZE12xW`#OLQX#sLrQczR?C=oTB9Ju81kJiPZBjWV)c;z7*> zU36!oL!!>()5nxfzO}aNl(u3v(zCidB8udTOh64Qp^tCtEj$hv?Z|u8g_u>gu5<7z zcdx7WT}N+avfZ{y?#uHz z+f&Bokgba~w+!Nu6xJj9FMI7A4Gwt4dq)PsbS@uy$%U0SEFZApDp8GUMI*3XpM)pP@4&U!8q?-pa@&!nf__`C-6=!N9mg2AHBWMG^lDd)x~NOV1bt5+DhO}(s0 z57`UAOl|CrVL_8(iaZaB@*(O2=NSCr;ynw!g3(@pMA2}rB(kl$bG}Jo+6;;&!z)8B z7*Hg;GoOZ>?BuY_l#YQx2GfF+X5*(zLZA@^kPbRfr7Il>n7i!SOjy&y`>q3^b3By+ z!nE)#VUCwK+fi4)o?A%R^)1cUp}mD>4V8tXK^zP@OcI zRMu^=!sguKnx)le)RVgD@iFNsP+XE|lx_p|+tJ;EZB*Ab)S0SzoBDSVVm~^#zT-1D z#NH2VT)1^zzjvhmKogL&+?-<&7aDo{$9oL!5!u}nFKE;=60621$_ zbV<}4&=?;497`1FaTX)nIan#Wq`5o06M5#;sbKlKFs$*)*&yig=`2c_E&VCK2mBu_ zjd>>sir1+KQB{{7QwYba1^SCZwzg0tC?|7@$XLen7$tnj$ zIVv_~3+jce3u*E3h84ylc`LlN4u?Y~PB*1~9+^n{DVB4#Da!VqYqBdzW40MNU30FP z5_2*1ES9Ux!3U*>)T&$J(Jlqz7{27D6199?6J~4VTY!N-qv(TJF9ctqe5Z@vB>p^c zguc?|Y`R}XVMAzV>F7YNmTNW?OKR|+V7}E---iK<*&4V(F=^V3&$Ttl(XQDc*6kd9 zYZ=HX^3r`mkD8aGk7VQ}(~z+^B(7S9=Glf3WsKeYfCcfGixF6(PnOECy44V@)d)u{ zY(A?AS0^#EmtBX5v0`l@=?hmf6>;8zfUe&NW zO5}>kEN{_RJr(aaJ85haGi{{UtzM`eE|CL9V4dp`I!HIe_(T90VY_0|t%V22NokbJ zt%_UGR8oxFr?6s1-Ut^p_>j;Y7-rUZZhh;UyX(1{#($sQeRWnE&T}-&rC}`qSadPLhP?gZHa75-o=f^Tukaqw}dN5C$f<6a8V)0gIZeM5qT%(-T$pr$}rY3lMFa@Av zo_E@m4lnw^OMSBU59#yI0s{C$CUjVBlgwv{mt+`=j zuX=3?Q7c?YTjVVu$Q~ndyGUJ&JgBY|{SF#Vd!*Lej~NjoqLqB(``K&vJ;< zexdTfgp>lr+TxV0KYc`HA%)5wm}SM}(|zR;eJEi{2;vjp*AOh|Wnz8|x>teQR zZXR%5u$=i$JO=L1iT&2COAlE-@A4Qp#M*g(U4}=6pf@eUNN7Me#pk_ltxjRIGLx`j z65{{aajoK=u~@Fm8ugW91JtN+pX)EfKma8BG-!!6;6CY(MNwXJtc*Lgu(=8T9$tc( zrYoO3@||GhB|5h%kfRyi?7`H<1_m9wa)Pu@cI;k}L@AzHg&a;^4$2|(;_40GSmoeB z>5%)X8{C(5v0=!4^b4ML&z+~q&LdkQfP~ppoyZDYVa?{9+<=&+!k{j#Ugdqod_01hu2mxhCUhUg`Sc)su7bmEf`v`B_;V4 z9PV5u{gBrM5N^x}!%A4JqA5VE&NV=bkx>=R%A$&oLIC!V4-#z{xu_@n-*3qHu&pgdEG1QU5j8fhe(7> zbGt~wOMtm_`2233pw**GApSn1BGBn0Ir`QtO+9Si(5#@>Xq8P-G-{cyz6PZbrVn)b zHP6kV5}`YpA-pgZ8oJfm*);UDu-rp)7glMfv*f*};DA6sP^nLj=Leo%8N7GktVR&oe*@m`~bh&QMXM>nHs@~XwE7D;9&3qn`! z9xb36<4#L`icdz8OJj0nPAH&mvB(UcuN}Y_B*W-fs}Xb%+mj+|cD{m+S9EIH4X&e` zin%9Qbu!C!@@6qEPTypnBAc*)N-6CM^|zIDe%-p{6QV+Xe!3Fs-6lH)YgvTKtO-3) zTzVeIrgwbfbO_%fl zqi7%v>_LVGK*hO65TS(O67f*>R#R_jC_{$nS)x6mBU8--&}BfU{w>>-qzHRh_0nW# zA|pZ3XlrWFuNEU2dIz&KBbnd=im=7Gt!a(G{`OhMNY!uq1aXQIsS@}Llf8lt{60f= zpu?<0_{AU{V8479Nq}ek>%P>7?;~GuZiRig=KDRf2Ag%EOI}?s+LLk$x2)+pf;+Vv zr%pSmb~aW((p1P=)BC;CMF;G!!al48A8y9$b>aggR}55L?X%EbEXL&ofw;P^O1Dx$ z_HM8Wv9|75p|$S4PEaZ?o~pytCTqH=l*1j4*m~BZHas-f-D8+c}>HXrui&t*@ zU~tq41;+NIw8m|?ZVMZ3zHQ$=-CL`ro`?%>hK3Pp6aBuz1aVVBN^cGpcYdO_LMmPCCYy8UCRGFnxW@sMi*amOM$&G?k_#V|W9kBC_pBKKWTOm6wtvmg zc8KInWLwtJlSG$lioI)`w09fi;VY8&rkkCejaaPYwPbAOXc}gS#qbLa`ojwm4QsG) zZBNxZbsZMzVb&52djnlJ-Xa4PkEY|mIjYFAB&T-SWq|2JZK1NJ$fb=fS)-^V#qmM@ zDHeML7hd2->2k38464p(mE~$U!G${v0(lK&TG&h~ek(od*OG0%+~{Ww7{Q&#Dn+D@pHc2F3DWz3QzKo6Fzo%5oePe~>eL1uzk3QW-_oi+v=yl|KAd6O6%JXv1 ziR8_Zz7gHTrO_X?FjaZr1pIiz&BsBjcO$;JfLU3wh4f1$r4 zQ%7>~>7myWx>|0e*pu^|yUrJ2dV2A&pj#v8Ux=?H8g{1QiIxfMVd*Gd2`z}E8d^MT z&O!mahdXTn+qhA9ey%cHtQ}s=GPn05RM>M^GoiP5bP~vQRtlH%0#_QBip9yyzQQDG zbq#>TBk|I$&SEKC-P>pp-_-(MdIWJ}^j>Go2nu~%UEa#-7O;?|uF;)ms7$U^uOI`3 zh6cj%?Q4AxHiU`&y}JBomh`d)qk{muXsA7tPjdswqhr%)1KVl*G0Wd+|K0`?Q9$I*Qr z-UzTBU9ZJdRD2{Wbjq9Dpx=)MB;=~@Qdtk_HFHp=N}bS9FKa$49L z=)UPST#IJga)!H}2Cdk%TwW!wQM}d4I;AsHD`LV4sfeR5q&!-R*%Vq`ra@+Tam6>z zE3vH0(%W^03QJue5t-4?!S)Noc^-td;jj}-paWRXW1n!QZp1Z%($LxPjeAnwSW0?U zJaZ1CROT~_Q%$MQRGd;|Wt?dGMdAo=04}_J{csG2Q`hE3#%&t7uWVEKh&9-1 z3CXa|#Bo}JUOR&l8l-spoRMrb(4yoYz*Ctc3)8-0DrX+)Kt;4T#z7%K@~Ck-Nzq3u zq~^|v@T7Y(QctCXOj3e`^G0cUrC{;>(uGr2$vdT7u;0!?nt-tBmv^C^K5x5ijgBP; zfE(tI zD}{zk2XZKCb}I3~v1@usJ;6YXMau%qdnW}z@<^yZjZDJYF5-E@)y07wNFxMl(+DM= zHtZR-B{H{k4qTfK4J|c|sj6F5>YAH6iS*Dorw(E6$%+Mm@+Nc|I)Ue!b#0lMiM78Q zAVi^ide${L7@2d1!}<4`vAn6z?DH(;sa4GNjOAmpl(>CQy3I}{2zkF|W-0mWlZ;_* z%uH@O%P?BW5>M5ghwQRQ4<8#8#=FYS8kXIHX3{pEHiXrE@TK}r`z73XK}q#RaqH#O z4O|aq7@fkdfm-NbvsKl7vq4XOD;*~G$BBqhIJMrWDn7KW-fII%&8kdwEUoew-J zf;(x96fW7x&cy^JTJHTQj~_c}QuEw_XFXra70ID3JG7IgM+KzadG=!68B6rLsbLRn zd_bc?4sICQ!d4EE&jQ!Z?Nx-%Qu-KWyY;09O91FI8I9kNt@NIm|0^&7DO(P7g{27; z!>pIqYB8XZ`NUHMx^f|40^~}aZM{KIYpj8?(JWvp4+Bh`OvrCa`-!=Mq0_1k{qX-? zEvKy#G1E*v@rYq^PkBX*aJWcLr%7b!Db78UC@@K3Z<1%BvN?_*^wafQ)xfET=642L zB*$ffvuVhx(XPNwA3eXk%6f#}Xm8DQb*}tPIB*X{)%PL{JQYyio|Y`Ps%#(Rx&{OO zzM7dl#mkR^y=eL=dRuv+9f|wuSqc!FCN0w~res|nvMnv}qM5YHtn^RGDW@o`>vp6Z z-?tKK`^PM#hG%_J4iUx@V>=%re(b5QFhx;FKD`8ktP%9C&2*!evemdTFjuSV5ry@K zz|T8muWu-hKf(lMi(H?kEl;uK6CffcB!<)5i`;_e77_o6CrLS0mk*W~^tM?`Zf&T! zvyK}>3dI3T-IWz=Og6#Ris^K1QX|a7(+n^4y0sL+i>7VUhOfCxkgF6SN)8~okEp=b zuCZVYM;0mDcxo;-Pojmnk6m}IZI^kF(_L+_l@8)QY4jH8&Zzv21kTn1u)@-Cpy!TD zjRta(f`$bnW5Zu(^O_E`yiV(F{ND8M)OoM)rmdQI7M*-?ZA2o7VJywb1SC8g?Ad5V z9~U9C>Y2$h!qEG|JMoV^bd*NxE{(G4I++m7MpAd}72+P|I?EN{GWqHX5EXS<*V%Gh zV=i=0&vrhjBx;IIsjYC5%sTna8Rm>-%_!W#uq?0WL|1S>XC(CoJSIIzZ#sA=evE~` z!$3Bc=b`cX?)sciy6d*UBW#;x^vK;!6QJO&Ix}xhsAcJ#?XKP?#;#!m+X+iG(;vUn zp~zw5!-$yK^_LyHAgLL)o8V;i9l}jy>!R&NN4Sn$mL_yf~OZ` zNslW9jgg3@jsMB&u?AFmMnbuPm1ZaUy`!F|z4CRqk>yqe>Ca)&*&Lz+qRo@M$BKD7 zP+G>oB@l*g6hV(PbZnWFo@cPak)JtQ(JiC{3Wc#o(<89tA7{2XMR&E7e@+DyZT-#BMsR}FB|j`TiOk7hhkTu2`xF&a*G zPJ5#jhrIKchC_my`7MB;Ja<3R+?t#b10SohKL zT!#tg37ZkW3~FNQjk`H6P=Tmc_FWep%210jb)OnURF?nx7{W3#8b?tWq}2$N0xau? z{F^m^(E}wxQ7E zeqN1nDrL<&Thn|43bb-{OE&Cb?5eL0R)vU+6*YVtZ4puF3FJYLI-wmfTB8lzGAWVh z?uJyN(7t3f3l~|4pSmZi;7b$s#Gs%Kf31kkWI!r_xwf%f>C|8#atMtgIC^F1y}l6W zl@Vp)TtOs=3!P-{ex7BT27bhtUTfmvkJ1}SE>d8Z$$1l2Ua0hY)4;Cb!|M|Mn8d`^ zHY8wn25<=U1A}mVz^yF5{PsS?4@01j@9TMLAJDUa*8rxs zNTRQ!Qr1>NKkeIbfN<($e5#gOqi^TA%i3tTUBWJKlmJOX)WHJM2}{a)-*XCXwXJER z2Awgjd1as@FPIr@>&TwVy$cpbuMJ;x3$EHWVbBufBP zq{y#3uzU}t?c3XMZdVIAEowhdNOgnbgIwvs;ca_j!+ffuS~GiazC>wzT?}?==D|W{ z-<55{ysr8gPbIcZM>*TfsWeg(caOC@Ebp>(*1M2mQ`W(gwyz&QzPy@kj&M0TV$spI zqWRav>=yT$GaGkf7r+JzaXrnVu&FN|NeT9+wfuq}ML?~XbI{zm!&r^=C5e=+33V`R zxL%OA)6R1`1a3Ai8#SuLEXEHe``-wY!OM~~wx3?Ay5O~~BnAtCxw%%~F2 zV3uG5q^^s2-44`M7k;d^=Ky=9GvA~_k62o@*|^I$RGjfjIQS(4^GMR(^zK01kEwWn ztzh83Nngw2R&Suex8X7<$>RiZsK5Nh+}*aus@q$C}9Q9q0uQUf_6L-(Pt3rXmAg(@;gRW2Zs>q}M%o2&Ft^ z8`CX8)>hAiwPcC)B@;E-V7p>DcS99i=!p$PNs72&ofu|Up8m4&z6-}q?L6L_Z-L3Y}KP%qLfLdfu>xZYW+}IqwnkHV+)kpnsgt~kR zd(q9jp2*>x2i=&f7q;e>TaB8tspk&Kxz=nnHUx*tdhec+ps1<_AUJ0#kQ)JnIIPoV zxim?=mDlkE2~~ul)0=RbYetvu;j|!tYdzv9M$0Al?0vEj6$6ts#1@y-cdfXa86aqd z!W|{zu_V3OxfPO&`p5=k^fX<4nC?^FHB&Em32EEN4yu@1t9J!Jgv!*Q z5-<46vb@rEuwW{yK&8`C9$AQKZBrV6!D<>DHenZzTr%{Q$o%#VSj#GeqIN2Ros%9dnJ>j^IJjldZ^GZAU zSr&8lGKTm#^?u89vwtW%(*Y_qKrCo;X4EXzHa1{>*)7;tYc?)#&?SI>Ad`^YCwCRL zxx`oF6JLaEm;*yo)a&Jg*hT^0kEXTK+f{WxTJgx%TeuC#66-X>E|ewQTzP{X^k#T3 zLb|Fx;6aNIYl0Usn7mEa8Y`vs-d(b64>VJqbtS1zYAmy105*tRD7^!)ruk_{Gsruv zePDYIwgsRo>f(PG+8dP!3W2kC0@zj5dKn6CCHOr?X0f3zYm2Qd*}|2_^N`QHkr-Z@ zWB_I*hGMBMBDrBkCJ6->oU*~7A_U?2 zz2&Epp)LfV2c(~0<#RKbN9#Q)CmcG@CX}O*S=2m17Gi=l2}L7Du8WZOsx;eqU zB_ptZPfgUf6dfOKoW6@Bc&I5FA@VHwcweFDRemo*ImC^FmK};fr@4HOIE#LJay>aT z61Cb})oDB3dFYs2Z+DK8Fe!f+H0uJ4x;^d=VbpR8r85+Jl+Eg6_j+oLWK8w=AHK|o%3cuyg&Xe>XhNmD$j zC^tB3-cKQhO1p0}F==YtuIYZ2lT|S!OF;~iHVrt4=RA){RoH(Oia26Mge7dm04!7! z!OHN@asdX2k#N?V03$JWwJe>r=9yUBKbz1;srJ-+4ajoaZBUciGf2oUs7Yh##4Rbr z1DZ6&lZtYK`Q)_};^d+Trn)IGKq}T_xg>R)u^jAf`(aiCwaJsLK$fDA-H^Pu8ve-# z#bfg!^NCI3h{LHS;)X#Vk|J!@B7M1dOn@G7{J`_T9em)B*{$$X(c@{DdG$6j>qsod zG)m?*jtPGg2l#amH?}w0nxw^%-c(pOJDOcib<>vDu}R(l(H0+1PD61`uxRB0TyaC* zmZYGb%Y(&Ef-IKQL6)tdS992tKbAgRy!pK+6xpVG>y!zg{9VfT2!g3i95x*iEwocg z%6_Fy;L1cItwWqU#&2|*vD<;Q~%A+OVS&HlI$c~!T5prmrSV^#P2#@BHlz;#w&)?yhUKUv( z@^bce0c?3WbGX2$q#ZCnWdQ*2V{PPI!fg2t) zk~Yy<=uv?xdG$kUE7P@^R}D1%YMWS@9(uq@#N{`73FE1erMyVi%G|R zBGL5W>|>+#nx`?L0raYF{s>|*5>Pt`3Ru#EKVDFiw)s13KJ{ir07{L0c&RDT<6t{9 zlc$RQJ6uJHN=UVlxly5^uDCCVmq`w9*04hr)XkHA$;$1SE|j%51gH0giQU@Kf0jP; zgMI(ASJb6!Bv=qPe#{TLU|3&3KY?DVtJN2_BrI;Nq4G>?O&ua+o$j z>3q}dVM{fqI%^m56l-EwW@=z4Yi-o4^W^)Qde4A1eF%;!)Ko2l@j0}rzG+M0 zy4ol0s%}h^x_NUd$2(;0px)*t zoZxHj4S9J<4NMd`B3wKWxd#ZWUL=HucX8^fHzh7!Ud80Xqh0x5U+ z%P+X!u;e<)X;mBqIbd)f?{IGD5E1DqA82)X!8Oxf@n~|$HlyZ$9POHyq#j&JoNgNZ zJP)bxnDX!gRv!$;_Wy6QSnHEUM-OLVYFP-BKY<96>{L@2hW$-5B~h3v4z~oY5{~M@ zJ(W{FuZ0s!RWE0!)w-l}V-udaGMEGQ3`JHV2>}w9Qzy}@wY-@ibSWW{QNohs4fNnh z5Wt!~q{845B=6TI^EGT#tFqavZ8I5@otT45;QuZtIkbGiG3-MzT%Vl5U4!;F`QHsJ z*ZVdaTs=l{eVgCfo;1IDJmXUj0}@7O*mSN;m(CI$nU({)2iD27(r+C0oS|nNN#S88 z%3F>zr^O=PQrKwk5Vso#WpYO)*+P2*y$OjSA?%ZV!!5wD`^ zlar{IKo(0jyn_atLFr`uhHI(Z#}hAMIbvDrQG@Ek;i_bBbW4=%uHG!C9&j5*iPge3 ztNMrvxkJv?2jv@Ha?@xHu!Aolp^1iHE6fl<0VqZ*-Z*x5^SepzxxN3mq^@q(ZNqUp zC+#QXm30Nq+Qmq~FFUOGMqiC}Da2{Lf(9dbgaI+|R-M_lWDBdaubClH_fZTG-#-I= z9=7Q{(Mpdbr|m(R=2!PJ(}?}|-RPKDyqAsrC8Se=Z}u=aZh^p@C3{kd!Z}+HRb{o( zU~SZ$$F6J*PT3y2;C40x<;{(kI>_v=-zAfFkjt7y0g1H^ zZIbG@HVSD8O^uKHUlKtg-yW%=MmS2iMy@8#u1F_p-3R#wn?xNZkEK$m8u8f%c$}%BskR0+!#i zhZu8;TU95pM3*>Ql_Hz9xq_%9x2ZW~wXC(sTN&ENdd0OJml!5=G3pFSzLQ3a_Cb*2 zx?FyDN+zJBC&T+4Ud3P#=p0uLczIbk`r9vzKUTM2NP)HkM@k!Q_Dl#Vx~Z#0i0Uw6 zCGr-qmbLxm%GRJIv1I?fryUI6giRhENzgD}E~!@}pWQCeT2<469gd3O@bcH3)aKO% z8-RPu?lCu3pPY5$Lxld{b;Ng7nLaZiu`O}EOBY&LG`wGcvHIL2+Cxj4&VLE_GyHX}ahp2dvqo$lG0$hJt3?XE~Z z&b*^?N}sf}uhs8sv%Jl=Ig5x!$cE=O;x z0pNg{wMIrzaaI}2>7y7)ujC@h4el%C%}Xs&3`HwfWlQcSbx7pK=*Hrf$hVNNDS&iAQ=9W%i4VoqN2H=F z1{+CO;@;R(1*K{8S%gyw9wdT~vf>{+yv}S@vEHF=GBqioR0MIUn#6vYM)M z&;faTfT;D{H0wf&b;g(-3Dtuv+5=?f%-P~UiA3s*HY-;}I-US8JMs^ugez^jDBTT5iJxTruimK6yR~q|ZM0ulSE**!cDqhQLrQX8+~Ptu@FFS4McBHf37d0S z8=gW+QxX4118j18n^W2m;4xWaa3>yY6&fY%*Dl`Rh_lIoY__bV%^Z9=EAvb$voKG| z?A$@1;=QM?VOM7r+5oeXs!0wKj47^^kd+DJ#5I9m+5|Nz2RtRyjgF^H_*RhH{EoE z(3~)MMTtDg+tO18)=5s?RVC5$14|SiEG;nwnlBfz@ zK`tR6?Hr0ML9mP#kY!o9Eox1O;F}6k0Tta1w+0mLXXpyFsL`QRkI9u*q#e28KBSX& zt$jJ>d1H7R$R;`HC?*T}Q}nS3{MO}_SZ~9De7FvrR4OTE!N3eTFcK2wPQ#HRxVDJh z@bh%3WH8u60uGYW-2p^x4EGCAIqG4eiygp>F}xSir4VS_kwqmZS4KhAR?D=>5ggI_ z46S9@ak%^&Fj(@506X?W$LyCxbidqW#{>EUDx~LV9*RJ&k$(V2GD=Ze|1g zz%!z2iCrwU=U+YHk+V`X$GaIC1&Lt$P^g4M!{Cn-C|3F;|y z==)ao-z#fCcv`Rns#B94OU-}h&~BiD#@fVoCMecXKL(|UsaXP=3FF2Yl6epUb;OS> zC>SsBaNSb;m*lYQB*z(Xu?h5=8gE5X%>_I5Mw5U+Vmh4N(X3()h30CTjzfZD+hO1g zMF(w4Tx<+fDpw{tHC{_YEITzH$NY2^$`3$AyV@vNCykrohU&{?$lS=#rw@y4_`#pg zXayWr>@)81*1*86*IysflAwKbz4y@>7^kVhL1;S+6}#Q~z+mkOMr<~N1RSxFTW}@@ zMTnJ0a}NpV98DKvC~~*u&WEM7(z>0!N>DqD3=_b{-LPS@lk;-i+JKMirY4P?%ucMt zMjo=it&28cg0#6~As$d#4&k*?(sT-Hg(%PFy<}r#-&&GuO$A#UB5goLBKDvpA5N11 zWwUMO7a-PivZ-5DQmp0U(9$I-Qo13;AYAPpM2$66R9@kFsjH9IM<;!Ac_8qVT;-Xa zgDX0g?}*&Itz7x8$jxn8#Zs)gB?8b?qwhA{j~^KJ1O$gj1=p4L1C%bnSml;?HlM{&k?RUhs65Mm{CL(-^HK z5bxnY)Lf_$wkG>&N@+v$;fJT*#xf31T~ z;ZN<%Gl-5B@-jW9?UQYF$X(XcxIOsWFetL-!RC;^s2tS^pNcR<;`kW|(1yZO509Ab zG^xk2sxRI@kjiy#b;(sRp#<&E#&0${=hfFr2ZGZY^aG}l0b!B#FlQ#Xag8yFjR6@9 z=JQ0y3lHZl2T=Q}69OSkUmPk4f==;A*w9^e-W3>FjiUf;nw7$#PcEn}TfGZ2l|oTw zh$^-`91hN1MI&%%Qwo96D^DDw z55RtO1ctmzHCz@d?zE^G1qj6x%I0Ua%`cb6H@3=O<;wuVr&YO!SoEP;Rgx`ht z!t8;5>Ai(F!jz3NHf^|fm{IhjvRy+8V4~4h&8ZoLI?#mhTJ}IzhCjgeUk<;zo35Kf^48{1(=LaxKPAw> z@L(~i)`mxRPDsv6eWHKxnf0UWgb&*Rf0XCovG@U1)_caWukr=T?O_~ygS^_cpsjSk zzOr@3kzMs^Ft#P1^rEbYoZ*S1YuL&j7_+64t=kf^F<{tfX|a?TaN&6?w$}yLzWp9v z&8gZkCA(Lp)HEspY>`c_snlAHbsz_$fD@L&>6H+YiPfPef zJ@ls_=q?+?zRDK}w1+|L4T^0KWgVD)pnpyPU3GnvHW=9F@QOP;WTn7`LNuG`Oc>T@ zA&Nf3qbdQNv;DLkSq?#hc$20V3d_h2vrFZG;vESIOX*ej!CgoJBpCP`AK=XjM@u%} zXXw~&W$Q2+6mV-QT@$y|J@MBI1y=ImWLr5o(^aPci@sCVfKNRnV;aO+oNFO>3u>-B z5irR1RHMLn-7a4Wm9(960P#ZilT6UA8=?3)mPv8-&j3>DS8QNAZD3KokH~}!8gQ@$ z7NK~T{`oLEvX9mL*}Ch~EncZ1k7U>315UNKpP+)!d*l<~zL^;9G@~fYs7K%pSk{xg zO0qx;*6ffx#AH&GZcAvMVh2P`N6Tbs)Sh5 zmuw=+-hEmpW7C4BJcs#E*3^Svw!xH@SA8(o`fRQVDZ8~fLSHVR=Hb4qQZSa7c1bPM zYa;o$0}ty=^uE%}q<0p1FRVsY{Db6M$+=BwWfxYFoeu%89|ElZK;7TwdOh=N1902?I8H4~QI}H+}T)-8&!csB3c1IIP%s`Yp+u*dj zxe3$-w&x01BCzx`(vGD4qjuAP?Yj+1fXL48Tgj9!ltx^I+TYx7c#@6$t>M3;HaKu` zBhYkmGLY}0I)}XBy@8;oZ#Y12qLY&H_&7dg@h!CQkAH>0BiSN2b4Md4~QH> zpp_41d%r=xsnu;RDhLq)nM&}S$knT*%2Kp}9gtg%=|I_OO>dVhgmf1ssss28H_*&m z?Xxez;m!W4(r$e*u2xl_xM;He-ZAt6T?i-@R`wZ^M#Uxhh(YF*I!8@G9VhK&zfjHg z9Fu4lHMKh}lWb4?Q>E0_+}U9eDCNg?y?{KWQ0JJYfjBYP9bZbzVH2&* z(H-=*B@IvbQT&oVPaeF;jEgKBFBMB^9cK}#d2FptY+V*38tjM?d1puLX*-xh~w;FkU_*EK)vnHbQ}>A+`pxIkOi!T>?a(+YAA#Oh}s^jTr#@XrC?!HC`S$ zbSZ6;ksQpU)RI2{U z+4K}PpFyCvh`%+dfQz@*)Q&?ln>J_Edy-%kF||k)xX(<1sF>Ox)f0U@Nw`fnT@j%z z(_jR#u#e`9t&vumK6QYEQrnYr^OeW)9Jfk(Y>J(59+5@9%r5{V5ZQ5Uxz_o@pvz7} z9EKws^rXXNoM7~(oVn3`3rPOFd$;1Jw-C#KSkoBOwJfC4UeypJIIril83>rh?3+R{ zbTNBfr7mLzY;Z4%rN0|fgqd8`yh!%ETk%Osrs)cjWDu4sePZWG5S|XF0@YU3rv{>) zd{t;OtxhyH7y(&T4gpMn@iU9+4kZk}2?gHHeuku7a-kn|&cM(4yHl#9_>@%bBlGX8 zN<1bPtkS8C3k3i>*JGjsIj`Pq^%XhEmR5Ptpv!I9i0X$Pbv#nkKQQ#TEeHA@lAOnq zp9iH!?{^O@_`PLNcM^Cq0n!&$cUsEJ&bE^zN6(PM&_*<76UNIr57kk^S!0-EsiWsdpzL%{a^NW%q_#wE)ONbx zG2v=Ep%TyXPNq%z;@VE?i^FM2Qwv&lSdzQurI2r0K|9dnuY{(OTF$!gA5Q8NWab`1 zrA)+Nqg16X_Tm>^&3v-SLvtpEE(>X6BXbNT0>qF`7qKK~_p5lz8!!}`c8$4AyrRBi ze&n+yDSP4n2;>0zM8h=-0jC8<+DG-!>1Jk|Aixplc}hNHkpQI{f0o|@jY2>G&=-@T6LfGXo`#3*oQgZ7qB~Z0@r)V`d68Wd&mvdiUH@M)=b;+MEeZPfqnHJX{7H z|H)JIt!yvssEUwr$AufF~D@8to$ zKAu-I-8aAaMxO*bi_i0$ufKTn=E{}GqkZ$`f6Fue?ti}0Yq#h6_N(9i`Q{Gv{coTD z8GfGb{>EYx@Ex^xZ%IGJpPa|M4I6Q~vsQzyIAI@BZbN 0 + assert list(filter(lambda x: x.metric_name == "test-x-step", metrics)) + assert list(filter(lambda x: x.metric_name == "test-x-timestamp", metrics)) + + # metrics -> eureka propagation + retry_with_backoff(verify_metrics) diff --git a/tests/integ/sagemaker/experiments/test_run.py b/tests/integ/sagemaker/experiments/test_run.py new file mode 100644 index 0000000000..713a6a3792 --- /dev/null +++ b/tests/integ/sagemaker/experiments/test_run.py @@ -0,0 +1,662 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime +import os + +import pytest + +from tests.integ.sagemaker.experiments.conftest import TAGS +from sagemaker.experiments._api_types import _TrialComponentStatusType +from sagemaker.experiments._utils import is_run_trial_component +from sagemaker.processing import FrameworkProcessor +from sagemaker.pytorch import PyTorch +from sagemaker.s3 import S3Uploader +from sagemaker.xgboost import XGBoostModel +from tests.integ import DATA_DIR +from sagemaker.experiments._metrics import BATCH_SIZE +from sagemaker.experiments.trial_component import _TrialComponent +from sagemaker.sklearn import SKLearn +from sagemaker.utils import retry_with_backoff, unique_name_from_base +from tests.integ.sagemaker.experiments.helpers import name, cleanup_exp_resources +from sagemaker.experiments.run import ( + RUN_NAME_BASE, + DELIMITER, +) +from sagemaker.experiments import Run, load_run, list_runs +from sagemaker.experiments._helper import _DEFAULT_ARTIFACT_PREFIX + + +# when running integration tests locally modify this to your test account's execution role +EXECUTION_ROLE = "SageMakerRole" + + +@pytest.fixture +def artifact_file_path(tempdir): + file_contents = "test artifact file" + file_path = os.path.join(tempdir, "artifact_file.txt") + with open(file_path, "w") as foo_file: + foo_file.write(file_contents) + return file_path + + +artifact_name = unique_name_from_base("Test-Artifact") +file_artifact_name = f"File-Artifact-{name()}" +metric_name = "Test-Local-Init-Log-Metric" + + +def test_local_run_with_load(sagemaker_session, artifact_file_path): + exp_name = f"My-Local-Exp-{name()}" + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + # Run name is not provided, will create a new TC + with Run(experiment_name=exp_name, sagemaker_session=sagemaker_session) as run1: + run1_name = run1.run_name + assert RUN_NAME_BASE in run1_name + _local_run_log_behaviors( + artifact_file_path=artifact_file_path, + sagemaker_session=sagemaker_session, + ) + + def verify_load_run(): + with load_run( + experiment_name=exp_name, + run_name=run1_name, + sagemaker_session=sagemaker_session, + ) as run2: + assert run2.run_name == run1_name + assert ( + run2._trial_component.trial_component_name + == f"{run2.experiment_name}{DELIMITER}{run1_name}" + ) + _check_run_from_local_end_result( + sagemaker_session=sagemaker_session, tc=run2._trial_component + ) + + # Add retry to make sure metrics -> eureka propagation is consistent + retry_with_backoff(verify_load_run, 4) + + +def test_two_local_run_init_with_same_run_name_and_different_exp_names(sagemaker_session): + exp_name1 = f"my-two-local-exp1-{name()}" + exp_name2 = f"my-two-local-exp2-{name()}" + run_name = "test-run" + with cleanup_exp_resources( + exp_names=[exp_name1, exp_name2], sagemaker_session=sagemaker_session + ): + # Run name is not provided, will create a new TC + with Run( + experiment_name=exp_name1, run_name=run_name, sagemaker_session=sagemaker_session + ) as run1: + pass + with Run( + experiment_name=exp_name2, run_name=run_name, sagemaker_session=sagemaker_session + ) as run2: + pass + + assert run1.experiment_name != run2.experiment_name + assert run1.run_name == run2.run_name + assert ( + run1._trial_component.trial_component_name != run2._trial_component.trial_component_name + ) + assert run1._trial_component.trial_component_name == f"{exp_name1}{DELIMITER}{run_name}" + assert run2._trial_component.trial_component_name == f"{exp_name2}{DELIMITER}{run_name}" + + +@pytest.mark.parametrize( + "input_names", + [ + (f"my-local-exp-{name()}", "test-run", None), # both have delimiter - + ("my-test-1", "my-test-1", None), # exp_name equals run_name + ("my-test-3", "my-test-3-run", None), # is subset of run_name + ("x" * 59, "test-run", None), # long exp_name + ("test-exp", "y" * 59, None), # long run_name + ("e" * 59, "y" * 59, None), # long exp_name and run_name + ("my-test4", "test-run", "run-display-name-test"), # with supplied display name + ], +) +def test_run_name_vs_trial_component_name_edge_cases(sagemaker_session, input_names): + exp_name, run_name, run_display_name = input_names + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + with Run( + experiment_name=exp_name, + sagemaker_session=sagemaker_session, + run_name=run_name, + run_display_name=run_display_name, + ) as run1: + assert not run1._experiment.tags + assert not run1._trial.tags + is_run_tc = is_run_trial_component( + trial_component_name=run1._trial_component.trial_component_name, + sagemaker_session=sagemaker_session, + ) + assert is_run_tc + + with load_run( + experiment_name=exp_name, + run_name=run_name, + sagemaker_session=sagemaker_session, + ) as run2: + assert run2.experiment_name == exp_name + assert run2.run_name == run_name + assert run2._trial_component.trial_component_name == f"{exp_name}{DELIMITER}{run_name}" + assert run2._trial_component.display_name in ( + run_display_name, + run2._trial_component.trial_component_name, + ) + + +_EXP_NAME_BASE_IN_SCRIPT = "job-exp-in-script" +_RUN_NAME_IN_SCRIPT = "job-run-in-script" + +_EXP_DIR = os.path.join(DATA_DIR, "experiment") +_ENTRY_POINT_PATH = os.path.join(_EXP_DIR, "train_job_script_for_run_clz.py") +_PYTHON_PROCESS_SCRIPT = "process_job_script_for_run_clz.py" +_TRANSFORM_MATERIALS = os.path.join(_EXP_DIR, "transform_job_materials") + +_RUN_INIT = "init" +_RUN_LOAD = "load" + + +def test_run_from_local_and_train_job_and_all_exp_cfg_match(sagemaker_session, dev_sdk_tar): + # Notes: + # 1. The 1st Run TC created locally and its exp config was auto passed to the job + # 2. In training job, the same exp and run names are given in the Run constructor + # which will load the 1st Run TC in training job and log parameters + # and metrics there + # 3. In a different training job, load the same Run TC and log more parameters there. + exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) + estimator = _generate_estimator( + sdk_tar=dev_sdk_tar, sagemaker_session=sagemaker_session, exp_name=exp_name + ) + tc_name = Run._generate_trial_component_name( + experiment_name=exp_name, run_name=_RUN_NAME_IN_SCRIPT + ) + + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + with Run( + experiment_name=exp_name, + run_name=_RUN_NAME_IN_SCRIPT, + sagemaker_session=sagemaker_session, + ) as run: + init_start_time = _check_tc_status_when_entering(run._trial_component) + _local_run_log_behaviors(is_complete_log=False, sagemaker_session=sagemaker_session) + # experiment_config is auto passed in by _RunContext + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + old_end_time = _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + sagemaker_session=sagemaker_session, + ) + + _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + old_end_time=old_end_time, + sagemaker_session=sagemaker_session, + ) + assert run.experiment_name == exp_name + assert run.run_name == _RUN_NAME_IN_SCRIPT + _check_run_from_local_end_result( + tc=run._trial_component, + sagemaker_session=sagemaker_session, + is_complete_log=False, + ) + _check_run_from_job_result( + tc_name=tc_name, + sagemaker_session=sagemaker_session, + ) + + with run: + estimator.environment["RUN_OPERATION"] = _RUN_LOAD + estimator.environment["CALL_RUN_LOAD_WITH_NO_NAME_ARGS"] = "True" + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + + old_end_time = _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + old_end_time=old_end_time, + sagemaker_session=sagemaker_session, + ) + + _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + old_end_time=old_end_time, + sagemaker_session=sagemaker_session, + ) + _check_run_from_job_result( + tc_name=tc_name, + sagemaker_session=sagemaker_session, + is_init=False, + has_extra_load=True, + ) + + +def test_run_from_local_and_train_job_and_exp_cfg_not_match(sagemaker_session, dev_sdk_tar): + # Notes: + # 1. The 1st Run TC created locally and its exp config was auto passed to the job + # 2. In training job, different exp and run names (i.e. 2nd Run TC) are given + # in the Run constructor which will create a Run TC according to the run_name + # passed in there and ignore the exp config in the job + # 3. Both metrics and parameters are logged in the Run TC created in job + # 4. In a different training job, load the 2nd Run TC and log more parameters there. + exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) + exp_name2 = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) + estimator = _generate_estimator( + sdk_tar=dev_sdk_tar, sagemaker_session=sagemaker_session, exp_name=exp_name + ) + tc_name = Run._generate_trial_component_name( + experiment_name=exp_name, run_name=_RUN_NAME_IN_SCRIPT + ) + + with cleanup_exp_resources( + exp_names=[exp_name, exp_name2], sagemaker_session=sagemaker_session + ): + with Run( + experiment_name=exp_name2, + run_name=f"{_RUN_NAME_IN_SCRIPT}2", + sagemaker_session=sagemaker_session, + ) as run: + init_start_time = _check_tc_status_when_entering(run._trial_component) + # experiment_config is auto passed in by _RunContext + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + _check_tc_status_intermediate( + trial_component=run._trial_component, + sagemaker_session=sagemaker_session, + init_start_time=init_start_time, + ) + + old_end_time = _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + sagemaker_session=sagemaker_session, + ) + assert run.experiment_name != exp_name + assert run.run_name != _RUN_NAME_IN_SCRIPT + _check_run_from_job_result( + tc_name=tc_name, + sagemaker_session=sagemaker_session, + ) + + with run: + estimator.environment["RUN_OPERATION"] = _RUN_LOAD + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + _check_tc_status_intermediate( + trial_component=run._trial_component, + sagemaker_session=sagemaker_session, + init_start_time=init_start_time, + old_end_time=old_end_time, + ) + + _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + old_end_time=old_end_time, + sagemaker_session=sagemaker_session, + ) + _check_run_from_job_result( + tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False + ) + + +def test_run_from_train_job_only(sagemaker_session, dev_sdk_tar): + # Notes: + # 1. No Run TC created locally or specified in experiment config + # 2. In training job, Run is initialized + # which will create a Run TC according to the run_name passed in there + # 3. Both metrics and parameters are logged in the Run TC created in job + # 4. In a different training job, load the same Run TC and log more parameters there. + exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) + estimator = _generate_estimator( + sdk_tar=dev_sdk_tar, + sagemaker_session=sagemaker_session, + exp_name=exp_name, + ) + tc_name = Run._generate_trial_component_name( + experiment_name=exp_name, run_name=_RUN_NAME_IN_SCRIPT + ) + + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + _check_run_from_job_result( + tc_name=tc_name, + sagemaker_session=sagemaker_session, + ) + + estimator.environment["RUN_OPERATION"] = _RUN_LOAD + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + _check_run_from_job_result( + tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False + ) + + +# dev_sdk_tar is required to trigger generating the dev SDK tar +def test_run_from_processing_job_and_override_default_exp_config( + sagemaker_session, dev_sdk_tar, run_obj +): + # Notes: + # 1. The 1st Run TC (run) created locally + # 2. Within the 2nd Run TC (run_obj)'s context, invoke processor.run + # but override the default experiment config in context of 2nd Run TC + # with the experiment config of the 1st Run TC + # 3. In the processing job script, load the 1st Run TC via the experiment config + # fetched from the job env + # 4. All data are logged in the Run TC either locally or in the processing job + exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) + processor = FrameworkProcessor( + estimator_cls=PyTorch, + framework_version="1.10", + py_version="py38", + instance_count=1, + instance_type="ml.m5.xlarge", + role=EXECUTION_ROLE, + sagemaker_session=sagemaker_session, + ) + + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + with Run( + experiment_name=exp_name, + run_name=_RUN_NAME_IN_SCRIPT, + sagemaker_session=sagemaker_session, + ) as run: + _local_run_log_behaviors(is_complete_log=False, sagemaker_session=sagemaker_session) + + with run_obj: + # Override the default experiment_config in _RunContext of run_obj + # with the experiment_config of run + processor.run( + code=_PYTHON_PROCESS_SCRIPT, + source_dir=_EXP_DIR, + job_name=f"process-job-{name()}", + wait=True, # wait the job to finish + logs=False, + experiment_config=run.experiment_config, + ) + + assert run_obj.experiment_name != run.experiment_name + assert run_obj.run_name != run.run_name + _check_run_from_local_end_result( + tc=run._trial_component, + sagemaker_session=sagemaker_session, + is_complete_log=False, + ) + tc_name = Run._generate_trial_component_name( + experiment_name=run.experiment_name, run_name=run.run_name + ) + _check_run_from_job_result( + tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False + ) + + with run_obj: + # Not to override the exp config and use the default one in the context + processor.run( + code=_PYTHON_PROCESS_SCRIPT, + source_dir=_EXP_DIR, + job_name=f"process-job-{name()}", + wait=True, # wait the job to finish + logs=False, + ) + + tc_name = Run._generate_trial_component_name( + experiment_name=run_obj.experiment_name, run_name=run_obj.run_name + ) + _check_run_from_job_result( + tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False + ) + + +# dev_sdk_tar is required to trigger generating the dev SDK tar +def test_run_from_transform_job(sagemaker_session, dev_sdk_tar, run_obj, xgboost_latest_version): + # Notes: + # 1. The 1st Run TC (run) created locally + # 2. In the inference script running in a transform job, load the 1st Run TC + # via explicitly passing the experiment_name and run_name of the 1st Run TC + # TODO: once we're able to retrieve exp config from the transform job env, + # we should expand this test and add the load_run() without explicitly supplying the names + # 3. All data are logged in the Run TC either locally or in the transform job + xgb_model_data_s3 = sagemaker_session.upload_data( + path=os.path.join(_TRANSFORM_MATERIALS, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + xgboost_model = XGBoostModel( + sagemaker_session=sagemaker_session, + model_data=xgb_model_data_s3, + role=EXECUTION_ROLE, + entry_point="inference.py", + source_dir=_EXP_DIR, + framework_version=xgboost_latest_version, + env={ + "EXPERIMENT_NAME": run_obj.experiment_name, + "RUN_NAME": run_obj.run_name, + }, + ) + transformer = xgboost_model.transformer( + instance_count=1, + instance_type="ml.m5.4xlarge", + max_concurrent_transforms=5, + max_payload=1, + strategy="MultiRecord", + ) + uri = "s3://{}/{}/input/data/{}".format( + sagemaker_session.default_bucket(), + "transform-test", + unique_name_from_base("json-data"), + ) + input_data = S3Uploader.upload( + os.path.join(_TRANSFORM_MATERIALS, "data.csv"), uri, sagemaker_session=sagemaker_session + ) + + with run_obj: + _local_run_log_behaviors(is_complete_log=False, sagemaker_session=sagemaker_session) + transformer.transform( + data=input_data, + content_type="text/libsvm", + split_type="Line", + wait=True, + job_name=f"transform-job-{name()}", + ) + + _check_run_from_local_end_result( + tc=run_obj._trial_component, + sagemaker_session=sagemaker_session, + is_complete_log=False, + ) + tc_name = Run._generate_trial_component_name( + experiment_name=run_obj.experiment_name, run_name=run_obj.run_name + ) + _check_run_from_job_result(tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False) + + +def test_list(run_obj, sagemaker_session): + tc1 = _TrialComponent.create( + trial_component_name=f"non-run-tc1-{name()}", + sagemaker_session=sagemaker_session, + ) + tc2 = _TrialComponent.create( + trial_component_name=f"non-run-tc2-{name()}", + sagemaker_session=sagemaker_session, + tags=TAGS, + ) + run_obj._trial.add_trial_component(tc1) + run_obj._trial.add_trial_component(tc2) + + run_tcs = list_runs( + experiment_name=run_obj.experiment_name, sagemaker_session=sagemaker_session + ) + assert len(run_tcs) == 1 + assert run_tcs[0].run_name == run_obj.run_name + assert run_tcs[0].experiment_name == run_obj.experiment_name + assert run_tcs[0].experiment_config == run_obj.experiment_config + + +def _generate_estimator(exp_name, sdk_tar, sagemaker_session): + return SKLearn( + framework_version="0.23-1", + entry_point=_ENTRY_POINT_PATH, + dependencies=[sdk_tar], + role=EXECUTION_ROLE, + instance_type="ml.m5.large", + instance_count=1, + volume_size=10, + max_run=900, + enable_sagemaker_metrics=True, + environment={ + "EXPERIMENT_NAME": exp_name, + "RUN_NAME": _RUN_NAME_IN_SCRIPT, + "RUN_OPERATION": _RUN_INIT, + }, + sagemaker_session=sagemaker_session, + ) + + +def _local_run_log_behaviors( + sagemaker_session, + artifact_file_path=None, + is_complete_log=True, +): + with load_run(sagemaker_session=sagemaker_session) as run: + run.log_parameter("pa", 1.0) + run.log_parameter("pb", "p2-value") + run.log_parameters({"pc": 2.0, "pd": "p4-value"}) + + if is_complete_log: + run.log_file(file_path=artifact_file_path, name=file_artifact_name) + run.log_artifact(name=artifact_name, value="s3://Output") + run.log_artifact(name=artifact_name, value="s3://Input", is_output=False) + + for i in range(BATCH_SIZE): + run.log_metric(name=metric_name, value=i, step=i) + + +def _check_run_from_local_end_result(sagemaker_session, tc, is_complete_log=True): + assert tc.parameters == {"pa": 1.0, "pb": "p2-value", "pc": 2.0, "pd": "p4-value"} + + if not is_complete_log: + return + + s3_prefix = f"s3://{sagemaker_session.default_bucket()}/{_DEFAULT_ARTIFACT_PREFIX}" + assert s3_prefix in tc.output_artifacts[file_artifact_name].value + assert "text/plain" == tc.output_artifacts[file_artifact_name].media_type + assert "s3://Output" == tc.output_artifacts[artifact_name].value + assert not tc.output_artifacts[artifact_name].media_type + assert "s3://Input" == tc.input_artifacts[artifact_name].value + assert not tc.input_artifacts[artifact_name].media_type + + # TODO: revert to len(tc.metrics) == 1 once backend fix reaches prod + assert len(tc.metrics) > 0 + metric_summary = tc.metrics[0] + assert metric_summary.metric_name == metric_name + assert metric_summary.max == 9.0 + assert metric_summary.min == 0.0 + + +def _check_run_from_job_result(sagemaker_session, tc_name=None, is_init=True, has_extra_load=False): + def validate_tc_updated_in_init(): + assert tc.start_time + assert tc.end_time + assert tc.status.primary_status == _TrialComponentStatusType.Completed.value + assert tc.parameters["p1"] == 1.0 + assert tc.parameters["p2"] == 2.0 + # TODO: revert to assert len(tc.metrics) == 5 once + # backend fix hits prod + assert len(tc.metrics) > 0 + for metric_summary in tc.metrics: + # metrics deletion is not supported at this point + # so its count would accumulate + assert metric_summary.count > 0 + assert metric_summary.min == 0.0 + assert metric_summary.max == 1.0 + + def validate_tc_updated_in_load(): + assert tc.parameters["p3"] == 3.0 + assert tc.parameters["p4"] == 4.0 + assert len(tc.metrics) > 0 + for metric_summary in tc.metrics: + if metric_summary.metric_name != "test-job-load-log-metric": + continue + assert metric_summary.last == 0.1 + assert metric_summary.max == 0.1 + assert metric_summary.min == 0.1 + if has_extra_load: + assert tc.parameters["p5"] == 5.0 + assert tc.parameters["p6"] == 6.0 + + tc = _TrialComponent.load(trial_component_name=tc_name, sagemaker_session=sagemaker_session) + if is_init: + # Add retry since the load behavior is inconsistent sometimes + retry_with_backoff(validate_tc_updated_in_init, 4) + else: + retry_with_backoff(validate_tc_updated_in_load, 4) + + +def _check_tc_status_when_entering(trial_component): + assert isinstance(trial_component.start_time, datetime.datetime) + assert not trial_component.end_time + assert trial_component.status.primary_status == _TrialComponentStatusType.InProgress.value + return trial_component.start_time + + +def _check_tc_status_when_exiting( + trial_component_name, sagemaker_session, init_start_time, old_end_time=None +): + tc = _TrialComponent.load( + trial_component_name=trial_component_name, sagemaker_session=sagemaker_session + ) + # There will be deviation (< 1s) caused by different TS precisions used in Backend and SDK + assert abs(tc.start_time.timestamp() - init_start_time.timestamp()) < 1 + assert tc.status.primary_status == _TrialComponentStatusType.Completed.value + assert isinstance(tc.end_time, datetime.datetime) + if old_end_time: + assert tc.end_time > old_end_time + return tc.end_time + + +def _check_tc_status_intermediate( + trial_component, sagemaker_session, init_start_time, old_end_time=None +): + tc_load = _TrialComponent.load( + trial_component_name=trial_component.trial_component_name, + sagemaker_session=sagemaker_session, + ) + assert abs(tc_load.start_time.timestamp() - init_start_time.timestamp()) < 1 + assert tc_load.status.primary_status == _TrialComponentStatusType.InProgress.value + if not old_end_time: + assert not trial_component.end_time + return + assert isinstance(tc_load.end_time, datetime.datetime) + assert tc_load.end_time == old_end_time diff --git a/tests/integ/sagemaker/experiments/test_trial.py b/tests/integ/sagemaker/experiments/test_trial.py new file mode 100644 index 0000000000..08f646c086 --- /dev/null +++ b/tests/integ/sagemaker/experiments/test_trial.py @@ -0,0 +1,75 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import logging + +from sagemaker.experiments import trial +from src.sagemaker.utils import retry_with_backoff + + +def test_create_delete(trial_obj): + # Fixture creates / deletes, just ensure used at least once. + assert trial_obj.trial_name + + +def test_create_tags(trial_obj, sagemaker_session): + client = sagemaker_session.sagemaker_client + while True: + actual_tags = client.list_tags(ResourceArn=trial_obj.trial_arn)["Tags"] + if actual_tags: + break + for tag in actual_tags: + if "aws:tag" in tag.get("Key"): + actual_tags.remove(tag) + assert actual_tags == trial_obj.tags + + +def test_save_load(trial_obj, sagemaker_session): + trial_obj.display_name = "foo" + trial_obj.save() + assert ( + "foo" + == trial._Trial.load( + trial_name=trial_obj.trial_name, + sagemaker_session=sagemaker_session, + ).display_name + ) + + +def test_add_remove_trial_component(trial_obj, trial_component_obj): + trial_obj.add_trial_component(trial_component_obj) + logging.info( + f"Added trial component {trial_component_obj.trial_component_name} to trial {trial_obj.trial_name}" + ) + + def validate_add(): + trial_components = list(trial_obj.list_trial_components()) + assert 1 == len( + trial_components + ), "Expected trial component to be included in trials list of TC" + + retry_with_backoff(validate_add) + + trial_obj.remove_trial_component(trial_component_obj) + logging.info( + f"Removed trial component {trial_component_obj.trial_component_name} from trial {trial_obj.trial_name}" + ) + + def validate_remove(): + trial_components = list(trial_obj.list_trial_components()) + assert 0 == len( + trial_components + ), "Expected trial component to be removed from trials list of TC" + + retry_with_backoff(validate_remove) diff --git a/tests/integ/sagemaker/experiments/test_trial_component.py b/tests/integ/sagemaker/experiments/test_trial_component.py new file mode 100644 index 0000000000..3d79e41cc4 --- /dev/null +++ b/tests/integ/sagemaker/experiments/test_trial_component.py @@ -0,0 +1,144 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime +import uuid + +from sagemaker.experiments._api_types import _TrialComponentStatusType +from tests.integ.sagemaker.experiments.helpers import EXP_INTEG_TEST_NAME_PREFIX +from sagemaker.experiments import _api_types, trial_component +from sagemaker.utilities.search_expression import Filter, Operator, SearchExpression + + +def test_create_delete(trial_component_obj): + # Fixture does create / delete, just need to ensure called at least once + assert trial_component_obj.trial_component_name + assert trial_component_obj.input_artifacts == {} + assert trial_component_obj.parameters == {} + assert trial_component_obj.output_artifacts == {} + + +def test_create_tags(trial_component_obj, sagemaker_session): + client = sagemaker_session.sagemaker_client + while True: + actual_tags = client.list_tags(ResourceArn=trial_component_obj.trial_component_arn)["Tags"] + if actual_tags: + break + for tag in actual_tags: + if "aws:tag" in tag.get("Key"): + actual_tags.remove(tag) + assert actual_tags == trial_component_obj.tags + + +def test_delete_with_force_disassociate( + trial_component_with_force_disassociation_obj, sagemaker_session +): + assert trial_component_with_force_disassociation_obj.trial_component_name + trials = sagemaker_session.sagemaker_client.list_trials( + TrialComponentName=trial_component_with_force_disassociation_obj.trial_component_name + )["TrialSummaries"] + assert len(trials) == 3 + + +def test_save(trial_component_obj, sagemaker_session): + trial_component_obj.display_name = str(uuid.uuid4()) + trial_component_obj.status = _api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, message="Message" + ) + trial_component_obj.start_time = datetime.datetime.now( + datetime.timezone.utc + ) - datetime.timedelta(days=1) + trial_component_obj.end_time = datetime.datetime.now(datetime.timezone.utc) + trial_component_obj.parameters = {"foo": "bar", "whizz": 100.1} + trial_component_obj.input_artifacts = { + "snizz": _api_types.TrialComponentArtifact(value="s3:/foo/bar", media_type="text/plain"), + "snizz1": _api_types.TrialComponentArtifact(value="s3:/foo/bar2", media_type="text/plain2"), + } + trial_component_obj.output_artifacts = { + "fly": _api_types.TrialComponentArtifact(value="s3:/sky/far", media_type="away/tomorrow"), + "fly2": _api_types.TrialComponentArtifact( + value="s3:/sky/far2", media_type="away/tomorrow2" + ), + } + trial_component_obj.parameters_to_remove = ["foo"] + trial_component_obj.input_artifacts_to_remove = ["snizz"] + trial_component_obj.output_artifacts_to_remove = ["fly2"] + + trial_component_obj.save() + + loaded = trial_component._TrialComponent.load( + trial_component_name=trial_component_obj.trial_component_name, + sagemaker_session=sagemaker_session, + ) + + assert trial_component_obj.trial_component_name == loaded.trial_component_name + assert trial_component_obj.status == loaded.status + + assert trial_component_obj.start_time - loaded.start_time < datetime.timedelta(seconds=1) + assert trial_component_obj.end_time - loaded.end_time < datetime.timedelta(seconds=1) + + assert loaded.parameters == {"whizz": 100.1} + assert loaded.input_artifacts == { + "snizz1": _api_types.TrialComponentArtifact(value="s3:/foo/bar2", media_type="text/plain2") + } + assert loaded.output_artifacts == { + "fly": _api_types.TrialComponentArtifact(value="s3:/sky/far", media_type="away/tomorrow") + } + + +def test_load(trial_component_obj, sagemaker_session): + loaded = trial_component._TrialComponent.load( + trial_component_name=trial_component_obj.trial_component_name, + sagemaker_session=sagemaker_session, + ) + assert trial_component_obj.trial_component_arn == loaded.trial_component_arn + + +def test_list_sort(trial_components, sagemaker_session): + slack = datetime.timedelta(minutes=1) + now = datetime.datetime.now(datetime.timezone.utc) + trial_component_names = [tc.trial_component_name for tc in trial_components] + + for sort_order in ["Ascending", "Descending"]: + trial_component_names_listed = [ + s.trial_component_name + for s in trial_component._TrialComponent.list( + created_after=now - slack, + created_before=now + slack, + sort_by="CreationTime", + sort_order=sort_order, + sagemaker_session=sagemaker_session, + ) + if s.trial_component_name in trial_component_names + ] + + if sort_order == "Descending": + trial_component_names_listed = trial_component_names_listed[::-1] + assert trial_component_names == trial_component_names_listed + assert trial_component_names # sanity test + + +def test_search(sagemaker_session): + trial_component_names_searched = [] + search_filter = Filter( + name="TrialComponentName", operator=Operator.CONTAINS, value=EXP_INTEG_TEST_NAME_PREFIX + ) + search_expression = SearchExpression(filters=[search_filter]) + for s in trial_component._TrialComponent.search( + search_expression=search_expression, max_results=10, sagemaker_session=sagemaker_session + ): + trial_component_names_searched.append(s.trial_component_name) + + assert len(trial_component_names_searched) > 0 + assert trial_component_names_searched # sanity test diff --git a/tests/integ/sagemaker/lineage/conftest.py b/tests/integ/sagemaker/lineage/conftest.py index 3c416ffd36..abfe6f6d0d 100644 --- a/tests/integ/sagemaker/lineage/conftest.py +++ b/tests/integ/sagemaker/lineage/conftest.py @@ -26,6 +26,7 @@ artifact, ) from sagemaker.model import ModelPackage +from sagemaker.utils import retry_with_backoff from tests.integ.sagemaker.workflow.test_workflow import ( test_end_to_end_pipeline_successful_execution, ) @@ -43,7 +44,7 @@ ) from sagemaker.lineage.lineage_trial_component import LineageTrialComponent -from tests.integ.sagemaker.lineage.helpers import name, names, retry +from tests.integ.sagemaker.lineage.helpers import name, names SLEEP_TIME_SECONDS = 1 SLEEP_TIME_TWO_SECONDS = 2 @@ -400,7 +401,7 @@ def model_obj(sagemaker_session): yield model time.sleep(SLEEP_TIME_SECONDS) - retry(lambda: model.delete(disassociate=True), num_attempts=4) + retry_with_backoff(lambda: model.delete(disassociate=True), num_attempts=4) @pytest.fixture diff --git a/tests/integ/sagemaker/lineage/helpers.py b/tests/integ/sagemaker/lineage/helpers.py index fb71d1d88c..5548c63cff 100644 --- a/tests/integ/sagemaker/lineage/helpers.py +++ b/tests/integ/sagemaker/lineage/helpers.py @@ -15,7 +15,6 @@ import uuid from datetime import datetime -import time def name(): @@ -33,19 +32,6 @@ def names(): ] -def retry(callable, num_attempts=8): - assert num_attempts >= 1 - for i in range(num_attempts): - try: - return callable() - except Exception as ex: - if i == num_attempts - 1: - raise ex - print("Retrying", ex) - time.sleep(2**i) - assert False, "logic error in retry" - - def traverse_graph_back(start_arn, sagemaker_session): def visit(arn, visited: set): visited.add(arn) diff --git a/tests/integ/sagemaker/lineage/test_artifact.py b/tests/integ/sagemaker/lineage/test_artifact.py index c629fcdc30..1980b51da2 100644 --- a/tests/integ/sagemaker/lineage/test_artifact.py +++ b/tests/integ/sagemaker/lineage/test_artifact.py @@ -20,7 +20,7 @@ import pytest from sagemaker.lineage import artifact -from tests.integ.sagemaker.lineage.helpers import retry +from sagemaker.utils import retry_with_backoff def test_create_delete(artifact_obj): @@ -125,7 +125,7 @@ def validate(): assert len(trials) == 1 assert trial_obj.trial_name in trials - retry(validate, num_attempts=3) + retry_with_backoff(validate, num_attempts=3) def test_downstream_trials_v2(trial_associated_artifact, trial_obj, sagemaker_session): diff --git a/tests/integ/sagemaker/utilities/__init__.py b/tests/integ/sagemaker/utilities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integ/sagemaker/utilities/test_search_expression.py b/tests/integ/sagemaker/utilities/test_search_expression.py new file mode 100644 index 0000000000..ea7f4476bf --- /dev/null +++ b/tests/integ/sagemaker/utilities/test_search_expression.py @@ -0,0 +1,67 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest + +from tests.integ.sagemaker.experiments.helpers import EXP_INTEG_TEST_NAME_PREFIX +from sagemaker.experiments.trial_component import _TrialComponent +from sagemaker.utilities.search_expression import Filter, Operator, SearchExpression, NestedFilter + + +def test_search(sagemaker_session): + tc_names_searched = [] + search_filter = Filter( + name="TrialComponentName", operator=Operator.CONTAINS, value=EXP_INTEG_TEST_NAME_PREFIX + ) + search_expression = SearchExpression(filters=[search_filter]) + for tc in _TrialComponent.search( + search_expression=search_expression, max_results=10, sagemaker_session=sagemaker_session + ): + tc_names_searched.append(tc.trial_component_name) + + assert len(tc_names_searched) > 0 + assert tc_names_searched + + +@pytest.mark.skip(reason="failed validation, need to wait for NestedFilter bug to be fixed") +def test_nested_search(sagemaker_session): + tc_names_searched = [] + search_filter = Filter( + name="TrialComponentName", operator=Operator.CONTAINS, value=EXP_INTEG_TEST_NAME_PREFIX + ) + nested_filter = NestedFilter(property_name="TrialComponentName", filters=[search_filter]) + search_expression = SearchExpression(nested_filters=[nested_filter]) + for tc in _TrialComponent.search( + search_expression=search_expression, max_results=10, sagemaker_session=sagemaker_session + ): + tc_names_searched.append(tc.trial_component_name) + + assert len(tc_names_searched) > 0 + assert tc_names_searched + + +def test_sub_expression(sagemaker_session): + tc_names_searched = [] + search_filter = Filter( + name="TrialComponentName", operator=Operator.CONTAINS, value=EXP_INTEG_TEST_NAME_PREFIX + ) + sub_expression = SearchExpression(filters=[search_filter]) + search_expression = SearchExpression(sub_expressions=[sub_expression]) + for tc in _TrialComponent.search( + search_expression=search_expression, max_results=10, sagemaker_session=sagemaker_session + ): + tc_names_searched.append(tc.trial_component_name) + + assert len(tc_names_searched) > 0 + assert tc_names_searched diff --git a/tests/integ/test_marketplace.py b/tests/integ/test_marketplace.py index b9ff13c50e..28b537c1ea 100644 --- a/tests/integ/test_marketplace.py +++ b/tests/integ/test_marketplace.py @@ -23,6 +23,7 @@ import sagemaker import tests.integ +from tests.integ.utils import create_repository from sagemaker import AlgorithmEstimator, ModelPackage, Model from sagemaker.serializers import CSVSerializer from sagemaker.tuner import IntegerParameter, HyperparameterTuner @@ -33,7 +34,6 @@ from tests.integ.test_multidatamodel import ( _ecr_image_uri, _ecr_login, - _create_repository, _delete_repository, ) from tests.integ.retry import retries @@ -214,7 +214,7 @@ def iris_image(sagemaker_session): rm=True, ) image.tag(ecr_image, tag="latest") - _create_repository(ecr_client, algorithm_name) + create_repository(ecr_client, algorithm_name) # Retry docker image push for _ in retries(3, "Upload docker image to ECR repo", seconds_to_sleep=10): diff --git a/tests/integ/test_multidatamodel.py b/tests/integ/test_multidatamodel.py index 78ba62c3db..d6c14037a7 100644 --- a/tests/integ/test_multidatamodel.py +++ b/tests/integ/test_multidatamodel.py @@ -19,8 +19,8 @@ import docker import numpy import pytest -from botocore.exceptions import ClientError +from tests.integ.utils import create_repository from sagemaker import utils from sagemaker.amazon.randomcutforest import RandomCutForest from sagemaker.deserializers import StringDeserializer @@ -59,7 +59,7 @@ def container_image(sagemaker_session): image.tag(ecr_image, tag="latest") # Create AWS ECR and push the local docker image to it - _create_repository(ecr_client, algorithm_name) + create_repository(ecr_client, algorithm_name) # Retry docker image push for _ in retries(3, "Upload docker image to ECR repo", seconds_to_sleep=10): @@ -90,23 +90,6 @@ def _ecr_image_uri(sagemaker_session, algorithm_name): return "{}.dkr.{}/{}:latest".format(account_id, endpoint_data["hostname"], algorithm_name) -def _create_repository(ecr_client, repository_name): - """ - Creates an ECS Repository (ECR). When a new transform is being registered, - we'll need a repository to push the image (and composed model images) to - """ - try: - response = ecr_client.create_repository(repositoryName=repository_name) - return response["repository"]["repositoryUri"] - except ClientError as e: - # Handle when the repository already exists - if "RepositoryAlreadyExistsException" == e.response.get("Error", {}).get("Code"): - response = ecr_client.describe_repositories(repositoryNames=[repository_name]) - return response["repositories"][0]["repositoryUri"] - else: - raise - - def _delete_repository(ecr_client, repository_name): """ Deletes an ECS Repository (ECR). After the integration test completes diff --git a/tests/integ/utils.py b/tests/integ/utils.py index 53440f96f5..d7891321f2 100644 --- a/tests/integ/utils.py +++ b/tests/integ/utils.py @@ -14,6 +14,8 @@ import logging from functools import wraps +from botocore.exceptions import ClientError + from tests.conftest import NO_P3_REGIONS, NO_M4_REGIONS from sagemaker.exceptions import CapacityError @@ -69,3 +71,21 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +def create_repository(ecr_client, repository_name): + """Creates an ECS Repository (ECR). + + When a new transform is being registered, + we'll need a repository to push the image (and composed model images) to + """ + try: + response = ecr_client.create_repository(repositoryName=repository_name) + return response["repository"]["repositoryUri"] + except ClientError as e: + # Handle when the repository already exists + if "RepositoryAlreadyExistsException" == e.response.get("Error", {}).get("Code"): + response = ecr_client.describe_repositories(repositoryNames=[repository_name]) + return response["repositories"][0]["repositoryUri"] + else: + raise diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 0000000000..21fe49cc97 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,66 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import sagemaker + +from mock import Mock, PropertyMock + +_ROLE = "DummyRole" +_REGION = "us-west-2" +_DEFAULT_BUCKET = "my-bucket" + + +@pytest.fixture(scope="session") +def client(): + """Mock client. + + Considerations when appropriate: + + * utilize botocore.stub.Stubber + * separate runtime client from client + """ + client_mock = Mock() + client_mock._client_config.user_agent = ( + "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" + ) + return client_mock + + +@pytest.fixture(scope="session") +def boto_session(client): + role_mock = Mock() + type(role_mock).arn = PropertyMock(return_value=_ROLE) + + resource_mock = Mock() + resource_mock.Role.return_value = role_mock + + session_mock = Mock(region_name=_REGION) + session_mock.resource.return_value = resource_mock + session_mock.client.return_value = client + + return session_mock + + +@pytest.fixture(scope="session") +def sagemaker_session(boto_session, client): + # ideally this would mock Session instead of instantiating it + # most unit tests do mock the session correctly + return sagemaker.session.Session( + boto_session=boto_session, + sagemaker_client=client, + sagemaker_runtime_client=client, + default_bucket=_DEFAULT_BUCKET, + sagemaker_metrics_client=client, + ) diff --git a/tests/unit/sagemaker/experiments/__init__.py b/tests/unit/sagemaker/experiments/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/experiments/conftest.py b/tests/unit/sagemaker/experiments/conftest.py new file mode 100644 index 0000000000..4d33ad759d --- /dev/null +++ b/tests/unit/sagemaker/experiments/conftest.py @@ -0,0 +1,86 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import unittest +from unittest.mock import patch, MagicMock, Mock + +import pytest + +from sagemaker import Session +from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments.run import RUN_NAME_BASE +from sagemaker.experiments import Run +from tests.unit.sagemaker.experiments.helpers import ( + mock_tc_load_or_create_func, + mock_trial_load_or_create_func, + TEST_EXP_NAME, +) + + +@pytest.fixture +def client(): + """Mock client. + + Considerations when appropriate: + + * utilize botocore.stub.Stubber + * separate runtime client from client + """ + client_mock = unittest.mock.Mock() + client_mock._client_config.user_agent = ( + "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" + ) + return client_mock + + +@pytest.fixture +def sagemaker_session(client): + return Session( + sagemaker_client=client, + ) + + +@pytest.fixture +def run_obj(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.update_trial_component.return_value = {} + client.associate_trial_component.return_value = {} + with patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock( + return_value=_Experiment( + experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session + ) + ), + ): + with patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), + ): + with patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), + ): + run = Run( + experiment_name=TEST_EXP_NAME, + sagemaker_session=sagemaker_session, + ) + run._artifact_uploader = Mock() + run._lineage_artifact_tracker = Mock() + run._metrics_manager = Mock() + + assert run.run_name.startswith(RUN_NAME_BASE) + assert run.run_group_name == Run._generate_trial_name(TEST_EXP_NAME) + + return run diff --git a/tests/unit/sagemaker/experiments/helpers.py b/tests/unit/sagemaker/experiments/helpers.py new file mode 100644 index 0000000000..b7914010e5 --- /dev/null +++ b/tests/unit/sagemaker/experiments/helpers.py @@ -0,0 +1,44 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.experiments.trial import _Trial +from sagemaker.experiments.trial_component import _TrialComponent + + +TEST_EXP_NAME = "my-experiment" +TEST_RUN_NAME = "my-run" + + +def mock_tc_load_or_create_func( + trial_component_name, display_name=None, tags=None, sagemaker_session=None +): + tc = _TrialComponent( + trial_component_name=trial_component_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + return tc, True + + +def mock_trial_load_or_create_func( + experiment_name, trial_name, display_name=None, tags=None, sagemaker_session=None +): + return _Trial( + trial_name=trial_name, + experiment_name=experiment_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) diff --git a/tests/unit/sagemaker/experiments/test_environment.py b/tests/unit/sagemaker/experiments/test_environment.py new file mode 100644 index 0000000000..8bb23db7b6 --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_environment.py @@ -0,0 +1,107 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json +import os +import shutil +import tempfile +import unittest.mock + +import pytest + +from sagemaker.experiments import _environment +from sagemaker.utils import retry_with_backoff + + +@pytest.fixture +def tempdir(): + dir = tempfile.mkdtemp() + yield dir + shutil.rmtree(dir) + + +@pytest.fixture +def training_job_env(): + old_value = os.environ.get("TRAINING_JOB_ARN") + os.environ["TRAINING_JOB_ARN"] = "arn:1234aBcDe" + yield os.environ + del os.environ["TRAINING_JOB_ARN"] + if old_value: + os.environ["TRAINING_JOB_ARN"] = old_value + + +@pytest.fixture +def transform_job_env(): + old_value = os.environ.get("SAGEMAKER_BATCH") + os.environ["SAGEMAKER_BATCH"] = "true" + yield os.environ + del os.environ["SAGEMAKER_BATCH"] + if old_value: + os.environ["SAGEMAKER_BATCH"] = old_value + + +def test_processing_job_environment(tempdir): + config_path = os.path.join(tempdir, "config.json") + with open(config_path, "w") as f: + f.write(json.dumps({"ProcessingJobArn": "arn:1234aBcDe"})) + environment = _environment._RunEnvironment.load(processing_job_config_path=config_path) + + assert _environment._EnvironmentType.SageMakerProcessingJob == environment.environment_type + assert "arn:1234aBcDe" == environment.source_arn + + +def test_training_job_environment(training_job_env): + environment = _environment._RunEnvironment.load() + assert _environment._EnvironmentType.SageMakerTrainingJob == environment.environment_type + assert "arn:1234aBcDe" == environment.source_arn + + +def test_transform_job_environment(transform_job_env): + environment = _environment._RunEnvironment.load() + assert _environment._EnvironmentType.SageMakerTransformJob == environment.environment_type + # TODO: update if we figure out how to get source_arn from the transform job + assert not environment.source_arn + + +def test_no_environment(): + assert _environment._RunEnvironment.load() is None + + +def test_resolve_trial_component(training_job_env, sagemaker_session): + trial_component_name = "foo-bar" + client = sagemaker_session.sagemaker_client + client.list_trial_components.return_value = { + "TrialComponentSummaries": [{"TrialComponentName": trial_component_name}] + } + client.describe_trial_component.return_value = {"TrialComponentName": trial_component_name} + environment = _environment._RunEnvironment.load() + tc = environment.get_trial_component(sagemaker_session) + + assert trial_component_name == tc.trial_component_name + client.describe_trial_component.assert_called_with(TrialComponentName=trial_component_name) + client.list_trial_components.assert_called_with(SourceArn="arn:1234abcde") + + +@unittest.mock.patch("sagemaker.experiments._environment.retry_with_backoff") +def test_resolve_trial_component_fails(mock_retry, sagemaker_session, training_job_env): + mock_retry.side_effect = lambda func: retry_with_backoff(func, 2) + client = sagemaker_session.sagemaker_client + client.list_trial_components.side_effect = Exception("Failed test") + environment = _environment._RunEnvironment.load() + assert environment.get_trial_component(sagemaker_session) is None + + +def test_resolve_transform_job_trial_component_fail(transform_job_env, sagemaker_session): + environment = _environment._RunEnvironment.load() + assert environment.get_trial_component(sagemaker_session) is None diff --git a/tests/unit/sagemaker/experiments/test_experiment.py b/tests/unit/sagemaker/experiments/test_experiment.py new file mode 100644 index 0000000000..b0ad55c27f --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_experiment.py @@ -0,0 +1,306 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import unittest.mock +import datetime + +from unittest.mock import patch + +from sagemaker import Session +from sagemaker.experiments import experiment +from sagemaker.experiments._api_types import TrialSummary + + +@pytest.fixture +def datetime_obj(): + return datetime.datetime(2017, 6, 16, 15, 55, 0) + + +def test_load(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.describe_experiment.return_value = {"Description": "description-value"} + experiment_obj = experiment._Experiment.load( + experiment_name="name-value", sagemaker_session=sagemaker_session + ) + assert experiment_obj.experiment_name == "name-value" + assert experiment_obj.description == "description-value" + + client.describe_experiment.assert_called_with(ExperimentName="name-value") + + +def test_create(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_experiment.return_value = {"Arn": "arn:aws:1234"} + experiment_obj = experiment._Experiment.create( + experiment_name="name-value", sagemaker_session=sagemaker_session + ) + assert experiment_obj.experiment_name == "name-value" + client.create_experiment.assert_called_with(ExperimentName="name-value") + + +def test_create_with_tags(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_experiment.return_value = {"Arn": "arn:aws:1234"} + tags = [{"Key": "foo", "Value": "bar"}] + experiment_obj = experiment._Experiment.create( + experiment_name="name-value", sagemaker_session=sagemaker_session, tags=tags + ) + assert experiment_obj.experiment_name == "name-value" + client.create_experiment.assert_called_with(ExperimentName="name-value", Tags=tags) + + +def test_save(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + client.update_experiment.return_value = {} + obj.save() + client.update_experiment.assert_called_with(ExperimentName="foo", Description="bar") + + +def test_delete(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + client.delete_experiment.return_value = {} + obj.delete() + client.delete_experiment.assert_called_with(ExperimentName="foo") + + +@patch("sagemaker.experiments.experiment._Experiment.load") +def test_load_or_create_when_exist(mock_load, sagemaker_session): + exp_name = "exp_name" + experiment._Experiment._load_or_create( + experiment_name=exp_name, sagemaker_session=sagemaker_session + ) + mock_load.assert_called_once_with(exp_name, sagemaker_session) + + +@patch("sagemaker.experiments.experiment._Experiment.load") +@patch("sagemaker.experiments.experiment._Experiment.create") +def test_load_or_create_when_not_exist(mock_create, mock_load): + sagemaker_session = Session() + client = sagemaker_session.sagemaker_client + exp_name = "exp_name" + not_found_err = client.exceptions.ResourceNotFound( + error_response={"Error": {"Code": "ResourceNotFound", "Message": "Not Found"}}, + operation_name="foo", + ) + mock_load.side_effect = not_found_err + + experiment._Experiment._load_or_create( + experiment_name=exp_name, sagemaker_session=sagemaker_session + ) + + mock_create.assert_called_once_with( + experiment_name=exp_name, + display_name=None, + description=None, + tags=None, + sagemaker_session=sagemaker_session, + ) + + +def test_list_trials_empty(sagemaker_session): + sagemaker_session.sagemaker_client.list_trials.return_value = {"TrialSummaries": []} + experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session) + assert list(experiment_obj.list_trials()) == [] + + +def test_list_trials_single(sagemaker_session, datetime_obj): + experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session) + sagemaker_session.sagemaker_client.list_trials.return_value = { + "TrialSummaries": [ + {"Name": "trial-foo", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj} + ] + } + + assert list(experiment_obj.list_trials()) == [ + TrialSummary(name="trial-foo", creation_time=datetime_obj, last_modified_time=datetime_obj) + ] + + +def test_list_trials_two_values(sagemaker_session, datetime_obj): + experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session) + sagemaker_session.sagemaker_client.list_trials.return_value = { + "TrialSummaries": [ + {"Name": "trial-foo-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj}, + {"Name": "trial-foo-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj}, + ] + } + + assert list(experiment_obj.list_trials()) == [ + TrialSummary( + name="trial-foo-1", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + name="trial-foo-2", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + ] + + +def test_next_token(sagemaker_session, datetime_obj): + experiment_obj = experiment._Experiment(sagemaker_session) + client = sagemaker_session.sagemaker_client + client.list_trials.side_effect = [ + { + "TrialSummaries": [ + { + "Name": "trial-foo-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "Name": "trial-foo-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ], + "NextToken": "foo", + }, + { + "TrialSummaries": [ + { + "Name": "trial-foo-3", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + } + ] + }, + ] + + assert list(experiment_obj.list_trials()) == [ + TrialSummary( + name="trial-foo-1", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + name="trial-foo-2", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + name="trial-foo-3", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + ] + + client.list_trials.assert_any_call(**{}) + client.list_trials.assert_any_call(NextToken="foo") + + +def test_list_trials_call_args(sagemaker_session): + client = sagemaker_session.sagemaker_client + created_before = datetime.datetime(1999, 10, 12, 0, 0, 0) + created_after = datetime.datetime(1990, 10, 12, 0, 0, 0) + experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session) + client.list_trials.return_value = {} + assert [] == list( + experiment_obj.list_trials(created_after=created_after, created_before=created_before) + ) + client.list_trials.assert_called_with(CreatedBefore=created_before, CreatedAfter=created_after) + + +def test_delete_all_with_incorrect_action_name(sagemaker_session): + obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + with pytest.raises(ValueError) as err: + obj._delete_all(action="abc") + + assert "Must confirm with string '--force'" in str(err) + + +def test_delete_all(sagemaker_session): + obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + client = sagemaker_session.sagemaker_client + client.list_trials.return_value = { + "TrialSummaries": [ + { + "TrialName": "trial-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialName": "trial-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + } + client.describe_trial.side_effect = [ + {"Trialname": "trial-1", "ExperimentName": "experiment-name-value"}, + {"Trialname": "trial-2", "ExperimentName": "experiment-name-value"}, + ] + client.list_trial_components.side_effect = [ + { + "TrialComponentSummaries": [ + { + "TrialComponentName": "trial-component-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialComponentName": "trial-component-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + }, + { + "TrialComponentSummaries": [ + { + "TrialComponentName": "trial-component-3", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialComponentName": "trial-component-4", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + }, + ] + + client.describe_trial_component.side_effect = [ + {"TrialComponentName": "trial-component-1"}, + {"TrialComponentName": "trial-component-2"}, + {"TrialComponentName": "trial-component-3"}, + {"TrialComponentName": "trial-component-4"}, + ] + + client.delete_trial_component.return_value = {} + client.delete_trial.return_value = {} + client.delete_experiment.return_value = {} + + obj._delete_all(action="--force") + + client.delete_experiment.assert_called_with(ExperimentName="foo") + + delete_trial_expected_calls = [ + unittest.mock.call(TrialName="trial-1"), + unittest.mock.call(TrialName="trial-2"), + ] + assert delete_trial_expected_calls == client.delete_trial.mock_calls + + delete_trial_component_expected_calls = [ + unittest.mock.call(TrialComponentName="trial-component-1"), + unittest.mock.call(TrialComponentName="trial-component-2"), + unittest.mock.call(TrialComponentName="trial-component-3"), + unittest.mock.call(TrialComponentName="trial-component-4"), + ] + assert delete_trial_component_expected_calls == client.delete_trial_component.mock_calls + + +def test_delete_all_fail(sagemaker_session): + obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + sagemaker_session.sagemaker_client.list_trials.side_effect = Exception + with pytest.raises(Exception) as e: + obj._delete_all(action="--force") + + assert str(e.value) == "Failed to delete, please try again." diff --git a/tests/unit/sagemaker/experiments/test_helper.py b/tests/unit/sagemaker/experiments/test_helper.py new file mode 100644 index 0000000000..a11f67389b --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_helper.py @@ -0,0 +1,195 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json +import os +import shutil +import tempfile + +from mock import Mock, PropertyMock, call +import pytest + +from src.sagemaker.experiments._helper import ( + _LineageArtifactTracker, + _ArtifactUploader, +) +from src.sagemaker.experiments._utils import resolve_artifact_name +from src.sagemaker.session import Session + + +@pytest.fixture +def client(): + """Mock client. + + Considerations when appropriate: + + * utilize botocore.stub.Stubber + * separate runtime client from client + """ + client_mock = Mock() + client_mock._client_config.user_agent = ( + "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" + ) + return client_mock + + +@pytest.fixture +def boto_session(client): + role_mock = Mock() + type(role_mock).arn = PropertyMock(return_value="DummyRole") + + resource_mock = Mock() + resource_mock.Role.return_value = role_mock + + session_mock = Mock(region_name="us-west-2") + session_mock.resource.return_value = resource_mock + session_mock.client.return_value = client + + return session_mock + + +@pytest.fixture +def sagemaker_session(client, boto_session): + return Session( + sagemaker_client=client, + boto_session=boto_session, + ) + + +@pytest.fixture +def lineage_artifact_tracker(sagemaker_session): + return _LineageArtifactTracker("test_trial_component_arn", sagemaker_session) + + +def test_lineage_artifact_tracker(lineage_artifact_tracker, sagemaker_session): + client = sagemaker_session.sagemaker_client + lineage_artifact_tracker.add_input_artifact( + "input_name", "input_source_uri", "input_etag", "text/plain" + ) + lineage_artifact_tracker.add_output_artifact( + "output_name", "output_source_uri", "output_etag", "text/plain" + ) + client.create_artifact.side_effect = [ + {"ArtifactArn": "created_arn_1"}, + {"ArtifactArn": "created_arn_2"}, + ] + + lineage_artifact_tracker.save() + + expected_calls = [ + call( + ArtifactName="input_name", + ArtifactType="text/plain", + Source={ + "SourceUri": "input_source_uri", + "SourceTypes": [{"SourceIdType": "S3ETag", "Value": "input_etag"}], + }, + ), + call( + ArtifactName="output_name", + ArtifactType="text/plain", + Source={ + "SourceUri": "output_source_uri", + "SourceTypes": [{"SourceIdType": "S3ETag", "Value": "output_etag"}], + }, + ), + ] + assert expected_calls == client.create_artifact.mock_calls + + expected_calls = [ + call( + SourceArn="created_arn_1", + DestinationArn="test_trial_component_arn", + AssociationType="ContributedTo", + ), + call( + SourceArn="test_trial_component_arn", + DestinationArn="created_arn_2", + AssociationType="Produced", + ), + ] + assert expected_calls == client.add_association.mock_calls + + +@pytest.fixture +def artifact_uploader(sagemaker_session): + return _ArtifactUploader( + trial_component_name="trial_component_name", + artifact_bucket="artifact_bucket", + artifact_prefix="artifact_prefix", + sagemaker_session=sagemaker_session, + ) + + +@pytest.fixture +def tempdir(): + tmp_dir = tempfile.mkdtemp() + yield tmp_dir + shutil.rmtree(tmp_dir) + + +def test_artifact_uploader_init(artifact_uploader): + assert "trial_component_name" == artifact_uploader.trial_component_name + assert "artifact_bucket" == artifact_uploader.artifact_bucket + assert "artifact_prefix" == artifact_uploader.artifact_prefix + + +def test_artifact_uploader_upload_artifact_file_not_exists(tempdir, artifact_uploader): + not_exist_file = os.path.join(tempdir, "not.exists") + with pytest.raises(ValueError) as error: + artifact_uploader.upload_artifact(not_exist_file) + assert "does not exist or is not a file" in str(error) + + +def test_artifact_uploader_upload_artifact(tempdir, artifact_uploader): + path = os.path.join(tempdir, "exists") + with open(path, "a") as f: + f.write("boo") + + name = resolve_artifact_name(path) + artifact_uploader._s3_client.head_object.return_value = {"ETag": "etag_value"} + + s3_uri, etag = artifact_uploader.upload_artifact(path) + expected_key = "{}/{}/{}".format( + artifact_uploader.artifact_prefix, artifact_uploader.trial_component_name, name + ) + + artifact_uploader._s3_client.upload_file.assert_called_with( + path, artifact_uploader.artifact_bucket, expected_key + ) + + expected_uri = "s3://{}/{}".format(artifact_uploader.artifact_bucket, expected_key) + assert expected_uri == s3_uri + + +def test_artifact_uploader_upload_object_artifact(tempdir, artifact_uploader): + artifact_uploader._s3_client.head_object.return_value = {"ETag": "etag_value"} + + artifact_name = "my-artifact" + artifact_object = {"key": "value"} + file_extension = ".csv" + s3_uri, etag = artifact_uploader.upload_object_artifact( + artifact_name, artifact_object, file_extension + ) + name = artifact_name + file_extension + expected_key = "{}/{}/{}".format( + artifact_uploader.artifact_prefix, artifact_uploader.trial_component_name, name + ) + + artifact_uploader._s3_client.put_object.assert_called_with( + Body=json.dumps(artifact_object), Bucket=artifact_uploader.artifact_bucket, Key=expected_key + ) + + expected_uri = "s3://{}/{}".format(artifact_uploader.artifact_bucket, expected_key) + assert expected_uri == s3_uri diff --git a/tests/unit/sagemaker/experiments/test_metrics.py b/tests/unit/sagemaker/experiments/test_metrics.py new file mode 100644 index 0000000000..21556f70fd --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_metrics.py @@ -0,0 +1,178 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import pytest +import tempfile +import shutil +import datetime +import dateutil +import json +import time + +from sagemaker.experiments._metrics import ( + _RawMetricData, + _SageMakerFileMetricsWriter, + SageMakerMetricsWriterException, +) + + +@pytest.fixture +def tempdir(): + dir = tempfile.mkdtemp() + yield dir + shutil.rmtree(dir) + + +@pytest.fixture +def filepath(tempdir): + return os.path.join(tempdir, "foo.json") + + +@pytest.fixture +def timestamp(): + return datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1) + + +def test_raw_metric_data_utc_timestamp(): + utcnow = datetime.datetime.now(datetime.timezone.utc) + assert utcnow.tzinfo + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=utcnow) + assert utcnow.timestamp() == metric.Timestamp + + +def test_raw_metric_data_utc_(): + utcnow = datetime.datetime.now(datetime.timezone.utc) + assert utcnow.tzinfo + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=utcnow) + assert utcnow.timestamp() == metric.Timestamp + + +def test_raw_metric_data_aware_timestamp(): + aware_datetime = datetime.datetime.now(dateutil.tz.gettz("America/Chicago")) + assert aware_datetime.tzinfo + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=aware_datetime) + assert (aware_datetime - aware_datetime.utcoffset()).replace( + tzinfo=datetime.timezone.utc + ).timestamp() == metric.Timestamp + + +def test_raw_metric_data_naive_timestamp(): + naive_datetime = datetime.datetime.now() + assert naive_datetime.tzinfo is None + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=naive_datetime) + local_datetime = naive_datetime.replace(tzinfo=dateutil.tz.tzlocal()) + assert (local_datetime - local_datetime.utcoffset()).replace( + tzinfo=datetime.timezone.utc + ).timestamp() == metric.Timestamp + + +def test_raw_metric_data_number_timestamp(): + time_now = time.time() + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=time_now) + assert time_now == metric.Timestamp + + +def test_raw_metric_data_request_item(): + time_now = time.time() + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=time_now, step=10) + expected = { + "MetricName": "foo", + "Value": 1.0, + "Timestamp": str(int(time_now)), + "Step": 10, + } + assert expected == metric.to_raw_metric_data() + + +def test_raw_metric_data_invalid_timestamp(): + with pytest.raises(ValueError) as error1: + _RawMetricData(metric_name="IFail", value=100, timestamp=time.time() - 2000000) + assert "Timestamps must be between two weeks before and two hours from now" in str(error1) + + with pytest.raises(ValueError) as error2: + _RawMetricData(metric_name="IFail", value=100, timestamp=time.time() + 10000) + assert "Timestamps must be between two weeks before and two hours from now" in str(error2) + + +def test_file_metrics_writer_log_metric(timestamp, filepath): + now = datetime.datetime.now(datetime.timezone.utc) + writer = _SageMakerFileMetricsWriter(filepath) + writer.log_metric(metric_name="foo", value=1.0) + writer.log_metric(metric_name="foo", value=2.0, step=1) + writer.log_metric(metric_name="foo", value=3.0, timestamp=timestamp) + writer.log_metric(metric_name="foo", value=4.0, timestamp=timestamp, step=2) + writer.close() + + lines = [x for x in open(filepath).read().split("\n") if x] + [entry_one, entry_two, entry_three, entry_four] = [json.loads(line) for line in lines] + + assert "foo" == entry_one["MetricName"] + assert 1.0 == entry_one["Value"] + assert (now.timestamp() - entry_one["Timestamp"]) < 1 + assert "Step" not in entry_one + + assert 1 == entry_two["Step"] + assert timestamp.timestamp() == entry_three["Timestamp"] + assert 2 == entry_four["Step"] + + +def test_file_metrics_writer_flushes_buffer_every_line_log_metric(filepath): + writer = _SageMakerFileMetricsWriter(filepath) + + writer.log_metric(metric_name="foo", value=1.0) + + lines = [x for x in open(filepath).read().split("\n") if x] + [entry_one] = [json.loads(line) for line in lines] + assert "foo" == entry_one["MetricName"] + assert 1.0 == entry_one["Value"] + + writer.log_metric(metric_name="bar", value=2.0) + lines = [x for x in open(filepath).read().split("\n") if x] + [entry_one, entry_two] = [json.loads(line) for line in lines] + assert "bar" == entry_two["MetricName"] + assert 2.0 == entry_two["Value"] + + writer.log_metric(metric_name="biz", value=3.0) + lines = [x for x in open(filepath).read().split("\n") if x] + [entry_one, entry_two, entry_three] = [json.loads(line) for line in lines] + assert "biz" == entry_three["MetricName"] + assert 3.0 == entry_three["Value"] + + writer.close() + + +def test_file_metrics_writer_context_manager(timestamp, filepath): + with _SageMakerFileMetricsWriter(filepath) as writer: + writer.log_metric("foo", value=1.0, timestamp=timestamp) + entry = json.loads(open(filepath, "r").read().strip()) + assert { + "MetricName": "foo", + "Value": 1.0, + "Timestamp": timestamp.timestamp(), + }.items() <= entry.items() + + +def test_file_metrics_writer_fail_write_on_close(filepath): + writer = _SageMakerFileMetricsWriter(filepath) + writer.log_metric(metric_name="foo", value=1.0) + writer.close() + with pytest.raises(SageMakerMetricsWriterException): + writer.log_metric(metric_name="foo", value=1.0) + + +def test_file_metrics_writer_no_write(filepath): + writer = _SageMakerFileMetricsWriter(filepath) + writer.close() + assert not os.path.exists(filepath) diff --git a/tests/unit/sagemaker/experiments/test_run.py b/tests/unit/sagemaker/experiments/test_run.py new file mode 100644 index 0000000000..0e4ebee181 --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_run.py @@ -0,0 +1,941 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime +import unittest +from math import inf, nan +from unittest.mock import patch, Mock, MagicMock + +import dateutil +import pytest + +from sagemaker.experiments import _environment, SortOrderType +from sagemaker.experiments._api_types import ( + TrialComponentArtifact, + TrialComponentSummary, + TrialComponentStatus, + _TrialComponentStatusType, + TrialComponentSearchResult, +) +from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments.run import ( + TRIAL_NAME_TEMPLATE, + MAX_RUN_TC_ARTIFACTS_LEN, + MAX_NAME_LEN_IN_BACKEND, + EXPERIMENT_NAME, + RUN_NAME, + TRIAL_NAME, + DELIMITER, + RUN_TC_TAG, + SortByType, +) +from sagemaker.experiments import Run, load_run, list_runs +from sagemaker.experiments.trial import _Trial +from sagemaker.experiments.trial_component import _TrialComponent +from tests.unit.sagemaker.experiments.helpers import ( + mock_trial_load_or_create_func, + mock_tc_load_or_create_func, + TEST_EXP_NAME, + TEST_RUN_NAME, +) + + +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +@patch.object(_TrialComponent, "save") +def test_run_init(mock_tc_save, sagemaker_session): + with Run( + experiment_name=TEST_EXP_NAME, run_name=TEST_RUN_NAME, sagemaker_session=sagemaker_session + ) as run_obj: + assert not run_obj._in_load + assert not run_obj._inside_load_context + assert run_obj._inside_init_context + assert not run_obj._trial_component.parameters + + expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}" + assert run_obj.experiment_name == TEST_EXP_NAME + assert run_obj.run_name == TEST_RUN_NAME + assert run_obj.run_group_name == TRIAL_NAME_TEMPLATE.format(TEST_EXP_NAME) + assert run_obj._trial_component.trial_component_name == expected_tc_name + assert run_obj._trial.trial_name == TRIAL_NAME_TEMPLATE.format(TEST_EXP_NAME) + assert run_obj._experiment.experiment_name == TEST_EXP_NAME + assert run_obj.experiment_config == { + EXPERIMENT_NAME: TEST_EXP_NAME, + TRIAL_NAME: run_obj.run_group_name, + RUN_NAME: expected_tc_name, + } + + # trail_component.save is called when entering/ exiting the with block + mock_tc_save.assert_called() + + +def test_run_init_name_length_exceed_limit(sagemaker_session): + invalid_name = "x" * MAX_NAME_LEN_IN_BACKEND + + # experiment_name exceeds + with pytest.raises(ValueError) as err: + Run( + experiment_name=invalid_name, + run_name=TEST_RUN_NAME, + sagemaker_session=sagemaker_session, + ) + + assert ( + f"The experiment_name (length: {MAX_NAME_LEN_IN_BACKEND}) must have length less than" + in str(err) + ) + + # run_name exceeds + with pytest.raises(ValueError) as err: + Run( + experiment_name=TEST_EXP_NAME, + run_name=invalid_name, + sagemaker_session=sagemaker_session, + ) + + assert f"The run_name (length: {MAX_NAME_LEN_IN_BACKEND}) must have length less than" in str( + err + ) + + +@patch.object(_TrialComponent, "save", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +@patch("sagemaker.experiments.run._RunEnvironment") +def test_run_load_no_run_name_and_in_train_job(mock_run_env, sagemaker_session): + client = sagemaker_session.sagemaker_client + job_name = "my-train-job" + rv = Mock() + rv.source_arn = f"arn:1234/{job_name}" + rv.environment_type = _environment._EnvironmentType.SageMakerTrainingJob + mock_run_env.load.return_value = rv + + expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}" + exp_config = { + EXPERIMENT_NAME: TEST_EXP_NAME, + TRIAL_NAME: Run._generate_trial_name(TEST_EXP_NAME), + RUN_NAME: expected_tc_name, + } + client.describe_training_job.return_value = { + "TrainingJobName": "train-job-experiments", + # The Run object has been created else where + "ExperimentConfig": exp_config, + } + with load_run(sagemaker_session=sagemaker_session) as run_obj: + assert run_obj._in_load + assert not run_obj._inside_init_context + assert run_obj._inside_load_context + assert run_obj.run_name == TEST_RUN_NAME + assert run_obj._trial_component.trial_component_name == expected_tc_name + assert run_obj.run_group_name == Run._generate_trial_name(TEST_EXP_NAME) + assert run_obj._trial + assert run_obj.experiment_name == TEST_EXP_NAME + assert run_obj._experiment + assert run_obj.experiment_config == exp_config + + client.describe_training_job.assert_called_once_with(TrainingJobName=job_name) + + +@patch("sagemaker.experiments.run._RunEnvironment") +def test_run_load_no_run_name_and_in_train_job_but_fail_to_get_exp_cfg( + mock_run_env, sagemaker_session +): + rv = Mock() + rv.source_arn = "arn:1234/my-train-job" + rv.environment_type = _environment._EnvironmentType.SageMakerTrainingJob + mock_run_env.load.return_value = rv + + # No Run object is created else where + sagemaker_session.sagemaker_client.describe_training_job.return_value = { + "TrainingJobName": "train-job-experiments", + } + + with pytest.raises(RuntimeError) as err: + with load_run(sagemaker_session=sagemaker_session): + pass + + assert "Not able to fetch RunName in ExperimentConfig of the sagemaker job" in str(err) + + +def test_run_load_no_run_name_and_not_in_train_job(run_obj, sagemaker_session): + with run_obj: + with load_run(sagemaker_session=sagemaker_session) as run: + assert run_obj == run + + +def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(sagemaker_session): + with pytest.raises(RuntimeError) as err: + with load_run(sagemaker_session=sagemaker_session): + pass + + assert "Failed to load a Run object" in str(err) + + # experiment_name is given but is not supplied along with the run_name so it's ignored. + with pytest.raises(RuntimeError) as err: + with load_run(experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session): + pass + + assert "Failed to load a Run object" in str(err) + + +@patch.object(_TrialComponent, "save", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +def test_run_load_with_run_name_and_exp_name(sagemaker_session): + with load_run( + run_name=TEST_RUN_NAME, + experiment_name=TEST_EXP_NAME, + sagemaker_session=sagemaker_session, + ) as run_obj: + expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}" + expected_exp_config = { + EXPERIMENT_NAME: TEST_EXP_NAME, + TRIAL_NAME: Run._generate_trial_name(TEST_EXP_NAME), + RUN_NAME: expected_tc_name, + } + + assert run_obj.run_name == TEST_RUN_NAME + assert run_obj.run_group_name == Run._generate_trial_name(TEST_EXP_NAME) + assert run_obj.experiment_name == TEST_EXP_NAME + assert run_obj._trial_component.trial_component_name == expected_tc_name + assert run_obj._trial + assert run_obj._experiment + assert run_obj.experiment_config == expected_exp_config + + +def test_run_load_with_run_name_but_no_exp_name(sagemaker_session): + with pytest.raises(ValueError) as err: + with load_run( + run_name=TEST_RUN_NAME, + sagemaker_session=sagemaker_session, + ): + pass + + assert "Invalid input: experiment_name is missing" in str(err) + + +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +@patch.object(_TrialComponent, "save", MagicMock(return_value=None)) +@patch("sagemaker.experiments.run._RunEnvironment") +def test_run_load_in_sm_processing_job(mock_run_env, sagemaker_session): + client = sagemaker_session.sagemaker_client + job_name = "my-process-job" + rv = unittest.mock.Mock() + rv.source_arn = f"arn:1234/{job_name}" + rv.environment_type = _environment._EnvironmentType.SageMakerProcessingJob + mock_run_env.load.return_value = rv + + expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}" + exp_config = { + EXPERIMENT_NAME: TEST_EXP_NAME, + TRIAL_NAME: Run._generate_trial_name(TEST_EXP_NAME), + RUN_NAME: expected_tc_name, + } + client.describe_processing_job.return_value = { + "ProcessingJobName": "process-job-experiments", + # The Run object has been created else where + "ExperimentConfig": exp_config, + } + + with load_run(sagemaker_session=sagemaker_session): + pass + + client.describe_processing_job.assert_called_once_with(ProcessingJobName=job_name) + + +@patch("sagemaker.experiments.run._RunEnvironment") +def test_run_load_in_sm_transform_job(mock_run_env, sagemaker_session): + # TODO: update this test once figure out how to get source_arn from transform job + rv = unittest.mock.Mock() + rv.environment_type = _environment._EnvironmentType.SageMakerTransformJob + rv.source_arn = "" + mock_run_env.load.return_value = rv + + with pytest.raises(RuntimeError) as err: + with load_run(sagemaker_session=sagemaker_session): + pass + + assert ( + "loading experiment config from transform job environment is not currently supported" + ) in str(err) + + +def test_log_parameter_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_parameter("foo", "bar") + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_parameter(run_obj): + with run_obj: + run_obj.log_parameter("foo", "bar") + assert run_obj._trial_component.parameters["foo"] == "bar" + run_obj.log_parameter("whizz", 1) + assert run_obj._trial_component.parameters["whizz"] == 1 + + +def test_log_parameter_skip_invalid_value(run_obj): + with run_obj: + run_obj.log_parameter("key", nan) + assert "key" not in run_obj._trial_component.parameters + + +def test_log_parameters_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_parameters({"a": "b", "c": "d", "e": 5}) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_parameters(run_obj): + with run_obj: + run_obj.log_parameters({"a": "b", "c": "d", "e": 5}) + assert run_obj._trial_component.parameters == {"a": "b", "c": "d", "e": 5} + + +def test_log_parameters_skip_invalid_values(run_obj): + with run_obj: + run_obj.log_parameters({"a": "b", "c": "d", "e": 5, "f": nan}) + assert run_obj._trial_component.parameters == {"a": "b", "c": "d", "e": 5} + + +def test_log_input_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_artifact("foo", "baz", "text/text", False) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_input(run_obj): + with run_obj: + run_obj.log_artifact("foo", "baz", "text/text", False) + assert run_obj._trial_component.input_artifacts == { + "foo": TrialComponentArtifact(value="baz", media_type="text/text") + } + + +def test_log_output_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_artifact("foo", "baz", "text/text") + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_output(run_obj): + with run_obj: + run_obj.log_artifact("foo", "baz", "text/text") + assert run_obj._trial_component.output_artifacts == { + "foo": TrialComponentArtifact(value="baz", media_type="text/text") + } + + +def test_log_metric_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_metric(name="foo", value=1.0, step=1) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_metric(run_obj): + now = datetime.datetime.now() + with run_obj: + run_obj.log_metric(name="foo", value=1.0, step=1, timestamp=now) + run_obj._metrics_manager.log_metric.assert_called_with( + metric_name="foo", value=1.0, step=1, timestamp=now + ) + + +def test_log_metric_skip_invalid_value(run_obj): + with run_obj: + run_obj.log_metric(None, nan, None, None) + assert not run_obj._metrics_manager.log_metric.called + + +def test_log_metric_attribute_error(run_obj): + now = datetime.datetime.now() + with run_obj: + run_obj._metrics_manager.log_metric.side_effect = AttributeError + + with pytest.raises(AttributeError): + run_obj.log_metric("foo", 1.0, 1, now) + + +def test_log_output_artifact_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_file("foo.txt", "name", "whizz/bang") + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_output_artifact(run_obj): + run_obj._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value") + with run_obj: + run_obj.log_file("foo.txt", "name", "whizz/bang") + run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt") + assert "whizz/bang" == run_obj._trial_component.output_artifacts["name"].media_type + + run_obj.log_file("foo.txt") + run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt") + assert "foo.txt" in run_obj._trial_component.output_artifacts + assert "text/plain" == run_obj._trial_component.output_artifacts["foo.txt"].media_type + + +def test_log_input_artifact_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_input_artifact(run_obj): + run_obj._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value") + with run_obj: + run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False) + run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt") + assert "whizz/bang" == run_obj._trial_component.input_artifacts["name"].media_type + + run_obj.log_file("foo.txt", is_output=False) + run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt") + assert "foo.txt" in run_obj._trial_component.input_artifacts + assert "text/plain" == run_obj._trial_component.input_artifacts["foo.txt"].media_type + + +def test_log_multiple_inputs(run_obj): + with run_obj: + for index in range(0, MAX_RUN_TC_ARTIFACTS_LEN): + file_path = "foo" + str(index) + ".txt" + run_obj._trial_component.input_artifacts[file_path] = { + "foo": TrialComponentArtifact(value="baz" + str(index), media_type="text/text") + } + with pytest.raises(ValueError) as error: + run_obj.log_artifact("foo.txt", "name", "whizz/bang", False) + assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} input_artifacts" in str(error) + + +def test_log_multiple_outputs(run_obj): + with run_obj: + for index in range(0, MAX_RUN_TC_ARTIFACTS_LEN): + file_path = "foo" + str(index) + ".txt" + run_obj._trial_component.output_artifacts[file_path] = { + "foo": TrialComponentArtifact(value="baz" + str(index), media_type="text/text") + } + with pytest.raises(ValueError) as error: + run_obj.log_artifact("foo.txt", "name", "whizz/bang") + assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} output_artifacts" in str(error) + + +def test_log_multiple_input_artifacts(run_obj): + with run_obj: + for index in range(0, MAX_RUN_TC_ARTIFACTS_LEN): + file_path = "foo" + str(index) + ".txt" + run_obj._artifact_uploader.upload_artifact.return_value = ( + "s3uri_value" + str(index), + "etag_value" + str(index), + ) + run_obj.log_file( + file_path, "name" + str(index), "whizz/bang" + str(index), is_output=False + ) + run_obj._artifact_uploader.upload_artifact.assert_called_with(file_path) + + run_obj._artifact_uploader.upload_artifact.return_value = ( + "s3uri_value", + "etag_value", + ) + + # log an output artifact, should be fine + run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=True) + + # log an extra input artifact, should raise exception + with pytest.raises(ValueError) as error: + run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False) + assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} input_artifacts" in str(error) + + +def test_log_multiple_output_artifacts(run_obj): + with run_obj: + for index in range(0, MAX_RUN_TC_ARTIFACTS_LEN): + file_path = "foo" + str(index) + ".txt" + run_obj._artifact_uploader.upload_artifact.return_value = ( + "s3uri_value" + str(index), + "etag_value" + str(index), + ) + run_obj.log_file(file_path, "name" + str(index), "whizz/bang" + str(index)) + run_obj._artifact_uploader.upload_artifact.assert_called_with(file_path) + + run_obj._artifact_uploader.upload_artifact.return_value = ( + "s3uri_value", + "etag_value", + ) + + # log an input artifact, should be fine + run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False) + + # log an extra output artifact, should raise exception + with pytest.raises(ValueError) as error: + run_obj.log_file("foo.txt", "name", "whizz/bang") + assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} output_artifacts" in str(error) + + +def test_log_precision_recall_outside_run_context(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35, 0.8] + no_skill = 0.1 + title = "TestPrecisionRecall" + + with pytest.raises(RuntimeError) as err: + run_obj.log_precision_recall( + y_true, y_scores, 0, title=title, no_skill=no_skill, is_output=False + ) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_precision_recall(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35, 0.8] + no_skill = 0.1 + title = "TestPrecisionRecall" + + run_obj._artifact_uploader.upload_object_artifact.return_value = ( + "s3uri_value", + "etag_value", + ) + with run_obj: + run_obj.log_precision_recall( + y_true, y_scores, 0, title=title, no_skill=no_skill, is_output=False + ) + + expected_data = { + "type": "PrecisionRecallCurve", + "version": 0, + "title": title, + "precision": [0.5, 0.3333333333333333, 0.5, 0.0, 1.0], + "recall": [1.0, 0.5, 0.5, 0.0, 0.0], + "averagePrecisionScore": 0.5, + "noSkill": 0.1, + } + run_obj._artifact_uploader.upload_object_artifact.assert_called_with( + title, expected_data, file_extension="json" + ) + + run_obj._lineage_artifact_tracker.add_input_artifact.assert_called_with( + name=title, + source_uri="s3uri_value", + etag="etag_value", + artifact_type="PrecisionRecallCurve", + ) + + +def test_log_precision_recall_invalid_input(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35] + no_skill = 0.1 + + with run_obj: + with pytest.raises(ValueError) as error: + run_obj.log_precision_recall( + y_true, y_scores, 0, title="TestPrecisionRecall", no_skill=no_skill, is_output=False + ) + assert "Lengths mismatch between true labels and predicted probabilities" in str(error) + + +def test_log_confusion_matrix_outside_run_context(run_obj): + y_true = [2, 0, 2, 2, 0, 1] + y_pred = [0, 0, 2, 2, 0, 2] + + with pytest.raises(RuntimeError) as err: + run_obj.log_confusion_matrix(y_true, y_pred, title="TestConfusionMatrix") + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_confusion_matrix(run_obj): + y_true = [2, 0, 2, 2, 0, 1] + y_pred = [0, 0, 2, 2, 0, 2] + + run_obj._artifact_uploader.upload_object_artifact.return_value = ( + "s3uri_value", + "etag_value", + ) + with run_obj: + run_obj.log_confusion_matrix(y_true, y_pred, title="TestConfusionMatrix") + + expected_data = { + "type": "ConfusionMatrix", + "version": 0, + "title": "TestConfusionMatrix", + "confusionMatrix": [[2, 0, 0], [0, 0, 1], [1, 0, 2]], + } + + run_obj._artifact_uploader.upload_object_artifact.assert_called_with( + "TestConfusionMatrix", expected_data, file_extension="json" + ) + + run_obj._lineage_artifact_tracker.add_output_artifact.assert_called_with( + name="TestConfusionMatrix", + source_uri="s3uri_value", + etag="etag_value", + artifact_type="ConfusionMatrix", + ) + + +def test_log_confusion_matrix_invalid_input(run_obj): + y_true = [2, 0, 2, 2, 0, 1] + y_pred = [0, 0, 2, 2, 0] + + with run_obj: + with pytest.raises(ValueError) as error: + run_obj.log_confusion_matrix(y_true, y_pred, title="TestConfusionMatrix") + assert "Lengths mismatch between true labels and predicted labels" in str(error) + + +def test_log_roc_curve_outside_run_context(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35, 0.8] + + with pytest.raises(RuntimeError) as err: + run_obj.log_roc_curve(y_true, y_scores, title="TestROCCurve", is_output=False) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_roc_curve(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35, 0.8] + with run_obj: + run_obj._artifact_uploader.upload_object_artifact.return_value = ( + "s3uri_value", + "etag_value", + ) + + run_obj.log_roc_curve(y_true, y_scores, title="TestROCCurve", is_output=False) + + expected_data = { + "type": "ROCCurve", + "version": 0, + "title": "TestROCCurve", + "falsePositiveRate": [0.0, 0.0, 0.5, 0.5, 1.0], + "truePositiveRate": [0.0, 0.5, 0.5, 1.0, 1.0], + "areaUnderCurve": 0.75, + } + run_obj._artifact_uploader.upload_object_artifact.assert_called_with( + "TestROCCurve", expected_data, file_extension="json" + ) + + run_obj._lineage_artifact_tracker.add_input_artifact.assert_called_with( + name="TestROCCurve", + source_uri="s3uri_value", + etag="etag_value", + artifact_type="ROCCurve", + ) + + +def test_log_roc_curve_invalid_input(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35] + + with run_obj: + with pytest.raises(ValueError) as error: + run_obj.log_roc_curve(y_true, y_scores, title="TestROCCurve", is_output=False) + assert "Lengths mismatch between true labels and predicted scores" in str(error) + + +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch("sagemaker.experiments.run._TrialComponent._load_or_create") +@patch("sagemaker.experiments.run._TrialComponent.list") +@patch("sagemaker.experiments.run._TrialComponent.search") +def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_session): + start_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1) + end_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=2) + creation_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=3) + last_modified_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=4) + tc_list_len = 20 + tc_list_len_half = int(tc_list_len / 2) + mock_tc_search.side_effect = [ + [ + TrialComponentSearchResult( + trial_component_name=Run._generate_trial_component_name( + "a" + str(i), TEST_EXP_NAME + ), + trial_component_arn="b" + str(i), + display_name="C" + str(i), + creation_time=creation_time + datetime.timedelta(hours=i), + last_modified_time=last_modified_time + datetime.timedelta(hours=i), + last_modified_by={}, + tags=[RUN_TC_TAG] if i < tc_list_len_half else None, + ) + ] + for i in range(tc_list_len) + ] + mock_tc_list.return_value = [ + TrialComponentSummary( + trial_component_name=Run._generate_trial_component_name("A" + str(i), TEST_EXP_NAME), + trial_component_arn="b" + str(i), + display_name="C" + str(i), + source_arn="D" + str(i), + status=TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i) + ), + start_time=start_time + datetime.timedelta(hours=i), + end_time=end_time + datetime.timedelta(hours=i), + creation_time=creation_time + datetime.timedelta(hours=i), + last_modified_time=last_modified_time + datetime.timedelta(hours=i), + last_modified_by={}, + ) + for i in range(tc_list_len) + ] + mock_tc_load.side_effect = [ + ( + _TrialComponent( + trial_component_name=Run._generate_trial_component_name( + "a" + str(i), TEST_EXP_NAME + ), + trial_component_arn="b" + str(i), + display_name="C" + str(i), + source_arn="D" + str(i), + status=TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i) + ), + start_time=start_time + datetime.timedelta(hours=i), + end_time=end_time + datetime.timedelta(hours=i), + creation_time=creation_time + datetime.timedelta(hours=i), + last_modified_time=last_modified_time + datetime.timedelta(hours=i), + last_modified_by={}, + ), + True, + ) + for i in range(tc_list_len_half) + ] + + run_list = list_runs( + experiment_name=TEST_EXP_NAME, + sort_by=SortByType.CREATION_TIME, + sort_order=SortOrderType.ASCENDING, + sagemaker_session=sagemaker_session, + ) + + mock_tc_list.assert_called_once_with( + experiment_name=TEST_EXP_NAME, + created_before=None, + created_after=None, + sort_by="CreationTime", + sort_order="Ascending", + sagemaker_session=sagemaker_session, + max_results=None, + next_token=None, + ) + assert len(run_list) == tc_list_len_half + for i in range(tc_list_len_half): + run = run_list[i] + assert run.experiment_name == TEST_EXP_NAME + assert run.run_name == "a" + str(i) + assert run._experiment + assert run._trial + assert isinstance(run._trial_component, _TrialComponent) + assert run._trial_component.trial_component_name == Run._generate_trial_component_name( + "a" + str(i), TEST_EXP_NAME + ) + assert run._in_load is False + assert run._inside_load_context is False + assert run._inside_init_context is False + assert run._artifact_uploader + assert run._lineage_artifact_tracker + assert run._metrics_manager + + +@patch("sagemaker.experiments.run._TrialComponent.list") +def test_list_empty(mock_tc_list, sagemaker_session): + mock_tc_list.return_value = [] + assert [] == list_runs(experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session) + + +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch("sagemaker.experiments.run._TrialComponent._load_or_create") +def test_enter_exit_locally(mock_load_tc, sagemaker_session, run_obj): + mock_load_tc.return_value = run_obj._trial_component, False + sagemaker_session.sagemaker_client.update_trial_component.return_value = {} + _verify_tc_status_before_enter_init(run_obj._trial_component) + + with run_obj: + _verify_tc_status_when_entering(run_obj._trial_component) + init_start_time = run_obj._trial_component.start_time + + with load_run(sagemaker_session=sagemaker_session): + _verify_tc_status_when_entering( + trial_component=run_obj._trial_component, + init_start_time=init_start_time, + ) + + old_end_time = _verify_tc_status_when_successfully_exit( + trial_component=run_obj._trial_component, + ) + + old_end_time = _verify_tc_status_when_successfully_exit( + trial_component=run_obj._trial_component, + old_end_time=old_end_time, + ) + + # Re-load to verify: + # 1. if it works when load_run and with are not in one line + # 2. if re-entering the load will change the "Completed" TC status + # to "InProgress" + # 3. when exiting the load, the end_time and status will be overridden again + run_load = load_run( + experiment_name=run_obj.experiment_name, + run_name=run_obj.run_name, + sagemaker_session=sagemaker_session, + ) + with run_load: + _verify_tc_status_when_entering( + trial_component=run_obj._trial_component, + init_start_time=init_start_time, + has_completed=True, + ) + _verify_tc_status_when_successfully_exit( + trial_component=run_obj._trial_component, old_end_time=old_end_time + ) + + +def test_exit_fail(sagemaker_session, run_obj): + sagemaker_session.sagemaker_client.update_trial_component.return_value = {} + try: + with run_obj: + raise ValueError("Foo") + except ValueError: + pass + + assert run_obj._trial_component.status.primary_status == _TrialComponentStatusType.Failed.value + assert run_obj._trial_component.status.message + assert isinstance(run_obj._trial_component.end_time, datetime.datetime) + + +@pytest.mark.parametrize( + "metric_value", + [1.3, "nan", "inf", "-inf", None], +) +def test_is_input_valid(run_obj, metric_value): + assert run_obj._is_input_valid("metric", "Name", metric_value) + + +@pytest.mark.parametrize( + "metric_value", + [nan, inf, -inf], +) +def test_is_input_valid_false(run_obj, metric_value): + assert not run_obj._is_input_valid("parameter", "Name", metric_value) + + +def test_generate_trial_name(): + base_name = "x" * MAX_NAME_LEN_IN_BACKEND + trial_name = Run._generate_trial_name(base_name=base_name) + assert len(trial_name) <= MAX_NAME_LEN_IN_BACKEND + + +def test_append_run_tc_label_to_tags(): + expected_tc_tag = RUN_TC_TAG + + tags = None + ret = Run._append_run_tc_label_to_tags(tags) + assert len(ret) == 1 + assert expected_tc_tag in ret + + tags = [] + ret = Run._append_run_tc_label_to_tags(tags) + assert len(ret) == 1 + assert expected_tc_tag in ret + + tags = [{"Key": "foo", "Value": "bar"}] + ret = Run._append_run_tc_label_to_tags(tags) + assert len(ret) == 2 + assert expected_tc_tag in ret + + +def _verify_tc_status_before_enter_init(trial_component): + assert not trial_component.start_time + assert not trial_component.end_time + assert not trial_component.status + + +def _verify_tc_status_when_entering(trial_component, init_start_time=None, has_completed=False): + if not init_start_time: + assert isinstance(trial_component.start_time, datetime.datetime) + now = datetime.datetime.now(dateutil.tz.tzlocal()) + assert (now.timestamp() - trial_component.start_time.timestamp()) < 1 + else: + assert trial_component.start_time == init_start_time + + if not has_completed: + assert not trial_component.end_time + assert trial_component.status.primary_status == _TrialComponentStatusType.InProgress.value + + +def _verify_tc_status_when_successfully_exit(trial_component, old_end_time=None): + assert trial_component.status.primary_status == _TrialComponentStatusType.Completed.value + assert isinstance(trial_component.start_time, datetime.datetime) + assert isinstance(trial_component.end_time, datetime.datetime) + if old_end_time: + assert trial_component.end_time > old_end_time + return trial_component.end_time diff --git a/tests/unit/sagemaker/experiments/test_run_context.py b/tests/unit/sagemaker/experiments/test_run_context.py new file mode 100644 index 0000000000..7e068136a1 --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_run_context.py @@ -0,0 +1,191 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest.mock import patch, MagicMock + +import pytest + +from sagemaker.estimator import Estimator, _TrainingJob +from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments.run import _RunContext +from sagemaker.experiments import load_run, Run +from sagemaker.experiments.trial import _Trial +from tests.unit.sagemaker.experiments.helpers import ( + TEST_EXP_NAME, + mock_trial_load_or_create_func, + mock_tc_load_or_create_func, +) + +_bucket = "my-bucket" +_train_input_path = f"s3://{_bucket}/data.csv" +_train_output_path = f"s3://{_bucket}" + + +@patch.object(_TrainingJob, "start_new") +def test_auto_pass_in_exp_config_to_train_job(mock_start_job, run_obj, sagemaker_session): + mock_start_job.return_value = _TrainingJob(sagemaker_session, "my-job") + with run_obj: + estimator = Estimator( + role="arn:my-role", + image_uri="my-image", + sagemaker_session=sagemaker_session, + output_path=_train_output_path, + ) + estimator.fit( + inputs=_train_input_path, + wait=False, + ) + + assert _RunContext.get_current_run() == run_obj + + expected_exp_config = run_obj.experiment_config + mock_start_job.assert_called_once_with(estimator, _train_input_path, expected_exp_config) + + # _RunContext is cleaned up after exiting the with statement + assert not _RunContext.get_current_run() + + +@patch.object(_TrainingJob, "start_new") +def test_user_supply_exp_config_to_train_job(mock_start_job, run_obj, sagemaker_session): + mock_start_job.return_value = _TrainingJob(sagemaker_session, "my-job") + supplied_exp_cfg = { + "ExperimentName": "my-supplied-exp-name", + "TrialName": "my-supplied-run-group-name", + "RunName": "my-supplied-run-name", + } + with run_obj: + estimator = Estimator( + role="arn:my-role", + image_uri="my-image", + sagemaker_session=sagemaker_session, + output_path=_train_output_path, + ) + estimator.fit( + experiment_config=supplied_exp_cfg, + inputs=_train_input_path, + wait=False, + ) + + assert _RunContext.get_current_run() == run_obj + + mock_start_job.assert_called_once_with(estimator, _train_input_path, supplied_exp_cfg) + + # _RunContext is cleaned up after exiting the with statement + assert not _RunContext.get_current_run() + + +def test_auto_fetch_created_run_obj_from_context(run_obj, sagemaker_session): + assert not run_obj._inside_init_context + assert not run_obj._inside_load_context + assert not run_obj._in_load + assert not _RunContext.get_current_run() + + def train(): + with load_run(sagemaker_session=sagemaker_session) as run_load: + assert run_load == run_obj + assert run_obj._inside_init_context + assert run_obj._inside_load_context + assert run_obj._in_load + + run_load.log_parameter("foo", "bar") + run_load.log_parameter("whizz", 1) + + with run_obj: + assert run_obj._inside_init_context + assert not run_obj._inside_load_context + assert not run_obj._in_load + assert _RunContext.get_current_run() + + train() + + assert run_obj._inside_init_context + assert not run_obj._inside_load_context + assert not run_obj._in_load + assert _RunContext.get_current_run() + + run_obj.log_parameters({"a": "b", "c": 2}) + + assert run_obj._trial_component.parameters["foo"] == "bar" + assert run_obj._trial_component.parameters["whizz"] == 1 + assert run_obj._trial_component.parameters["a"] == "b" + assert run_obj._trial_component.parameters["c"] == 2 + + # Verify separating load_run and with statement in different lines still work + run_load2 = load_run(sagemaker_session=sagemaker_session) + with run_load2: + assert run_load2 == run_obj + assert run_obj._inside_init_context + assert run_obj._inside_load_context + assert run_obj._in_load + + assert run_obj._inside_init_context + assert not run_obj._inside_load_context + assert not run_obj._in_load + assert _RunContext.get_current_run() + + assert not run_obj._inside_init_context + assert not run_obj._inside_load_context + assert not run_obj._in_load + assert not _RunContext.get_current_run() + + +def test_nested_run_init_context_on_same_run_object(run_obj, sagemaker_session): + assert not _RunContext.get_current_run() + + with pytest.raises(RuntimeError) as err: + with run_obj: + assert _RunContext.get_current_run() + + with run_obj: + pass + assert "It is not allowed to use nested 'with' statements on the Run" in str(err) + + +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +def test_nested_run_init_context_on_different_run_object(run_obj, sagemaker_session): + assert not _RunContext.get_current_run() + + with pytest.raises(RuntimeError) as err: + with Run(experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session): + assert _RunContext.get_current_run() + + with run_obj: + pass + assert "It is not allowed to use nested 'with' statements on the Run" in str(err) + + +def test_nested_run_load_context(run_obj, sagemaker_session): + assert not _RunContext.get_current_run() + + with pytest.raises(RuntimeError) as err: + with run_obj: + assert _RunContext.get_current_run() + + with load_run(): + run_load = load_run() + with run_load: + pass + assert "It is not allowed to use nested 'with' statements on the load_run" in str(err) diff --git a/tests/unit/sagemaker/experiments/test_trial.py b/tests/unit/sagemaker/experiments/test_trial.py new file mode 100644 index 0000000000..f6996fefc3 --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_trial.py @@ -0,0 +1,276 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest + +import datetime + +from unittest.mock import patch + +from sagemaker import Session +from sagemaker.experiments._api_types import TrialSummary +from sagemaker.experiments.trial import _Trial +from sagemaker.experiments.trial_component import _TrialComponent + + +@pytest.fixture +def datetime_obj(): + return datetime.datetime(2017, 6, 16, 15, 55, 0) + + +def test_load(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.describe_trial.return_value = {"ExperimentName": "experiment-name-value"} + trial_obj = _Trial.load(trial_name="name-value", sagemaker_session=sagemaker_session) + assert trial_obj.trial_name == "name-value" + assert trial_obj.experiment_name == "experiment-name-value" + client.describe_trial.assert_called_with(TrialName="name-value") + + +def test_create(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_trial.return_value = { + "Arn": "arn:aws:1234", + "TrialName": "name-value", + } + trial_obj = _Trial.create( + trial_name="name-value", + experiment_name="experiment-name-value", + sagemaker_session=sagemaker_session, + ) + assert trial_obj.trial_name == "name-value" + client.create_trial.assert_called_with( + TrialName="name-value", ExperimentName="experiment-name-value" + ) + + +def test_create_with_tags(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_trial.return_value = { + "Arn": "arn:aws:1234", + "TrialName": "name-value", + } + tags = [{"Key": "foo", "Value": "bar"}] + trial_obj = _Trial.create( + trial_name="name-value", + experiment_name="experiment-name-value", + sagemaker_session=sagemaker_session, + tags=tags, + ) + assert trial_obj.trial_name == "name-value" + client.create_trial.assert_called_with( + TrialName="name-value", + ExperimentName="experiment-name-value", + Tags=[{"Key": "foo", "Value": "bar"}], + ) + + +def test_delete(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = _Trial(sagemaker_session, trial_name="foo") + client.delete_trial.return_value = {} + obj.delete() + client.delete_trial.assert_called_with(TrialName="foo") + + +def test_save(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = _Trial( + sagemaker_session, + trial_name="foo", + experiment_name="whizz", + display_name="bar", + tags=[{"Key": "foo", "Value": "bar"}], + ) + client.update_trial.return_value = {} + obj.save() + + client.update_trial.assert_called_with( + TrialName="foo", + DisplayName="bar", + ) + + +def test_add_trial_component(sagemaker_session): + client = sagemaker_session.sagemaker_client + trial = _Trial(sagemaker_session=sagemaker_session) + trial.trial_name = "bar" + trial.add_trial_component("foo") + client.associate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="foo") + + tc = _TrialComponent(trial_component_name="tc-foo", sagemaker_session=sagemaker_session) + trial.add_trial_component(tc) + client.associate_trial_component.assert_called_with( + TrialName="bar", TrialComponentName=tc.trial_component_name + ) + + +def test_remove_trial_component(sagemaker_session): + client = sagemaker_session.sagemaker_client + trial = _Trial(sagemaker_session=sagemaker_session) + trial.trial_name = "bar" + trial.remove_trial_component("foo") + client.disassociate_trial_component.assert_called_with( + TrialName="bar", TrialComponentName="foo" + ) + + tc = _TrialComponent(trial_component_name="tc-foo", sagemaker_session=sagemaker_session) + trial.remove_trial_component(tc) + client.disassociate_trial_component.assert_called_with( + TrialName="bar", TrialComponentName=tc.trial_component_name + ) + + +@patch("sagemaker.experiments.trial._Trial.load") +def test_load_or_create_when_exist(mock_load): + sagemaker_session = Session() + trial_name = "trial_name" + exp_name = "exp_name" + + # The trial exists and experiment matches + mock_load.return_value = _Trial( + trial_name=trial_name, + experiment_name=exp_name, + sagemaker_session=sagemaker_session, + ) + _Trial._load_or_create( + trial_name=trial_name, experiment_name=exp_name, sagemaker_session=sagemaker_session + ) + mock_load.assert_called_once_with(trial_name, sagemaker_session) + + # The trial exists but experiment does not match + mock_load.return_value = _Trial( + trial_name=trial_name, + exp_name="another_exp_name", + sagemaker_session=sagemaker_session, + ) + with pytest.raises(ValueError) as err: + _Trial._load_or_create( + trial_name=trial_name, experiment_name=exp_name, sagemaker_session=sagemaker_session + ) + assert "The given experiment_name {} does not match that in the loaded trial".format( + exp_name + ) in str(err) + + +@patch("sagemaker.experiments.trial._Trial.load") +@patch("sagemaker.experiments.trial._Trial.create") +def test_load_or_create_when_not_exist(mock_create, mock_load): + sagemaker_session = Session() + client = sagemaker_session.sagemaker_client + trial_name = "trial_name" + exp_name = "exp_name" + not_found_err = client.exceptions.ResourceNotFound( + error_response={"Error": {"Code": "ResourceNotFound", "Message": "Not Found"}}, + operation_name="foo", + ) + mock_load.side_effect = not_found_err + + _Trial._load_or_create( + trial_name=trial_name, experiment_name=exp_name, sagemaker_session=sagemaker_session + ) + + mock_create.assert_called_once_with( + trial_name=trial_name, + experiment_name=exp_name, + display_name=None, + tags=None, + sagemaker_session=sagemaker_session, + ) + + +def test_list_trials_without_experiment_name(sagemaker_session, datetime_obj): + client = sagemaker_session.sagemaker_client + client.list_trials.return_value = { + "TrialSummaries": [ + { + "TrialName": "trial-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialName": "trial-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + } + expected = [ + TrialSummary( + trial_name="trial-1", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + trial_name="trial-2", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + ] + assert expected == list(_Trial.list(sagemaker_session=sagemaker_session)) + client.list_trials.assert_called_with(**{}) + + +def test_list_trials_with_experiment_name(sagemaker_session, datetime_obj): + client = sagemaker_session.sagemaker_client + client.list_trials.return_value = { + "TrialSummaries": [ + { + "TrialName": "trial-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialName": "trial-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + } + expected = [ + TrialSummary( + trial_name="trial-1", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + trial_name="trial-2", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + ] + assert expected == list(_Trial.list(experiment_name="foo", sagemaker_session=sagemaker_session)) + client.list_trials.assert_called_with(ExperimentName="foo") + + +def test_list_trials_with_trial_component_name(sagemaker_session, datetime_obj): + client = sagemaker_session.sagemaker_client + client.list_trials.return_value = { + "TrialSummaries": [ + { + "TrialName": "trial-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialName": "trial-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + } + expected = [ + TrialSummary( + trial_name="trial-1", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + trial_name="trial-2", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + ] + assert expected == list( + _Trial.list(trial_component_name="tc-foo", sagemaker_session=sagemaker_session) + ) + client.list_trials.assert_called_with(TrialComponentName="tc-foo") diff --git a/tests/unit/sagemaker/experiments/test_trial_component.py b/tests/unit/sagemaker/experiments/test_trial_component.py new file mode 100644 index 0000000000..c14663893e --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_trial_component.py @@ -0,0 +1,384 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime +import unittest.mock + +from unittest.mock import patch + +from sagemaker import Session +from sagemaker.experiments import _api_types +from sagemaker.experiments._api_types import ( + TrialComponentSearchResult, + Parent, + _TrialComponentStatusType, +) +from sagemaker.experiments.trial_component import _TrialComponent + + +def test_create(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_trial_component.return_value = { + "TrialComponentArn": "bazz", + } + obj = _TrialComponent.create( + trial_component_name="foo", display_name="bar", sagemaker_session=sagemaker_session + ) + client.create_trial_component.assert_called_with(TrialComponentName="foo", DisplayName="bar") + assert "foo" == obj.trial_component_name + assert "bar" == obj.display_name + assert "bazz" == obj.trial_component_arn + + +def test_create_with_tags(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_trial_component.return_value = { + "TrialComponentArn": "bazz", + } + tags = [{"Key": "foo", "Value": "bar"}] + _TrialComponent.create( + trial_component_name="foo", + display_name="bar", + sagemaker_session=sagemaker_session, + tags=tags, + ) + client.create_trial_component.assert_called_with( + TrialComponentName="foo", DisplayName="bar", Tags=tags + ) + + +def test_load(sagemaker_session): + now = datetime.datetime.now(datetime.timezone.utc) + client = sagemaker_session.sagemaker_client + client.describe_trial_component.return_value = { + "TrialComponentArn": "A", + "TrialComponentName": "B", + "DisplayName": "C", + "Status": {"PrimaryStatus": _TrialComponentStatusType.InProgress.value, "Message": "D"}, + "Parameters": {"E": {"NumberValue": 1.0}, "F": {"StringValue": "G"}}, + "InputArtifacts": {"H": {"Value": "s3://foo/bar", "MediaType": "text/plain"}}, + "OutputArtifacts": {"I": {"Value": "s3://whizz/bang", "MediaType": "text/plain"}}, + "Metrics": [ + { + "MetricName": "J", + "Count": 1, + "Min": 1.0, + "Max": 2.0, + "Avg": 3.0, + "StdDev": 4.0, + "SourceArn": "K", + "Timestamp": now, + } + ], + } + obj = _TrialComponent.load(trial_component_name="foo", sagemaker_session=sagemaker_session) + client.describe_trial_component.assert_called_with(TrialComponentName="foo") + assert "A" == obj.trial_component_arn + assert "B" == obj.trial_component_name + assert "C" == obj.display_name + assert ( + _api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, message="D" + ) + == obj.status + ) + assert {"E": 1.0, "F": "G"} == obj.parameters + assert {"H": _api_types.TrialComponentArtifact(value="s3://foo/bar", media_type="text/plain")} + assert { + "I": _api_types.TrialComponentArtifact(value="s3://whizz/bang", media_type="text/plain") + } + assert [ + _api_types.TrialComponentMetricSummary( + metric_name="J", + count=1, + min=1.0, + max=2.0, + avg=3.0, + std_dev=4.0, + source_arn="K", + timestamp=now, + ) + ] + + +def test_save(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = _TrialComponent( + sagemaker_session, + trial_component_name="foo", + display_name="bar", + parameters_to_remove=["E"], + input_artifacts_to_remove=["F"], + output_artifacts_to_remove=["G"], + ) + client.update_trial_component.return_value = {} + obj.save() + + client.update_trial_component.assert_called_with( + TrialComponentName="foo", + DisplayName="bar", + Parameters={}, + ParametersToRemove=["E"], + InputArtifacts={}, + InputArtifactsToRemove=["F"], + OutputArtifacts={}, + OutputArtifactsToRemove=["G"], + ) + + +def test_delete(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = _TrialComponent(sagemaker_session, trial_component_name="foo", display_name="bar") + client.delete_trial_component.return_value = {} + obj.delete() + client.delete_trial_component.assert_called_with(TrialComponentName="foo") + + +def test_delete_with_force_disassociate(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = _TrialComponent(sagemaker_session, trial_component_name="foo", display_name="bar") + client.delete_trial_component.return_value = {} + + client.list_trials.side_effect = [ + {"TrialSummaries": [{"TrialName": "trial-1"}, {"TrialName": "trial-2"}], "NextToken": "a"}, + {"TrialSummaries": [{"TrialName": "trial-3"}, {"TrialName": "trial-4"}]}, + ] + + obj.delete(force_disassociate=True) + expected_calls = [ + unittest.mock.call(TrialName="trial-1", TrialComponentName="foo"), + unittest.mock.call(TrialName="trial-2", TrialComponentName="foo"), + unittest.mock.call(TrialName="trial-3", TrialComponentName="foo"), + unittest.mock.call(TrialName="trial-4", TrialComponentName="foo"), + ] + assert expected_calls == client.disassociate_trial_component.mock_calls + client.delete_trial_component.assert_called_with(TrialComponentName="foo") + + +def test_list(sagemaker_session): + start_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1) + end_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=2) + creation_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=3) + last_modified_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=4) + + client = sagemaker_session.sagemaker_client + client.list_trial_components.side_effect = [ + { + "TrialComponentSummaries": [ + { + "TrialComponentName": "A" + str(i), + "TrialComponentArn": "B" + str(i), + "DisplayName": "C" + str(i), + "SourceArn": "D" + str(i), + "Status": { + "PrimaryStatus": _TrialComponentStatusType.InProgress.value, + "Message": "E" + str(i), + }, + "StartTime": start_time + datetime.timedelta(hours=i), + "EndTime": end_time + datetime.timedelta(hours=i), + "CreationTime": creation_time + datetime.timedelta(hours=i), + "LastModifiedTime": last_modified_time + datetime.timedelta(hours=i), + "LastModifiedBy": {}, + } + for i in range(10) + ], + "NextToken": "100", + }, + { + "TrialComponentSummaries": [ + { + "TrialComponentName": "A" + str(i), + "TrialComponentArn": "B" + str(i), + "DisplayName": "C" + str(i), + "SourceArn": "D" + str(i), + "Status": { + "PrimaryStatus": _TrialComponentStatusType.InProgress.value, + "Message": "E" + str(i), + }, + "StartTime": start_time + datetime.timedelta(hours=i), + "EndTime": end_time + datetime.timedelta(hours=i), + "CreationTime": creation_time + datetime.timedelta(hours=i), + "LastModifiedTime": last_modified_time + datetime.timedelta(hours=i), + "LastModifiedBy": {}, + } + for i in range(10, 20) + ] + }, + ] + + expected = [ + _api_types.TrialComponentSummary( + trial_component_name="A" + str(i), + trial_component_arn="B" + str(i), + display_name="C" + str(i), + source_arn="D" + str(i), + status=_api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i) + ), + start_time=start_time + datetime.timedelta(hours=i), + end_time=end_time + datetime.timedelta(hours=i), + creation_time=creation_time + datetime.timedelta(hours=i), + last_modified_time=last_modified_time + datetime.timedelta(hours=i), + last_modified_by={}, + ) + for i in range(20) + ] + result = list( + _TrialComponent.list( + sagemaker_session=sagemaker_session, + source_arn="foo", + sort_by="CreationTime", + sort_order="Ascending", + ) + ) + + assert expected == result + expected_calls = [ + unittest.mock.call(SortBy="CreationTime", SortOrder="Ascending", SourceArn="foo"), + unittest.mock.call( + NextToken="100", SortBy="CreationTime", SortOrder="Ascending", SourceArn="foo" + ), + ] + assert expected_calls == client.list_trial_components.mock_calls + + +def test_list_empty(sagemaker_session): + sagemaker_session.sagemaker_client.list_trial_components.return_value = { + "TrialComponentSummaries": [] + } + assert [] == list(_TrialComponent.list(sagemaker_session=sagemaker_session)) + + +def test_list_trial_components_call_args(sagemaker_session): + created_before = datetime.datetime(1999, 10, 12, 0, 0, 0) + created_after = datetime.datetime(1990, 10, 12, 0, 0, 0) + trial_name = "foo-trial" + experiment_name = "foo-experiment" + next_token = "thetoken" + max_results = 99 + + client = sagemaker_session.sagemaker_client + client.list_trial_components.return_value = {} + assert [] == list( + _TrialComponent.list( + sagemaker_session=sagemaker_session, + trial_name=trial_name, + experiment_name=experiment_name, + created_before=created_before, + created_after=created_after, + next_token=next_token, + max_results=max_results, + sort_by="CreationTime", + sort_order="Ascending", + ) + ) + + expected_calls = [ + unittest.mock.call( + TrialName="foo-trial", + ExperimentName="foo-experiment", + CreatedBefore=created_before, + CreatedAfter=created_after, + SortBy="CreationTime", + SortOrder="Ascending", + NextToken="thetoken", + MaxResults=99, + ) + ] + assert expected_calls == client.list_trial_components.mock_calls + + +@patch("sagemaker.experiments.trial_component._TrialComponent.load") +def test_load_or_create_when_exist(mock_load, sagemaker_session): + tc_name = "tc_name" + _, is_existed = _TrialComponent._load_or_create( + trial_component_name=tc_name, sagemaker_session=sagemaker_session + ) + assert is_existed + mock_load.assert_called_once_with( + tc_name, + sagemaker_session, + ) + + +@patch("sagemaker.experiments.trial_component._TrialComponent.load") +@patch("sagemaker.experiments.trial_component._TrialComponent.create") +def test_load_or_create_when_not_exist(mock_create, mock_load): + sagemaker_session = Session() + client = sagemaker_session.sagemaker_client + tc_name = "tc_name" + not_found_err = client.exceptions.ResourceNotFound( + error_response={"Error": {"Code": "ResourceNotFound", "Message": "Not Found"}}, + operation_name="foo", + ) + mock_load.side_effect = not_found_err + + _, is_existed = _TrialComponent._load_or_create( + trial_component_name=tc_name, sagemaker_session=sagemaker_session + ) + + assert not is_existed + mock_create.assert_called_once_with( + trial_component_name=tc_name, + display_name=None, + tags=None, + sagemaker_session=sagemaker_session, + ) + + +def test_search(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.search.return_value = { + "Results": [ + { + "TrialComponent": { + "TrialComponentName": "tc-1", + "TrialComponentArn": "arn::tc-1", + "DisplayName": "TC1", + "Parents": [ + { + "ExperimentName": "e-1", + "TrialName": "t-1", + }, + { + "ExperimentName": "e-2", + "TrialName": "t-2", + }, + ], + } + }, + { + "TrialComponent": { + "TrialComponentName": "tc-2", + "TrialComponentArn": "arn::tc-2", + "DisplayName": "TC2", + } + }, + ] + } + expected = [ + TrialComponentSearchResult( + trial_component_name="tc-1", + trial_component_arn="arn::tc-1", + display_name="TC1", + parents=[ + Parent(experiment_name="e-1", trial_name="t-1"), + Parent(experiment_name="e-2", trial_name="t-2"), + ], + ), + TrialComponentSearchResult( + trial_component_name="tc-2", trial_component_arn="arn::tc-2", display_name="TC2" + ), + ] + assert expected == list(_TrialComponent.search(sagemaker_session=sagemaker_session)) diff --git a/tests/unit/sagemaker/experiments/test_utils.py b/tests/unit/sagemaker/experiments/test_utils.py new file mode 100644 index 0000000000..a63c96c0fe --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_utils.py @@ -0,0 +1,36 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from src.sagemaker.experiments._utils import resolve_artifact_name, guess_media_type + + +def test_resolve_artifact_name(): + file_names = { + "a": "a", + "a.txt": "a.txt", + "b.": "b.", + ".c": ".c", + "/x/a/a.txt": "a.txt", + "/a/b/c.": "c.", + "./.a": ".a", + "../b.txt": "b.txt", + "~/a.txt": "a.txt", + "c/d.txt": "d.txt", + } + for file_name, artifact_name in file_names.items(): + assert artifact_name == resolve_artifact_name(file_name) + + +def test_guess_media_type(): + assert "text/plain" == guess_media_type("foo.txt") diff --git a/tests/unit/sagemaker/huggingface/test_estimator.py b/tests/unit/sagemaker/huggingface/test_estimator.py index c391d45382..0088e34c58 100644 --- a/tests/unit/sagemaker/huggingface/test_estimator.py +++ b/tests/unit/sagemaker/huggingface/test_estimator.py @@ -48,6 +48,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } diff --git a/tests/unit/sagemaker/tensorflow/test_estimator.py b/tests/unit/sagemaker/tensorflow/test_estimator.py index 2e7576421f..fea80b7ea9 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator.py @@ -56,6 +56,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py index af46cf4360..d35c0a51dd 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py @@ -52,6 +52,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py index 5aef9316da..7645c4fe23 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py @@ -50,6 +50,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } diff --git a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py index 7517f3a641..1ce58a19b4 100644 --- a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py @@ -50,6 +50,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } diff --git a/tests/unit/sagemaker/utilities/test_search_expression.py b/tests/unit/sagemaker/utilities/test_search_expression.py new file mode 100644 index 0000000000..98a52a992a --- /dev/null +++ b/tests/unit/sagemaker/utilities/test_search_expression.py @@ -0,0 +1,80 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest + +from sagemaker.utilities.search_expression import ( + Filter, + Operator, + NestedFilter, + SearchExpression, + BooleanOperator, +) + + +def test_filters(): + search_filter = Filter(name="learning_rate", operator=Operator.EQUALS, value="0.1") + + assert { + "Name": "learning_rate", + "Operator": "Equals", + "Value": "0.1", + } == search_filter.to_boto() + + +def test_partial_filters(): + search_filter = Filter(name="learning_rate") + + assert {"Name": "learning_rate"} == search_filter.to_boto() + + +def test_nested_filters(): + search_filter = Filter(name="learning_rate", operator=Operator.EQUALS, value="0.1") + filters = [search_filter] + nested_filters = NestedFilter(property_name="hyper_param", filters=filters) + + assert { + "Filters": [{"Name": "learning_rate", "Operator": "Equals", "Value": "0.1"}], + "NestedPropertyName": "hyper_param", + } == nested_filters.to_boto() + + +def test_search_expression(): + search_filter = Filter(name="learning_rate", operator=Operator.EQUALS, value="0.1") + nested_filter = NestedFilter(property_name="hyper_param", filters=[search_filter]) + search_expression = SearchExpression( + filters=[search_filter], + nested_filters=[nested_filter], + sub_expressions=[], + boolean_operator=BooleanOperator.AND, + ) + + assert { + "Filters": [{"Name": "learning_rate", "Operator": "Equals", "Value": "0.1"}], + "NestedFilters": [ + { + "Filters": [{"Name": "learning_rate", "Operator": "Equals", "Value": "0.1"}], + "NestedPropertyName": "hyper_param", + } + ], + "SubExpressions": [], + "Operator": "And", + } == search_expression.to_boto() + + +def test_illegal_search_expression(): + with pytest.raises( + ValueError, match="You must specify at least one subexpression, filter, or nested filter" + ): + SearchExpression() diff --git a/tests/unit/sagemaker/workflow/test_clarify_check_step.py b/tests/unit/sagemaker/workflow/test_clarify_check_step.py index feadaa03dc..54b354b71e 100644 --- a/tests/unit/sagemaker/workflow/test_clarify_check_step.py +++ b/tests/unit/sagemaker/workflow/test_clarify_check_step.py @@ -16,10 +16,6 @@ import re import pytest -import sagemaker - -from mock import Mock, PropertyMock - from sagemaker.clarify import ( DataConfig, BiasConfig, @@ -50,46 +46,6 @@ _S3_ANALYSIS_CONFIG_OUTPUT_PATH = "s3://my_bucket/analysis_cfg_output" -@pytest.fixture -def boto_session(): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value=_ROLE) - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name=_REGION) - session_mock.resource.return_value = resource_mock - - return session_mock - - -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return sagemaker.session.Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket=_DEFAULT_BUCKET, - ) - - _expected_data_bias_dsl = { "Name": "DataBiasCheckStep", "Type": "ClarifyCheck", diff --git a/tests/unit/sagemaker/workflow/test_entities.py b/tests/unit/sagemaker/workflow/test_entities.py index 6f0be2ccca..a36207b241 100644 --- a/tests/unit/sagemaker/workflow/test_entities.py +++ b/tests/unit/sagemaker/workflow/test_entities.py @@ -19,9 +19,6 @@ from enum import Enum -from mock.mock import Mock, PropertyMock - -import sagemaker from sagemaker.workflow.condition_step import ConditionStep from sagemaker.workflow.conditions import ConditionGreaterThan from sagemaker.workflow.entities import ( @@ -58,46 +55,6 @@ def custom_entity_list(): return [CustomEntity(1), CustomEntity(2)] -@pytest.fixture -def boto_session(): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value="role") - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name="us-west-2") - session_mock.resource.return_value = resource_mock - - return session_mock - - -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return sagemaker.session.Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket="my-bucket", - ) - - def test_entity(custom_entity): request_struct = {"foo": 1} assert custom_entity.to_request() == request_struct diff --git a/tests/unit/sagemaker/workflow/test_quality_check_step.py b/tests/unit/sagemaker/workflow/test_quality_check_step.py index b60e2de8fa..dc104d71df 100644 --- a/tests/unit/sagemaker/workflow/test_quality_check_step.py +++ b/tests/unit/sagemaker/workflow/test_quality_check_step.py @@ -15,10 +15,6 @@ import json import pytest -import sagemaker - -from mock import Mock, PropertyMock - from sagemaker.model_monitor import DatasetFormat from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.pipeline import Pipeline @@ -31,49 +27,7 @@ from sagemaker.workflow.steps import CacheConfig from sagemaker.workflow.check_job_config import CheckJobConfig -_REGION = "us-west-2" _ROLE = "DummyRole" -_BUCKET = "my-bucket" - - -@pytest.fixture -def boto_session(): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value=_ROLE) - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name=_REGION) - session_mock.resource.return_value = resource_mock - - return session_mock - - -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return sagemaker.session.Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket=_BUCKET, - ) _expected_data_quality_dsl = { diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index 9887d43078..ba712d11d7 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -16,15 +16,10 @@ import json import pytest -import sagemaker import os import warnings -from mock import ( - Mock, - PropertyMock, - patch, -) +from mock import patch from sagemaker.debugger import ProfilerConfig from sagemaker.estimator import Estimator @@ -94,46 +89,6 @@ def create_predictor(self, endpoint_name): return Predictor(endpoint_name, self.sagemaker_session) -@pytest.fixture -def boto_session(): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value=ROLE) - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name=REGION) - session_mock.resource.return_value = resource_mock - - return session_mock - - -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return sagemaker.session.Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket=BUCKET, - ) - - @pytest.fixture def script_processor(sagemaker_session): return ScriptProcessor( diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py index 82b154317d..44b5818fc8 100644 --- a/tests/unit/test_amazon_estimator.py +++ b/tests/unit/test_amazon_estimator.py @@ -225,6 +225,9 @@ def test_fit_ndarray(time, sagemaker_session): assert mock_object.put.call_count == 4 + called_args = sagemaker_session.train.call_args + assert not called_args[1]["experiment_config"] + def test_fit_pass_experiment_config(sagemaker_session): kwargs = dict(COMMON_ARGS) @@ -239,12 +242,18 @@ def test_fit_pass_experiment_config(sagemaker_session): labels = [99, 85, 87, 2] pca.fit( pca.record_set(np.array(train), np.array(labels)), - experiment_config={"ExperimentName": "exp"}, + experiment_config={ + "ExperimentName": "exp", + "RunName": "rn", + }, ) called_args = sagemaker_session.train.call_args - assert called_args[1]["experiment_config"] == {"ExperimentName": "exp"} + assert called_args[1]["experiment_config"] == { + "ExperimentName": "exp", + "RunName": "rn", + } def test_build_shards(): diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 34e6a43fcf..868da88d78 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -2489,7 +2489,12 @@ def test_start_new(sagemaker_session): hyperparameters=hyperparameters, ) - exp_config = {"ExperimentName": "exp", "TrialName": "t", "TrialComponentDisplayName": "tc"} + exp_config = { + "ExperimentName": "exp", + "TrialName": "t", + "TrialComponentDisplayName": "tc", + "RunName": "rn", + } started_training_job = training_job.start_new(estimator, inputs, experiment_config=exp_config) called_args = sagemaker_session.train.call_args @@ -2680,6 +2685,7 @@ def test_unsupported_type_in_dict(): "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } } ) @@ -2884,6 +2890,7 @@ def test_generic_to_fit_with_experiment_config(time, sagemaker_session): "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", }, ) diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 99b0e839b7..9ba3e17ff3 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -62,6 +62,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } MODEL_PKG_RESPONSE = {"ModelPackageArn": "arn:model-pkg-arn"} diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 082f699d63..c8aad13774 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -54,6 +54,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } DISTRIBUTION_PYTORCH_DDP_ENABLED = {"pytorchddp": {"enabled": True}} diff --git a/tests/unit/test_rl.py b/tests/unit/test_rl.py index 4efc2e5bf8..2035636e76 100644 --- a/tests/unit/test_rl.py +++ b/tests/unit/test_rl.py @@ -49,6 +49,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index d7c94470f5..ec4a21cbc9 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -588,11 +588,16 @@ def test_user_agent_injected(boto_session): assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_client._client_config.user_agent assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_runtime_client._client_config.user_agent + assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_metrics_client._client_config.user_agent assert "AWS-SageMaker-Notebook-Instance" not in sess.sagemaker_client._client_config.user_agent assert ( "AWS-SageMaker-Notebook-Instance" not in sess.sagemaker_runtime_client._client_config.user_agent ) + assert ( + "AWS-SageMaker-Notebook-Instance" + not in sess.sagemaker_metrics_client._client_config.user_agent + ) def test_user_agent_injected_with_nbi(boto_session): @@ -607,10 +612,14 @@ def test_user_agent_injected_with_nbi(boto_session): assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_client._client_config.user_agent assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_runtime_client._client_config.user_agent + assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_metrics_client._client_config.user_agent assert "AWS-SageMaker-Notebook-Instance" in sess.sagemaker_client._client_config.user_agent assert ( "AWS-SageMaker-Notebook-Instance" in sess.sagemaker_runtime_client._client_config.user_agent ) + assert ( + "AWS-SageMaker-Notebook-Instance" in sess.sagemaker_metrics_client._client_config.user_agent + ) def test_user_agent_injected_with_nbi_ioerror(boto_session): @@ -625,11 +634,16 @@ def test_user_agent_injected_with_nbi_ioerror(boto_session): assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_client._client_config.user_agent assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_runtime_client._client_config.user_agent + assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_metrics_client._client_config.user_agent assert "AWS-SageMaker-Notebook-Instance" not in sess.sagemaker_client._client_config.user_agent assert ( "AWS-SageMaker-Notebook-Instance" not in sess.sagemaker_runtime_client._client_config.user_agent ) + assert ( + "AWS-SageMaker-Notebook-Instance" + not in sess.sagemaker_metrics_client._client_config.user_agent + ) def test_training_input_all_defaults(): @@ -700,6 +714,7 @@ def test_training_input_all_arguments(): "ExperimentName": "dummyExp", "TrialName": "dummyT", "TrialComponentDisplayName": "dummyTC", + "RunName": "dummyRN", } MODEL_CLIENT_CONFIG = {"InvocationsMaxRetries": 2, "InvocationsTimeoutInSeconds": 60} diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index 13cc755336..c3e984e0b7 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -51,6 +51,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 0eb81be584..8bcbed41c2 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -25,10 +25,12 @@ from boto3 import exceptions import botocore import pytest -from mock import call, patch, Mock, MagicMock +from mock import call, patch, Mock, MagicMock, PropertyMock import sagemaker +from sagemaker.experiments._run_context import _RunContext from sagemaker.session_settings import SessionSettings +from sagemaker.utils import retry_with_backoff, check_and_get_run_experiment_config from tests.unit.sagemaker.workflow.helpers import CustomStep from sagemaker.workflow.parameters import ParameterString, ParameterInteger @@ -795,3 +797,63 @@ def test_start_waiting(capfd): out, _ = capfd.readouterr() assert "." * sagemaker.utils.WAITING_DOT_NUMBER in out + + +def test_retry_with_backoff(): + callable_func = Mock() + + # Invalid input + with pytest.raises(ValueError) as value_err: + retry_with_backoff(callable_func, 0) + assert "The num_attempts must be >= 1" in str(value_err) + callable_func.assert_not_called() + + # All retries fail + run_err_msg = "Test Retry Error" + callable_func.side_effect = RuntimeError(run_err_msg) + with pytest.raises(RuntimeError) as run_err: + retry_with_backoff(callable_func, 2) + assert run_err_msg in str(run_err) + + # One retry passes + func_return_val = "Test Return" + callable_func.side_effect = [RuntimeError(run_err_msg), func_return_val] + assert retry_with_backoff(callable_func, 2) == func_return_val + + # No retry + callable_func.side_effect = None + callable_func.return_value = func_return_val + assert retry_with_backoff(callable_func, 2) == func_return_val + + +def test_check_and_get_run_experiment_config(): + supplied_exp_cfg = {"ExperimentName": "my-supplied-exp-name", "RunName": "my-supplied-run-name"} + run_exp_cfg = {"ExperimentName": "my-run-exp-name", "RunName": "my-run-run-name"} + + # No user supplied exp config and no current Run + assert not _RunContext.get_current_run() + exp_cfg1 = check_and_get_run_experiment_config(None) + assert exp_cfg1 is None + + # With user supplied exp config and no current Run + assert not _RunContext.get_current_run() + exp_cfg2 = check_and_get_run_experiment_config(supplied_exp_cfg) + assert exp_cfg2 == supplied_exp_cfg + + run = Mock() + type(run).experiment_config = PropertyMock(return_value=run_exp_cfg) + _RunContext.add_run_object(run) + + try: + # No user supplied exp config and with current Run + assert _RunContext.get_current_run().experiment_config == run_exp_cfg + exp_cfg3 = check_and_get_run_experiment_config(None) + assert exp_cfg3 == run_exp_cfg + + # With user supplied exp config and current Run + assert _RunContext.get_current_run().experiment_config == run_exp_cfg + exp_cfg4 = check_and_get_run_experiment_config(supplied_exp_cfg) + assert exp_cfg4 == supplied_exp_cfg + finally: + # Clean up the global static variable in case it affects other tests + _RunContext.drop_current_run() diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py index 82f27c19ae..d58c4992cd 100644 --- a/tests/unit/test_xgboost.py +++ b/tests/unit/test_xgboost.py @@ -54,6 +54,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", }