Skip to content

Commit

Permalink
Distributed Map Support in AWS Step Functions (#1720)
Browse files Browse the repository at this point in the history
* support git repos

* remove spurious function

* fix formatting

* handle race condition with local packages metadata

* remove print

* Support distributed map in AWS Step Functions

* add comments

* add plenty of jitters and retries

* add retries

* add s3 path

* fix bucket

* fix

* fix
  • Loading branch information
savingoyal authored Feb 16, 2024
1 parent d95e7b6 commit 2588f1d
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 19 deletions.
8 changes: 7 additions & 1 deletion metaflow/metaflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,13 @@
# machine execution logs. This needs to be available when using the
# `step-functions create --log-execution-history` command.
SFN_EXECUTION_LOG_GROUP_ARN = from_conf("SFN_EXECUTION_LOG_GROUP_ARN")

# Amazon S3 path for storing the results of AWS Step Functions Distributed Map
SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH = from_conf(
"SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH",
os.path.join(DATASTORE_SYSROOT_S3, "sfn_distributed_map_output")
if DATASTORE_SYSROOT_S3
else None,
)
###
# Kubernetes configuration
###
Expand Down
34 changes: 28 additions & 6 deletions metaflow/plugins/aws/step_functions/dynamo_db_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
import time

import requests

from metaflow.metaflow_config import SFN_DYNAMO_DB_TABLE


Expand All @@ -25,12 +28,31 @@ def save_foreach_cardinality(self, foreach_split_task_id, foreach_cardinality, t
def save_parent_task_id_for_foreach_join(
self, foreach_split_task_id, foreach_join_parent_task_id
):
return self._client.update_item(
TableName=self.name,
Key={"pathspec": {"S": foreach_split_task_id}},
UpdateExpression="ADD parent_task_ids_for_foreach_join :val",
ExpressionAttributeValues={":val": {"SS": [foreach_join_parent_task_id]}},
)
ex = None
for attempt in range(10):
try:
return self._client.update_item(
TableName=self.name,
Key={"pathspec": {"S": foreach_split_task_id}},
UpdateExpression="ADD parent_task_ids_for_foreach_join :val",
ExpressionAttributeValues={
":val": {"SS": [foreach_join_parent_task_id]}
},
)
except self._client.exceptions.ClientError as error:
ex = error
if (
error.response["Error"]["Code"]
== "ProvisionedThroughputExceededException"
):
# hopefully, enough time for AWS to scale up! otherwise
# ensure sufficient on-demand throughput for dynamo db
# is provisioned ahead of time
sleep_time = min((2**attempt) * 10, 60)
time.sleep(sleep_time)
else:
raise
raise ex

def get_parent_task_ids_for_foreach_join(self, foreach_split_task_id):
response = self._client.get_item(
Expand Down
2 changes: 1 addition & 1 deletion metaflow/plugins/aws/step_functions/production_token.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import json
import os
import random
import string
import zlib
Expand Down
172 changes: 168 additions & 4 deletions metaflow/plugins/aws/step_functions/step_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
SFN_DYNAMO_DB_TABLE,
SFN_EXECUTION_LOG_GROUP_ARN,
SFN_IAM_ROLE,
SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH,
)
from metaflow.parameters import deploy_time_eval
from metaflow.util import dict_to_cli_options, to_pascalcase
Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(
max_workers=None,
workflow_timeout=None,
is_project=False,
use_distributed_map=False,
):
self.name = name
self.graph = graph
Expand All @@ -70,6 +72,9 @@ def __init__(
self.max_workers = max_workers
self.workflow_timeout = workflow_timeout

# https://aws.amazon.com/blogs/aws/step-functions-distributed-map-a-serverless-solution-for-large-scale-parallel-data-processing/
self.use_distributed_map = use_distributed_map

self._client = StepFunctionsClient()
self._workflow = self._compile()
self._cron = self._cron()
Expand Down Expand Up @@ -365,17 +370,80 @@ def _visit(node, workflow, exit_node=None):
.parameter("SplitParentTaskId.$", "$.JobId")
.parameter("Parameters.$", "$.Parameters")
.parameter("Index.$", "$$.Map.Item.Value")
.next(node.matching_join)
.next(
"%s_*GetManifest" % iterator_name
if self.use_distributed_map
else node.matching_join
)
.iterator(
_visit(
self.graph[node.out_funcs[0]],
Workflow(node.out_funcs[0]).start_at(node.out_funcs[0]),
Workflow(node.out_funcs[0])
.start_at(node.out_funcs[0])
.mode(
"DISTRIBUTED" if self.use_distributed_map else "INLINE"
),
node.matching_join,
)
)
.max_concurrency(self.max_workers)
.output_path("$.[0]")
# AWS Step Functions has a short coming for DistributedMap at the
# moment that does not allow us to subset the output of for-each
# to just a single element. We have to rely on a rather terrible
# hack and resort to using ResultWriter to write the state to
# Amazon S3 and process it in another task. But, well what can we
# do...
.result_writer(
*(
(
(
SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH[len("s3://") :]
if SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH.startswith(
"s3://"
)
else SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH
).split("/", 1)
+ [""]
)[:2]
if self.use_distributed_map
else (None, None)
)
)
.output_path("$" if self.use_distributed_map else "$.[0]")
)
if self.use_distributed_map:
workflow.add_state(
State("%s_*GetManifest" % iterator_name)
.resource("arn:aws:states:::aws-sdk:s3:getObject")
.parameter("Bucket.$", "$.ResultWriterDetails.Bucket")
.parameter("Key.$", "$.ResultWriterDetails.Key")
.next("%s_*Map" % iterator_name)
.result_selector("Body.$", "States.StringToJson($.Body)")
)
workflow.add_state(
Map("%s_*Map" % iterator_name)
.iterator(
Workflow("%s_*PassWorkflow" % iterator_name)
.mode("DISTRIBUTED")
.start_at("%s_*Pass" % iterator_name)
.add_state(
Pass("%s_*Pass" % iterator_name)
.end()
.parameter("Output.$", "States.StringToJson($.Output)")
.output_path("$.Output")
)
)
.next(node.matching_join)
.max_concurrency(1000)
.item_reader(
JSONItemReader()
.resource("arn:aws:states:::s3:getObject")
.parameter("Bucket.$", "$.Body.DestinationBucket")
.parameter("Key.$", "$.Body.ResultFiles.SUCCEEDED.[0].Key")
)
.output_path("$.[0]")
)

# Continue the traversal from the matching_join.
_visit(self.graph[node.matching_join], workflow, exit_node)
# We shouldn't ideally ever get here.
Expand Down Expand Up @@ -444,7 +512,6 @@ def _batch(self, node):
"metaflow.owner": self.username,
"metaflow.flow_name": self.flow.name,
"metaflow.step_name": node.name,
"metaflow.run_id.$": "$$.Execution.Name",
# Unfortunately we can't set the task id here since AWS Step
# Functions lacks any notion of run-scoped task identifiers. We
# instead co-opt the AWS Batch job id as the task id. This also
Expand All @@ -456,6 +523,10 @@ def _batch(self, node):
# `$$.State.RetryCount` resolves to an int dynamically and
# AWS Batch job specification only accepts strings. We handle
# retries/catch within AWS Batch to get around this limitation.
# And, we also cannot set the run id here since the run id maps to
# the execution name of the AWS Step Functions State Machine, which
# is different when executing inside a distributed map. We set it once
# in the start step and move it along to be consumed by all the children.
"metaflow.version": self.environment.get_environment_info()[
"metaflow_version"
],
Expand Down Expand Up @@ -492,6 +563,12 @@ def _batch(self, node):
env["METAFLOW_S3_ENDPOINT_URL"] = S3_ENDPOINT_URL

if node.name == "start":
# metaflow.run_id maps to AWS Step Functions State Machine Execution in all
# cases except for when within a for-each construct that relies on
# Distributed Map. To work around this issue, we pass the run id from the
# start step to all subsequent tasks.
attrs["metaflow.run_id.$"] = "$$.Execution.Name"

# Initialize parameters for the flow in the `start` step.
parameters = self._process_parameters()
if parameters:
Expand Down Expand Up @@ -550,6 +627,8 @@ def _batch(self, node):
env["METAFLOW_SPLIT_PARENT_TASK_ID"] = (
"$.Parameters.split_parent_task_id_%s" % node.split_parents[-1]
)
# Inherit the run id from the parent and pass it along to children.
attrs["metaflow.run_id.$"] = "$.Parameters.['metaflow.run_id']"
else:
# Set appropriate environment variables for runtime replacement.
if len(node.in_funcs) == 1:
Expand All @@ -558,6 +637,8 @@ def _batch(self, node):
% node.in_funcs[0]
)
env["METAFLOW_PARENT_TASK_ID"] = "$.JobId"
# Inherit the run id from the parent and pass it along to children.
attrs["metaflow.run_id.$"] = "$.Parameters.['metaflow.run_id']"
else:
# Generate the input paths in a quasi-compressed format.
# See util.decompress_list for why this is written the way
Expand All @@ -567,6 +648,8 @@ def _batch(self, node):
"${METAFLOW_PARENT_%s_TASK_ID}" % (idx, idx)
for idx, _ in enumerate(node.in_funcs)
)
# Inherit the run id from the parent and pass it along to children.
attrs["metaflow.run_id.$"] = "$.[0].Parameters.['metaflow.run_id']"
for idx, _ in enumerate(node.in_funcs):
env["METAFLOW_PARENT_%s_TASK_ID" % idx] = "$.[%s].JobId" % idx
env["METAFLOW_PARENT_%s_STEP" % idx] = (
Expand Down Expand Up @@ -893,6 +976,12 @@ def __init__(self, name):
tree = lambda: defaultdict(tree)
self.payload = tree()

def mode(self, mode):
self.payload["ProcessorConfig"] = {"Mode": mode}
if mode == "DISTRIBUTED":
self.payload["ProcessorConfig"]["ExecutionType"] = "STANDARD"
return self

def start_at(self, start_at):
self.payload["StartAt"] = start_at
return self
Expand Down Expand Up @@ -940,10 +1029,18 @@ def result_path(self, result_path):
self.payload["ResultPath"] = result_path
return self

def result_selector(self, name, value):
self.payload["ResultSelector"][name] = value
return self

def _partition(self):
# This is needed to support AWS Gov Cloud and AWS CN regions
return SFN_IAM_ROLE.split(":")[1]

def retry_strategy(self, retry_strategy):
self.payload["Retry"] = [retry_strategy]
return self

def batch(self, job):
self.resource(
"arn:%s:states:::batch:submitJob.sync" % self._partition()
Expand All @@ -963,6 +1060,19 @@ def batch(self, job):
# tags may not be present in all scenarios
if "tags" in job.payload:
self.parameter("Tags", job.payload["tags"])
# set retry strategy for AWS Batch job submission to account for the
# measily 50 jobs / second queue admission limit which people can
# run into very quickly.
self.retry_strategy(
{
"ErrorEquals": ["Batch.AWSBatchException"],
"BackoffRate": 2,
"IntervalSeconds": 2,
"MaxDelaySeconds": 60,
"MaxAttempts": 10,
"JitterStrategy": "FULL",
}
)
return self

def dynamo_db(self, table_name, primary_key, values):
Expand All @@ -976,6 +1086,26 @@ def dynamo_db(self, table_name, primary_key, values):
return self


class Pass(object):
def __init__(self, name):
self.name = name
tree = lambda: defaultdict(tree)
self.payload = tree()
self.payload["Type"] = "Pass"

def end(self):
self.payload["End"] = True
return self

def parameter(self, name, value):
self.payload["Parameters"][name] = value
return self

def output_path(self, output_path):
self.payload["OutputPath"] = output_path
return self


class Parallel(object):
def __init__(self, name):
self.name = name
Expand Down Expand Up @@ -1037,3 +1167,37 @@ def output_path(self, output_path):
def result_path(self, result_path):
self.payload["ResultPath"] = result_path
return self

def item_reader(self, item_reader):
self.payload["ItemReader"] = item_reader.payload
return self

def result_writer(self, bucket, prefix):
if bucket is not None and prefix is not None:
self.payload["ResultWriter"] = {
"Resource": "arn:aws:states:::s3:putObject",
"Parameters": {
"Bucket": bucket,
"Prefix": prefix,
},
}
return self


class JSONItemReader(object):
def __init__(self):
tree = lambda: defaultdict(tree)
self.payload = tree()
self.payload["ReaderConfig"] = {"InputType": "JSON", "MaxItems": 1}

def resource(self, resource):
self.payload["Resource"] = resource
return self

def parameter(self, name, value):
self.payload["Parameters"][name] = value
return self

def output_path(self, output_path):
self.payload["OutputPath"] = output_path
return self
Loading

0 comments on commit 2588f1d

Please sign in to comment.