From 5b31b7094a044c7f4dc964b7192685519e484ee5 Mon Sep 17 00:00:00 2001 From: Yuxiang Yang Date: Tue, 17 Oct 2023 15:12:15 +0800 Subject: [PATCH] Support checkpoint for Megatron-LM (#100) **Description** Support checkpoint for Megatron-LM --- msamp/megatron/optimizer/distrib_optimizer.py | 57 +++++++++++++++++-- 1 file changed, 53 insertions(+), 4 deletions(-) diff --git a/msamp/megatron/optimizer/distrib_optimizer.py b/msamp/megatron/optimizer/distrib_optimizer.py index 560ecce5..3b8f8148 100644 --- a/msamp/megatron/optimizer/distrib_optimizer.py +++ b/msamp/megatron/optimizer/distrib_optimizer.py @@ -10,6 +10,7 @@ from megatron.core import mpu, tensor_parallel from megatron.optimizer.optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper from megatron.optimizer.distrib_optimizer import DistributedOptimizer, Range +from megatron.utils import print_rank_0 from msamp.common.dtype import Dtypes from msamp.common.tensor import ScalingTensor, ScalingMeta @@ -350,12 +351,60 @@ def get_model_parallel_group(self): return None def state_dict(self): - """Return the optimizer's state dict.""" - raise NotImplementedError + """The state dict must contain the fp32-from-float16 and fp16-from-fp8 shards.""" + state_dict = {} + state_dict['optimizer'] = self.optimizer.state_dict() + if self.grad_scaler: + state_dict['grad_scaler'] = self.grad_scaler.state_dict() + # shared master weight + state_dict['shard_fp32_from_float16_groups'] = \ + self.shard_fp32_from_float16_groups + state_dict['shard_hp_from_fp8_groups'] = \ + self.shard_hp_from_fp8_groups + return state_dict def load_state_dict(self, state_dict): - """Load the optimizer's state dict.""" - raise NotImplementedError + """Load the state dict.""" + optimizer_key = 'optimizer' + if optimizer_key not in state_dict: + optimizer_key = 'optimizer_state_dict' + print_rank_0('***WARNING*** loading optimizer from ' + 'an old checkpoint ...') + # convert optimizer states + ckpt_state_dict = state_dict[optimizer_key] + self.optimizer.load_state_dict(ckpt_state_dict) + + # Grad scaler. + if 'grad_scaler' not in state_dict: + if self.fp16: + print_rank_0('***WARNING*** found an old checkpoint, will not ' + 'load grad scaler ...') + else: + if self.grad_scaler: + self.grad_scaler.load_state_dict(state_dict['grad_scaler']) + else: + print_rank_0( + '***WARNING*** fould the grad scaler in the ' + 'checkpoint but it is None in the class. ' + 'Skipping loading grad scaler ...' + ) + + # Copy data for the main params. + for current_group, saved_group in zip( + self.shard_fp32_from_float16_groups, state_dict['shard_fp32_from_float16_groups'] + ): + for current_param, saved_param in zip(current_group, saved_group): + current_param.data.copy_(saved_param.data) + + for current_group, saved_group in zip(self.shard_hp_from_fp8_groups, state_dict['shard_hp_from_fp8_groups']): + for current_param, saved_param in zip(current_group, saved_group): + if current_param.data.qtype == saved_param.data.qtype: + current_param.data.copy_(saved_param.data) + else: + # when the data type of optimizer's master weight and checkpoint's is different + current_param.data.copy_( + saved_param.data.to(current_param.data.device).cast(current_param.data.qtype) + ) def zero_grad(self, set_to_none=True): """Zero grads.