|
89 | 89 | metadata_store_pb2.Execution.State.CANCELED:
|
90 | 90 | run_state_pb2.RunState.STOPPED,
|
91 | 91 | }
|
92 |
| - |
93 |
| - |
94 |
| -@dataclasses.dataclass |
95 |
| -class StateRecord(json_utils.Jsonable): |
96 |
| - state: str |
97 |
| - backfill_token: str |
98 |
| - status_code: Optional[int] |
99 |
| - update_time: float |
100 |
| - # TODO(b/242083811) Some status_msg have already been written into MLMD. |
101 |
| - # Keeping this field is for backward compatibility to avoid json failing to |
102 |
| - # parse existing status_msg. We can remove it once we are sure no status_msg |
103 |
| - # in MLMD is in use. |
104 |
| - status_msg: str = '' |
105 |
| - |
106 |
| - |
107 |
| -# TODO(b/228198652): Stop using json_util.Jsonable. Before we do, |
108 |
| -# this class MUST NOT be moved out of this module. |
109 |
| -@attr.s(auto_attribs=True, kw_only=True) |
110 |
| -class NodeState(json_utils.Jsonable): |
111 |
| - """Records node state. |
112 |
| -
|
113 |
| - Attributes: |
114 |
| - state: Current state of the node. |
115 |
| - status: Status of the node in state STOPPING or STOPPED. |
116 |
| - """ |
117 |
| - |
118 |
| - STARTED = 'started' # Node is ready for execution. |
119 |
| - STOPPING = 'stopping' # Pending work before state can change to STOPPED. |
120 |
| - STOPPED = 'stopped' # Node execution is stopped. |
121 |
| - RUNNING = 'running' # Node is under active execution (i.e. triggered). |
122 |
| - COMPLETE = 'complete' # Node execution completed successfully. |
123 |
| - # Node execution skipped due to condition not satisfied when pipeline has |
124 |
| - # conditionals. |
125 |
| - SKIPPED = 'skipped' |
126 |
| - # Node execution skipped due to partial run. |
127 |
| - SKIPPED_PARTIAL_RUN = 'skipped_partial_run' |
128 |
| - FAILED = 'failed' # Node execution failed due to errors. |
129 |
| - |
130 |
| - state: str = attr.ib( |
131 |
| - default=STARTED, |
132 |
| - validator=attr.validators.in_([ |
133 |
| - STARTED, |
134 |
| - STOPPING, |
135 |
| - STOPPED, |
136 |
| - RUNNING, |
137 |
| - COMPLETE, |
138 |
| - SKIPPED, |
139 |
| - SKIPPED_PARTIAL_RUN, |
140 |
| - FAILED, |
141 |
| - ]), |
142 |
| - on_setattr=attr.setters.validate, |
143 |
| - ) |
144 |
| - backfill_token: str = '' |
145 |
| - status_code: Optional[int] = None |
146 |
| - status_msg: str = '' |
147 |
| - last_updated_time: float = attr.ib(factory=lambda: time.time()) # pylint:disable=unnecessary-lambda |
148 |
| - |
149 |
| - state_history: List[StateRecord] = attr.ib(default=attr.Factory(list)) |
150 |
| - |
151 |
| - @property |
152 |
| - def status(self) -> Optional[status_lib.Status]: |
153 |
| - if self.status_code is not None: |
154 |
| - return status_lib.Status(code=self.status_code, message=self.status_msg) |
155 |
| - return None |
156 |
| - |
157 |
| - def update( |
158 |
| - self, |
159 |
| - state: str, |
160 |
| - status: Optional[status_lib.Status] = None, |
161 |
| - backfill_token: str = '', |
162 |
| - ) -> None: |
163 |
| - if self.state != state: |
164 |
| - self.state_history.append( |
165 |
| - StateRecord( |
166 |
| - state=self.state, |
167 |
| - backfill_token=self.backfill_token, |
168 |
| - status_code=self.status_code, |
169 |
| - update_time=self.last_updated_time, |
170 |
| - ) |
171 |
| - ) |
172 |
| - if len(self.state_history) > _MAX_STATE_HISTORY_LEN: |
173 |
| - self.state_history = self.state_history[-_MAX_STATE_HISTORY_LEN:] |
174 |
| - self.last_updated_time = time.time() |
175 |
| - |
176 |
| - self.state = state |
177 |
| - self.backfill_token = backfill_token |
178 |
| - self.status_code = status.code if status is not None else None |
179 |
| - self.status_msg = (status.message or '') if status is not None else '' |
180 |
| - |
181 |
| - def is_startable(self) -> bool: |
182 |
| - """Returns True if the node can be started.""" |
183 |
| - return self.state in set([self.STOPPING, self.STOPPED, self.FAILED]) |
184 |
| - |
185 |
| - def is_stoppable(self) -> bool: |
186 |
| - """Returns True if the node can be stopped.""" |
187 |
| - return self.state in set([self.STARTED, self.RUNNING]) |
188 |
| - |
189 |
| - def is_backfillable(self) -> bool: |
190 |
| - """Returns True if the node can be backfilled.""" |
191 |
| - return self.state in set([self.STOPPED, self.FAILED]) |
192 |
| - |
193 |
| - def is_programmatically_skippable(self) -> bool: |
194 |
| - """Returns True if the node can be skipped via programmatic operation.""" |
195 |
| - return self.state in set([self.STARTED, self.STOPPED]) |
196 |
| - |
197 |
| - def is_success(self) -> bool: |
198 |
| - return is_node_state_success(self.state) |
199 |
| - |
200 |
| - def is_failure(self) -> bool: |
201 |
| - return is_node_state_failure(self.state) |
202 |
| - |
203 |
| - def to_run_state(self) -> run_state_pb2.RunState: |
204 |
| - """Returns this NodeState converted to a RunState.""" |
205 |
| - status_code_value = None |
206 |
| - if self.status_code is not None: |
207 |
| - status_code_value = run_state_pb2.RunState.StatusCodeValue( |
208 |
| - value=self.status_code) |
209 |
| - return run_state_pb2.RunState( |
210 |
| - state=_NODE_STATE_TO_RUN_STATE_MAP.get( |
211 |
| - self.state, run_state_pb2.RunState.UNKNOWN |
212 |
| - ), |
213 |
| - status_code=status_code_value, |
214 |
| - status_msg=self.status_msg, |
215 |
| - update_time=int(self.last_updated_time * 1000), |
216 |
| - ) |
217 |
| - |
218 |
| - def to_run_state_history(self) -> List[run_state_pb2.RunState]: |
219 |
| - run_state_history = [] |
220 |
| - for state in self.state_history: |
221 |
| - # STARTING, PAUSING and PAUSED has been deprecated but may still be |
222 |
| - # present in state_history. |
223 |
| - if ( |
224 |
| - state.state == 'starting' |
225 |
| - or state.state == 'pausing' |
226 |
| - or state.state == 'paused' |
227 |
| - ): |
228 |
| - continue |
229 |
| - run_state_history.append( |
230 |
| - NodeState( |
231 |
| - state=state.state, |
232 |
| - status_code=state.status_code, |
233 |
| - last_updated_time=state.update_time).to_run_state()) |
234 |
| - return run_state_history |
235 |
| - |
236 |
| - # By default, json_utils.Jsonable serializes and deserializes objects using |
237 |
| - # obj.__dict__, which prevents attr.ib from populating default fields. |
238 |
| - # Overriding this function to ensure default fields are populated. |
239 |
| - @classmethod |
240 |
| - def from_json_dict(cls, dict_data: Dict[str, Any]) -> Any: |
241 |
| - """Convert from dictionary data to an object.""" |
242 |
| - return cls(**dict_data) |
243 |
| - |
244 |
| - def latest_predicate_time_s(self, predicate: Callable[[StateRecord], bool], |
245 |
| - include_current_state: bool) -> Optional[int]: |
246 |
| - """Returns the latest time the StateRecord satisfies the given predicate. |
247 |
| -
|
248 |
| - Args: |
249 |
| - predicate: Predicate that takes the state string. |
250 |
| - include_current_state: Whether to include the current node state when |
251 |
| - checking the node state history (the node state history doesn't include |
252 |
| - the current node state). |
253 |
| -
|
254 |
| - Returns: |
255 |
| - The latest time (in the state history) the StateRecord satisfies the given |
256 |
| - predicate, or None if the predicate is never satisfied. |
257 |
| - """ |
258 |
| - if include_current_state: |
259 |
| - current_record = StateRecord( |
260 |
| - state=self.state, |
261 |
| - backfill_token=self.backfill_token, |
262 |
| - status_code=self.status_code, |
263 |
| - update_time=self.last_updated_time, |
264 |
| - ) |
265 |
| - if predicate(current_record): |
266 |
| - return int(current_record.update_time) |
267 |
| - |
268 |
| - for s in reversed(self.state_history): |
269 |
| - if predicate(s): |
270 |
| - return int(s.update_time) |
271 |
| - return None |
272 |
| - |
273 |
| - def latest_running_time_s(self) -> Optional[int]: |
274 |
| - """Returns the latest time the node entered a RUNNING state. |
275 |
| -
|
276 |
| - Returns: |
277 |
| - The latest time (in the state history) the node entered a RUNNING |
278 |
| - state, or None if the node never entered a RUNNING state. |
279 |
| - """ |
280 |
| - return self.latest_predicate_time_s( |
281 |
| - lambda s: is_node_state_running(s.state), include_current_state=True) |
282 |
| - |
283 |
| - |
284 |
| -class _NodeStatesProxy: |
285 |
| - """Proxy for reading and updating deserialized NodeState dicts from Execution. |
286 |
| -
|
287 |
| - This proxy contains an internal write-back cache. Changes are not saved back |
288 |
| - to the `Execution` until `save()` is called; cache would not be updated if |
289 |
| - changes were made outside of the proxy, either. This is primarily used to |
290 |
| - reduce JSON serialization/deserialization overhead for getting node state |
291 |
| - execution property from pipeline execution. |
292 |
| - """ |
293 |
| - |
294 |
| - def __init__(self, execution: metadata_store_pb2.Execution): |
295 |
| - self._custom_properties = execution.custom_properties |
296 |
| - self._deserialized_cache: Dict[str, Dict[str, NodeState]] = {} |
297 |
| - self._changed_state_types: Set[str] = set() |
298 |
| - |
299 |
| - def get(self, state_type: str = _NODE_STATES) -> Dict[str, NodeState]: |
300 |
| - """Gets node states dict from pipeline execution with the specified type.""" |
301 |
| - if state_type not in [_NODE_STATES, _PREVIOUS_NODE_STATES]: |
302 |
| - raise status_lib.StatusNotOkError( |
303 |
| - code=status_lib.Code.INVALID_ARGUMENT, |
304 |
| - message=( |
305 |
| - f'Expected state_type is {_NODE_STATES} or' |
306 |
| - f' {_PREVIOUS_NODE_STATES}, got {state_type}.' |
307 |
| - ), |
308 |
| - ) |
309 |
| - if state_type not in self._deserialized_cache: |
310 |
| - node_states_json = _get_metadata_value( |
311 |
| - self._custom_properties.get(state_type) |
312 |
| - ) |
313 |
| - self._deserialized_cache[state_type] = ( |
314 |
| - json_utils.loads(node_states_json) if node_states_json else {} |
315 |
| - ) |
316 |
| - return self._deserialized_cache[state_type] |
317 |
| - |
318 |
| - def set( |
319 |
| - self, node_states: Dict[str, NodeState], state_type: str = _NODE_STATES |
320 |
| - ) -> None: |
321 |
| - """Sets node states dict with the specified type.""" |
322 |
| - self._deserialized_cache[state_type] = node_states |
323 |
| - self._changed_state_types.add(state_type) |
324 |
| - |
325 |
| - def save(self) -> None: |
326 |
| - """Saves all changed node states dicts to pipeline execution.""" |
327 |
| - max_mlmd_str_value_len = env.get_env().max_mlmd_str_value_length() |
328 |
| - |
329 |
| - for state_type in self._changed_state_types: |
330 |
| - node_states = self._deserialized_cache[state_type] |
331 |
| - node_states_json = json_utils.dumps(node_states) |
332 |
| - |
333 |
| - # Removes state history from node states if it's too large to avoid |
334 |
| - # hitting MLMD limit. |
335 |
| - if ( |
336 |
| - max_mlmd_str_value_len |
337 |
| - and len(node_states_json) > max_mlmd_str_value_len |
338 |
| - ): |
339 |
| - logging.info( |
340 |
| - 'Node states length %d is too large (> %d); Removing state history' |
341 |
| - ' from it.', |
342 |
| - len(node_states_json), |
343 |
| - max_mlmd_str_value_len, |
344 |
| - ) |
345 |
| - node_states_no_history = {} |
346 |
| - for node, old_state in node_states.items(): |
347 |
| - new_state = copy.deepcopy(old_state) |
348 |
| - new_state.state_history.clear() |
349 |
| - node_states_no_history[node] = new_state |
350 |
| - node_states_json = json_utils.dumps(node_states_no_history) |
351 |
| - logging.info( |
352 |
| - 'Node states length after removing state history: %d', |
353 |
| - len(node_states_json), |
354 |
| - ) |
355 |
| - |
356 |
| - data_types_utils.set_metadata_value( |
357 |
| - self._custom_properties[state_type], node_states_json |
358 |
| - ) |
359 |
| - |
360 |
| - |
361 |
| -def is_node_state_success(state: str) -> bool: |
362 |
| - return state in (NodeState.COMPLETE, NodeState.SKIPPED, |
363 |
| - NodeState.SKIPPED_PARTIAL_RUN) |
364 |
| - |
365 |
| - |
366 |
| -def is_node_state_failure(state: str) -> bool: |
367 |
| - return state == NodeState.FAILED |
368 |
| - |
369 |
| - |
370 |
| -def is_node_state_running(state: str) -> bool: |
371 |
| - return state == NodeState.RUNNING |
372 |
| - |
373 |
| - |
374 | 92 | _NODE_STATE_TO_RUN_STATE_MAP = {
|
375 | 93 | NodeState.STARTED: run_state_pb2.RunState.READY,
|
376 | 94 | NodeState.STOPPING: run_state_pb2.RunState.UNKNOWN,
|
|
0 commit comments