Skip to content

Commit

Permalink
Fixes #220: modify method assertDataFrameNoOrderEquals to handle dupl…
Browse files Browse the repository at this point in the history
…icates in dataframe
  • Loading branch information
smadarasmi committed Oct 21, 2018
1 parent 154ba1e commit 9b0f24c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,15 @@ trait DataFrameSuiteBaseLike extends SparkContextProvider

/**
* Compares if two [[DataFrame]]s are equal without caring about order of rows, by
* finding elements in one DataFrame not in the other. The resulting DataFrame
* finding elements in one DataFrame that is not in the other. The resulting DataFrame
* should be empty inferring the two DataFrames have the same elements.
*/
def assertDataFrameNoOrderEquals(expected: DataFrame, result: DataFrame) {
assertEmpty(expected.except(result).rdd.take(maxUnequalRowsToShow))
assertEmpty(result.except(expected).rdd.take(maxUnequalRowsToShow))
import org.apache.spark.sql.functions.col
val expectedElementsCount = expected.groupBy(expected.columns.map(s => col(s)): _*).count()
val resultElementsCount = result.groupBy(result.columns.map(s => col(s)): _*).count()

assertDataFrameEquals(expectedElementsCount, resultElementsCount)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase {

test("dataframe should be equal with different order of rows") {
import sqlContext.implicits._
val input = sc.parallelize(inputList).toDF
val reverseInput = sc.parallelize(inputList.reverse).toDF
val inputListWithDuplicates = inputList ++ List(inputList.head)
val input = sc.parallelize(inputListWithDuplicates).toDF
val reverseInput = sc.parallelize(inputListWithDuplicates.reverse).toDF
assertDataFrameNoOrderEquals(input, reverseInput)
}

Expand All @@ -55,8 +56,9 @@ class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase {

test("unequal dataframe with different order should not equal") {
import sqlContext.implicits._
val input = sc.parallelize(inputList).toDF
val input2 = sc.parallelize(List(inputList.head)).toDF
val inputListWithDuplicates = inputList ++ List(inputList.head)
val input = sc.parallelize(inputListWithDuplicates).toDF
val input2 = sc.parallelize(inputList).toDF
intercept[org.scalatest.exceptions.TestFailedException] {
assertDataFrameNoOrderEquals(input, input2)
}
Expand Down

0 comments on commit 9b0f24c

Please sign in to comment.