diff --git a/lua/SlurmTree.lua b/lua/SlurmTree.lua index b12adea..ba10312 100644 --- a/lua/SlurmTree.lua +++ b/lua/SlurmTree.lua @@ -2,7 +2,7 @@ local ipc = require 'libipc' local Tree = require 'ipc.Tree' local NullTree = require 'ipc.NullTree' -local function SlurmTree(fn, tasksPerGpu) +local function SlurmTree(fn, tasksPerGpu, hostAddress) tasksPerGpu = tasksPerGpu or 1 local slurmProcId = tonumber(os.getenv("SLURM_PROCID")) local numNodes = tonumber(os.getenv("SLURM_NTASKS")) @@ -72,7 +72,7 @@ local function SlurmTree(fn, tasksPerGpu) if numNodes == 1 then tree = NullTree() else - local nodeHost = sys.execute('/bin/hostname') + local nodeHost = hostAddress or sys.execute('/bin/hostname') local nodePort = nil if nodeIndex == 1 then local server,nodePort = ipc.server(nodeHost, nodePort)