Skip to content

Commit 509d917

Browse files
committed
changes to include the distributed operations in the aten_ops lib
1 parent c38f613 commit 509d917

File tree

7 files changed

+276
-201
lines changed

7 files changed

+276
-201
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import logging
2+
import os
3+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
4+
5+
import numpy as np
6+
import tensorrt as trt
7+
import torch
8+
import torch.distributed as dist
9+
from torch.distributed._tensor.device_mesh import init_device_mesh
10+
11+
12+
def find_repo_root(max_depth=10):
13+
dir_path = os.path.dirname(os.path.realpath(__file__))
14+
for i in range(max_depth):
15+
files = os.listdir(dir_path)
16+
if "MODULE.bazel" in files:
17+
return dir_path
18+
else:
19+
dir_path = os.path.dirname(dir_path)
20+
21+
raise RuntimeError("Could not find repo root")
22+
23+
24+
def initialize_logger(rank, logger_file_name):
25+
logger = logging.getLogger()
26+
logger.setLevel(logging.INFO)
27+
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
28+
fh.setLevel(logging.INFO)
29+
logger.addHandler(fh)
30+
return logger
31+
32+
33+
# This is required for env initialization since we use mpirun
34+
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500):
35+
local_rank = int(
36+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
37+
)
38+
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))
39+
40+
# Set up environment variable to run with mpirun
41+
os.environ["RANK"] = str(local_rank)
42+
os.environ["WORLD_SIZE"] = str(world_size)
43+
os.environ["MASTER_ADDR"] = "127.0.0.1"
44+
os.environ["MASTER_PORT"] = str(port)
45+
# Note this will not work in the initialization here
46+
# You would need to set it externally as a user
47+
os.environ["trtllm_env"] = (
48+
find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so"
49+
)
50+
51+
# Necessary to assign a device to each rank.
52+
torch.cuda.set_device(local_rank)
53+
54+
# We use nccl backend
55+
dist.init_process_group("nccl")
56+
57+
# set a manual seed for reproducibility
58+
torch.manual_seed(1111)
59+
60+
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
61+
rank = device_mesh.get_rank()
62+
assert rank == local_rank
63+
logger = initialize_logger(rank, logger_file_name)
64+
device_id = (
65+
rank % torch.cuda.device_count()
66+
) # Ensure each rank gets a unique device
67+
torch.cuda.set_device(device_id)
68+
69+
return device_mesh, world_size, rank, logger

examples/distributed_inference/tensor_parallel_llama3.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@
77
import torch
88
import torch_tensorrt
99
from llama3_model import ModelArgs, ParallelTransformer
10-
from tensor_parallel_nccl_ops import register_nccl_ops
10+
from tensor_parallel_initialize_dist import initialize_distributed_env
1111
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
1212
from torch.distributed._composable.fsdp.fully_shard import fully_shard
1313
from torch.distributed._tensor import Replicate, Shard
1414
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1515
checkpoint_wrapper,
1616
)
1717

18-
device_mesh, _world_size, _rank, logger = register_nccl_ops("./tensor_parallel_llama3")
18+
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
19+
"./tensor_parallel_llama3"
20+
)
1921

2022
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
2123
assert (

examples/distributed_inference/tensor_parallel_nccl_ops.py

-197
This file was deleted.

examples/distributed_inference/tensor_parallel_simple_example.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
import torch
66
import torch.nn as nn
77
import torch_tensorrt
8-
from tensor_parallel_nccl_ops import register_nccl_ops
8+
from tensor_parallel_initialize_dist import initialize_distributed_env
99
from torch.distributed._tensor import Shard
1010
from torch.distributed.tensor.parallel import (
1111
ColwiseParallel,
1212
RowwiseParallel,
1313
parallelize_module,
1414
)
1515

16-
device_mesh, _world_size, _rank, logger = register_nccl_ops(
16+
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
1717
"./tensor_parallel_simple_example"
1818
)
1919

0 commit comments

Comments
 (0)