From d50d3191e23c039dded296d7ba3fa4d583bcefc5 Mon Sep 17 00:00:00 2001 From: CK Luk Date: Tue, 6 Aug 2024 12:30:52 -0700 Subject: [PATCH] Support global norm gradient clipping if DTensor is used (#2271) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2271 If DTensor is used in the parameters passed to gradient clipping, we compute the global norm across all ranks (instead per-rank local norm) to be clipped. This is needed according to a study in https://fb.workplace.com/chat/t/26249100244704391#:~:text=was%20really%20important%3A-,D59763695,-today%20in%20APS The existing test for gradient clipping is modified to test this new capability. Note: the global norm gradient clipping was original implemented by Andrew in D60625965. I combine it with the unit test here as single diff. Reviewed By: iamzainhuda Differential Revision: D60704597 fbshipit-source-id: 07a50ef04a9d950121cc828bedeb43d5493d6add --- torchrec/optim/clipping.py | 124 +++++++++++++++++++++++++++++-------- 1 file changed, 99 insertions(+), 25 deletions(-) diff --git a/torchrec/optim/clipping.py b/torchrec/optim/clipping.py index 1f7f9ffec..e48a6a4ca 100644 --- a/torchrec/optim/clipping.py +++ b/torchrec/optim/clipping.py @@ -7,10 +7,12 @@ # pyre-strict +from collections import defaultdict from enum import Enum, unique -from typing import Any, List, Union +from typing import Any, cast, Dict, List, Union import torch +import torch.distributed as dist from torch.distributed._tensor.api import DTensor @@ -49,28 +51,41 @@ def __init__( self._check_meta: bool = True self._params: List[torch.Tensor] = [] + # Only used if there are DTensor parameters, in which case this dict + # holds the sharded DTensor parameters and `self._params` holds the + # replicated tensor parameters + self._mesh_to_dtensor_params: Dict[dist.DeviceMesh, List[DTensor]] = ( + defaultdict(list) + ) + for param_group in self.param_groups: - self._params += list(param_group["params"]) - - # Convert dtensors to local tensors for performance reason; - # otherwise, it needs to go thru dtensor dispatch, which is - # quite slow currently. - with torch.autograd.profiler.record_function( - "Dtensors => Tensors in GradientClippingOptimizer::init()" - ): - with torch.no_grad(): - # Under no_grad(), p.to_local() will be as cheap as p._local_tensor. - for i, p in enumerate(self._params): - if not isinstance(p, DTensor): - continue - local_p = p.to_local() - if p.grad is None: - local_p.grad = None - else: - # if p is a DTensor, so should be p.grad - assert isinstance(p.grad, DTensor) - local_p.grad = p.grad.to_local() - self._params[i] = local_p + for param in param_group["params"]: + if isinstance(param, DTensor): + self._mesh_to_dtensor_params[param.device_mesh].append(param) + else: + self._params.append(param) + + if len(self._mesh_to_dtensor_params) == 0: + return + + # check if we have the support for DTensor + if len(self._mesh_to_dtensor_params) > 1: + raise NotImplementedError( + "More than one device mesh is not supported yet: " + f"{self._mesh_to_dtensor_params.keys()}" + ) + + device_mesh = next(iter(self._mesh_to_dtensor_params.keys())) + if device_mesh.ndim > 1: + raise NotImplementedError( + f"{device_mesh.ndim}D device mesh is not supported yet" + ) + + if self._clipping == GradientClipping.VALUE: + # This path is currently not used in any production. + raise NotImplementedError( + "clip_grad_value_ for DTensor parameters is not supported yet" + ) # pyre-ignore [2] def step(self, closure: Any = None) -> None: @@ -82,10 +97,69 @@ def step(self, closure: Any = None) -> None: self._check_meta = False if self._clipping == GradientClipping.NORM: - torch.nn.utils.clip_grad_norm_( - self._params, self._max_gradient, norm_type=self._norm_type - ) + if len(self._mesh_to_dtensor_params) == 0: + # No DTensor parameters, so we can use the regular clip_grad_norm_ + torch.nn.utils.clip_grad_norm_( + self._params, self._max_gradient, norm_type=self._norm_type + ) + else: + # There are DTensor parameters, so we need to use _dist_clip_grad_norm + device_mesh = next(iter(self._mesh_to_dtensor_params.keys())) + dtensor_params = self._mesh_to_dtensor_params[device_mesh] + process_group = device_mesh.get_group() + sharded_grads = [ + cast(DTensor, p.grad)._local_tensor + for p in dtensor_params + if p.grad is not None + ] + if sharded_grads: + replicated_grads = [ + p.grad for p in self._params if p.grad is not None + ] + _dist_clip_grad_norm( + sharded_grads, + replicated_grads, + process_group, + self._max_gradient, + float(self._norm_type), + ) elif self._clipping == GradientClipping.VALUE: torch.nn.utils.clip_grad_value_(self._params, self._max_gradient) super().step(closure) + + +def _dist_clip_grad_norm( + sharded_grads: List[torch.Tensor], + replicated_grads: List[torch.Tensor], + process_group: dist.ProcessGroup, + max_norm: float, + norm_type: float = 2.0, +) -> torch.Tensor: + assert len(sharded_grads) > 0 + sharded_norms = torch._foreach_norm(sharded_grads, norm_type) + local_norm = torch.linalg.vector_norm(torch.stack(sharded_norms), norm_type) + if replicated_grads: + replicated_norms = torch._foreach_norm(replicated_grads, norm_type) + replicated_norm = torch.linalg.vector_norm( + torch.stack(replicated_norms), norm_type + ) + else: + replicated_norm = None + + if norm_type == torch.inf: + total_norm = local_norm + dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=process_group) + if replicated_norm is not None: + total_norm = torch.maximum(total_norm, replicated_norm) + else: + total_norm = local_norm**norm_type + dist.all_reduce(total_norm, group=process_group) + if replicated_norm is not None: + total_norm += replicated_norm**norm_type + total_norm = total_norm ** (1.0 / norm_type) + + clip_coef = cast(torch.Tensor, max_norm / (total_norm + 1e-6)) + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + torch._foreach_mul_(sharded_grads + replicated_grads, clip_coef_clamped) + return total_norm