diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index ce2a3a2e40411..7c81cb96e07f2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -212,14 +212,13 @@ object CrossValidator extends MLReadable[CrossValidator] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) - val numFolds = (metadata.params \ "numFolds").extract[Int] - val seed = (metadata.params \ "seed").extract[Long] - new CrossValidator(metadata.uid) + val cv = new CrossValidator(metadata.uid) .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - .setNumFolds(numFolds) - .setSeed(seed) + DefaultParamsReader.getAndSetParams(cv, metadata, + skipParams = Option(List("estimatorParamMaps"))) + cv } } } @@ -302,17 +301,17 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) - val numFolds = (metadata.params \ "numFolds").extract[Int] - val seed = (metadata.params \ "seed").extract[Long] val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray + val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - .set(model.numFolds, numFolds) - .set(model.seed, seed) + DefaultParamsReader.getAndSetParams(model, metadata, + skipParams = Option(List("estimatorParamMaps"))) + model } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 16db0f5f12c77..6e3ad40706803 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.tuning +import java.io.IOException import java.util.{List => JList} import scala.collection.JavaConverters._ @@ -207,14 +208,13 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) - val trainRatio = (metadata.params \ "trainRatio").extract[Double] - val seed = (metadata.params \ "seed").extract[Long] - new TrainValidationSplit(metadata.uid) + val tvs = new TrainValidationSplit(metadata.uid) .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) - .setTrainRatio(trainRatio) - .setSeed(seed) + DefaultParamsReader.getAndSetParams(tvs, metadata, + skipParams = Option(List("estimatorParamMaps"))) + tvs } } } @@ -295,17 +295,17 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) - val trainRatio = (metadata.params \ "trainRatio").extract[Double] - val seed = (metadata.params \ "seed").extract[Long] val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray + val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics) model.set(model.estimator, estimator) .set(model.evaluator, evaluator) .set(model.estimatorParamMaps, estimatorParamMaps) - .set(model.trainRatio, trainRatio) - .set(model.seed, seed) + DefaultParamsReader.getAndSetParams(model, metadata, + skipParams = Option(List("estimatorParamMaps"))) + model } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 0ab6eed959381..363304ef10147 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -150,20 +150,14 @@ private[ml] object ValidatorParams { }.toSeq )) - val validatorSpecificParams = instance match { - case cv: CrossValidatorParams => - List("numFolds" -> parse(cv.numFolds.jsonEncode(cv.getNumFolds))) - case tvs: TrainValidationSplitParams => - List("trainRatio" -> parse(tvs.trainRatio.jsonEncode(tvs.getTrainRatio))) - case _ => - // This should not happen. - throw new NotImplementedError("ValidatorParams.saveImpl does not handle type: " + - instance.getClass.getCanonicalName) - } - - val jsonParams = validatorSpecificParams ++ List( - "estimatorParamMaps" -> parse(estimatorParamMapsJson), - "seed" -> parse(instance.seed.jsonEncode(instance.getSeed))) + val params = instance.extractParamMap().toSeq + val skipParams = List("estimator", "evaluator", "estimatorParamMaps") + val jsonParams = render(params + .filter { case ParamPair(p, v) => !skipParams.contains(p.name)} + .map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList ++ List("estimatorParamMaps" -> parse(estimatorParamMapsJson)) + ) DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 65f142cfbbcb6..7188da3531267 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -396,17 +396,27 @@ private[ml] object DefaultParamsReader { /** * Extract Params from metadata, and set them in the instance. - * This works if all Params implement [[org.apache.spark.ml.param.Param.jsonDecode()]]. + * This works if all Params (except params included by `skipParams` list) implement + * [[org.apache.spark.ml.param.Param.jsonDecode()]]. + * + * @param skipParams The params included in `skipParams` won't be set. This is useful if some + * params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]] + * and need special handling. * TODO: Move to [[Metadata]] method */ - def getAndSetParams(instance: Params, metadata: Metadata): Unit = { + def getAndSetParams( + instance: Params, + metadata: Metadata, + skipParams: Option[List[String]] = None): Unit = { implicit val format = DefaultFormats metadata.params match { case JObject(pairs) => pairs.foreach { case (paramName, jsonValue) => - val param = instance.getParam(paramName) - val value = param.jsonDecode(compact(render(jsonValue))) - instance.set(param, value) + if (skipParams == None || !skipParams.get.contains(paramName)) { + val param = instance.getParam(paramName) + val value = param.jsonDecode(compact(render(jsonValue))) + instance.set(param, value) + } } case _ => throw new IllegalArgumentException( diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index a8d4377cff2d1..a01744f7b67fd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -159,12 +159,15 @@ class CrossValidatorSuite .setEvaluator(evaluator) .setNumFolds(20) .setEstimatorParamMaps(paramMaps) + .setSeed(42L) + .setParallelism(2) val cv2 = testDefaultReadWrite(cv, testParams = false) assert(cv.uid === cv2.uid) assert(cv.getNumFolds === cv2.getNumFolds) assert(cv.getSeed === cv2.getSeed) + assert(cv.getParallelism === cv2.getParallelism) assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 74801733381c1..2ed4fbb601b61 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressio import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} import org.apache.spark.ml.linalg.Vectors -import org.apache.spark.ml.param.{ParamMap} +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} @@ -160,11 +160,13 @@ class TrainValidationSplitSuite .setTrainRatio(0.5) .setEstimatorParamMaps(paramMaps) .setSeed(42L) + .setParallelism(2) val tvs2 = testDefaultReadWrite(tvs, testParams = false) assert(tvs.getTrainRatio === tvs2.getTrainRatio) assert(tvs.getSeed === tvs2.getSeed) + assert(tvs.getParallelism === tvs2.getParallelism) ValidatorParamsSuiteHelpers .compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps)