diff --git a/tfx/orchestration/experimental/core/pipeline_ops.py b/tfx/orchestration/experimental/core/pipeline_ops.py index 0f83b9177e..a3e7a38972 100644 --- a/tfx/orchestration/experimental/core/pipeline_ops.py +++ b/tfx/orchestration/experimental/core/pipeline_ops.py @@ -467,7 +467,7 @@ def _check_nodes_exist( raise status_lib.StatusNotOkError( code=status_lib.Code.INVALID_ARGUMENT, message=( - f'`f{op_name}` operation failed, cannot find node(s) ' + f'`{op_name}` operation failed, cannot find node(s) ' f'{", ".join(node_id_set)} in the pipeline IR.' ), ) @@ -554,6 +554,67 @@ def skip_nodes( ) +@_pipeline_op() +def skip_failed_nodes( + mlmd_handle: metadata.Metadata, node_uids: Sequence[task_lib.NodeUid] +) -> None: + """Marks the given failed nodes as skipped instead.""" + # All node_uids must have the same pipeline_uid. + pipeline_uids_set = set(n.pipeline_uid for n in node_uids) + if len(pipeline_uids_set) != 1: + raise status_lib.StatusNotOkError( + code=status_lib.Code.INVALID_ARGUMENT, + message=( + 'All nodes must belong to the same pipeline, but the given ' + f'nodes do not. Node UIDs were: {node_uids}' + ), + ) + pipeline_uid = pipeline_uids_set.pop() + with pstate.PipelineState.load_run( + mlmd_handle, + pipeline_id=pipeline_uid.pipeline_id, + run_id=pipeline_uid.pipeline_run_id, + ) as pipeline_state: + pipeline = pipeline_state.pipeline + if pipeline.execution_mode != pipeline_pb2.Pipeline.SYNC: + raise status_lib.StatusNotOkError( + code=status_lib.Code.FAILED_PRECONDITION, + message=( + 'Can only skip failed nodes for SYNC pipelines, but pipeline had' + f'execution mode: {pipeline.execution_mode}' + ), + ) + if not execution_lib.is_execution_failed(pipeline_state.execution): + state_str = metadata_store_pb2.Execution.State.Name( + pipeline_state.execution.last_known_state + ) + raise status_lib.StatusNotOkError( + code=status_lib.Code.FAILED_PRECONDITION, + message=( + 'Can only skip failed nodes for a pipeline in FAILED state, but ' + f'pipeline in state: {state_str}' + ), + ) + _check_nodes_exist(node_uids, pipeline_state.pipeline, 'skip_nodes') + for node_uid in node_uids: + with pipeline_state.node_state_update_context(node_uid) as node_state: + if node_state.state != pstate.NodeState.FAILED: + raise status_lib.StatusNotOkError( + code=status_lib.Code.FAILED_PRECONDITION, + message=( + 'Can only skip nodes that are in a FAILED state, but node ' + f'{node_uid} was in state {node_state.state}' + ), + ) + node_state.update( + pstate.NodeState.SKIPPED, + status_lib.Status( + code=status_lib.Code.OK, + message='Failed node marked as skipped using SkipFailedNodes', + ), + ) + + @_pipeline_op() def resume_manual_node( mlmd_handle: metadata.Metadata, node_uid: task_lib.NodeUid diff --git a/tfx/orchestration/experimental/core/pipeline_ops_test.py b/tfx/orchestration/experimental/core/pipeline_ops_test.py index a136622f36..c86977784b 100644 --- a/tfx/orchestration/experimental/core/pipeline_ops_test.py +++ b/tfx/orchestration/experimental/core/pipeline_ops_test.py @@ -3206,6 +3206,80 @@ def test_skip_nodes(self): states_dict[task_lib.NodeUid(pipeline_uid, 'Pusher')].state, ) + def test_skip_failed_nodes(self): + with self._mlmd_cm as mlmd_connection_manager: + m = mlmd_connection_manager.primary_mlmd_handle + pipeline = _test_pipeline( + 'pipeline1', execution_mode=pipeline_pb2.Pipeline.SYNC + ) + pipeline_uid = task_lib.PipelineUid.from_pipeline(pipeline) + pipeline.nodes.add().pipeline_node.node_info.id = 'ExampleGen' + pipeline.nodes.add().pipeline_node.node_info.id = 'Transform' + pipeline_ops.initiate_pipeline_start(m, pipeline) + + # Can't skip failed nodes if the pipeline isn't in a FAILED state + with self.assertRaises(status_lib.StatusNotOkError) as exception_context: + pipeline_ops.skip_failed_nodes( + m, + [task_lib.NodeUid(pipeline_uid, 'ExampleGen')], + ) + self.assertEqual( + status_lib.Code.FAILED_PRECONDITION, exception_context.exception.code + ) + + with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: + # Change state of ExampleGen node to COMPLETE. + with pipeline_state.node_state_update_context( + task_lib.NodeUid(pipeline_uid, 'ExampleGen') + ) as node_state: + node_state.state = pstate.NodeState.COMPLETE + # Change state of Transform node to FAILED. + with pipeline_state.node_state_update_context( + task_lib.NodeUid(pipeline_uid, 'Transform') + ) as node_state: + node_state.state = pstate.NodeState.FAILED + + # Can't skip failed nodes if the pipeline isn't in a FAILED state, + # even if the node is in a FAILED state. + with self.assertRaises(status_lib.StatusNotOkError) as exception_context: + pipeline_ops.skip_failed_nodes( + m, + [task_lib.NodeUid(pipeline_uid, 'Transform')], + ) + self.assertEqual( + status_lib.Code.FAILED_PRECONDITION, exception_context.exception.code + ) + + with pstate.PipelineState.load(m, pipeline_uid) as pipeline_state: + # Now mark the pipeline as FAILED. + pipeline_state.set_pipeline_execution_state( + metadata_store_pb2.Execution.FAILED + ) + + # Can't skip non-failed nodes. + with self.assertRaises(status_lib.StatusNotOkError) as exception_context: + pipeline_ops.skip_failed_nodes( + m, + [task_lib.NodeUid(pipeline_uid, 'ExampleGen')], + ) + self.assertEqual( + status_lib.Code.FAILED_PRECONDITION, exception_context.exception.code + ) + + # Skip Transform + pipeline_ops.skip_failed_nodes( + m, + [task_lib.NodeUid(pipeline_uid, 'Transform')], + ) + + pipeline_view = pstate.PipelineView.load( + m, + pipeline_id=pipeline_uid.pipeline_id, + pipeline_run_id=pipeline_uid.pipeline_run_id, + ) + states_dict = pipeline_view.get_node_states_dict() + self.assertEqual(pstate.NodeState.SKIPPED, states_dict['Transform'].state) + def test_exception_while_orchestrating_active_pipeline(self): with self._mlmd_cm as mlmd_connection_manager: m = mlmd_connection_manager.primary_mlmd_handle