From 9b0f24cf7032de4c336ae8618f4bd419a90a3b9d Mon Sep 17 00:00:00 2001 From: smadarasmi Date: Sun, 21 Oct 2018 23:05:27 +0700 Subject: [PATCH] Fixes #220: modify method assertDataFrameNoOrderEquals to handle duplicates in dataframe --- .../holdenkarau/spark/testing/DataFrameSuiteBase.scala | 9 ++++++--- .../spark/testing/SampleDataFrameTest.scala | 10 ++++++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala b/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala index 515ec380..aa50ef8a 100644 --- a/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala +++ b/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala @@ -172,12 +172,15 @@ trait DataFrameSuiteBaseLike extends SparkContextProvider /** * Compares if two [[DataFrame]]s are equal without caring about order of rows, by - * finding elements in one DataFrame not in the other. The resulting DataFrame + * finding elements in one DataFrame that is not in the other. The resulting DataFrame * should be empty inferring the two DataFrames have the same elements. */ def assertDataFrameNoOrderEquals(expected: DataFrame, result: DataFrame) { - assertEmpty(expected.except(result).rdd.take(maxUnequalRowsToShow)) - assertEmpty(result.except(expected).rdd.take(maxUnequalRowsToShow)) + import org.apache.spark.sql.functions.col + val expectedElementsCount = expected.groupBy(expected.columns.map(s => col(s)): _*).count() + val resultElementsCount = result.groupBy(result.columns.map(s => col(s)): _*).count() + + assertDataFrameEquals(expectedElementsCount, resultElementsCount) } /** diff --git a/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala b/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala index 5e7160a2..de04e25f 100644 --- a/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala +++ b/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala @@ -39,8 +39,9 @@ class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase { test("dataframe should be equal with different order of rows") { import sqlContext.implicits._ - val input = sc.parallelize(inputList).toDF - val reverseInput = sc.parallelize(inputList.reverse).toDF + val inputListWithDuplicates = inputList ++ List(inputList.head) + val input = sc.parallelize(inputListWithDuplicates).toDF + val reverseInput = sc.parallelize(inputListWithDuplicates.reverse).toDF assertDataFrameNoOrderEquals(input, reverseInput) } @@ -55,8 +56,9 @@ class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase { test("unequal dataframe with different order should not equal") { import sqlContext.implicits._ - val input = sc.parallelize(inputList).toDF - val input2 = sc.parallelize(List(inputList.head)).toDF + val inputListWithDuplicates = inputList ++ List(inputList.head) + val input = sc.parallelize(inputListWithDuplicates).toDF + val input2 = sc.parallelize(inputList).toDF intercept[org.scalatest.exceptions.TestFailedException] { assertDataFrameNoOrderEquals(input, input2) }