Skip to content

Commit 7fbb857

Browse files
authored
chore: [NCCL] reorg and better error messages (#3338)
Signed-off-by: Naren Dasan <[email protected]>
1 parent b77a971 commit 7fbb857

File tree

11 files changed

+1578
-319
lines changed

11 files changed

+1578
-319
lines changed

py/torch_tensorrt/dynamo/conversion/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
from . import aten_ops_converters, ops_evaluators, plugins, prims_ops_converters
1+
from . import (
2+
aten_ops_converters,
3+
custom_ops_converters,
4+
ops_evaluators,
5+
plugins,
6+
prims_ops_converters,
7+
)
28
from ._conversion import convert_module, interpret_module_to_result
39
from ._ConversionContext import ConversionContext
410
from ._ConverterRegistry import * # noqa: F403

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+1-45
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
import logging
44
import operator
5-
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
5+
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
66

77
import numpy as np
8-
import tensorrt as trt
98
import torch
109
from torch.fx.node import Argument, Node, Target
1110
from torch_tensorrt.dynamo._settings import CompilationSettings
@@ -21,11 +20,6 @@
2120
enforce_tensor_types,
2221
get_positive_dim,
2322
is_only_operator_on_placeholder,
24-
load_tensorrt_llm,
25-
)
26-
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
27-
tensorrt_fused_nccl_all_gather_op,
28-
tensorrt_fused_nccl_reduce_scatter_op,
2923
)
3024
from torch_tensorrt.dynamo.types import TRTTensor
3125

@@ -3591,41 +3585,3 @@ def aten_ops_full(
35913585
fill_value=args[1],
35923586
dtype=kwargs.get("dtype", None),
35933587
)
3594-
3595-
3596-
if load_tensorrt_llm():
3597-
3598-
@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
3599-
def insert_nccl_gather_op(
3600-
ctx: ConversionContext,
3601-
target: Target,
3602-
args: Tuple[Argument, ...],
3603-
kwargs: Dict[str, Argument],
3604-
name: str,
3605-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3606-
return impl.distributed.gather_op(
3607-
ctx,
3608-
target,
3609-
SourceIR.ATEN,
3610-
name,
3611-
[args[0]],
3612-
)
3613-
3614-
@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op)
3615-
def insert_nccl_reduce_scatter_plugin(
3616-
ctx: ConversionContext,
3617-
target: Target,
3618-
args: Tuple[Argument, ...],
3619-
kwargs: Dict[str, Argument],
3620-
name: str,
3621-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3622-
return impl.distributed.reduce_scatter_op(
3623-
ctx,
3624-
target,
3625-
SourceIR.ATEN,
3626-
name,
3627-
[args[0]],
3628-
)
3629-
3630-
else:
3631-
_LOGGER.warning("Unable to load the TRT-LLM plugins")

py/torch_tensorrt/dynamo/conversion/converter_utils.py

+12-16
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload
77

88
import numpy as np
9-
import tensorrt as trt
109
import torch
1110
import torch_tensorrt.dynamo.conversion.impl as impl
1211
from torch.fx.node import Argument, Target
@@ -20,6 +19,8 @@
2019
DynamoConverterImplSignature,
2120
)
2221

22+
import tensorrt as trt
23+
2324
from ..types import Shape, TRTDataType, TRTLayer, TRTTensor
2425

2526
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -931,33 +932,28 @@ def load_tensorrt_llm() -> bool:
931932
bool: True if the plugin was successfully loaded and initialized, False otherwise.
932933
"""
933934
try:
934-
import tensorrt_llm as trt_llm
935+
import tensorrt_llm as trt_llm # noqa: F401
935936

936-
_LOGGER.info("TensorRT_LLM successfully imported.")
937+
_LOGGER.info("TensorRT-LLM successfully imported")
937938
return True
938939
except (ImportError, AssertionError) as e_import_error:
939-
_LOGGER.warning(
940-
"TensorRT_LLM is not installed. Please install TensorRT_LLM or set TRTLLM_PLUGINS_PATH",
941-
exc_info=e_import_error,
942-
)
943-
944940
# Check for environment variable for the plugin library path
945941
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
946942
if not plugin_lib_path:
947943
_LOGGER.warning(
948-
"Please specify a valid path for TRTLLM_PLUGINS_PATH libnvinfer_plugin_tensorrt_llm.so when using distributed examples in examples/distributed_inference."
944+
"TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops",
949945
)
950946
return False
951947

952-
_LOGGER.info(f"Plugin lib path found: {plugin_lib_path}")
948+
_LOGGER.info(f"TensorRT-LLM Plugin lib path found: {plugin_lib_path}")
953949
try:
954950
# Load the shared library
955951
handle = ctypes.CDLL(plugin_lib_path)
956952
_LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}")
957953
except OSError as e_os_error:
958954
_LOGGER.error(
959-
f"Failed to load the shared library at {plugin_lib_path}. "
960-
f"Ensure the path is correct and the library is compatible.",
955+
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}"
956+
f"Ensure the path is correct and the library is compatible",
961957
exc_info=e_os_error,
962958
)
963959
return False
@@ -968,7 +964,7 @@ def load_tensorrt_llm() -> bool:
968964
handle.initTrtLlmPlugins.restype = ctypes.c_bool
969965
except AttributeError as e_plugin_unavailable:
970966
_LOGGER.warning(
971-
"TensorRT-LLM Plugin initialization function is unavailable.",
967+
"Unable to initialize the TensorRT-LLM plugin library",
972968
exc_info=e_plugin_unavailable,
973969
)
974970
return False
@@ -977,14 +973,14 @@ def load_tensorrt_llm() -> bool:
977973
# Initialize the plugin
978974
TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
979975
if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")):
980-
_LOGGER.info("TensorRT-LLM Plugin successfully initialized.")
976+
_LOGGER.info("TensorRT-LLM plugin successfully initialized")
981977
return True
982978
else:
983-
_LOGGER.warning("TensorRT-LLM Plugin initialization failed.")
979+
_LOGGER.warning("TensorRT-LLM plugin library failed in initialization")
984980
return False
985981
except Exception as e_initialization_error:
986982
_LOGGER.warning(
987-
"Exception occurred during TensorRT-LLM plugin initialization.",
983+
"Exception occurred during TensorRT-LLM plugin library initialization",
988984
exc_info=e_initialization_error,
989985
)
990986
return False
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# mypy: disallow-untyped-decorators=False
2+
3+
import logging
4+
from typing import Dict, Sequence, Tuple, Union
5+
6+
from torch.fx.node import Argument, Target
7+
from torch_tensorrt.dynamo._SourceIR import SourceIR
8+
from torch_tensorrt.dynamo.conversion import impl
9+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
10+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
11+
dynamo_tensorrt_converter,
12+
)
13+
from torch_tensorrt.dynamo.conversion.converter_utils import load_tensorrt_llm
14+
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
15+
tensorrt_fused_nccl_all_gather_op,
16+
tensorrt_fused_nccl_reduce_scatter_op,
17+
)
18+
19+
import tensorrt as trt
20+
21+
_LOGGER: logging.Logger = logging.getLogger(__name__)
22+
23+
if load_tensorrt_llm():
24+
25+
@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
26+
def fused_nccl_gather(
27+
ctx: ConversionContext,
28+
target: Target,
29+
args: Tuple[Argument, ...],
30+
kwargs: Dict[str, Argument],
31+
name: str,
32+
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
33+
return impl.distributed.nccl_gather(
34+
ctx,
35+
target,
36+
SourceIR.ATEN,
37+
name,
38+
[args[0]],
39+
)
40+
41+
@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op)
42+
def fused_nccl_reduce_scatter(
43+
ctx: ConversionContext,
44+
target: Target,
45+
args: Tuple[Argument, ...],
46+
kwargs: Dict[str, Argument],
47+
name: str,
48+
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
49+
return impl.distributed.nccl_reduce_scatter(
50+
ctx,
51+
target,
52+
SourceIR.ATEN,
53+
name,
54+
[args[0]],
55+
)
56+
57+
breakpoint()
58+
else:
59+
_LOGGER.debug(
60+
"Did not load torch.distributed converters since TensorRT-LLM is not available"
61+
)

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Any, Callable, Optional, Union
44

55
import numpy as np
6-
import tensorrt as trt
76
import torch
87
from torch.fx.node import Target
98
from torch_tensorrt import _enums
@@ -14,12 +13,13 @@
1413
broadcast_to_same_shape,
1514
cast_trt_tensor,
1615
get_trt_tensor,
17-
broadcast,
1816
has_dynamic_shape,
1917
set_layer_name,
2018
)
2119
from torch_tensorrt.dynamo.types import TRTElementWiseOp, TRTTensor
2220

21+
import tensorrt as trt
22+
2323

2424
def get_python_op_from_trt_elementwise_op(
2525
trt_op: TRTElementWiseOp,

py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import os
22
from enum import IntEnum, IntFlag, auto
3-
from typing import Optional, Sequence, Union
3+
from typing import Optional, Tuple, Union
44

55
import numpy as np
6-
import tensorrt as trt
7-
import torch
8-
from torch.fx.node import Target
6+
from torch.fx.node import Argument, Target
97
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
108
from torch_tensorrt.fx.converters.converter_utils import SourceIR, set_layer_name
119

10+
import tensorrt as trt
11+
1212

1313
# class for AllReduce
1414
class AllReduceStrategy(IntEnum):
@@ -33,25 +33,25 @@ class AllReduceConfig(IntFlag):
3333
PUSH_MODE = auto()
3434

3535

36-
def gather_op(
36+
def nccl_gather(
3737
ctx: ConversionContext,
3838
target: Union[Target, str],
3939
source_ir: Optional[SourceIR],
4040
name: str,
41-
plug_inputs,
42-
):
41+
plug_inputs: Tuple[Argument, ...],
42+
) -> trt.ITensor:
4343
allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator(
4444
"AllGather", "1", "tensorrt_llm"
4545
)
4646
assert allgather_plg_creator is not None
4747
_world_size = os.environ.get("WORLD_SIZE")
4848
if _world_size is not None:
49-
_world_size = int(_world_size)
49+
world_size = int(_world_size)
5050
else:
5151
raise RuntimeError(
52-
f"The WORLD_SIZE env variable is not set in distributed environment"
52+
"The WORLD_SIZE env variable is not set in distributed environment"
5353
)
54-
group = list(range(_world_size))
54+
group = list(range(world_size))
5555
group = trt.PluginField(
5656
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
5757
)
@@ -66,13 +66,13 @@ def gather_op(
6666
return layer.get_output(0)
6767

6868

69-
def reduce_scatter_op(
69+
def nccl_reduce_scatter(
7070
ctx: ConversionContext,
7171
target: Union[Target, str],
7272
source_ir: Optional[SourceIR],
7373
name: str,
74-
plug_inputs,
75-
):
74+
plug_inputs: Tuple[Argument, ...],
75+
) -> trt.ITensor:
7676
allreduce_plg_creator = trt.get_plugin_registry().get_plugin_creator(
7777
"ReduceScatter", "1", "tensorrt_llm"
7878
)
@@ -84,12 +84,12 @@ def reduce_scatter_op(
8484
config = AllReduceConfig(0)
8585
_world_size = os.environ.get("WORLD_SIZE")
8686
if _world_size is not None:
87-
_world_size = int(_world_size)
87+
world_size = int(_world_size)
8888
else:
8989
raise RuntimeError(
90-
f"The WORLD_SIZE env variable is not set in distributed environment"
90+
"The WORLD_SIZE env variable is not set in distributed environment"
9191
)
92-
group = list(range(_world_size))
92+
group = list(range(world_size))
9393
group = trt.PluginField(
9494
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
9595
)

py/torch_tensorrt/dynamo/conversion/impl/select.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Optional, Sequence, Union
33

44
import numpy as np
5-
import tensorrt as trt
65
import torch
76
from torch.fx.node import Target
87
from torch_tensorrt.dynamo._SourceIR import SourceIR
@@ -16,7 +15,6 @@
1615
to_numpy,
1716
)
1817
from torch_tensorrt.dynamo.conversion.impl.elementwise import convert_binary_elementwise
19-
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
2018
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
2119
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
2220
from torch_tensorrt.fx.converters.converter_utils import (
@@ -25,6 +23,8 @@
2523
)
2624
from torch_tensorrt.fx.types import TRTTensor
2725

26+
import tensorrt as trt
27+
2828
_LOGGER: logging.Logger = logging.getLogger(__name__)
2929

3030

0 commit comments

Comments
 (0)