Skip to content

Commit b6f5980

Browse files
committed
using nccl ops from TRT-LLM namespace
1 parent 6d40ff1 commit b6f5980

File tree

7 files changed

+296
-27
lines changed

7 files changed

+296
-27
lines changed

examples/distributed_inference/README.md

+16
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,19 @@ See the examples started with `data_parallel` for more details.
1414
Here we use torch.distributed as an example, but compilation with tensor parallelism is agnostic to the implementation framework as long as the module is properly sharded.
1515

1616
torchrun --nproc_per_node=2 tensor_parallel_llama2.py
17+
18+
3. Tensor parallel distributed inference using nccl ops plugin
19+
20+
apt install libmpich-dev
21+
apt install libopenmpi-dev
22+
pip install tensorrt-llm
23+
#then pip install the tensorrt and torch version compatible with installed torchTRT
24+
mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
25+
26+
4. Tensor parallel distributed llama3 inference using nccl ops plugin
27+
28+
apt install libmpich-dev
29+
apt install libopenmpi-dev
30+
pip install tensorrt-llm
31+
#then pip install the tensorrt and torch version compatible with installed torchTRT
32+
mpirun -n 2 --allow-run-as-root python tensor_parallel_llama3.py
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
accelerate
22
transformers
3-
diffusers
3+
diffusers
4+
site
5+
tensorrt-llm

examples/distributed_inference/tensor_parallel_llama3.py

+7-13
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,20 @@
77
import torch
88
import torch_tensorrt
99
from llama3_model import ModelArgs, ParallelTransformer
10+
from tensor_parallel_nccl_ops import register_nccl_ops
1011
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
1112
from torch.distributed._composable.fsdp.fully_shard import fully_shard
1213
from torch.distributed._tensor import Replicate, Shard
1314
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1415
checkpoint_wrapper,
1516
)
16-
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
1717

18-
_rank = int(os.environ["RANK"])
19-
_world_size = int(os.environ["WORLD_SIZE"])
20-
tp_size = 2
18+
device_mesh, _world_size, _rank, logger = register_nccl_ops("./tensor_parallel_llama3")
2119

22-
logger = logging.getLogger()
23-
logger.setLevel(logging.INFO)
24-
fh = logging.FileHandler(f"./tensor_parallel_log_{_rank}.log", mode="w")
25-
fh.setLevel(logging.INFO)
26-
logger.addHandler(fh)
27-
28-
tp_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
20+
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
21+
assert (
22+
_world_size % 2 == 0
23+
), f"TP examples require even number of GPUs, but got {_world_size} gpus"
2924

3025
model_args = ModelArgs(
3126
vocab_size=32000,
@@ -38,7 +33,7 @@
3833
)
3934

4035
with torch.no_grad():
41-
model = ParallelTransformer(model_args, tp_mesh)
36+
model = ParallelTransformer(model_args, device_mesh)
4237
torch.manual_seed(0)
4338
inp = torch.randint(32000, (8, 256), device="cuda")
4439
python_result = model(inp)
@@ -53,7 +48,6 @@
5348
"use_python_runtime": True,
5449
"workspace_size": 1 << 33,
5550
"debug": False,
56-
"timing_cache_path": "/opt/file/cache/timing_cache_llama.bin",
5751
},
5852
dynamic=False,
5953
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import ctypes
2+
import logging
3+
import os
4+
import site
5+
from enum import IntEnum, IntFlag, auto
6+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
7+
8+
import numpy as np
9+
import tensorrt as trt
10+
import tensorrt_llm
11+
import torch
12+
import torch.distributed as dist
13+
import torch_tensorrt
14+
from torch.distributed._tensor.device_mesh import init_device_mesh
15+
from torch.fx import GraphModule, Node
16+
from torch.fx.node import Argument, Target
17+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
18+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
19+
dynamo_tensorrt_converter,
20+
)
21+
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
22+
custom_fused_all_gather_op,
23+
custom_fused_reduce_scatter_op,
24+
)
25+
from torch_tensorrt.dynamo.types import TRTTensor
26+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
27+
28+
29+
# class for AllReduce
30+
class AllReduceStrategy(IntEnum):
31+
"""Warning: actual definition is in kernels/customAllReduceKernels.h.
32+
33+
They must be kept in sync.
34+
"""
35+
36+
NCCL = 0
37+
ONESHOT = 1
38+
TWOSHOT = 2
39+
AUTO = 3
40+
41+
42+
class AllReduceConfig(IntFlag):
43+
"""Warning: actual definition is in kernels/customAllReduceKernels.h.
44+
45+
They must be kept in sync
46+
"""
47+
48+
USE_MEMCPY = auto()
49+
PUSH_MODE = auto()
50+
51+
52+
def initialize_logger(rank, logger_file_name):
53+
logger = logging.getLogger()
54+
logger.setLevel(logging.INFO)
55+
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
56+
fh.setLevel(logging.INFO)
57+
logger.addHandler(fh)
58+
return logger
59+
60+
61+
# This is required for env initialization since we use mpirun
62+
def initialize_distributed_env(rank=0, world_size=1, port=29500):
63+
local_rank = int(
64+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
65+
)
66+
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))
67+
68+
# Set up environment variable to run with mpirun
69+
os.environ["RANK"] = str(local_rank)
70+
os.environ["WORLD_SIZE"] = str(world_size)
71+
os.environ["MASTER_ADDR"] = "127.0.0.1"
72+
os.environ["MASTER_PORT"] = str(port)
73+
74+
# Necessary to assign a device to each rank.
75+
torch.cuda.set_device(local_rank)
76+
77+
# We use nccl backend
78+
dist.init_process_group("nccl")
79+
80+
# set a manual seed for reproducibility
81+
torch.manual_seed(1111)
82+
83+
return local_rank, world_size
84+
85+
86+
def register_nccl_ops(logger_file_name):
87+
# Initialization
88+
initialize_distributed_env()
89+
# create a device mesh based on the given world_size.
90+
_world_size = int(os.environ["WORLD_SIZE"])
91+
92+
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
93+
_rank = device_mesh.get_rank()
94+
logger = initialize_logger(_rank, logger_file_name)
95+
device_id = (
96+
_rank % torch.cuda.device_count()
97+
) # Ensure each rank gets a unique device
98+
torch.cuda.set_device(device_id)
99+
100+
# TensorRT NCCL plugins
101+
# Iterate over all registered plugin creators
102+
plugin_registry = trt.get_plugin_registry()
103+
for plugin_creator in plugin_registry.plugin_creator_list:
104+
logger.info(
105+
f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}"
106+
)
107+
108+
@dynamo_tensorrt_converter(custom_fused_all_gather_op)
109+
def insert_nccl_gather_op(
110+
ctx: ConversionContext,
111+
target: Target,
112+
args: Tuple[Argument, ...],
113+
kwargs: Dict[str, Argument],
114+
name: str,
115+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
116+
plug_inputs = [args[0]]
117+
allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator(
118+
"AllGather", "1", "tensorrt_llm"
119+
)
120+
assert allgather_plg_creator is not None
121+
_world_size = int(os.environ["WORLD_SIZE"])
122+
group = list(range(_world_size))
123+
group = trt.PluginField(
124+
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
125+
)
126+
p_dtype = trt.float16
127+
pf_type = trt.PluginField(
128+
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
129+
)
130+
pfc = trt.PluginFieldCollection([group, pf_type])
131+
allgather = allgather_plg_creator.create_plugin("allgather", pfc)
132+
layer = ctx.net.add_plugin_v2(plug_inputs, allgather)
133+
set_layer_name(layer, target, name)
134+
return layer.get_output(0)
135+
136+
@dynamo_tensorrt_converter(custom_fused_reduce_scatter_op)
137+
def insert_nccl_reduce_scatter_plugin(
138+
ctx: ConversionContext,
139+
target: Target,
140+
args: Tuple[Argument, ...],
141+
kwargs: Dict[str, Argument],
142+
name: str,
143+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
144+
plug_inputs = [args[0]]
145+
allreduce_plg_creator = trt.get_plugin_registry().get_plugin_creator(
146+
"ReduceScatter", "1", "tensorrt_llm"
147+
)
148+
149+
assert allreduce_plg_creator is not None
150+
151+
counter = 0
152+
strategy = AllReduceStrategy.NCCL
153+
config = AllReduceConfig(0)
154+
155+
world_size = dist.get_world_size()
156+
group = list(range(world_size))
157+
group = trt.PluginField(
158+
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
159+
)
160+
161+
p_dtype = trt.float16
162+
pf_dtype = trt.PluginField(
163+
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
164+
)
165+
pfc = [group, pf_dtype]
166+
p_strategy = trt.PluginField(
167+
"strategy", np.array([int(strategy)], np.int8), trt.PluginFieldType.INT8
168+
)
169+
pfc.append(p_strategy)
170+
p_config = trt.PluginField(
171+
"config", np.array([int(config)], np.int8), trt.PluginFieldType.INT8
172+
)
173+
pfc.append(p_config)
174+
p_counter = trt.PluginField(
175+
"counter", np.array([counter], np.int32), trt.PluginFieldType.INT32
176+
)
177+
pfc.append(p_counter)
178+
179+
pfc = trt.PluginFieldCollection(pfc)
180+
ar_plug = allreduce_plg_creator.create_plugin("allreduce", pfc)
181+
182+
layer = ctx.net.add_plugin_v2(plug_inputs, ar_plug)
183+
set_layer_name(layer, target, name)
184+
return layer.get_output(0)
185+
186+
return device_mesh, _world_size, _rank, logger

examples/distributed_inference/tensor_parallel_simple_example.py

+10-13
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
1-
import os
2-
import sys
31
import time
42

3+
import tensorrt as trt
4+
import tensorrt_llm
55
import torch
66
import torch.nn as nn
77
import torch_tensorrt
8+
from tensor_parallel_nccl_ops import register_nccl_ops
89
from torch.distributed._tensor import Shard
9-
from torch.distributed._tensor.device_mesh import init_device_mesh
1010
from torch.distributed.tensor.parallel import (
1111
ColwiseParallel,
1212
RowwiseParallel,
1313
parallelize_module,
1414
)
1515

16+
device_mesh, _world_size, _rank, logger = register_nccl_ops(
17+
"./tensor_parallel_simple_example"
18+
)
19+
1620
"""
1721
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
1822
"""
@@ -36,14 +40,7 @@ def forward(self, x):
3640
return x
3741

3842

39-
# create a device mesh based on the given world_size.
40-
_world_size = int(os.environ["WORLD_SIZE"])
41-
42-
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
43-
_rank = device_mesh.get_rank()
44-
45-
46-
print(f"Starting PyTorch TP example on rank {_rank}.")
43+
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
4744
assert (
4845
_world_size % 2 == 0
4946
), f"TP examples require even number of GPUs, but got {_world_size} gpus"
@@ -91,9 +88,9 @@ def forward(self, x):
9188
output = tp_model(inp)
9289
end = time.time()
9390
if i == 0:
94-
print(f"Compilation time is {end-start}")
91+
logger.info(f"Compilation time is {end-start}")
9592
assert (
9693
python_result - output
9794
).std() < 0.01, "Compilation result is not correct."
9895
elif _rank == 0:
99-
print(f"Inference time is {end-start}")
96+
logger.info(f"Inference time is {end-start}")

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .accumulate_fp32_matmul import accumulate_fp32_matmul
88
from .constant_folding import constant_fold
9+
from .fuse_distributed_ops import fuse_distributed_ops
910
from .fuse_prims_broadcast import fuse_prims_broadcast
1011
from .lower_linear import lower_linear
1112
from .lower_scaled_dot_product_attention import lower_scaled_dot_product_attention
@@ -26,6 +27,7 @@
2627
lower_scaled_dot_product_attention,
2728
lower_linear,
2829
fuse_prims_broadcast,
30+
fuse_distributed_ops,
2931
replace_max_pool_with_indices,
3032
replace_full_like_with_full,
3133
view_to_reshape,

0 commit comments

Comments
 (0)