diff --git a/engines/python/src/main/java/ai/djl/python/engine/Connection.java b/engines/python/src/main/java/ai/djl/python/engine/Connection.java index 81bb3a155..cb89096f8 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/Connection.java +++ b/engines/python/src/main/java/ai/djl/python/engine/Connection.java @@ -67,9 +67,9 @@ class Connection { private Channel channel; private RequestHandler requestHandler; - Connection(PyEnv pyEnv, int workerId, int rank) { + Connection(PyEnv pyEnv, int basePort, int rank) { requestHandler = new RequestHandler(); - port = 19000 + workerId; + port = 19000 + basePort; socketAddress = getSocketAddress(pyEnv.isMpiMode(), rank); } @@ -99,6 +99,8 @@ CompletableFuture send(Input input) throws InterruptedException { static String[] getPythonStartCmd(PyEnv pyEnv, Model model, int workerId, int port) { int tensorParallelDegree = pyEnv.getTensorParallelDegree(); if (pyEnv.isMpiMode()) { + String cudaDevices = getVisibleDevices(workerId, tensorParallelDegree); + logger.info("Set CUDA_VISIBLE_DEVICES={}", cudaDevices); String[] args = new String[36]; args[0] = "mpirun"; args[1] = "-N"; @@ -122,7 +124,7 @@ static String[] getPythonStartCmd(PyEnv pyEnv, Model model, int workerId, int po args[16] = "-x"; args[17] = "PYTHONPATH"; args[18] = "-x"; - args[19] = "CUDA_VISIBLE_DEVICES=" + getVisibleDevices(workerId, tensorParallelDegree); + args[19] = "CUDA_VISIBLE_DEVICES=" + cudaDevices; args[20] = "-x"; args[21] = "MASTER_ADDR=" + MASTER_ADDR; args[22] = "-x"; diff --git a/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java b/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java index 10f1de9b1..f9936cab3 100644 --- a/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java +++ b/engines/python/src/main/java/ai/djl/python/engine/PyProcess.java @@ -57,15 +57,16 @@ class PyProcess { PyProcess(Model model, PyEnv pyEnv, int workerId) { this.model = model; this.pyEnv = pyEnv; - this.workerId = workerId + counter.getAndIncrement(); + this.workerId = workerId; + int port = workerId + counter.getAndIncrement(); if (pyEnv.isMpiMode()) { int tensorParallelDegree = pyEnv.getTensorParallelDegree(); connections = new ArrayList<>(tensorParallelDegree); for (int i = 0; i < tensorParallelDegree; ++i) { - connections.add(new Connection(pyEnv, this.workerId, i)); + connections.add(new Connection(pyEnv, port, i)); } } else { - connections = Collections.singletonList(new Connection(pyEnv, this.workerId, -1)); + connections = Collections.singletonList(new Connection(pyEnv, port, -1)); } restartCount = new AtomicInteger(0); }