Skip to content

Commit

Permalink
Provide orbit preemption on-demand checkpoint action.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 506370331
  • Loading branch information
weinan1997 authored and tensorflower-gardener committed Feb 1, 2023
1 parent 4711f47 commit c897ca4
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
2 changes: 2 additions & 0 deletions orbit/actions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,5 @@ class to make it easy to trigger actions conditionally based on reusable

from orbit.actions.new_best_metric import JSONPersistedValue
from orbit.actions.new_best_metric import NewBestMetric

from orbit.actions.save_checkpoint_if_preempted import SaveCheckpointIfPreempted
62 changes: 62 additions & 0 deletions orbit/actions/save_checkpoint_if_preempted.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2022 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Provides the `SaveCheckpointIfPreempted` action."""

from typing import Optional

import tensorflow as tf


class SaveCheckpointIfPreempted:
"""Action that saves on-demand checkpoints after a preemption."""

def __init__(
self,
cluster_resolver: tf.distribute.cluster_resolver.ClusterResolver,
checkpoint_manager: tf.train.CheckpointManager,
checkpoint_number: Optional[tf.Variable] = None,
keep_running_after_save: Optional[bool] = False,
):
"""Initializes the instance.
Args:
cluster_resolver: A `tf.distribute.cluster_resolver.ClusterResolver`
object.
checkpoint_manager: A `tf.train.CheckpointManager` object.
checkpoint_number: A `tf.Variable` to indicate the checkpoint_number for
checkpoint manager, usually it will be the global step.
keep_running_after_save: Whether to keep the job running after the
preemption on-demand checkpoint. Only set to True when in-process
preemption recovery with tf.distribute.experimental.PreemptionWatcher is
enabled.
"""
self._checkpoint_number = checkpoint_number
self._termination_config = None
if keep_running_after_save:
self._termination_config = tf.distribute.experimental.TerminationConfig(
exit_fn=lambda: None
)
self._preemption_handler = (
tf.distribute.experimental.PreemptionCheckpointHandler(
cluster_resolver,
checkpoint_manager,
termination_config=self._termination_config,
)
)

def __call__(self, _) -> None:
self._preemption_handler._save_checkpoint_if_preempted(
checkpoint_number=self._checkpoint_number, check_interval=False
)

0 comments on commit c897ca4

Please sign in to comment.