Skip to content

Commit

Permalink
add ability to terminate execution of a step-fn state machine (#1695)
Browse files Browse the repository at this point in the history
* add ability to terminate execution of a step-fn state machine

* don't use f-strings

* remove checks for project, branch

* add comment

* suggested changes
  • Loading branch information
madhur-ob authored Jan 30, 2024
1 parent dc6af41 commit 81d8909
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 5 deletions.
4 changes: 2 additions & 2 deletions metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ def delete(name):

return schedule_deleted, sensor_deleted, workflow_deleted

@staticmethod
def terminate(flow_name, name):
@classmethod
def terminate(cls, flow_name, name):
client = ArgoClient(namespace=KUBERNETES_NAMESPACE)

response = client.terminate_workflow(name)
Expand Down
51 changes: 51 additions & 0 deletions metaflow/plugins/aws/step_functions/step_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ def delete(cls, name):

return schedule_deleted, sfn_deleted

@classmethod
def terminate(cls, flow_name, name):
client = StepFunctionsClient()
execution_arn, _, _, _ = cls.get_execution(flow_name, name)
response = client.terminate_execution(execution_arn)
return response

@classmethod
def trigger(cls, name, parameters):
try:
Expand Down Expand Up @@ -234,6 +241,50 @@ def get_existing_deployment(cls, name):
)
return None

@classmethod
def get_execution(cls, state_machine_name, name):
client = StepFunctionsClient()
try:
state_machine = client.get(state_machine_name)
except Exception as e:
raise StepFunctionsException(repr(e))
if state_machine is None:
raise StepFunctionsException(
"The state machine *%s* doesn't exist on AWS Step Functions."
% state_machine_name
)
try:
state_machine_arn = state_machine.get("stateMachineArn")
environment_vars = (
json.loads(state_machine.get("definition"))
.get("States")
.get("start")
.get("Parameters")
.get("ContainerOverrides")
.get("Environment")
)
parameters = {
item.get("Name"): item.get("Value") for item in environment_vars
}
executions = client.list_executions(state_machine_arn, states=["RUNNING"])
for execution in executions:
if execution.get("name") == name:
try:
return (
execution.get("executionArn"),
parameters.get("METAFLOW_OWNER"),
parameters.get("METAFLOW_PRODUCTION_TOKEN"),
parameters.get("SFN_STATE_MACHINE"),
)
except KeyError:
raise StepFunctionsException(
"A non-metaflow workflow *%s* already exists in AWS Step Functions."
% name
)
return None
except Exception as e:
raise StepFunctionsException(repr(e))

def _compile(self):
if self.flow._flow_decorators.get("trigger") or self.flow._flow_decorators.get(
"trigger_on_finish"
Expand Down
81 changes: 81 additions & 0 deletions metaflow/plugins/aws/step_functions/step_functions_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class IncorrectProductionToken(MetaflowException):
headline = "Incorrect production token"


class RunIdMismatch(MetaflowException):
headline = "Run ID mismatch"


class IncorrectMetadataServiceVersion(MetaflowException):
headline = "Incorrect version for metaflow service"

Expand Down Expand Up @@ -614,6 +618,83 @@ def _token_instructions(flow_name, prev_user):
)


@step_functions.command(help="Terminate flow execution on Step Functions.")
@click.option(
"--authorize",
default=None,
type=str,
help="Authorize the termination with a production token",
)
@click.argument("run-id", required=True, type=str)
@click.pass_obj
def terminate(obj, run_id, authorize=None):
def _token_instructions(flow_name, prev_user):
obj.echo(
"There is an existing version of *%s* on AWS Step Functions which was "
"deployed by the user *%s*." % (flow_name, prev_user)
)
obj.echo(
"To terminate this flow, you need to use the same production token that they used."
)
obj.echo(
"Please reach out to them to get the token. Once you have it, call "
"this command:"
)
obj.echo(" step-functions terminate --authorize MY_TOKEN RUN_ID", fg="green")
obj.echo(
'See "Organizing Results" at docs.metaflow.org for more information '
"about production tokens."
)

validate_run_id(
obj.state_machine_name, obj.token_prefix, authorize, run_id, _token_instructions
)

# Trim prefix from run_id
name = run_id[4:]
obj.echo(
"Terminating run *{run_id}* for {flow_name} ...".format(
run_id=run_id, flow_name=obj.flow.name
),
bold=True,
)

terminated = StepFunctions.terminate(obj.state_machine_name, name)
if terminated:
obj.echo("\nRun terminated at %s." % terminated.get("stopDate"))


def validate_run_id(
state_machine_name, token_prefix, authorize, run_id, instructions_fn=None
):
if not run_id.startswith("sfn-"):
raise RunIdMismatch(
"Run IDs for flows executed through AWS Step Functions begin with 'sfn-'"
)

name = run_id[4:]
execution = StepFunctions.get_execution(state_machine_name, name)
if execution is None:
raise MetaflowException(
"Could not find the execution *%s* (in RUNNING state) for the state machine *%s* on AWS Step Functions"
% (name, state_machine_name)
)

_, owner, token, _ = execution

if authorize is None:
authorize = load_token(token_prefix)
elif authorize.startswith("production:"):
authorize = authorize[11:]

if owner != get_username() and authorize != token:
if instructions_fn:
instructions_fn(flow_name=name, prev_user=owner)
raise IncorrectProductionToken("Try again with the correct production token.")

return True


def validate_token(name, token_prefix, authorize, instruction_fn=None):
"""
Validate that the production token matches that of the deployed flow.
Expand Down
11 changes: 8 additions & 3 deletions metaflow/plugins/aws/step_functions/step_functions_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,14 @@ def list_executions(self, state_machine_arn, states):
for execution in page["executions"]
)

def terminate_execution(self, state_machine_arn, execution_arn):
# TODO
pass
def terminate_execution(self, execution_arn):
try:
response = self._client.stop_execution(executionArn=execution_arn)
return response
except self._client.exceptions.ExecutionDoesNotExist:
raise ValueError("The execution ARN %s does not exist." % execution_arn)
except Exception as e:
raise e

def _default_logging_configuration(self, log_execution_history):
if log_execution_history:
Expand Down

0 comments on commit 81d8909

Please sign in to comment.