Skip to content

Commit

Permalink
fixing the full state path in checkpoint handler+loss report calculat…
Browse files Browse the repository at this point in the history
…ion (#51)
  • Loading branch information
chauhang authored Aug 1, 2023
2 parents 174b856 + 88d3e1f commit 1387b76
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 140 deletions.
13 changes: 13 additions & 0 deletions docs/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,16 @@ Here we discuss frequently asked questions that may occur and we found useful al
6. What are the hardware SKU requirements for fine-tuning Llama pre-trained models?

Fine-tuning requirements vary based on amount of data, time to complete fine-tuning and cost constraints. To fine-tune these models we have generally used multiple NVIDIA A100 machines with data parallelism across nodes and a mix of data and tensor parallelism intra node. But using a single machine, or other GPU types like NVIDIA A10G or H100 are definitely possible (e.g. alpaca models are trained on a single RTX4090: https://github.com/tloen/alpaca-lora).

7. How to handle CUDA memory fragmentations during fine-tuning that may lead into an OOM? In some cases you may experience that after model checkpointing specially with FSDP (this usually does not happen with PEFT methods), the reserved and allocated CUDA memory has increased. This might be due to CUDA memory fragmentations. PyTorch recenly added an enviroment variable that helps to better manage memory fragmentation (this feature in available on PyTorch nightlies at the time of writing this doc July 30 2023). You can set this in your main training script as follows:

```bash

os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'

```
We also added this enviroment variable in `setup_environ_flags` of the [train_utils.py](../utils/train_utils.py), feel free to uncomment it if required.

8. Additional debugging flags? the environment variable `TORCH_DISTRIBUTED_DEBUG` can be used to trigger additional useful logging and collective synchronization checks to ensure all ranks are synchronized appropriately. `TORCH_DISTRIBUTED_DEBUG` can be set to either OFF (default), INFO, or DETAIL depending on the debugging level required. Please note that the most verbose option, DETAIL may impact the application performance and thus should only be used when debugging issues.

We also added this enviroment variable in `setup_environ_flags` of the [train_utils.py](../utils/train_utils.py), feel free to uncomment it if required.
2 changes: 0 additions & 2 deletions model_checkpointing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from .checkpoint_handler import (
load_model_checkpoint,
save_model_checkpoint,
save_distributed_model_checkpoint,
load_distributed_model_checkpoint,
load_optimizer_checkpoint,
save_optimizer_checkpoint,
save_model_and_optimizer_sharded,
Expand Down
129 changes: 31 additions & 98 deletions model_checkpointing/checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_date_of_run():
fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)


def load_model_sharded(model, rank, cfg, verbose=True):
def load_model_sharded(model, rank, cfg):
# torch.manual_seed(103)
folder_name = (
cfg.dist_checkpoint_root_folder
Expand Down Expand Up @@ -83,7 +83,7 @@ def load_model_sharded(model, rank, cfg, verbose=True):
print(f"Sharded state checkpoint loaded from {load_dir}")


def save_model_and_optimizer_sharded(model, rank, cfg,optim=None, verbose=True):
def save_model_and_optimizer_sharded(model, rank, cfg,optim=None):
"""save model and optimizer via sharded_state_dict to save_dir"""

folder_name = (
Expand Down Expand Up @@ -142,20 +142,27 @@ def save_model_checkpoint(
if rank == 0:
print(f"--> saving model ...")
# create save path
save_dir = Path.cwd() / cfg.checkpoint_folder
folder_name = (
cfg.dist_checkpoint_root_folder
+ "/"
+ cfg.dist_checkpoint_folder
+ "-"
+ cfg.model_name
)
save_dir = Path.cwd() / folder_name
save_dir.mkdir(parents=True, exist_ok=True)
save_name = cfg.model_name + "-" + str(epoch) + ".pt"
save_full_path = str(save_dir) + "/" + save_name

# save model
torch.save(cpu_state, save_full_path)

if cfg.verbose:
print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")

print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")



def load_model_checkpoint(model, rank, cfg, verbose=True):
def load_model_checkpoint(model, rank, cfg):
"""load local checkpoint to rank0 cpu
must be called * before * passing to FSDP"""

Expand All @@ -178,8 +185,8 @@ def load_model_checkpoint(model, rank, cfg, verbose=True):
# integrate into loaded model
model.load_state_dict(model_checkpoint)

if cfg.verbose:
print(f"model checkpoint loaded to rank0 cpu")

print(f"model checkpoint loaded to rank0 cpu")


def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
Expand All @@ -192,15 +199,22 @@ def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):

optim_state = FSDP.full_optim_state_dict(model, optimizer)

if cfg.verbose:
print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")

print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n")

if rank == 0:
save_dir = Path.cwd() / cfg.checkpoint_folder
folder_name = (
cfg.dist_checkpoint_root_folder
+ "/"
+ cfg.dist_checkpoint_folder
+ "-"
+ cfg.model_name
)
save_dir = Path.cwd() / folder_name
save_dir.mkdir(parents=True, exist_ok=True)

opt_save_name = (
cfg.optimizer_name + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
"optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt"
)
opt_save_full_path = save_dir / opt_save_name

Expand All @@ -211,109 +225,28 @@ def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1):
print(f"--> saved {opt_save_full_path} to disk")


def load_optimizer_checkpoint(model, optimizer, rank, cfg):
def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank):
"""load an fsdp optimizer full_state checkpoint using scatter method
this ensures only rank 0 loads the optimizer state dict and scatters to other ranks
"""

opt_file_path = Path.cwd() / cfg.checkpoint_folder / cfg.optimizer_checkpoint_file

if not opt_file_path.is_file():
if not optimizer_checkpoint_path.is_file():
print(
f"warning - optimizer checkpoint not present {opt_file_path}. Returning. "
f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. "
)
return

full_osd = None

if rank == 0:
full_osd = torch.load(opt_file_path)

if cfg.verbose:
print(f"loaded full osd on rank 0")
full_osd = torch.load(optimizer_checkpoint_path)

# called from all ranks, though only rank0 has a valid param for full_osd
sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model)

if cfg.verbose:
print(f"optimizer shard loaded on rank {rank}")

print(f"optimizer shard loaded on rank {rank}")


def load_distributed_model_checkpoint(model, rank, cfg):
if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT:
print(f"loading distributed checkpoint, rank {rank}...")
folder_name = (
cfg.dist_checkpoint_root_folder
+ "/"
+ cfg.dist_checkpoint_folder
+ "-"
+ cfg.model_name
)

checkdir = Path.cwd() / folder_name

if not checkdir.exists():
if rank == 0:
print(f"No checkpoint directory found...skipping")
return


reader = FileSystemReader(checkdir)

with FSDP.state_dict_type(
model,
StateDictType.LOCAL_STATE_DICT,
):
state_dict = model.state_dict()
load_state_dict(state_dict, reader)
model.load_state_dict(state_dict)

print(f"--> local state loaded on rank {rank}")

return


def save_distributed_model_checkpoint(model, rank, cfg, epoch=1):
# distributed checkpoint saving

# confirm type of checkpoint and save
if cfg.checkpoint_type == StateDictType.LOCAL_STATE_DICT:
# create writer to current path
folder_name = (
cfg.dist_checkpoint_root_folder
+ "/"
+ cfg.dist_checkpoint_folder
+ "-"
+ cfg.model_name
)
save_dir = Path.cwd() / folder_name

writer = FileSystemWriter(
save_dir,
)

with FSDP.state_dict_type(
model,
StateDictType.LOCAL_STATE_DICT,
):
state_dict = model.state_dict()


# write out distributed checkpoint
save_state_dict(state_dict, writer)

return

def load_sharded_model_single_gpu(model, model_path):

dcp.load_state_dict(
state_dict=state_dict_to_load_to,
storage_reader=FsspecReader(path),
no_dist=True,
)
print(f"Sharded state checkpoint loaded from {load_dir}")

def load_sharded_model_single_gpu(model,model_path):

reader = FileSystemReader(model_path)
Expand Down
13 changes: 12 additions & 1 deletion scripts/spellcheck_conf/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1078,4 +1078,15 @@ samsum
vLLM
TGI
vLLM
vLLM's
vLLM's
OOM
RTX
SKU
TPUs
checkpointing
enviroment
fragmentations
intra
nightlies
recenly
uncomment
1 change: 1 addition & 0 deletions utils/memory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __exit__(self, *exc):
cuda_info = torch.cuda.memory_stats()
self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
self.used = byte2gb(self.end - self.begin)
self.peaked = byte2gb(self.peak - self.begin)
Expand Down
Loading

0 comments on commit 1387b76

Please sign in to comment.