Skip to content

[SPARK-51947] Spark connect model cache offloading #50752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
10 changes: 10 additions & 0 deletions python/pyspark/ml/tests/connect/test_parity_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ class TuningParityTests(TuningTestsMixin, ReusedConnectTestCase):
pass


class TuningParityWithMLCacheOffloadingEnabledTests(TuningTestsMixin, ReusedConnectTestCase):
@classmethod
def conf(cls):
conf = super().conf()
conf.set("spark.connect.session.connectML.mlCache.offloading.enabled", "true")
conf.set("spark.connect.session.connectML.mlCache.offloading.maxInMemorySize", "1024")
conf.set("spark.connect.session.connectML.mlCache.offloading.timeout", "1")
return conf


if __name__ == "__main__":
from pyspark.ml.tests.connect.test_parity_tuning import * # noqa: F401

Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/testing/connectutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ def conf(cls):
# Set a static token for all tests so the parallelism doesn't overwrite each
# tests' environment variables
conf.set("spark.connect.authenticate.token", "deadbeef")
# Make the max size of ML Cache larger, to avoid CONNECT_ML.CACHE_INVALID issues
# in tests.
conf.set("spark.connect.session.connectML.mlCache.maxSize", "1g")
# Disable ml cache offloading,
# offloading hasn't supported APIs like model.summary / model.evaluate
conf.set("spark.connect.session.connectML.mlCache.offloading.enabled", "false")
return conf

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,23 +334,36 @@ object Connect {
}
}

val CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE =
buildConf("spark.connect.session.connectML.mlCache.maxSize")
.doc("Maximum size of the MLCache per session. The cache will evict the least recently" +
"used models if the size exceeds this limit. The size is in bytes.")
val CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_MAX_IN_MEMORY_SIZE =
buildConf("spark.connect.session.connectML.mlCache.offloading.maxInMemorySize")
.doc(
"In-memory maximum size of the MLCache per session. The cache will offload the least " +
"recently used models to Spark driver local disk if the size exceeds this limit. " +
"The size is in bytes. This configuration only works when " +
"'spark.connect.session.connectML.mlCache.offloading.enabled' is 'true'.")
.version("4.1.0")
.internal()
.bytesConf(ByteUnit.BYTE)
// By default, 1/3 of total designated memory (the configured -Xmx).
.createWithDefault(Runtime.getRuntime.maxMemory() / 3)

val CONNECT_SESSION_CONNECT_ML_CACHE_TIMEOUT =
buildConf("spark.connect.session.connectML.mlCache.timeout")
val CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_TIMEOUT =
buildConf("spark.connect.session.connectML.mlCache.offloading.timeout")
.doc(
"Timeout of models in MLCache. Models will be evicted from the cache if they are not " +
"used for this amount of time. The timeout is in minutes.")
"Timeout of model offloading in MLCache. Models will be offloaded to Spark driver local " +
"disk if they are not used for this amount of time. The timeout is in minutes. " +
"This configuration only works when " +
"'spark.connect.session.connectML.mlCache.offloading.enabled' is 'true'.")
.version("4.1.0")
.internal()
.timeConf(TimeUnit.MINUTES)
.createWithDefault(15)
.createWithDefault(5)

val CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED =
buildConf("spark.connect.session.connectML.mlCache.offloading.enabled")
.doc("Enables ML cache offloading.")
.version("4.1.0")
.internal()
.booleanConf
.createWithDefault(true)
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,20 @@
*/
package org.apache.spark.sql.connect.ml

import java.io.File
import java.nio.file.{Files, Path, Paths}
import java.util.UUID
import java.util.concurrent.{ConcurrentMap, TimeUnit}
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, TimeUnit}
import java.util.concurrent.atomic.AtomicLong

import scala.collection.mutable

import com.google.common.cache.{CacheBuilder, RemovalNotification}
import org.apache.commons.io.FileUtils

import org.apache.spark.internal.Logging
import org.apache.spark.ml.Model
import org.apache.spark.ml.util.ConnectHelper
import org.apache.spark.ml.util.{ConnectHelper, MLWritable, Summary}
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.service.SessionHolder

Expand All @@ -36,38 +39,74 @@ import org.apache.spark.sql.connect.service.SessionHolder
private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging {
private val helper = new ConnectHelper()
private val helperID = "______ML_CONNECT_HELPER______"
private val modelClassNameFile = "__model_class_name__"

private def getMaxCacheSizeKB: Long = {
sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE) / 1024
// TODO: rename it to `totalInMemorySizeBytes` because it only counts the in-memory
// part data size.
private[ml] val totalSizeBytes: AtomicLong = new AtomicLong(0)

val offloadedModelsDir: Path = {
val path = Paths.get(
System.getProperty("java.io.tmpdir"),
"spark_connect_model_cache",
sessionHolder.sessionId)
Files.createDirectories(path)
}
private[spark] def getOffloadingEnabled: Boolean = {
sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED)
}

private def getTimeoutMinute: Long = {
sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_TIMEOUT)
private def getMaxInMemoryCacheSizeKB: Long = {
sessionHolder.session.conf.get(
Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_MAX_IN_MEMORY_SIZE) / 1024
}

private def getOffloadingTimeoutMinute: Long = {
sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_TIMEOUT)
}

private[ml] case class CacheItem(obj: Object, sizeBytes: Long)
private[ml] val cachedModel: ConcurrentMap[String, CacheItem] = CacheBuilder
.newBuilder()
.softValues()
.maximumWeight(getMaxCacheSizeKB)
.expireAfterAccess(getTimeoutMinute, TimeUnit.MINUTES)
.weigher((key: String, value: CacheItem) => {
Math.ceil(value.sizeBytes.toDouble / 1024).toInt
})
.removalListener((removed: RemovalNotification[String, CacheItem]) =>
totalSizeBytes.addAndGet(-removed.getValue.sizeBytes))
.build[String, CacheItem]()
.asMap()
private[ml] val cachedModel: ConcurrentMap[String, CacheItem] = {
if (getOffloadingEnabled) {
CacheBuilder
.newBuilder()
.softValues()
.removalListener((removed: RemovalNotification[String, CacheItem]) =>
totalSizeBytes.addAndGet(-removed.getValue.sizeBytes))
.maximumWeight(getMaxInMemoryCacheSizeKB)
.weigher((key: String, value: CacheItem) => {
Math.ceil(value.sizeBytes.toDouble / 1024).toInt
})
.expireAfterAccess(getOffloadingTimeoutMinute, TimeUnit.MINUTES)
.build[String, CacheItem]()
.asMap()
} else {
new ConcurrentHashMap[String, CacheItem]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it is risky that mlcache has no memory limit when offloading is disabled? Probably it makes sense to keep the old behaviour in this case: mlcache also has memory limit and retention time when offloading is disabled, but once a model is evicted, future access to it will throw an error CACHE_INVALID. WDYT?

}
}

private[ml] val totalSizeBytes: AtomicLong = new AtomicLong(0)
private[ml] val cachedSummary: ConcurrentMap[String, Summary] = {
new ConcurrentHashMap[String, Summary]()
}

private def estimateObjectSize(obj: Object): Long = {
obj match {
case model: Model[_] =>
model.asInstanceOf[Model[_]].estimatedSize
case _ =>
// There can only be Models in the cache, so we should never reach here.
1
throw new RuntimeException(f"Unexpected model object type.")
}
}

private[spark] def checkSummaryAvail(): Unit = {
if (getOffloadingEnabled) {
throw MlUnsupportedException(
"SparkML 'model.summary' and 'model.evaluate' APIs are not supported' when " +
"Spark Connect session ML cache offloading is enabled. You can use APIs in " +
"'pyspark.ml.evaluation' instead, or you can set Spark config " +
"'spark.connect.session.connectML.mlCache.offloading.enabled' to 'false' to " +
"disable Spark Connect session ML cache offloading.")
}
}

Expand All @@ -80,9 +119,26 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging {
*/
def register(obj: Object): String = {
val objectId = UUID.randomUUID().toString
val sizeBytes = estimateObjectSize(obj)
totalSizeBytes.addAndGet(sizeBytes)
cachedModel.put(objectId, CacheItem(obj, sizeBytes))

if (obj.isInstanceOf[Summary]) {
checkSummaryAvail()
cachedSummary.put(objectId, obj.asInstanceOf[Summary])
} else if (obj.isInstanceOf[Model[_]]) {
val sizeBytes = if (getOffloadingEnabled) {
estimateObjectSize(obj)
} else {
0L // Don't need to calculate size if disables offloading.
}
cachedModel.put(objectId, CacheItem(obj, sizeBytes))
if (getOffloadingEnabled) {
val savePath = offloadedModelsDir.resolve(objectId)
obj.asInstanceOf[MLWritable].write.saveToLocal(savePath.toString)
Files.writeString(savePath.resolve(modelClassNameFile), obj.getClass.getName)
}
totalSizeBytes.addAndGet(sizeBytes)
} else {
throw new RuntimeException("'MLCache.register' only accepts model or summary objects.")
}
objectId
}

Expand All @@ -97,8 +153,41 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging {
if (refId == helperID) {
helper
} else {
Option(cachedModel.get(refId)).map(_.obj).orNull
var obj: Object =
Option(cachedModel.get(refId)).map(_.obj).getOrElse(cachedSummary.get(refId))
if (obj == null && getOffloadingEnabled) {
val loadPath = offloadedModelsDir.resolve(refId)
if (Files.isDirectory(loadPath)) {
val className = Files.readString(loadPath.resolve(modelClassNameFile))
obj = MLUtils.loadTransformer(
sessionHolder,
className,
loadPath.toString,
loadFromLocal = true)
val sizeBytes = estimateObjectSize(obj)
cachedModel.put(refId, CacheItem(obj, sizeBytes))
totalSizeBytes.addAndGet(sizeBytes)
}
}
obj
}
}

def _removeModel(refId: String): Boolean = {
val removedModel = cachedModel.remove(refId)
val removedFromMem = removedModel != null
val removedFromDisk = if (getOffloadingEnabled) {
val offloadingPath = new File(offloadedModelsDir.resolve(refId).toString)
if (offloadingPath.exists()) {
FileUtils.deleteDirectory(offloadingPath)
true
} else {
false
}
} else {
false
}
removedFromMem || removedFromDisk
}

/**
Expand All @@ -107,9 +196,14 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging {
* the key used to look up the corresponding object
*/
def remove(refId: String): Boolean = {
val removed = cachedModel.remove(refId)
// remove returns null if the key is not present
removed != null
val modelIsRemoved = _removeModel(refId)

if (modelIsRemoved) {
true
} else {
val removedSummary = cachedSummary.remove(refId)
removedSummary != null
}
}

/**
Expand All @@ -118,6 +212,10 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging {
def clear(): Int = {
val size = cachedModel.size()
cachedModel.clear()
cachedSummary.clear()
if (getOffloadingEnabled) {
FileUtils.cleanDirectory(new File(offloadedModelsDir.toString))
}
size
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.ml.util.{MLWritable, Summary}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.ml.Serializer.deserializeMethodArguments
import org.apache.spark.sql.connect.service.SessionHolder

Expand All @@ -42,7 +43,7 @@ private class AttributeHelper(
val sessionHolder: SessionHolder,
val objRef: String,
val methods: Array[Method]) {
protected lazy val instance = {
protected def instance(): Object = {
val obj = sessionHolder.mlCache.get(objRef)
if (obj == null) {
throw MLCacheInvalidException(s"object $objRef")
Expand All @@ -52,7 +53,10 @@ private class AttributeHelper(
// Get the attribute by reflection
def getAttribute: Any = {
assert(methods.length >= 1)
methods.foldLeft(instance) { (obj, m) =>
methods.foldLeft(instance()) { (obj, m) =>
if (obj.isInstanceOf[Summary]) {
sessionHolder.mlCache.checkSummaryAvail()
}
if (m.argValues.isEmpty) {
MLUtils.invokeMethodAllowed(obj, m.name)
} else {
Expand All @@ -71,7 +75,7 @@ private class ModelAttributeHelper(

def transform(relation: proto.MlRelation.Transform): DataFrame = {
// Create a copied model to avoid concurrently modify model params.
val model = instance.asInstanceOf[Model[_]]
val model = instance().asInstanceOf[Model[_]]
val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]]
MLUtils.setInstanceParams(copiedModel, relation.getParams)
val inputDF = MLUtils.parseRelationProto(relation.getInput, sessionHolder)
Expand Down Expand Up @@ -119,13 +123,31 @@ private[connect] object MLHandler extends Logging {

mlCommand.getCommandCase match {
case proto.MlCommand.CommandCase.FIT =>
val offloadingEnabled = sessionHolder.session.conf.get(
Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED)
val fitCmd = mlCommand.getFit
val estimatorProto = fitCmd.getEstimator
assert(estimatorProto.getType == proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR)

val dataset = MLUtils.parseRelationProto(fitCmd.getDataset, sessionHolder)
val estimator =
MLUtils.getEstimator(sessionHolder, estimatorProto, Some(fitCmd.getParams))
if (offloadingEnabled) {
if (estimator.getClass.getName == "org.apache.spark.ml.fpm.FPGrowth") {
throw MlUnsupportedException(
"FPGrowth algorithm is not supported " +
"if Spark Connect model cache offloading is enabled.")
}
if (estimator.getClass.getName == "org.apache.spark.ml.clustering.LDA"
&& estimator
.asInstanceOf[org.apache.spark.ml.clustering.LDA]
.getOptimizer
.toLowerCase() == "em") {
throw MlUnsupportedException(
"LDA algorithm with 'em' optimizer is not supported " +
"if Spark Connect model cache offloading is enabled.")
}
}
val model = estimator.fit(dataset).asInstanceOf[Model[_]]
val id = mlCache.register(model)
proto.MlCommandResult
Expand Down
Loading