Skip to content

Commit

Permalink
Support global norm gradient clipping if DTensor is used (pytorch#2271)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2271

If DTensor is used in the parameters passed to gradient clipping, we compute the global norm across all ranks (instead per-rank local norm) to be clipped. This is needed according to a study in https://fb.workplace.com/chat/t/26249100244704391#:~:text=was%20really%20important%3A-,D59763695,-today%20in%20APS

The existing test for gradient clipping is modified to test this new capability.

Note: the global norm gradient clipping was original implemented by Andrew in D60625965. I combine it with the unit test here as single diff.

Reviewed By: iamzainhuda

Differential Revision: D60704597

fbshipit-source-id: 07a50ef04a9d950121cc828bedeb43d5493d6add
  • Loading branch information
ckluk2 authored and facebook-github-bot committed Aug 6, 2024
1 parent 0198312 commit d50d319
Showing 1 changed file with 99 additions and 25 deletions.
124 changes: 99 additions & 25 deletions torchrec/optim/clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

# pyre-strict

from collections import defaultdict
from enum import Enum, unique
from typing import Any, List, Union
from typing import Any, cast, Dict, List, Union

import torch
import torch.distributed as dist

from torch.distributed._tensor.api import DTensor

Expand Down Expand Up @@ -49,28 +51,41 @@ def __init__(
self._check_meta: bool = True

self._params: List[torch.Tensor] = []
# Only used if there are DTensor parameters, in which case this dict
# holds the sharded DTensor parameters and `self._params` holds the
# replicated tensor parameters
self._mesh_to_dtensor_params: Dict[dist.DeviceMesh, List[DTensor]] = (
defaultdict(list)
)

for param_group in self.param_groups:
self._params += list(param_group["params"])

# Convert dtensors to local tensors for performance reason;
# otherwise, it needs to go thru dtensor dispatch, which is
# quite slow currently.
with torch.autograd.profiler.record_function(
"Dtensors => Tensors in GradientClippingOptimizer::init()"
):
with torch.no_grad():
# Under no_grad(), p.to_local() will be as cheap as p._local_tensor.
for i, p in enumerate(self._params):
if not isinstance(p, DTensor):
continue
local_p = p.to_local()
if p.grad is None:
local_p.grad = None
else:
# if p is a DTensor, so should be p.grad
assert isinstance(p.grad, DTensor)
local_p.grad = p.grad.to_local()
self._params[i] = local_p
for param in param_group["params"]:
if isinstance(param, DTensor):
self._mesh_to_dtensor_params[param.device_mesh].append(param)
else:
self._params.append(param)

if len(self._mesh_to_dtensor_params) == 0:
return

# check if we have the support for DTensor
if len(self._mesh_to_dtensor_params) > 1:
raise NotImplementedError(
"More than one device mesh is not supported yet: "
f"{self._mesh_to_dtensor_params.keys()}"
)

device_mesh = next(iter(self._mesh_to_dtensor_params.keys()))
if device_mesh.ndim > 1:
raise NotImplementedError(
f"{device_mesh.ndim}D device mesh is not supported yet"
)

if self._clipping == GradientClipping.VALUE:
# This path is currently not used in any production.
raise NotImplementedError(
"clip_grad_value_ for DTensor parameters is not supported yet"
)

# pyre-ignore [2]
def step(self, closure: Any = None) -> None:
Expand All @@ -82,10 +97,69 @@ def step(self, closure: Any = None) -> None:
self._check_meta = False

if self._clipping == GradientClipping.NORM:
torch.nn.utils.clip_grad_norm_(
self._params, self._max_gradient, norm_type=self._norm_type
)
if len(self._mesh_to_dtensor_params) == 0:
# No DTensor parameters, so we can use the regular clip_grad_norm_
torch.nn.utils.clip_grad_norm_(
self._params, self._max_gradient, norm_type=self._norm_type
)
else:
# There are DTensor parameters, so we need to use _dist_clip_grad_norm
device_mesh = next(iter(self._mesh_to_dtensor_params.keys()))
dtensor_params = self._mesh_to_dtensor_params[device_mesh]
process_group = device_mesh.get_group()
sharded_grads = [
cast(DTensor, p.grad)._local_tensor
for p in dtensor_params
if p.grad is not None
]
if sharded_grads:
replicated_grads = [
p.grad for p in self._params if p.grad is not None
]
_dist_clip_grad_norm(
sharded_grads,
replicated_grads,
process_group,
self._max_gradient,
float(self._norm_type),
)
elif self._clipping == GradientClipping.VALUE:
torch.nn.utils.clip_grad_value_(self._params, self._max_gradient)

super().step(closure)


def _dist_clip_grad_norm(
sharded_grads: List[torch.Tensor],
replicated_grads: List[torch.Tensor],
process_group: dist.ProcessGroup,
max_norm: float,
norm_type: float = 2.0,
) -> torch.Tensor:
assert len(sharded_grads) > 0
sharded_norms = torch._foreach_norm(sharded_grads, norm_type)
local_norm = torch.linalg.vector_norm(torch.stack(sharded_norms), norm_type)
if replicated_grads:
replicated_norms = torch._foreach_norm(replicated_grads, norm_type)
replicated_norm = torch.linalg.vector_norm(
torch.stack(replicated_norms), norm_type
)
else:
replicated_norm = None

if norm_type == torch.inf:
total_norm = local_norm
dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=process_group)
if replicated_norm is not None:
total_norm = torch.maximum(total_norm, replicated_norm)
else:
total_norm = local_norm**norm_type
dist.all_reduce(total_norm, group=process_group)
if replicated_norm is not None:
total_norm += replicated_norm**norm_type
total_norm = total_norm ** (1.0 / norm_type)

clip_coef = cast(torch.Tensor, max_norm / (total_norm + 1e-6))
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
torch._foreach_mul_(sharded_grads + replicated_grads, clip_coef_clamped)
return total_norm

0 comments on commit d50d319

Please sign in to comment.