Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ability to terminate execution of a step-fn state machine #1695

Merged
merged 5 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ def delete(name):

return schedule_deleted, sensor_deleted, workflow_deleted

@staticmethod
def terminate(flow_name, name):
@classmethod
madhur-ob marked this conversation as resolved.
Show resolved Hide resolved
def terminate(cls, flow_name, name):
client = ArgoClient(namespace=KUBERNETES_NAMESPACE)

response = client.terminate_workflow(name)
Expand Down
51 changes: 51 additions & 0 deletions metaflow/plugins/aws/step_functions/step_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ def delete(cls, name):

return schedule_deleted, sfn_deleted

@classmethod
saikonen marked this conversation as resolved.
Show resolved Hide resolved
def terminate(cls, flow_name, name):
client = StepFunctionsClient()
execution_arn, _, _, _ = cls.get_execution(flow_name, name)
response = client.terminate_execution(execution_arn)
return response

@classmethod
def trigger(cls, name, parameters):
try:
Expand Down Expand Up @@ -234,6 +241,50 @@ def get_existing_deployment(cls, name):
)
return None

@classmethod
def get_execution(cls, state_machine_name, name):
client = StepFunctionsClient()
try:
state_machine = client.get(state_machine_name)
except Exception as e:
raise StepFunctionsException(repr(e))
if state_machine is None:
raise StepFunctionsException(
"The state machine *%s* doesn't exist on AWS Step Functions."
% state_machine_name
)
try:
state_machine_arn = state_machine.get("stateMachineArn")
environment_vars = (
json.loads(state_machine.get("definition"))
.get("States")
.get("start")
.get("Parameters")
.get("ContainerOverrides")
.get("Environment")
)
parameters = {
item.get("Name"): item.get("Value") for item in environment_vars
}
executions = client.list_executions(state_machine_arn, states=["RUNNING"])
for execution in executions:
if execution.get("name") == name:
try:
return (
execution.get("executionArn"),
parameters.get("METAFLOW_OWNER"),
parameters.get("METAFLOW_PRODUCTION_TOKEN"),
parameters.get("SFN_STATE_MACHINE"),
)
except KeyError:
raise StepFunctionsException(
"A non-metaflow workflow *%s* already exists in AWS Step Functions."
% name
)
return None
except Exception as e:
raise StepFunctionsException(repr(e))

def _compile(self):
if self.flow._flow_decorators.get("trigger") or self.flow._flow_decorators.get(
"trigger_on_finish"
Expand Down
81 changes: 81 additions & 0 deletions metaflow/plugins/aws/step_functions/step_functions_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class IncorrectProductionToken(MetaflowException):
headline = "Incorrect production token"


class RunIdMismatch(MetaflowException):
headline = "Run ID mismatch"


class IncorrectMetadataServiceVersion(MetaflowException):
headline = "Incorrect version for metaflow service"

Expand Down Expand Up @@ -614,6 +618,83 @@ def _token_instructions(flow_name, prev_user):
)


@step_functions.command(help="Terminate flow execution on Step Functions.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nit - but for consistency AWS Step Functions. not blocking for merging this PR in, but would be good if you can sweep this and various CLIs to polish up the help strings and make them consistent.

@click.option(
"--authorize",
default=None,
type=str,
help="Authorize the termination with a production token",
)
@click.argument("run-id", required=True, type=str)
@click.pass_obj
def terminate(obj, run_id, authorize=None):
def _token_instructions(flow_name, prev_user):
obj.echo(
"There is an existing version of *%s* on AWS Step Functions which was "
"deployed by the user *%s*." % (flow_name, prev_user)
)
obj.echo(
"To terminate this flow, you need to use the same production token that they used."
)
obj.echo(
"Please reach out to them to get the token. Once you have it, call "
"this command:"
)
obj.echo(" step-functions terminate --authorize MY_TOKEN RUN_ID", fg="green")
obj.echo(
'See "Organizing Results" at docs.metaflow.org for more information '
"about production tokens."
)

validate_run_id(
obj.state_machine_name, obj.token_prefix, authorize, run_id, _token_instructions
)

# Trim prefix from run_id
name = run_id[4:]
obj.echo(
"Terminating run *{run_id}* for {flow_name} ...".format(
run_id=run_id, flow_name=obj.flow.name
),
bold=True,
)

terminated = StepFunctions.terminate(obj.state_machine_name, name)
if terminated:
obj.echo("\nRun terminated at %s." % terminated.get("stopDate"))


def validate_run_id(
state_machine_name, token_prefix, authorize, run_id, instructions_fn=None
):
if not run_id.startswith("sfn-"):
raise RunIdMismatch(
"Run IDs for flows executed through AWS Step Functions begin with 'sfn-'"
)

name = run_id[4:]
execution = StepFunctions.get_execution(state_machine_name, name)
if execution is None:
raise MetaflowException(
madhur-ob marked this conversation as resolved.
Show resolved Hide resolved
"Could not find the execution *%s* (in RUNNING state) for the state machine *%s* on AWS Step Functions"
% (name, state_machine_name)
)

_, owner, token, _ = execution

if authorize is None:
authorize = load_token(token_prefix)
elif authorize.startswith("production:"):
authorize = authorize[11:]

if owner != get_username() and authorize != token:
if instructions_fn:
instructions_fn(flow_name=name, prev_user=owner)
raise IncorrectProductionToken("Try again with the correct production token.")

return True


def validate_token(name, token_prefix, authorize, instruction_fn=None):
"""
Validate that the production token matches that of the deployed flow.
Expand Down
11 changes: 8 additions & 3 deletions metaflow/plugins/aws/step_functions/step_functions_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,14 @@ def list_executions(self, state_machine_arn, states):
for execution in page["executions"]
)

def terminate_execution(self, state_machine_arn, execution_arn):
# TODO
pass
def terminate_execution(self, execution_arn):
try:
response = self._client.stop_execution(executionArn=execution_arn)
return response
except self._client.exceptions.ExecutionDoesNotExist:
raise ValueError("The execution ARN %s does not exist." % execution_arn)
except Exception as e:
raise e

def _default_logging_configuration(self, log_execution_history):
if log_execution_history:
Expand Down
Loading