diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py index d94c352401..e510df1761 100644 --- a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py @@ -77,7 +77,16 @@ def _parse_args(argv=None, namespace=None): help="Disable the comm+GEMM overlap.", ) parser.add_argument( - "--num-replicas", type=int, default=1, help="Number of data-parallel model replicas." + "--num-replicas", + type=int, + default=1, + help="Number of data-parallel model replicas per node.", + ) + parser.add_argument( + "--use-global-replica-count", + action="store_true", + default=False, + help="Treat '--num-replicas' as the total number of replicas.", ) parser.add_argument( "--tcp-init", @@ -173,13 +182,12 @@ def _train(opts): opts.tcp_init = True opts.bind_to_device = True opts.bootstrap_backend = "mpi" - elif "TORCHELASTIC_RUN_ID" in os.environ: + else: # TORCHELASTIC, SLURM, etc... WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) - LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) - else: - raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", str(torch.cuda.device_count()))) + NUM_NODES = WORLD_SIZE // LOCAL_SIZE # Initialize torch.distributed global process group and get DP/TP groups @@ -214,90 +222,24 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False) dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") - # Figure out process groups for tensor- and data-parallelism (if any) - if NUM_NODES > 1: - # Create a list of world ranks on this node - hostname = socket.gethostname() - ifname = os.getenv( - "NVTE_UB_SOCKET_IFNAME", - os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")), - ) - - if ifname is not None: - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - try: - hostname = socket.inet_ntoa( - fcntl.ioctl( - s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) - )[20:24] - ) - except OSError as err: - raise OSError(f"Invalid network interface: {ifname}") from err - - hostnames = [None for _ in range(WORLD_SIZE)] - dist.all_gather_object(hostnames, hostname) - unique_hosts = [] - for host in hostnames: - if host not in unique_hosts: - unique_hosts.append(host) - assert len(unique_hosts) == NUM_NODES - - ranks_per_node_list = [[] for _ in range(NUM_NODES)] - self_node_idx = -1 - for i, host in enumerate(hostnames): - node_idx = unique_hosts.index(host) - ranks_per_node_list[node_idx].append(i) - if host == hostname: - self_node_idx = node_idx - assert self_node_idx >= 0 - self_node_ranks = ranks_per_node_list[self_node_idx] - - if opts.num_replicas > 1: - # Split node ranks into multiple replicas - assert len(self_node_ranks) % opts.num_replicas == 0 - tp_size = len(self_node_ranks) // opts.num_replicas - ranks_per_replica_list = [] - for node_ranks in ranks_per_node_list: - for i in range(opts.num_replicas): - start = i * tp_size - end = start + tp_size - ranks_per_replica_list.append(node_ranks[start:end]) - - self_replica_idx = -1 - for i, replica_ranks in enumerate(ranks_per_replica_list): - if WORLD_RANK in replica_ranks: - self_replica_idx = i - break - assert self_replica_idx >= 0 + total_replicas = ( + opts.num_replicas if opts.use_global_replica_count else opts.num_replicas * NUM_NODES + ) + tp_size = WORLD_SIZE // total_replicas - else: - # The entire node is the tensor-parallel group - ranks_per_replica_list = ranks_per_node_list - self_replica_idx = self_node_idx + if total_replicas > 1: + ranks_per_replica_list = [ + [i * tp_size + t for t in range(tp_size)] for i in range(total_replicas) + ] tp_group, _ = dist.new_subgroups_by_enumeration(ranks_per_replica_list, backend="nccl") ranks_per_replica_tensor = torch.tensor(ranks_per_replica_list, dtype=torch.int32) dp_group, _ = dist.new_subgroups_by_enumeration( ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" ) - else: - if opts.num_replicas > 1: - # Mixed data- and tensor-parallelism on a single node - # NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions - all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu") - ranks_per_replica_tensor = all_ranks.reshape( - (opts.num_replicas, LOCAL_SIZE // opts.num_replicas) - ) - tp_group, _ = dist.new_subgroups_by_enumeration( - ranks_per_replica_tensor.tolist(), backend="nccl" - ) - dp_group, _ = dist.new_subgroups_by_enumeration( - ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" - ) - else: - dp_group = None - tp_group = nccl_world + dp_group = None + tp_group = nccl_world tp_rank = dist.get_rank(tp_group) tp_size = dist.get_world_size(tp_group) diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index 9e11e07e11..4bbdd23fd6 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -180,15 +180,22 @@ def _main(opts): LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1")) opts.tcp_init = True opts.bootstrap_backend = "mpi" - elif "TORCHELASTIC_RUN_ID" in os.environ: + else: # TORCHELASTIC, SLURM, etc... WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) - LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) - else: - raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") - assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node - assert LOCAL_SIZE <= torch.cuda.device_count() + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", str(torch.cuda.device_count()))) + + result = subprocess.run( + "nvidia-smi -q | grep -m1 CliqueId | awk '{printf $3}'", + capture_output=True, + text=True, + shell=True, + ) + + if result.stdout == "0": # Extra checks for non-MNNVL platforms + assert WORLD_SIZE == LOCAL_SIZE + assert LOCAL_SIZE <= torch.cuda.device_count() # Fix clock speed torch.cuda.set_device(LOCAL_RANK) diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index d4a01386ee..39200775c9 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -7,6 +7,7 @@ import os import sys import socket +import subprocess import argparse import warnings import pprint @@ -209,14 +210,21 @@ def _train(opts): opts.tcp_init = True opts.bind_to_device = True opts.bootstrap_backend = "mpi" - elif "TORCHELASTIC_RUN_ID" in os.environ: + else: WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) - LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) - else: - raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") - assert LOCAL_SIZE == WORLD_SIZE + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", str(torch.cuda.device_count()))) + + result = subprocess.run( + "nvidia-smi -q | grep -m1 CliqueId | awk '{printf $3}'", + capture_output=True, + text=True, + shell=True, + ) + + if result.stdout == "0": # Extra checks for non-MNNVL platforms + assert WORLD_SIZE == LOCAL_SIZE def dist_print(msg, src=None, end="\n", debug=False, error=False): if debug and not opts.debug: @@ -227,7 +235,7 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): dist.barrier() # Set device and initialize RNG states - torch.cuda.set_device(WORLD_RANK) + torch.cuda.set_device(LOCAL_RANK) torch.manual_seed(opts.seed) torch.cuda.manual_seed(opts.seed) @@ -312,7 +320,7 @@ def run_fwd_bwd(model, x): return out torch_rng_state = torch.get_rng_state() - cuda_rng_state = torch.cuda.get_rng_state(torch.device(f"cuda:{WORLD_RANK}")) + cuda_rng_state = torch.cuda.get_rng_state(torch.device(f"cuda:{LOCAL_RANK}")) if opts.use_cuda_graphs: test_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(test_graph): @@ -329,7 +337,7 @@ def run_fwd_bwd(model, x): names.append(test_name + ".grad") torch.set_rng_state(torch_rng_state) - torch.cuda.set_rng_state(cuda_rng_state, torch.device(f"cuda:{WORLD_RANK}")) + torch.cuda.set_rng_state(cuda_rng_state, torch.device(f"cuda:{LOCAL_RANK}")) if opts.use_cuda_graphs: ref_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(ref_graph): diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 68231f6c04..0a2abb6e4e 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -78,6 +78,7 @@ list(APPEND transformer_engine_SOURCES util/cast.cu util/padding.cu util/cuda_driver.cpp + util/cuda_nvml.cpp util/cuda_runtime.cpp util/rtc.cpp swizzle/swizzle.cu diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index c3453aeffe..14ff853266 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -20,6 +20,7 @@ #include #include "common/util/cuda_driver.h" +#include "common/util/cuda_nvml.h" #include "common/util/cuda_runtime.h" #include "common/util/logging.h" #include "common/util/system.h" @@ -29,7 +30,6 @@ #ifdef NVTE_UB_WITH_MPI static MPI_Comm EXT_COMM_WORLD = MPI_COMM_WORLD; static MPI_Comm EXT_COMM_INTRA; -static MPI_Comm EXT_COMM_INTER; #define UB_MPI_CHECK(expr) \ do { \ @@ -58,11 +58,20 @@ void ub_mpi_barrier(ExtComm comm) { UB_MPI_CHECK(MPI_Barrier(comm)); } #else #define EXT_COMM_WORLD "world" #define EXT_COMM_INTRA "intra" -#define EXT_COMM_INTER "inter" #endif #define MULTICAST_GB_TOTAL 512 +#if CUDART_VERSION < 12030 +// MNNVL: FABRIC handle support lifted from CUDA 12.3 +#define CU_MEM_HANDLE_TYPE_FABRIC ((CUmemAllocationHandleType)0x8ULL) +#define CU_IPC_HANDLE_SIZE 64 +typedef struct CUmemFabricHandle_st { + unsigned char data[CU_IPC_HANDLE_SIZE]; +} CUmemFabricHandle_v1; +typedef CUmemFabricHandle_v1 CUmemFabricHandle; +#endif + int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (const char *)b); } #define IPCCHECK(cmd) \ @@ -82,18 +91,43 @@ int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (co } \ } while (0); -int pipe_rank(communicator *comm, int step) { - int mynode = comm->myrank / comm->nvsize; - int mylocal = comm->nvrank; - int numlocal = comm->nvsize; - - int newlocal1 = mylocal + step * comm->ar_nvsize * comm->ar2_nvsize; - int newlocal = (numlocal + (newlocal1 % numlocal)) % numlocal; - int newnode = mynode; - newnode += (newlocal1 - newlocal) / numlocal * comm->num_nodes * comm->num2_nodes; - int allnodes = comm->nranks / comm->nvsize; - newnode = (allnodes + (newnode % allnodes)) % allnodes; - return newnode * numlocal + newlocal; +bool has_mnnvl_fabric(int device_id) { +#if CUDA_VERSION < 12040 + if (getenv("NVTE_UBDEBUG")) { + printf( + "TransformerEngine does not support multi-node NVLINK " + "since it was not built with CUDA version >= 12.4.\n"); + } + return false; +#else + bool mnnvl_fabric_support = false; + CUdevice dev; + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &dev, device_id); + int fabric_handle_supported = 0; + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &fabric_handle_supported, + CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, dev); + if (fabric_handle_supported) { + NVTE_CALL_CHECK_CUDA_NVML(nvmlInit_v2); + nvmlDevice_t local_device; + NVTE_CALL_CHECK_CUDA_NVML(nvmlDeviceGetHandleByIndex_v2, device_id, &local_device); + nvmlGpuFabricInfoV_t fabricInfo = {}; + fabricInfo.version = nvmlGpuFabricInfo_v2; + fabricInfo.clusterUuid[0] = '\0'; + NVTE_CALL_CHECK_CUDA_NVML(nvmlDeviceGetGpuFabricInfoV, local_device, &fabricInfo); + NVTE_CALL_CHECK_CUDA_NVML(nvmlShutdown); + if (fabricInfo.state >= NVML_GPU_FABRIC_STATE_COMPLETED && fabricInfo.clusterUuid[0] != '\0') { + mnnvl_fabric_support = true; + } + } + if (getenv("NVTE_UBDEBUG")) { + if (mnnvl_fabric_support) { + printf("MNNVL NVLINK is supported on this platform.\n"); + } else { + printf("MNNVL NVLINK is not supported on this platform.\n"); + } + } + return mnnvl_fabric_support; +#endif } int create_communicator_grouped2(communicator **comm, int myrank, int numranks, int mylocal, @@ -122,10 +156,6 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, (*comm)->use_ce = 0; (*comm)->cga_size = 2; for (int i = 0; i < userbuffers_op_types; i++) (*comm)->basecounter[i] = 0; - (*comm)->head = 0; - (*comm)->tail = 0; - (*comm)->active_nreqs = 0; - for (int i = 0; i < userbuffers_op_types; i++) (*comm)->active_req[i].active = -1; int device_clock = 0; // 110 sec wait time by default @@ -182,29 +212,14 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, // ar2 has step equal to ar_nvsize int allnodes = numranks / numlocal; int nodeid = myrank / numlocal; - int datanodes = allnodes / pipenodes / tensornodes; - int pipenodegroup_id = myrank / numlocal / (datanodes * tensornodes); - (*comm)->pipe_id = pipegpus * pipenodegroup_id + mylocal / (datagpus * tensorgpus); - - (*comm)->comm_inter = EXT_COMM_INTER; - (*comm)->first_node = nodeid - mynode; (*comm)->num_nodes = numnodes; (*comm)->my_node = mynode; - (*comm)->num2_nodes = tensornodes; - (*comm)->my2_node = (mynode / datanodes) % tensornodes; - (*comm)->first2_node = mynode - (*comm)->my2_node * datanodes; - - (*comm)->fifo = reinterpret_cast(malloc(sizeof(ub_request) * NVTE_MAX_REQUESTS)); - (*comm)->nblocks = 8; - (*comm)->alignblock = 1024 * 512; - (*comm)->minblock = 1024 * 2 * 1024; - (*comm)->asyncblocks = 16; - #define NBUF 2 #if CUDART_VERSION >= 12010 + bool mnnvl_fabric = has_mnnvl_fabric(cur_dev); if (!transformer_engine::getenv("UB_SKIPMC") && transformer_engine::cuda::supports_multicast() && (*comm)->ar2_nvsize > 1) { // multicast init only for TP ops (____2 operations) @@ -215,7 +230,8 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, CUmulticastObjectProp mcProp = {}; mcProp.numDevices = (*comm)->ar2_nvsize; mcProp.size = (*comm)->mc_maxsize; - mcProp.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + mcProp.handleTypes = + mnnvl_fabric ? CU_MEM_HANDLE_TYPE_FABRIC : CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; NVTE_CALL_CHECK_CUDA_DRIVER( cuMulticastGetGranularity, &gran, &mcProp, @@ -223,46 +239,78 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, mc_maxsize = ((mc_maxsize + gran - 1) / gran) * gran; mcProp.size = mc_maxsize; (*comm)->mc_maxsize = mc_maxsize; - - // Broadcast the a POSIX file descriptor from the local root rank to other local ranks. - // NOTE: This cannot be done via MPI_Bcast or other external comm libraries. They mangle the - // file descriptor and prevent cuMemImportFromShareableHandle() from correctly - // interpreting the file. Instead, we use Unix domain sockets for the kernel to - // recreate the correct file descriptor on every receiving rank. - int fd; - volatile uint32_t abortFlag = 0; - IpcSocketHandle ipcSock = {0}; - uint64_t opId = 0xdeadcafeb000 + (*comm)->ar2_firstgpu; - ipcSocketResult_t ret = ipcSocketSuccess; - IPCCHECK(ipcSocketInit(&ipcSock, (*comm)->ar2_nvrank, (uint64_t)opId, &abortFlag)); - (*comm)->_barrier((*comm)->comm_world); - - if ((*comm)->ar2_nvrank == 0) { + if ((*comm)->ar2_nvrank == 0) NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastCreate, &(*comm)->mc_handle, &mcProp); - NVTE_CALL_CHECK_CUDA_DRIVER( - cuMemExportToShareableHandle, reinterpret_cast(&fd), (*comm)->mc_handle, - static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), - (uint64_t)0); - for (int p = 1; p < (*comm)->ar2_nvsize; p++) { - (*comm)->_barrier((*comm)->comm_intra); - IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, fd, p, (uint64_t)opId), ret, error); + if (mnnvl_fabric) { + CUmemFabricHandle *exphndl = + reinterpret_cast(malloc(sizeof(CUmemFabricHandle))); + CUmemFabricHandle *tmphndl = + reinterpret_cast(malloc(sizeof(CUmemFabricHandle))); + CUmemFabricHandle *exphndls; + NVTE_CHECK_CUDA(cudaMallocHost(&exphndls, (*comm)->nvsize * sizeof(CUmemFabricHandle))); + if ((*comm)->ar2_nvrank == 0) + NVTE_CALL_CHECK_CUDA_DRIVER(cuMemExportToShareableHandle, static_cast(tmphndl), + (*comm)->mc_handle, CU_MEM_HANDLE_TYPE_FABRIC, 0); + for (int grp = 0; grp < (*comm)->ar_nvsize; + grp++) { // we do N broadcasts for N TP groups in NVL domain + int root = grp * (*comm)->ar2_nvsize; + + // It just needs to be a bcast but reuse existing allgather comm + (*comm)->_allgather( + reinterpret_cast(exphndls), (*comm)->nvsize * sizeof(CUmemFabricHandle), + reinterpret_cast(tmphndl), sizeof(CUmemFabricHandle), (*comm)->comm_intra); + + //save data if brodcast was from rank 0 in our group + if ((*comm)->ar2_firstgpu == root) + memcpy(exphndl, exphndls + root, sizeof(CUmemFabricHandle)); } + if ((*comm)->ar2_nvrank != 0) + NVTE_CALL_CHECK_CUDA_DRIVER(cuMemImportFromShareableHandle, &(*comm)->mc_handle, + reinterpret_cast(exphndl), CU_MEM_HANDLE_TYPE_FABRIC); + free(exphndl); + free(tmphndl); + NVTE_CHECK_CUDA(cudaFreeHost(exphndls)); } else { - for (int p = 1; p < (*comm)->ar2_nvsize; p++) { - (*comm)->_barrier((*comm)->comm_intra); - if ((*comm)->ar2_nvrank == p) IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &fd), ret, error); + // Broadcast the a POSIX file descriptor from the local root rank to other local ranks. + // NOTE: This cannot be done via MPI_Bcast or other external comm libraries. They mangle the + // file descriptor and prevent cuMemImportFromShareableHandle() from correctly + // interpreting the file. Instead, we use Unix domain sockets for the kernel to + // recreate the correct file descriptor on every receiving rank. + int fd; + volatile uint32_t abortFlag = 0; + IpcSocketHandle ipcSock = {0}; + uint64_t opId = 0xdeadcafeb000 + (*comm)->my_node + (*comm)->ar2_firstgpu; + ipcSocketResult_t ret = ipcSocketSuccess; + IPCCHECK(ipcSocketInit(&ipcSock, (*comm)->ar2_nvrank, (uint64_t)opId, &abortFlag)); + (*comm)->_barrier((*comm)->comm_world); + + if ((*comm)->ar2_nvrank == 0) { + NVTE_CALL_CHECK_CUDA_DRIVER( + cuMemExportToShareableHandle, reinterpret_cast(&fd), (*comm)->mc_handle, + static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), + (uint64_t)0); + + for (int p = 1; p < (*comm)->ar2_nvsize; p++) { + (*comm)->_barrier((*comm)->comm_intra); + IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, fd, p, (uint64_t)opId), ret, error); + } + } else { + for (int p = 1; p < (*comm)->ar2_nvsize; p++) { + (*comm)->_barrier((*comm)->comm_intra); + if ((*comm)->ar2_nvrank == p) IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &fd), ret, error); + } } - } - error: - if ((*comm)->ar2_nvrank != 0) { - NVTE_CALL_CHECK_CUDA_DRIVER( - cuMemImportFromShareableHandle, &(*comm)->mc_handle, reinterpret_cast(fd), - static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + error: + if ((*comm)->ar2_nvrank != 0) { + NVTE_CALL_CHECK_CUDA_DRIVER( + cuMemImportFromShareableHandle, &(*comm)->mc_handle, reinterpret_cast(fd), + static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + } + IPCCHECK(ipcSocketClose(&ipcSock)); + close(fd); } - IPCCHECK(ipcSocketClose(&ipcSock)); - close(fd); NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastAddDevice, (*comm)->mc_handle, (CUdeviceptr)(*comm)->mydev); @@ -327,12 +375,11 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, if (getenv("NVTE_UBDEBUG")) printf( - "%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP %dx%d TPGROUP " - "%dx%d PIPE_ID %d/%d\n", + "%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP x%d TPGROUP " + "%dx%d\n", myrank, numranks, myrank / numlocal, myrank % numlocal, (*comm)->my_node, - (*comm)->ar_nvrank, (*comm)->my2_node, (*comm)->ar2_nvrank, (*comm)->num_nodes, - (*comm)->ar_nvsize, (*comm)->num2_nodes, (*comm)->ar2_nvsize, (*comm)->pipe_id, - pipegpus * pipenodes); + (*comm)->ar_nvrank, (*comm)->my_node, (*comm)->ar2_nvrank, (*comm)->ar_nvsize, + (*comm)->num_nodes, (*comm)->ar2_nvsize); fflush(NULL); return 0; @@ -361,43 +408,14 @@ int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipe UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_WORLD, &myrank)); UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_WORLD, &numranks)); - // find intranode numbers and make internode communicator - char hostname[MPI_MAX_PROCESSOR_NAME]; - int namelen; - UB_MPI_CHECK(MPI_Get_processor_name(hostname, &namelen)); - - char(*hostnames)[MPI_MAX_PROCESSOR_NAME] = - static_cast(malloc(numranks * MPI_MAX_PROCESSOR_NAME)); - strcpy(hostnames[myrank], hostname); // NOLINT(*) - for (int n = 0; n < numranks; n++) - UB_MPI_CHECK(MPI_Bcast(&(hostnames[n]), MPI_MAX_PROCESSOR_NAME, MPI_CHAR, n, EXT_COMM_WORLD)); - qsort(hostnames, numranks, MPI_MAX_PROCESSOR_NAME, stringCmp); - - int color = 0; - for (int n = 0; n < numranks; n++) { - if (n > 0 && strcmp(hostnames[n - 1], hostnames[n])) color++; - if (strcmp(hostname, hostnames[n]) == 0) break; - } - free(hostnames); - int mylocal, numlocal; - UB_MPI_CHECK(MPI_Comm_split(EXT_COMM_WORLD, color, myrank, &EXT_COMM_INTRA)); + UB_MPI_CHECK(MPI_Comm_split(EXT_COMM_WORLD, myrank / tensorgpus, myrank, &EXT_COMM_INTRA)); UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_INTRA, &mylocal)); UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_INTRA, &numlocal)); // find internode numbers and make internode communicator NVTE_CHECK_CUDA(cudaFree(0)); - int allnodes = numranks / numlocal; - int datanodes = allnodes / pipenodes / tensornodes; - // data reduction group node belongs, equals 0 for all if both pipenodes=1 and tensornodes=1 - int datanodegroup_id = myrank / numlocal / datanodes; - // mpi communicator only needed for SHARP which is always allreduce1/data-parallel - UB_MPI_CHECK(MPI_Comm_split(EXT_COMM_WORLD, mylocal + numlocal * datanodegroup_id, myrank, - &EXT_COMM_INTER)); - // different rails from same group are in different subcommunicators int mynode, numnodes; - UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_INTER, &numnodes)); - UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_INTER, &mynode)); // finally call the abstracted constructor with MPI info return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, @@ -447,13 +465,11 @@ void destroy_communicator(communicator *comm) { if (comm->use_mc) { NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->mc_handle); } - free(comm->fifo); delete comm; } void destroy_communicator_mpi(communicator *comm) { #ifdef NVTE_UB_WITH_MPI - MPI_Comm_free(static_cast(&(comm->comm_inter))); MPI_Comm_free(static_cast(&(comm->comm_intra))); destroy_communicator(comm); #else @@ -472,6 +488,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * #if CUDART_VERSION >= 12010 if (comm->use_mc && alloc) { + bool mnnvl_fabric = has_mnnvl_fabric(comm->mydev); int nranks = comm->nvsize; // total GPUs in NVLINK domain int myrank = comm->nvrank; void **remptrs = reinterpret_cast(malloc(nranks * sizeof(void *))); @@ -481,7 +498,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; prop.location.id = comm->mydev; prop.requestedHandleTypes = - CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; // CU_MEM_HANDLE_TYPE_FABRIC; + mnnvl_fabric ? CU_MEM_HANDLE_TYPE_FABRIC : CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; size_t granularity = 0; NVTE_CALL_CHECK_CUDA_DRIVER( @@ -507,41 +524,58 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * NVTE_CALL_CHECK_CUDA_DRIVER(cuMemCreate, &(comm->uchandles[hndl][myrank]), aligned_size, &prop, (uint64_t)0); - int *peerfd = reinterpret_cast(malloc(nranks * sizeof(int))); - NVTE_CALL_CHECK_CUDA_DRIVER( - cuMemExportToShareableHandle, reinterpret_cast(&peerfd[myrank]), - comm->uchandles[hndl][myrank], - static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), - (uint64_t)0); - - volatile uint32_t abortFlag = 0; - IpcSocketHandle ipcSock = {0}; - uint64_t opId = 0xdeadcafebeef; - ipcSocketResult_t ret = ipcSocketSuccess; - - // All-gather POSIX file descriptors across local ranks - IPCCHECK(ipcSocketInit(&ipcSock, myrank, (uint64_t)opId, &abortFlag)); - for (int p = 1; p < nranks; p++) { - int send_to = (myrank + p) % nranks; - int recv_from = (myrank + nranks - p) % nranks; - comm->_barrier(comm->comm_intra); - IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, peerfd[myrank], send_to, (uint64_t)opId), ret, error); - IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &peerfd[recv_from]), ret, error); - } + if (mnnvl_fabric) { + CUmemFabricHandle *exphndl; + CUmemFabricHandle myhndl; + NVTE_CALL_CHECK_CUDA_DRIVER(cuMemExportToShareableHandle, &myhndl, + comm->uchandles[hndl][myrank], CU_MEM_HANDLE_TYPE_FABRIC, 0); + NVTE_CHECK_CUDA(cudaMallocHost(&exphndl, comm->nvsize * sizeof(CUmemFabricHandle))); + comm->_allgather(reinterpret_cast(exphndl), comm->nvsize * sizeof(CUmemFabricHandle), + reinterpret_cast(&myhndl), sizeof(CUmemFabricHandle), + comm->comm_intra); + for (int p = 0; p < nranks; p++) + if (p != myrank) + NVTE_CALL_CHECK_CUDA_DRIVER(cuMemImportFromShareableHandle, &comm->uchandles[hndl][p], + reinterpret_cast(&exphndl[p]), + CU_MEM_HANDLE_TYPE_FABRIC); + NVTE_CHECK_CUDA(cudaFreeHost(exphndl)); + } else { + int *peerfd = reinterpret_cast(malloc(nranks * sizeof(int))); + NVTE_CALL_CHECK_CUDA_DRIVER( + cuMemExportToShareableHandle, reinterpret_cast(&peerfd[myrank]), + comm->uchandles[hndl][myrank], + static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), + (uint64_t)0); - error: - IPCCHECK(ipcSocketClose(&ipcSock)); + volatile uint32_t abortFlag = 0; + IpcSocketHandle ipcSock = {0}; + uint64_t opId = 0xdeadcafebeef + comm->my_node; + ipcSocketResult_t ret = ipcSocketSuccess; + + // All-gather POSIX file descriptors across local ranks + IPCCHECK(ipcSocketInit(&ipcSock, myrank, (uint64_t)opId, &abortFlag)); + for (int p = 1; p < nranks; p++) { + int send_to = (myrank + p) % nranks; + int recv_from = (myrank + nranks - p) % nranks; + comm->_barrier(comm->comm_intra); + IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, peerfd[myrank], send_to, (uint64_t)opId), ret, + error); + IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &peerfd[recv_from]), ret, error); + } - for (int p = 0; p < nranks; p++) { - if (p != myrank) - NVTE_CALL_CHECK_CUDA_DRIVER( - cuMemImportFromShareableHandle, &comm->uchandles[hndl][p], - reinterpret_cast(peerfd[p]), - static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); - close(peerfd[p]); - } - free(peerfd); + error: + IPCCHECK(ipcSocketClose(&ipcSock)); + for (int p = 0; p < nranks; p++) { + if (p != myrank) + NVTE_CALL_CHECK_CUDA_DRIVER( + cuMemImportFromShareableHandle, &comm->uchandles[hndl][p], + reinterpret_cast(peerfd[p]), + static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + close(peerfd[p]); + } + free(peerfd); + } CUdeviceptr ptr; NVTE_CALL_CHECK_CUDA_DRIVER(cuMemAddressReserve, &ptr, (size_t)(aligned_size * nranks), (size_t)0, (CUdeviceptr)0, (uint64_t)0); @@ -571,13 +605,13 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * cudaMemcpy((reinterpret_cast(comm->gpu_ptrs)) + (hndl * nranks * sizeof(void *)), remptrs, nranks * sizeof(void *), cudaMemcpyHostToDevice)); free(remptrs); - comm->memflags[hndl] = UB_MEM_UC_CONTIG | UB_MEM_ALLOCATED; + comm->memflags[hndl] = NVTE_UB_MEM_UC_CONTIG | NVTE_UB_MEM_ALLOCATED; if (comm->use_mc && comm->mc_maxsize >= comm->mc_offset + aligned_size) { NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastBindMem, comm->mc_handle, comm->mc_offset, comm->uchandles[hndl][myrank], (size_t)0 /*memOffset*/, aligned_size, (uint64_t)0); - comm->memflags[hndl] |= UB_MEM_MC_CREATED; + comm->memflags[hndl] |= NVTE_UB_MEM_MC_CREATED; comm->mc_ptr[hndl] = reinterpret_cast(comm->mc_baseptr) + comm->mc_offset; comm->mc_offset += aligned_size; } else if (!comm->myrank) { diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 58de844858..1211392e40 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -1682,6 +1682,7 @@ void reducescatter2_userbuff_strided(void *output, const int handler, const int SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); callranks_rs_oop_stride(2) callranks_rs_oop_stride(4) callranks_rs_oop_stride(8) + callranks_rs_oop_stride(16) callranks_rs_oop_stride(32) } void reducescatter2_userbuff_strided_atomic(void *output, const int handler, const int offset, const int rowelements, const int colelements, @@ -1703,7 +1704,8 @@ void reducescatter2_userbuff_strided_atomic(void *output, const int handler, con SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); callranks_rs_oop_stride_atomic(2) callranks_rs_oop_stride_atomic(4) - callranks_rs_oop_stride_atomic(8) + callranks_rs_oop_stride_atomic(8) callranks_rs_oop_stride_atomic(16) + callranks_rs_oop_stride_atomic(32) } template @@ -1729,6 +1731,7 @@ void reducescatter2_userbuff_strided_universal_fp8(void *output, float *scale, c SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); callranks_rs_oop_atomic_fp8(2) callranks_rs_oop_atomic_fp8(4) callranks_rs_oop_atomic_fp8(8) + callranks_rs_oop_atomic_fp8(16) callranks_rs_oop_atomic_fp8(32) } template @@ -1773,7 +1776,8 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); callranks_rs_oop_stride_multiatomic(2) callranks_rs_oop_stride_multiatomic(4) - callranks_rs_oop_stride_multiatomic(8) + callranks_rs_oop_stride_multiatomic(8) callranks_rs_oop_stride_multiatomic(16) + callranks_rs_oop_stride_multiatomic(32) } void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, @@ -1793,17 +1797,17 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int if (comm_launch_event) { SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) + if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { + callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) callranks_agMC(16) callranks_agMC(32) } else { - callranks_ag(2) callranks_ag(4) callranks_ag(8) + callranks_ag(2) callranks_ag(4) callranks_ag(8) callranks_ag(16) callranks_ag(32) } } else { SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) + if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { + callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) callranks_agMC(16) callranks_agMC(32) } else { - callranks_ag(2) callranks_ag(4) callranks_ag(8) + callranks_ag(2) callranks_ag(4) callranks_ag(8) callranks_ag(16) callranks_ag(32) } } } @@ -1840,17 +1844,17 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const if (comm_launch_event) { SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) + if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { + callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) callranks_rsMC(16) callranks_rsMC(32) } else { - callranks_rs(2) callranks_rs(4) callranks_rs(8) + callranks_rs(2) callranks_rs(4) callranks_rs(8) callranks_rs(16) callranks_rs(32) } } else { SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) + if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { + callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) callranks_rsMC(16) callranks_rsMC(32) } else { - callranks_rs(2) callranks_rs(4) callranks_rs(8) + callranks_rs(2) callranks_rs(4) callranks_rs(8) callranks_rs(16) callranks_rs(32) } } } @@ -1873,17 +1877,21 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons if (comm_launch_event) { SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) + if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { + callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) callranks_rs_oopMC(16) + callranks_rs_oopMC(32) } else { - callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) + callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) callranks_rs_oop(16) + callranks_rs_oop(32) } } else { SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) + if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { + callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) callranks_rs_oopMC(16) + callranks_rs_oopMC(32) } else { - callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) + callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) callranks_rs_oop(16) + callranks_rs_oop(32) } } } @@ -1915,10 +1923,12 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const if (comm_launch_event) { SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); - callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) + callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(16) + callranks_rs_oop_fp8(32) } else { SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) + callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(16) + callranks_rs_oop_fp8(32) } } diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index ee808b7f9a..84defcdb23 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -34,11 +34,7 @@ using ExtBarrierOp = std::function; #define NVTE_MAX_REQUESTS 1024 #define NVTE_LAUNCH_GPU 1 #define NVTE_LAUNCH_CPU 2 -#define NVTE_MAX_NVLINK 8 - -#define UB_MEM_UC_CONTIG 1 -#define UB_MEM_MC_CREATED 2 -#define UB_MEM_ALLOCATED 4 +#define NVTE_MAX_NVLINK 32 #define NVTE_UB_MEM_UC_CONTIG 1 #define NVTE_UB_MEM_MC_CREATED 2 @@ -124,11 +120,8 @@ struct communicator { ar_nvrank; // number of gpus(and first gpu in a group) of gpus per node in reduction subgroup // (_splitar init used) would be equal to (nvsize,0) for regular comm_create int ar2_nvsize, ar2_firstgpu, ar2_nvrank; // with ar_nvsize as a step - int pipe_id; // which allreduce set of groups (pipeline rank in range of 0..pipeline_size) int sm_arch; - int num_nodes, my_node, - first_node; // comm_inter communicator, per-rail allreduce (might have subset of nodes) - int num2_nodes, my2_node, first2_node; // with num_nodes as a stride + int num_nodes, my_node; // max value for running block counters in hostflags int basecounter[userbuffers_op_types]; // NOLINT(*) @@ -136,20 +129,11 @@ struct communicator { void *mem_mr[NVTE_MAX_REGIONS]; - ub_request *fifo; - int nblocks, alignblock, minblock, asyncblocks, active_nreqs; - ub_request active_req[userbuffers_op_types]; // NOLINT(*) - int padding[7]; - volatile int head; - int padding2[15]; - volatile int tail; - // Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks) ExtAllgatherOp _allgather; ExtBarrierOp _barrier; ExtComm comm_world; - ExtComm comm_inter; // reduction group communicator (subset of the nodes) along GPU rail ExtComm comm_intra; // full intranode (all ndev GPUS) #ifdef NVTE_UB_WITH_MPI MPI_Request mpihndl[NVTE_MAX_SHARP]; @@ -199,11 +183,6 @@ void destroy_communicator_mpi(communicator *comm); returned offset is offset of gpubuff relative to buffer registered */ -int pipe_rank(communicator *comm, - int step); // helper function to help walk across allreduce1 x allreduce2 groups - // data-parallel and tensor-parallel position within data and tensor - // groups would be preserved - int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc); /* returns handler and registers buffers. assumed to be collective i.e. you use same groups and dont mix buffers for different operations returns -1 if cant register (too many preregistered diff --git a/transformer_engine/common/util/cuda_driver.cpp b/transformer_engine/common/util/cuda_driver.cpp index 8605447c61..48fb5d77d9 100644 --- a/transformer_engine/common/util/cuda_driver.cpp +++ b/transformer_engine/common/util/cuda_driver.cpp @@ -4,8 +4,6 @@ * See LICENSE for license information. ************************************************************************/ -#include - #include #include "../common.h" @@ -13,84 +11,6 @@ namespace transformer_engine { -namespace { - -/*! \brief Wrapper class for a shared library - * - * \todo Windows support - */ -class Library { - public: - explicit Library(const char *filename) { -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - // TODO Windows support - NVTE_ERROR("Shared library initialization is not supported with Windows"); -#else - handle_ = dlopen(filename, RTLD_LAZY | RTLD_LOCAL); - NVTE_CHECK(handle_ != nullptr, "Lazy library initialization failed"); -#endif // _WIN32 or _WIN64 or __WINDOW__ - } - - ~Library() { -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - // TODO Windows support -#else - if (handle_ != nullptr) { - dlclose(handle_); - } -#endif // _WIN32 or _WIN64 or __WINDOW__ - } - - Library(const Library &) = delete; // move-only - - Library(Library &&other) noexcept { swap(*this, other); } - - Library &operator=(Library other) noexcept { - // Copy-and-swap idiom - swap(*this, other); - return *this; - } - - friend void swap(Library &first, Library &second) noexcept; - - void *get() noexcept { return handle_; } - - const void *get() const noexcept { return handle_; } - - /*! \brief Get pointer corresponding to symbol in shared library */ - void *get_symbol(const char *symbol) { -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - // TODO Windows support - NVTE_ERROR("Shared library initialization is not supported with Windows"); -#else - void *ptr = dlsym(handle_, symbol); - NVTE_CHECK(ptr != nullptr, "Could not find symbol in lazily-initialized library"); - return ptr; -#endif // _WIN32 or _WIN64 or __WINDOW__ - } - - private: - void *handle_ = nullptr; -}; - -void swap(Library &first, Library &second) noexcept { - using std::swap; - swap(first.handle_, second.handle_); -} - -/*! \brief Lazily-initialized shared library for CUDA driver */ -Library &cuda_driver_lib() { -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - constexpr char lib_name[] = "nvcuda.dll"; -#else - constexpr char lib_name[] = "libcuda.so.1"; -#endif - static Library lib(lib_name); - return lib; -} - -} // namespace - namespace cuda_driver { void *get_symbol(const char *symbol) { diff --git a/transformer_engine/common/util/cuda_nvml.cpp b/transformer_engine/common/util/cuda_nvml.cpp new file mode 100644 index 0000000000..0af9cd7411 --- /dev/null +++ b/transformer_engine/common/util/cuda_nvml.cpp @@ -0,0 +1,26 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "cuda_nvml.h" + +#include "shared_lib_wrapper.h" + +namespace transformer_engine { + +namespace cuda_nvml { + +/*! \brief Lazily-initialized shared library for CUDA NVML */ +Library &cuda_nvml_lib() { + constexpr char lib_name[] = "libnvidia-ml.so.1"; + static Library lib(lib_name); + return lib; +} + +void *get_symbol(const char *symbol) { return cuda_nvml_lib().get_symbol(symbol); } + +} // namespace cuda_nvml + +} // namespace transformer_engine diff --git a/transformer_engine/common/util/cuda_nvml.h b/transformer_engine/common/util/cuda_nvml.h new file mode 100644 index 0000000000..14131a3cdd --- /dev/null +++ b/transformer_engine/common/util/cuda_nvml.h @@ -0,0 +1,69 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_ +#define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_ + +#include + +#include + +#include "../common.h" +#include "../util/string.h" + +namespace transformer_engine { + +namespace cuda_nvml { + +/*! \brief Get pointer corresponding to symbol in CUDA NVML library */ +void *get_symbol(const char *symbol); + +/*! \brief Call function in CUDA NVML library + * + * The CUDA NVML library (libnvidia-ml.so.1 on Linux) may be different at + * compile-time and run-time. + * + * \param[in] symbol Function name + * \param[in] args Function arguments + */ +template +inline nvmlReturn_t call(const char *symbol, ArgTs... args) { + using FuncT = nvmlReturn_t(ArgTs...); + FuncT *func = reinterpret_cast(get_symbol(symbol)); + return (*func)(args...); +} + +/*! \brief Get NVML error string + * + * \param[in] rc NVML return code + */ +inline const char *get_nvml_error_string(nvmlReturn_t rc) { + using FuncT = const char *(nvmlReturn_t); + FuncT *func = reinterpret_cast(get_symbol("nvmlErrorString")); + return (*func)(rc); +} + +} // namespace cuda_nvml + +} // namespace transformer_engine + +#define NVTE_CHECK_CUDA_NVML(expr) \ + do { \ + const nvmlReturn_t status_NVTE_CHECK_CUDA_NVML = (expr); \ + if (status_NVTE_CHECK_CUDA_NVML != NVML_SUCCESS) { \ + const char *desc_NVTE_CHECK_CUDA_NVML = \ + ::transformer_engine::cuda_nvml::get_nvml_error_string(status_NVTE_CHECK_CUDA_NVML); \ + NVTE_ERROR("NVML Error: ", desc_NVTE_CHECK_CUDA_NVML); \ + } \ + } while (false) + +#define VA_ARGS(...) , ##__VA_ARGS__ +#define NVTE_CALL_CHECK_CUDA_NVML(symbol, ...) \ + do { \ + NVTE_CHECK_CUDA_NVML(::transformer_engine::cuda_nvml::call(#symbol VA_ARGS(__VA_ARGS__))); \ + } while (false) + +#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_ diff --git a/transformer_engine/common/util/shared_lib_wrapper.h b/transformer_engine/common/util/shared_lib_wrapper.h new file mode 100644 index 0000000000..3ccc8239b8 --- /dev/null +++ b/transformer_engine/common/util/shared_lib_wrapper.h @@ -0,0 +1,64 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_SHARED_LIB_WRAPPER_H_ +#define TRANSFORMER_ENGINE_COMMON_UTIL_SHARED_LIB_WRAPPER_H_ + +#include + +namespace transformer_engine { + +/*! \brief Wrapper class for a shared library + * + * \todo Windows support + */ +class Library { + public: + explicit Library(const char *filename) { +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + // TODO Windows support + NVTE_ERROR("Shared library initialization is not supported with Windows"); +#else + handle_ = dlopen(filename, RTLD_LAZY | RTLD_LOCAL); + NVTE_CHECK(handle_ != nullptr, "Lazy library initialization failed"); +#endif // _WIN32 or _WIN64 or __WINDOW__ + } + + ~Library() { +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + // TODO Windows support +#else + if (handle_ != nullptr) { + dlclose(handle_); + } +#endif // _WIN32 or _WIN64 or __WINDOW__ + } + + Library(const Library &) = delete; // move-only + + void *get() noexcept { return handle_; } + + const void *get() const noexcept { return handle_; } + + /*! \brief Get pointer corresponding to symbol in shared library */ + void *get_symbol(const char *symbol) { +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + // TODO Windows support + NVTE_ERROR("Shared library initialization is not supported with Windows"); +#else + void *ptr = dlsym(handle_, symbol); + NVTE_CHECK(ptr != nullptr, "Could not find symbol in lazily-initialized library"); + return ptr; +#endif // _WIN32 or _WIN64 or __WINDOW__ + } + + private: + void *handle_ = nullptr; +}; + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_UTIL_SHARED_LIB_WRAPPER_H_ diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e871228b80..d8fc76a2eb 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -390,8 +390,7 @@ class CommOverlapHelper : torch::CustomClassHolder { CommOverlapHelper(); CommOverlapHelper(c10d::ProcessGroup *world_group, - std::optional intra_node_group, - std::optional inter_node_group); + std::optional intra_node_group); ~CommOverlapHelper(); diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index 30126651ce..6d05869c36 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -26,8 +26,7 @@ CommOverlapHelper::CommOverlapHelper() { } // empty constructor for NVTE_UB_WITH_MPI=1 CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, - std::optional intra_domain_group, - std::optional inter_domain_group) { + std::optional intra_domain_group) { #ifndef NVTE_UB_WITH_MPI pgs.insert({"world", world_group}); myrank = pgs["world"]->getRank(); @@ -53,20 +52,9 @@ CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, mynode = 0; numnodes = 1; } else { - // Intra-node group is different than the world group so there must be multiple nodes - NVTE_CHECK( - inter_domain_group.has_value(), - "Internal TE error: Inter-node group cannot be `None` when intra-node group is not ", - "identical to the world_group!"); - // Get node ID and number of nodes - NVTE_CHECK( - inter_domain_group.value()->getBackendType() == backend, - "Internal TE error: Inter-node group must be on the same backend (%s) as the world ", - "group!", pgs["world"]->getBackendName()); - pgs.insert({"inter", inter_domain_group.value()}); - mynode = pgs["inter"]->getRank(); - numnodes = pgs["inter"]->getSize(); + mynode = myrank / numlocal; + numnodes = numranks / numlocal; } } else { // Intra-node group is not set so we assume there is only 1 node diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 442837d767..0604847235 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -285,10 +285,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::class_(m, "CommOverlapHelper") .def(py::init<>(), py::call_guard()) - .def(py::init, - std::optional>(), + .def(py::init>(), py::call_guard(), py::arg("world_group"), - py::arg("intra_node_group") = py::none(), py::arg("inter_node_group") = py::none()); + py::arg("intra_node_group") = py::none()); py::class_, transformer_engine::CommOverlapBase, transformer_engine::CommOverlapCore>(m, "CommOverlap") diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index d0f9525135..84326f58ea 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -7,9 +7,6 @@ import os import pickle import warnings -import socket -import fcntl -import struct from abc import ABC, abstractmethod from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union from contextlib import contextmanager @@ -177,85 +174,32 @@ def initialize_ub( world_rank = torch.distributed.get_rank(world_group) world_size = torch.distributed.get_world_size(world_group) - # We have single-node NVLink so we can color based on physical node hostnames. - # NOTE: Prefer a network interface defined via the NVTE_UB_SOCKET_IFNAME variable, and - # otherwise fall back on NCCL_SOCKET_IFNAME or GLOO_SOCKET_IFNAME depending on - # the chosen bootstrap backend. - mydomain = socket.gethostname() - ifname = os.getenv( - "NVTE_UB_SOCKET_IFNAME", os.getenv(f"{bootstrap_backend.upper()}_SOCKET_IFNAME") - ) - if ifname is not None: - # Make sure the ifname found in the environment is a valid network interface - if ifname in [name for _, name in socket.if_nameindex()]: - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - try: - mydomain = socket.inet_ntoa( - fcntl.ioctl( - s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) - )[20:24] - ) - except OSError as err: - raise OSError(f"Invalid network interface: {ifname}") from err - finally: - s.close() - else: - ifname_warning = ( - f"'{ifname}' is not a valid network interface! `te.initialize_ub()` will" - + " attempt to detect ranks on the same node by matching " - + "'socket.gethostname()', which is known to fail on virtual clusters like " - + "Kubernetes. If Userbuffers initialization fails, please set the " - + "'NVTE_UB_SOCKET_IFNAME' variable in your environment to the correct network " - + "interface." - ) - warnings.warn(ifname_warning, UserWarning) - - # Allgather the domain colors across ranks and reduce to a list of unique domains - domain_per_rank_list = [None for _ in range(world_size)] - torch.distributed.all_gather_object(domain_per_rank_list, mydomain, world_group) - unique_domains = [] - for domain in domain_per_rank_list: - if domain not in unique_domains: - unique_domains.append(domain) - num_domains = len(unique_domains) - + num_domains = world_size // tp_size + mydomain_idx = world_rank // tp_size if num_domains > 1: - # DP/TP model replicated on multiple NVLink domains - ranks_per_domain_list = [[] for _ in range(num_domains)] - mydomain_idx = -1 - for i, domain in enumerate(domain_per_rank_list): - domain_idx = unique_domains.index(domain) - ranks_per_domain_list[domain_idx].append(i) - if domain == mydomain: - mydomain_idx = domain_idx - assert mydomain_idx >= 0, "Internal TE error!" - - intra_domain_group, _ = torch.distributed.new_subgroups_by_enumeration( + ranks_per_domain_list = [ + [i * tp_size + t for t in range(tp_size)] for i in range(num_domains) + ] + tp_domain_group, _ = torch.distributed.new_subgroups_by_enumeration( ranks_per_domain_list, backend=bootstrap_backend ) - local_rank = torch.distributed.get_rank(intra_domain_group) - intra_domain_ranks = torch.distributed.get_process_group_ranks(intra_domain_group) - - inter_domain_group, _ = torch.distributed.new_subgroups_by_enumeration( - [list(ranks) for ranks in zip(*ranks_per_domain_list)], - backend=bootstrap_backend, - ) - - helper = tex.CommOverlapHelper(world_group, intra_domain_group, inter_domain_group) + local_rank = torch.distributed.get_rank(tp_domain_group) + tp_domain_ranks = torch.distributed.get_process_group_ranks(tp_domain_group) + helper = tex.CommOverlapHelper(world_group, tp_domain_group) else: # TP model on single NVLink domain, no replication, no data-parallelism mydomain_idx = 0 local_rank = world_rank - intra_domain_ranks = list(range(world_size)) + tp_domain_ranks = list(range(world_size)) helper = tex.CommOverlapHelper(world_group) if world_rank == 0: - print(f"!!! [UB] Number of NVLink domains: {num_domains}\n", end="", flush=True) + print(f"!!! [UB] Number of TP domains: {num_domains}\n", end="", flush=True) if local_rank == 0: print( - f"!!! [UB] Global ranks on domain {mydomain_idx}: {intra_domain_ranks}\n", + f"!!! [UB] Global ranks on TP domain {mydomain_idx}: {tp_domain_ranks}\n", end="", flush=True, )