Skip to content

Commit a18ba8b

Browse files
committed
changes to include the distributed operations in the aten_ops lib
1 parent fc74cff commit a18ba8b

9 files changed

+312
-204
lines changed

examples/distributed_inference/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ torchrun --nproc_per_node=2 tensor_parallel_llama2.py
2525

2626
pip install tensorrt-llm
2727

28-
For other python versions, you need to load the libnvinfer_plugin_tensorrt_llm.so. Please set that in the environment variable export trtllm_env={lib_path}. For example, we have already set the variable in initialize_distributed_env(). Note that won't work while running example, since it needs to be preset for the converter library to get.
28+
For other python versions, you need to load the libnvinfer_plugin_tensorrt_llm.so. Please set that in the environment variable export TRTLLM_PLUGINS_PATH={lib_path}. For example, we have already set the variable in initialize_distributed_env(). You can replace this with your TRTLLM_PLUGINS_PATH and unset it there
2929

3030
#then pip install the tensorrt and torch version compatible with installed torchTRT
3131

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import logging
2+
import os
3+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
4+
5+
import numpy as np
6+
import tensorrt as trt
7+
import torch
8+
import torch.distributed as dist
9+
from torch.distributed._tensor.device_mesh import init_device_mesh
10+
11+
12+
def find_repo_root(max_depth=10):
13+
dir_path = os.path.dirname(os.path.realpath(__file__))
14+
for i in range(max_depth):
15+
files = os.listdir(dir_path)
16+
if "MODULE.bazel" in files:
17+
return dir_path
18+
else:
19+
dir_path = os.path.dirname(dir_path)
20+
21+
raise RuntimeError("Could not find repo root")
22+
23+
24+
def initialize_logger(rank, logger_file_name):
25+
logger = logging.getLogger()
26+
logger.setLevel(logging.INFO)
27+
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
28+
fh.setLevel(logging.INFO)
29+
logger.addHandler(fh)
30+
return logger
31+
32+
33+
# This is required for env initialization since we use mpirun
34+
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500):
35+
local_rank = int(
36+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
37+
)
38+
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))
39+
40+
# Set up environment variable to run with mpirun
41+
os.environ["RANK"] = str(local_rank)
42+
os.environ["WORLD_SIZE"] = str(world_size)
43+
os.environ["MASTER_ADDR"] = "127.0.0.1"
44+
os.environ["MASTER_PORT"] = str(port)
45+
os.environ["TRTLLM_PLUGINS_PATH"] = (
46+
find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so"
47+
)
48+
49+
# Necessary to assign a device to each rank.
50+
torch.cuda.set_device(local_rank)
51+
52+
# We use nccl backend
53+
dist.init_process_group("nccl")
54+
55+
# set a manual seed for reproducibility
56+
torch.manual_seed(1111)
57+
58+
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
59+
rank = device_mesh.get_rank()
60+
assert rank == local_rank
61+
logger = initialize_logger(rank, logger_file_name)
62+
device_id = (
63+
rank % torch.cuda.device_count()
64+
) # Ensure each rank gets a unique device
65+
torch.cuda.set_device(device_id)
66+
67+
return device_mesh, world_size, rank, logger

examples/distributed_inference/tensor_parallel_llama3.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,20 @@
55
import time
66

77
import torch
8-
import torch_tensorrt
98
from llama3_model import ModelArgs, ParallelTransformer
10-
from tensor_parallel_nccl_ops import register_nccl_ops
9+
from tensor_parallel_initialize_dist import initialize_distributed_env
1110
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
1211
from torch.distributed._composable.fsdp.fully_shard import fully_shard
1312
from torch.distributed._tensor import Replicate, Shard
1413
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1514
checkpoint_wrapper,
1615
)
1716

18-
device_mesh, _world_size, _rank, logger = register_nccl_ops("./tensor_parallel_llama3")
17+
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
18+
"./tensor_parallel_llama3"
19+
)
20+
# Import should be after initialization of the TRT-LLM plugin .so path
21+
import tensorrt_llm
1922

2023
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
2124
assert (

examples/distributed_inference/tensor_parallel_nccl_ops.py

-197
This file was deleted.

examples/distributed_inference/tensor_parallel_simple_example.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
import time
22

33
import tensorrt as trt
4-
import tensorrt_llm
54
import torch
65
import torch.nn as nn
76
import torch_tensorrt
8-
from tensor_parallel_nccl_ops import register_nccl_ops
7+
from tensor_parallel_initialize_dist import initialize_distributed_env
98
from torch.distributed._tensor import Shard
109
from torch.distributed.tensor.parallel import (
1110
ColwiseParallel,
1211
RowwiseParallel,
1312
parallelize_module,
1413
)
1514

16-
device_mesh, _world_size, _rank, logger = register_nccl_ops(
15+
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
1716
"./tensor_parallel_simple_example"
1817
)
18+
import tensorrt_llm
1919

2020
"""
2121
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+44
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
66

77
import numpy as np
8+
import tensorrt as trt
89
import torch
910
from torch.fx.node import Argument, Node, Target
1011
from torch_tensorrt.dynamo._settings import CompilationSettings
@@ -20,6 +21,11 @@
2021
enforce_tensor_types,
2122
get_positive_dim,
2223
is_only_operator_on_placeholder,
24+
load_tensorrt_llm,
25+
)
26+
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
27+
tensorrt_fused_nccl_all_gather_op,
28+
tensorrt_fused_nccl_reduce_scatter_op,
2329
)
2430
from torch_tensorrt.dynamo.types import TRTTensor
2531

@@ -3585,3 +3591,41 @@ def aten_ops_full(
35853591
fill_value=args[1],
35863592
dtype=kwargs.get("dtype", None),
35873593
)
3594+
3595+
3596+
if load_tensorrt_llm():
3597+
3598+
@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
3599+
def insert_nccl_gather_op(
3600+
ctx: ConversionContext,
3601+
target: Target,
3602+
args: Tuple[Argument, ...],
3603+
kwargs: Dict[str, Argument],
3604+
name: str,
3605+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3606+
return impl.distributed.gather_op(
3607+
ctx,
3608+
target,
3609+
SourceIR.ATEN,
3610+
name,
3611+
[args[0]],
3612+
)
3613+
3614+
@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op)
3615+
def insert_nccl_reduce_scatter_plugin(
3616+
ctx: ConversionContext,
3617+
target: Target,
3618+
args: Tuple[Argument, ...],
3619+
kwargs: Dict[str, Argument],
3620+
name: str,
3621+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3622+
return impl.distributed.reduce_scatter_op(
3623+
ctx,
3624+
target,
3625+
SourceIR.ATEN,
3626+
name,
3627+
[args[0]],
3628+
)
3629+
3630+
else:
3631+
_LOGGER.warning("Unable to load the TRT-LLM plugins")

0 commit comments

Comments
 (0)