Skip to content

Commit

Permalink
Check src & dst dtypes in allgather to prevent silent failures. (#3523)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3523

X-link: facebookresearch/FBGEMM#604

We noticed silent failures when dst & src dtypes mismatch. fbgemm allgather copies memory buffer which produces junk output. Here we add an explicit dtype check.

Reviewed By: xw285cornell

Differential Revision: D67535285

fbshipit-source-id: 66046a9368692fdea4e52401af33fdad95b3ea7e
  • Loading branch information
Chenheli Hua authored and facebook-github-bot committed Jan 2, 2025
1 parent 213d849 commit 61eda30
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 11 deletions.
3 changes: 3 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ void nccl_allgather(at::Tensor dst, at::Tensor src, int64_t comm_idx) {
using namespace c10d;
TORCH_CHECK(src.is_contiguous());
TORCH_CHECK(dst.is_contiguous());
TORCH_CHECK(
src.dtype() == dst.dtype(),
"dst and src tensors must have the same dtype.");
ncclDataType_t type = to_nccl_data_type(src.scalar_type());
C10D_NCCL_CHECK(
ncclAllGather(
Expand Down
86 changes: 75 additions & 11 deletions fbgemm_gpu/experimental/gen_ai/test/comm/multi_gpu_car_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import tempfile
import unittest
import uuid
from typing import Tuple

import fbgemm_gpu.experimental.gen_ai # noqa: F401

Expand All @@ -36,7 +37,12 @@ def has_nvswitch() -> bool:
return "GRANDTETON" in model or "SUPERMICRO" in model


def _run_allgather_inner(rdvz: str, dtype: torch.dtype) -> None:
def _run_allgather_inner(
rdvz: str,
dst_dtype: torch.dtype,
src_dtype: torch.dtype,
skip_torch_compile: bool = False,
) -> None:
rank = int(os.environ["LOCAL_RANK"])
W = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{rank}")
Expand All @@ -45,16 +51,23 @@ def _run_allgather_inner(rdvz: str, dtype: torch.dtype) -> None:
torch.ops.fbgemm.nccl_init(rank, W, rdvz)

B, T, D = 2, 4096, 1024
y = torch.empty(size=(B, T, D), dtype=dtype, device="cuda")
y = torch.empty(size=(B, T, D), dtype=src_dtype, device="cuda")
y[:] = rank
y_gather = torch.full(size=(W, B, T, D), fill_value=-1, dtype=dtype, device="cuda")
# Here we test to confirm that allgather is compatible with torch.compile.
torch.compile(torch.ops.fbgemm.nccl_allgather)(y_gather, y)
for w in range(W):
torch.testing.assert_close(
y_gather[w],
torch.full(size=(B, T, D), fill_value=w, dtype=dtype, device=y.device),
)
y_gather = torch.full(
size=(W, B, T, D), fill_value=-1, dtype=dst_dtype, device="cuda"
)
# TORCH_CHECK failures can be suppressed by torch.compile, in which case
# we may not be able to capture the right exception in Python.
if not skip_torch_compile:
# Here we test to confirm that allgather is compatible with torch.compile.
torch.compile(torch.ops.fbgemm.nccl_allgather)(y_gather, y)
for w in range(W):
torch.testing.assert_close(
y_gather[w],
torch.full(
size=(B, T, D), fill_value=w, dtype=dst_dtype, device=y.device
),
)

for _ in range(20):
torch.ops.fbgemm.nccl_allgather(y_gather, y)
Expand Down Expand Up @@ -303,7 +316,58 @@ def test_allgather(self, dtype: torch.dtype) -> None:
max_restarts=0,
)
elastic_launch(config=lc, entrypoint=_run_allgather_inner)(
os.path.join(path, "rdvz"), dtype
os.path.join(path, "rdvz"), dtype, dtype
)

@given(
dtypes=st.sampled_from(
[
(torch.bfloat16, torch.float16),
(torch.bfloat16, torch.int),
(torch.bfloat16, torch.float8_e4m3fn),
]
)
)
@settings(verbosity=Verbosity.verbose, max_examples=6, deadline=60000)
def test_allgather_dtype_mismatch(
self, dtypes: Tuple[torch.dtype, torch.dtype]
) -> None:
dst_dtype, src_dtype = dtypes
# float8 is only supported in H100 or MI300x
if dst_dtype == torch.float8_e4m3fn or src_dtype == torch.float8_e4m3fn:
if torch.version.hip:
if dst_dtype == torch.float8_e4m3fn:
dst_dtype = torch.float8_e4m3fnuz
if src_dtype == torch.float8_e4m3fn:
src_dtype = torch.float8_e4m3fnuz
elif torch.cuda.get_device_capability() < (9, 0):
self.skipTest(
"float8_e4m3fn is only supported in H100 or MI300x, but we're running "
f"on {torch.cuda.get_device_capability()}"
)

with tempfile.TemporaryDirectory() as tmpdir, tempfile.TemporaryDirectory() as path:
lc = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=torch.cuda.device_count(),
run_id=str(uuid.uuid4()),
rdzv_backend="c10d",
rdzv_endpoint=os.path.join(tmpdir, "rdzv"),
rdzv_configs={"store_type": "file"},
start_method="spawn",
monitor_interval=1,
max_restarts=0,
)
with self.assertRaises(Exception) as cm:
elastic_launch(config=lc, entrypoint=_run_allgather_inner)(
os.path.join(path, "rdvz"),
dst_dtype,
src_dtype,
True,
)
self.assertTrue(
"dst and src tensors must have the same dtype." in cm.exception.args[0]
)

def test_allreduce(self) -> None:
Expand Down

0 comments on commit 61eda30

Please sign in to comment.