diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 421336b4872a5a..e008669ca8ad79 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -227,6 +227,18 @@ ProcessGroupXCCL::ProcessGroupXCCL( : Backend(rank, size), store_(store) { blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false); init(); + + { + int local_rank = getXCCLEnvVar("LOCAL_RANK"); + int local_world_size = getXCCLEnvVar("LOCAL_WORLD_SIZE"); + if (local_rank == -1 || local_world_size == -1) { + local_rank = rank; + local_world_size = size; + } + setXCCLEnvVar("CCL_PROCESS_LAUNCHER", "none"); + setXCCLEnvVar("CCL_LOCAL_RANK", local_rank); + setXCCLEnvVar("CCL_LOCAL_SIZE", local_world_size); + } } ProcessGroupXCCL::~ProcessGroupXCCL() = default; diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 7bb3a14d6e1446..eca66a33922d55 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -31,6 +31,32 @@ #include namespace c10d { +namespace { +int getXCCLEnvVar(std::string envVarName) { + char* stringValue = std::getenv(envVarName.c_str()); + if (stringValue != nullptr) { + try { + int val = std::stoi(stringValue); + return val; + } catch (std::exception& e) { + TORCH_CHECK( + false, + "Invalid value for environment variable: " + std::string(envVarName)); + } + } else { + return -1; + } +} + +void setXCCLEnvVar(std::string envVarName, int val) { + setenv(envVarName.c_str(), std::to_string(val).c_str(), val); +} + +void setXCCLEnvVar(std::string envVarName, std::string val) { + setenv(envVarName.c_str(), val.c_str(), 1); +} +} // namespace + static std::vector TORCH_XCCL_BLOCKING_WAIT = { "TORCH_XCCL_BLOCKING_WAIT", "XCCL_BLOCKING_WAIT"};