Skip to content

Commit

Permalink
Add an argument to set checkpoint dir.
Browse files Browse the repository at this point in the history
  • Loading branch information
workingloong committed Jan 2, 2024
1 parent e833e69 commit faa5b4f
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 15 deletions.
3 changes: 3 additions & 0 deletions dlrover/trainer/torch/flash_checkpoint/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,9 @@ def save_to_storage(self, step, state_dict, path):
if self._local_rank != 0:
return
if path:
logger.info(
"Put a save event to notify the agent persists checkpoint."
)
event = CheckpointEvent(type=CheckpointEventType.SAVE, step=step)
self._event_queue.put(event)

Expand Down
3 changes: 2 additions & 1 deletion examples/pytorch/example.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ FROM python:3.8.14 as base

WORKDIR /dlrover
RUN apt-get update && apt-get install -y sudo vim libgl1-mesa-glx libglib2.0-dev
RUN pip install deprecated pyparsing torch==2.0.1 opencv-python==4.7.0.72 torchvision==0.15.2 transformers
RUN pip install deprecated pyparsing torch==2.0.1 opencv-python==4.7.0.72 \
torchvision==0.15.2 transformers deepspeed

COPY ./data /data
COPY ./examples ./examples
Expand Down
6 changes: 3 additions & 3 deletions examples/pytorch/nanogpt/ds_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,10 @@

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

# We should use a shared storage to persist the checkpiont.
checkpoint_dir = "/nas/nanogpt-ckpt-ds/"


def train():
args = arg_parser()
checkpoint_dir = args.save_dir
setup()
os.makedirs(checkpoint_dir, exist_ok=True)
world_size = int(os.getenv("WORLD_SIZE", 1))
Expand Down Expand Up @@ -212,6 +210,7 @@ def train():
iter_num,
args.save_memory_interval,
args.save_storage_interval,
checkpoint_dir,
)
if saved:
save_time = round(time.time() - start_save_t, 2)
Expand Down Expand Up @@ -243,6 +242,7 @@ def flash_save_checkpoint(
iter_num,
save_memory_interval,
save_storage_interval,
checkpoint_dir,
):
saved = False
if iter_num % save_memory_interval == 0:
Expand Down
20 changes: 13 additions & 7 deletions examples/pytorch/nanogpt/fsdp_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,10 @@

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

# We should use a shared storage to persist the checkpiont.
checkpoint_dir = "/nas/nanogpt-ckpt-fsdp/"


def train():
args = arg_parser()
checkpoint_dir = args.save_dir
setup()
os.makedirs(checkpoint_dir, exist_ok=True)
world_size = int(os.getenv("WORLD_SIZE", 1))
Expand Down Expand Up @@ -162,7 +160,7 @@ def train():

start_load_t = time.time()
if args.use_native_ckpt:
iter_num = native_load_checkpoint(0, model, optimizer)
iter_num = native_load_checkpoint(0, model, optimizer, checkpoint_dir)
else:
checkpointer = FsdpCheckpointer(checkpoint_dir)
iter_num = flash_load_checkpoint(checkpointer, model, optimizer)
Expand Down Expand Up @@ -231,7 +229,11 @@ def train():
start_save_t = time.time()
if args.use_native_ckpt:
saved = native_save_checkpoint(
iter_num, model, optimizer, args.save_storage_interval
iter_num,
model,
optimizer,
args.save_storage_interval,
checkpoint_dir,
)
else:
saved = flash_save_checkpoint(
Expand All @@ -241,6 +243,7 @@ def train():
optimizer,
args.save_memory_interval,
args.save_storage_interval,
checkpoint_dir,
)
if saved:
save_time = round(time.time() - start_save_t, 2)
Expand All @@ -255,7 +258,7 @@ def train():
break


def native_load_checkpoint(step, model, optimizer):
def native_load_checkpoint(step, model, optimizer, checkpoint_dir):
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
state_dict = {
"model": model.state_dict(),
Expand Down Expand Up @@ -286,7 +289,9 @@ def native_load_checkpoint(step, model, optimizer):
return state_dict["step"]


def native_save_checkpoint(step, model, optimizer, save_storage_interval):
def native_save_checkpoint(
step, model, optimizer, save_storage_interval, checkpoint_dir
):
saved = False
if step % save_storage_interval != 0:
return saved
Expand Down Expand Up @@ -344,6 +349,7 @@ def flash_save_checkpoint(
optimizer,
save_memory_interval,
save_storage_interval,
checkpoint_dir,
):
saved = False
if step % save_memory_interval != 0 and step % save_storage_interval != 0:
Expand Down
12 changes: 8 additions & 4 deletions examples/pytorch/nanogpt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,10 @@

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

# We should use a shared storage to persist the checkpiont.
checkpoint_dir = "/nas/nanogpt-ckpt/"


def train():
args = arg_parser()
checkpoint_dir = args.save_dir
setup()
os.makedirs(checkpoint_dir, exist_ok=True)
world_size = int(os.getenv("WORLD_SIZE", 1))
Expand Down Expand Up @@ -244,6 +242,7 @@ def train():
optimizer,
train_loader,
args.save_storage_interval,
checkpoint_dir,
)
else:
saved = flash_save_checkpoint(
Expand All @@ -269,7 +268,12 @@ def train():


def native_save_checkpoint(
iter_num, model, optimizer, train_loader, save_storage_interval
iter_num,
model,
optimizer,
train_loader,
save_storage_interval,
checkpoint_dir,
):
saved = False
if iter_num % save_storage_interval != 0:
Expand Down
3 changes: 3 additions & 0 deletions examples/pytorch/nanogpt/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,6 @@ def add_train_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--use_native_ckpt", action="store_true", required=False
)
parser.add_argument(
"--save_dir", type=str, default="/tmp/checkpoint/", required=False
)

0 comments on commit faa5b4f

Please sign in to comment.