Skip to content

Commit

Permalink
Add Unittest For Distributed Adam With CUDA Graph (#1836)
Browse files Browse the repository at this point in the history
* Add unittest for distributed adam with cuda graph.

* Fix the distributed adam issue if user passes float LR.

* skip if world_size < 8

---------

Co-authored-by: Masaki Kozuki <[email protected]>
  • Loading branch information
alpha0422 and crcrpar committed Aug 30, 2024
1 parent c3e4adf commit b7a4acc
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 54 deletions.
6 changes: 5 additions & 1 deletion apex/contrib/optimizers/distributed_fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,11 @@ def __init__(
if len(group['params']) == 0:
continue
for item in ['lr']:
self.param_groups[idx][item] = group[item].to(device=self.device)
if torch.is_tensor(group[item]):
self.param_groups[idx][item] = group[item].to(device=self.device)
else:
self.param_groups[idx][item] = torch.tensor(group[item],
device=self.device)

# For better representation string
arg_names = inspect.getfullargspec(DistributedFusedAdam.__init__).args
Expand Down
175 changes: 122 additions & 53 deletions apex/contrib/test/optimizers/test_dist_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Callable, Optional, Tuple
import unittest
import warnings
from contextlib import nullcontext

import torch
from torch.testing._internal import common_utils
Expand Down Expand Up @@ -49,6 +50,7 @@ def make_models(
store_param_remainders: bool = False,
with_scaled_states: bool = False,
nccl_ub: bool = False,
with_cuda_graph: bool = False,
):

# Construct models with same parameters
Expand Down Expand Up @@ -100,6 +102,7 @@ def make_models(
store_param_remainders=store_param_remainders,
with_scaled_states=with_scaled_states,
nccl_ub=nccl_ub,
capturable=with_cuda_graph,
**optim_args,
)

Expand Down Expand Up @@ -143,78 +146,130 @@ def test_matches_pytorch(
with_scaled_states: bool = False,
nccl_ub: bool = False,
init_optim_func: Optional[Callable[[DistributedFusedAdam], None]] = None,
with_cuda_graph: bool = False,
):

torch.manual_seed(self.seed + self.rank)

# Identical models with data-parallel and ZeRO
ref_model, ref_optim, dist_model, dist_optim = make_models(
num_layers,
layer_size,
adam_w_mode=adam_w_mode,
model_dtype=model_dtype,
optim_dtype=optim_dtype,
grad_sync_dtype=grad_sync_dtype,
param_sync_dtype=param_sync_dtype,
device=device,
overlap_communication=overlap_communication,
bucket_cap_mb=bucket_cap_mb,
contiguous_buffers=contiguous_buffers,
store_params=store_params,
store_param_remainders=store_param_remainders,
with_scaled_states=with_scaled_states,
nccl_ub=nccl_ub,
)
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
ref_model, ref_optim, dist_model, dist_optim = make_models(
num_layers,
layer_size,
adam_w_mode=adam_w_mode,
model_dtype=model_dtype,
optim_dtype=optim_dtype,
grad_sync_dtype=grad_sync_dtype,
param_sync_dtype=param_sync_dtype,
device=device,
overlap_communication=overlap_communication,
bucket_cap_mb=bucket_cap_mb,
contiguous_buffers=contiguous_buffers,
store_params=store_params,
store_param_remainders=store_param_remainders,
with_scaled_states=with_scaled_states,
nccl_ub=nccl_ub,
with_cuda_graph=with_cuda_graph,
)

# Initialize distributed optimizer
if init_optim_func is not None:
init_optim_func(dist_optim)
with torch.cuda.stream(stream):
init_optim_func(dist_optim)

# Training loop
for step in range(num_steps):
# Static data
static_xs, static_dys = [], []
ys_ref, grad_xs_ref = [], []
ys_dist, grad_xs_dist = [], []

# Reset gradients
ref_optim.zero_grad()
dist_optim.zero_grad()

# Forward and backward passes
for micro_step in range(micro_batch_steps):
graph = torch.cuda.CUDAGraph() if with_cuda_graph else None
CAPTURE_ITERATION = 11
if with_cuda_graph:
assert num_steps > CAPTURE_ITERATION + 3, \
"Not enough iterations for CUDA graph test."

# Training loop
with torch.cuda.stream(stream):
for step in range(num_steps):
# Synthetic data
x = torch.rand(batch_size, layer_size) - 0.5
dy = torch.rand_like(x) - 0.5
x = x.to(dtype=model_dtype, device=device)
dy = dy.to(dtype=model_dtype, device=device)
for micro_step in range(micro_batch_steps):
x = torch.rand(batch_size, layer_size) - 0.5
dy = torch.rand_like(x) - 0.5
x = x.to(dtype=model_dtype, device=device)
dy = dy.to(dtype=model_dtype, device=device)
if step == 0:
static_xs.append(x)
static_dys.append(dy)
else:
static_xs[micro_step].copy_(x)
static_dys[micro_step].copy_(dy)

# Reference implementation
x_ref = x.detach().clone().requires_grad_(True)
y_ref = ref_model(x_ref)
y_ref.backward(dy)
ref_optim.zero_grad()
for micro_step in range(micro_batch_steps):
x, dy = static_xs[micro_step], static_dys[micro_step]

x_ref = x.detach().clone().requires_grad_(True)
y_ref = ref_model(x_ref)
y_ref.backward(dy)

if step == 0:
ys_ref.append(y_ref)
grad_xs_ref.append(x_ref.grad)
else:
with torch.no_grad():
ys_ref[micro_step].copy_(y_ref)
grad_xs_ref[micro_step].copy_(x_ref.grad)
ref_optim.step()

# Distributed implementation
x_dist = x.detach().clone().requires_grad_(True)
y_dist = dist_model(x_dist)
backward_context = dummy_context
if use_nosync and micro_step < micro_batch_steps-1:
backward_context = dist_optim.no_sync
with backward_context():
y_dist.backward(dy)
if not with_cuda_graph or step <= CAPTURE_ITERATION:
if with_cuda_graph and step == CAPTURE_ITERATION:
ctx = torch.cuda.graph(graph)
torch.cuda.synchronize()
else:
ctx = nullcontext()

with ctx:
dist_optim.zero_grad()
for micro_step in range(micro_batch_steps):
x, dy = static_xs[micro_step], static_dys[micro_step]

x_dist = x.detach().clone().requires_grad_(True)
y_dist = dist_model(x_dist)
backward_context = dummy_context
if use_nosync and micro_step < micro_batch_steps-1:
backward_context = dist_optim.no_sync
with backward_context():
y_dist.backward(dy)

if step == 0:
ys_dist.append(y_dist)
grad_xs_dist.append(x_dist.grad)
else:
with torch.no_grad():
ys_dist[micro_step].copy_(y_dist)
grad_xs_dist[micro_step].copy_(x_dist.grad)
dist_optim.step()

if with_cuda_graph and step == CAPTURE_ITERATION:
graph.replay()
else:
graph.replay()

# Check that data tensors match
torch.testing.assert_close(
y_dist, y_ref, rtol=rtol, atol=atol)
torch.testing.assert_close(
x_dist.grad, x_ref.grad, rtol=rtol, atol=atol)

# Optimization step
ref_optim.step()
dist_optim.step()
for mbs in range(micro_batch_steps):
torch.testing.assert_close(
ys_dist[mbs], ys_ref[mbs], rtol=rtol, atol=atol)
torch.testing.assert_close(
grad_xs_dist[mbs], grad_xs_ref[mbs], rtol=rtol, atol=atol)

# Check that parameters match
for ref_param, dist_param in zip(ref_model.parameters(),
dist_model.parameters()):
torch.testing.assert_close(
dist_param, ref_param, rtol=rtol, atol=atol)
# Check that parameters match
for ref_param, dist_param in zip(ref_model.parameters(),
dist_model.parameters()):
torch.testing.assert_close(
dist_param, ref_param, rtol=rtol, atol=atol)

def test_matches_pytorch_l2_reg(self):
self.test_matches_pytorch(adam_w_mode=False)
Expand Down Expand Up @@ -797,6 +852,20 @@ def test_bucket_low_utilization_warning(self):
for w in warns:
self.assertNotRegex(str(w.message), ".*Consider decreasing the bucket_cap_mb argument.")

def test_cuda_graph(self):
"""Test distributed adam with CUDA graph"""
if self.world_size <= 8:
self.skipTest(f"{self.world_size=} is expected to be >= 8")
self.test_matches_pytorch(
rtol=5e-3,
atol=1e-5,
num_steps=15,
micro_batch_steps=1,
model_dtype=torch.float16,
optim_dtype=torch.float16,
contiguous_buffers=True,
with_cuda_graph=True,
)

if __name__ == "__main__":
# Assume script has been run with torchrun
Expand Down

0 comments on commit b7a4acc

Please sign in to comment.