diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 1f1da1b6d09..cb3a01e520d 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -227,8 +227,8 @@ def delete(name): return schedule_deleted, sensor_deleted, workflow_deleted - @staticmethod - def terminate(flow_name, name): + @classmethod + def terminate(cls, flow_name, name): client = ArgoClient(namespace=KUBERNETES_NAMESPACE) response = client.terminate_workflow(name) diff --git a/metaflow/plugins/aws/step_functions/step_functions.py b/metaflow/plugins/aws/step_functions/step_functions.py index 7dd6854b725..261dc98fdcd 100644 --- a/metaflow/plugins/aws/step_functions/step_functions.py +++ b/metaflow/plugins/aws/step_functions/step_functions.py @@ -166,6 +166,13 @@ def delete(cls, name): return schedule_deleted, sfn_deleted + @classmethod + def terminate(cls, flow_name, name): + client = StepFunctionsClient() + execution_arn, _, _, _ = cls.get_execution(flow_name, name) + response = client.terminate_execution(execution_arn) + return response + @classmethod def trigger(cls, name, parameters): try: @@ -234,6 +241,50 @@ def get_existing_deployment(cls, name): ) return None + @classmethod + def get_execution(cls, state_machine_name, name): + client = StepFunctionsClient() + try: + state_machine = client.get(state_machine_name) + except Exception as e: + raise StepFunctionsException(repr(e)) + if state_machine is None: + raise StepFunctionsException( + "The state machine *%s* doesn't exist on AWS Step Functions." + % state_machine_name + ) + try: + state_machine_arn = state_machine.get("stateMachineArn") + environment_vars = ( + json.loads(state_machine.get("definition")) + .get("States") + .get("start") + .get("Parameters") + .get("ContainerOverrides") + .get("Environment") + ) + parameters = { + item.get("Name"): item.get("Value") for item in environment_vars + } + executions = client.list_executions(state_machine_arn, states=["RUNNING"]) + for execution in executions: + if execution.get("name") == name: + try: + return ( + execution.get("executionArn"), + parameters.get("METAFLOW_OWNER"), + parameters.get("METAFLOW_PRODUCTION_TOKEN"), + parameters.get("SFN_STATE_MACHINE"), + ) + except KeyError: + raise StepFunctionsException( + "A non-metaflow workflow *%s* already exists in AWS Step Functions." + % name + ) + return None + except Exception as e: + raise StepFunctionsException(repr(e)) + def _compile(self): if self.flow._flow_decorators.get("trigger") or self.flow._flow_decorators.get( "trigger_on_finish" diff --git a/metaflow/plugins/aws/step_functions/step_functions_cli.py b/metaflow/plugins/aws/step_functions/step_functions_cli.py index 63b1b645424..3ee971dd0f8 100644 --- a/metaflow/plugins/aws/step_functions/step_functions_cli.py +++ b/metaflow/plugins/aws/step_functions/step_functions_cli.py @@ -26,6 +26,10 @@ class IncorrectProductionToken(MetaflowException): headline = "Incorrect production token" +class RunIdMismatch(MetaflowException): + headline = "Run ID mismatch" + + class IncorrectMetadataServiceVersion(MetaflowException): headline = "Incorrect version for metaflow service" @@ -614,6 +618,83 @@ def _token_instructions(flow_name, prev_user): ) +@step_functions.command(help="Terminate flow execution on Step Functions.") +@click.option( + "--authorize", + default=None, + type=str, + help="Authorize the termination with a production token", +) +@click.argument("run-id", required=True, type=str) +@click.pass_obj +def terminate(obj, run_id, authorize=None): + def _token_instructions(flow_name, prev_user): + obj.echo( + "There is an existing version of *%s* on AWS Step Functions which was " + "deployed by the user *%s*." % (flow_name, prev_user) + ) + obj.echo( + "To terminate this flow, you need to use the same production token that they used." + ) + obj.echo( + "Please reach out to them to get the token. Once you have it, call " + "this command:" + ) + obj.echo(" step-functions terminate --authorize MY_TOKEN RUN_ID", fg="green") + obj.echo( + 'See "Organizing Results" at docs.metaflow.org for more information ' + "about production tokens." + ) + + validate_run_id( + obj.state_machine_name, obj.token_prefix, authorize, run_id, _token_instructions + ) + + # Trim prefix from run_id + name = run_id[4:] + obj.echo( + "Terminating run *{run_id}* for {flow_name} ...".format( + run_id=run_id, flow_name=obj.flow.name + ), + bold=True, + ) + + terminated = StepFunctions.terminate(obj.state_machine_name, name) + if terminated: + obj.echo("\nRun terminated at %s." % terminated.get("stopDate")) + + +def validate_run_id( + state_machine_name, token_prefix, authorize, run_id, instructions_fn=None +): + if not run_id.startswith("sfn-"): + raise RunIdMismatch( + "Run IDs for flows executed through AWS Step Functions begin with 'sfn-'" + ) + + name = run_id[4:] + execution = StepFunctions.get_execution(state_machine_name, name) + if execution is None: + raise MetaflowException( + "Could not find the execution *%s* (in RUNNING state) for the state machine *%s* on AWS Step Functions" + % (name, state_machine_name) + ) + + _, owner, token, _ = execution + + if authorize is None: + authorize = load_token(token_prefix) + elif authorize.startswith("production:"): + authorize = authorize[11:] + + if owner != get_username() and authorize != token: + if instructions_fn: + instructions_fn(flow_name=name, prev_user=owner) + raise IncorrectProductionToken("Try again with the correct production token.") + + return True + + def validate_token(name, token_prefix, authorize, instruction_fn=None): """ Validate that the production token matches that of the deployed flow. diff --git a/metaflow/plugins/aws/step_functions/step_functions_client.py b/metaflow/plugins/aws/step_functions/step_functions_client.py index f7418f15427..ceec8e4d0ce 100644 --- a/metaflow/plugins/aws/step_functions/step_functions_client.py +++ b/metaflow/plugins/aws/step_functions/step_functions_client.py @@ -81,9 +81,14 @@ def list_executions(self, state_machine_arn, states): for execution in page["executions"] ) - def terminate_execution(self, state_machine_arn, execution_arn): - # TODO - pass + def terminate_execution(self, execution_arn): + try: + response = self._client.stop_execution(executionArn=execution_arn) + return response + except self._client.exceptions.ExecutionDoesNotExist: + raise ValueError("The execution ARN %s does not exist." % execution_arn) + except Exception as e: + raise e def _default_logging_configuration(self, log_execution_history): if log_execution_history: