Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed Map Support in AWS Step Functions #1720

Merged
merged 13 commits into from
Feb 16, 2024
8 changes: 7 additions & 1 deletion metaflow/metaflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,13 @@
# machine execution logs. This needs to be available when using the
# `step-functions create --log-execution-history` command.
SFN_EXECUTION_LOG_GROUP_ARN = from_conf("SFN_EXECUTION_LOG_GROUP_ARN")

# Amazon S3 path for storing the results of AWS Step Functions Distributed Map
SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH = from_conf(
"SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH",
os.path.join(DATASTORE_SYSROOT_S3, "sfn_distributed_map_output")
if DATASTORE_SYSROOT_S3
else None,
)
###
# Kubernetes configuration
###
Expand Down
34 changes: 28 additions & 6 deletions metaflow/plugins/aws/step_functions/dynamo_db_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
import time

import requests

from metaflow.metaflow_config import SFN_DYNAMO_DB_TABLE


Expand All @@ -25,12 +28,31 @@ def save_foreach_cardinality(self, foreach_split_task_id, foreach_cardinality, t
def save_parent_task_id_for_foreach_join(
self, foreach_split_task_id, foreach_join_parent_task_id
):
return self._client.update_item(
TableName=self.name,
Key={"pathspec": {"S": foreach_split_task_id}},
UpdateExpression="ADD parent_task_ids_for_foreach_join :val",
ExpressionAttributeValues={":val": {"SS": [foreach_join_parent_task_id]}},
)
ex = None
for attempt in range(10):
try:
return self._client.update_item(
TableName=self.name,
Key={"pathspec": {"S": foreach_split_task_id}},
UpdateExpression="ADD parent_task_ids_for_foreach_join :val",
ExpressionAttributeValues={
":val": {"SS": [foreach_join_parent_task_id]}
},
)
except self._client.exceptions.ClientError as error:
ex = error
if (
error.response["Error"]["Code"]
== "ProvisionedThroughputExceededException"
):
# hopefully, enough time for AWS to scale up! otherwise
# ensure sufficient on-demand throughput for dynamo db
# is provisioned ahead of time
sleep_time = min((2**attempt) * 10, 60)
time.sleep(sleep_time)
else:
raise
raise ex

def get_parent_task_ids_for_foreach_join(self, foreach_split_task_id):
response = self._client.get_item(
Expand Down
2 changes: 1 addition & 1 deletion metaflow/plugins/aws/step_functions/production_token.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import json
import os
import random
import string
import zlib
Expand Down
172 changes: 168 additions & 4 deletions metaflow/plugins/aws/step_functions/step_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
SFN_DYNAMO_DB_TABLE,
SFN_EXECUTION_LOG_GROUP_ARN,
SFN_IAM_ROLE,
SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH,
)
from metaflow.parameters import deploy_time_eval
from metaflow.util import dict_to_cli_options, to_pascalcase
Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(
max_workers=None,
workflow_timeout=None,
is_project=False,
use_distributed_map=False,
):
self.name = name
self.graph = graph
Expand All @@ -70,6 +72,9 @@ def __init__(
self.max_workers = max_workers
self.workflow_timeout = workflow_timeout

# https://aws.amazon.com/blogs/aws/step-functions-distributed-map-a-serverless-solution-for-large-scale-parallel-data-processing/
self.use_distributed_map = use_distributed_map

self._client = StepFunctionsClient()
self._workflow = self._compile()
self._cron = self._cron()
Expand Down Expand Up @@ -365,17 +370,80 @@ def _visit(node, workflow, exit_node=None):
.parameter("SplitParentTaskId.$", "$.JobId")
.parameter("Parameters.$", "$.Parameters")
.parameter("Index.$", "$$.Map.Item.Value")
.next(node.matching_join)
.next(
"%s_*GetManifest" % iterator_name
if self.use_distributed_map
else node.matching_join
)
.iterator(
_visit(
self.graph[node.out_funcs[0]],
Workflow(node.out_funcs[0]).start_at(node.out_funcs[0]),
Workflow(node.out_funcs[0])
.start_at(node.out_funcs[0])
.mode(
"DISTRIBUTED" if self.use_distributed_map else "INLINE"
),
node.matching_join,
)
)
.max_concurrency(self.max_workers)
.output_path("$.[0]")
# AWS Step Functions has a short coming for DistributedMap at the
# moment that does not allow us to subset the output of for-each
# to just a single element. We have to rely on a rather terrible
# hack and resort to using ResultWriter to write the state to
# Amazon S3 and process it in another task. But, well what can we
# do...
.result_writer(
*(
(
(
SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH[len("s3://") :]
if SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH.startswith(
"s3://"
)
else SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH
).split("/", 1)
+ [""]
)[:2]
if self.use_distributed_map
else (None, None)
)
)
.output_path("$" if self.use_distributed_map else "$.[0]")
)
if self.use_distributed_map:
workflow.add_state(
State("%s_*GetManifest" % iterator_name)
.resource("arn:aws:states:::aws-sdk:s3:getObject")
.parameter("Bucket.$", "$.ResultWriterDetails.Bucket")
.parameter("Key.$", "$.ResultWriterDetails.Key")
.next("%s_*Map" % iterator_name)
.result_selector("Body.$", "States.StringToJson($.Body)")
)
workflow.add_state(
Map("%s_*Map" % iterator_name)
.iterator(
Workflow("%s_*PassWorkflow" % iterator_name)
.mode("DISTRIBUTED")
.start_at("%s_*Pass" % iterator_name)
.add_state(
Pass("%s_*Pass" % iterator_name)
.end()
.parameter("Output.$", "States.StringToJson($.Output)")
.output_path("$.Output")
)
)
.next(node.matching_join)
.max_concurrency(1000)
.item_reader(
JSONItemReader()
.resource("arn:aws:states:::s3:getObject")
.parameter("Bucket.$", "$.Body.DestinationBucket")
.parameter("Key.$", "$.Body.ResultFiles.SUCCEEDED.[0].Key")
)
.output_path("$.[0]")
)

# Continue the traversal from the matching_join.
_visit(self.graph[node.matching_join], workflow, exit_node)
# We shouldn't ideally ever get here.
Expand Down Expand Up @@ -444,7 +512,6 @@ def _batch(self, node):
"metaflow.owner": self.username,
"metaflow.flow_name": self.flow.name,
"metaflow.step_name": node.name,
"metaflow.run_id.$": "$$.Execution.Name",
# Unfortunately we can't set the task id here since AWS Step
# Functions lacks any notion of run-scoped task identifiers. We
# instead co-opt the AWS Batch job id as the task id. This also
Expand All @@ -456,6 +523,10 @@ def _batch(self, node):
# `$$.State.RetryCount` resolves to an int dynamically and
# AWS Batch job specification only accepts strings. We handle
# retries/catch within AWS Batch to get around this limitation.
# And, we also cannot set the run id here since the run id maps to
# the execution name of the AWS Step Functions State Machine, which
# is different when executing inside a distributed map. We set it once
# in the start step and move it along to be consumed by all the children.
"metaflow.version": self.environment.get_environment_info()[
"metaflow_version"
],
Expand Down Expand Up @@ -492,6 +563,12 @@ def _batch(self, node):
env["METAFLOW_S3_ENDPOINT_URL"] = S3_ENDPOINT_URL

if node.name == "start":
# metaflow.run_id maps to AWS Step Functions State Machine Execution in all
# cases except for when within a for-each construct that relies on
# Distributed Map. To work around this issue, we pass the run id from the
# start step to all subsequent tasks.
attrs["metaflow.run_id.$"] = "$$.Execution.Name"

# Initialize parameters for the flow in the `start` step.
parameters = self._process_parameters()
if parameters:
Expand Down Expand Up @@ -550,6 +627,8 @@ def _batch(self, node):
env["METAFLOW_SPLIT_PARENT_TASK_ID"] = (
"$.Parameters.split_parent_task_id_%s" % node.split_parents[-1]
)
# Inherit the run id from the parent and pass it along to children.
attrs["metaflow.run_id.$"] = "$.Parameters.['metaflow.run_id']"
else:
# Set appropriate environment variables for runtime replacement.
if len(node.in_funcs) == 1:
Expand All @@ -558,6 +637,8 @@ def _batch(self, node):
% node.in_funcs[0]
)
env["METAFLOW_PARENT_TASK_ID"] = "$.JobId"
# Inherit the run id from the parent and pass it along to children.
attrs["metaflow.run_id.$"] = "$.Parameters.['metaflow.run_id']"
else:
# Generate the input paths in a quasi-compressed format.
# See util.decompress_list for why this is written the way
Expand All @@ -567,6 +648,8 @@ def _batch(self, node):
"${METAFLOW_PARENT_%s_TASK_ID}" % (idx, idx)
for idx, _ in enumerate(node.in_funcs)
)
# Inherit the run id from the parent and pass it along to children.
attrs["metaflow.run_id.$"] = "$.[0].Parameters.['metaflow.run_id']"
for idx, _ in enumerate(node.in_funcs):
env["METAFLOW_PARENT_%s_TASK_ID" % idx] = "$.[%s].JobId" % idx
env["METAFLOW_PARENT_%s_STEP" % idx] = (
Expand Down Expand Up @@ -893,6 +976,12 @@ def __init__(self, name):
tree = lambda: defaultdict(tree)
self.payload = tree()

def mode(self, mode):
self.payload["ProcessorConfig"] = {"Mode": mode}
if mode == "DISTRIBUTED":
self.payload["ProcessorConfig"]["ExecutionType"] = "STANDARD"
return self

def start_at(self, start_at):
self.payload["StartAt"] = start_at
return self
Expand Down Expand Up @@ -940,10 +1029,18 @@ def result_path(self, result_path):
self.payload["ResultPath"] = result_path
return self

def result_selector(self, name, value):
self.payload["ResultSelector"][name] = value
return self

def _partition(self):
# This is needed to support AWS Gov Cloud and AWS CN regions
return SFN_IAM_ROLE.split(":")[1]

def retry_strategy(self, retry_strategy):
self.payload["Retry"] = [retry_strategy]
return self

def batch(self, job):
self.resource(
"arn:%s:states:::batch:submitJob.sync" % self._partition()
Expand All @@ -963,6 +1060,19 @@ def batch(self, job):
# tags may not be present in all scenarios
if "tags" in job.payload:
self.parameter("Tags", job.payload["tags"])
# set retry strategy for AWS Batch job submission to account for the
# measily 50 jobs / second queue admission limit which people can
# run into very quickly.
self.retry_strategy(
{
"ErrorEquals": ["Batch.AWSBatchException"],
"BackoffRate": 2,
"IntervalSeconds": 2,
"MaxDelaySeconds": 60,
"MaxAttempts": 10,
"JitterStrategy": "FULL",
}
)
return self

def dynamo_db(self, table_name, primary_key, values):
Expand All @@ -976,6 +1086,26 @@ def dynamo_db(self, table_name, primary_key, values):
return self


class Pass(object):
def __init__(self, name):
self.name = name
tree = lambda: defaultdict(tree)
self.payload = tree()
self.payload["Type"] = "Pass"

def end(self):
self.payload["End"] = True
return self

def parameter(self, name, value):
self.payload["Parameters"][name] = value
return self

def output_path(self, output_path):
self.payload["OutputPath"] = output_path
return self


class Parallel(object):
def __init__(self, name):
self.name = name
Expand Down Expand Up @@ -1037,3 +1167,37 @@ def output_path(self, output_path):
def result_path(self, result_path):
self.payload["ResultPath"] = result_path
return self

def item_reader(self, item_reader):
self.payload["ItemReader"] = item_reader.payload
return self

def result_writer(self, bucket, prefix):
if bucket is not None and prefix is not None:
self.payload["ResultWriter"] = {
"Resource": "arn:aws:states:::s3:putObject",
"Parameters": {
"Bucket": bucket,
"Prefix": prefix,
},
}
return self


class JSONItemReader(object):
def __init__(self):
tree = lambda: defaultdict(tree)
self.payload = tree()
self.payload["ReaderConfig"] = {"InputType": "JSON", "MaxItems": 1}

def resource(self, resource):
self.payload["Resource"] = resource
return self

def parameter(self, name, value):
self.payload["Parameters"][name] = value
return self

def output_path(self, output_path):
self.payload["OutputPath"] = output_path
return self
Loading
Loading