Skip to content

Commit 38335b9

Browse files
committed
changes to include the distributed operations in the aten_ops lib
1 parent 6707c6f commit 38335b9

File tree

6 files changed

+257
-199
lines changed

6 files changed

+257
-199
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import logging
2+
import os
3+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
4+
5+
import numpy as np
6+
import tensorrt as trt
7+
import tensorrt_llm
8+
import torch
9+
import torch.distributed as dist
10+
from torch.distributed._tensor.device_mesh import init_device_mesh
11+
from torch.fx.node import Argument, Target
12+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
13+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
14+
dynamo_tensorrt_converter,
15+
)
16+
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
17+
tensorrt_fused_nccl_all_gather_op,
18+
tensorrt_fused_nccl_reduce_scatter_op,
19+
)
20+
from torch_tensorrt.dynamo.types import TRTTensor
21+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
22+
23+
24+
def initialize_logger(rank, logger_file_name):
25+
logger = logging.getLogger()
26+
logger.setLevel(logging.INFO)
27+
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
28+
fh.setLevel(logging.INFO)
29+
logger.addHandler(fh)
30+
return logger
31+
32+
33+
# This is required for env initialization since we use mpirun
34+
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500):
35+
local_rank = int(
36+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
37+
)
38+
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))
39+
40+
# Set up environment variable to run with mpirun
41+
os.environ["RANK"] = str(local_rank)
42+
os.environ["WORLD_SIZE"] = str(world_size)
43+
os.environ["MASTER_ADDR"] = "127.0.0.1"
44+
os.environ["MASTER_PORT"] = str(port)
45+
46+
# Necessary to assign a device to each rank.
47+
torch.cuda.set_device(local_rank)
48+
49+
# We use nccl backend
50+
dist.init_process_group("nccl")
51+
52+
# set a manual seed for reproducibility
53+
torch.manual_seed(1111)
54+
55+
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
56+
rank = device_mesh.get_rank()
57+
assert rank == local_rank
58+
logger = initialize_logger(rank, logger_file_name)
59+
device_id = (
60+
rank % torch.cuda.device_count()
61+
) # Ensure each rank gets a unique device
62+
torch.cuda.set_device(device_id)
63+
64+
return device_mesh, world_size, rank, logger

examples/distributed_inference/tensor_parallel_llama3.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,20 @@
77
import torch
88
import torch_tensorrt
99
from llama3_model import ModelArgs, ParallelTransformer
10-
from tensor_parallel_nccl_ops import register_nccl_ops
1110
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
1211
from torch.distributed._composable.fsdp.fully_shard import fully_shard
1312
from torch.distributed._tensor import Replicate, Shard
1413
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1514
checkpoint_wrapper,
1615
)
1716

18-
device_mesh, _world_size, _rank, logger = register_nccl_ops("./tensor_parallel_llama3")
17+
from TensorRT.examples.distributed_inference.tensor_parallel_initialize_dist import (
18+
initialize_distributed_env,
19+
)
20+
21+
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
22+
"./tensor_parallel_llama3"
23+
)
1924

2025
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
2126
assert (

examples/distributed_inference/tensor_parallel_nccl_ops.py

-197
This file was deleted.

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+59
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# mypy: disallow-untyped-decorators=False
22

3+
import ctypes
34
import logging
45
import operator
56
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
67

78
import numpy as np
9+
import tensorrt as trt
810
import torch
911
from torch.fx.node import Argument, Node, Target
1012
from torch_tensorrt.dynamo._settings import CompilationSettings
@@ -19,6 +21,11 @@
1921
enforce_tensor_types,
2022
get_positive_dim,
2123
is_only_operator_on_placeholder,
24+
plugin_lib_path,
25+
)
26+
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
27+
tensorrt_fused_nccl_all_gather_op,
28+
tensorrt_fused_nccl_reduce_scatter_op,
2229
)
2330
from torch_tensorrt.dynamo.types import TRTTensor
2431

@@ -3558,3 +3565,55 @@ def aten_ops_full(
35583565
fill_value=args[1],
35593566
dtype=kwargs.get("dtype", None),
35603567
)
3568+
3569+
3570+
try:
3571+
import tensorrt_llm as trt_llm
3572+
except (ImportError, AssertionError) as e:
3573+
_LOGGER.warning("tensorrt_llm is not installed. Please install tensorrt_llm", e)
3574+
# note this is for Linux only
3575+
plugin_lib_path = plugin_lib_path()
3576+
handle = ctypes.CDLL(plugin_lib_path)
3577+
try:
3578+
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
3579+
handle.initTrtLlmPlugins.restype = ctypes.c_bool
3580+
except AttributeError as e_1:
3581+
_LOGGER.warning("TensorRT-LLM Plugin is unavailable")
3582+
try:
3583+
TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
3584+
assert handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8"))
3585+
except Exception as e_2:
3586+
_LOGGER.warning("Exception happened in initializing TensorRT-LLM plugins", e)
3587+
else:
3588+
3589+
@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
3590+
def insert_nccl_gather_op(
3591+
ctx: ConversionContext,
3592+
target: Target,
3593+
args: Tuple[Argument, ...],
3594+
kwargs: Dict[str, Argument],
3595+
name: str,
3596+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3597+
return impl.distributed.gather_op(
3598+
ctx,
3599+
target,
3600+
SourceIR.ATEN,
3601+
name,
3602+
args[0],
3603+
)
3604+
3605+
@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op)
3606+
def insert_nccl_reduce_scatter_plugin(
3607+
ctx: ConversionContext,
3608+
target: Target,
3609+
args: Tuple[Argument, ...],
3610+
kwargs: Dict[str, Argument],
3611+
name: str,
3612+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3613+
return impl.distributed.gather_op(
3614+
ctx,
3615+
target,
3616+
SourceIR.ATEN,
3617+
name,
3618+
args[0],
3619+
)

py/torch_tensorrt/dynamo/conversion/converter_utils.py

+7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import collections
22
import functools
33
import logging
4+
from pathlib import Path
45
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload
56

67
import numpy as np
@@ -913,3 +914,9 @@ def set_layer_name(
913914
else f"{source_ir}_ops.{target.__name__}"
914915
)
915916
layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]"
917+
918+
919+
def plugin_lib_path() -> str:
920+
project_dir = Path(__file__).parent.parent.parent.parent.absolute()
921+
dyn_lib = "libnvinfer_plugin_tensorrt_llm.so"
922+
return str(project_dir.joinpath("libs", dyn_lib))

0 commit comments

Comments
 (0)