Skip to content

Commit

Permalink
chore: [NCCL] reorg and better error messages (#3338)
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan authored and apbose committed Dec 31, 2024
1 parent 095cec0 commit 45b28b7
Show file tree
Hide file tree
Showing 11 changed files with 1,575 additions and 320 deletions.
8 changes: 7 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from . import aten_ops_converters, ops_evaluators, plugins, prims_ops_converters
from . import (
aten_ops_converters,
custom_ops_converters,
ops_evaluators,
plugins,
prims_ops_converters,
)
from ._conversion import convert_module, interpret_module_to_result
from ._ConversionContext import ConversionContext
from ._ConverterRegistry import * # noqa: F403
Expand Down
46 changes: 1 addition & 45 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import logging
import operator
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
from typing import Callable, Dict, Optional, Sequence, Tuple, Union

import numpy as np
import tensorrt as trt
import torch
from torch.fx.node import Argument, Node, Target
from torch_tensorrt.dynamo._settings import CompilationSettings
Expand All @@ -21,11 +20,6 @@
enforce_tensor_types,
get_positive_dim,
is_only_operator_on_placeholder,
load_tensorrt_llm,
)
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
tensorrt_fused_nccl_all_gather_op,
tensorrt_fused_nccl_reduce_scatter_op,
)
from torch_tensorrt.dynamo.types import TRTTensor

Expand Down Expand Up @@ -3591,41 +3585,3 @@ def aten_ops_full(
fill_value=args[1],
dtype=kwargs.get("dtype", None),
)


if load_tensorrt_llm():

@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
def insert_nccl_gather_op(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.distributed.gather_op(
ctx,
target,
SourceIR.ATEN,
name,
[args[0]],
)

@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op)
def insert_nccl_reduce_scatter_plugin(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.distributed.reduce_scatter_op(
ctx,
target,
SourceIR.ATEN,
name,
[args[0]],
)

else:
_LOGGER.warning("Unable to load the TRT-LLM plugins")
28 changes: 12 additions & 16 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload

import numpy as np
import tensorrt as trt
import torch
import torch_tensorrt.dynamo.conversion.impl as impl
from torch.fx.node import Argument, Target
Expand All @@ -20,6 +19,8 @@
DynamoConverterImplSignature,
)

import tensorrt as trt

from ..types import Shape, TRTDataType, TRTLayer, TRTTensor

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -931,33 +932,28 @@ def load_tensorrt_llm() -> bool:
bool: True if the plugin was successfully loaded and initialized, False otherwise.
"""
try:
import tensorrt_llm as trt_llm
import tensorrt_llm as trt_llm # noqa: F401

_LOGGER.info("TensorRT_LLM successfully imported.")
_LOGGER.info("TensorRT-LLM successfully imported")
return True
except (ImportError, AssertionError) as e_import_error:
_LOGGER.warning(
"TensorRT_LLM is not installed. Please install TensorRT_LLM or set TRTLLM_PLUGINS_PATH",
exc_info=e_import_error,
)

# Check for environment variable for the plugin library path
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
if not plugin_lib_path:
_LOGGER.warning(
"Please specify a valid path for TRTLLM_PLUGINS_PATH libnvinfer_plugin_tensorrt_llm.so when using distributed examples in examples/distributed_inference."
"TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops",
)
return False

_LOGGER.info(f"Plugin lib path found: {plugin_lib_path}")
_LOGGER.info(f"TensorRT-LLM Plugin lib path found: {plugin_lib_path}")
try:
# Load the shared library
handle = ctypes.CDLL(plugin_lib_path)
_LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}")
except OSError as e_os_error:
_LOGGER.error(
f"Failed to load the shared library at {plugin_lib_path}. "
f"Ensure the path is correct and the library is compatible.",
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}"
f"Ensure the path is correct and the library is compatible",
exc_info=e_os_error,
)
return False
Expand All @@ -968,7 +964,7 @@ def load_tensorrt_llm() -> bool:
handle.initTrtLlmPlugins.restype = ctypes.c_bool
except AttributeError as e_plugin_unavailable:
_LOGGER.warning(
"TensorRT-LLM Plugin initialization function is unavailable.",
"Unable to initialize the TensorRT-LLM plugin library",
exc_info=e_plugin_unavailable,
)
return False
Expand All @@ -977,14 +973,14 @@ def load_tensorrt_llm() -> bool:
# Initialize the plugin
TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")):
_LOGGER.info("TensorRT-LLM Plugin successfully initialized.")
_LOGGER.info("TensorRT-LLM plugin successfully initialized")
return True
else:
_LOGGER.warning("TensorRT-LLM Plugin initialization failed.")
_LOGGER.warning("TensorRT-LLM plugin library failed in initialization")
return False
except Exception as e_initialization_error:
_LOGGER.warning(
"Exception occurred during TensorRT-LLM plugin initialization.",
"Exception occurred during TensorRT-LLM plugin library initialization",
exc_info=e_initialization_error,
)
return False
Expand Down
61 changes: 61 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# mypy: disallow-untyped-decorators=False

import logging
from typing import Dict, Sequence, Tuple, Union

from torch.fx.node import Argument, Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
dynamo_tensorrt_converter,
)
from torch_tensorrt.dynamo.conversion.converter_utils import load_tensorrt_llm
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
tensorrt_fused_nccl_all_gather_op,
tensorrt_fused_nccl_reduce_scatter_op,
)

import tensorrt as trt

_LOGGER: logging.Logger = logging.getLogger(__name__)

if load_tensorrt_llm():

@dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op)
def fused_nccl_gather(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.distributed.nccl_gather(
ctx,
target,
SourceIR.ATEN,
name,
[args[0]],
)

@dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op)
def fused_nccl_reduce_scatter(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:
return impl.distributed.nccl_reduce_scatter(
ctx,
target,
SourceIR.ATEN,
name,
[args[0]],
)

breakpoint()
else:
_LOGGER.debug(
"Did not load torch.distributed converters since TensorRT-LLM is not available"
)
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any, Callable, Optional, Union

import numpy as np
import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt import _enums
Expand All @@ -14,12 +13,13 @@
broadcast_to_same_shape,
cast_trt_tensor,
get_trt_tensor,
broadcast,
has_dynamic_shape,
set_layer_name,
)
from torch_tensorrt.dynamo.types import TRTElementWiseOp, TRTTensor

import tensorrt as trt


def get_python_op_from_trt_elementwise_op(
trt_op: TRTElementWiseOp,
Expand Down
32 changes: 16 additions & 16 deletions py/torch_tensorrt/dynamo/conversion/impl/nccl_ops.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import os
from enum import IntEnum, IntFlag, auto
from typing import Optional, Sequence, Union
from typing import Optional, Tuple, Union

import numpy as np
import tensorrt as trt
import torch
from torch.fx.node import Target
from torch.fx.node import Argument, Target
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.converters.converter_utils import SourceIR, set_layer_name

import tensorrt as trt


# class for AllReduce
class AllReduceStrategy(IntEnum):
Expand All @@ -33,25 +33,25 @@ class AllReduceConfig(IntFlag):
PUSH_MODE = auto()


def gather_op(
def nccl_gather(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
plug_inputs,
):
plug_inputs: Tuple[Argument, ...],
) -> trt.ITensor:
allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator(
"AllGather", "1", "tensorrt_llm"
)
assert allgather_plg_creator is not None
_world_size = os.environ.get("WORLD_SIZE")
if _world_size is not None:
_world_size = int(_world_size)
world_size = int(_world_size)
else:
raise RuntimeError(
f"The WORLD_SIZE env variable is not set in distributed environment"
"The WORLD_SIZE env variable is not set in distributed environment"
)
group = list(range(_world_size))
group = list(range(world_size))
group = trt.PluginField(
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
)
Expand All @@ -66,13 +66,13 @@ def gather_op(
return layer.get_output(0)


def reduce_scatter_op(
def nccl_reduce_scatter(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
plug_inputs,
):
plug_inputs: Tuple[Argument, ...],
) -> trt.ITensor:
allreduce_plg_creator = trt.get_plugin_registry().get_plugin_creator(
"ReduceScatter", "1", "tensorrt_llm"
)
Expand All @@ -84,12 +84,12 @@ def reduce_scatter_op(
config = AllReduceConfig(0)
_world_size = os.environ.get("WORLD_SIZE")
if _world_size is not None:
_world_size = int(_world_size)
world_size = int(_world_size)
else:
raise RuntimeError(
f"The WORLD_SIZE env variable is not set in distributed environment"
"The WORLD_SIZE env variable is not set in distributed environment"
)
group = list(range(_world_size))
group = list(range(world_size))
group = trt.PluginField(
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
)
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Optional, Sequence, Union

import numpy as np
import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
Expand All @@ -16,7 +15,6 @@
to_numpy,
)
from torch_tensorrt.dynamo.conversion.impl.elementwise import convert_binary_elementwise
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
from torch_tensorrt.fx.converters.converter_utils import (
Expand All @@ -25,6 +23,8 @@
)
from torch_tensorrt.fx.types import TRTTensor

import tensorrt as trt

_LOGGER: logging.Logger = logging.getLogger(__name__)


Expand Down
Loading

0 comments on commit 45b28b7

Please sign in to comment.