From 97a3ea4d76c7578d3c22ebadcb4d1bfea5de3249 Mon Sep 17 00:00:00 2001 From: zizhao Date: Thu, 4 Jan 2024 12:26:00 +0200 Subject: [PATCH] support spill to disk --- .../compat/spark_2_4/UcxShuffleClient.scala | 55 +++--- .../compat/spark_3_0/UcxShuffleClient.scala | 56 +++--- .../ucx/CommonUcxShuffleBlockResolver.scala | 4 +- .../spark/shuffle/ucx/ShuffleTransport.scala | 2 +- .../spark/shuffle/ucx/UcxFetchCallBack.scala | 2 +- .../spark/shuffle/ucx/UcxShuffleConf.scala | 7 + .../shuffle/ucx/UcxShuffleTransport.scala | 43 ++++- .../spark/shuffle/ucx/UcxStreamState.scala | 5 + .../spark/shuffle/ucx/UcxWorkerWrapper.scala | 176 +++++++++++++++++- .../shuffle/ucx/perf/UcxPerfBenchmark.scala | 4 +- .../ucx/rpc/GlobalWorkerRpcThread.scala | 11 +- 11 files changed, 297 insertions(+), 68 deletions(-) create mode 100644 src/main/scala/org/apache/spark/shuffle/ucx/UcxStreamState.scala diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala index 81e3dd5f..37e9f63a 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_4/UcxShuffleClient.scala @@ -1,39 +1,42 @@ package org.apache.spark.shuffle.compat.spark_2_4 -import org.openucx.jucx.UcxUtils -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ShuffleClient} -import org.apache.spark.shuffle.ucx.{OperationCallback, OperationResult, UcxShuffleBockId, UcxShuffleTransport} -import org.apache.spark.shuffle.utils.UnsafeUtils +import org.apache.spark.shuffle.ucx.{UcxFetchCallBack, UcxDownloadCallBack, UcxShuffleBockId, UcxShuffleTransport} import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId} class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient{ override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String], listener: BlockFetchingListener, downloadFileManager: DownloadFileManager): Unit = { - val ucxBlockIds = Array.ofDim[UcxShuffleBockId](blockIds.length) - val callbacks = Array.ofDim[OperationCallback](blockIds.length) - for (i <- blockIds.indices) { - val blockId = SparkBlockId.apply(blockIds(i)).asInstanceOf[SparkShuffleBlockId] - ucxBlockIds(i) = UcxShuffleBockId(blockId.shuffleId, blockId.mapId, blockId.reduceId) - callbacks(i) = (result: OperationResult) => { - val memBlock = result.getData - val buffer = UnsafeUtils.getByteBufferView(memBlock.address, memBlock.size.toInt) - listener.onBlockFetchSuccess(blockIds(i), new NioManagedBuffer(buffer) { - override def release: ManagedBuffer = { - memBlock.close() - this - } - }) + if (downloadFileManager == null) { + val ucxBlockIds = Array.ofDim[UcxShuffleBockId](blockIds.length) + val callbacks = Array.ofDim[UcxFetchCallBack](blockIds.length) + for (i <- blockIds.indices) { + val blockId = SparkBlockId.apply(blockIds(i)) + .asInstanceOf[SparkShuffleBlockId] + ucxBlockIds(i) = UcxShuffleBockId(blockId.shuffleId, blockId.mapId, + blockId.reduceId) + callbacks(i) = new UcxFetchCallBack(blockIds(i), listener) + } + val maxBlocksPerRequest= transport.maxBlocksPerRequest + val resultBufferAllocator = transport.hostBounceBufferMemoryPool.get _ + for (i <- 0 until blockIds.length by maxBlocksPerRequest) { + val j = i + maxBlocksPerRequest + transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds.slice(i, j), + resultBufferAllocator, + callbacks.slice(i, j)) + } + } else { + for (i <- blockIds.indices) { + val blockId = SparkBlockId.apply(blockIds(i)) + .asInstanceOf[SparkShuffleBlockId] + val ucxBlockId = UcxShuffleBockId(blockId.shuffleId, blockId.mapId, + blockId.reduceId) + val callback = new UcxDownloadCallBack(blockIds(i), listener, + downloadFileManager, + transport.sparkTransportConf) + transport.fetchBlockByStream(execId.toLong, ucxBlockId, callback) } - } - val maxBlocksPerRequest= transport.maxBlocksPerRequest - val resultBufferAllocator = (size: Long) => transport.hostBounceBufferMemoryPool.get(size) - for (i <- 0 until blockIds.length by maxBlocksPerRequest) { - val j = i + maxBlocksPerRequest - transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds.slice(i, j), - resultBufferAllocator, - callbacks.slice(i, j)) } } diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala index 42c53bb6..7ee0d443 100755 --- a/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_3_0/UcxShuffleClient.scala @@ -5,10 +5,8 @@ package org.apache.spark.shuffle.compat.spark_3_0 import org.apache.spark.internal.Logging -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.shuffle.{BlockFetchingListener, BlockStoreClient, DownloadFileManager} -import org.apache.spark.shuffle.ucx.{OperationCallback, OperationResult, UcxShuffleBockId, UcxShuffleTransport} -import org.apache.spark.shuffle.utils.UnsafeUtils +import org.apache.spark.shuffle.ucx.{UcxFetchCallBack, UcxDownloadCallBack, UcxShuffleBockId, UcxShuffleTransport} import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId} class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Map[Long, Int]) extends BlockStoreClient with Logging { @@ -16,29 +14,37 @@ class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Ma override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String], listener: BlockFetchingListener, downloadFileManager: DownloadFileManager): Unit = { - val ucxBlockIds = Array.ofDim[UcxShuffleBockId](blockIds.length) - val callbacks = Array.ofDim[OperationCallback](blockIds.length) - for (i <- blockIds.indices) { - val blockId = SparkBlockId.apply(blockIds(i)).asInstanceOf[SparkShuffleBlockId] - ucxBlockIds(i) = UcxShuffleBockId(blockId.shuffleId, mapId2PartitionId(blockId.mapId), blockId.reduceId) - callbacks(i) = (result: OperationResult) => { - val memBlock = result.getData - val buffer = UnsafeUtils.getByteBufferView(memBlock.address, memBlock.size.toInt) - listener.onBlockFetchSuccess(blockIds(i), new NioManagedBuffer(buffer) { - override def release: ManagedBuffer = { - memBlock.close() - this - } - }) + if (downloadFileManager == null) { + val ucxBlockIds = Array.ofDim[UcxShuffleBockId](blockIds.length) + val callbacks = Array.ofDim[UcxFetchCallBack](blockIds.length) + for (i <- blockIds.indices) { + val blockId = SparkBlockId.apply(blockIds(i)) + .asInstanceOf[SparkShuffleBlockId] + ucxBlockIds(i) = UcxShuffleBockId(blockId.shuffleId, + mapId2PartitionId(blockId.mapId), + blockId.reduceId) + callbacks(i) = new UcxFetchCallBack(blockIds(i), listener) + } + val maxBlocksPerRequest= transport.maxBlocksPerRequest + val resultBufferAllocator = transport.hostBounceBufferMemoryPool.get _ + for (i <- 0 until blockIds.length by maxBlocksPerRequest) { + val j = i + maxBlocksPerRequest + transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds.slice(i, j), + resultBufferAllocator, + callbacks.slice(i, j)) + } + } else { + for (i <- blockIds.indices) { + val blockId = SparkBlockId.apply(blockIds(i)) + .asInstanceOf[SparkShuffleBlockId] + val ucxBlockId = UcxShuffleBockId(blockId.shuffleId, + mapId2PartitionId(blockId.mapId), + blockId.reduceId) + val callback = new UcxDownloadCallBack(blockIds(i), listener, + downloadFileManager, + transport.sparkTransportConf) + transport.fetchBlockByStream(execId.toLong, ucxBlockId, callback) } - } - val maxBlocksPerRequest= transport.maxBlocksPerRequest - val resultBufferAllocator = (size: Long) => transport.hostBounceBufferMemoryPool.get(size) - for (i <- 0 until blockIds.length by maxBlocksPerRequest) { - val j = i + maxBlocksPerRequest - transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds.slice(i, j), - resultBufferAllocator, - callbacks.slice(i, j)) } } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/CommonUcxShuffleBlockResolver.scala b/src/main/scala/org/apache/spark/shuffle/ucx/CommonUcxShuffleBlockResolver.scala index 41a2337f..7230d0be 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/CommonUcxShuffleBlockResolver.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/CommonUcxShuffleBlockResolver.scala @@ -44,8 +44,8 @@ abstract class CommonUcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffle val block = new Block { private val fileOffset = offset - override def getBlock(byteBuffer: ByteBuffer): Unit = { - channel.read(byteBuffer, fileOffset) + override def getBlock(byteBuffer: ByteBuffer, offset: Long): Unit = { + channel.read(byteBuffer, fileOffset + offset) } override def getSize: Long = blockLength diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala index 27f5d78c..68344409 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala @@ -43,7 +43,7 @@ trait Block extends BlockLock { def getMemoryBlock: MemoryBlock = ??? // Get block from a file into byte buffer backed bunce buffer - def getBlock(byteBuffer: ByteBuffer): Unit + def getBlock(byteBuffer: ByteBuffer, offset: Long): Unit } object OperationStatus extends Enumeration { diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxFetchCallBack.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxFetchCallBack.scala index 49e21631..ee03001f 100644 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxFetchCallBack.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxFetchCallBack.scala @@ -47,4 +47,4 @@ class UcxDownloadCallBack( targetFile.delete(); } } -} \ No newline at end of file +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala index 86a26e20..a9741599 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala @@ -91,4 +91,11 @@ class UcxShuffleConf(sparkConf: SparkConf) extends SparkConf { .createWithDefault(50) lazy val maxBlocksPerRequest: Int = sparkConf.getInt(MAX_BLOCKS_IN_FLIGHT.key, MAX_BLOCKS_IN_FLIGHT.defaultValue.get) + + private lazy val MAX_REPLY_SIZE = ConfigBuilder(getUcxConf("maxReplySize")) + .doc("Maximum size of fetch reply message") + .bytesConf(ByteUnit.MiB) + .createWithDefault(32) + + lazy val maxReplySize: Long = sparkConf.getSizeAsBytes(MAX_REPLY_SIZE.key, MAX_REPLY_SIZE.defaultValueString) } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index d23082ee..5366e8fc 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -90,7 +90,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo private var progressThread: Thread = _ var hostBounceBufferMemoryPool: UcxHostBounceBuffersPool = _ - private[spark] lazy val replyWorkersThreadPool = + private[spark] lazy val replyThreadPool = ThreadUtils.newDaemonFixedThreadPool(ucxShuffleConf.numListenerThreads, "UcxListenerThread") private[spark] lazy val sparkTransportConf = SparkTransportConf.fromSparkConf( @@ -282,8 +282,14 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo override def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId], resultBufferAllocator: BufferAllocator, callbacks: Seq[OperationCallback]): Seq[Request] = { - allocatedClientWorkers((Thread.currentThread().getId % allocatedClientWorkers.length).toInt) - .fetchBlocksByBlockIds(executorId, blockIds, resultBufferAllocator, callbacks) + selectClientWorker.fetchBlocksByBlockIds(executorId, blockIds, + resultBufferAllocator, + callbacks) + } + + def fetchBlockByStream(executorId: ExecutorId, blockId: BlockId, + callback: OperationCallback): Unit = { + selectClientWorker.fetchBlockByStream(executorId, blockId, callback) } def connectServerWorkers(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = { @@ -291,7 +297,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo } def handleFetchBlockRequest(replyTag: Int, amData: UcpAmData, replyExecutor: Long): Unit = { - replyWorkersThreadPool.submit(new Runnable { + replyThreadPool.submit(new Runnable { override def run(): Unit = { val buffer = UnsafeUtils.getByteBufferView(amData.getDataAddress, amData.getLength.toInt) @@ -308,13 +314,34 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo val blocks = blockIds.map(bid => registeredBlocks(bid)) amData.close() - allocatedServerWorkers( - (Thread.currentThread().getId % allocatedServerWorkers.length).toInt) - .handleFetchBlockRequest(blocks, replyTag, replyExecutor) + selectServerWorker.handleFetchBlockRequest(blocks, replyTag, + replyExecutor) + } + }) + } + + def handleFetchBlockStream(replyTag: Int, blockId: BlockId, + replyExecutor: Long): Unit = { + replyThreadPool.submit(new Runnable { + override def run(): Unit = { + val block = registeredBlocks(blockId) + selectServerWorker.handleFetchBlockStream(block, replyTag, + replyExecutor) } }) } + @inline + def selectClientWorker(): UcxWorkerWrapper = { + allocatedClientWorkers( + (Thread.currentThread().getId % allocatedClientWorkers.length).toInt) + } + + @inline + def selectServerWorker(): UcxWorkerWrapper = { + allocatedServerWorkers( + (Thread.currentThread().getId % allocatedServerWorkers.length).toInt) + } /** * Progress outstanding operations. This routine is blocking (though may poll for event). @@ -324,7 +351,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo * But not guaranteed that at least one [[ fetchBlocksByBlockIds ]] completed! */ override def progress(): Unit = { - allocatedClientWorkers((Thread.currentThread().getId % allocatedClientWorkers.length).toInt).progress() + selectClientWorker.progress() } def progressConnect(): Unit = { diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxStreamState.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxStreamState.scala new file mode 100644 index 00000000..b38975e1 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxStreamState.scala @@ -0,0 +1,5 @@ +package org.apache.spark.shuffle.ucx + +class UcxStreamState(val callback: OperationCallback, + val request: UcxRequest, + var remaining: Int) {} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 2bb68cac..b7347247 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -5,7 +5,7 @@ package org.apache.spark.shuffle.ucx import java.io.Closeable -import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch} import java.util.concurrent.atomic.AtomicInteger import scala.collection.concurrent.TrieMap import scala.util.Random @@ -24,6 +24,16 @@ import java.nio.ByteBuffer import scala.collection.parallel.ForkJoinTaskSupport +class UcxSucceedOperationResult(mem: MemoryBlock, stats: OperationStats) extends OperationResult { + override def getStatus: OperationStatus.Value = OperationStatus.SUCCESS + + override def getError: TransportError = null + + override def getStats: Option[OperationStats] = Some(stats) + + override def getData: MemoryBlock = mem +} + class UcxFailureOperationResult(errorMsg: String) extends OperationResult { override def getStatus: OperationStatus.Value = OperationStatus.FAILURE @@ -65,6 +75,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i private final val connections = new TrieMap[transport.ExecutorId, UcpEndpoint] private val requestData = new TrieMap[Int, (Seq[OperationCallback], UcxRequest, transport.BufferAllocator)] + private[ucx] lazy val streamData = new TrieMap[Int, UcxStreamState] private val tag = new AtomicInteger(Random.nextInt()) private val flushRequests = new ConcurrentLinkedQueue[UcpRequest]() @@ -72,6 +83,74 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i transport.ucxShuffleConf.numIoThreads) private val ioTaskSupport = new ForkJoinTaskSupport(ioThreadPool) + private[ucx] lazy val maxReplySize = transport.ucxShuffleConf.maxReplySize + private[ucx] lazy val memPool = transport.hostBounceBufferMemoryPool + + private[this] case class UcxStreamReplyHandle() extends UcpAmRecvCallback() { + override def onReceive(headerAddress: Long, headerSize: Long, + ucpAmData: UcpAmData, ep: UcpEndpoint): Int = { + val headerBuffer = UnsafeUtils.getByteBufferView(headerAddress, + headerSize.toInt) + val i = headerBuffer.getInt + val remaining = headerBuffer.getInt + + val data = streamData.get(i) + if (data.isEmpty) { + throw new UcxException(s"Stream tag $i context not found.") + } + + val streamState = data.get + if (remaining >= streamState.remaining) { + throw new UcxException( + s"Stream tag $i out of order $remaining <= ${streamState.remaining}.") + } + streamState.remaining = remaining + + val stats = streamState.request.getStats.get.asInstanceOf[UcxStats] + stats.receiveSize += ucpAmData.getLength + + if (ucpAmData.isDataValid) { + stats.endTime = System.nanoTime() + logDebug(s"Stream receive amData ${ucpAmData} tag $i in " + + s"${stats.getElapsedTimeNs} ns") + val buffer = UnsafeUtils.getByteBufferView( + ucpAmData.getDataAddress, ucpAmData.getLength.toInt) + streamState.callback.onData(buffer) + if (remaining == 0) { + streamState.callback.onComplete( + new UcxSucceedOperationResult(null, stats)) + streamData.remove(i) + } + } else { + val mem = memPool.get(ucpAmData.getLength) + stats.amHandleTime = System.nanoTime() + worker.recvAmDataNonBlocking( + ucpAmData.getDataHandle, mem.address, ucpAmData.getLength, + new UcxCallback() { + override def onSuccess(r: UcpRequest): Unit = { + stats.endTime = System.nanoTime() + logDebug(s"Stream receive rndv data ${ucpAmData.getLength} " + + s"tag $i in ${stats.getElapsedTimeNs} ns amHandle " + + s"${stats.endTime - stats.amHandleTime} ns") + val buffer = UnsafeUtils.getByteBufferView( + mem.address, ucpAmData.getLength.toInt) + streamState.callback.onData(buffer) + mem.close() + if (remaining == 0) { + streamState.callback.onComplete( + new UcxSucceedOperationResult(null, stats)) + streamData.remove(i) + } + } + }, UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + } + UcsConstants.STATUS.UCS_OK + } + } + + // Receive block data handler + worker.setAmRecvHandler(2, UcxStreamReplyHandle(), UcpConstants.UCP_AM_FLAG_WHOLE_MSG) + if (isClientWorker) { // Receive block data handler worker.setAmRecvHandler(1, @@ -297,7 +376,7 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i } for (i <- blocksCollection) { - blocks(i).getBlock(localBuffers(i)) + blocks(i).getBlock(localBuffers(i), 0) } val startTime = System.nanoTime() @@ -322,4 +401,97 @@ case class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, i case ex: Throwable => logError(s"Failed to read and send data: $ex") } + def fetchBlockByStream(executorId: transport.ExecutorId, blockId: BlockId, + callback: OperationCallback): Unit = { + val startTime = System.nanoTime() + val headerSize = UnsafeUtils.INT_SIZE + UnsafeUtils.LONG_SIZE + + blockId.serializedSize + + val t = tag.incrementAndGet() + + val buffer = Platform.allocateDirectBuffer(headerSize) + buffer.putInt(t) + buffer.putLong(id) + blockId.serialize(buffer) + + val request = new UcxRequest(null, new UcxStats()) + streamData.put(t, new UcxStreamState(callback, request, Int.MaxValue)) + + val address = UnsafeUtils.getAdress(buffer) + + val ep = getConnection(executorId) + worker.synchronized { + ep.sendAmNonBlocking(2, address, headerSize, address, 0, + UcpConstants.UCP_AM_SEND_FLAG_EAGER, new UcxCallback() { + override def onSuccess(request: UcpRequest): Unit = { + buffer.clear() + logDebug(s"Worker $id sent stream to $executorId block $blockId " + + s"tag $t in ${System.nanoTime() - startTime} ns") + } + }, MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + } + } + + def handleFetchBlockStream(block: Block, replyTag: Int, + replyExecutor: Long): Unit = { + val headerSize = UnsafeUtils.INT_SIZE + UnsafeUtils.INT_SIZE + val maxBodySize = maxReplySize - headerSize.toLong + val blockSize = block.getSize + val blockSlice = (0L until blockSize by maxBodySize) + val firstLatch = new CountDownLatch(1) + + def send(workerWrapper: UcxWorkerWrapper, currentId: Int, + sendLatch: CountDownLatch): Unit = try { + val mem = memPool.get(maxReplySize) + .asInstanceOf[UcxBounceBufferMemoryBlock] + val buffer = UcxUtils.getByteBufferView(mem.address, mem.size) + + val remaining = blockSlice.length - currentId - 1 + val currentOffset = blockSlice(currentId) + val currentSize = (blockSize - currentOffset).min(maxBodySize) + buffer.limit(headerSize + currentSize.toInt) + buffer.putInt(replyTag) + buffer.putInt(remaining) + block.getBlock(buffer, currentOffset) + + val nextLatch = new CountDownLatch(1) + sendLatch.await() + + val startTime = System.nanoTime() + val ep = workerWrapper.connections(replyExecutor) + val req = workerWrapper.worker.synchronized { + ep.sendAmNonBlocking(2, mem.address, headerSize, + mem.address + headerSize, currentSize, 0, new UcxCallback { + override def onSuccess(request: UcpRequest): Unit = { + logTrace(s"Reply stream block $currentId size $currentSize tag " + + s"$replyTag in ${System.nanoTime() - startTime} ns.") + mem.close() + nextLatch.countDown() + } + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + logError(s"Failed to reply stream $errorMsg") + mem.close() + nextLatch.countDown() + } + }, new UcpRequestParams() + .setMemoryType(UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + .setMemoryHandle(mem.memory)) + } + if (remaining > 0) { + transport.replyThreadPool.submit(new Runnable { + override def run = send(transport.selectServerWorker, currentId + 1, + nextLatch) + }) + } + while (!req.isCompleted) { + progress() + } + } catch { + case ex: Throwable => + logError(s"Failed to reply stream tag $replyTag id $currentId $ex.") + } + + firstLatch.countDown() + send(this, 0, firstLatch) + } } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala b/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala index 16bd821b..433c7f05 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxPerfBenchmark.scala @@ -195,8 +195,8 @@ object UcxPerfBenchmark extends App with Logging { override def getSize: Long = options.blockSize - override def getBlock(byteBuffer: ByteBuffer): Unit = { - channel.read(byteBuffer, fileOffset) + override def getBlock(byteBuffer: ByteBuffer, offset: Long): Unit = { + channel.read(byteBuffer, fileOffset + offset) } } ucxTransport.register(blockId, block) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala index b005606f..40e28376 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala @@ -7,7 +7,7 @@ package org.apache.spark.shuffle.ucx.rpc import org.openucx.jucx.ucp.{UcpAmData, UcpConstants, UcpEndpoint, UcpWorker} import org.openucx.jucx.ucs.UcsConstants import org.apache.spark.internal.Logging -import org.apache.spark.shuffle.ucx.UcxShuffleTransport +import org.apache.spark.shuffle.ucx.{UcxShuffleTransport, UcxShuffleBockId} import org.apache.spark.shuffle.utils.UnsafeUtils class GlobalWorkerRpcThread(globalWorker: UcpWorker, transport: UcxShuffleTransport) @@ -35,6 +35,15 @@ class GlobalWorkerRpcThread(globalWorker: UcpWorker, transport: UcxShuffleTransp UcsConstants.STATUS.UCS_OK }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG) + globalWorker.setAmRecvHandler(2, (headerAddress: Long, headerSize: Long, amData: UcpAmData, _: UcpEndpoint) => { + val header = UnsafeUtils.getByteBufferView(headerAddress, headerSize.toInt) + val replyTag = header.getInt + val replyExecutor = header.getLong + val blockId = UcxShuffleBockId.deserialize(header) + transport.handleFetchBlockStream(replyTag, blockId, replyExecutor) + UcsConstants.STATUS.UCS_OK + }, UcpConstants.UCP_AM_FLAG_WHOLE_MSG ) + override def run(): Unit = { if (transport.ucxShuffleConf.useWakeup) { while (!isInterrupted) {