Skip to content

Commit 684edfc

Browse files
XinranTangtfx-copybara
authored andcommitted
no-op
PiperOrigin-RevId: 651792884
1 parent ce6f947 commit 684edfc

File tree

1 file changed

+0
-282
lines changed

1 file changed

+0
-282
lines changed

tfx/orchestration/experimental/core/pipeline_state.py

-282
Original file line numberDiff line numberDiff line change
@@ -89,288 +89,6 @@
8989
metadata_store_pb2.Execution.State.CANCELED:
9090
run_state_pb2.RunState.STOPPED,
9191
}
92-
93-
94-
@dataclasses.dataclass
95-
class StateRecord(json_utils.Jsonable):
96-
state: str
97-
backfill_token: str
98-
status_code: Optional[int]
99-
update_time: float
100-
# TODO(b/242083811) Some status_msg have already been written into MLMD.
101-
# Keeping this field is for backward compatibility to avoid json failing to
102-
# parse existing status_msg. We can remove it once we are sure no status_msg
103-
# in MLMD is in use.
104-
status_msg: str = ''
105-
106-
107-
# TODO(b/228198652): Stop using json_util.Jsonable. Before we do,
108-
# this class MUST NOT be moved out of this module.
109-
@attr.s(auto_attribs=True, kw_only=True)
110-
class NodeState(json_utils.Jsonable):
111-
"""Records node state.
112-
113-
Attributes:
114-
state: Current state of the node.
115-
status: Status of the node in state STOPPING or STOPPED.
116-
"""
117-
118-
STARTED = 'started' # Node is ready for execution.
119-
STOPPING = 'stopping' # Pending work before state can change to STOPPED.
120-
STOPPED = 'stopped' # Node execution is stopped.
121-
RUNNING = 'running' # Node is under active execution (i.e. triggered).
122-
COMPLETE = 'complete' # Node execution completed successfully.
123-
# Node execution skipped due to condition not satisfied when pipeline has
124-
# conditionals.
125-
SKIPPED = 'skipped'
126-
# Node execution skipped due to partial run.
127-
SKIPPED_PARTIAL_RUN = 'skipped_partial_run'
128-
FAILED = 'failed' # Node execution failed due to errors.
129-
130-
state: str = attr.ib(
131-
default=STARTED,
132-
validator=attr.validators.in_([
133-
STARTED,
134-
STOPPING,
135-
STOPPED,
136-
RUNNING,
137-
COMPLETE,
138-
SKIPPED,
139-
SKIPPED_PARTIAL_RUN,
140-
FAILED,
141-
]),
142-
on_setattr=attr.setters.validate,
143-
)
144-
backfill_token: str = ''
145-
status_code: Optional[int] = None
146-
status_msg: str = ''
147-
last_updated_time: float = attr.ib(factory=lambda: time.time()) # pylint:disable=unnecessary-lambda
148-
149-
state_history: List[StateRecord] = attr.ib(default=attr.Factory(list))
150-
151-
@property
152-
def status(self) -> Optional[status_lib.Status]:
153-
if self.status_code is not None:
154-
return status_lib.Status(code=self.status_code, message=self.status_msg)
155-
return None
156-
157-
def update(
158-
self,
159-
state: str,
160-
status: Optional[status_lib.Status] = None,
161-
backfill_token: str = '',
162-
) -> None:
163-
if self.state != state:
164-
self.state_history.append(
165-
StateRecord(
166-
state=self.state,
167-
backfill_token=self.backfill_token,
168-
status_code=self.status_code,
169-
update_time=self.last_updated_time,
170-
)
171-
)
172-
if len(self.state_history) > _MAX_STATE_HISTORY_LEN:
173-
self.state_history = self.state_history[-_MAX_STATE_HISTORY_LEN:]
174-
self.last_updated_time = time.time()
175-
176-
self.state = state
177-
self.backfill_token = backfill_token
178-
self.status_code = status.code if status is not None else None
179-
self.status_msg = (status.message or '') if status is not None else ''
180-
181-
def is_startable(self) -> bool:
182-
"""Returns True if the node can be started."""
183-
return self.state in set([self.STOPPING, self.STOPPED, self.FAILED])
184-
185-
def is_stoppable(self) -> bool:
186-
"""Returns True if the node can be stopped."""
187-
return self.state in set([self.STARTED, self.RUNNING])
188-
189-
def is_backfillable(self) -> bool:
190-
"""Returns True if the node can be backfilled."""
191-
return self.state in set([self.STOPPED, self.FAILED])
192-
193-
def is_programmatically_skippable(self) -> bool:
194-
"""Returns True if the node can be skipped via programmatic operation."""
195-
return self.state in set([self.STARTED, self.STOPPED])
196-
197-
def is_success(self) -> bool:
198-
return is_node_state_success(self.state)
199-
200-
def is_failure(self) -> bool:
201-
return is_node_state_failure(self.state)
202-
203-
def to_run_state(self) -> run_state_pb2.RunState:
204-
"""Returns this NodeState converted to a RunState."""
205-
status_code_value = None
206-
if self.status_code is not None:
207-
status_code_value = run_state_pb2.RunState.StatusCodeValue(
208-
value=self.status_code)
209-
return run_state_pb2.RunState(
210-
state=_NODE_STATE_TO_RUN_STATE_MAP.get(
211-
self.state, run_state_pb2.RunState.UNKNOWN
212-
),
213-
status_code=status_code_value,
214-
status_msg=self.status_msg,
215-
update_time=int(self.last_updated_time * 1000),
216-
)
217-
218-
def to_run_state_history(self) -> List[run_state_pb2.RunState]:
219-
run_state_history = []
220-
for state in self.state_history:
221-
# STARTING, PAUSING and PAUSED has been deprecated but may still be
222-
# present in state_history.
223-
if (
224-
state.state == 'starting'
225-
or state.state == 'pausing'
226-
or state.state == 'paused'
227-
):
228-
continue
229-
run_state_history.append(
230-
NodeState(
231-
state=state.state,
232-
status_code=state.status_code,
233-
last_updated_time=state.update_time).to_run_state())
234-
return run_state_history
235-
236-
# By default, json_utils.Jsonable serializes and deserializes objects using
237-
# obj.__dict__, which prevents attr.ib from populating default fields.
238-
# Overriding this function to ensure default fields are populated.
239-
@classmethod
240-
def from_json_dict(cls, dict_data: Dict[str, Any]) -> Any:
241-
"""Convert from dictionary data to an object."""
242-
return cls(**dict_data)
243-
244-
def latest_predicate_time_s(self, predicate: Callable[[StateRecord], bool],
245-
include_current_state: bool) -> Optional[int]:
246-
"""Returns the latest time the StateRecord satisfies the given predicate.
247-
248-
Args:
249-
predicate: Predicate that takes the state string.
250-
include_current_state: Whether to include the current node state when
251-
checking the node state history (the node state history doesn't include
252-
the current node state).
253-
254-
Returns:
255-
The latest time (in the state history) the StateRecord satisfies the given
256-
predicate, or None if the predicate is never satisfied.
257-
"""
258-
if include_current_state:
259-
current_record = StateRecord(
260-
state=self.state,
261-
backfill_token=self.backfill_token,
262-
status_code=self.status_code,
263-
update_time=self.last_updated_time,
264-
)
265-
if predicate(current_record):
266-
return int(current_record.update_time)
267-
268-
for s in reversed(self.state_history):
269-
if predicate(s):
270-
return int(s.update_time)
271-
return None
272-
273-
def latest_running_time_s(self) -> Optional[int]:
274-
"""Returns the latest time the node entered a RUNNING state.
275-
276-
Returns:
277-
The latest time (in the state history) the node entered a RUNNING
278-
state, or None if the node never entered a RUNNING state.
279-
"""
280-
return self.latest_predicate_time_s(
281-
lambda s: is_node_state_running(s.state), include_current_state=True)
282-
283-
284-
class _NodeStatesProxy:
285-
"""Proxy for reading and updating deserialized NodeState dicts from Execution.
286-
287-
This proxy contains an internal write-back cache. Changes are not saved back
288-
to the `Execution` until `save()` is called; cache would not be updated if
289-
changes were made outside of the proxy, either. This is primarily used to
290-
reduce JSON serialization/deserialization overhead for getting node state
291-
execution property from pipeline execution.
292-
"""
293-
294-
def __init__(self, execution: metadata_store_pb2.Execution):
295-
self._custom_properties = execution.custom_properties
296-
self._deserialized_cache: Dict[str, Dict[str, NodeState]] = {}
297-
self._changed_state_types: Set[str] = set()
298-
299-
def get(self, state_type: str = _NODE_STATES) -> Dict[str, NodeState]:
300-
"""Gets node states dict from pipeline execution with the specified type."""
301-
if state_type not in [_NODE_STATES, _PREVIOUS_NODE_STATES]:
302-
raise status_lib.StatusNotOkError(
303-
code=status_lib.Code.INVALID_ARGUMENT,
304-
message=(
305-
f'Expected state_type is {_NODE_STATES} or'
306-
f' {_PREVIOUS_NODE_STATES}, got {state_type}.'
307-
),
308-
)
309-
if state_type not in self._deserialized_cache:
310-
node_states_json = _get_metadata_value(
311-
self._custom_properties.get(state_type)
312-
)
313-
self._deserialized_cache[state_type] = (
314-
json_utils.loads(node_states_json) if node_states_json else {}
315-
)
316-
return self._deserialized_cache[state_type]
317-
318-
def set(
319-
self, node_states: Dict[str, NodeState], state_type: str = _NODE_STATES
320-
) -> None:
321-
"""Sets node states dict with the specified type."""
322-
self._deserialized_cache[state_type] = node_states
323-
self._changed_state_types.add(state_type)
324-
325-
def save(self) -> None:
326-
"""Saves all changed node states dicts to pipeline execution."""
327-
max_mlmd_str_value_len = env.get_env().max_mlmd_str_value_length()
328-
329-
for state_type in self._changed_state_types:
330-
node_states = self._deserialized_cache[state_type]
331-
node_states_json = json_utils.dumps(node_states)
332-
333-
# Removes state history from node states if it's too large to avoid
334-
# hitting MLMD limit.
335-
if (
336-
max_mlmd_str_value_len
337-
and len(node_states_json) > max_mlmd_str_value_len
338-
):
339-
logging.info(
340-
'Node states length %d is too large (> %d); Removing state history'
341-
' from it.',
342-
len(node_states_json),
343-
max_mlmd_str_value_len,
344-
)
345-
node_states_no_history = {}
346-
for node, old_state in node_states.items():
347-
new_state = copy.deepcopy(old_state)
348-
new_state.state_history.clear()
349-
node_states_no_history[node] = new_state
350-
node_states_json = json_utils.dumps(node_states_no_history)
351-
logging.info(
352-
'Node states length after removing state history: %d',
353-
len(node_states_json),
354-
)
355-
356-
data_types_utils.set_metadata_value(
357-
self._custom_properties[state_type], node_states_json
358-
)
359-
360-
361-
def is_node_state_success(state: str) -> bool:
362-
return state in (NodeState.COMPLETE, NodeState.SKIPPED,
363-
NodeState.SKIPPED_PARTIAL_RUN)
364-
365-
366-
def is_node_state_failure(state: str) -> bool:
367-
return state == NodeState.FAILED
368-
369-
370-
def is_node_state_running(state: str) -> bool:
371-
return state == NodeState.RUNNING
372-
373-
37492
_NODE_STATE_TO_RUN_STATE_MAP = {
37593
NodeState.STARTED: run_state_pb2.RunState.READY,
37694
NodeState.STOPPING: run_state_pb2.RunState.UNKNOWN,

0 commit comments

Comments
 (0)