Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scr: initial integration #4

Open
wants to merge 1 commit into
base: distshuf
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,15 @@ def _add_checkpointing_args(parser):
'or rng state from checkpoint and set iteration to 0. '
'Assumed when loading a release checkpoint.')

group.add_argument('--scr', action='store_true', default=None,
help='Enable SCR for checkpointing.')
group.add_argument('--scr-interval', type=int, default=None,
help='Number of iterations between defensive checkpoints.')
group.add_argument('--scr-seconds', type=float, default=None,
help='Number of seconds between defensive checkpoints.')
group.add_argument('--scr-overhead', type=float, default=None,
help='Maximum runtime percentage for defensive checkpoints.')

return parser


Expand Down
8 changes: 8 additions & 0 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save))

# SCR: the call to start output lives in DeepSpeed right now
#if args.scr:
# scr.start_output(checkpoint_name, scr.FLAG_CHECKPOINT)

if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0 \
or args.deepspeed:

Expand Down Expand Up @@ -184,6 +188,10 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
if torch.distributed.is_initialized():
torch.distributed.barrier()

# SCR: the call to start output lives in DeepSpeed right now
#if args.scr:
# scr.complete_output(valid)

# since the code can be exited or aborted in various places we use the checkpoint saving as
# a save saving point for the codecarbon tracker. If the program doesn't run to its normal
# end, then only the data since the last saved checkpoint will be lost.
Expand Down
31 changes: 30 additions & 1 deletion megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@

import deepspeed

from mpi4py import MPI
import scr


def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False, allow_no_cuda=False):
Expand Down Expand Up @@ -66,6 +69,31 @@ def finish_mpu_init():
_set_random_seed(args.seed)

args = get_args()

# SCR: point SCR_PREFIX to checkpoint path
if args.scr:
# SCR only supports a single directory to both read previous checkpoints and to write new checkpoints
if args.save != args.load:
raise ValueError(f"--save {args.save} must match --load {args.load} when using SCR")

# SCR will default to use the current working dir if args.save not specified
if args.save is not None:
scr.config(f"SCR_PREFIX={args.save}")

# DeepSpeed expects files to be on global file system
# This will flush any cached checkpoint to the file system on restart
scr.config("SCR_GLOBAL_RESTART=1")

# Configure seconds between checkpoints if user provided a limit.
if args.scr_seconds is not None:
scr.config(f"SCR_CHECKPOINT_SECONDS={args.scr_seconds}")

# Configure max percentage of runtime for checkpointing if user provided a limit.
if args.scr_overhead is not None:
scr.config(f"SCR_CHECKPOINT_OVERHEAD={args.scr_overhead}")

scr.init()

if args.lazy_mpu_init:
args.use_cpu_initialization=True
# delayed initialization of DDP-related stuff
Expand Down Expand Up @@ -208,7 +236,8 @@ def _initialize_distributed():
args.local_rank = device
torch.cuda.set_device(device)
# Call the init process
init_method = 'tcp://'
#init_method = 'tcp://'
init_method = 'env://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
Expand Down
42 changes: 42 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@

import deepspeed

# SCR: import scalable checkpoint/restart library
import scr


def print_datetime(string):
"""Note that this call will sync across all ranks."""
Expand Down Expand Up @@ -167,6 +170,10 @@ def pretrain(train_valid_test_dataset_provider,

codecarbon_tracker_stop()

# SCR: flush any cached checkpoint
if args.scr:
scr.finalize()


def update_train_iters(args):

Expand Down Expand Up @@ -730,6 +737,31 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
lr_scheduler)
saved_checkpoint = True

# SCR: Take a defensive checkpoint if it's time
if args.save and args.scr and args.scr_interval and \
iteration % args.scr_interval == 0:
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
saved_checkpoint = True

# SCR: Take a defensive checkpoint if SCR recommends its
if args.save and args.scr and scr.need_checkpoint():
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
saved_checkpoint = True

# SCR: Save checkpiont and exit run if SCR recommends its
#if args.save and args.scr and scr.should_exit():
# if not saved_checkpoint:
# save_checkpoint_and_time(iteration, model, optimizer,
# lr_scheduler)
# torch.distributed.barrier()
# print_datetime('exiting program at iteration {}'.format(iteration))
# scr.finalize()
# sys.exit()

# Exiting based on duration
if args.exit_duration_in_mins:
train_time = (time.time() - _TRAIN_START_TIME) / 60.0
Expand All @@ -743,6 +775,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
print_datetime('exiting program after {} minutes'.format(train_time))

# SCR: finalize to flush any cached checkpoint
if args.scr:
scr.finalize()

sys.exit()

# Exiting based on iterations
Expand All @@ -752,6 +789,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
lr_scheduler)
torch.distributed.barrier()
print_datetime('exiting program at iteration {}'.format(iteration))

# SCR: finalize to flush any cached checkpoint
if args.scr:
scr.finalize()

sys.exit()


Expand Down
Loading