Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] main from NVIDIA:main #88

Merged
merged 1 commit into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 23 additions & 81 deletions examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,16 @@ def _parse_args(argv=None, namespace=None):
help="Disable the comm+GEMM overlap.",
)
parser.add_argument(
"--num-replicas", type=int, default=1, help="Number of data-parallel model replicas."
"--num-replicas",
type=int,
default=1,
help="Number of data-parallel model replicas per node.",
)
parser.add_argument(
"--use-global-replica-count",
action="store_true",
default=False,
help="Treat '--num-replicas' as the total number of replicas.",
)
parser.add_argument(
"--tcp-init",
Expand Down Expand Up @@ -173,13 +182,12 @@ def _train(opts):
opts.tcp_init = True
opts.bind_to_device = True
opts.bootstrap_backend = "mpi"
elif "TORCHELASTIC_RUN_ID" in os.environ:
else: # TORCHELASTIC, SLURM, etc...
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
else:
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!")
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", str(torch.cuda.device_count())))

NUM_NODES = WORLD_SIZE // LOCAL_SIZE

# Initialize torch.distributed global process group and get DP/TP groups
Expand Down Expand Up @@ -214,90 +222,24 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False)

dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs")

# Figure out process groups for tensor- and data-parallelism (if any)
if NUM_NODES > 1:
# Create a list of world ranks on this node
hostname = socket.gethostname()
ifname = os.getenv(
"NVTE_UB_SOCKET_IFNAME",
os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")),
)

if ifname is not None:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
hostname = socket.inet_ntoa(
fcntl.ioctl(
s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8"))
)[20:24]
)
except OSError as err:
raise OSError(f"Invalid network interface: {ifname}") from err

hostnames = [None for _ in range(WORLD_SIZE)]
dist.all_gather_object(hostnames, hostname)
unique_hosts = []
for host in hostnames:
if host not in unique_hosts:
unique_hosts.append(host)
assert len(unique_hosts) == NUM_NODES

ranks_per_node_list = [[] for _ in range(NUM_NODES)]
self_node_idx = -1
for i, host in enumerate(hostnames):
node_idx = unique_hosts.index(host)
ranks_per_node_list[node_idx].append(i)
if host == hostname:
self_node_idx = node_idx
assert self_node_idx >= 0
self_node_ranks = ranks_per_node_list[self_node_idx]

if opts.num_replicas > 1:
# Split node ranks into multiple replicas
assert len(self_node_ranks) % opts.num_replicas == 0
tp_size = len(self_node_ranks) // opts.num_replicas
ranks_per_replica_list = []
for node_ranks in ranks_per_node_list:
for i in range(opts.num_replicas):
start = i * tp_size
end = start + tp_size
ranks_per_replica_list.append(node_ranks[start:end])

self_replica_idx = -1
for i, replica_ranks in enumerate(ranks_per_replica_list):
if WORLD_RANK in replica_ranks:
self_replica_idx = i
break
assert self_replica_idx >= 0
total_replicas = (
opts.num_replicas if opts.use_global_replica_count else opts.num_replicas * NUM_NODES
)
tp_size = WORLD_SIZE // total_replicas

else:
# The entire node is the tensor-parallel group
ranks_per_replica_list = ranks_per_node_list
self_replica_idx = self_node_idx
if total_replicas > 1:
ranks_per_replica_list = [
[i * tp_size + t for t in range(tp_size)] for i in range(total_replicas)
]

tp_group, _ = dist.new_subgroups_by_enumeration(ranks_per_replica_list, backend="nccl")
ranks_per_replica_tensor = torch.tensor(ranks_per_replica_list, dtype=torch.int32)
dp_group, _ = dist.new_subgroups_by_enumeration(
ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl"
)

else:
if opts.num_replicas > 1:
# Mixed data- and tensor-parallelism on a single node
# NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions
all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu")
ranks_per_replica_tensor = all_ranks.reshape(
(opts.num_replicas, LOCAL_SIZE // opts.num_replicas)
)
tp_group, _ = dist.new_subgroups_by_enumeration(
ranks_per_replica_tensor.tolist(), backend="nccl"
)
dp_group, _ = dist.new_subgroups_by_enumeration(
ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl"
)
else:
dp_group = None
tp_group = nccl_world
dp_group = None
tp_group = nccl_world

tp_rank = dist.get_rank(tp_group)
tp_size = dist.get_world_size(tp_group)
Expand Down
19 changes: 13 additions & 6 deletions tests/pytorch/distributed/run_gemm_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,22 @@ def _main(opts):
LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1"))
opts.tcp_init = True
opts.bootstrap_backend = "mpi"
elif "TORCHELASTIC_RUN_ID" in os.environ:
else: # TORCHELASTIC, SLURM, etc...
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
else:
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!")
assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node
assert LOCAL_SIZE <= torch.cuda.device_count()
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", str(torch.cuda.device_count())))

result = subprocess.run(
"nvidia-smi -q | grep -m1 CliqueId | awk '{printf $3}'",
capture_output=True,
text=True,
shell=True,
)

if result.stdout == "0": # Extra checks for non-MNNVL platforms
assert WORLD_SIZE == LOCAL_SIZE
assert LOCAL_SIZE <= torch.cuda.device_count()

# Fix clock speed
torch.cuda.set_device(LOCAL_RANK)
Expand Down
24 changes: 16 additions & 8 deletions tests/pytorch/distributed/run_layer_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import sys
import socket
import subprocess
import argparse
import warnings
import pprint
Expand Down Expand Up @@ -209,14 +210,21 @@ def _train(opts):
opts.tcp_init = True
opts.bind_to_device = True
opts.bootstrap_backend = "mpi"
elif "TORCHELASTIC_RUN_ID" in os.environ:
else:
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
else:
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!")
assert LOCAL_SIZE == WORLD_SIZE
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", str(torch.cuda.device_count())))

result = subprocess.run(
"nvidia-smi -q | grep -m1 CliqueId | awk '{printf $3}'",
capture_output=True,
text=True,
shell=True,
)

if result.stdout == "0": # Extra checks for non-MNNVL platforms
assert WORLD_SIZE == LOCAL_SIZE

def dist_print(msg, src=None, end="\n", debug=False, error=False):
if debug and not opts.debug:
Expand All @@ -227,7 +235,7 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False):
dist.barrier()

# Set device and initialize RNG states
torch.cuda.set_device(WORLD_RANK)
torch.cuda.set_device(LOCAL_RANK)
torch.manual_seed(opts.seed)
torch.cuda.manual_seed(opts.seed)

Expand Down Expand Up @@ -312,7 +320,7 @@ def run_fwd_bwd(model, x):
return out

torch_rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state(torch.device(f"cuda:{WORLD_RANK}"))
cuda_rng_state = torch.cuda.get_rng_state(torch.device(f"cuda:{LOCAL_RANK}"))
if opts.use_cuda_graphs:
test_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(test_graph):
Expand All @@ -329,7 +337,7 @@ def run_fwd_bwd(model, x):
names.append(test_name + ".grad")

torch.set_rng_state(torch_rng_state)
torch.cuda.set_rng_state(cuda_rng_state, torch.device(f"cuda:{WORLD_RANK}"))
torch.cuda.set_rng_state(cuda_rng_state, torch.device(f"cuda:{LOCAL_RANK}"))
if opts.use_cuda_graphs:
ref_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(ref_graph):
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ list(APPEND transformer_engine_SOURCES
util/cast.cu
util/padding.cu
util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp
util/rtc.cpp
swizzle/swizzle.cu
Expand Down
Loading
Loading