Skip to content

Commit

Permalink
Merge pull request #90 from MadDataScience/master
Browse files Browse the repository at this point in the history
Fix SimpleIndexer fit method to set inputCol and outputCol correctly
  • Loading branch information
holdenk authored Jul 22, 2017
2 parents bb5e995 + dd03847 commit 04f7d83
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ project/plugins/project/
# Scala-IDE specific
.scala_dependencies
.worksheet
.idea/

# emacs stuff
\#*\#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ class SimpleIndexer(override val uid: String)
import dataset.sparkSession.implicits._
val words = dataset.select(dataset($(inputCol)).as[String]).distinct
.collect()
new SimpleIndexerModel(uid, words)
val model = new SimpleIndexerModel(uid, words)
this.copyValues(model)
model
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/**
* Simple tests for our CustomPipeline demo pipeline stage
*/
package com.highperformancespark.examples.ml

import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.apache.spark.sql.Dataset
import org.scalatest.FunSuite

case class TestRow(id: Int, inputColumn: String)

class CustomPipelineSuite extends FunSuite with DataFrameSuiteBase {
val d = List(
TestRow(0, "a"),
TestRow(1, "b"),
TestRow(2, "c"),
TestRow(3, "a"),
TestRow(4, "a"),
TestRow(5, "c")
)

test("test spark context") {
val session = spark
val rdd = session.sparkContext.parallelize(1 to 10)
assert(rdd.sum === 55)
}

test("simple indexer test") {
val session = spark
import session.implicits._
val ds: Dataset[TestRow] = session.createDataset(d)
val indexer = new SimpleIndexer()
indexer.setInputCol("inputColumn")
indexer.setOutputCol("categoryIndex")
val model = indexer.fit(ds)
val predicted = model.transform(ds)
assert(predicted.columns.contains("categoryIndex"))
predicted.show()
}
}

0 comments on commit 04f7d83

Please sign in to comment.