Skip to content

Commit d9893b8

Browse files
authored
Upgrade to Spark 2.2 Fix copying the params to the model (#92)
* Upgrade to Spark 2.2 * Fix copying the params to the model * Fix long line * Comment * Simplify * Add a test to verify we are copy the params correctly * test copy while we are there * And the testing lib too * Fix some tests/copy methods * Push to 2.2.0 0.7.2 for windows build too * Fix copy
1 parent 04f7d83 commit d9893b8

File tree

5 files changed

+22
-13
lines changed

5 files changed

+22
-13
lines changed

build.sbt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ crossScalaVersions := Seq("2.11.6")
1515
javacOptions ++= Seq("-source", "1.8", "-target", "1.8")
1616

1717
//tag::sparkVersion[]
18-
sparkVersion := "2.1.0"
18+
sparkVersion := "2.2.0"
1919
//end::sparkVersion[]
2020

2121
//tag::sparkComponents[]
@@ -40,7 +40,7 @@ libraryDependencies ++= Seq(
4040
"org.scalacheck" %% "scalacheck" % "1.13.4",
4141
"junit" % "junit" % "4.12",
4242
"junit" % "junit" % "4.11",
43-
"com.holdenkarau" %% "spark-testing-base" % "2.1.0_0.6.0",
43+
"com.holdenkarau" %% "spark-testing-base" % "2.2.0_0.7.2",
4444
"com.novocode" % "junit-interface" % "0.11" % "test->default",
4545
//tag::scalaLogging[]
4646
"com.typesafe.scala-logging" %% "scala-logging" % "3.5.0",

build_windows.sbt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ crossScalaVersions := Seq("2.11.6")
1515
javacOptions ++= Seq("-source", "1.8", "-target", "1.8")
1616

1717
//tag::sparkVersion[]
18-
sparkVersion := "2.1.0"
18+
sparkVersion := "2.2.0"
1919
//end::sparkVersion[]
2020

2121
//tag::sparkComponents[]
@@ -40,7 +40,7 @@ libraryDependencies ++= Seq(
4040
"org.scalacheck" %% "scalacheck" % "1.13.4",
4141
"junit" % "junit" % "4.12",
4242
"junit" % "junit" % "4.11",
43-
"com.holdenkarau" %% "spark-testing-base" % "2.1.0_0.6.0",
43+
"com.holdenkarau" %% "spark-testing-base" % "2.2.0_0.7.2",
4444
"com.novocode" % "junit-interface" % "0.11" % "test->default",
4545
//tag::sacalLogging[]
4646
"com.typesafe.scala-logging" %% "scala-logging" % "3.5.0",

src/main/scala/com/high-performance-spark-examples/ml/CustomPipeline.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,10 @@ class SimpleIndexer(override val uid: String)
122122
import dataset.sparkSession.implicits._
123123
val words = dataset.select(dataset($(inputCol)).as[String]).distinct
124124
.collect()
125+
// Construct the model
125126
val model = new SimpleIndexerModel(uid, words)
126-
this.copyValues(model)
127-
model
127+
// Copy the parameters to the model
128+
copyValues(model)
128129
}
129130
}
130131

src/main/scala/com/high-performance-spark-examples/ml/SimpleNaiveBayes.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,14 @@ class SimpleNaiveBayes(val uid: String)
8282
// Unpersist now that we are done computing everything
8383
ds.unpersist()
8484
// Construct a model
85-
new SimpleNaiveBayesModel(uid, numClasses, numFeatures, Vectors.dense(pi),
85+
val model = new SimpleNaiveBayesModel(
86+
uid, numClasses, numFeatures, Vectors.dense(pi),
8687
new DenseMatrix(numClasses, theta(0).length, theta.flatten, true))
88+
// Copy the params values to the model
89+
copyValues(model)
8790
}
8891

89-
override def copy(extra: ParamMap) = {
92+
override def copy(extra: ParamMap): SimpleNaiveBayes = {
9093
defaultCopy(extra)
9194
}
9295
}
@@ -100,8 +103,9 @@ case class SimpleNaiveBayesModel(
100103
val theta: DenseMatrix) extends
101104
ClassificationModel[Vector, SimpleNaiveBayesModel] {
102105

103-
override def copy(extra: ParamMap) = {
104-
defaultCopy(extra)
106+
override def copy(extra: ParamMap): SimpleNaiveBayesModel = {
107+
val copied = new SimpleNaiveBayesModel(uid, numClasses, numFeatures, pi, theta)
108+
copyValues(copied, extra).setParent(parent)
105109
}
106110

107111
// We have to do some tricks here because we are using Spark's

src/test/scala/com/high-performance-spark-examples/ml/SimpleNaiveBayes.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import com.holdenkarau.spark.testing._
99

1010
import org.apache.spark.ml._
1111
import org.apache.spark.ml.feature._
12+
import org.apache.spark.ml.param._
1213
import org.apache.spark.sql.types._
1314
import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
1415
import org.scalatest.Matchers._
@@ -30,14 +31,17 @@ class SimpleNaiveBayesSuite extends FunSuite with DataFrameSuiteBase {
3031
val ds: Dataset[MiniPanda] = session.createDataset(miniPandasList)
3132
val assembler = new VectorAssembler()
3233
assembler.setInputCols(Array("fuzzy", "old"))
33-
assembler.setOutputCol("features")
34+
assembler.setOutputCol("magical_features")
3435
val snb = new SimpleNaiveBayes()
3536
snb.setLabelCol("happy")
36-
snb.setFeaturesCol("features")
37+
snb.setFeaturesCol("magical_features")
3738
val pipeline = new Pipeline().setStages(Array(assembler, snb))
3839
val model = pipeline.fit(ds)
3940
val test = ds.select("fuzzy", "old")
4041
val predicted = model.transform(test)
41-
println(predicted.collect())
42+
assert(predicted.count() === miniPandasList.size)
43+
val nbModel = model.stages(1).asInstanceOf[SimpleNaiveBayesModel]
44+
assert(nbModel.getFeaturesCol === "magical_features")
45+
assert(nbModel.copy(ParamMap.empty).getFeaturesCol === "magical_features")
4246
}
4347
}

0 commit comments

Comments
 (0)