Skip to content

Commit 08decba

Browse files
committed
Add save and load checkpoint mechanism
1 parent b769149 commit 08decba

File tree

5 files changed

+169
-15
lines changed

5 files changed

+169
-15
lines changed

examples/igbh/dist_train_rgnn.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from mlperf_logging_utils import get_mlperf_logger, submission_info
2828
from torch.nn.parallel import DistributedDataParallel
29+
from utilities import create_ckpt_folder
2930
from rgnn import RGNN
3031

3132
mllogger = get_mlperf_logger(path=osp.dirname(osp.abspath(__file__)))
@@ -93,12 +94,15 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
9394
val_loader_master_port,
9495
with_gpu, trim_to_layer, use_fp16,
9596
edge_dir, rpc_timeout,
96-
validation_acc, validation_frac_within_epoch, evaluate_on_epoch_end):
97+
validation_acc, validation_frac_within_epoch, evaluate_on_epoch_end,
98+
checkpoint_on_epoch_end, ckpt_steps, ckpt_path):
9799

98100
world_size=num_nodes*num_training_procs
99101
rank=node_rank*num_training_procs+local_proc_rank
100102
if rank == 0:
101103
mllogger.start(key=mllog_constants.RUN_START)
104+
if ckpt_steps > 0:
105+
ckpt_dir = create_ckpt_folder(base_dir=osp.dirname(osp.abspath(__file__)))
102106

103107
glt.utils.common.seed_everything(random_seed)
104108

@@ -180,6 +184,14 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
180184
)
181185
)
182186

187+
# Load checkpoint
188+
ckpt = None
189+
if ckpt_path is not None:
190+
try:
191+
ckpt = torch.load(ckpt_path)
192+
except FileNotFoundError:
193+
return -1
194+
183195
# Define model and optimizer.
184196
if with_gpu:
185197
torch.cuda.set_device(current_device)
@@ -193,6 +205,8 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
193205
heads=num_heads,
194206
node_type='paper',
195207
with_trim=trim_to_layer).to(current_device)
208+
if ckpt is not None:
209+
model.load_state_dict(ckpt['model_state_dict'])
196210
model = DistributedDataParallel(model,
197211
device_ids=[current_device.index] if with_gpu else None,
198212
find_unused_parameters=True)
@@ -209,6 +223,8 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
209223

210224
loss_fcn = torch.nn.CrossEntropyLoss().to(current_device)
211225
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
226+
if ckpt is not None:
227+
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
212228
batch_num = (len(train_idx) + train_batch_size - 1) // train_batch_size
213229
validation_freq = int(batch_num * validation_frac_within_epoch)
214230
is_success = False
@@ -249,6 +265,16 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
249265
if with_gpu
250266
else 0
251267
)
268+
#checkpoint
269+
if ckpt_steps> 0 and idx % ckpt_steps == 0:
270+
if with_gpu:
271+
torch.cuda.synchronize()
272+
torch.distributed.barrier()
273+
if rank == 0:
274+
epoch_num = epoch + idx / batch_num
275+
glt.utils.common.save_ckpt(idx + epoch * batch_num,
276+
ckpt_dir, model.module, optimizer, epoch_num)
277+
torch.distributed.barrier()
252278
# evaluate
253279
if idx % validation_freq == 0:
254280
if with_gpu:
@@ -271,6 +297,14 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
271297
torch.cuda.synchronize()
272298
torch.distributed.barrier()
273299

300+
#checkpoint at the end of epoch
301+
if checkpoint_on_epoch_end:
302+
if rank == 0:
303+
epoch_num = epoch + 1
304+
glt.utils.common.save_ckpt(idx + epoch * batch_num,
305+
ckpt_dir, model.module, optimizer, epoch_num)
306+
torch.distributed.barrier()
307+
274308
# evaluate at the end of epoch
275309
if evaluate_on_epoch_end and not is_success:
276310
epoch_num = epoch + 1
@@ -332,7 +366,7 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
332366
parser.add_argument('--val_batch_size', type=int, default=512)
333367
parser.add_argument('--hidden_channels', type=int, default=128)
334368
parser.add_argument('--learning_rate', type=float, default=0.001)
335-
parser.add_argument('--epochs', type=int, default=20)
369+
parser.add_argument('--epochs', type=int, default=2)
336370
parser.add_argument('--num_layers', type=int, default=3)
337371
parser.add_argument('--num_heads', type=int, default=4)
338372
parser.add_argument('--random_seed', type=int, default=42)
@@ -371,10 +405,16 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
371405
help="load node/edge feature using fp16 format to reduce memory usage")
372406
parser.add_argument("--validation_frac_within_epoch", type=float, default=0.05,
373407
help="Fraction of the epoch after which validation should be performed.")
374-
parser.add_argument("--validation_acc", type=float, default=0.72,
408+
parser.add_argument("--validation_acc", type=float, default=1,
375409
help="Validation accuracy threshold to stop training once reached.")
376410
parser.add_argument("--evaluate_on_epoch_end", action="store_true",
377-
help="Evaluate using validation set on each epoch end.")
411+
help="Evaluate using validation set on each epoch end."),
412+
parser.add_argument("--checkpoint_on_epoch_end", action="store_true",
413+
help="Save checkpoint on each epoch end."),
414+
parser.add_argument('--ckpt_steps', type=int, default=-1,
415+
help="Save checkpoint every n steps. Default is -1, which means no checkpoint is saved.")
416+
parser.add_argument('--ckpt_path', type=str, default=None,
417+
help="Path to load checkpoint from. Default is None.")
378418
args = parser.parse_args()
379419
assert args.layout in ['COO', 'CSC', 'CSR']
380420

@@ -436,7 +476,10 @@ def run_training_proc(local_proc_rank, num_nodes, node_rank, num_training_procs,
436476
args.rpc_timeout,
437477
args.validation_acc,
438478
args.validation_frac_within_epoch,
439-
args.evaluate_on_epoch_end),
479+
args.evaluate_on_epoch_end,
480+
args.checkpoint_on_epoch_end,
481+
args.ckpt_steps,
482+
args.ckpt_path),
440483
nprocs=args.num_training_procs,
441484
join=True
442485
)

examples/igbh/split_seeds.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def __init__(self,
99
dataset_size='tiny',
1010
use_label_2K=True,
1111
random_seed=42,
12-
validation_frac=0.05):
12+
validation_frac=0.01):
1313
self.path = path
1414
self.dataset_size = dataset_size
1515
self.use_label_2K = use_label_2K
@@ -49,7 +49,7 @@ def process(self):
4949
parser.add_argument("--random_seed", type=int, default='42')
5050
parser.add_argument('--num_classes', type=int, default=2983,
5151
choices=[19, 2983], help='number of classes')
52-
parser.add_argument("--validation_frac", type=float, default=0.05,
52+
parser.add_argument("--validation_frac", type=float, default=0.01,
5353
help="Fraction of labeled vertices to be used for validation.")
5454

5555
args = parser.parse_args()

examples/igbh/train_rgnn_multi_gpu.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from dataset import IGBHeteroDataset
3030
from mlperf_logging_utils import get_mlperf_logger, submission_info
31+
from utilities import create_ckpt_folder
3132
from rgnn import RGNN
3233

3334
warnings.filterwarnings("ignore")
@@ -80,9 +81,11 @@ def run_training_proc(rank, world_size,
8081
hidden_channels, num_classes, num_layers, model_type, num_heads, fan_out,
8182
epochs, train_batch_size, val_batch_size, learning_rate, random_seed, dataset,
8283
train_idx, val_idx, with_gpu, validation_acc, validation_frac_within_epoch,
83-
evaluate_on_epoch_end):
84+
evaluate_on_epoch_end, checkpoint_on_epoch_end, ckpt_steps, ckpt_path):
8485
if rank == 0:
8586
mllogger.start(key=mllog_constants.RUN_START)
87+
if ckpt_steps > 0:
88+
ckpt_dir = create_ckpt_folder(base_dir=osp.dirname(osp.abspath(__file__)))
8689
os.environ['MASTER_ADDR'] = 'localhost'
8790
os.environ['MASTER_PORT'] = '12355'
8891
dist.init_process_group('nccl', rank=rank, world_size=world_size)
@@ -116,7 +119,14 @@ def run_training_proc(rank, world_size,
116119
device=current_device,
117120
seed=random_seed
118121
)
119-
122+
# Load checkpoint
123+
ckpt = None
124+
if ckpt_path is not None:
125+
try:
126+
ckpt = torch.load(ckpt_path)
127+
except FileNotFoundError:
128+
return -1
129+
120130
# Define model and optimizer.
121131
model = RGNN(dataset.get_edge_types(),
122132
dataset.node_features['paper'].shape[1],
@@ -127,6 +137,8 @@ def run_training_proc(rank, world_size,
127137
model=model_type,
128138
heads=num_heads,
129139
node_type='paper').to(current_device)
140+
if ckpt is not None:
141+
model.load_state_dict(ckpt['model_state_dict'])
130142
model = DistributedDataParallel(model,
131143
device_ids=[current_device.index] if with_gpu else None,
132144
find_unused_parameters=True)
@@ -143,6 +155,9 @@ def run_training_proc(rank, world_size,
143155

144156
loss_fcn = torch.nn.CrossEntropyLoss().to(current_device)
145157
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
158+
if ckpt is not None:
159+
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
160+
146161
batch_num = (len(train_idx) + train_batch_size - 1) // train_batch_size
147162
validation_freq = int(batch_num * validation_frac_within_epoch)
148163
is_success = False
@@ -179,6 +194,16 @@ def run_training_proc(rank, world_size,
179194
if with_gpu
180195
else 0
181196
)
197+
#checkpoint
198+
if ckpt_steps > 0 and idx % ckpt_steps == 0:
199+
if with_gpu:
200+
torch.cuda.synchronize()
201+
dist.barrier()
202+
if rank == 0:
203+
epoch_num = epoch + idx / batch_num
204+
glt.utils.common.save_ckpt(idx + epoch * batch_num,
205+
ckpt_dir, model.module, optimizer, epoch_num)
206+
dist.barrier()
182207
# evaluate
183208
if idx % validation_freq == 0:
184209
if with_gpu:
@@ -197,6 +222,14 @@ def run_training_proc(rank, world_size,
197222
torch.cuda.synchronize()
198223
dist.barrier()
199224

225+
#checkpoint at the end of epoch
226+
if checkpoint_on_epoch_end:
227+
if rank == 0:
228+
epoch_num = epoch + 1
229+
glt.utils.common.save_ckpt(idx + epoch * batch_num,
230+
ckpt_dir, model.module, optimizer, epoch_num)
231+
dist.barrier()
232+
200233
# evaluate at the end of epoch
201234
if evaluate_on_epoch_end and not is_success:
202235
epoch_num = epoch + 1
@@ -257,12 +290,12 @@ def run_training_proc(rank, world_size,
257290
choices=['rgat', 'rsage'])
258291
# Model parameters
259292
parser.add_argument('--fan_out', type=str, default='15,10,5')
260-
parser.add_argument('--train_batch_size', type=int, default=1024)
261-
parser.add_argument('--val_batch_size', type=int, default=1024)
293+
parser.add_argument('--train_batch_size', type=int, default=512)
294+
parser.add_argument('--val_batch_size', type=int, default=512)
262295
parser.add_argument('--hidden_channels', type=int, default=128)
263-
parser.add_argument('--learning_rate', type=float, default=0.01)
264-
parser.add_argument('--epochs', type=int, default=3)
265-
parser.add_argument('--num_layers', type=int, default=2)
296+
parser.add_argument('--learning_rate', type=float, default=0.001)
297+
parser.add_argument('--epochs', type=int, default=2)
298+
parser.add_argument('--num_layers', type=int, default=3)
266299
parser.add_argument('--num_heads', type=int, default=4)
267300
parser.add_argument('--random_seed', type=int, default=42)
268301
parser.add_argument("--cpu_mode", action="store_true",
@@ -280,6 +313,12 @@ def run_training_proc(rank, world_size,
280313
help="Validation accuracy threshold to stop training once reached.")
281314
parser.add_argument("--evaluate_on_epoch_end", action="store_true",
282315
help="Evaluate using validation set on each epoch end.")
316+
parser.add_argument("--checkpoint_on_epoch_end", action="store_true",
317+
help="Save checkpoint on each epoch end.")
318+
parser.add_argument('--ckpt_steps', type=int, default=-1,
319+
help="Save checkpoint every n steps. Default is -1, which means no checkpoint is saved.")
320+
parser.add_argument('--ckpt_path', type=str, default=None,
321+
help="Path to load checkpoint from. Default is None.")
283322
args = parser.parse_args()
284323
args.with_gpu = (not args.cpu_mode) and torch.cuda.is_available()
285324
assert args.layout in ['COO', 'CSC', 'CSR']
@@ -324,7 +363,8 @@ def run_training_proc(rank, world_size,
324363
args.learning_rate, args.random_seed,
325364
glt_dataset, train_idx, val_idx, args.with_gpu,
326365
args.validation_acc, args.validation_frac_within_epoch,
327-
args.evaluate_on_epoch_end),
366+
args.evaluate_on_epoch_end, args.checkpoint_on_epoch_end,
367+
args.ckpt_steps, args.ckpt_path),
328368
nprocs=world_size,
329369
join=True
330370
)

examples/igbh/utilities.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import os
2+
import time
3+
import torch
4+
5+
def create_ckpt_folder(base_dir, prefix="ckpt"):
6+
timestamp = time.strftime("%Y%m%d-%H%M%S")
7+
folder_name = f"{prefix}_{timestamp}" if prefix else timestamp
8+
full_path = os.path.join(base_dir, folder_name)
9+
if not os.path.exists(full_path):
10+
os.makedirs(full_path)
11+
return full_path
12+

graphlearn_torch/python/utils/common.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,62 @@ def load_and_concatenate_tensors(filename, device):
165165
combined_tensor[start_idx:end_idx] = tensor.to(device)
166166
start_idx = end_idx
167167
return combined_tensor
168+
169+
def save_ckpt(
170+
ckpt_seq: int,
171+
ckpt_dir: str,
172+
model: torch.nn.Module,
173+
optimizer: Optional[torch.optim.Optimizer] = None,
174+
epoch: float = 0,
175+
):
176+
"""
177+
Saves a checkpoint of the model's state.
178+
179+
Parameters:
180+
ckpt_seq (int): The sequence number of the checkpoint.
181+
ckpt_dir (str): The directory where the checkpoint will be saved.
182+
model (torch.nn.Module): The model to be saved.
183+
optimizer (Optional[torch.optim.Optimizer]): The optimizer, if any.
184+
epoch (float): The current epoch. Default is 0.
185+
"""
186+
if not os.path.isdir(ckpt_dir):
187+
os.makedirs(ckpt_dir)
188+
ckpt_path = os.path.join(ckpt_dir, f"model_seq_{ckpt_seq}.ckpt")
189+
190+
ckpt = {
191+
'seq': ckpt_seq,
192+
'epoch': epoch,
193+
'model_state_dict': model.state_dict()
194+
}
195+
if optimizer:
196+
ckpt['optimizer_state_dict'] = optimizer.state_dict()
197+
198+
torch.save(ckpt, ckpt_path)
199+
200+
def load_ckpt(
201+
ckpt_seq: int,
202+
ckpt_dir: str,
203+
model: torch.nn.Module,
204+
optimizer: Optional[torch.optim.Optimizer] = None,
205+
) -> float:
206+
"""
207+
Loads a checkpoint of the model's state, returns the epoch of the checkpoint.
208+
209+
Parameters:
210+
ckpt_seq (int): The sequence number of the checkpoint.
211+
ckpt_dir (str): The directory where the checkpoint will be saved.
212+
model (torch.nn.Module): The model to be saved.
213+
optimizer (Optional[torch.optim.Optimizer]): The optimizer, if any.
214+
"""
215+
216+
ckpt_path = os.path.join(ckpt_dir, f"model_seq_{ckpt_seq}.ckpt")
217+
try:
218+
ckpt = torch.load(ckpt_path)
219+
except FileNotFoundError:
220+
return -1
221+
222+
model.load_state_dict(ckpt['model_state_dict'])
223+
epoch = ckpt.get('epoch')
224+
if optimizer:
225+
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
226+
return epoch

0 commit comments

Comments
 (0)