From 9b30983f7ea18c2a6e67bc85901e08727f60c815 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Tue, 29 Apr 2025 19:08:45 +0800 Subject: [PATCH 01/16] init Signed-off-by: Weichen Xu --- .../spark/sql/connect/config/Connect.scala | 30 +++-- .../apache/spark/sql/connect/ml/MLCache.scala | 107 +++++++++++++----- .../apache/spark/sql/connect/ml/MLSuite.scala | 2 +- 3 files changed, 101 insertions(+), 38 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index 1b9f770e9e96a..2c6f65f184061 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -334,23 +334,35 @@ object Connect { } } - val CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE = - buildConf("spark.connect.session.connectML.mlCache.maxSize") - .doc("Maximum size of the MLCache per session. The cache will evict the least recently" + - "used models if the size exceeds this limit. The size is in bytes.") + val CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_MAX_IN_MEMORY_SIZE = + buildConf("spark.connect.session.connectML.mlCache.offloading.maxInMemorySize") + .doc("In-memory maximum size of the MLCache per session. The cache will offload the least " + + "recently used models to Spark driver local disk if the size exceeds this limit. " + + "The size is in bytes. This configuration only works when " + + "'spark.connect.session.connectML.mlCache.offloading.enabled' is 'true'.") .version("4.1.0") .internal() .bytesConf(ByteUnit.BYTE) // By default, 1/3 of total designated memory (the configured -Xmx). .createWithDefault(Runtime.getRuntime.maxMemory() / 3) - val CONNECT_SESSION_CONNECT_ML_CACHE_TIMEOUT = - buildConf("spark.connect.session.connectML.mlCache.timeout") + val CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_TIMEOUT = + buildConf("spark.connect.session.connectML.mlCache.offloading.timeout") .doc( - "Timeout of models in MLCache. Models will be evicted from the cache if they are not " + - "used for this amount of time. The timeout is in minutes.") + "Timeout of model offloading in MLCache. Models will be offloaded to Spark driver local " + + "disk if they are not used for this amount of time. The timeout is in minutes. " + + "This configuration only works when " + + "'spark.connect.session.connectML.mlCache.offloading.enabled' is 'true'.") .version("4.1.0") .internal() .timeConf(TimeUnit.MINUTES) - .createWithDefault(15) + .createWithDefault(5) + + val CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED = + buildConf("spark.connect.session.connectML.mlCache.offloading.enabled") + .doc("Enables ML cache offloading.") + .version("4.1.0") + .internal() + .booleanConf + .createWithDefault(false) } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala index 05fa976b5beab..89cb34bb9f3ff 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala @@ -16,19 +16,20 @@ */ package org.apache.spark.sql.connect.ml +import java.nio.file.{Files, Path} import java.util.UUID -import java.util.concurrent.{ConcurrentMap, TimeUnit} -import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, TimeUnit} import scala.collection.mutable -import com.google.common.cache.{CacheBuilder, RemovalNotification} +import com.google.common.cache.CacheBuilder import org.apache.spark.internal.Logging import org.apache.spark.ml.Model -import org.apache.spark.ml.util.ConnectHelper +import org.apache.spark.ml.util.{ConnectHelper, MLWriter, Summary} import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.service.SessionHolder +import org.apache.spark.util.Utils /** * MLCache is for caching ML objects, typically for models and summaries evaluated by a model. @@ -36,30 +37,45 @@ import org.apache.spark.sql.connect.service.SessionHolder private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { private val helper = new ConnectHelper() private val helperID = "______ML_CONNECT_HELPER______" + private val modelClassNameFile = "__model_class_name__" - private def getMaxCacheSizeKB: Long = { - sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE) / 1024 + val offloadedModelsDir: Path = Utils.createTempDir().toPath + private def getOffloadingEnabled: Boolean = { + sessionHolder.session.conf.get( + Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED + ) } - private def getTimeoutMinute: Long = { - sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_TIMEOUT) + private def getMaxInMemoryCacheSizeKB: Long = { + sessionHolder.session.conf.get( + Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_MAX_IN_MEMORY_SIZE + ) / 1024 + } + + private def getOffloadingTimeoutMinute: Long = { + sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_TIMEOUT) } private[ml] case class CacheItem(obj: Object, sizeBytes: Long) - private[ml] val cachedModel: ConcurrentMap[String, CacheItem] = CacheBuilder - .newBuilder() - .softValues() - .maximumWeight(getMaxCacheSizeKB) - .expireAfterAccess(getTimeoutMinute, TimeUnit.MINUTES) - .weigher((key: String, value: CacheItem) => { - Math.ceil(value.sizeBytes.toDouble / 1024).toInt - }) - .removalListener((removed: RemovalNotification[String, CacheItem]) => - totalSizeBytes.addAndGet(-removed.getValue.sizeBytes)) - .build[String, CacheItem]() - .asMap() - - private[ml] val totalSizeBytes: AtomicLong = new AtomicLong(0) + private[ml] val cachedModel: ConcurrentMap[String, CacheItem] = { + var builder = CacheBuilder + .newBuilder() + .softValues() + .weigher((key: String, value: CacheItem) => { + Math.ceil(value.sizeBytes.toDouble / 1024).toInt + }) + + if (getOffloadingEnabled) { + builder = builder + .maximumWeight(getMaxInMemoryCacheSizeKB) + .expireAfterAccess(getOffloadingTimeoutMinute, TimeUnit.MINUTES) + } + builder.build[String, CacheItem]().asMap() + } + + private[ml] val cachedSummary: ConcurrentMap[String, Summary] = { + new ConcurrentHashMap[String, Summary]() + } private def estimateObjectSize(obj: Object): Long = { obj match { @@ -67,7 +83,7 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { model.asInstanceOf[Model[_]].estimatedSize case _ => // There can only be Models in the cache, so we should never reach here. - 1 + throw new RuntimeException(f"Unexpected model object type.") } } @@ -80,9 +96,29 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { */ def register(obj: Object): String = { val objectId = UUID.randomUUID().toString - val sizeBytes = estimateObjectSize(obj) - totalSizeBytes.addAndGet(sizeBytes) - cachedModel.put(objectId, CacheItem(obj, sizeBytes)) + + if (obj.isInstanceOf[Summary]) { + if (getOffloadingEnabled) { + throw new RuntimeException( + "SparkML 'model.summary' and 'model.evaluate' APIs are not supported' when " + + "Spark Connect session ML cache offloading is enabled. You can use APIs in " + + "'pyspark.ml.evaluation' instead.") + } + cachedSummary.put(objectId, obj.asInstanceOf[Summary]) + } else if (obj.isInstanceOf[Model[_]]) { + val sizeBytes = estimateObjectSize(obj) + cachedModel.put(objectId, CacheItem(obj, sizeBytes)) + if (getOffloadingEnabled) { + val savePath = offloadedModelsDir.resolve(objectId) + obj.asInstanceOf[MLWriter].saveToLocal(savePath.toString) + Files.writeString( + savePath.resolve(modelClassNameFile), + obj.getClass.getName + ) + } + } else { + throw new RuntimeException("'MLCache.register' only accepts model or summary objects.") + } objectId } @@ -97,7 +133,18 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { if (refId == helperID) { helper } else { - Option(cachedModel.get(refId)).map(_.obj).orNull + var obj: Object = Option(cachedModel.get(refId)).map(_.obj).getOrElse( + () => cachedSummary.get(refId) + ) + if (obj == null && getOffloadingEnabled) { + val loadPath = offloadedModelsDir.resolve(refId) + if (Files.isDirectory(loadPath)) { + val className = Files.readString(loadPath.resolve(modelClassNameFile)) + obj = MLUtils.loadTransformer(sessionHolder, className, loadPath.toString) + cachedModel.put(refId, CacheItem(obj, estimateObjectSize(obj))) + } + } + obj } } @@ -107,7 +154,10 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { * the key used to look up the corresponding object */ def remove(refId: String): Boolean = { - val removed = cachedModel.remove(refId) + var removed: Object = cachedModel.remove(refId) + if (removed == null) { + removed = cachedSummary.remove(refId) + } // remove returns null if the key is not present removed != null } @@ -118,6 +168,7 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { def clear(): Int = { val size = cachedModel.size() cachedModel.clear() + cachedSummary.clear() size } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala index 73bc1f2086aef..1351522694121 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala @@ -385,7 +385,7 @@ class MLSuite extends MLHelper { val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) val memorySizeBytes = 1024 * 16 sessionHolder.session.conf - .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MAX_SIZE.key, memorySizeBytes) + .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MAX_IN_MEMORY_SIZE.key, memorySizeBytes) trainLogisticRegressionModel(sessionHolder) assert(sessionHolder.mlCache.cachedModel.size() == 1) assert(sessionHolder.mlCache.totalSizeBytes.get() > 0) From 5f115e1a4ffcfd9550db656854ac1130e7b270a8 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Tue, 29 Apr 2025 19:26:48 +0800 Subject: [PATCH 02/16] update Signed-off-by: Weichen Xu --- .../apache/spark/sql/connect/ml/MLCache.scala | 4 +++- .../apache/spark/sql/connect/ml/MLUtils.scala | 21 ++++++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala index 89cb34bb9f3ff..5f5b0b79ac56b 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala @@ -140,7 +140,9 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { val loadPath = offloadedModelsDir.resolve(refId) if (Files.isDirectory(loadPath)) { val className = Files.readString(loadPath.resolve(modelClassNameFile)) - obj = MLUtils.loadTransformer(sessionHolder, className, loadPath.toString) + obj = MLUtils.loadTransformer( + sessionHolder, className, loadPath.toString, loadFromLocal = true + ) cachedModel.put(refId, CacheItem(obj, estimateObjectSize(obj))) } } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala index fb9469cd480eb..56d3ba0763452 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala @@ -37,7 +37,7 @@ import org.apache.spark.ml.param.Params import org.apache.spark.ml.recommendation._ import org.apache.spark.ml.regression._ import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} -import org.apache.spark.ml.util.{ConnectHelper, HasTrainingSummary, Identifiable, MLWritable} +import org.apache.spark.ml.util.{ConnectHelper, HasTrainingSummary, Identifiable, MLReader, MLWritable} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.classic.Dataset import org.apache.spark.sql.connect.common.LiteralValueProtoConverter @@ -410,17 +410,23 @@ private[ml] object MLUtils { sessionHolder: SessionHolder, className: String, path: String, - operatorClass: Class[T]): T = { + operatorClass: Class[T], + loadFromLocal: Boolean = false): T = { val name = replaceOperator(sessionHolder, className) val operators = loadOperators(operatorClass) if (operators.isEmpty || !operators.contains(name)) { throw MlUnsupportedException(s"Unsupported read for $name") } try { - operators(name) - .getMethod("load", classOf[String]) - .invoke(null, path) - .asInstanceOf[T] + val clazz = operators(name) + val loaded = if (loadFromLocal) { + val loader = clazz.getMethod("read").invoke(null).asInstanceOf[MLReader[T]] + loader.loadFromLocal(path) + } else { + clazz.getMethod("load", classOf[String]) + .invoke(null, path) + } + loaded.asInstanceOf[T] } catch { case e: InvocationTargetException if e.getCause != null => throw e.getCause @@ -443,7 +449,8 @@ private[ml] object MLUtils { def loadTransformer( sessionHolder: SessionHolder, className: String, - path: String): Transformer = { + path: String, + loadFromLocal: Boolean = false): Transformer = { loadOperator(sessionHolder, className, path, classOf[Transformer]) } From 0697e418bf53d8794df61cbd9e2569ca278f33ff Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 30 Apr 2025 12:45:21 +0800 Subject: [PATCH 03/16] update Signed-off-by: Weichen Xu --- .../apache/spark/sql/connect/ml/MLCache.scala | 6 ++++- .../spark/sql/connect/ml/MLHandler.scala | 23 ++++++++++++++++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala index 5f5b0b79ac56b..e165bd09dba7d 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala @@ -106,7 +106,11 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { } cachedSummary.put(objectId, obj.asInstanceOf[Summary]) } else if (obj.isInstanceOf[Model[_]]) { - val sizeBytes = estimateObjectSize(obj) + val sizeBytes = if (getOffloadingEnabled) { + estimateObjectSize(obj) + } else { + 0L // Don't need to calculate size if disables offloading. + } cachedModel.put(objectId, CacheItem(obj, sizeBytes)) if (getOffloadingEnabled) { val savePath = offloadedModelsDir.resolve(objectId) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala index 5283639e4aa2b..cb07dfc5df59c 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.connect.ml import scala.collection.mutable import scala.jdk.CollectionConverters.CollectionHasAsScala - import org.apache.spark.connect.proto import org.apache.spark.internal.Logging import org.apache.spark.ml.Model @@ -27,6 +26,7 @@ import org.apache.spark.ml.param.{ParamMap, Params} import org.apache.spark.ml.util.{MLWritable, Summary} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.connect.common.LiteralValueProtoConverter +import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.ml.Serializer.deserializeMethodArguments import org.apache.spark.sql.connect.service.SessionHolder @@ -119,6 +119,9 @@ private[connect] object MLHandler extends Logging { mlCommand.getCommandCase match { case proto.MlCommand.CommandCase.FIT => + val offloadingEnabled = sessionHolder.session.conf.get( + Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED + ) val fitCmd = mlCommand.getFit val estimatorProto = fitCmd.getEstimator assert(estimatorProto.getType == proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR) @@ -126,6 +129,24 @@ private[connect] object MLHandler extends Logging { val dataset = MLUtils.parseRelationProto(fitCmd.getDataset, sessionHolder) val estimator = MLUtils.getEstimator(sessionHolder, estimatorProto, Some(fitCmd.getParams)) + if (offloadingEnabled) { + if (estimator.getClass.getName == "org.apache.spark.ml.fpm.FPGrowth") { + throw new UnsupportedOperationException( + "FPGrowth algorithm is not supported " + + "if Spark Connect model cache offloading is enabled." + ) + } + if ( + estimator.getClass.getName == "org.apache.spark.ml.clustering.LDA" + && estimator.asInstanceOf[org.apache.spark.ml.clustering.LDA] + .getOptimizer.toLowerCase() == "em" + ) { + throw new UnsupportedOperationException( + "LDA algorithm with 'em' optimizer is not supported " + + "if Spark Connect model cache offloading is enabled." + ) + } + } val model = estimator.fit(dataset).asInstanceOf[Model[_]] val id = mlCache.register(model) proto.MlCommandResult From 2d5ae261ba7976b1d4e1f8c26e012d5ddf0c1ff3 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 30 Apr 2025 17:12:17 +0800 Subject: [PATCH 04/16] update Signed-off-by: Weichen Xu --- .../scala/org/apache/spark/sql/connect/ml/MLCache.scala | 9 +++++++++ .../org/apache/spark/sql/connect/ml/MLHandler.scala | 1 + 2 files changed, 10 insertions(+) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala index e165bd09dba7d..8f620511d7406 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.connect.ml +import java.io.File import java.nio.file.{Files, Path} import java.util.UUID import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, TimeUnit} @@ -23,6 +24,7 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, TimeUnit} import scala.collection.mutable import com.google.common.cache.CacheBuilder +import org.apache.commons.io.FileUtils import org.apache.spark.internal.Logging import org.apache.spark.ml.Model @@ -163,6 +165,10 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { var removed: Object = cachedModel.remove(refId) if (removed == null) { removed = cachedSummary.remove(refId) + } else { + if (getOffloadingEnabled) { + FileUtils.deleteDirectory(new File(offloadedModelsDir.resolve(refId).toString)) + } } // remove returns null if the key is not present removed != null @@ -175,6 +181,9 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { val size = cachedModel.size() cachedModel.clear() cachedSummary.clear() + if (getOffloadingEnabled) { + FileUtils.cleanDirectory(new File(offloadedModelsDir.toString)) + } size } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala index cb07dfc5df59c..a605f06a23519 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.connect.ml import scala.collection.mutable import scala.jdk.CollectionConverters.CollectionHasAsScala + import org.apache.spark.connect.proto import org.apache.spark.internal.Logging import org.apache.spark.ml.Model From ad97210c5df2dc74426c293b2d18fea775db1664 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 30 Apr 2025 19:17:54 +0800 Subject: [PATCH 05/16] update Signed-off-by: Weichen Xu --- .../apache/spark/sql/connect/ml/MLCache.scala | 45 ++++++++++++++----- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala index 8f620511d7406..0b08b07b4cd72 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala @@ -20,10 +20,11 @@ import java.io.File import java.nio.file.{Files, Path} import java.util.UUID import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, TimeUnit} +import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import com.google.common.cache.CacheBuilder +import com.google.common.cache.{CacheBuilder, RemovalNotification} import org.apache.commons.io.FileUtils import org.apache.spark.internal.Logging @@ -41,6 +42,10 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { private val helperID = "______ML_CONNECT_HELPER______" private val modelClassNameFile = "__model_class_name__" + // TODO: rename it to `totalInMemorySizeBytes` because it only counts the in-memory + // part data size. + private[ml] val totalSizeBytes: AtomicLong = new AtomicLong(0) + val offloadedModelsDir: Path = Utils.createTempDir().toPath private def getOffloadingEnabled: Boolean = { sessionHolder.session.conf.get( @@ -69,6 +74,8 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { if (getOffloadingEnabled) { builder = builder + .removalListener((removed: RemovalNotification[String, CacheItem]) => + totalSizeBytes.addAndGet(-removed.getValue.sizeBytes)) .maximumWeight(getMaxInMemoryCacheSizeKB) .expireAfterAccess(getOffloadingTimeoutMinute, TimeUnit.MINUTES) } @@ -122,6 +129,7 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { obj.getClass.getName ) } + totalSizeBytes.addAndGet(sizeBytes) } else { throw new RuntimeException("'MLCache.register' only accepts model or summary objects.") } @@ -149,29 +157,46 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { obj = MLUtils.loadTransformer( sessionHolder, className, loadPath.toString, loadFromLocal = true ) - cachedModel.put(refId, CacheItem(obj, estimateObjectSize(obj))) + val sizeBytes = estimateObjectSize(obj) + cachedModel.put(refId, CacheItem(obj, sizeBytes)) + totalSizeBytes.addAndGet(sizeBytes) } } obj } } + def _removeModel(refId: String): Boolean = { + val removedModel = cachedModel.remove(refId) + val removedFromMem = removedModel != null + val removedFromDisk = if (getOffloadingEnabled) { + val offloadingPath = new File(offloadedModelsDir.resolve(refId).toString) + if (offloadingPath.exists()) { + FileUtils.deleteDirectory(offloadingPath) + true + } else { + false + } + } else { + false + } + removedFromMem || removedFromDisk + } + /** * Remove the object from MLCache * @param refId * the key used to look up the corresponding object */ def remove(refId: String): Boolean = { - var removed: Object = cachedModel.remove(refId) - if (removed == null) { - removed = cachedSummary.remove(refId) + val modelIsRemoved = _removeModel(refId) + + if (modelIsRemoved) { + true } else { - if (getOffloadingEnabled) { - FileUtils.deleteDirectory(new File(offloadedModelsDir.resolve(refId).toString)) - } + val removedSummary = cachedSummary.remove(refId) + removedSummary != null } - // remove returns null if the key is not present - removed != null } /** From 0736b37855d8f748304527e26db64a49a039e8b1 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 30 Apr 2025 19:36:12 +0800 Subject: [PATCH 06/16] update Signed-off-by: Weichen Xu --- .../org/apache/spark/sql/connect/ml/MLCache.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala index 0b08b07b4cd72..161c6b73dbf00 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.connect.ml import java.io.File -import java.nio.file.{Files, Path} +import java.nio.file.{Files, Path, Paths} import java.util.UUID import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, TimeUnit} import java.util.concurrent.atomic.AtomicLong @@ -46,7 +46,13 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { // part data size. private[ml] val totalSizeBytes: AtomicLong = new AtomicLong(0) - val offloadedModelsDir: Path = Utils.createTempDir().toPath + val offloadedModelsDir: Path = { + Paths.get( + System.getProperty("java.io.tmpdir"), + "spark_connect_model_cache", + sessionHolder.sessionId + ) + } private def getOffloadingEnabled: Boolean = { sessionHolder.session.conf.get( Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED From 281cb9b5d952cfd658474846831c71f6bfccf334 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 30 Apr 2025 21:25:16 +0800 Subject: [PATCH 07/16] format Signed-off-by: Weichen Xu --- .../src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala index 161c6b73dbf00..a7c4c69bc0f65 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala @@ -32,7 +32,6 @@ import org.apache.spark.ml.Model import org.apache.spark.ml.util.{ConnectHelper, MLWriter, Summary} import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.service.SessionHolder -import org.apache.spark.util.Utils /** * MLCache is for caching ML objects, typically for models and summaries evaluated by a model. From fbad3b5063b7069f3358729f2b1489c9980e5313 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Wed, 30 Apr 2025 23:15:20 +0800 Subject: [PATCH 08/16] update Signed-off-by: Weichen Xu --- python/pyspark/testing/connectutils.py | 6 ++--- .../spark/sql/connect/config/Connect.scala | 2 +- .../apache/spark/sql/connect/ml/MLCache.scala | 22 +++++++++---------- .../apache/spark/sql/connect/ml/MLSuite.scala | 6 +++-- 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index c0d91fb8bd149..ed32014fc8064 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -158,9 +158,9 @@ def conf(cls): # Set a static token for all tests so the parallelism doesn't overwrite each # tests' environment variables conf.set("spark.connect.authenticate.token", "deadbeef") - # Make the max size of ML Cache larger, to avoid CONNECT_ML.CACHE_INVALID issues - # in tests. - conf.set("spark.connect.session.connectML.mlCache.maxSize", "1g") + # Disable ml cache offloading, + # offloading hasn't support APIs like model.summary / model.evaluate + conf.set("spark.connect.session.connectML.mlCache.offloading.enabled", "false") return conf @classmethod diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index 2c6f65f184061..86f77c04570d1 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -364,5 +364,5 @@ object Connect { .version("4.1.0") .internal() .booleanConf - .createWithDefault(false) + .createWithDefault(true) } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala index a7c4c69bc0f65..8f98ec1d12e50 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala @@ -29,7 +29,7 @@ import org.apache.commons.io.FileUtils import org.apache.spark.internal.Logging import org.apache.spark.ml.Model -import org.apache.spark.ml.util.{ConnectHelper, MLWriter, Summary} +import org.apache.spark.ml.util.{ConnectHelper, MLWritable, Summary} import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.service.SessionHolder @@ -70,21 +70,21 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { private[ml] case class CacheItem(obj: Object, sizeBytes: Long) private[ml] val cachedModel: ConcurrentMap[String, CacheItem] = { - var builder = CacheBuilder - .newBuilder() - .softValues() - .weigher((key: String, value: CacheItem) => { - Math.ceil(value.sizeBytes.toDouble / 1024).toInt - }) - if (getOffloadingEnabled) { - builder = builder + CacheBuilder + .newBuilder() + .softValues() .removalListener((removed: RemovalNotification[String, CacheItem]) => totalSizeBytes.addAndGet(-removed.getValue.sizeBytes)) .maximumWeight(getMaxInMemoryCacheSizeKB) + .weigher((key: String, value: CacheItem) => { + Math.ceil(value.sizeBytes.toDouble / 1024).toInt + }) .expireAfterAccess(getOffloadingTimeoutMinute, TimeUnit.MINUTES) + .build[String, CacheItem]().asMap() + } else { + new ConcurrentHashMap[String, CacheItem]() } - builder.build[String, CacheItem]().asMap() } private[ml] val cachedSummary: ConcurrentMap[String, Summary] = { @@ -128,7 +128,7 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { cachedModel.put(objectId, CacheItem(obj, sizeBytes)) if (getOffloadingEnabled) { val savePath = offloadedModelsDir.resolve(objectId) - obj.asInstanceOf[MLWriter].saveToLocal(savePath.toString) + obj.asInstanceOf[MLWritable].write.saveToLocal(savePath.toString) Files.writeString( savePath.resolve(modelClassNameFile), obj.getClass.getName diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala index 1351522694121..5af64891f16a0 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala @@ -384,8 +384,10 @@ class MLSuite extends MLHelper { test("Memory limitation of MLCache works") { val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) val memorySizeBytes = 1024 * 16 - sessionHolder.session.conf - .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_MAX_IN_MEMORY_SIZE.key, memorySizeBytes) + sessionHolder.session.conf.set( + Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_MAX_IN_MEMORY_SIZE.key, + memorySizeBytes + ) trainLogisticRegressionModel(sessionHolder) assert(sessionHolder.mlCache.cachedModel.size() == 1) assert(sessionHolder.mlCache.totalSizeBytes.get() > 0) From 8f9e7a4407be92d69cc13aeeaa2fbdebad882cd1 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Thu, 1 May 2025 13:51:25 +0800 Subject: [PATCH 09/16] update Signed-off-by: Weichen Xu --- .../scala/org/apache/spark/sql/connect/ml/MLCache.scala | 5 +++-- .../scala/org/apache/spark/sql/connect/ml/MLHandler.scala | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala index 8f98ec1d12e50..78d81462531ee 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala @@ -46,11 +46,12 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { private[ml] val totalSizeBytes: AtomicLong = new AtomicLong(0) val offloadedModelsDir: Path = { - Paths.get( + val path = Paths.get( System.getProperty("java.io.tmpdir"), "spark_connect_model_cache", sessionHolder.sessionId ) + Files.createDirectories(path) } private def getOffloadingEnabled: Boolean = { sessionHolder.session.conf.get( @@ -153,7 +154,7 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { helper } else { var obj: Object = Option(cachedModel.get(refId)).map(_.obj).getOrElse( - () => cachedSummary.get(refId) + cachedSummary.get(refId) ) if (obj == null && getOffloadingEnabled) { val loadPath = offloadedModelsDir.resolve(refId) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala index a605f06a23519..9733ad9b92d18 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala @@ -43,7 +43,7 @@ private class AttributeHelper( val sessionHolder: SessionHolder, val objRef: String, val methods: Array[Method]) { - protected lazy val instance = { + protected def instance(): Object = { val obj = sessionHolder.mlCache.get(objRef) if (obj == null) { throw MLCacheInvalidException(s"object $objRef") @@ -53,7 +53,7 @@ private class AttributeHelper( // Get the attribute by reflection def getAttribute: Any = { assert(methods.length >= 1) - methods.foldLeft(instance) { (obj, m) => + methods.foldLeft(instance()) { (obj, m) => if (m.argValues.isEmpty) { MLUtils.invokeMethodAllowed(obj, m.name) } else { @@ -72,7 +72,7 @@ private class ModelAttributeHelper( def transform(relation: proto.MlRelation.Transform): DataFrame = { // Create a copied model to avoid concurrently modify model params. - val model = instance.asInstanceOf[Model[_]] + val model = instance().asInstanceOf[Model[_]] val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]] MLUtils.setInstanceParams(copiedModel, relation.getParams) val inputDF = MLUtils.parseRelationProto(relation.getInput, sessionHolder) From 4b0f2381472332ce25d3c327fc0f03767cdd7a93 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Thu, 1 May 2025 13:54:40 +0800 Subject: [PATCH 10/16] format Signed-off-by: Weichen Xu --- .../spark/sql/connect/config/Connect.scala | 15 ++++---- .../apache/spark/sql/connect/ml/MLCache.scala | 35 ++++++++----------- .../spark/sql/connect/ml/MLHandler.scala | 19 +++++----- .../apache/spark/sql/connect/ml/MLUtils.scala | 3 +- .../apache/spark/sql/connect/ml/MLSuite.scala | 3 +- 5 files changed, 34 insertions(+), 41 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index 86f77c04570d1..5b2daae1c0e47 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -336,10 +336,11 @@ object Connect { val CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_MAX_IN_MEMORY_SIZE = buildConf("spark.connect.session.connectML.mlCache.offloading.maxInMemorySize") - .doc("In-memory maximum size of the MLCache per session. The cache will offload the least " + - "recently used models to Spark driver local disk if the size exceeds this limit. " + - "The size is in bytes. This configuration only works when " + - "'spark.connect.session.connectML.mlCache.offloading.enabled' is 'true'.") + .doc( + "In-memory maximum size of the MLCache per session. The cache will offload the least " + + "recently used models to Spark driver local disk if the size exceeds this limit. " + + "The size is in bytes. This configuration only works when " + + "'spark.connect.session.connectML.mlCache.offloading.enabled' is 'true'.") .version("4.1.0") .internal() .bytesConf(ByteUnit.BYTE) @@ -350,9 +351,9 @@ object Connect { buildConf("spark.connect.session.connectML.mlCache.offloading.timeout") .doc( "Timeout of model offloading in MLCache. Models will be offloaded to Spark driver local " + - "disk if they are not used for this amount of time. The timeout is in minutes. " + - "This configuration only works when " + - "'spark.connect.session.connectML.mlCache.offloading.enabled' is 'true'.") + "disk if they are not used for this amount of time. The timeout is in minutes. " + + "This configuration only works when " + + "'spark.connect.session.connectML.mlCache.offloading.enabled' is 'true'.") .version("4.1.0") .internal() .timeConf(TimeUnit.MINUTES) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala index 78d81462531ee..52f85aaa8214d 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala @@ -49,20 +49,16 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { val path = Paths.get( System.getProperty("java.io.tmpdir"), "spark_connect_model_cache", - sessionHolder.sessionId - ) + sessionHolder.sessionId) Files.createDirectories(path) } private def getOffloadingEnabled: Boolean = { - sessionHolder.session.conf.get( - Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED - ) + sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED) } private def getMaxInMemoryCacheSizeKB: Long = { sessionHolder.session.conf.get( - Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_MAX_IN_MEMORY_SIZE - ) / 1024 + Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_MAX_IN_MEMORY_SIZE) / 1024 } private def getOffloadingTimeoutMinute: Long = { @@ -82,7 +78,8 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { Math.ceil(value.sizeBytes.toDouble / 1024).toInt }) .expireAfterAccess(getOffloadingTimeoutMinute, TimeUnit.MINUTES) - .build[String, CacheItem]().asMap() + .build[String, CacheItem]() + .asMap() } else { new ConcurrentHashMap[String, CacheItem]() } @@ -116,24 +113,21 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { if (getOffloadingEnabled) { throw new RuntimeException( "SparkML 'model.summary' and 'model.evaluate' APIs are not supported' when " + - "Spark Connect session ML cache offloading is enabled. You can use APIs in " + - "'pyspark.ml.evaluation' instead.") + "Spark Connect session ML cache offloading is enabled. You can use APIs in " + + "'pyspark.ml.evaluation' instead.") } cachedSummary.put(objectId, obj.asInstanceOf[Summary]) } else if (obj.isInstanceOf[Model[_]]) { val sizeBytes = if (getOffloadingEnabled) { estimateObjectSize(obj) } else { - 0L // Don't need to calculate size if disables offloading. + 0L // Don't need to calculate size if disables offloading. } cachedModel.put(objectId, CacheItem(obj, sizeBytes)) if (getOffloadingEnabled) { val savePath = offloadedModelsDir.resolve(objectId) obj.asInstanceOf[MLWritable].write.saveToLocal(savePath.toString) - Files.writeString( - savePath.resolve(modelClassNameFile), - obj.getClass.getName - ) + Files.writeString(savePath.resolve(modelClassNameFile), obj.getClass.getName) } totalSizeBytes.addAndGet(sizeBytes) } else { @@ -153,16 +147,17 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { if (refId == helperID) { helper } else { - var obj: Object = Option(cachedModel.get(refId)).map(_.obj).getOrElse( - cachedSummary.get(refId) - ) + var obj: Object = + Option(cachedModel.get(refId)).map(_.obj).getOrElse(cachedSummary.get(refId)) if (obj == null && getOffloadingEnabled) { val loadPath = offloadedModelsDir.resolve(refId) if (Files.isDirectory(loadPath)) { val className = Files.readString(loadPath.resolve(modelClassNameFile)) obj = MLUtils.loadTransformer( - sessionHolder, className, loadPath.toString, loadFromLocal = true - ) + sessionHolder, + className, + loadPath.toString, + loadFromLocal = true) val sizeBytes = estimateObjectSize(obj) cachedModel.put(refId, CacheItem(obj, sizeBytes)) totalSizeBytes.addAndGet(sizeBytes) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala index 9733ad9b92d18..75c49e35832da 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala @@ -121,8 +121,7 @@ private[connect] object MLHandler extends Logging { mlCommand.getCommandCase match { case proto.MlCommand.CommandCase.FIT => val offloadingEnabled = sessionHolder.session.conf.get( - Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED - ) + Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED) val fitCmd = mlCommand.getFit val estimatorProto = fitCmd.getEstimator assert(estimatorProto.getType == proto.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR) @@ -134,18 +133,16 @@ private[connect] object MLHandler extends Logging { if (estimator.getClass.getName == "org.apache.spark.ml.fpm.FPGrowth") { throw new UnsupportedOperationException( "FPGrowth algorithm is not supported " + - "if Spark Connect model cache offloading is enabled." - ) + "if Spark Connect model cache offloading is enabled.") } - if ( - estimator.getClass.getName == "org.apache.spark.ml.clustering.LDA" - && estimator.asInstanceOf[org.apache.spark.ml.clustering.LDA] - .getOptimizer.toLowerCase() == "em" - ) { + if (estimator.getClass.getName == "org.apache.spark.ml.clustering.LDA" + && estimator + .asInstanceOf[org.apache.spark.ml.clustering.LDA] + .getOptimizer + .toLowerCase() == "em") { throw new UnsupportedOperationException( "LDA algorithm with 'em' optimizer is not supported " + - "if Spark Connect model cache offloading is enabled." - ) + "if Spark Connect model cache offloading is enabled.") } } val model = estimator.fit(dataset).asInstanceOf[Model[_]] diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala index 56d3ba0763452..fdeffbba1dda9 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala @@ -423,7 +423,8 @@ private[ml] object MLUtils { val loader = clazz.getMethod("read").invoke(null).asInstanceOf[MLReader[T]] loader.loadFromLocal(path) } else { - clazz.getMethod("load", classOf[String]) + clazz + .getMethod("load", classOf[String]) .invoke(null, path) } loaded.asInstanceOf[T] diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala index 5af64891f16a0..b6908691ceee5 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala @@ -386,8 +386,7 @@ class MLSuite extends MLHelper { val memorySizeBytes = 1024 * 16 sessionHolder.session.conf.set( Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_MAX_IN_MEMORY_SIZE.key, - memorySizeBytes - ) + memorySizeBytes) trainLogisticRegressionModel(sessionHolder) assert(sessionHolder.mlCache.cachedModel.size() == 1) assert(sessionHolder.mlCache.totalSizeBytes.get() > 0) From 060844ff8a1629fd1a7e2403f219d0dabe5e04b1 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Thu, 1 May 2025 14:34:28 +0800 Subject: [PATCH 11/16] update Signed-off-by: Weichen Xu --- .../ml/tests/connect/test_parity_tuning.py | 11 +++++++++++ python/pyspark/testing/connectutils.py | 2 +- .../apache/spark/sql/connect/ml/MLCache.scala | 4 +++- .../apache/spark/sql/connect/ml/MLUtils.scala | 4 +++- .../apache/spark/sql/connect/ml/MLSuite.scala | 17 ++++++++++++----- 5 files changed, 30 insertions(+), 8 deletions(-) diff --git a/python/pyspark/ml/tests/connect/test_parity_tuning.py b/python/pyspark/ml/tests/connect/test_parity_tuning.py index 2d21644ceed53..801d4fcd584ff 100644 --- a/python/pyspark/ml/tests/connect/test_parity_tuning.py +++ b/python/pyspark/ml/tests/connect/test_parity_tuning.py @@ -25,6 +25,17 @@ class TuningParityTests(TuningTestsMixin, ReusedConnectTestCase): pass +class TuningParityWithMLCacheOffloadingEnabledTests(TuningTestsMixin, ReusedConnectTestCase): + + @classmethod + def conf(cls): + conf = super().conf() + conf.set("spark.connect.session.connectML.mlCache.offloading.enabled", "true") + conf.set("spark.connect.session.connectML.mlCache.offloading.maxInMemorySize", "1024") + conf.set("spark.connect.session.connectML.mlCache.offloading.timeout", "1") + return conf + + if __name__ == "__main__": from pyspark.ml.tests.connect.test_parity_tuning import * # noqa: F401 diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index ed32014fc8064..ce727e78605e6 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -159,7 +159,7 @@ def conf(cls): # tests' environment variables conf.set("spark.connect.authenticate.token", "deadbeef") # Disable ml cache offloading, - # offloading hasn't support APIs like model.summary / model.evaluate + # offloading hasn't supported APIs like model.summary / model.evaluate conf.set("spark.connect.session.connectML.mlCache.offloading.enabled", "false") return conf diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala index 52f85aaa8214d..4e270f455d568 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala @@ -114,7 +114,9 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { throw new RuntimeException( "SparkML 'model.summary' and 'model.evaluate' APIs are not supported' when " + "Spark Connect session ML cache offloading is enabled. You can use APIs in " + - "'pyspark.ml.evaluation' instead.") + "'pyspark.ml.evaluation' instead, or you can set Spark config " + + "'spark.connect.session.connectML.mlCache.offloading.enabled' to 'false' to " + + "disable Spark Connect session ML cache offloading.") } cachedSummary.put(objectId, obj.asInstanceOf[Summary]) } else if (obj.isInstanceOf[Model[_]]) { diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala index fdeffbba1dda9..c5f78c79b54c6 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala @@ -452,7 +452,9 @@ private[ml] object MLUtils { className: String, path: String, loadFromLocal: Boolean = false): Transformer = { - loadOperator(sessionHolder, className, path, classOf[Transformer]) + loadOperator( + sessionHolder, className, path, classOf[Transformer], loadFromLocal = loadFromLocal + ) } /** diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala index b6908691ceee5..40f1a1fa8549a 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala @@ -381,13 +381,14 @@ class MLSuite extends MLHelper { .toArray sameElements Array("a", "b", "c")) } - test("Memory limitation of MLCache works") { + test("MLCache offloading works") { val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) val memorySizeBytes = 1024 * 16 sessionHolder.session.conf.set( Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_MAX_IN_MEMORY_SIZE.key, memorySizeBytes) - trainLogisticRegressionModel(sessionHolder) + val modelIdList = scala.collection.mutable.ListBuffer[String]() + modelIdList.append(trainLogisticRegressionModel(sessionHolder)) assert(sessionHolder.mlCache.cachedModel.size() == 1) assert(sessionHolder.mlCache.totalSizeBytes.get() > 0) val modelSizeBytes = sessionHolder.mlCache.totalSizeBytes.get() @@ -395,18 +396,24 @@ class MLSuite extends MLHelper { // All models will be kept if the total size is less than the memory limit. for (i <- 1 until maxNumModels) { - trainLogisticRegressionModel(sessionHolder) + modelIdList.append(trainLogisticRegressionModel(sessionHolder)) assert(sessionHolder.mlCache.cachedModel.size() == i + 1) assert(sessionHolder.mlCache.totalSizeBytes.get() > 0) assert(sessionHolder.mlCache.totalSizeBytes.get() <= memorySizeBytes) } - // Old models will be removed if new ones are added and the total size exceeds the memory limit. + // Old models will be offloaded + // if new ones are added and the total size exceeds the memory limit. for (_ <- 0 until 3) { - trainLogisticRegressionModel(sessionHolder) + modelIdList.append(trainLogisticRegressionModel(sessionHolder)) assert(sessionHolder.mlCache.cachedModel.size() == maxNumModels) assert(sessionHolder.mlCache.totalSizeBytes.get() > 0) assert(sessionHolder.mlCache.totalSizeBytes.get() <= memorySizeBytes) } + + // Assert all models can be loaded back from disk after they are offloaded. + for (modelId <- modelIdList) { + assert(sessionHolder.mlCache.get(modelId) != null) + } } } From 64c18e53388a85587db8179dfea297fdafee44d4 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Thu, 1 May 2025 14:48:38 +0800 Subject: [PATCH 12/16] refine err Signed-off-by: Weichen Xu --- .../apache/spark/sql/connect/ml/MLCache.scala | 22 +++++++++++-------- .../spark/sql/connect/ml/MLHandler.scala | 3 +++ 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala index 4e270f455d568..5099e83985c2d 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala @@ -52,7 +52,7 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { sessionHolder.sessionId) Files.createDirectories(path) } - private def getOffloadingEnabled: Boolean = { + private[spark] def getOffloadingEnabled: Boolean = { sessionHolder.session.conf.get(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED) } @@ -99,6 +99,17 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { } } + private[spark] def checkSummaryAvail(): Unit = { + if (getOffloadingEnabled) { + throw new RuntimeException( + "SparkML 'model.summary' and 'model.evaluate' APIs are not supported' when " + + "Spark Connect session ML cache offloading is enabled. You can use APIs in " + + "'pyspark.ml.evaluation' instead, or you can set Spark config " + + "'spark.connect.session.connectML.mlCache.offloading.enabled' to 'false' to " + + "disable Spark Connect session ML cache offloading.") + } + } + /** * Cache an object into a map of MLCache, and return its key * @param obj @@ -110,14 +121,7 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { val objectId = UUID.randomUUID().toString if (obj.isInstanceOf[Summary]) { - if (getOffloadingEnabled) { - throw new RuntimeException( - "SparkML 'model.summary' and 'model.evaluate' APIs are not supported' when " + - "Spark Connect session ML cache offloading is enabled. You can use APIs in " + - "'pyspark.ml.evaluation' instead, or you can set Spark config " + - "'spark.connect.session.connectML.mlCache.offloading.enabled' to 'false' to " + - "disable Spark Connect session ML cache offloading.") - } + checkSummaryAvail() cachedSummary.put(objectId, obj.asInstanceOf[Summary]) } else if (obj.isInstanceOf[Model[_]]) { val sizeBytes = if (getOffloadingEnabled) { diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala index 75c49e35832da..9db55e8e45bc0 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala @@ -54,6 +54,9 @@ private class AttributeHelper( def getAttribute: Any = { assert(methods.length >= 1) methods.foldLeft(instance()) { (obj, m) => + if (obj.isInstanceOf[Summary]) { + sessionHolder.mlCache.checkSummaryAvail() + } if (m.argValues.isEmpty) { MLUtils.invokeMethodAllowed(obj, m.name) } else { From 6e3902fc63487289522ac3e88598e7868a907519 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Thu, 1 May 2025 14:59:40 +0800 Subject: [PATCH 13/16] update Signed-off-by: Weichen Xu --- .../org/apache/spark/sql/connect/ml/MLSuite.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala index 40f1a1fa8549a..32407f87a336e 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala @@ -135,6 +135,9 @@ class MLSuite extends MLHelper { // Estimator/Model works test("LogisticRegression works") { val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) + sessionHolder.session.conf.set( + Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED.key, + "false") // estimator read/write val ret = readWrite(sessionHolder, getLogisticRegression, getMaxIter) @@ -259,6 +262,9 @@ class MLSuite extends MLHelper { test("Exception: cannot retrieve object") { val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) + sessionHolder.session.conf.set( + Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED.key, + "false") val modelId = trainLogisticRegressionModel(sessionHolder) // Fetch summary attribute @@ -383,6 +389,10 @@ class MLSuite extends MLHelper { test("MLCache offloading works") { val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) + sessionHolder.session.conf.set( + Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED.key, + "true") + val memorySizeBytes = 1024 * 16 sessionHolder.session.conf.set( Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_MAX_IN_MEMORY_SIZE.key, From 83fc19fbd27f542a108ac162f73bfaad87cd9662 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Thu, 1 May 2025 17:46:18 +0800 Subject: [PATCH 14/16] format Signed-off-by: Weichen Xu --- .../org/apache/spark/sql/connect/ml/MLUtils.scala | 7 +++++-- .../org/apache/spark/sql/connect/ml/MLSuite.scala | 15 ++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala index c5f78c79b54c6..17f4b765830e0 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala @@ -453,8 +453,11 @@ private[ml] object MLUtils { path: String, loadFromLocal: Boolean = false): Transformer = { loadOperator( - sessionHolder, className, path, classOf[Transformer], loadFromLocal = loadFromLocal - ) + sessionHolder, + className, + path, + classOf[Transformer], + loadFromLocal = loadFromLocal) } /** diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala index 32407f87a336e..36b198b0db31a 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ml/MLSuite.scala @@ -135,9 +135,8 @@ class MLSuite extends MLHelper { // Estimator/Model works test("LogisticRegression works") { val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) - sessionHolder.session.conf.set( - Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED.key, - "false") + sessionHolder.session.conf + .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED.key, "false") // estimator read/write val ret = readWrite(sessionHolder, getLogisticRegression, getMaxIter) @@ -262,9 +261,8 @@ class MLSuite extends MLHelper { test("Exception: cannot retrieve object") { val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) - sessionHolder.session.conf.set( - Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED.key, - "false") + sessionHolder.session.conf + .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED.key, "false") val modelId = trainLogisticRegressionModel(sessionHolder) // Fetch summary attribute @@ -389,9 +387,8 @@ class MLSuite extends MLHelper { test("MLCache offloading works") { val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) - sessionHolder.session.conf.set( - Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED.key, - "true") + sessionHolder.session.conf + .set(Connect.CONNECT_SESSION_CONNECT_ML_CACHE_OFFLOADING_ENABLED.key, "true") val memorySizeBytes = 1024 * 16 sessionHolder.session.conf.set( From cf015b568f1d7cfef3a0b761a5d548d9fc1a45af Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Thu, 1 May 2025 20:05:26 +0800 Subject: [PATCH 15/16] format Signed-off-by: Weichen Xu --- python/pyspark/ml/tests/connect/test_parity_tuning.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/ml/tests/connect/test_parity_tuning.py b/python/pyspark/ml/tests/connect/test_parity_tuning.py index 801d4fcd584ff..9db783666d221 100644 --- a/python/pyspark/ml/tests/connect/test_parity_tuning.py +++ b/python/pyspark/ml/tests/connect/test_parity_tuning.py @@ -26,7 +26,6 @@ class TuningParityTests(TuningTestsMixin, ReusedConnectTestCase): class TuningParityWithMLCacheOffloadingEnabledTests(TuningTestsMixin, ReusedConnectTestCase): - @classmethod def conf(cls): conf = super().conf() From a1577be97846b3cb4fc329c00d6aa42fa9f119a5 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Fri, 2 May 2025 14:09:40 +0800 Subject: [PATCH 16/16] update Signed-off-by: Weichen Xu --- .../main/scala/org/apache/spark/sql/connect/ml/MLCache.scala | 2 +- .../scala/org/apache/spark/sql/connect/ml/MLHandler.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala index 5099e83985c2d..2379284b62b02 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala @@ -101,7 +101,7 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { private[spark] def checkSummaryAvail(): Unit = { if (getOffloadingEnabled) { - throw new RuntimeException( + throw MlUnsupportedException( "SparkML 'model.summary' and 'model.evaluate' APIs are not supported' when " + "Spark Connect session ML cache offloading is enabled. You can use APIs in " + "'pyspark.ml.evaluation' instead, or you can set Spark config " + diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala index 9db55e8e45bc0..44595f3418318 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala @@ -134,7 +134,7 @@ private[connect] object MLHandler extends Logging { MLUtils.getEstimator(sessionHolder, estimatorProto, Some(fitCmd.getParams)) if (offloadingEnabled) { if (estimator.getClass.getName == "org.apache.spark.ml.fpm.FPGrowth") { - throw new UnsupportedOperationException( + throw MlUnsupportedException( "FPGrowth algorithm is not supported " + "if Spark Connect model cache offloading is enabled.") } @@ -143,7 +143,7 @@ private[connect] object MLHandler extends Logging { .asInstanceOf[org.apache.spark.ml.clustering.LDA] .getOptimizer .toLowerCase() == "em") { - throw new UnsupportedOperationException( + throw MlUnsupportedException( "LDA algorithm with 'em' optimizer is not supported " + "if Spark Connect model cache offloading is enabled.") }