diff --git a/tfx/orchestration/experimental/core/pipeline_state.py b/tfx/orchestration/experimental/core/pipeline_state.py index 9db976639d..8c4ddd9eb7 100644 --- a/tfx/orchestration/experimental/core/pipeline_state.py +++ b/tfx/orchestration/experimental/core/pipeline_state.py @@ -89,288 +89,6 @@ metadata_store_pb2.Execution.State.CANCELED: run_state_pb2.RunState.STOPPED, } - - -@dataclasses.dataclass -class StateRecord(json_utils.Jsonable): - state: str - backfill_token: str - status_code: Optional[int] - update_time: float - # TODO(b/242083811) Some status_msg have already been written into MLMD. - # Keeping this field is for backward compatibility to avoid json failing to - # parse existing status_msg. We can remove it once we are sure no status_msg - # in MLMD is in use. - status_msg: str = '' - - -# TODO(b/228198652): Stop using json_util.Jsonable. Before we do, -# this class MUST NOT be moved out of this module. -@attr.s(auto_attribs=True, kw_only=True) -class NodeState(json_utils.Jsonable): - """Records node state. - - Attributes: - state: Current state of the node. - status: Status of the node in state STOPPING or STOPPED. - """ - - STARTED = 'started' # Node is ready for execution. - STOPPING = 'stopping' # Pending work before state can change to STOPPED. - STOPPED = 'stopped' # Node execution is stopped. - RUNNING = 'running' # Node is under active execution (i.e. triggered). - COMPLETE = 'complete' # Node execution completed successfully. - # Node execution skipped due to condition not satisfied when pipeline has - # conditionals. - SKIPPED = 'skipped' - # Node execution skipped due to partial run. - SKIPPED_PARTIAL_RUN = 'skipped_partial_run' - FAILED = 'failed' # Node execution failed due to errors. - - state: str = attr.ib( - default=STARTED, - validator=attr.validators.in_([ - STARTED, - STOPPING, - STOPPED, - RUNNING, - COMPLETE, - SKIPPED, - SKIPPED_PARTIAL_RUN, - FAILED, - ]), - on_setattr=attr.setters.validate, - ) - backfill_token: str = '' - status_code: Optional[int] = None - status_msg: str = '' - last_updated_time: float = attr.ib(factory=lambda: time.time()) # pylint:disable=unnecessary-lambda - - state_history: List[StateRecord] = attr.ib(default=attr.Factory(list)) - - @property - def status(self) -> Optional[status_lib.Status]: - if self.status_code is not None: - return status_lib.Status(code=self.status_code, message=self.status_msg) - return None - - def update( - self, - state: str, - status: Optional[status_lib.Status] = None, - backfill_token: str = '', - ) -> None: - if self.state != state: - self.state_history.append( - StateRecord( - state=self.state, - backfill_token=self.backfill_token, - status_code=self.status_code, - update_time=self.last_updated_time, - ) - ) - if len(self.state_history) > _MAX_STATE_HISTORY_LEN: - self.state_history = self.state_history[-_MAX_STATE_HISTORY_LEN:] - self.last_updated_time = time.time() - - self.state = state - self.backfill_token = backfill_token - self.status_code = status.code if status is not None else None - self.status_msg = (status.message or '') if status is not None else '' - - def is_startable(self) -> bool: - """Returns True if the node can be started.""" - return self.state in set([self.STOPPING, self.STOPPED, self.FAILED]) - - def is_stoppable(self) -> bool: - """Returns True if the node can be stopped.""" - return self.state in set([self.STARTED, self.RUNNING]) - - def is_backfillable(self) -> bool: - """Returns True if the node can be backfilled.""" - return self.state in set([self.STOPPED, self.FAILED]) - - def is_programmatically_skippable(self) -> bool: - """Returns True if the node can be skipped via programmatic operation.""" - return self.state in set([self.STARTED, self.STOPPED]) - - def is_success(self) -> bool: - return is_node_state_success(self.state) - - def is_failure(self) -> bool: - return is_node_state_failure(self.state) - - def to_run_state(self) -> run_state_pb2.RunState: - """Returns this NodeState converted to a RunState.""" - status_code_value = None - if self.status_code is not None: - status_code_value = run_state_pb2.RunState.StatusCodeValue( - value=self.status_code) - return run_state_pb2.RunState( - state=_NODE_STATE_TO_RUN_STATE_MAP.get( - self.state, run_state_pb2.RunState.UNKNOWN - ), - status_code=status_code_value, - status_msg=self.status_msg, - update_time=int(self.last_updated_time * 1000), - ) - - def to_run_state_history(self) -> List[run_state_pb2.RunState]: - run_state_history = [] - for state in self.state_history: - # STARTING, PAUSING and PAUSED has been deprecated but may still be - # present in state_history. - if ( - state.state == 'starting' - or state.state == 'pausing' - or state.state == 'paused' - ): - continue - run_state_history.append( - NodeState( - state=state.state, - status_code=state.status_code, - last_updated_time=state.update_time).to_run_state()) - return run_state_history - - # By default, json_utils.Jsonable serializes and deserializes objects using - # obj.__dict__, which prevents attr.ib from populating default fields. - # Overriding this function to ensure default fields are populated. - @classmethod - def from_json_dict(cls, dict_data: Dict[str, Any]) -> Any: - """Convert from dictionary data to an object.""" - return cls(**dict_data) - - def latest_predicate_time_s(self, predicate: Callable[[StateRecord], bool], - include_current_state: bool) -> Optional[int]: - """Returns the latest time the StateRecord satisfies the given predicate. - - Args: - predicate: Predicate that takes the state string. - include_current_state: Whether to include the current node state when - checking the node state history (the node state history doesn't include - the current node state). - - Returns: - The latest time (in the state history) the StateRecord satisfies the given - predicate, or None if the predicate is never satisfied. - """ - if include_current_state: - current_record = StateRecord( - state=self.state, - backfill_token=self.backfill_token, - status_code=self.status_code, - update_time=self.last_updated_time, - ) - if predicate(current_record): - return int(current_record.update_time) - - for s in reversed(self.state_history): - if predicate(s): - return int(s.update_time) - return None - - def latest_running_time_s(self) -> Optional[int]: - """Returns the latest time the node entered a RUNNING state. - - Returns: - The latest time (in the state history) the node entered a RUNNING - state, or None if the node never entered a RUNNING state. - """ - return self.latest_predicate_time_s( - lambda s: is_node_state_running(s.state), include_current_state=True) - - -class _NodeStatesProxy: - """Proxy for reading and updating deserialized NodeState dicts from Execution. - - This proxy contains an internal write-back cache. Changes are not saved back - to the `Execution` until `save()` is called; cache would not be updated if - changes were made outside of the proxy, either. This is primarily used to - reduce JSON serialization/deserialization overhead for getting node state - execution property from pipeline execution. - """ - - def __init__(self, execution: metadata_store_pb2.Execution): - self._custom_properties = execution.custom_properties - self._deserialized_cache: Dict[str, Dict[str, NodeState]] = {} - self._changed_state_types: Set[str] = set() - - def get(self, state_type: str = _NODE_STATES) -> Dict[str, NodeState]: - """Gets node states dict from pipeline execution with the specified type.""" - if state_type not in [_NODE_STATES, _PREVIOUS_NODE_STATES]: - raise status_lib.StatusNotOkError( - code=status_lib.Code.INVALID_ARGUMENT, - message=( - f'Expected state_type is {_NODE_STATES} or' - f' {_PREVIOUS_NODE_STATES}, got {state_type}.' - ), - ) - if state_type not in self._deserialized_cache: - node_states_json = _get_metadata_value( - self._custom_properties.get(state_type) - ) - self._deserialized_cache[state_type] = ( - json_utils.loads(node_states_json) if node_states_json else {} - ) - return self._deserialized_cache[state_type] - - def set( - self, node_states: Dict[str, NodeState], state_type: str = _NODE_STATES - ) -> None: - """Sets node states dict with the specified type.""" - self._deserialized_cache[state_type] = node_states - self._changed_state_types.add(state_type) - - def save(self) -> None: - """Saves all changed node states dicts to pipeline execution.""" - max_mlmd_str_value_len = env.get_env().max_mlmd_str_value_length() - - for state_type in self._changed_state_types: - node_states = self._deserialized_cache[state_type] - node_states_json = json_utils.dumps(node_states) - - # Removes state history from node states if it's too large to avoid - # hitting MLMD limit. - if ( - max_mlmd_str_value_len - and len(node_states_json) > max_mlmd_str_value_len - ): - logging.info( - 'Node states length %d is too large (> %d); Removing state history' - ' from it.', - len(node_states_json), - max_mlmd_str_value_len, - ) - node_states_no_history = {} - for node, old_state in node_states.items(): - new_state = copy.deepcopy(old_state) - new_state.state_history.clear() - node_states_no_history[node] = new_state - node_states_json = json_utils.dumps(node_states_no_history) - logging.info( - 'Node states length after removing state history: %d', - len(node_states_json), - ) - - data_types_utils.set_metadata_value( - self._custom_properties[state_type], node_states_json - ) - - -def is_node_state_success(state: str) -> bool: - return state in (NodeState.COMPLETE, NodeState.SKIPPED, - NodeState.SKIPPED_PARTIAL_RUN) - - -def is_node_state_failure(state: str) -> bool: - return state == NodeState.FAILED - - -def is_node_state_running(state: str) -> bool: - return state == NodeState.RUNNING - - _NODE_STATE_TO_RUN_STATE_MAP = { NodeState.STARTED: run_state_pb2.RunState.READY, NodeState.STOPPING: run_state_pb2.RunState.UNKNOWN,