From 2d73adcfa4606282cc9bf1e022fc15afa7aebd8f Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Sun, 25 Feb 2024 01:53:37 -0600 Subject: [PATCH] Modification to PR 45052, to illustrate and fix the issue being discussed w.r.t NPE --- .../scala/org/apache/spark/SparkContext.scala | 1 + .../scala/org/apache/spark/SparkEnv.scala | 20 +++- .../org/apache/spark/scheduler/Task.scala | 18 +++- .../apache/spark/storage/BlockManager.scala | 20 ++-- .../plugin/PluginContainerSuite.scala | 53 +++++++++++ .../spark/scheduler/TaskContextSuite.scala | 95 ++++++++++++++++++- 6 files changed, 188 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 801b6dd85a2bd..d519617c4095d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -577,6 +577,7 @@ class SparkContext(config: SparkConf) extends Logging { // Initialize any plugins before the task scheduler is initialized. _plugins = PluginContainer(this, _resources.asJava) _env.initializeShuffleManager() + _env.initializeMemoryManager(SparkContext.numDriverCores(master, conf)) // Create and start the scheduler val (sched, ts) = SparkContext.createTaskScheduler(this, master) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index ca07c276fbff3..005681cc1a1a1 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -67,7 +67,6 @@ class SparkEnv ( val blockManager: BlockManager, val securityManager: SecurityManager, val metricsSystem: MetricsSystem, - val memoryManager: MemoryManager, val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { @@ -77,6 +76,12 @@ class SparkEnv ( def shuffleManager: ShuffleManager = _shuffleManager + // We initialize the MemoryManager later in SparkContext after DriverPlugin is loaded + // to allow the plugin to overwrite memory configurations + private var _memoryManager: MemoryManager = _ + + def memoryManager: MemoryManager = _memoryManager + @volatile private[spark] var isStopped = false /** @@ -199,6 +204,12 @@ class SparkEnv ( "Shuffle manager already initialized to %s", _shuffleManager) _shuffleManager = ShuffleManager.create(conf, executorId == SparkContext.DRIVER_IDENTIFIER) } + + private[spark] def initializeMemoryManager(numUsableCores: Int): Unit = { + Preconditions.checkState(null == memoryManager, + "Memory manager already initialized to %s", _memoryManager) + _memoryManager = UnifiedMemoryManager(conf, numUsableCores) + } } object SparkEnv extends Logging { @@ -276,6 +287,8 @@ object SparkEnv extends Logging { numCores, ioEncryptionKey ) + // Set the memory manager since it needs to be initialized explicitly + env.initializeMemoryManager(numCores) SparkEnv.set(env) env } @@ -358,8 +371,6 @@ object SparkEnv extends Logging { new MapOutputTrackerMasterEndpoint( rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) - val memoryManager: MemoryManager = UnifiedMemoryManager(conf, numUsableCores) - val blockManagerPort = if (isDriver) { conf.get(DRIVER_BLOCK_MANAGER_PORT) } else { @@ -418,7 +429,7 @@ object SparkEnv extends Logging { blockManagerMaster, serializerManager, conf, - memoryManager, + _memoryManager = null, mapOutputTracker, _shuffleManager = null, blockTransferService, @@ -463,7 +474,6 @@ object SparkEnv extends Logging { blockManager, securityManager, metricsSystem, - memoryManager, outputCommitCoordinator, conf) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 1ecd185de557a..b40803308358e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -93,7 +93,16 @@ private[spark] abstract class Task[T]( require(cpus > 0, "CPUs per task should be > 0") - SparkEnv.get.blockManager.registerTask(taskAttemptId) + // Use the blockManager at start of the task through out the task - particularly in + // case of local mode, a SparkEnv can be initialized when spark context is restarted + // and we want to ensure the right env and block manager is used (given lazy initialization of + // block manager) + // For @sunchao - for illustrating the bug - use the def blockManager below and the + // test TaskContextSuite."Ensure the right block manager is used to unroll memory for task" + // will fail + val blockManager = SparkEnv.get.blockManager + // def blockManager = SparkEnv.get.blockManager + blockManager.registerTask(taskAttemptId) // TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether // the stage is barrier. val taskContext = new TaskContextImpl( @@ -143,15 +152,14 @@ private[spark] abstract class Task[T]( try { Utils.tryLogNonFatalError { // Release memory used by this thread for unrolling blocks - SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP) - SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask( - MemoryMode.OFF_HEAP) + blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP) + blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP) // Notify any tasks waiting for execution memory to be freed to wake up and try to // acquire memory again. This makes impossible the scenario where a task sleeps forever // because there are no other tasks left to notify it. Since this is safe to do but may // not be strictly necessary, we should revisit whether we can remove this in the // future. - val memoryManager = SparkEnv.get.memoryManager + val memoryManager = blockManager.memoryManager memoryManager.synchronized { memoryManager.notifyAll() } } } finally { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 42bbd025177b2..228ec5752e1b6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -185,7 +185,7 @@ private[spark] class BlockManager( val master: BlockManagerMaster, val serializerManager: SerializerManager, val conf: SparkConf, - memoryManager: MemoryManager, + private val _memoryManager: MemoryManager, mapOutputTracker: MapOutputTracker, private val _shuffleManager: ShuffleManager, val blockTransferService: BlockTransferService, @@ -198,6 +198,12 @@ private[spark] class BlockManager( // (except for tests) and we ask for the instance from the SparkEnv. private lazy val shuffleManager = Option(_shuffleManager).getOrElse(SparkEnv.get.shuffleManager) + // Similarly, we also initialize MemoryManager later after DriverPlugin is loaded, to + // allow the plugin to overwrite certain memory configurations. The `_memoryManager` will be + // null here and we ask for the instance from SparkEnv + private[spark] lazy val memoryManager = + Option(_memoryManager).getOrElse(SparkEnv.get.memoryManager) + // same as `conf.get(config.SHUFFLE_SERVICE_ENABLED)` private[spark] val externalShuffleServiceEnabled: Boolean = externalBlockStoreClient.isDefined private val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER @@ -224,17 +230,19 @@ private[spark] class BlockManager( ThreadUtils.newDaemonCachedThreadPool("block-manager-future", 128)) // Actual storage of where blocks are kept - private[spark] val memoryStore = - new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, this) + private[spark] lazy val memoryStore = { + val store = new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, this) + memoryManager.setMemoryStore(store) + store + } private[spark] val diskStore = new DiskStore(conf, diskBlockManager, securityManager) - memoryManager.setMemoryStore(memoryStore) // Note: depending on the memory manager, `maxMemory` may actually vary over time. // However, since we use this only for reporting and logging, what we actually want here is // the absolute maximum value that `maxMemory` can ever possibly reach. We may need // to revisit whether reporting this value as the "max" is intuitive to the user. - private val maxOnHeapMemory = memoryManager.maxOnHeapStorageMemory - private val maxOffHeapMemory = memoryManager.maxOffHeapStorageMemory + private lazy val maxOnHeapMemory = memoryManager.maxOnHeapStorageMemory + private lazy val maxOffHeapMemory = memoryManager.maxOffHeapStorageMemory private[spark] val externalShuffleServicePort = StorageUtils.externalShuffleServicePort(conf) diff --git a/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala b/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala index 197c2f13d807b..cdbe5553bc95d 100644 --- a/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.TestUtils._ import org.apache.spark.api.plugin._ import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher +import org.apache.spark.memory.MemoryMode import org.apache.spark.resource.ResourceInformation import org.apache.spark.resource.ResourceUtils.GPU import org.apache.spark.resource.TestResourceIDs.{DRIVER_GPU_ID, EXECUTOR_GPU_ID, WORKER_GPU_ID} @@ -228,6 +229,58 @@ class PluginContainerSuite extends SparkFunSuite with LocalSparkContext { assert(driverResources.get(GPU).name === GPU) } } + + test("memory override in plugin") { + val conf = new SparkConf() + .setAppName(getClass().getName()) + .set(SparkLauncher.SPARK_MASTER, "local-cluster[2,1,1024]") + .set(PLUGINS, Seq(classOf[MemoryOverridePlugin].getName())) + + var sc: SparkContext = null + try { + sc = new SparkContext(conf) + val memoryManager = sc.env.memoryManager + + assert(memoryManager.tungstenMemoryMode == MemoryMode.OFF_HEAP) + assert(memoryManager.maxOffHeapStorageMemory == MemoryOverridePlugin.offHeapMemory) + + // Ensure all executors has started + TestUtils.waitUntilExecutorsUp(sc, 1, 60000) + + // Check executor memory is also updated + val execInfo = sc.statusTracker.getExecutorInfos.head + assert(execInfo.totalOffHeapStorageMemory() == MemoryOverridePlugin.offHeapMemory) + } finally { + if (sc != null) { + sc.stop() + } + } + } +} + +class MemoryOverridePlugin extends SparkPlugin { + override def driverPlugin(): DriverPlugin = { + new DriverPlugin { + override def init(sc: SparkContext, pluginContext: PluginContext): JMap[String, String] = { + // Take the original executor memory, and set `spark.memory.offHeap.size` to be the + // same value. Also set `spark.memory.offHeap.enabled` to true. + val originalExecutorMemBytes = + sc.conf.getSizeAsMb(EXECUTOR_MEMORY.key, EXECUTOR_MEMORY.defaultValueString) + sc.conf.set(MEMORY_OFFHEAP_ENABLED.key, "true") + sc.conf.set(MEMORY_OFFHEAP_SIZE.key, s"${originalExecutorMemBytes}M") + MemoryOverridePlugin.offHeapMemory = sc.conf.getSizeAsBytes(MEMORY_OFFHEAP_SIZE.key) + Map.empty[String, String].asJava + } + } + } + + override def executorPlugin(): ExecutorPlugin = { + new ExecutorPlugin {} + } +} + +object MemoryOverridePlugin { + var offHeapMemory: Long = _ } class NonLocalModeSparkPlugin extends SparkPlugin { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 9aba41cea2150..df7b820672e71 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.scheduler import java.util.Properties -import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.Semaphore +import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} import scala.collection.mutable.ArrayBuffer @@ -27,9 +28,11 @@ import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.apache.spark._ +import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, SparkPlugin} import org.apache.spark.executor.{Executor, TaskMetrics, TaskMetricsSuite} +import org.apache.spark.internal.Logging import org.apache.spark.internal.config.METRICS_CONF -import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.memory.{MemoryManager, TaskMemoryManager} import org.apache.spark.metrics.source.JvmSource import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD @@ -680,9 +683,65 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark context.markTaskCompleted(None) assert(isFailed) } + + + // This test is for illustration purposes - needs to be cleaned up + test("Ensure the right block manager is used to unroll memory for task") { + import BlockManagerValidationPlugin._ + BlockManagerValidationPlugin.resetState() + + // run a task which ignores thread interruption when spark context is shutdown + sc = new SparkContext("local", "test") + + val rdd = new RDD[String](sc, List()) { + override def getPartitions = Array[Partition](StubPartition(0)) + override def compute(split: Partition, context: TaskContext): Iterator[String] = { + context.addTaskCompletionListener(new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = { + var done = false + while (!done) { + try { + releaseTaskSem.acquire(1) + done = true + } catch { + case iEx: InterruptedException => + // ignore thread interruption + logInfo("Ignoring thread interruption", iEx) + } + } + } + }) + taskMemoryManager.set(SparkEnv.get.blockManager.memoryManager) + taskStartedSem.release() + Iterator("hi") + } + } + // submit the job, but dont block this thread + rdd.collectAsync() + // wait for task to start + taskStartedSem.acquire(1) + + sc.stop() + assert(sc.isStopped) + + // create a new SparkContext + val conf = new SparkConf() + conf.set("spark.plugins", classOf[BlockManagerValidationPlugin].getName) + BlockManagerValidationPlugin.threadLocalState.set( + () => { + val tmm = taskMemoryManager.get() + tmm.synchronized { + releaseTaskSem.release(1) + tmm.wait() + } + Thread.sleep(2500) + } + ) + sc = new SparkContext("local", "test", conf) + } } -private object TaskContextSuite { +private object TaskContextSuite extends Logging { @volatile var completed = false @volatile var lastError: Throwable = _ @@ -690,4 +749,34 @@ private object TaskContextSuite { class FakeTaskFailureException extends Exception("Fake task failure") } +class BlockManagerValidationPlugin extends SparkPlugin { + + override def driverPlugin(): DriverPlugin = { + new DriverPlugin() { + // We dont really do anything - other than notifying that plugin creation has completed + // and then wait for a while + Option(BlockManagerValidationPlugin.threadLocalState.get()).foreach(_.apply()) + } + } + override def executorPlugin(): ExecutorPlugin = { + new ExecutorPlugin() { + // nothing to see here + } + } +} + +object BlockManagerValidationPlugin { + val threadLocalState = new ThreadLocal[() => Unit]() + + val releaseTaskSem = new Semaphore(0) + val taskMemoryManager = new AtomicReference[MemoryManager](null) + val taskStartedSem = new Semaphore(0) + + def resetState(): Unit = { + releaseTaskSem.drainPermits() + taskStartedSem.drainPermits() + taskMemoryManager.set(null) + } +} + private case class StubPartition(index: Int) extends Partition