Skip to content

Commit

Permalink
Support distributed map in AWS Step Functions
Browse files Browse the repository at this point in the history
  • Loading branch information
savingoyal committed Feb 6, 2024
1 parent 801f6f1 commit 4f14a4b
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 11 deletions.
2 changes: 2 additions & 0 deletions metaflow/plugins/aws/step_functions/dynamo_db_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

import requests

from metaflow.metaflow_config import SFN_DYNAMO_DB_TABLE


Expand Down
2 changes: 1 addition & 1 deletion metaflow/plugins/aws/step_functions/production_token.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import json
import os
import random
import string
import zlib
Expand Down
34 changes: 31 additions & 3 deletions metaflow/plugins/aws/step_functions/step_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
max_workers=None,
workflow_timeout=None,
is_project=False,
use_distributed_map=False,
):
self.name = name
self.graph = graph
Expand All @@ -70,6 +71,9 @@ def __init__(
self.max_workers = max_workers
self.workflow_timeout = workflow_timeout

# https://aws.amazon.com/blogs/aws/step-functions-distributed-map-a-serverless-solution-for-large-scale-parallel-data-processing/
self.use_distributed_map = use_distributed_map

self._client = StepFunctionsClient()
self._workflow = self._compile()
self._cron = self._cron()
Expand Down Expand Up @@ -318,7 +322,11 @@ def _visit(node, workflow, exit_node=None):
.iterator(
_visit(
self.graph[node.out_funcs[0]],
Workflow(node.out_funcs[0]).start_at(node.out_funcs[0]),
Workflow(node.out_funcs[0])
.start_at(node.out_funcs[0])
.mode(
"DISTRIBUTED" if self.use_distributed_map else "INLINE"
),
node.matching_join,
)
)
Expand Down Expand Up @@ -393,7 +401,7 @@ def _batch(self, node):
"metaflow.owner": self.username,
"metaflow.flow_name": self.flow.name,
"metaflow.step_name": node.name,
"metaflow.run_id.$": "$$.Execution.Name",
# "metaflow.run_id.$": "$$.Execution.Name",
# Unfortunately we can't set the task id here since AWS Step
# Functions lacks any notion of run-scoped task identifiers. We
# instead co-opt the AWS Batch job id as the task id. This also
Expand Down Expand Up @@ -423,6 +431,11 @@ def _batch(self, node):
# specification that allows us to set key-values.
"step_name": node.name,
}
# # metaflow.run_id maps to AWS Step Functions State Machine Execution in all
# # cases except for when within a for-each construct that relies on Distributed
# # Map. To work around this issue, within a for-each, we lean on reading off of
# # AWS DynamoDb to get the run id.
# attrs["metaflow.run_id.$"] = "$$.Execution.Name"

# Store production token within the `start` step, so that subsequent
# `step-functions create` calls can perform a rudimentary authorization
Expand All @@ -441,6 +454,13 @@ def _batch(self, node):
env["METAFLOW_S3_ENDPOINT_URL"] = S3_ENDPOINT_URL

if node.name == "start":
# metaflow.run_id maps to AWS Step Functions State Machine Execution in all
# cases except for when within a for-each construct that relies on
# Distributed Map. To work around this issue, we pass the run id from the
# start step to all subsequent tasks.
attrs["metaflow.run_id.$"] = "$$.Execution.Name"
attrs["run_id.$"] = "$$.Execution.Name"

# Initialize parameters for the flow in the `start` step.
parameters = self._process_parameters()
if parameters:
Expand Down Expand Up @@ -475,7 +495,9 @@ def _batch(self, node):
raise StepFunctionsException(
"Parallel steps are not supported yet with AWS step functions."
)

# Inherit the run id from the parent and pass it along to children.
attrs["metaflow.run_id.$"] = "$.Parameters.run_id"
attrs["run_id.$"] = "$.Parameters.run_id"
# Handle foreach join.
if (
node.type == "join"
Expand Down Expand Up @@ -842,6 +864,12 @@ def __init__(self, name):
tree = lambda: defaultdict(tree)
self.payload = tree()

def mode(self, mode):
self.payload["ProcessorConfig"] = {"Mode": mode}
if mode == "DISTRIBUTED":
self.payload["ProcessorConfig"]["ExecutionType"] = "STANDARD"
return self

def start_at(self, start_at):
self.payload["StartAt"] = start_at
return self
Expand Down
29 changes: 23 additions & 6 deletions metaflow/plugins/aws/step_functions/step_functions_cli.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
import base64
from metaflow._vendor import click
from hashlib import sha1
import json
import re
from hashlib import sha1

from metaflow import current, decorators, parameters, JSONType
from metaflow import JSONType, current, decorators, parameters
from metaflow._vendor import click
from metaflow.exception import MetaflowException, MetaflowInternalError
from metaflow.metaflow_config import (
SERVICE_VERSION_CHECK,
SFN_STATE_MACHINE_PREFIX,
UI_URL,
)
from metaflow.exception import MetaflowException, MetaflowInternalError
from metaflow.package import MetaflowPackage
from metaflow.plugins.aws.batch.batch_decorator import BatchDecorator
from metaflow.tagging_util import validate_tags
from metaflow.util import get_username, to_bytes, to_unicode, version_parse

from .production_token import load_token, new_token, store_token
from .step_functions import StepFunctions
from .production_token import load_token, store_token, new_token

VALID_NAME = re.compile(r"[^a-zA-Z0-9_\-\.]")

Expand Down Expand Up @@ -120,6 +120,12 @@ def step_functions(obj, name=None):
help="Log AWS Step Functions execution history to AWS CloudWatch "
"Logs log group.",
)
@click.option(
"--use-distributed-map/--no-use-distributed-map",
is_flag=True,
help="Use AWS Step Functions Distributed Map instead of Inline Map for "
"defining foreach tasks in Amazon State Language.",
)
@click.pass_obj
def create(
obj,
Expand All @@ -132,6 +138,7 @@ def create(
max_workers=None,
workflow_timeout=None,
log_execution_history=False,
use_distributed_map=False,
):
validate_tags(tags)

Expand Down Expand Up @@ -161,6 +168,7 @@ def create(
max_workers,
workflow_timeout,
obj.is_project,
use_distributed_map,
)

if only_json:
Expand Down Expand Up @@ -269,7 +277,15 @@ def attach_prefix(name):


def make_flow(
obj, token, name, tags, namespace, max_workers, workflow_timeout, is_project
obj,
token,
name,
tags,
namespace,
max_workers,
workflow_timeout,
is_project,
use_distributed_map,
):
if obj.flow_datastore.TYPE != "s3":
raise MetaflowException("AWS Step Functions requires --datastore=s3.")
Expand Down Expand Up @@ -305,6 +321,7 @@ def make_flow(
username=get_username(),
workflow_timeout=workflow_timeout,
is_project=is_project,
use_distributed_map=use_distributed_map,
)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import json
import os
import time

from metaflow.decorators import StepDecorator
Expand Down

0 comments on commit 4f14a4b

Please sign in to comment.