|
| 1 | +import logging |
| 2 | +import os |
| 3 | +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import tensorrt as trt |
| 7 | +import torch |
| 8 | +import torch.distributed as dist |
| 9 | +from torch.distributed._tensor.device_mesh import init_device_mesh |
| 10 | + |
| 11 | + |
| 12 | +def find_repo_root(max_depth=10): |
| 13 | + dir_path = os.path.dirname(os.path.realpath(__file__)) |
| 14 | + for i in range(max_depth): |
| 15 | + files = os.listdir(dir_path) |
| 16 | + if "MODULE.bazel" in files: |
| 17 | + return dir_path |
| 18 | + else: |
| 19 | + dir_path = os.path.dirname(dir_path) |
| 20 | + |
| 21 | + raise RuntimeError("Could not find repo root") |
| 22 | + |
| 23 | + |
| 24 | +def initialize_logger(rank, logger_file_name): |
| 25 | + logger = logging.getLogger() |
| 26 | + logger.setLevel(logging.INFO) |
| 27 | + fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w") |
| 28 | + fh.setLevel(logging.INFO) |
| 29 | + logger.addHandler(fh) |
| 30 | + return logger |
| 31 | + |
| 32 | + |
| 33 | +# This is required for env initialization since we use mpirun |
| 34 | +def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500): |
| 35 | + local_rank = int( |
| 36 | + os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count()) |
| 37 | + ) |
| 38 | + world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size)) |
| 39 | + |
| 40 | + # Set up environment variable to run with mpirun |
| 41 | + os.environ["RANK"] = str(local_rank) |
| 42 | + os.environ["WORLD_SIZE"] = str(world_size) |
| 43 | + os.environ["MASTER_ADDR"] = "127.0.0.1" |
| 44 | + os.environ["MASTER_PORT"] = str(port) |
| 45 | + # Note this will not work in the initialization here |
| 46 | + # You would need to set it externally as a user |
| 47 | + os.environ["trtllm_env"] = ( |
| 48 | + find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so" |
| 49 | + ) |
| 50 | + |
| 51 | + # Necessary to assign a device to each rank. |
| 52 | + torch.cuda.set_device(local_rank) |
| 53 | + |
| 54 | + # We use nccl backend |
| 55 | + dist.init_process_group("nccl") |
| 56 | + |
| 57 | + # set a manual seed for reproducibility |
| 58 | + torch.manual_seed(1111) |
| 59 | + |
| 60 | + device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) |
| 61 | + rank = device_mesh.get_rank() |
| 62 | + assert rank == local_rank |
| 63 | + logger = initialize_logger(rank, logger_file_name) |
| 64 | + device_id = ( |
| 65 | + rank % torch.cuda.device_count() |
| 66 | + ) # Ensure each rank gets a unique device |
| 67 | + torch.cuda.set_device(device_id) |
| 68 | + |
| 69 | + return device_mesh, world_size, rank, logger |
0 commit comments