From 61eda30eb7fdb3cd1e254d2d5cc51504f6c16429 Mon Sep 17 00:00:00 2001 From: Chenheli Hua Date: Thu, 2 Jan 2025 13:38:19 -0800 Subject: [PATCH] Check src & dst dtypes in allgather to prevent silent failures. (#3523) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3523 X-link: https://github.com/facebookresearch/FBGEMM/pull/604 We noticed silent failures when dst & src dtypes mismatch. fbgemm allgather copies memory buffer which produces junk output. Here we add an explicit dtype check. Reviewed By: xw285cornell Differential Revision: D67535285 fbshipit-source-id: 66046a9368692fdea4e52401af33fdad95b3ea7e --- .../experimental/gen_ai/src/comm/car.cpp | 3 + .../gen_ai/test/comm/multi_gpu_car_test.py | 86 ++++++++++++++++--- 2 files changed, 78 insertions(+), 11 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp index e8d07e271..fb07f4cc4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/comm/car.cpp @@ -141,6 +141,9 @@ void nccl_allgather(at::Tensor dst, at::Tensor src, int64_t comm_idx) { using namespace c10d; TORCH_CHECK(src.is_contiguous()); TORCH_CHECK(dst.is_contiguous()); + TORCH_CHECK( + src.dtype() == dst.dtype(), + "dst and src tensors must have the same dtype."); ncclDataType_t type = to_nccl_data_type(src.scalar_type()); C10D_NCCL_CHECK( ncclAllGather( diff --git a/fbgemm_gpu/experimental/gen_ai/test/comm/multi_gpu_car_test.py b/fbgemm_gpu/experimental/gen_ai/test/comm/multi_gpu_car_test.py index 47065972c..34b249e8a 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/comm/multi_gpu_car_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/comm/multi_gpu_car_test.py @@ -14,6 +14,7 @@ import tempfile import unittest import uuid +from typing import Tuple import fbgemm_gpu.experimental.gen_ai # noqa: F401 @@ -36,7 +37,12 @@ def has_nvswitch() -> bool: return "GRANDTETON" in model or "SUPERMICRO" in model -def _run_allgather_inner(rdvz: str, dtype: torch.dtype) -> None: +def _run_allgather_inner( + rdvz: str, + dst_dtype: torch.dtype, + src_dtype: torch.dtype, + skip_torch_compile: bool = False, +) -> None: rank = int(os.environ["LOCAL_RANK"]) W = int(os.environ["WORLD_SIZE"]) device = torch.device(f"cuda:{rank}") @@ -45,16 +51,23 @@ def _run_allgather_inner(rdvz: str, dtype: torch.dtype) -> None: torch.ops.fbgemm.nccl_init(rank, W, rdvz) B, T, D = 2, 4096, 1024 - y = torch.empty(size=(B, T, D), dtype=dtype, device="cuda") + y = torch.empty(size=(B, T, D), dtype=src_dtype, device="cuda") y[:] = rank - y_gather = torch.full(size=(W, B, T, D), fill_value=-1, dtype=dtype, device="cuda") - # Here we test to confirm that allgather is compatible with torch.compile. - torch.compile(torch.ops.fbgemm.nccl_allgather)(y_gather, y) - for w in range(W): - torch.testing.assert_close( - y_gather[w], - torch.full(size=(B, T, D), fill_value=w, dtype=dtype, device=y.device), - ) + y_gather = torch.full( + size=(W, B, T, D), fill_value=-1, dtype=dst_dtype, device="cuda" + ) + # TORCH_CHECK failures can be suppressed by torch.compile, in which case + # we may not be able to capture the right exception in Python. + if not skip_torch_compile: + # Here we test to confirm that allgather is compatible with torch.compile. + torch.compile(torch.ops.fbgemm.nccl_allgather)(y_gather, y) + for w in range(W): + torch.testing.assert_close( + y_gather[w], + torch.full( + size=(B, T, D), fill_value=w, dtype=dst_dtype, device=y.device + ), + ) for _ in range(20): torch.ops.fbgemm.nccl_allgather(y_gather, y) @@ -303,7 +316,58 @@ def test_allgather(self, dtype: torch.dtype) -> None: max_restarts=0, ) elastic_launch(config=lc, entrypoint=_run_allgather_inner)( - os.path.join(path, "rdvz"), dtype + os.path.join(path, "rdvz"), dtype, dtype + ) + + @given( + dtypes=st.sampled_from( + [ + (torch.bfloat16, torch.float16), + (torch.bfloat16, torch.int), + (torch.bfloat16, torch.float8_e4m3fn), + ] + ) + ) + @settings(verbosity=Verbosity.verbose, max_examples=6, deadline=60000) + def test_allgather_dtype_mismatch( + self, dtypes: Tuple[torch.dtype, torch.dtype] + ) -> None: + dst_dtype, src_dtype = dtypes + # float8 is only supported in H100 or MI300x + if dst_dtype == torch.float8_e4m3fn or src_dtype == torch.float8_e4m3fn: + if torch.version.hip: + if dst_dtype == torch.float8_e4m3fn: + dst_dtype = torch.float8_e4m3fnuz + if src_dtype == torch.float8_e4m3fn: + src_dtype = torch.float8_e4m3fnuz + elif torch.cuda.get_device_capability() < (9, 0): + self.skipTest( + "float8_e4m3fn is only supported in H100 or MI300x, but we're running " + f"on {torch.cuda.get_device_capability()}" + ) + + with tempfile.TemporaryDirectory() as tmpdir, tempfile.TemporaryDirectory() as path: + lc = LaunchConfig( + min_nodes=1, + max_nodes=1, + nproc_per_node=torch.cuda.device_count(), + run_id=str(uuid.uuid4()), + rdzv_backend="c10d", + rdzv_endpoint=os.path.join(tmpdir, "rdzv"), + rdzv_configs={"store_type": "file"}, + start_method="spawn", + monitor_interval=1, + max_restarts=0, + ) + with self.assertRaises(Exception) as cm: + elastic_launch(config=lc, entrypoint=_run_allgather_inner)( + os.path.join(path, "rdvz"), + dst_dtype, + src_dtype, + True, + ) + self.assertTrue( + "dst and src tensors must have the same dtype." in cm.exception.args[0] ) def test_allreduce(self) -> None: