diff --git a/msamp/megatron/__init__.py b/msamp/megatron/__init__.py index 74954e73..1932453a 100644 --- a/msamp/megatron/__init__.py +++ b/msamp/megatron/__init__.py @@ -3,7 +3,7 @@ """Expose the interface of MS-AMP megatron package.""" -from msamp.megatron.optimizer.clip_grads import clip_grad_norm_fp8 +from msamp.megatron.optimizer.clip_grads import clip_grad_norm_fp32 from msamp.megatron.distributed import FP8DistributedDataParallel from msamp.common.utils.lazy_import import LazyImport @@ -13,6 +13,6 @@ FP8DistributedOptimizer = LazyImport('msamp.megatron.optimizer.distrib_optimizer', 'FP8DistributedOptimizer') __all__ = [ - 'clip_grad_norm_fp8', 'FP8DistributedDataParallel', 'FP8LinearWithGradAccumulationAndAsyncCommunication', + 'clip_grad_norm_fp32', 'FP8DistributedDataParallel', 'FP8LinearWithGradAccumulationAndAsyncCommunication', 'FP8DistributedOptimizer' ]