Skip to content

Commit

Permalink
Modification to PR 45052, to illustrate and fix the issue being discu…
Browse files Browse the repository at this point in the history
…ssed w.r.t NPE
  • Loading branch information
Mridul Muralidharan committed Feb 25, 2024
1 parent 18b8606 commit 2d73adc
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 19 deletions.
1 change: 1 addition & 0 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ class SparkContext(config: SparkConf) extends Logging {
// Initialize any plugins before the task scheduler is initialized.
_plugins = PluginContainer(this, _resources.asJava)
_env.initializeShuffleManager()
_env.initializeMemoryManager(SparkContext.numDriverCores(master, conf))

// Create and start the scheduler
val (sched, ts) = SparkContext.createTaskScheduler(this, master)
Expand Down
20 changes: 15 additions & 5 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ class SparkEnv (
val blockManager: BlockManager,
val securityManager: SecurityManager,
val metricsSystem: MetricsSystem,
val memoryManager: MemoryManager,
val outputCommitCoordinator: OutputCommitCoordinator,
val conf: SparkConf) extends Logging {

Expand All @@ -77,6 +76,12 @@ class SparkEnv (

def shuffleManager: ShuffleManager = _shuffleManager

// We initialize the MemoryManager later in SparkContext after DriverPlugin is loaded
// to allow the plugin to overwrite memory configurations
private var _memoryManager: MemoryManager = _

def memoryManager: MemoryManager = _memoryManager

@volatile private[spark] var isStopped = false

/**
Expand Down Expand Up @@ -199,6 +204,12 @@ class SparkEnv (
"Shuffle manager already initialized to %s", _shuffleManager)
_shuffleManager = ShuffleManager.create(conf, executorId == SparkContext.DRIVER_IDENTIFIER)
}

private[spark] def initializeMemoryManager(numUsableCores: Int): Unit = {
Preconditions.checkState(null == memoryManager,
"Memory manager already initialized to %s", _memoryManager)
_memoryManager = UnifiedMemoryManager(conf, numUsableCores)
}
}

object SparkEnv extends Logging {
Expand Down Expand Up @@ -276,6 +287,8 @@ object SparkEnv extends Logging {
numCores,
ioEncryptionKey
)
// Set the memory manager since it needs to be initialized explicitly
env.initializeMemoryManager(numCores)
SparkEnv.set(env)
env
}
Expand Down Expand Up @@ -358,8 +371,6 @@ object SparkEnv extends Logging {
new MapOutputTrackerMasterEndpoint(
rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))

val memoryManager: MemoryManager = UnifiedMemoryManager(conf, numUsableCores)

val blockManagerPort = if (isDriver) {
conf.get(DRIVER_BLOCK_MANAGER_PORT)
} else {
Expand Down Expand Up @@ -418,7 +429,7 @@ object SparkEnv extends Logging {
blockManagerMaster,
serializerManager,
conf,
memoryManager,
_memoryManager = null,
mapOutputTracker,
_shuffleManager = null,
blockTransferService,
Expand Down Expand Up @@ -463,7 +474,6 @@ object SparkEnv extends Logging {
blockManager,
securityManager,
metricsSystem,
memoryManager,
outputCommitCoordinator,
conf)

Expand Down
18 changes: 13 additions & 5 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,16 @@ private[spark] abstract class Task[T](

require(cpus > 0, "CPUs per task should be > 0")

SparkEnv.get.blockManager.registerTask(taskAttemptId)
// Use the blockManager at start of the task through out the task - particularly in
// case of local mode, a SparkEnv can be initialized when spark context is restarted
// and we want to ensure the right env and block manager is used (given lazy initialization of
// block manager)
// For @sunchao - for illustrating the bug - use the def blockManager below and the
// test TaskContextSuite."Ensure the right block manager is used to unroll memory for task"
// will fail
val blockManager = SparkEnv.get.blockManager
// def blockManager = SparkEnv.get.blockManager
blockManager.registerTask(taskAttemptId)
// TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether
// the stage is barrier.
val taskContext = new TaskContextImpl(
Expand Down Expand Up @@ -143,15 +152,14 @@ private[spark] abstract class Task[T](
try {
Utils.tryLogNonFatalError {
// Release memory used by this thread for unrolling blocks
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(
MemoryMode.OFF_HEAP)
blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP)
// Notify any tasks waiting for execution memory to be freed to wake up and try to
// acquire memory again. This makes impossible the scenario where a task sleeps forever
// because there are no other tasks left to notify it. Since this is safe to do but may
// not be strictly necessary, we should revisit whether we can remove this in the
// future.
val memoryManager = SparkEnv.get.memoryManager
val memoryManager = blockManager.memoryManager
memoryManager.synchronized { memoryManager.notifyAll() }
}
} finally {
Expand Down
20 changes: 14 additions & 6 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ private[spark] class BlockManager(
val master: BlockManagerMaster,
val serializerManager: SerializerManager,
val conf: SparkConf,
memoryManager: MemoryManager,
private val _memoryManager: MemoryManager,
mapOutputTracker: MapOutputTracker,
private val _shuffleManager: ShuffleManager,
val blockTransferService: BlockTransferService,
Expand All @@ -198,6 +198,12 @@ private[spark] class BlockManager(
// (except for tests) and we ask for the instance from the SparkEnv.
private lazy val shuffleManager = Option(_shuffleManager).getOrElse(SparkEnv.get.shuffleManager)

// Similarly, we also initialize MemoryManager later after DriverPlugin is loaded, to
// allow the plugin to overwrite certain memory configurations. The `_memoryManager` will be
// null here and we ask for the instance from SparkEnv
private[spark] lazy val memoryManager =
Option(_memoryManager).getOrElse(SparkEnv.get.memoryManager)

// same as `conf.get(config.SHUFFLE_SERVICE_ENABLED)`
private[spark] val externalShuffleServiceEnabled: Boolean = externalBlockStoreClient.isDefined
private val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER
Expand All @@ -224,17 +230,19 @@ private[spark] class BlockManager(
ThreadUtils.newDaemonCachedThreadPool("block-manager-future", 128))

// Actual storage of where blocks are kept
private[spark] val memoryStore =
new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, this)
private[spark] lazy val memoryStore = {
val store = new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, this)
memoryManager.setMemoryStore(store)
store
}
private[spark] val diskStore = new DiskStore(conf, diskBlockManager, securityManager)
memoryManager.setMemoryStore(memoryStore)

// Note: depending on the memory manager, `maxMemory` may actually vary over time.
// However, since we use this only for reporting and logging, what we actually want here is
// the absolute maximum value that `maxMemory` can ever possibly reach. We may need
// to revisit whether reporting this value as the "max" is intuitive to the user.
private val maxOnHeapMemory = memoryManager.maxOnHeapStorageMemory
private val maxOffHeapMemory = memoryManager.maxOffHeapStorageMemory
private lazy val maxOnHeapMemory = memoryManager.maxOnHeapStorageMemory
private lazy val maxOffHeapMemory = memoryManager.maxOffHeapStorageMemory

private[spark] val externalShuffleServicePort = StorageUtils.externalShuffleServicePort(conf)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.TestUtils._
import org.apache.spark.api.plugin._
import org.apache.spark.internal.config._
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.memory.MemoryMode
import org.apache.spark.resource.ResourceInformation
import org.apache.spark.resource.ResourceUtils.GPU
import org.apache.spark.resource.TestResourceIDs.{DRIVER_GPU_ID, EXECUTOR_GPU_ID, WORKER_GPU_ID}
Expand Down Expand Up @@ -228,6 +229,58 @@ class PluginContainerSuite extends SparkFunSuite with LocalSparkContext {
assert(driverResources.get(GPU).name === GPU)
}
}

test("memory override in plugin") {
val conf = new SparkConf()
.setAppName(getClass().getName())
.set(SparkLauncher.SPARK_MASTER, "local-cluster[2,1,1024]")
.set(PLUGINS, Seq(classOf[MemoryOverridePlugin].getName()))

var sc: SparkContext = null
try {
sc = new SparkContext(conf)
val memoryManager = sc.env.memoryManager

assert(memoryManager.tungstenMemoryMode == MemoryMode.OFF_HEAP)
assert(memoryManager.maxOffHeapStorageMemory == MemoryOverridePlugin.offHeapMemory)

// Ensure all executors has started
TestUtils.waitUntilExecutorsUp(sc, 1, 60000)

// Check executor memory is also updated
val execInfo = sc.statusTracker.getExecutorInfos.head
assert(execInfo.totalOffHeapStorageMemory() == MemoryOverridePlugin.offHeapMemory)
} finally {
if (sc != null) {
sc.stop()
}
}
}
}

class MemoryOverridePlugin extends SparkPlugin {
override def driverPlugin(): DriverPlugin = {
new DriverPlugin {
override def init(sc: SparkContext, pluginContext: PluginContext): JMap[String, String] = {
// Take the original executor memory, and set `spark.memory.offHeap.size` to be the
// same value. Also set `spark.memory.offHeap.enabled` to true.
val originalExecutorMemBytes =
sc.conf.getSizeAsMb(EXECUTOR_MEMORY.key, EXECUTOR_MEMORY.defaultValueString)
sc.conf.set(MEMORY_OFFHEAP_ENABLED.key, "true")
sc.conf.set(MEMORY_OFFHEAP_SIZE.key, s"${originalExecutorMemBytes}M")
MemoryOverridePlugin.offHeapMemory = sc.conf.getSizeAsBytes(MEMORY_OFFHEAP_SIZE.key)
Map.empty[String, String].asJava
}
}
}

override def executorPlugin(): ExecutorPlugin = {
new ExecutorPlugin {}
}
}

object MemoryOverridePlugin {
var offHeapMemory: Long = _
}

class NonLocalModeSparkPlugin extends SparkPlugin {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
package org.apache.spark.scheduler

import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.Semaphore
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}

import scala.collection.mutable.ArrayBuffer

Expand All @@ -27,9 +28,11 @@ import org.mockito.Mockito._
import org.scalatest.BeforeAndAfter

import org.apache.spark._
import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, SparkPlugin}
import org.apache.spark.executor.{Executor, TaskMetrics, TaskMetricsSuite}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.METRICS_CONF
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.memory.{MemoryManager, TaskMemoryManager}
import org.apache.spark.metrics.source.JvmSource
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -680,14 +683,100 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
context.markTaskCompleted(None)
assert(isFailed)
}


// This test is for illustration purposes - needs to be cleaned up
test("Ensure the right block manager is used to unroll memory for task") {
import BlockManagerValidationPlugin._
BlockManagerValidationPlugin.resetState()

// run a task which ignores thread interruption when spark context is shutdown
sc = new SparkContext("local", "test")

val rdd = new RDD[String](sc, List()) {
override def getPartitions = Array[Partition](StubPartition(0))
override def compute(split: Partition, context: TaskContext): Iterator[String] = {
context.addTaskCompletionListener(new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = {
var done = false
while (!done) {
try {
releaseTaskSem.acquire(1)
done = true
} catch {
case iEx: InterruptedException =>
// ignore thread interruption
logInfo("Ignoring thread interruption", iEx)
}
}
}
})
taskMemoryManager.set(SparkEnv.get.blockManager.memoryManager)
taskStartedSem.release()
Iterator("hi")
}
}
// submit the job, but dont block this thread
rdd.collectAsync()
// wait for task to start
taskStartedSem.acquire(1)

sc.stop()
assert(sc.isStopped)

// create a new SparkContext
val conf = new SparkConf()
conf.set("spark.plugins", classOf[BlockManagerValidationPlugin].getName)
BlockManagerValidationPlugin.threadLocalState.set(
() => {
val tmm = taskMemoryManager.get()
tmm.synchronized {
releaseTaskSem.release(1)
tmm.wait()
}
Thread.sleep(2500)
}
)
sc = new SparkContext("local", "test", conf)
}
}

private object TaskContextSuite {
private object TaskContextSuite extends Logging {
@volatile var completed = false

@volatile var lastError: Throwable = _

class FakeTaskFailureException extends Exception("Fake task failure")
}

class BlockManagerValidationPlugin extends SparkPlugin {

override def driverPlugin(): DriverPlugin = {
new DriverPlugin() {
// We dont really do anything - other than notifying that plugin creation has completed
// and then wait for a while
Option(BlockManagerValidationPlugin.threadLocalState.get()).foreach(_.apply())
}
}
override def executorPlugin(): ExecutorPlugin = {
new ExecutorPlugin() {
// nothing to see here
}
}
}

object BlockManagerValidationPlugin {
val threadLocalState = new ThreadLocal[() => Unit]()

val releaseTaskSem = new Semaphore(0)
val taskMemoryManager = new AtomicReference[MemoryManager](null)
val taskStartedSem = new Semaphore(0)

def resetState(): Unit = {
releaseTaskSem.drainPermits()
taskStartedSem.drainPermits()
taskMemoryManager.set(null)
}
}

private case class StubPartition(index: Int) extends Partition

0 comments on commit 2d73adc

Please sign in to comment.