diff --git a/build.sbt b/build.sbt index f3eab10..35b1508 100644 --- a/build.sbt +++ b/build.sbt @@ -15,7 +15,7 @@ crossScalaVersions := Seq("2.11.6") javacOptions ++= Seq("-source", "1.8", "-target", "1.8") //tag::sparkVersion[] -sparkVersion := "2.1.0" +sparkVersion := "2.2.0" //end::sparkVersion[] //tag::sparkComponents[] @@ -40,7 +40,7 @@ libraryDependencies ++= Seq( "org.scalacheck" %% "scalacheck" % "1.13.4", "junit" % "junit" % "4.12", "junit" % "junit" % "4.11", - "com.holdenkarau" %% "spark-testing-base" % "2.1.0_0.6.0", + "com.holdenkarau" %% "spark-testing-base" % "2.2.0_0.7.2", "com.novocode" % "junit-interface" % "0.11" % "test->default", //tag::scalaLogging[] "com.typesafe.scala-logging" %% "scala-logging" % "3.5.0", diff --git a/build_windows.sbt b/build_windows.sbt index a867c9f..b698ab9 100644 --- a/build_windows.sbt +++ b/build_windows.sbt @@ -15,7 +15,7 @@ crossScalaVersions := Seq("2.11.6") javacOptions ++= Seq("-source", "1.8", "-target", "1.8") //tag::sparkVersion[] -sparkVersion := "2.1.0" +sparkVersion := "2.2.0" //end::sparkVersion[] //tag::sparkComponents[] @@ -40,7 +40,7 @@ libraryDependencies ++= Seq( "org.scalacheck" %% "scalacheck" % "1.13.4", "junit" % "junit" % "4.12", "junit" % "junit" % "4.11", - "com.holdenkarau" %% "spark-testing-base" % "2.1.0_0.6.0", + "com.holdenkarau" %% "spark-testing-base" % "2.2.0_0.7.2", "com.novocode" % "junit-interface" % "0.11" % "test->default", //tag::sacalLogging[] "com.typesafe.scala-logging" %% "scala-logging" % "3.5.0", diff --git a/src/main/scala/com/high-performance-spark-examples/ml/CustomPipeline.scala b/src/main/scala/com/high-performance-spark-examples/ml/CustomPipeline.scala index a7d80e3..2b87a7e 100644 --- a/src/main/scala/com/high-performance-spark-examples/ml/CustomPipeline.scala +++ b/src/main/scala/com/high-performance-spark-examples/ml/CustomPipeline.scala @@ -122,9 +122,10 @@ class SimpleIndexer(override val uid: String) import dataset.sparkSession.implicits._ val words = dataset.select(dataset($(inputCol)).as[String]).distinct .collect() + // Construct the model val model = new SimpleIndexerModel(uid, words) - this.copyValues(model) - model + // Copy the parameters to the model + copyValues(model) } } diff --git a/src/main/scala/com/high-performance-spark-examples/ml/SimpleNaiveBayes.scala b/src/main/scala/com/high-performance-spark-examples/ml/SimpleNaiveBayes.scala index 7f0cfbb..13e937f 100644 --- a/src/main/scala/com/high-performance-spark-examples/ml/SimpleNaiveBayes.scala +++ b/src/main/scala/com/high-performance-spark-examples/ml/SimpleNaiveBayes.scala @@ -82,11 +82,14 @@ class SimpleNaiveBayes(val uid: String) // Unpersist now that we are done computing everything ds.unpersist() // Construct a model - new SimpleNaiveBayesModel(uid, numClasses, numFeatures, Vectors.dense(pi), + val model = new SimpleNaiveBayesModel( + uid, numClasses, numFeatures, Vectors.dense(pi), new DenseMatrix(numClasses, theta(0).length, theta.flatten, true)) + // Copy the params values to the model + copyValues(model) } - override def copy(extra: ParamMap) = { + override def copy(extra: ParamMap): SimpleNaiveBayes = { defaultCopy(extra) } } @@ -100,8 +103,9 @@ case class SimpleNaiveBayesModel( val theta: DenseMatrix) extends ClassificationModel[Vector, SimpleNaiveBayesModel] { - override def copy(extra: ParamMap) = { - defaultCopy(extra) + override def copy(extra: ParamMap): SimpleNaiveBayesModel = { + val copied = new SimpleNaiveBayesModel(uid, numClasses, numFeatures, pi, theta) + copyValues(copied, extra).setParent(parent) } // We have to do some tricks here because we are using Spark's diff --git a/src/test/scala/com/high-performance-spark-examples/ml/SimpleNaiveBayes.scala b/src/test/scala/com/high-performance-spark-examples/ml/SimpleNaiveBayes.scala index 783dfb2..1fa296a 100644 --- a/src/test/scala/com/high-performance-spark-examples/ml/SimpleNaiveBayes.scala +++ b/src/test/scala/com/high-performance-spark-examples/ml/SimpleNaiveBayes.scala @@ -9,6 +9,7 @@ import com.holdenkarau.spark.testing._ import org.apache.spark.ml._ import org.apache.spark.ml.feature._ +import org.apache.spark.ml.param._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} import org.scalatest.Matchers._ @@ -30,14 +31,17 @@ class SimpleNaiveBayesSuite extends FunSuite with DataFrameSuiteBase { val ds: Dataset[MiniPanda] = session.createDataset(miniPandasList) val assembler = new VectorAssembler() assembler.setInputCols(Array("fuzzy", "old")) - assembler.setOutputCol("features") + assembler.setOutputCol("magical_features") val snb = new SimpleNaiveBayes() snb.setLabelCol("happy") - snb.setFeaturesCol("features") + snb.setFeaturesCol("magical_features") val pipeline = new Pipeline().setStages(Array(assembler, snb)) val model = pipeline.fit(ds) val test = ds.select("fuzzy", "old") val predicted = model.transform(test) - println(predicted.collect()) + assert(predicted.count() === miniPandasList.size) + val nbModel = model.stages(1).asInstanceOf[SimpleNaiveBayesModel] + assert(nbModel.getFeaturesCol === "magical_features") + assert(nbModel.copy(ParamMap.empty).getFeaturesCol === "magical_features") } }