Skip to content

Commit

Permalink
simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
JeynmannZ committed Jun 3, 2024
1 parent 9a931e9 commit 53ce91c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
* connection establishment outside of UcxShuffleManager.
*/
override def addExecutor(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = {
executorAddresses.put(executorId, workerAddress)
allocatedClientWorkers.foreach(_.getConnection(executorId))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package org.apache.spark.shuffle.ucx

import java.io.Closeable
import java.util.concurrent.{ConcurrentLinkedQueue, Callable, Future, FutureTask}
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.concurrent.TrieMap
import scala.util.Random
Expand Down Expand Up @@ -211,13 +210,12 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i

def getConnection(executorId: transport.ExecutorId): UcpEndpoint = {

if (!connections.contains(executorId)) {
if (!transport.executorAddresses.contains(executorId)) {
val startTime = System.currentTimeMillis()
while (!transport.executorAddresses.contains(executorId)) {
if (System.currentTimeMillis() - startTime > timeout) {
throw new UcxException(s"Don't get a worker address for $executorId")
}
if ((!connections.contains(executorId)) &&
(!transport.executorAddresses.contains(executorId))) {
val startTime = System.currentTimeMillis()
while (!transport.executorAddresses.contains(executorId)) {
if (System.currentTimeMillis() - startTime > timeout) {
throw new UcxException(s"Don't get a worker address for $executorId")
}
}
}
Expand Down

0 comments on commit 53ce91c

Please sign in to comment.