Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ability to terminate execution of a step-fn state machine #1695

Merged
merged 5 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ def delete(name):

return schedule_deleted, sensor_deleted, workflow_deleted

@staticmethod
def terminate(flow_name, name):
@classmethod
madhur-ob marked this conversation as resolved.
Show resolved Hide resolved
def terminate(cls, flow_name, name):
client = ArgoClient(namespace=KUBERNETES_NAMESPACE)

response = client.terminate_workflow(name)
Expand Down
49 changes: 49 additions & 0 deletions metaflow/plugins/aws/step_functions/step_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ def delete(cls, name):

return schedule_deleted, sfn_deleted

@classmethod
saikonen marked this conversation as resolved.
Show resolved Hide resolved
def terminate(cls, flow_name, name):
client = StepFunctionsClient()
execution_arn, _, _, _, _, _ = cls.get_execution(flow_name, name)
response = client.terminate_execution(execution_arn)
return response

@classmethod
def trigger(cls, name, parameters):
try:
Expand Down Expand Up @@ -234,6 +241,48 @@ def get_existing_deployment(cls, name):
)
return None

@classmethod
def get_execution(cls, state_machine_name, name):
client = StepFunctionsClient()
try:
state_machine = client.get(state_machine_name)
except Exception as e:
raise StepFunctionsException(repr(e))
if state_machine is None:
raise StepFunctionsException(
"The workflow *%s* doesn't exist " "on AWS Step Functions." % name
)
try:
state_machine_arn = state_machine.get("stateMachineArn")
parameters = (
json.loads(state_machine.get("definition"))
.get("States")
.get("start")
.get("Parameters")
.get("Parameters")
)

executions = client.list_executions(state_machine_arn, states=["RUNNING"])
for execution in executions:
if execution.get("name") == name:
try:
return (
execution.get("executionArn"),
parameters.get("metaflow.owner"),
parameters.get("metaflow.production_token"),
parameters.get("metaflow.flow_name"),
parameters.get("metaflow.branch_name", None),
parameters.get("metaflow.project_name", None),
)
except KeyError:
raise StepFunctionsException(
"A non-metaflow workflow *%s* already exists in AWS Step Functions."
% name
)
return None
except Exception as e:
raise StepFunctionsException(repr(e))

def _compile(self):
if self.flow._flow_decorators.get("trigger") or self.flow._flow_decorators.get(
"trigger_on_finish"
Expand Down
102 changes: 102 additions & 0 deletions metaflow/plugins/aws/step_functions/step_functions_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class IncorrectProductionToken(MetaflowException):
headline = "Incorrect production token"


class RunIdMismatch(MetaflowException):
headline = "Run ID mismatch"


class IncorrectMetadataServiceVersion(MetaflowException):
headline = "Incorrect version for metaflow service"

Expand Down Expand Up @@ -614,6 +618,104 @@ def _token_instructions(flow_name, prev_user):
)


@step_functions.command(help="Terminate flow execution on Step Functions.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nit - but for consistency AWS Step Functions. not blocking for merging this PR in, but would be good if you can sweep this and various CLIs to polish up the help strings and make them consistent.

@click.option(
"--authorize",
default=None,
type=str,
help="Authorize the termination with a production token",
)
@click.argument("run-id", required=True, type=str)
@click.pass_obj
def terminate(obj, run_id, authorize=None):
def _token_instructions(flow_name, prev_user):
obj.echo(
"There is an existing version of *%s* on AWS Step Functions which was "
"deployed by the user *%s*." % (flow_name, prev_user)
)
obj.echo(
"To terminate this flow, you need to use the same production token that they used."
)
obj.echo(
"Please reach out to them to get the token. Once you have it, call "
"this command:"
)
obj.echo(" step-functions terminate --authorize MY_TOKEN RUN_ID", fg="green")
obj.echo(
'See "Organizing Results" at docs.metaflow.org for more information '
"about production tokens."
)

validate_run_id(
obj.state_machine_name, obj.token_prefix, authorize, run_id, _token_instructions
)

# Trim prefix from run_id
name = run_id[4:]
obj.echo(
"Terminating run *{run_id}* for {flow_name} ...".format(
run_id=run_id, flow_name=obj.flow.name
),
bold=True,
)

terminated = StepFunctions.terminate(obj.state_machine_name, name)
if terminated:
obj.echo(f"\nRun terminated at {terminated.get('stopDate')}.")
madhur-ob marked this conversation as resolved.
Show resolved Hide resolved


def validate_run_id(
state_machine_name, token_prefix, authorize, run_id, instructions_fn=None
):
if not run_id.startswith("sfn-"):
raise RunIdMismatch(
"Run IDs for flows executed through AWS Step Functions begin with 'sfn-'"
)

name = run_id[4:]
execution = StepFunctions.get_execution(state_machine_name, name)
if execution is None:
raise MetaflowException(
madhur-ob marked this conversation as resolved.
Show resolved Hide resolved
"Could not find the execution *%s* (in RUNNING state) on AWS Step Functions"
madhur-ob marked this conversation as resolved.
Show resolved Hide resolved
% name
)

_, owner, token, flow_name, branch_name, project_name = execution

if current.flow_name != flow_name:
raise RunIdMismatch(
"The workflow with the run_id *%s* belongs to the flow *%s*, not for the flow *%s*."
% (run_id, flow_name, current.flow_name)
)

if project_name is not None:
if current.get("project_name") != project_name:
raise RunIdMismatch(
"The workflow belongs to the project *%s*. "
"Please use the project decorator or --name to target the correct project"
% project_name
)

if current.get("branch_name") != branch_name:
raise RunIdMismatch(
"The workflow belongs to the branch *%s*. "
"Please use --branch, --production or --name to target the correct branch"
% branch_name
)

if authorize is None:
authorize = load_token(token_prefix)
elif authorize.startswith("production:"):
authorize = authorize[11:]

if owner != get_username() and authorize != token:
if instructions_fn:
instructions_fn(flow_name=name, prev_user=owner)
raise IncorrectProductionToken("Try again with the correct production token.")

return True


def validate_token(name, token_prefix, authorize, instruction_fn=None):
"""
Validate that the production token matches that of the deployed flow.
Expand Down
11 changes: 8 additions & 3 deletions metaflow/plugins/aws/step_functions/step_functions_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,14 @@ def list_executions(self, state_machine_arn, states):
for execution in page["executions"]
)

def terminate_execution(self, state_machine_arn, execution_arn):
# TODO
pass
def terminate_execution(self, execution_arn):
try:
response = self._client.stop_execution(executionArn=execution_arn)
return response
except self._client.exceptions.ExecutionDoesNotExist:
raise ValueError(f"The execution ARN {execution_arn} does not exist.")
madhur-ob marked this conversation as resolved.
Show resolved Hide resolved
except Exception as e:
raise e

def _default_logging_configuration(self, log_execution_history):
if log_execution_history:
Expand Down
Loading