Skip to content

Commit

Permalink
add ability to terminate execution of a step-fn state machine
Browse files Browse the repository at this point in the history
  • Loading branch information
madhur-ob committed Jan 24, 2024
1 parent a46f4a3 commit 14943d9
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 5 deletions.
4 changes: 2 additions & 2 deletions metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ def delete(name):

return schedule_deleted, sensor_deleted, workflow_deleted

@staticmethod
def terminate(flow_name, name):
@classmethod
def terminate(cls, flow_name, name):
client = ArgoClient(namespace=KUBERNETES_NAMESPACE)

response = client.terminate_workflow(name)
Expand Down
49 changes: 49 additions & 0 deletions metaflow/plugins/aws/step_functions/step_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ def delete(cls, name):

return schedule_deleted, sfn_deleted

@classmethod
def terminate(cls, flow_name, name):
client = StepFunctionsClient()
execution_arn, _, _, _, _, _ = cls.get_execution(flow_name, name)
response = client.terminate_execution(execution_arn)
return response

@classmethod
def trigger(cls, name, parameters):
try:
Expand Down Expand Up @@ -234,6 +241,48 @@ def get_existing_deployment(cls, name):
)
return None

@classmethod
def get_execution(cls, state_machine_name, name):
client = StepFunctionsClient()
try:
state_machine = client.get(state_machine_name)
except Exception as e:
raise StepFunctionsException(repr(e))
if state_machine is None:
raise StepFunctionsException(
"The workflow *%s* doesn't exist " "on AWS Step Functions." % name
)
try:
state_machine_arn = state_machine.get("stateMachineArn")
parameters = (
json.loads(state_machine.get("definition"))
.get("States")
.get("start")
.get("Parameters")
.get("Parameters")
)

executions = client.list_executions(state_machine_arn, states=["RUNNING"])
for execution in executions:
if execution.get("name") == name:
try:
return (
execution.get("executionArn"),
parameters.get("metaflow.owner"),
parameters.get("metaflow.production_token"),
parameters.get("metaflow.flow_name"),
parameters.get("metaflow.branch_name", None),
parameters.get("metaflow.project_name", None),
)
except KeyError:
raise StepFunctionsException(
"A non-metaflow workflow *%s* already exists in AWS Step Functions."
% name
)
return None
except Exception as e:
raise StepFunctionsException(repr(e))

def _compile(self):
if self.flow._flow_decorators.get("trigger") or self.flow._flow_decorators.get(
"trigger_on_finish"
Expand Down
102 changes: 102 additions & 0 deletions metaflow/plugins/aws/step_functions/step_functions_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class IncorrectProductionToken(MetaflowException):
headline = "Incorrect production token"


class RunIdMismatch(MetaflowException):
headline = "Run ID mismatch"


class IncorrectMetadataServiceVersion(MetaflowException):
headline = "Incorrect version for metaflow service"

Expand Down Expand Up @@ -614,6 +618,104 @@ def _token_instructions(flow_name, prev_user):
)


@step_functions.command(help="Terminate flow execution on Step Functions.")
@click.option(
"--authorize",
default=None,
type=str,
help="Authorize the termination with a production token",
)
@click.argument("run-id", required=True, type=str)
@click.pass_obj
def terminate(obj, run_id, authorize=None):
def _token_instructions(flow_name, prev_user):
obj.echo(
"There is an existing version of *%s* on AWS Step Functions which was "
"deployed by the user *%s*." % (flow_name, prev_user)
)
obj.echo(
"To terminate this flow, you need to use the same production token that they used."
)
obj.echo(
"Please reach out to them to get the token. Once you have it, call "
"this command:"
)
obj.echo(" step-functions terminate --authorize MY_TOKEN RUN_ID", fg="green")
obj.echo(
'See "Organizing Results" at docs.metaflow.org for more information '
"about production tokens."
)

validate_run_id(
obj.state_machine_name, obj.token_prefix, authorize, run_id, _token_instructions
)

# Trim prefix from run_id
name = run_id[4:]
obj.echo(
"Terminating run *{run_id}* for {flow_name} ...".format(
run_id=run_id, flow_name=obj.flow.name
),
bold=True,
)

terminated = StepFunctions.terminate(obj.state_machine_name, name)
if terminated:
obj.echo(f"\nRun terminated at {terminated.get('stopDate')}.")


def validate_run_id(
state_machine_name, token_prefix, authorize, run_id, instructions_fn=None
):
if not run_id.startswith("sfn-"):
raise RunIdMismatch(
"Run IDs for flows executed through AWS Step Functions begin with 'sfn-'"
)

name = run_id[4:]
execution = StepFunctions.get_execution(state_machine_name, name)
if execution is None:
raise MetaflowException(
"Could not find the execution *%s* (in RUNNING state) on AWS Step Functions"
% name
)

_, owner, token, flow_name, branch_name, project_name = execution

if current.flow_name != flow_name:
raise RunIdMismatch(
"The workflow with the run_id *%s* belongs to the flow *%s*, not for the flow *%s*."
% (run_id, flow_name, current.flow_name)
)

if project_name is not None:
if current.get("project_name") != project_name:
raise RunIdMismatch(
"The workflow belongs to the project *%s*. "
"Please use the project decorator or --name to target the correct project"
% project_name
)

if current.get("branch_name") != branch_name:
raise RunIdMismatch(
"The workflow belongs to the branch *%s*. "
"Please use --branch, --production or --name to target the correct branch"
% branch_name
)

if authorize is None:
authorize = load_token(token_prefix)
elif authorize.startswith("production:"):
authorize = authorize[11:]

if owner != get_username() and authorize != token:
if instructions_fn:
instructions_fn(flow_name=name, prev_user=owner)
raise IncorrectProductionToken("Try again with the correct production token.")

return True


def validate_token(name, token_prefix, authorize, instruction_fn=None):
"""
Validate that the production token matches that of the deployed flow.
Expand Down
11 changes: 8 additions & 3 deletions metaflow/plugins/aws/step_functions/step_functions_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,14 @@ def list_executions(self, state_machine_arn, states):
for execution in page["executions"]
)

def terminate_execution(self, state_machine_arn, execution_arn):
# TODO
pass
def terminate_execution(self, execution_arn):
try:
response = self._client.stop_execution(executionArn=execution_arn)
return response
except self._client.exceptions.ExecutionDoesNotExist:
raise ValueError(f"The execution ARN {execution_arn} does not exist.")
except Exception as e:
raise e

def _default_logging_configuration(self, log_execution_history):
if log_execution_history:
Expand Down

0 comments on commit 14943d9

Please sign in to comment.