Skip to content

Commit

Permalink
Upgrade to Spark 2.2 Fix copying the params to the model (#92)
Browse files Browse the repository at this point in the history
* Upgrade to Spark 2.2

* Fix copying the params to the model

* Fix long line

* Comment

* Simplify

* Add a test to verify we are copy the params correctly

* test copy while we are there

* And the testing lib too

* Fix some tests/copy methods

* Push to 2.2.0 0.7.2 for windows build too

* Fix copy
  • Loading branch information
holdenk authored Jul 27, 2017
1 parent 04f7d83 commit d9893b8
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 13 deletions.
4 changes: 2 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ crossScalaVersions := Seq("2.11.6")
javacOptions ++= Seq("-source", "1.8", "-target", "1.8")

//tag::sparkVersion[]
sparkVersion := "2.1.0"
sparkVersion := "2.2.0"
//end::sparkVersion[]

//tag::sparkComponents[]
Expand All @@ -40,7 +40,7 @@ libraryDependencies ++= Seq(
"org.scalacheck" %% "scalacheck" % "1.13.4",
"junit" % "junit" % "4.12",
"junit" % "junit" % "4.11",
"com.holdenkarau" %% "spark-testing-base" % "2.1.0_0.6.0",
"com.holdenkarau" %% "spark-testing-base" % "2.2.0_0.7.2",
"com.novocode" % "junit-interface" % "0.11" % "test->default",
//tag::scalaLogging[]
"com.typesafe.scala-logging" %% "scala-logging" % "3.5.0",
Expand Down
4 changes: 2 additions & 2 deletions build_windows.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ crossScalaVersions := Seq("2.11.6")
javacOptions ++= Seq("-source", "1.8", "-target", "1.8")

//tag::sparkVersion[]
sparkVersion := "2.1.0"
sparkVersion := "2.2.0"
//end::sparkVersion[]

//tag::sparkComponents[]
Expand All @@ -40,7 +40,7 @@ libraryDependencies ++= Seq(
"org.scalacheck" %% "scalacheck" % "1.13.4",
"junit" % "junit" % "4.12",
"junit" % "junit" % "4.11",
"com.holdenkarau" %% "spark-testing-base" % "2.1.0_0.6.0",
"com.holdenkarau" %% "spark-testing-base" % "2.2.0_0.7.2",
"com.novocode" % "junit-interface" % "0.11" % "test->default",
//tag::sacalLogging[]
"com.typesafe.scala-logging" %% "scala-logging" % "3.5.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,10 @@ class SimpleIndexer(override val uid: String)
import dataset.sparkSession.implicits._
val words = dataset.select(dataset($(inputCol)).as[String]).distinct
.collect()
// Construct the model
val model = new SimpleIndexerModel(uid, words)
this.copyValues(model)
model
// Copy the parameters to the model
copyValues(model)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,14 @@ class SimpleNaiveBayes(val uid: String)
// Unpersist now that we are done computing everything
ds.unpersist()
// Construct a model
new SimpleNaiveBayesModel(uid, numClasses, numFeatures, Vectors.dense(pi),
val model = new SimpleNaiveBayesModel(
uid, numClasses, numFeatures, Vectors.dense(pi),
new DenseMatrix(numClasses, theta(0).length, theta.flatten, true))
// Copy the params values to the model
copyValues(model)
}

override def copy(extra: ParamMap) = {
override def copy(extra: ParamMap): SimpleNaiveBayes = {
defaultCopy(extra)
}
}
Expand All @@ -100,8 +103,9 @@ case class SimpleNaiveBayesModel(
val theta: DenseMatrix) extends
ClassificationModel[Vector, SimpleNaiveBayesModel] {

override def copy(extra: ParamMap) = {
defaultCopy(extra)
override def copy(extra: ParamMap): SimpleNaiveBayesModel = {
val copied = new SimpleNaiveBayesModel(uid, numClasses, numFeatures, pi, theta)
copyValues(copied, extra).setParent(parent)
}

// We have to do some tricks here because we are using Spark's
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import com.holdenkarau.spark.testing._

import org.apache.spark.ml._
import org.apache.spark.ml.feature._
import org.apache.spark.ml.param._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
import org.scalatest.Matchers._
Expand All @@ -30,14 +31,17 @@ class SimpleNaiveBayesSuite extends FunSuite with DataFrameSuiteBase {
val ds: Dataset[MiniPanda] = session.createDataset(miniPandasList)
val assembler = new VectorAssembler()
assembler.setInputCols(Array("fuzzy", "old"))
assembler.setOutputCol("features")
assembler.setOutputCol("magical_features")
val snb = new SimpleNaiveBayes()
snb.setLabelCol("happy")
snb.setFeaturesCol("features")
snb.setFeaturesCol("magical_features")
val pipeline = new Pipeline().setStages(Array(assembler, snb))
val model = pipeline.fit(ds)
val test = ds.select("fuzzy", "old")
val predicted = model.transform(test)
println(predicted.collect())
assert(predicted.count() === miniPandasList.size)
val nbModel = model.stages(1).asInstanceOf[SimpleNaiveBayesModel]
assert(nbModel.getFeaturesCol === "magical_features")
assert(nbModel.copy(ParamMap.empty).getFeaturesCol === "magical_features")
}
}

0 comments on commit d9893b8

Please sign in to comment.