Skip to content

Commit

Permalink
Add save and load checkpoint mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
LiSu committed Jan 29, 2024
1 parent b769149 commit 08decba
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 15 deletions.
53 changes: 48 additions & 5 deletions examples/igbh/dist_train_rgnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from mlperf_logging_utils import get_mlperf_logger, submission_info
from torch.nn.parallel import DistributedDataParallel
from utilities import create_ckpt_folder
from rgnn import RGNN

mllogger = get_mlperf_logger(path=osp.dirname(osp.abspath(__file__)))
Expand Down Expand Up @@ -93,12 +94,15 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
val_loader_master_port,
with_gpu, trim_to_layer, use_fp16,
edge_dir, rpc_timeout,
validation_acc, validation_frac_within_epoch, evaluate_on_epoch_end):
validation_acc, validation_frac_within_epoch, evaluate_on_epoch_end,
checkpoint_on_epoch_end, ckpt_steps, ckpt_path):

world_size=num_nodes*num_training_procs
rank=node_rank*num_training_procs+local_proc_rank
if rank == 0:
mllogger.start(key=mllog_constants.RUN_START)
if ckpt_steps > 0:
ckpt_dir = create_ckpt_folder(base_dir=osp.dirname(osp.abspath(__file__)))

glt.utils.common.seed_everything(random_seed)

Expand Down Expand Up @@ -180,6 +184,14 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
)
)

# Load checkpoint
ckpt = None
if ckpt_path is not None:
try:
ckpt = torch.load(ckpt_path)
except FileNotFoundError:
return -1

# Define model and optimizer.
if with_gpu:
torch.cuda.set_device(current_device)
Expand All @@ -193,6 +205,8 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
heads=num_heads,
node_type='paper',
with_trim=trim_to_layer).to(current_device)
if ckpt is not None:
model.load_state_dict(ckpt['model_state_dict'])
model = DistributedDataParallel(model,
device_ids=[current_device.index] if with_gpu else None,
find_unused_parameters=True)
Expand All @@ -209,6 +223,8 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,

loss_fcn = torch.nn.CrossEntropyLoss().to(current_device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
if ckpt is not None:
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
batch_num = (len(train_idx) + train_batch_size - 1) // train_batch_size
validation_freq = int(batch_num * validation_frac_within_epoch)
is_success = False
Expand Down Expand Up @@ -249,6 +265,16 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
if with_gpu
else 0
)
#checkpoint
if ckpt_steps> 0 and idx % ckpt_steps == 0:
if with_gpu:
torch.cuda.synchronize()
torch.distributed.barrier()
if rank == 0:
epoch_num = epoch + idx / batch_num
glt.utils.common.save_ckpt(idx + epoch * batch_num,
ckpt_dir, model.module, optimizer, epoch_num)
torch.distributed.barrier()
# evaluate
if idx % validation_freq == 0:
if with_gpu:
Expand All @@ -271,6 +297,14 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
torch.cuda.synchronize()
torch.distributed.barrier()

#checkpoint at the end of epoch
if checkpoint_on_epoch_end:
if rank == 0:
epoch_num = epoch + 1
glt.utils.common.save_ckpt(idx + epoch * batch_num,
ckpt_dir, model.module, optimizer, epoch_num)
torch.distributed.barrier()

# evaluate at the end of epoch
if evaluate_on_epoch_end and not is_success:
epoch_num = epoch + 1
Expand Down Expand Up @@ -332,7 +366,7 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
parser.add_argument('--val_batch_size', type=int, default=512)
parser.add_argument('--hidden_channels', type=int, default=128)
parser.add_argument('--learning_rate', type=float, default=0.001)
parser.add_argument('--epochs', type=int, default=20)
parser.add_argument('--epochs', type=int, default=2)
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--num_heads', type=int, default=4)
parser.add_argument('--random_seed', type=int, default=42)
Expand Down Expand Up @@ -371,10 +405,16 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
help="load node/edge feature using fp16 format to reduce memory usage")
parser.add_argument("--validation_frac_within_epoch", type=float, default=0.05,
help="Fraction of the epoch after which validation should be performed.")
parser.add_argument("--validation_acc", type=float, default=0.72,
parser.add_argument("--validation_acc", type=float, default=1,
help="Validation accuracy threshold to stop training once reached.")
parser.add_argument("--evaluate_on_epoch_end", action="store_true",
help="Evaluate using validation set on each epoch end.")
help="Evaluate using validation set on each epoch end."),
parser.add_argument("--checkpoint_on_epoch_end", action="store_true",
help="Save checkpoint on each epoch end."),
parser.add_argument('--ckpt_steps', type=int, default=-1,
help="Save checkpoint every n steps. Default is -1, which means no checkpoint is saved.")
parser.add_argument('--ckpt_path', type=str, default=None,
help="Path to load checkpoint from. Default is None.")
args = parser.parse_args()
assert args.layout in ['COO', 'CSC', 'CSR']

Expand Down Expand Up @@ -436,7 +476,10 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
args.rpc_timeout,
args.validation_acc,
args.validation_frac_within_epoch,
args.evaluate_on_epoch_end),
args.evaluate_on_epoch_end,
args.checkpoint_on_epoch_end,
args.ckpt_steps,
args.ckpt_path),
nprocs=args.num_training_procs,
join=True
)
4 changes: 2 additions & 2 deletions examples/igbh/split_seeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def __init__(self,
dataset_size='tiny',
use_label_2K=True,
random_seed=42,
validation_frac=0.05):
validation_frac=0.01):
self.path = path
self.dataset_size = dataset_size
self.use_label_2K = use_label_2K
Expand Down Expand Up @@ -49,7 +49,7 @@ def process(self):
parser.add_argument("--random_seed", type=int, default='42')
parser.add_argument('--num_classes', type=int, default=2983,
choices=[19, 2983], help='number of classes')
parser.add_argument("--validation_frac", type=float, default=0.05,
parser.add_argument("--validation_frac", type=float, default=0.01,
help="Fraction of labeled vertices to be used for validation.")

args = parser.parse_args()
Expand Down
56 changes: 48 additions & 8 deletions examples/igbh/train_rgnn_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from dataset import IGBHeteroDataset
from mlperf_logging_utils import get_mlperf_logger, submission_info
from utilities import create_ckpt_folder
from rgnn import RGNN

warnings.filterwarnings("ignore")
Expand Down Expand Up @@ -80,9 +81,11 @@ def run_training_proc(rank, world_size,
hidden_channels, num_classes, num_layers, model_type, num_heads, fan_out,
epochs, train_batch_size, val_batch_size, learning_rate, random_seed, dataset,
train_idx, val_idx, with_gpu, validation_acc, validation_frac_within_epoch,
evaluate_on_epoch_end):
evaluate_on_epoch_end, checkpoint_on_epoch_end, ckpt_steps, ckpt_path):
if rank == 0:
mllogger.start(key=mllog_constants.RUN_START)
if ckpt_steps > 0:
ckpt_dir = create_ckpt_folder(base_dir=osp.dirname(osp.abspath(__file__)))
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group('nccl', rank=rank, world_size=world_size)
Expand Down Expand Up @@ -116,7 +119,14 @@ def run_training_proc(rank, world_size,
device=current_device,
seed=random_seed
)

# Load checkpoint
ckpt = None
if ckpt_path is not None:
try:
ckpt = torch.load(ckpt_path)
except FileNotFoundError:
return -1

# Define model and optimizer.
model = RGNN(dataset.get_edge_types(),
dataset.node_features['paper'].shape[1],
Expand All @@ -127,6 +137,8 @@ def run_training_proc(rank, world_size,
model=model_type,
heads=num_heads,
node_type='paper').to(current_device)
if ckpt is not None:
model.load_state_dict(ckpt['model_state_dict'])
model = DistributedDataParallel(model,
device_ids=[current_device.index] if with_gpu else None,
find_unused_parameters=True)
Expand All @@ -143,6 +155,9 @@ def run_training_proc(rank, world_size,

loss_fcn = torch.nn.CrossEntropyLoss().to(current_device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
if ckpt is not None:
optimizer.load_state_dict(ckpt['optimizer_state_dict'])

batch_num = (len(train_idx) + train_batch_size - 1) // train_batch_size
validation_freq = int(batch_num * validation_frac_within_epoch)
is_success = False
Expand Down Expand Up @@ -179,6 +194,16 @@ def run_training_proc(rank, world_size,
if with_gpu
else 0
)
#checkpoint
if ckpt_steps > 0 and idx % ckpt_steps == 0:
if with_gpu:
torch.cuda.synchronize()
dist.barrier()
if rank == 0:
epoch_num = epoch + idx / batch_num
glt.utils.common.save_ckpt(idx + epoch * batch_num,
ckpt_dir, model.module, optimizer, epoch_num)
dist.barrier()
# evaluate
if idx % validation_freq == 0:
if with_gpu:
Expand All @@ -197,6 +222,14 @@ def run_training_proc(rank, world_size,
torch.cuda.synchronize()
dist.barrier()

#checkpoint at the end of epoch
if checkpoint_on_epoch_end:
if rank == 0:
epoch_num = epoch + 1
glt.utils.common.save_ckpt(idx + epoch * batch_num,
ckpt_dir, model.module, optimizer, epoch_num)
dist.barrier()

# evaluate at the end of epoch
if evaluate_on_epoch_end and not is_success:
epoch_num = epoch + 1
Expand Down Expand Up @@ -257,12 +290,12 @@ def run_training_proc(rank, world_size,
choices=['rgat', 'rsage'])
# Model parameters
parser.add_argument('--fan_out', type=str, default='15,10,5')
parser.add_argument('--train_batch_size', type=int, default=1024)
parser.add_argument('--val_batch_size', type=int, default=1024)
parser.add_argument('--train_batch_size', type=int, default=512)
parser.add_argument('--val_batch_size', type=int, default=512)
parser.add_argument('--hidden_channels', type=int, default=128)
parser.add_argument('--learning_rate', type=float, default=0.01)
parser.add_argument('--epochs', type=int, default=3)
parser.add_argument('--num_layers', type=int, default=2)
parser.add_argument('--learning_rate', type=float, default=0.001)
parser.add_argument('--epochs', type=int, default=2)
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--num_heads', type=int, default=4)
parser.add_argument('--random_seed', type=int, default=42)
parser.add_argument("--cpu_mode", action="store_true",
Expand All @@ -280,6 +313,12 @@ def run_training_proc(rank, world_size,
help="Validation accuracy threshold to stop training once reached.")
parser.add_argument("--evaluate_on_epoch_end", action="store_true",
help="Evaluate using validation set on each epoch end.")
parser.add_argument("--checkpoint_on_epoch_end", action="store_true",
help="Save checkpoint on each epoch end.")
parser.add_argument('--ckpt_steps', type=int, default=-1,
help="Save checkpoint every n steps. Default is -1, which means no checkpoint is saved.")
parser.add_argument('--ckpt_path', type=str, default=None,
help="Path to load checkpoint from. Default is None.")
args = parser.parse_args()
args.with_gpu = (not args.cpu_mode) and torch.cuda.is_available()
assert args.layout in ['COO', 'CSC', 'CSR']
Expand Down Expand Up @@ -324,7 +363,8 @@ def run_training_proc(rank, world_size,
args.learning_rate, args.random_seed,
glt_dataset, train_idx, val_idx, args.with_gpu,
args.validation_acc, args.validation_frac_within_epoch,
args.evaluate_on_epoch_end),
args.evaluate_on_epoch_end, args.checkpoint_on_epoch_end,
args.ckpt_steps, args.ckpt_path),
nprocs=world_size,
join=True
)
12 changes: 12 additions & 0 deletions examples/igbh/utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import os
import time
import torch

def create_ckpt_folder(base_dir, prefix="ckpt"):
timestamp = time.strftime("%Y%m%d-%H%M%S")
folder_name = f"{prefix}_{timestamp}" if prefix else timestamp
full_path = os.path.join(base_dir, folder_name)
if not os.path.exists(full_path):
os.makedirs(full_path)
return full_path

59 changes: 59 additions & 0 deletions graphlearn_torch/python/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,62 @@ def load_and_concatenate_tensors(filename, device):
combined_tensor[start_idx:end_idx] = tensor.to(device)
start_idx = end_idx
return combined_tensor

def save_ckpt(
ckpt_seq: int,
ckpt_dir: str,
model: torch.nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
epoch: float = 0,
):
"""
Saves a checkpoint of the model's state.
Parameters:
ckpt_seq (int): The sequence number of the checkpoint.
ckpt_dir (str): The directory where the checkpoint will be saved.
model (torch.nn.Module): The model to be saved.
optimizer (Optional[torch.optim.Optimizer]): The optimizer, if any.
epoch (float): The current epoch. Default is 0.
"""
if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)
ckpt_path = os.path.join(ckpt_dir, f"model_seq_{ckpt_seq}.ckpt")

ckpt = {
'seq': ckpt_seq,
'epoch': epoch,
'model_state_dict': model.state_dict()
}
if optimizer:
ckpt['optimizer_state_dict'] = optimizer.state_dict()

torch.save(ckpt, ckpt_path)

def load_ckpt(
ckpt_seq: int,
ckpt_dir: str,
model: torch.nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
) -> float:
"""
Loads a checkpoint of the model's state, returns the epoch of the checkpoint.
Parameters:
ckpt_seq (int): The sequence number of the checkpoint.
ckpt_dir (str): The directory where the checkpoint will be saved.
model (torch.nn.Module): The model to be saved.
optimizer (Optional[torch.optim.Optimizer]): The optimizer, if any.
"""

ckpt_path = os.path.join(ckpt_dir, f"model_seq_{ckpt_seq}.ckpt")
try:
ckpt = torch.load(ckpt_path)
except FileNotFoundError:
return -1

model.load_state_dict(ckpt['model_state_dict'])
epoch = ckpt.get('epoch')
if optimizer:
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
return epoch

0 comments on commit 08decba

Please sign in to comment.