From 4f14a4bba07469019be1f2b7799395a48a0930e6 Mon Sep 17 00:00:00 2001 From: savin Date: Mon, 5 Feb 2024 19:39:50 -0800 Subject: [PATCH] Support distributed map in AWS Step Functions --- .../aws/step_functions/dynamo_db_client.py | 2 ++ .../aws/step_functions/production_token.py | 2 +- .../aws/step_functions/step_functions.py | 34 +++++++++++++++++-- .../aws/step_functions/step_functions_cli.py | 29 ++++++++++++---- .../step_functions_decorator.py | 2 +- 5 files changed, 58 insertions(+), 11 deletions(-) diff --git a/metaflow/plugins/aws/step_functions/dynamo_db_client.py b/metaflow/plugins/aws/step_functions/dynamo_db_client.py index caff1b4a343..d14fbbb0de4 100644 --- a/metaflow/plugins/aws/step_functions/dynamo_db_client.py +++ b/metaflow/plugins/aws/step_functions/dynamo_db_client.py @@ -1,5 +1,7 @@ import os + import requests + from metaflow.metaflow_config import SFN_DYNAMO_DB_TABLE diff --git a/metaflow/plugins/aws/step_functions/production_token.py b/metaflow/plugins/aws/step_functions/production_token.py index 1561cba1ab5..e2cceedbd04 100644 --- a/metaflow/plugins/aws/step_functions/production_token.py +++ b/metaflow/plugins/aws/step_functions/production_token.py @@ -1,5 +1,5 @@ -import os import json +import os import random import string import zlib diff --git a/metaflow/plugins/aws/step_functions/step_functions.py b/metaflow/plugins/aws/step_functions/step_functions.py index 7dd6854b725..502d5b98f75 100644 --- a/metaflow/plugins/aws/step_functions/step_functions.py +++ b/metaflow/plugins/aws/step_functions/step_functions.py @@ -52,6 +52,7 @@ def __init__( max_workers=None, workflow_timeout=None, is_project=False, + use_distributed_map=False, ): self.name = name self.graph = graph @@ -70,6 +71,9 @@ def __init__( self.max_workers = max_workers self.workflow_timeout = workflow_timeout + # https://aws.amazon.com/blogs/aws/step-functions-distributed-map-a-serverless-solution-for-large-scale-parallel-data-processing/ + self.use_distributed_map = use_distributed_map + self._client = StepFunctionsClient() self._workflow = self._compile() self._cron = self._cron() @@ -318,7 +322,11 @@ def _visit(node, workflow, exit_node=None): .iterator( _visit( self.graph[node.out_funcs[0]], - Workflow(node.out_funcs[0]).start_at(node.out_funcs[0]), + Workflow(node.out_funcs[0]) + .start_at(node.out_funcs[0]) + .mode( + "DISTRIBUTED" if self.use_distributed_map else "INLINE" + ), node.matching_join, ) ) @@ -393,7 +401,7 @@ def _batch(self, node): "metaflow.owner": self.username, "metaflow.flow_name": self.flow.name, "metaflow.step_name": node.name, - "metaflow.run_id.$": "$$.Execution.Name", + # "metaflow.run_id.$": "$$.Execution.Name", # Unfortunately we can't set the task id here since AWS Step # Functions lacks any notion of run-scoped task identifiers. We # instead co-opt the AWS Batch job id as the task id. This also @@ -423,6 +431,11 @@ def _batch(self, node): # specification that allows us to set key-values. "step_name": node.name, } + # # metaflow.run_id maps to AWS Step Functions State Machine Execution in all + # # cases except for when within a for-each construct that relies on Distributed + # # Map. To work around this issue, within a for-each, we lean on reading off of + # # AWS DynamoDb to get the run id. + # attrs["metaflow.run_id.$"] = "$$.Execution.Name" # Store production token within the `start` step, so that subsequent # `step-functions create` calls can perform a rudimentary authorization @@ -441,6 +454,13 @@ def _batch(self, node): env["METAFLOW_S3_ENDPOINT_URL"] = S3_ENDPOINT_URL if node.name == "start": + # metaflow.run_id maps to AWS Step Functions State Machine Execution in all + # cases except for when within a for-each construct that relies on + # Distributed Map. To work around this issue, we pass the run id from the + # start step to all subsequent tasks. + attrs["metaflow.run_id.$"] = "$$.Execution.Name" + attrs["run_id.$"] = "$$.Execution.Name" + # Initialize parameters for the flow in the `start` step. parameters = self._process_parameters() if parameters: @@ -475,7 +495,9 @@ def _batch(self, node): raise StepFunctionsException( "Parallel steps are not supported yet with AWS step functions." ) - + # Inherit the run id from the parent and pass it along to children. + attrs["metaflow.run_id.$"] = "$.Parameters.run_id" + attrs["run_id.$"] = "$.Parameters.run_id" # Handle foreach join. if ( node.type == "join" @@ -842,6 +864,12 @@ def __init__(self, name): tree = lambda: defaultdict(tree) self.payload = tree() + def mode(self, mode): + self.payload["ProcessorConfig"] = {"Mode": mode} + if mode == "DISTRIBUTED": + self.payload["ProcessorConfig"]["ExecutionType"] = "STANDARD" + return self + def start_at(self, start_at): self.payload["StartAt"] = start_at return self diff --git a/metaflow/plugins/aws/step_functions/step_functions_cli.py b/metaflow/plugins/aws/step_functions/step_functions_cli.py index 63b1b645424..f28a82e8ef6 100644 --- a/metaflow/plugins/aws/step_functions/step_functions_cli.py +++ b/metaflow/plugins/aws/step_functions/step_functions_cli.py @@ -1,23 +1,23 @@ import base64 -from metaflow._vendor import click -from hashlib import sha1 import json import re +from hashlib import sha1 -from metaflow import current, decorators, parameters, JSONType +from metaflow import JSONType, current, decorators, parameters +from metaflow._vendor import click +from metaflow.exception import MetaflowException, MetaflowInternalError from metaflow.metaflow_config import ( SERVICE_VERSION_CHECK, SFN_STATE_MACHINE_PREFIX, UI_URL, ) -from metaflow.exception import MetaflowException, MetaflowInternalError from metaflow.package import MetaflowPackage from metaflow.plugins.aws.batch.batch_decorator import BatchDecorator from metaflow.tagging_util import validate_tags from metaflow.util import get_username, to_bytes, to_unicode, version_parse +from .production_token import load_token, new_token, store_token from .step_functions import StepFunctions -from .production_token import load_token, store_token, new_token VALID_NAME = re.compile(r"[^a-zA-Z0-9_\-\.]") @@ -120,6 +120,12 @@ def step_functions(obj, name=None): help="Log AWS Step Functions execution history to AWS CloudWatch " "Logs log group.", ) +@click.option( + "--use-distributed-map/--no-use-distributed-map", + is_flag=True, + help="Use AWS Step Functions Distributed Map instead of Inline Map for " + "defining foreach tasks in Amazon State Language.", +) @click.pass_obj def create( obj, @@ -132,6 +138,7 @@ def create( max_workers=None, workflow_timeout=None, log_execution_history=False, + use_distributed_map=False, ): validate_tags(tags) @@ -161,6 +168,7 @@ def create( max_workers, workflow_timeout, obj.is_project, + use_distributed_map, ) if only_json: @@ -269,7 +277,15 @@ def attach_prefix(name): def make_flow( - obj, token, name, tags, namespace, max_workers, workflow_timeout, is_project + obj, + token, + name, + tags, + namespace, + max_workers, + workflow_timeout, + is_project, + use_distributed_map, ): if obj.flow_datastore.TYPE != "s3": raise MetaflowException("AWS Step Functions requires --datastore=s3.") @@ -305,6 +321,7 @@ def make_flow( username=get_username(), workflow_timeout=workflow_timeout, is_project=is_project, + use_distributed_map=use_distributed_map, ) diff --git a/metaflow/plugins/aws/step_functions/step_functions_decorator.py b/metaflow/plugins/aws/step_functions/step_functions_decorator.py index ba71c0af8df..89a7de79857 100644 --- a/metaflow/plugins/aws/step_functions/step_functions_decorator.py +++ b/metaflow/plugins/aws/step_functions/step_functions_decorator.py @@ -1,5 +1,5 @@ -import os import json +import os import time from metaflow.decorators import StepDecorator