Skip to content

[SPARK-51955] Adding release() to ReadStateStore interface and reusing ReadStore for Streaming Aggregations #50742

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with

override def abort(): Unit = {}

override def release(): Unit = {}

override def toString(): String = {
s"HDFSReadStateStore[id=(op=${id.operatorId},part=${id.partitionId}),dir=$baseDir]"
}
Expand All @@ -112,6 +114,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
case object UPDATING extends STATE
case object COMMITTED extends STATE
case object ABORTED extends STATE
case object RELEASED extends STATE

private val newVersion = version + 1
@volatile private var state: STATE = UPDATING
Expand Down Expand Up @@ -194,6 +197,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
log"for ${MDC(LogKeys.STATE_STORE_PROVIDER, this)}")
}

override def release(): Unit = {
state = RELEASED
}

/**
* Get an iterator of all the store data.
* This can be called only after committing all the updates made in the current thread.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,12 @@ class RocksDB(
}
}

def release(): Unit = {
if (db != null) {
release(LoadStore)
}
}

/**
* Commit all the updates made as a version to DFS. The steps it needs to do to commits are:
* - Flush all changes to disk
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ private[sql] class RocksDBStateStoreProvider
case object UPDATING extends STATE
case object COMMITTED extends STATE
case object ABORTED extends STATE
case object RELEASED extends STATE

@volatile private var state: STATE = UPDATING
@volatile private var isValidated = false
Expand Down Expand Up @@ -365,6 +366,18 @@ private[sql] class RocksDBStateStoreProvider
}
result
}

override def release(): Unit = {
if (state != RELEASED) {
logInfo(log"Releasing ${MDC(VERSION_NUM, version + 1)} " +
log"for ${MDC(STATE_STORE_ID, id)}")
rocksDB.release()
state = RELEASED
} else {
// Optionally log at DEBUG level that it's already released
logDebug(log"State store already released")
}
}
}

// Test-visible method to fetch the internal RocksDBStateStore class
Expand Down Expand Up @@ -446,17 +459,49 @@ private[sql] class RocksDBStateStoreProvider

override def stateStoreId: StateStoreId = stateStoreId_

override def getStore(version: Long, uniqueId: Option[String] = None): StateStore = {
/**
* Creates and returns a state store with the specified parameters.
*
* @param version The version of the state store to load
* @param uniqueId Optional unique identifier for checkpoint
* @param readOnly Whether to open the store in read-only mode
* @param existingStore Optional existing store to reuse instead of creating a new one
* @return The loaded state store
*/
private def loadStateStore(
version: Long,
uniqueId: Option[String],
readOnly: Boolean,
existingStore: Option[ReadStateStore] = None): StateStore = {
try {
if (version < 0) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
}
rocksDB.load(
version,
stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None)
new RocksDBStateStore(version)
}
catch {
try {
// Load RocksDB store
rocksDB.load(
version,
stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None,
readOnly = readOnly)

// Return appropriate store instance
existingStore match {
// We need to match like this as opposed to case Some(ss: RocksDBStateStore)
// because of how the tests create the class in StateStoreRDDSuite
case Some(stateStore: ReadStateStore) if stateStore.isInstanceOf[RocksDBStateStore] =>
stateStore.asInstanceOf[StateStore]
case Some(other) =>
throw new IllegalArgumentException(s"Existing store must be a RocksDBStateStore," +
s" store is actually ${other.getClass.getSimpleName}")
case None =>
// Create new store instance for getStore/getReadStore cases
new RocksDBStateStore(version)
}
} catch {
case e: Throwable =>
throw e
}
} catch {
case e: SparkException
if Option(e.getCondition).exists(_.contains("CANNOT_LOAD_STATE_STORE")) =>
throw e
Expand All @@ -469,28 +514,24 @@ private[sql] class RocksDBStateStoreProvider
}
}

override def getStore(version: Long, uniqueId: Option[String] = None): StateStore = {
loadStateStore(version, uniqueId, readOnly = false)
}

override def upgradeReadStoreToWriteStore(
readStore: ReadStateStore,
version: Long,
uniqueId: Option[String] = None): StateStore = {
assert(version == readStore.version,
s"Can only upgrade readStore to writeStore with the same version," +
s" readStoreVersion: ${readStore.version}, writeStoreVersion: ${version}")
assert(this.stateStoreId == readStore.id, "Can only upgrade readStore to writeStore with" +
" the same stateStoreId")
loadStateStore(version, uniqueId, readOnly = false, existingStore = Some(readStore))
}

override def getReadStore(version: Long, uniqueId: Option[String] = None): StateStore = {
try {
if (version < 0) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
}
rocksDB.load(
version,
stateStoreCkptId = if (storeConf.enableStateStoreCheckpointIds) uniqueId else None,
readOnly = true)
new RocksDBStateStore(version)
}
catch {
case e: SparkException
if Option(e.getCondition).exists(_.contains("CANNOT_LOAD_STATE_STORE")) =>
throw e
case e: OutOfMemoryError =>
throw QueryExecutionErrors.notEnoughMemoryToLoadStore(
stateStoreId.toString,
"ROCKSDB_STORE_PROVIDER",
e)
case e: Throwable => throw QueryExecutionErrors.cannotLoadStore(e)
}
loadStateStore(version, uniqueId, readOnly = true)
}

override def doMaintenance(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,21 @@ trait ReadStateStore {
* The method name is to respect backward compatibility on [[StateStore]].
*/
def abort(): Unit

/**
* Releases resources associated with this read-only state store.
*
* This method should be called when the store is no longer needed but has completed
* successfully (i.e., no errors occurred during reading). It performs any necessary
* cleanup operations without invalidating or rolling back the data that was read.
*
* In contrast to `abort()`, which is called on error paths to cancel operations,
* `release()` is the proper method to call in success scenarios when a read-only
* store is no longer needed.
*
* This method is idempotent and safe to call multiple times.
*/
def release(): Unit
}

/**
Expand Down Expand Up @@ -234,6 +249,8 @@ class WrappedReadStateStore(store: StateStore) extends ReadStateStore {

override def abort(): Unit = store.abort()

override def release(): Unit = {}

override def prefixScan(prefixKey: UnsafeRow,
colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] =
store.prefixScan(prefixKey, colFamilyName)
Expand Down Expand Up @@ -565,6 +582,29 @@ trait StateStoreProvider {
version: Long,
stateStoreCkptId: Option[String] = None): StateStore

/**
* Creates a writable store from an existing read-only store for the specified version.
*
* This method enables an important optimization pattern for stateful operations where
* the same state store needs to be accessed for both reading and writing within a task.
* Instead of opening two separate state store instances (which can cause contention issues),
* this method converts an existing read-only store to a writable store that can commit changes.
*
* This approach is particularly beneficial when:
* - A stateful operation needs to first read the existing state, then update it
* - The state store has locking mechanisms that prevent concurrent access
* - Multiple state store connections would cause unnecessary resource duplication
*
* @param readStore The existing read-only store instance to convert to a writable store
* @param version The version of the state store (must match the read store's version)
* @param uniqueId Optional unique identifier for checkpointing
* @return A writable StateStore instance that can be used to update and commit changes
*/
def upgradeReadStoreToWriteStore(
readStore: ReadStateStore,
version: Long,
uniqueId: Option[String] = None): StateStore = getStore(version, uniqueId)

/**
* Return an instance of [[ReadStateStore]] representing state data of the given version
* and uniqueID if provided.
Expand Down Expand Up @@ -823,7 +863,6 @@ class UnsafeRowPair(var key: UnsafeRow = null, var value: UnsafeRow = null) {
}
}


/**
* Companion object to [[StateStore]] that provides helper methods to create and retrieve stores
* by their unique ids. In addition, when a SparkContext is active (i.e. SparkEnv.get is not null),
Expand Down Expand Up @@ -950,6 +989,29 @@ object StateStore extends Logging {
storeProvider.getReadStore(version, stateStoreCkptId)
}

def getWriteStore(
readStore: ReadStateStore,
storeProviderId: StateStoreProviderId,
keySchema: StructType,
valueSchema: StructType,
keyStateEncoderSpec: KeyStateEncoderSpec,
version: Long,
stateStoreCkptId: Option[String],
stateSchemaBroadcast: Option[StateSchemaBroadcast],
useColumnFamilies: Boolean,
storeConf: StateStoreConf,
hadoopConf: Configuration,
useMultipleValuesPerKey: Boolean = false): StateStore = {
hadoopConf.set(StreamExecution.RUN_ID_KEY, storeProviderId.queryRunId.toString)
if (version < 0) {
throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
}
val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema,
keyStateEncoderSpec, useColumnFamilies, storeConf, hadoopConf, useMultipleValuesPerKey,
stateSchemaBroadcast)
storeProvider.upgradeReadStoreToWriteStore(readStore, version, stateStoreCkptId)
}

/** Get or create a store associated with the id. */
def get(
storeProviderId: StateStoreProviderId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,43 @@ import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration

/**
* Thread local storage for sharing StateStore instances between RDDs.
* This allows a ReadStateStore to be reused by a subsequent StateStore operation.
*/
object StateStoreThreadLocalTracker {
/** Case class to hold both the store and its usage state */
case class StoreInfo(store: ReadStateStore, usedForWriteStore: Boolean = false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe move members to a new line each ?


private val storeInfo: ThreadLocal[StoreInfo] = new ThreadLocal[StoreInfo]

def setStore(store: ReadStateStore): Unit = {
Option(storeInfo.get()) match {
case Some(info) => storeInfo.set(info.copy(store = store))
case None => storeInfo.set(StoreInfo(store))
}
}

def getStore: Option[ReadStateStore] = {
Option(storeInfo.get()).map(_.store)
}

def setUsedForWriteStore(used: Boolean): Unit = {
Option(storeInfo.get()) match {
case Some(info) => storeInfo.set(info.copy(usedForWriteStore = used))
case None => // If there's no store set, we don't need to track usage
}
}

def isUsedForWriteStore: Boolean = {
Option(storeInfo.get()).exists(_.usedForWriteStore)
}

def clearStore(): Unit = {
storeInfo.remove()
}
}

abstract class BaseStateStoreRDD[T: ClassTag, U: ClassTag](
dataRDD: RDD[T],
checkpointLocation: String,
Expand Down Expand Up @@ -95,6 +132,7 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag](
stateStoreCkptIds.map(_.apply(partition.index).head),
stateSchemaBroadcast,
useColumnFamilies, storeConf, hadoopConfBroadcast.value.value)
StateStoreThreadLocalTracker.setStore(store)
storeReadFunction(store, inputIter)
}
}
Expand Down Expand Up @@ -130,12 +168,26 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
val storeProviderId = getStateProviderId(partition)

val inputIter = dataRDD.iterator(partition, ctxt)
val store = StateStore.get(
storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, storeVersion,
uniqueId.map(_.apply(partition.index).head),
stateSchemaBroadcast,
useColumnFamilies, storeConf, hadoopConfBroadcast.value.value,
useMultipleValuesPerKey)
val store = StateStoreThreadLocalTracker.getStore match {
case Some(readStateStore: ReadStateStore) =>
val writeStore = StateStore.getWriteStore(readStateStore, storeProviderId,
keySchema, valueSchema, keyStateEncoderSpec, storeVersion,
uniqueId.map(_.apply(partition.index).head),
stateSchemaBroadcast,
useColumnFamilies, storeConf, hadoopConfBroadcast.value.value,
useMultipleValuesPerKey)
if (writeStore.equals(readStateStore)) {
StateStoreThreadLocalTracker.setUsedForWriteStore(true)
}
writeStore
case None =>
StateStore.get(
storeProviderId, keySchema, valueSchema, keyStateEncoderSpec, storeVersion,
uniqueId.map(_.apply(partition.index).head),
stateSchemaBroadcast,
useColumnFamilies, storeConf, hadoopConfBroadcast.value.value,
useMultipleValuesPerKey)
}

if (storeConf.unloadOnCommit) {
ctxt.addTaskCompletionListener[Unit](_ => {
Expand Down
Loading