Skip to content

Commit

Permalink
support spill to disk
Browse files Browse the repository at this point in the history
  • Loading branch information
JeynmannZ committed Jan 4, 2024
1 parent 3f8b34e commit 97a3ea4
Show file tree
Hide file tree
Showing 11 changed files with 297 additions and 68 deletions.
Original file line number Diff line number Diff line change
@@ -1,39 +1,42 @@
package org.apache.spark.shuffle.compat.spark_2_4

import org.openucx.jucx.UcxUtils
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ShuffleClient}
import org.apache.spark.shuffle.ucx.{OperationCallback, OperationResult, UcxShuffleBockId, UcxShuffleTransport}
import org.apache.spark.shuffle.utils.UnsafeUtils
import org.apache.spark.shuffle.ucx.{UcxFetchCallBack, UcxDownloadCallBack, UcxShuffleBockId, UcxShuffleTransport}
import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId}

class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient{
override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String],
listener: BlockFetchingListener,
downloadFileManager: DownloadFileManager): Unit = {
val ucxBlockIds = Array.ofDim[UcxShuffleBockId](blockIds.length)
val callbacks = Array.ofDim[OperationCallback](blockIds.length)
for (i <- blockIds.indices) {
val blockId = SparkBlockId.apply(blockIds(i)).asInstanceOf[SparkShuffleBlockId]
ucxBlockIds(i) = UcxShuffleBockId(blockId.shuffleId, blockId.mapId, blockId.reduceId)
callbacks(i) = (result: OperationResult) => {
val memBlock = result.getData
val buffer = UnsafeUtils.getByteBufferView(memBlock.address, memBlock.size.toInt)
listener.onBlockFetchSuccess(blockIds(i), new NioManagedBuffer(buffer) {
override def release: ManagedBuffer = {
memBlock.close()
this
}
})
if (downloadFileManager == null) {
val ucxBlockIds = Array.ofDim[UcxShuffleBockId](blockIds.length)
val callbacks = Array.ofDim[UcxFetchCallBack](blockIds.length)
for (i <- blockIds.indices) {
val blockId = SparkBlockId.apply(blockIds(i))
.asInstanceOf[SparkShuffleBlockId]
ucxBlockIds(i) = UcxShuffleBockId(blockId.shuffleId, blockId.mapId,
blockId.reduceId)
callbacks(i) = new UcxFetchCallBack(blockIds(i), listener)
}
val maxBlocksPerRequest= transport.maxBlocksPerRequest
val resultBufferAllocator = transport.hostBounceBufferMemoryPool.get _
for (i <- 0 until blockIds.length by maxBlocksPerRequest) {
val j = i + maxBlocksPerRequest
transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds.slice(i, j),
resultBufferAllocator,
callbacks.slice(i, j))
}
} else {
for (i <- blockIds.indices) {
val blockId = SparkBlockId.apply(blockIds(i))
.asInstanceOf[SparkShuffleBlockId]
val ucxBlockId = UcxShuffleBockId(blockId.shuffleId, blockId.mapId,
blockId.reduceId)
val callback = new UcxDownloadCallBack(blockIds(i), listener,
downloadFileManager,
transport.sparkTransportConf)
transport.fetchBlockByStream(execId.toLong, ucxBlockId, callback)
}
}
val maxBlocksPerRequest= transport.maxBlocksPerRequest
val resultBufferAllocator = (size: Long) => transport.hostBounceBufferMemoryPool.get(size)
for (i <- 0 until blockIds.length by maxBlocksPerRequest) {
val j = i + maxBlocksPerRequest
transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds.slice(i, j),
resultBufferAllocator,
callbacks.slice(i, j))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,46 @@
package org.apache.spark.shuffle.compat.spark_3_0

import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, BlockStoreClient, DownloadFileManager}
import org.apache.spark.shuffle.ucx.{OperationCallback, OperationResult, UcxShuffleBockId, UcxShuffleTransport}
import org.apache.spark.shuffle.utils.UnsafeUtils
import org.apache.spark.shuffle.ucx.{UcxFetchCallBack, UcxDownloadCallBack, UcxShuffleBockId, UcxShuffleTransport}
import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId}

class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Map[Long, Int]) extends BlockStoreClient with Logging {

override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String],
listener: BlockFetchingListener,
downloadFileManager: DownloadFileManager): Unit = {
val ucxBlockIds = Array.ofDim[UcxShuffleBockId](blockIds.length)
val callbacks = Array.ofDim[OperationCallback](blockIds.length)
for (i <- blockIds.indices) {
val blockId = SparkBlockId.apply(blockIds(i)).asInstanceOf[SparkShuffleBlockId]
ucxBlockIds(i) = UcxShuffleBockId(blockId.shuffleId, mapId2PartitionId(blockId.mapId), blockId.reduceId)
callbacks(i) = (result: OperationResult) => {
val memBlock = result.getData
val buffer = UnsafeUtils.getByteBufferView(memBlock.address, memBlock.size.toInt)
listener.onBlockFetchSuccess(blockIds(i), new NioManagedBuffer(buffer) {
override def release: ManagedBuffer = {
memBlock.close()
this
}
})
if (downloadFileManager == null) {
val ucxBlockIds = Array.ofDim[UcxShuffleBockId](blockIds.length)
val callbacks = Array.ofDim[UcxFetchCallBack](blockIds.length)
for (i <- blockIds.indices) {
val blockId = SparkBlockId.apply(blockIds(i))
.asInstanceOf[SparkShuffleBlockId]
ucxBlockIds(i) = UcxShuffleBockId(blockId.shuffleId,
mapId2PartitionId(blockId.mapId),
blockId.reduceId)
callbacks(i) = new UcxFetchCallBack(blockIds(i), listener)
}
val maxBlocksPerRequest= transport.maxBlocksPerRequest
val resultBufferAllocator = transport.hostBounceBufferMemoryPool.get _
for (i <- 0 until blockIds.length by maxBlocksPerRequest) {
val j = i + maxBlocksPerRequest
transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds.slice(i, j),
resultBufferAllocator,
callbacks.slice(i, j))
}
} else {
for (i <- blockIds.indices) {
val blockId = SparkBlockId.apply(blockIds(i))
.asInstanceOf[SparkShuffleBlockId]
val ucxBlockId = UcxShuffleBockId(blockId.shuffleId,
mapId2PartitionId(blockId.mapId),
blockId.reduceId)
val callback = new UcxDownloadCallBack(blockIds(i), listener,
downloadFileManager,
transport.sparkTransportConf)
transport.fetchBlockByStream(execId.toLong, ucxBlockId, callback)
}
}
val maxBlocksPerRequest= transport.maxBlocksPerRequest
val resultBufferAllocator = (size: Long) => transport.hostBounceBufferMemoryPool.get(size)
for (i <- 0 until blockIds.length by maxBlocksPerRequest) {
val j = i + maxBlocksPerRequest
transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds.slice(i, j),
resultBufferAllocator,
callbacks.slice(i, j))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ abstract class CommonUcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffle
val block = new Block {
private val fileOffset = offset

override def getBlock(byteBuffer: ByteBuffer): Unit = {
channel.read(byteBuffer, fileOffset)
override def getBlock(byteBuffer: ByteBuffer, offset: Long): Unit = {
channel.read(byteBuffer, fileOffset + offset)
}

override def getSize: Long = blockLength
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ trait Block extends BlockLock {
def getMemoryBlock: MemoryBlock = ???

// Get block from a file into byte buffer backed bunce buffer
def getBlock(byteBuffer: ByteBuffer): Unit
def getBlock(byteBuffer: ByteBuffer, offset: Long): Unit
}

object OperationStatus extends Enumeration {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,4 @@ class UcxDownloadCallBack(
targetFile.delete();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,11 @@ class UcxShuffleConf(sparkConf: SparkConf) extends SparkConf {
.createWithDefault(50)

lazy val maxBlocksPerRequest: Int = sparkConf.getInt(MAX_BLOCKS_IN_FLIGHT.key, MAX_BLOCKS_IN_FLIGHT.defaultValue.get)

private lazy val MAX_REPLY_SIZE = ConfigBuilder(getUcxConf("maxReplySize"))
.doc("Maximum size of fetch reply message")
.bytesConf(ByteUnit.MiB)
.createWithDefault(32)

lazy val maxReplySize: Long = sparkConf.getSizeAsBytes(MAX_REPLY_SIZE.key, MAX_REPLY_SIZE.defaultValueString)
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
private var progressThread: Thread = _
var hostBounceBufferMemoryPool: UcxHostBounceBuffersPool = _

private[spark] lazy val replyWorkersThreadPool =
private[spark] lazy val replyThreadPool =
ThreadUtils.newDaemonFixedThreadPool(ucxShuffleConf.numListenerThreads,
"UcxListenerThread")
private[spark] lazy val sparkTransportConf = SparkTransportConf.fromSparkConf(
Expand Down Expand Up @@ -282,16 +282,22 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
override def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId],
resultBufferAllocator: BufferAllocator,
callbacks: Seq[OperationCallback]): Seq[Request] = {
allocatedClientWorkers((Thread.currentThread().getId % allocatedClientWorkers.length).toInt)
.fetchBlocksByBlockIds(executorId, blockIds, resultBufferAllocator, callbacks)
selectClientWorker.fetchBlocksByBlockIds(executorId, blockIds,
resultBufferAllocator,
callbacks)
}

def fetchBlockByStream(executorId: ExecutorId, blockId: BlockId,
callback: OperationCallback): Unit = {
selectClientWorker.fetchBlockByStream(executorId, blockId, callback)
}

def connectServerWorkers(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = {
allocatedServerWorkers.foreach(w => w.connectByWorkerAddress(executorId, workerAddress))
}

def handleFetchBlockRequest(replyTag: Int, amData: UcpAmData, replyExecutor: Long): Unit = {
replyWorkersThreadPool.submit(new Runnable {
replyThreadPool.submit(new Runnable {
override def run(): Unit = {
val buffer = UnsafeUtils.getByteBufferView(amData.getDataAddress,
amData.getLength.toInt)
Expand All @@ -308,13 +314,34 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo

val blocks = blockIds.map(bid => registeredBlocks(bid))
amData.close()
allocatedServerWorkers(
(Thread.currentThread().getId % allocatedServerWorkers.length).toInt)
.handleFetchBlockRequest(blocks, replyTag, replyExecutor)
selectServerWorker.handleFetchBlockRequest(blocks, replyTag,
replyExecutor)
}
})
}

def handleFetchBlockStream(replyTag: Int, blockId: BlockId,
replyExecutor: Long): Unit = {
replyThreadPool.submit(new Runnable {
override def run(): Unit = {
val block = registeredBlocks(blockId)
selectServerWorker.handleFetchBlockStream(block, replyTag,
replyExecutor)
}
})
}

@inline
def selectClientWorker(): UcxWorkerWrapper = {
allocatedClientWorkers(
(Thread.currentThread().getId % allocatedClientWorkers.length).toInt)
}

@inline
def selectServerWorker(): UcxWorkerWrapper = {
allocatedServerWorkers(
(Thread.currentThread().getId % allocatedServerWorkers.length).toInt)
}

/**
* Progress outstanding operations. This routine is blocking (though may poll for event).
Expand All @@ -324,7 +351,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
* But not guaranteed that at least one [[ fetchBlocksByBlockIds ]] completed!
*/
override def progress(): Unit = {
allocatedClientWorkers((Thread.currentThread().getId % allocatedClientWorkers.length).toInt).progress()
selectClientWorker.progress()
}

def progressConnect(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package org.apache.spark.shuffle.ucx

class UcxStreamState(val callback: OperationCallback,
val request: UcxRequest,
var remaining: Int) {}
Loading

0 comments on commit 97a3ea4

Please sign in to comment.