|
| 1 | +import ctypes |
| 2 | +import logging |
| 3 | +import os |
| 4 | +import site |
| 5 | +from enum import IntEnum, IntFlag, auto |
| 6 | +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +import tensorrt as trt |
| 10 | +import tensorrt_llm |
| 11 | +import torch |
| 12 | +import torch.distributed as dist |
| 13 | +import torch_tensorrt |
| 14 | +from torch.distributed._tensor.device_mesh import init_device_mesh |
| 15 | +from torch.fx import GraphModule, Node |
| 16 | +from torch.fx.node import Argument, Target |
| 17 | +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext |
| 18 | +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( |
| 19 | + dynamo_tensorrt_converter, |
| 20 | +) |
| 21 | +from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import ( |
| 22 | + custom_fused_all_gather_op, |
| 23 | + custom_fused_reduce_scatter_op, |
| 24 | +) |
| 25 | +from torch_tensorrt.dynamo.types import TRTTensor |
| 26 | +from torch_tensorrt.fx.converters.converter_utils import set_layer_name |
| 27 | + |
| 28 | + |
| 29 | +# class for AllReduce |
| 30 | +class AllReduceStrategy(IntEnum): |
| 31 | + """Warning: actual definition is in kernels/customAllReduceKernels.h. |
| 32 | +
|
| 33 | + They must be kept in sync. |
| 34 | + """ |
| 35 | + |
| 36 | + NCCL = 0 |
| 37 | + ONESHOT = 1 |
| 38 | + TWOSHOT = 2 |
| 39 | + AUTO = 3 |
| 40 | + |
| 41 | + |
| 42 | +class AllReduceConfig(IntFlag): |
| 43 | + """Warning: actual definition is in kernels/customAllReduceKernels.h. |
| 44 | +
|
| 45 | + They must be kept in sync |
| 46 | + """ |
| 47 | + |
| 48 | + USE_MEMCPY = auto() |
| 49 | + PUSH_MODE = auto() |
| 50 | + |
| 51 | + |
| 52 | +def initialize_logger(rank, logger_file_name): |
| 53 | + logger = logging.getLogger() |
| 54 | + logger.setLevel(logging.INFO) |
| 55 | + fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") |
| 56 | + fh.setLevel(logging.INFO) |
| 57 | + logger.addHandler(fh) |
| 58 | + return logger |
| 59 | + |
| 60 | + |
| 61 | +# This is required for env initialization since we use mpirun |
| 62 | +def initialize_distributed_env(rank=0, world_size=1, port=29500): |
| 63 | + local_rank = int( |
| 64 | + os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) |
| 65 | + ) |
| 66 | + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) |
| 67 | + |
| 68 | + # Set up environment variable to run with mpirun |
| 69 | + os.environ["RANK"] = str(local_rank) |
| 70 | + os.environ["WORLD_SIZE"] = str(world_size) |
| 71 | + os.environ["MASTER_ADDR"] = "127.0.0.1" |
| 72 | + os.environ["MASTER_PORT"] = str(port) |
| 73 | + |
| 74 | + # Necessary to assign a device to each rank. |
| 75 | + torch.cuda.set_device(local_rank) |
| 76 | + |
| 77 | + # We use nccl backend |
| 78 | + dist.init_process_group("nccl") |
| 79 | + |
| 80 | + # set a manual seed for reproducibility |
| 81 | + torch.manual_seed(1111) |
| 82 | + |
| 83 | + return local_rank, world_size |
| 84 | + |
| 85 | + |
| 86 | +def register_nccl_ops(logger_file_name): |
| 87 | + # Initialization |
| 88 | + initialize_distributed_env() |
| 89 | + # create a device mesh based on the given world_size. |
| 90 | + _world_size = int(os.environ["WORLD_SIZE"]) |
| 91 | + |
| 92 | + device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,)) |
| 93 | + _rank = device_mesh.get_rank() |
| 94 | + logger = initialize_logger(_rank, logger_file_name) |
| 95 | + device_id = ( |
| 96 | + _rank % torch.cuda.device_count() |
| 97 | + ) # Ensure each rank gets a unique device |
| 98 | + torch.cuda.set_device(device_id) |
| 99 | + |
| 100 | + # TensorRT NCCL plugins |
| 101 | + # Iterate over all registered plugin creators |
| 102 | + plugin_registry = trt.get_plugin_registry() |
| 103 | + for plugin_creator in plugin_registry.plugin_creator_list: |
| 104 | + logger.info( |
| 105 | + f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}" |
| 106 | + ) |
| 107 | + |
| 108 | + @dynamo_tensorrt_converter(custom_fused_all_gather_op) |
| 109 | + def insert_nccl_gather_op( |
| 110 | + ctx: ConversionContext, |
| 111 | + target: Target, |
| 112 | + args: Tuple[Argument, ...], |
| 113 | + kwargs: Dict[str, Argument], |
| 114 | + name: str, |
| 115 | + ) -> Union[TRTTensor, Sequence[TRTTensor]]: |
| 116 | + plug_inputs = [args[0]] |
| 117 | + allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator( |
| 118 | + "AllGather", "1", "tensorrt_llm" |
| 119 | + ) |
| 120 | + assert allgather_plg_creator is not None |
| 121 | + _world_size = int(os.environ["WORLD_SIZE"]) |
| 122 | + group = list(range(_world_size)) |
| 123 | + group = trt.PluginField( |
| 124 | + "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32 |
| 125 | + ) |
| 126 | + p_dtype = trt.float16 |
| 127 | + pf_type = trt.PluginField( |
| 128 | + "type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32 |
| 129 | + ) |
| 130 | + pfc = trt.PluginFieldCollection([group, pf_type]) |
| 131 | + allgather = allgather_plg_creator.create_plugin("allgather", pfc) |
| 132 | + layer = ctx.net.add_plugin_v2(plug_inputs, allgather) |
| 133 | + set_layer_name(layer, target, name) |
| 134 | + return layer.get_output(0) |
| 135 | + |
| 136 | + @dynamo_tensorrt_converter(custom_fused_reduce_scatter_op) |
| 137 | + def insert_nccl_reduce_scatter_plugin( |
| 138 | + ctx: ConversionContext, |
| 139 | + target: Target, |
| 140 | + args: Tuple[Argument, ...], |
| 141 | + kwargs: Dict[str, Argument], |
| 142 | + name: str, |
| 143 | + ) -> Union[TRTTensor, Sequence[TRTTensor]]: |
| 144 | + plug_inputs = [args[0]] |
| 145 | + allreduce_plg_creator = trt.get_plugin_registry().get_plugin_creator( |
| 146 | + "ReduceScatter", "1", "tensorrt_llm" |
| 147 | + ) |
| 148 | + |
| 149 | + assert allreduce_plg_creator is not None |
| 150 | + |
| 151 | + counter = 0 |
| 152 | + strategy = AllReduceStrategy.NCCL |
| 153 | + config = AllReduceConfig(0) |
| 154 | + |
| 155 | + world_size = dist.get_world_size() |
| 156 | + group = list(range(world_size)) |
| 157 | + group = trt.PluginField( |
| 158 | + "group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32 |
| 159 | + ) |
| 160 | + |
| 161 | + p_dtype = trt.float16 |
| 162 | + pf_dtype = trt.PluginField( |
| 163 | + "type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32 |
| 164 | + ) |
| 165 | + pfc = [group, pf_dtype] |
| 166 | + p_strategy = trt.PluginField( |
| 167 | + "strategy", np.array([int(strategy)], np.int8), trt.PluginFieldType.INT8 |
| 168 | + ) |
| 169 | + pfc.append(p_strategy) |
| 170 | + p_config = trt.PluginField( |
| 171 | + "config", np.array([int(config)], np.int8), trt.PluginFieldType.INT8 |
| 172 | + ) |
| 173 | + pfc.append(p_config) |
| 174 | + p_counter = trt.PluginField( |
| 175 | + "counter", np.array([counter], np.int32), trt.PluginFieldType.INT32 |
| 176 | + ) |
| 177 | + pfc.append(p_counter) |
| 178 | + |
| 179 | + pfc = trt.PluginFieldCollection(pfc) |
| 180 | + ar_plug = allreduce_plg_creator.create_plugin("allreduce", pfc) |
| 181 | + |
| 182 | + layer = ctx.net.add_plugin_v2(plug_inputs, ar_plug) |
| 183 | + set_layer_name(layer, target, name) |
| 184 | + return layer.get_output(0) |
| 185 | + |
| 186 | + return device_mesh, _world_size, _rank, logger |
0 commit comments