Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support reply in slices for large block. #34

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,34 +1,43 @@
package org.apache.spark.shuffle.compat.spark_2_4

import org.openucx.jucx.UcxUtils
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ShuffleClient}
import org.apache.spark.shuffle.ucx.{OperationCallback, OperationResult, UcxShuffleBockId, UcxShuffleTransport}
import org.apache.spark.shuffle.utils.UnsafeUtils
import org.apache.spark.shuffle.ucx.{UcxFetchCallBack, UcxDownloadCallBack, UcxShuffleBockId, UcxShuffleTransport}
import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId}

class UcxShuffleClient(val transport: UcxShuffleTransport) extends ShuffleClient{
override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String],
listener: BlockFetchingListener,
downloadFileManager: DownloadFileManager): Unit = {
val ucxBlockIds = Array.ofDim[UcxShuffleBockId](blockIds.length)
val callbacks = Array.ofDim[OperationCallback](blockIds.length)
for (i <- blockIds.indices) {
val blockId = SparkBlockId.apply(blockIds(i)).asInstanceOf[SparkShuffleBlockId]
ucxBlockIds(i) = UcxShuffleBockId(blockId.shuffleId, blockId.mapId, blockId.reduceId)
callbacks(i) = (result: OperationResult) => {
val memBlock = result.getData
val buffer = UnsafeUtils.getByteBufferView(memBlock.address, memBlock.size.toInt)
listener.onBlockFetchSuccess(blockIds(i), new NioManagedBuffer(buffer) {
override def release: ManagedBuffer = {
memBlock.close()
this
}
})
if (downloadFileManager == null) {
val ucxBlockIds = Array.ofDim[UcxShuffleBockId](blockIds.length)
val callbacks = Array.ofDim[UcxFetchCallBack](blockIds.length)
for (i <- blockIds.indices) {
val blockId = SparkBlockId.apply(blockIds(i))
.asInstanceOf[SparkShuffleBlockId]
ucxBlockIds(i) = UcxShuffleBockId(blockId.shuffleId, blockId.mapId,
blockId.reduceId)
callbacks(i) = new UcxFetchCallBack(blockIds(i), listener)
}
val maxBlocksPerRequest= transport.maxBlocksPerRequest
val resultBufferAllocator = transport.hostBounceBufferMemoryPool.get _
for (i <- 0 until blockIds.length by maxBlocksPerRequest) {
val j = i + maxBlocksPerRequest
transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds.slice(i, j),
resultBufferAllocator,
callbacks.slice(i, j))
}
} else {
for (i <- blockIds.indices) {
val blockId = SparkBlockId.apply(blockIds(i))
.asInstanceOf[SparkShuffleBlockId]
val ucxBlockId = UcxShuffleBockId(blockId.shuffleId, blockId.mapId,
blockId.reduceId)
val callback = new UcxDownloadCallBack(blockIds(i), listener,
downloadFileManager,
transport.sparkTransportConf)
transport.fetchBlockByStream(execId.toLong, ucxBlockId, callback)
}
}
val resultBufferAllocator = (size: Long) => transport.hostBounceBufferMemoryPool.get(size)
transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds, resultBufferAllocator, callbacks)
}

override def close(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,47 @@
package org.apache.spark.shuffle.compat.spark_3_0

import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, BlockStoreClient, DownloadFileManager}
import org.apache.spark.shuffle.ucx.{OperationCallback, OperationResult, UcxShuffleBockId, UcxShuffleTransport}
import org.apache.spark.shuffle.utils.UnsafeUtils
import org.apache.spark.shuffle.ucx.{UcxFetchCallBack, UcxDownloadCallBack, UcxShuffleBockId, UcxShuffleTransport}
import org.apache.spark.storage.{BlockId => SparkBlockId, ShuffleBlockId => SparkShuffleBlockId}

class UcxShuffleClient(val transport: UcxShuffleTransport, mapId2PartitionId: Map[Long, Int]) extends BlockStoreClient with Logging {

override def fetchBlocks(host: String, port: Int, execId: String, blockIds: Array[String],
listener: BlockFetchingListener,
downloadFileManager: DownloadFileManager): Unit = {
if (blockIds.length > transport.ucxShuffleConf.maxBlocksPerRequest) {
val (b1, b2) = blockIds.splitAt(blockIds.length / 2)
fetchBlocks(host, port, execId, b1, listener, downloadFileManager)
fetchBlocks(host, port, execId, b2, listener, downloadFileManager)
return
}

val ucxBlockIds = Array.ofDim[UcxShuffleBockId](blockIds.length)
val callbacks = Array.ofDim[OperationCallback](blockIds.length)
for (i <- blockIds.indices) {
val blockId = SparkBlockId.apply(blockIds(i)).asInstanceOf[SparkShuffleBlockId]
ucxBlockIds(i) = UcxShuffleBockId(blockId.shuffleId, mapId2PartitionId(blockId.mapId), blockId.reduceId)
callbacks(i) = (result: OperationResult) => {
val memBlock = result.getData
val buffer = UnsafeUtils.getByteBufferView(memBlock.address, memBlock.size.toInt)
listener.onBlockFetchSuccess(blockIds(i), new NioManagedBuffer(buffer) {
override def release: ManagedBuffer = {
memBlock.close()
this
}
})
if (downloadFileManager == null) {
val ucxBlockIds = Array.ofDim[UcxShuffleBockId](blockIds.length)
val callbacks = Array.ofDim[UcxFetchCallBack](blockIds.length)
for (i <- blockIds.indices) {
val blockId = SparkBlockId.apply(blockIds(i))
.asInstanceOf[SparkShuffleBlockId]
ucxBlockIds(i) = UcxShuffleBockId(blockId.shuffleId,
mapId2PartitionId(blockId.mapId),
blockId.reduceId)
callbacks(i) = new UcxFetchCallBack(blockIds(i), listener)
}
val maxBlocksPerRequest= transport.maxBlocksPerRequest
val resultBufferAllocator = transport.hostBounceBufferMemoryPool.get _
for (i <- 0 until blockIds.length by maxBlocksPerRequest) {
val j = i + maxBlocksPerRequest
transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds.slice(i, j),
resultBufferAllocator,
callbacks.slice(i, j))
}
} else {
for (i <- blockIds.indices) {
val blockId = SparkBlockId.apply(blockIds(i))
.asInstanceOf[SparkShuffleBlockId]
val ucxBlockId = UcxShuffleBockId(blockId.shuffleId,
mapId2PartitionId(blockId.mapId),
blockId.reduceId)
val callback = new UcxDownloadCallBack(blockIds(i), listener,
downloadFileManager,
transport.sparkTransportConf)
transport.fetchBlockByStream(execId.toLong, ucxBlockId, callback)
}
}
val resultBufferAllocator = (size: Long) => transport.hostBounceBufferMemoryPool.get(size)
transport.fetchBlocksByBlockIds(execId.toLong, ucxBlockIds, resultBufferAllocator, callbacks)
transport.progress()
}

override def close(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ abstract class CommonUcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffle
val block = new Block {
private val fileOffset = offset

override def getBlock(byteBuffer: ByteBuffer): Unit = {
channel.read(byteBuffer, fileOffset)
override def getBlock(byteBuffer: ByteBuffer, offset: Long): Unit = {
channel.read(byteBuffer, fileOffset + offset)
}

override def getSize: Long = blockLength
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ trait Block extends BlockLock {
def getMemoryBlock: MemoryBlock = ???

// Get block from a file into byte buffer backed bunce buffer
def getBlock(byteBuffer: ByteBuffer): Unit
def getBlock(byteBuffer: ByteBuffer, offset: Long): Unit
}

object OperationStatus extends Enumeration {
Expand Down Expand Up @@ -90,6 +90,7 @@ trait Request {
*/
trait OperationCallback {
def onComplete(result: OperationResult): Unit
def onData(buf: ByteBuffer): Unit = ???
}

/**
Expand Down
50 changes: 50 additions & 0 deletions src/main/scala/org/apache/spark/shuffle/ucx/UcxFetchCallBack.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package org.apache.spark.shuffle.ucx

import java.nio.ByteBuffer

import org.apache.spark.network.util.TransportConf
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager}

import org.apache.spark.shuffle.utils.UnsafeUtils

class UcxFetchCallBack(
blockId: String, listener: BlockFetchingListener)
extends OperationCallback {

override def onComplete(result: OperationResult): Unit = {
val memBlock = result.getData
val buffer = UnsafeUtils.getByteBufferView(memBlock.address,
memBlock.size.toInt)
listener.onBlockFetchSuccess(blockId, new NioManagedBuffer(buffer) {
override def release: ManagedBuffer = {
memBlock.close()
this
}
})
}
}

class UcxDownloadCallBack(
blockId: String, listener: BlockFetchingListener,
downloadFileManager: DownloadFileManager,
transportConf: TransportConf)
extends OperationCallback {

private[this] val targetFile = downloadFileManager.createTempFile(
transportConf)
private[this] val channel = targetFile.openForWriting();

override def onData(buffer: ByteBuffer): Unit = {
while (buffer.hasRemaining()) {
channel.write(buffer);
}
}

override def onComplete(result: OperationResult): Unit = {
listener.onBlockFetchSuccess(blockId, channel.closeAndRead());
if (!downloadFileManager.registerTempFileToClean(targetFile)) {
targetFile.delete();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,11 @@ class UcxShuffleConf(sparkConf: SparkConf) extends SparkConf {
.createWithDefault(50)

lazy val maxBlocksPerRequest: Int = sparkConf.getInt(MAX_BLOCKS_IN_FLIGHT.key, MAX_BLOCKS_IN_FLIGHT.defaultValue.get)

private lazy val MAX_REPLY_SIZE = ConfigBuilder(getUcxConf("maxReplySize"))
.doc("Maximum size of fetch reply message")
.bytesConf(ByteUnit.MiB)
.createWithDefault(32)

lazy val maxReplySize: Long = sparkConf.getSizeAsBytes(MAX_REPLY_SIZE.key, MAX_REPLY_SIZE.defaultValueString)
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
*/
package org.apache.spark.shuffle.ucx

import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.util.ThreadUtils
import org.apache.spark.shuffle.ucx.memory.UcxHostBounceBuffersPool
import org.apache.spark.shuffle.ucx.rpc.GlobalWorkerRpcThread
import org.apache.spark.shuffle.ucx.utils.{SerializableDirectBuffer, SerializationUtils}
Expand Down Expand Up @@ -88,6 +90,14 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
private var progressThread: Thread = _
var hostBounceBufferMemoryPool: UcxHostBounceBuffersPool = _

private[spark] lazy val replyThreadPool =
ThreadUtils.newDaemonFixedThreadPool(ucxShuffleConf.numListenerThreads,
"UcxListenerThread")
private[spark] lazy val sparkTransportConf = SparkTransportConf.fromSparkConf(
ucxShuffleConf.getSparkConf, "shuffle", ucxShuffleConf.numWorkers)
private[spark] lazy val maxBlocksPerRequest = maxBlocksInAmHeader.min(
ucxShuffleConf.maxBlocksPerRequest).toInt

private val errorHandler = new UcpEndpointErrorHandler {
override def onError(ucpEndpoint: UcpEndpoint, errorCode: Int, errorString: String): Unit = {
if (errorCode == UcsConstants.STATUS.UCS_ERR_CONNECTION_RESET) {
Expand Down Expand Up @@ -190,6 +200,10 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
}
}

def maxBlocksInAmHeader(): Long = {
(globalWorker.getMaxAmHeaderSize - 2) / UnsafeUtils.INT_SIZE
}

/**
* Add executor's worker address. For standalone testing purpose and for implementations that makes
* connection establishment outside of UcxShuffleManager.
Expand Down Expand Up @@ -268,33 +282,66 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
override def fetchBlocksByBlockIds(executorId: ExecutorId, blockIds: Seq[BlockId],
resultBufferAllocator: BufferAllocator,
callbacks: Seq[OperationCallback]): Seq[Request] = {
allocatedClientWorkers((Thread.currentThread().getId % allocatedClientWorkers.length).toInt)
.fetchBlocksByBlockIds(executorId, blockIds, resultBufferAllocator, callbacks)
selectClientWorker.fetchBlocksByBlockIds(executorId, blockIds,
resultBufferAllocator,
callbacks)
}

def fetchBlockByStream(executorId: ExecutorId, blockId: BlockId,
callback: OperationCallback): Unit = {
selectClientWorker.fetchBlockByStream(executorId, blockId, callback)
}

def connectServerWorkers(executorId: ExecutorId, workerAddress: ByteBuffer): Unit = {
allocatedServerWorkers.foreach(w => w.connectByWorkerAddress(executorId, workerAddress))
}

def handleFetchBlockRequest(replyTag: Int, amData: UcpAmData, replyExecutor: Long): Unit = {
val buffer = UnsafeUtils.getByteBufferView(amData.getDataAddress, amData.getLength.toInt)
val blockIds = mutable.ArrayBuffer.empty[BlockId]

// 1. Deserialize blockIds from header
while (buffer.remaining() > 0) {
val blockId = UcxShuffleBockId.deserialize(buffer)
if (!registeredBlocks.contains(blockId)) {
throw new UcxException(s"$blockId is not registered")
replyThreadPool.submit(new Runnable {
override def run(): Unit = {
val buffer = UnsafeUtils.getByteBufferView(amData.getDataAddress,
amData.getLength.toInt)
val blockIds = mutable.ArrayBuffer.empty[BlockId]

// 1. Deserialize blockIds from header
while (buffer.remaining() > 0) {
val blockId = UcxShuffleBockId.deserialize(buffer)
if (!registeredBlocks.contains(blockId)) {
throw new UcxException(s"$blockId is not registered")
}
blockIds += blockId
}

val blocks = blockIds.map(bid => registeredBlocks(bid))
amData.close()
selectServerWorker.handleFetchBlockRequest(blocks, replyTag,
replyExecutor)
}
blockIds += blockId
}
})
}

def handleFetchBlockStream(replyTag: Int, blockId: BlockId,
replyExecutor: Long): Unit = {
replyThreadPool.submit(new Runnable {
override def run(): Unit = {
val block = registeredBlocks(blockId)
selectServerWorker.handleFetchBlockStream(block, replyTag,
replyExecutor)
}
})
}

val blocks = blockIds.map(bid => registeredBlocks(bid))
amData.close()
allocatedServerWorkers((Thread.currentThread().getId % allocatedServerWorkers.length).toInt)
.handleFetchBlockRequest(blocks, replyTag, replyExecutor)
@inline
def selectClientWorker(): UcxWorkerWrapper = {
allocatedClientWorkers(
(Thread.currentThread().getId % allocatedClientWorkers.length).toInt)
}

@inline
def selectServerWorker(): UcxWorkerWrapper = {
allocatedServerWorkers(
(Thread.currentThread().getId % allocatedServerWorkers.length).toInt)
}

/**
* Progress outstanding operations. This routine is blocking (though may poll for event).
Expand All @@ -304,7 +351,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo
* But not guaranteed that at least one [[ fetchBlocksByBlockIds ]] completed!
*/
override def progress(): Unit = {
allocatedClientWorkers((Thread.currentThread().getId % allocatedClientWorkers.length).toInt).progress()
selectClientWorker.progress()
}

def progressConnect(): Unit = {
Expand Down
11 changes: 11 additions & 0 deletions src/main/scala/org/apache/spark/shuffle/ucx/UcxStreamState.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package org.apache.spark.shuffle.ucx

class UcxStreamState(val callback: OperationCallback,
val request: UcxRequest,
var remaining: Int) {}

class UcxSliceState(val callback: OperationCallback,
val request: UcxRequest,
val mem: MemoryBlock,
var offset: Long,
var remaining: Int) {}
Loading
Loading