From 0d89a1a1628ba0eaa9bc5ed1444bf7aa58c6b535 Mon Sep 17 00:00:00 2001 From: savin Date: Thu, 11 Jan 2024 20:57:26 -0800 Subject: [PATCH 01/13] support git repos --- metaflow/plugins/pypi/pip.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/metaflow/plugins/pypi/pip.py b/metaflow/plugins/pypi/pip.py index b2a773919b6..da96574c2cb 100644 --- a/metaflow/plugins/pypi/pip.py +++ b/metaflow/plugins/pypi/pip.py @@ -257,6 +257,13 @@ def indices(self, prefix): return index, extras + def build(self, id_, packages, python, platform): + prefix = self.micromamba.path_to_environment(id_) + build_metadata_file = BUILD_METADATA_FILE.format(prefix=prefix) + # skip build if already built. + if os.path.isfile(build_metadata_file): + return + def _call(self, prefix, args, env=None, isolated=True): if env is None: env = {} From 708f06d2fdf495b1f59a29619b13a06732fba115 Mon Sep 17 00:00:00 2001 From: savin Date: Thu, 11 Jan 2024 20:59:36 -0800 Subject: [PATCH 02/13] remove spurious function --- metaflow/plugins/pypi/pip.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/metaflow/plugins/pypi/pip.py b/metaflow/plugins/pypi/pip.py index da96574c2cb..b2a773919b6 100644 --- a/metaflow/plugins/pypi/pip.py +++ b/metaflow/plugins/pypi/pip.py @@ -257,13 +257,6 @@ def indices(self, prefix): return index, extras - def build(self, id_, packages, python, platform): - prefix = self.micromamba.path_to_environment(id_) - build_metadata_file = BUILD_METADATA_FILE.format(prefix=prefix) - # skip build if already built. - if os.path.isfile(build_metadata_file): - return - def _call(self, prefix, args, env=None, isolated=True): if env is None: env = {} From f30238d0702762558a18674fb17c2a99a70073e8 Mon Sep 17 00:00:00 2001 From: savin Date: Sun, 14 Jan 2024 19:45:10 -0800 Subject: [PATCH 03/13] fix formatting --- metaflow/plugins/pypi/bootstrap.py | 1 + 1 file changed, 1 insertion(+) diff --git a/metaflow/plugins/pypi/bootstrap.py b/metaflow/plugins/pypi/bootstrap.py index a26a07ac578..ba676da5301 100644 --- a/metaflow/plugins/pypi/bootstrap.py +++ b/metaflow/plugins/pypi/bootstrap.py @@ -114,6 +114,7 @@ [package["path"] for package in env["pypi"]] ) as results: for key, tmpfile, _ in results: + print(key) dest = os.path.join(pypi_pkgs_dir, os.path.basename(key)) os.makedirs(os.path.dirname(dest), exist_ok=True) shutil.move(tmpfile, dest) From 8c07a053da9e9761e67f96cd0ae79f543c414a05 Mon Sep 17 00:00:00 2001 From: savin Date: Tue, 16 Jan 2024 12:52:00 -0800 Subject: [PATCH 04/13] handle race condition with local packages metadata --- metaflow/plugins/pypi/pip.py | 1 + 1 file changed, 1 insertion(+) diff --git a/metaflow/plugins/pypi/pip.py b/metaflow/plugins/pypi/pip.py index b2a773919b6..5b03efe4cbf 100644 --- a/metaflow/plugins/pypi/pip.py +++ b/metaflow/plugins/pypi/pip.py @@ -85,6 +85,7 @@ def solve(self, id_, packages, python, platform): def _format(dl_info): res = {k: v for k, v in dl_info.items() if k in ["url"]} + print(res) # If source url is not a wheel, we need to build the target. res["require_build"] = not res["url"].endswith(".whl") From fbea2ce111348080576c9f98ef68103022c69b79 Mon Sep 17 00:00:00 2001 From: savin Date: Tue, 16 Jan 2024 12:54:57 -0800 Subject: [PATCH 05/13] remove print --- metaflow/plugins/pypi/bootstrap.py | 1 - metaflow/plugins/pypi/pip.py | 1 - 2 files changed, 2 deletions(-) diff --git a/metaflow/plugins/pypi/bootstrap.py b/metaflow/plugins/pypi/bootstrap.py index ba676da5301..a26a07ac578 100644 --- a/metaflow/plugins/pypi/bootstrap.py +++ b/metaflow/plugins/pypi/bootstrap.py @@ -114,7 +114,6 @@ [package["path"] for package in env["pypi"]] ) as results: for key, tmpfile, _ in results: - print(key) dest = os.path.join(pypi_pkgs_dir, os.path.basename(key)) os.makedirs(os.path.dirname(dest), exist_ok=True) shutil.move(tmpfile, dest) diff --git a/metaflow/plugins/pypi/pip.py b/metaflow/plugins/pypi/pip.py index 5b03efe4cbf..b2a773919b6 100644 --- a/metaflow/plugins/pypi/pip.py +++ b/metaflow/plugins/pypi/pip.py @@ -85,7 +85,6 @@ def solve(self, id_, packages, python, platform): def _format(dl_info): res = {k: v for k, v in dl_info.items() if k in ["url"]} - print(res) # If source url is not a wheel, we need to build the target. res["require_build"] = not res["url"].endswith(".whl") From 9bb485f6cf69b5c0d2318216177e89f848935340 Mon Sep 17 00:00:00 2001 From: savin Date: Mon, 5 Feb 2024 19:39:50 -0800 Subject: [PATCH 06/13] Support distributed map in AWS Step Functions --- .../aws/step_functions/dynamo_db_client.py | 2 ++ .../aws/step_functions/production_token.py | 2 +- .../aws/step_functions/step_functions.py | 34 +++++++++++++++++-- .../aws/step_functions/step_functions_cli.py | 29 ++++++++++++---- .../step_functions_decorator.py | 2 +- 5 files changed, 58 insertions(+), 11 deletions(-) diff --git a/metaflow/plugins/aws/step_functions/dynamo_db_client.py b/metaflow/plugins/aws/step_functions/dynamo_db_client.py index caff1b4a343..d14fbbb0de4 100644 --- a/metaflow/plugins/aws/step_functions/dynamo_db_client.py +++ b/metaflow/plugins/aws/step_functions/dynamo_db_client.py @@ -1,5 +1,7 @@ import os + import requests + from metaflow.metaflow_config import SFN_DYNAMO_DB_TABLE diff --git a/metaflow/plugins/aws/step_functions/production_token.py b/metaflow/plugins/aws/step_functions/production_token.py index 1561cba1ab5..e2cceedbd04 100644 --- a/metaflow/plugins/aws/step_functions/production_token.py +++ b/metaflow/plugins/aws/step_functions/production_token.py @@ -1,5 +1,5 @@ -import os import json +import os import random import string import zlib diff --git a/metaflow/plugins/aws/step_functions/step_functions.py b/metaflow/plugins/aws/step_functions/step_functions.py index 261dc98fdcd..f860b953705 100644 --- a/metaflow/plugins/aws/step_functions/step_functions.py +++ b/metaflow/plugins/aws/step_functions/step_functions.py @@ -52,6 +52,7 @@ def __init__( max_workers=None, workflow_timeout=None, is_project=False, + use_distributed_map=False, ): self.name = name self.graph = graph @@ -70,6 +71,9 @@ def __init__( self.max_workers = max_workers self.workflow_timeout = workflow_timeout + # https://aws.amazon.com/blogs/aws/step-functions-distributed-map-a-serverless-solution-for-large-scale-parallel-data-processing/ + self.use_distributed_map = use_distributed_map + self._client = StepFunctionsClient() self._workflow = self._compile() self._cron = self._cron() @@ -369,7 +373,11 @@ def _visit(node, workflow, exit_node=None): .iterator( _visit( self.graph[node.out_funcs[0]], - Workflow(node.out_funcs[0]).start_at(node.out_funcs[0]), + Workflow(node.out_funcs[0]) + .start_at(node.out_funcs[0]) + .mode( + "DISTRIBUTED" if self.use_distributed_map else "INLINE" + ), node.matching_join, ) ) @@ -444,7 +452,7 @@ def _batch(self, node): "metaflow.owner": self.username, "metaflow.flow_name": self.flow.name, "metaflow.step_name": node.name, - "metaflow.run_id.$": "$$.Execution.Name", + # "metaflow.run_id.$": "$$.Execution.Name", # Unfortunately we can't set the task id here since AWS Step # Functions lacks any notion of run-scoped task identifiers. We # instead co-opt the AWS Batch job id as the task id. This also @@ -474,6 +482,11 @@ def _batch(self, node): # specification that allows us to set key-values. "step_name": node.name, } + # # metaflow.run_id maps to AWS Step Functions State Machine Execution in all + # # cases except for when within a for-each construct that relies on Distributed + # # Map. To work around this issue, within a for-each, we lean on reading off of + # # AWS DynamoDb to get the run id. + # attrs["metaflow.run_id.$"] = "$$.Execution.Name" # Store production token within the `start` step, so that subsequent # `step-functions create` calls can perform a rudimentary authorization @@ -492,6 +505,13 @@ def _batch(self, node): env["METAFLOW_S3_ENDPOINT_URL"] = S3_ENDPOINT_URL if node.name == "start": + # metaflow.run_id maps to AWS Step Functions State Machine Execution in all + # cases except for when within a for-each construct that relies on + # Distributed Map. To work around this issue, we pass the run id from the + # start step to all subsequent tasks. + attrs["metaflow.run_id.$"] = "$$.Execution.Name" + attrs["run_id.$"] = "$$.Execution.Name" + # Initialize parameters for the flow in the `start` step. parameters = self._process_parameters() if parameters: @@ -526,7 +546,9 @@ def _batch(self, node): raise StepFunctionsException( "Parallel steps are not supported yet with AWS step functions." ) - + # Inherit the run id from the parent and pass it along to children. + attrs["metaflow.run_id.$"] = "$.Parameters.run_id" + attrs["run_id.$"] = "$.Parameters.run_id" # Handle foreach join. if ( node.type == "join" @@ -893,6 +915,12 @@ def __init__(self, name): tree = lambda: defaultdict(tree) self.payload = tree() + def mode(self, mode): + self.payload["ProcessorConfig"] = {"Mode": mode} + if mode == "DISTRIBUTED": + self.payload["ProcessorConfig"]["ExecutionType"] = "STANDARD" + return self + def start_at(self, start_at): self.payload["StartAt"] = start_at return self diff --git a/metaflow/plugins/aws/step_functions/step_functions_cli.py b/metaflow/plugins/aws/step_functions/step_functions_cli.py index 3ee971dd0f8..82e13fe537f 100644 --- a/metaflow/plugins/aws/step_functions/step_functions_cli.py +++ b/metaflow/plugins/aws/step_functions/step_functions_cli.py @@ -1,23 +1,23 @@ import base64 -from metaflow._vendor import click -from hashlib import sha1 import json import re +from hashlib import sha1 -from metaflow import current, decorators, parameters, JSONType +from metaflow import JSONType, current, decorators, parameters +from metaflow._vendor import click +from metaflow.exception import MetaflowException, MetaflowInternalError from metaflow.metaflow_config import ( SERVICE_VERSION_CHECK, SFN_STATE_MACHINE_PREFIX, UI_URL, ) -from metaflow.exception import MetaflowException, MetaflowInternalError from metaflow.package import MetaflowPackage from metaflow.plugins.aws.batch.batch_decorator import BatchDecorator from metaflow.tagging_util import validate_tags from metaflow.util import get_username, to_bytes, to_unicode, version_parse +from .production_token import load_token, new_token, store_token from .step_functions import StepFunctions -from .production_token import load_token, store_token, new_token VALID_NAME = re.compile(r"[^a-zA-Z0-9_\-\.]") @@ -124,6 +124,12 @@ def step_functions(obj, name=None): help="Log AWS Step Functions execution history to AWS CloudWatch " "Logs log group.", ) +@click.option( + "--use-distributed-map/--no-use-distributed-map", + is_flag=True, + help="Use AWS Step Functions Distributed Map instead of Inline Map for " + "defining foreach tasks in Amazon State Language.", +) @click.pass_obj def create( obj, @@ -136,6 +142,7 @@ def create( max_workers=None, workflow_timeout=None, log_execution_history=False, + use_distributed_map=False, ): validate_tags(tags) @@ -165,6 +172,7 @@ def create( max_workers, workflow_timeout, obj.is_project, + use_distributed_map, ) if only_json: @@ -273,7 +281,15 @@ def attach_prefix(name): def make_flow( - obj, token, name, tags, namespace, max_workers, workflow_timeout, is_project + obj, + token, + name, + tags, + namespace, + max_workers, + workflow_timeout, + is_project, + use_distributed_map, ): if obj.flow_datastore.TYPE != "s3": raise MetaflowException("AWS Step Functions requires --datastore=s3.") @@ -309,6 +325,7 @@ def make_flow( username=get_username(), workflow_timeout=workflow_timeout, is_project=is_project, + use_distributed_map=use_distributed_map, ) diff --git a/metaflow/plugins/aws/step_functions/step_functions_decorator.py b/metaflow/plugins/aws/step_functions/step_functions_decorator.py index ba71c0af8df..89a7de79857 100644 --- a/metaflow/plugins/aws/step_functions/step_functions_decorator.py +++ b/metaflow/plugins/aws/step_functions/step_functions_decorator.py @@ -1,5 +1,5 @@ -import os import json +import os import time from metaflow.decorators import StepDecorator From 164c2ad3fad4ad8c47134ceea676ef8e4f881760 Mon Sep 17 00:00:00 2001 From: savin Date: Mon, 5 Feb 2024 19:42:50 -0800 Subject: [PATCH 07/13] add comments --- metaflow/plugins/aws/step_functions/step_functions.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/metaflow/plugins/aws/step_functions/step_functions.py b/metaflow/plugins/aws/step_functions/step_functions.py index f860b953705..c20868b0224 100644 --- a/metaflow/plugins/aws/step_functions/step_functions.py +++ b/metaflow/plugins/aws/step_functions/step_functions.py @@ -452,7 +452,6 @@ def _batch(self, node): "metaflow.owner": self.username, "metaflow.flow_name": self.flow.name, "metaflow.step_name": node.name, - # "metaflow.run_id.$": "$$.Execution.Name", # Unfortunately we can't set the task id here since AWS Step # Functions lacks any notion of run-scoped task identifiers. We # instead co-opt the AWS Batch job id as the task id. This also @@ -464,6 +463,10 @@ def _batch(self, node): # `$$.State.RetryCount` resolves to an int dynamically and # AWS Batch job specification only accepts strings. We handle # retries/catch within AWS Batch to get around this limitation. + # And, we also cannot set the run id here since the run id maps to + # the execution name of the AWS Step Functions State Machine, which + # is different when executing inside a distributed map. We set it once + # in the start step and move it along to be consumed by all the children. "metaflow.version": self.environment.get_environment_info()[ "metaflow_version" ], @@ -482,11 +485,6 @@ def _batch(self, node): # specification that allows us to set key-values. "step_name": node.name, } - # # metaflow.run_id maps to AWS Step Functions State Machine Execution in all - # # cases except for when within a for-each construct that relies on Distributed - # # Map. To work around this issue, within a for-each, we lean on reading off of - # # AWS DynamoDb to get the run id. - # attrs["metaflow.run_id.$"] = "$$.Execution.Name" # Store production token within the `start` step, so that subsequent # `step-functions create` calls can perform a rudimentary authorization From 9d7ec351ebde58dd344dd0b17021d1928b253389 Mon Sep 17 00:00:00 2001 From: savin Date: Mon, 5 Feb 2024 20:54:47 -0800 Subject: [PATCH 08/13] add plenty of jitters and retries --- .../aws/step_functions/step_functions.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/metaflow/plugins/aws/step_functions/step_functions.py b/metaflow/plugins/aws/step_functions/step_functions.py index c20868b0224..6f6b68eac79 100644 --- a/metaflow/plugins/aws/step_functions/step_functions.py +++ b/metaflow/plugins/aws/step_functions/step_functions.py @@ -544,9 +544,11 @@ def _batch(self, node): raise StepFunctionsException( "Parallel steps are not supported yet with AWS step functions." ) + # Inherit the run id from the parent and pass it along to children. attrs["metaflow.run_id.$"] = "$.Parameters.run_id" attrs["run_id.$"] = "$.Parameters.run_id" + # Handle foreach join. if ( node.type == "join" @@ -970,6 +972,10 @@ def _partition(self): # This is needed to support AWS Gov Cloud and AWS CN regions return SFN_IAM_ROLE.split(":")[1] + def retry_strategy(self, retry_strategy): + self.payload["Retry"] = [retry_strategy] + return self + def batch(self, job): self.resource( "arn:%s:states:::batch:submitJob.sync" % self._partition() @@ -989,6 +995,19 @@ def batch(self, job): # tags may not be present in all scenarios if "tags" in job.payload: self.parameter("Tags", job.payload["tags"]) + # set retry strategy for AWS Batch job submission to account for the + # measily 50 jobs / second queue admission limit which people can + # run into very quickly. + self.retry_strategy( + { + "ErrorEquals": ["Batch.AWSBatchException"], + "BackoffRate": 2, + "IntervalSeconds": 2, + "MaxDelaySeconds": 60, + "MaxAttempts": 10, + "JitterStrategy": "FULL", + } + ) return self def dynamo_db(self, table_name, primary_key, values): From b7df18ad219cc4ea20f02855fc927fee3a999e20 Mon Sep 17 00:00:00 2001 From: savin Date: Tue, 6 Feb 2024 09:04:11 -0800 Subject: [PATCH 09/13] add retries --- .../aws/step_functions/dynamo_db_client.py | 33 +++++++++++++++---- .../aws/step_functions/step_functions.py | 13 +++++--- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/metaflow/plugins/aws/step_functions/dynamo_db_client.py b/metaflow/plugins/aws/step_functions/dynamo_db_client.py index d14fbbb0de4..c36707a2f3c 100644 --- a/metaflow/plugins/aws/step_functions/dynamo_db_client.py +++ b/metaflow/plugins/aws/step_functions/dynamo_db_client.py @@ -1,5 +1,5 @@ import os - +import time import requests from metaflow.metaflow_config import SFN_DYNAMO_DB_TABLE @@ -27,12 +27,31 @@ def save_foreach_cardinality(self, foreach_split_task_id, foreach_cardinality, t def save_parent_task_id_for_foreach_join( self, foreach_split_task_id, foreach_join_parent_task_id ): - return self._client.update_item( - TableName=self.name, - Key={"pathspec": {"S": foreach_split_task_id}}, - UpdateExpression="ADD parent_task_ids_for_foreach_join :val", - ExpressionAttributeValues={":val": {"SS": [foreach_join_parent_task_id]}}, - ) + ex = None + for attempt in range(10): + try: + return self._client.update_item( + TableName=self.name, + Key={"pathspec": {"S": foreach_split_task_id}}, + UpdateExpression="ADD parent_task_ids_for_foreach_join :val", + ExpressionAttributeValues={ + ":val": {"SS": [foreach_join_parent_task_id]} + }, + ) + except self._client.exceptions.ClientError as error: + ex = error + if ( + error.response["Error"]["Code"] + == "ProvisionedThroughputExceededException" + ): + # hopefully, enough time for AWS to scale up! otherwise + # ensure sufficient on-demand throughput for dynamo db + # is provisioned ahead of time + sleep_time = min((2**attempt) * 10, 60) + time.sleep(sleep_time) + else: + raise + raise ex def get_parent_task_ids_for_foreach_join(self, foreach_split_task_id): response = self._client.get_item( diff --git a/metaflow/plugins/aws/step_functions/step_functions.py b/metaflow/plugins/aws/step_functions/step_functions.py index 6f6b68eac79..642b0c5fb46 100644 --- a/metaflow/plugins/aws/step_functions/step_functions.py +++ b/metaflow/plugins/aws/step_functions/step_functions.py @@ -545,10 +545,6 @@ def _batch(self, node): "Parallel steps are not supported yet with AWS step functions." ) - # Inherit the run id from the parent and pass it along to children. - attrs["metaflow.run_id.$"] = "$.Parameters.run_id" - attrs["run_id.$"] = "$.Parameters.run_id" - # Handle foreach join. if ( node.type == "join" @@ -572,6 +568,9 @@ def _batch(self, node): env["METAFLOW_SPLIT_PARENT_TASK_ID"] = ( "$.Parameters.split_parent_task_id_%s" % node.split_parents[-1] ) + # Inherit the run id from the parent and pass it along to children. + attrs["metaflow.run_id.$"] = "$.Parameters.run_id" + attrs["run_id.$"] = "$.Parameters.run_id" else: # Set appropriate environment variables for runtime replacement. if len(node.in_funcs) == 1: @@ -580,6 +579,9 @@ def _batch(self, node): % node.in_funcs[0] ) env["METAFLOW_PARENT_TASK_ID"] = "$.JobId" + # Inherit the run id from the parent and pass it along to children. + attrs["metaflow.run_id.$"] = "$.Parameters.run_id" + attrs["run_id.$"] = "$.Parameters.run_id" else: # Generate the input paths in a quasi-compressed format. # See util.decompress_list for why this is written the way @@ -589,6 +591,9 @@ def _batch(self, node): "${METAFLOW_PARENT_%s_TASK_ID}" % (idx, idx) for idx, _ in enumerate(node.in_funcs) ) + # Inherit the run id from the parent and pass it along to children. + attrs["metaflow.run_id.$"] = "$.[0].Parameters.run_id" + attrs["run_id.$"] = "$.[0].Parameters.run_id" for idx, _ in enumerate(node.in_funcs): env["METAFLOW_PARENT_%s_TASK_ID" % idx] = "$.[%s].JobId" % idx env["METAFLOW_PARENT_%s_STEP" % idx] = ( From 8b064a9314c5a36e944d6023c8f4b25d3e9c2a58 Mon Sep 17 00:00:00 2001 From: savin Date: Fri, 9 Feb 2024 11:50:50 -0700 Subject: [PATCH 10/13] add s3 path --- metaflow/metaflow_config.py | 8 +- .../aws/step_functions/dynamo_db_client.py | 1 + .../aws/step_functions/step_functions.py | 123 ++++++++++++++++-- 3 files changed, 122 insertions(+), 10 deletions(-) diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index f5a690603e4..7cad6a71e17 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -268,7 +268,13 @@ # machine execution logs. This needs to be available when using the # `step-functions create --log-execution-history` command. SFN_EXECUTION_LOG_GROUP_ARN = from_conf("SFN_EXECUTION_LOG_GROUP_ARN") - +# Amazon S3 path for storing the results of AWS Step Functions Distributed Map +SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH = from_conf( + "SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH", + os.path.join(DATASTORE_SYSROOT_S3, "sfn_distributed_map_output") + if DATASTORE_SYSROOT_S3 + else None, +) ### # Kubernetes configuration ### diff --git a/metaflow/plugins/aws/step_functions/dynamo_db_client.py b/metaflow/plugins/aws/step_functions/dynamo_db_client.py index c36707a2f3c..c1e06de156c 100644 --- a/metaflow/plugins/aws/step_functions/dynamo_db_client.py +++ b/metaflow/plugins/aws/step_functions/dynamo_db_client.py @@ -1,5 +1,6 @@ import os import time + import requests from metaflow.metaflow_config import SFN_DYNAMO_DB_TABLE diff --git a/metaflow/plugins/aws/step_functions/step_functions.py b/metaflow/plugins/aws/step_functions/step_functions.py index 642b0c5fb46..aea04e3afa1 100644 --- a/metaflow/plugins/aws/step_functions/step_functions.py +++ b/metaflow/plugins/aws/step_functions/step_functions.py @@ -15,6 +15,7 @@ SFN_DYNAMO_DB_TABLE, SFN_EXECUTION_LOG_GROUP_ARN, SFN_IAM_ROLE, + SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH, ) from metaflow.parameters import deploy_time_eval from metaflow.util import dict_to_cli_options, to_pascalcase @@ -369,7 +370,11 @@ def _visit(node, workflow, exit_node=None): .parameter("SplitParentTaskId.$", "$.JobId") .parameter("Parameters.$", "$.Parameters") .parameter("Index.$", "$$.Map.Item.Value") - .next(node.matching_join) + .next( + "%s_*GetManifest" % iterator_name + if self.use_distributed_map + else node.matching_join + ) .iterator( _visit( self.graph[node.out_funcs[0]], @@ -382,8 +387,54 @@ def _visit(node, workflow, exit_node=None): ) ) .max_concurrency(self.max_workers) - .output_path("$.[0]") + # AWS Step Functions has a short coming for DistributedMap at the + # moment that does not allow us to subset the output of for-each + # to just a single element. We have to rely on a rather terrible + # hack and resort to using ResultWriter to write the state to + # Amazon S3 and process it in another task. But, well what can we + # do... + .result_writer( + *( + SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH.rsplit("/", 1) + if self.use_distributed_map + else () + ) + ) + .output_path("$" if self.use_distributed_map else "$.[0]") ) + if self.use_distributed_map: + workflow.add_state( + State("%s_*GetManifest" % iterator_name) + .resource("arn:aws:states:::aws-sdk:s3:getObject") + .parameter("Bucket.$", "$.ResultWriterDetails.Bucket") + .parameter("Key.$", "$.ResultWriterDetails.Key") + .next("%s_*Map" % iterator_name) + .result_selector("Body.$", "States.StringToJson($.Body)") + ) + workflow.add_state( + Map("%s_*Map" % iterator_name) + .iterator( + Workflow("%s_*PassWorkflow" % iterator_name) + .mode("DISTRIBUTED") + .start_at("%s_*Pass" % iterator_name) + .add_state( + Pass("%s_*Pass" % iterator_name) + .end() + .parameter("Output.$", "States.StringToJson($.Output)") + .output_path("$.Output") + ) + ) + .next(node.matching_join) + .max_concurrency(1000) + .item_reader( + JSONItemReader() + .resource("arn:aws:states:::s3:getObject") + .parameter("Bucket.$", "$.Body.DestinationBucket") + .parameter("Key.$", "$.Body.ResultFiles.SUCCEEDED.[0].Key") + ) + .output_path("$.[0]") + ) + # Continue the traversal from the matching_join. _visit(self.graph[node.matching_join], workflow, exit_node) # We shouldn't ideally ever get here. @@ -508,7 +559,6 @@ def _batch(self, node): # Distributed Map. To work around this issue, we pass the run id from the # start step to all subsequent tasks. attrs["metaflow.run_id.$"] = "$$.Execution.Name" - attrs["run_id.$"] = "$$.Execution.Name" # Initialize parameters for the flow in the `start` step. parameters = self._process_parameters() @@ -569,8 +619,7 @@ def _batch(self, node): "$.Parameters.split_parent_task_id_%s" % node.split_parents[-1] ) # Inherit the run id from the parent and pass it along to children. - attrs["metaflow.run_id.$"] = "$.Parameters.run_id" - attrs["run_id.$"] = "$.Parameters.run_id" + attrs["metaflow.run_id.$"] = "$.Parameters.['metaflow.run_id']" else: # Set appropriate environment variables for runtime replacement. if len(node.in_funcs) == 1: @@ -580,8 +629,7 @@ def _batch(self, node): ) env["METAFLOW_PARENT_TASK_ID"] = "$.JobId" # Inherit the run id from the parent and pass it along to children. - attrs["metaflow.run_id.$"] = "$.Parameters.run_id" - attrs["run_id.$"] = "$.Parameters.run_id" + attrs["metaflow.run_id.$"] = "$.Parameters.['metaflow.run_id']" else: # Generate the input paths in a quasi-compressed format. # See util.decompress_list for why this is written the way @@ -592,8 +640,7 @@ def _batch(self, node): for idx, _ in enumerate(node.in_funcs) ) # Inherit the run id from the parent and pass it along to children. - attrs["metaflow.run_id.$"] = "$.[0].Parameters.run_id" - attrs["run_id.$"] = "$.[0].Parameters.run_id" + attrs["metaflow.run_id.$"] = "$.[0].Parameters.['metaflow.run_id']" for idx, _ in enumerate(node.in_funcs): env["METAFLOW_PARENT_%s_TASK_ID" % idx] = "$.[%s].JobId" % idx env["METAFLOW_PARENT_%s_STEP" % idx] = ( @@ -973,6 +1020,10 @@ def result_path(self, result_path): self.payload["ResultPath"] = result_path return self + def result_selector(self, name, value): + self.payload["ResultSelector"][name] = value + return self + def _partition(self): # This is needed to support AWS Gov Cloud and AWS CN regions return SFN_IAM_ROLE.split(":")[1] @@ -1026,6 +1077,26 @@ def dynamo_db(self, table_name, primary_key, values): return self +class Pass(object): + def __init__(self, name): + self.name = name + tree = lambda: defaultdict(tree) + self.payload = tree() + self.payload["Type"] = "Pass" + + def end(self): + self.payload["End"] = True + return self + + def parameter(self, name, value): + self.payload["Parameters"][name] = value + return self + + def output_path(self, output_path): + self.payload["OutputPath"] = output_path + return self + + class Parallel(object): def __init__(self, name): self.name = name @@ -1087,3 +1158,37 @@ def output_path(self, output_path): def result_path(self, result_path): self.payload["ResultPath"] = result_path return self + + def item_reader(self, item_reader): + self.payload["ItemReader"] = item_reader.payload + return self + + def result_writer(self, bucket, prefix): + if bucket is not None and prefix is not None: + self.payload["ResultWriter"] = { + "Resource": "arn:aws:states:::s3:putObject", + "Parameters": { + "Bucket": bucket, + "Prefix": prefix, + }, + } + return self + + +class JSONItemReader(object): + def __init__(self): + tree = lambda: defaultdict(tree) + self.payload = tree() + self.payload["ReaderConfig"] = {"InputType": "JSON", "MaxItems": 1} + + def resource(self, resource): + self.payload["Resource"] = resource + return self + + def parameter(self, name, value): + self.payload["Parameters"][name] = value + return self + + def output_path(self, output_path): + self.payload["OutputPath"] = output_path + return self From 7cf9522f6dc791596c5361842160f04c085f7d31 Mon Sep 17 00:00:00 2001 From: savin Date: Fri, 9 Feb 2024 15:31:06 -0700 Subject: [PATCH 11/13] fix bucket --- metaflow/plugins/aws/step_functions/step_functions.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/metaflow/plugins/aws/step_functions/step_functions.py b/metaflow/plugins/aws/step_functions/step_functions.py index aea04e3afa1..81396eb1a49 100644 --- a/metaflow/plugins/aws/step_functions/step_functions.py +++ b/metaflow/plugins/aws/step_functions/step_functions.py @@ -395,9 +395,11 @@ def _visit(node, workflow, exit_node=None): # do... .result_writer( *( - SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH.rsplit("/", 1) + "s3://" + + SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH[5:].split("/", 1)[0], + SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH[5:].split("/", 1)[1] if self.use_distributed_map - else () + else (), ) ) .output_path("$" if self.use_distributed_map else "$.[0]") From dfe468b5c5c6546dc6677f133d69e8948d1e1324 Mon Sep 17 00:00:00 2001 From: savin Date: Fri, 9 Feb 2024 16:00:05 -0700 Subject: [PATCH 12/13] fix --- .../plugins/aws/step_functions/step_functions.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/metaflow/plugins/aws/step_functions/step_functions.py b/metaflow/plugins/aws/step_functions/step_functions.py index 81396eb1a49..dd062f5078e 100644 --- a/metaflow/plugins/aws/step_functions/step_functions.py +++ b/metaflow/plugins/aws/step_functions/step_functions.py @@ -395,11 +395,18 @@ def _visit(node, workflow, exit_node=None): # do... .result_writer( *( - "s3://" - + SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH[5:].split("/", 1)[0], - SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH[5:].split("/", 1)[1] + ( + ( + SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH[len("s3://") :] + if SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH.startswith( + "s3://" + ) + else SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH + ).split("/", 1) + + [""] + )[:2] if self.use_distributed_map - else (), + else () ) ) .output_path("$" if self.use_distributed_map else "$.[0]") From ca8c4851f7575d6ee89367c51073860999334d1a Mon Sep 17 00:00:00 2001 From: savin Date: Fri, 16 Feb 2024 09:33:43 -0800 Subject: [PATCH 13/13] fix --- metaflow/plugins/aws/step_functions/step_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metaflow/plugins/aws/step_functions/step_functions.py b/metaflow/plugins/aws/step_functions/step_functions.py index dd062f5078e..5ad3bfceca3 100644 --- a/metaflow/plugins/aws/step_functions/step_functions.py +++ b/metaflow/plugins/aws/step_functions/step_functions.py @@ -406,7 +406,7 @@ def _visit(node, workflow, exit_node=None): + [""] )[:2] if self.use_distributed_map - else () + else (None, None) ) ) .output_path("$" if self.use_distributed_map else "$.[0]")