diff --git a/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala b/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala index 6652a63e..a79f6c68 100644 --- a/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala +++ b/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala @@ -146,7 +146,11 @@ trait DataFrameSuiteBaseLike extends SparkContextProvider * @param tol max acceptable tolerance, should be less than 1. */ def assertDataFrameApproximateEquals( - expected: DataFrame, result: DataFrame, tol: Double) { + expected: DataFrame, + result: DataFrame, + tol: Double = 0.0, + relTol: Double = 0.0 + ) { assert(expected.schema, result.schema) @@ -160,7 +164,7 @@ trait DataFrameSuiteBaseLike extends SparkContextProvider val unequalRDD = expectedIndexValue.join(resultIndexValue). filter{case (idx, (r1, r2)) => - !DataFrameSuiteBase.approxEquals(r1, r2, tol)} + !DataFrameSuiteBase.approxEquals(r1, r2, tol, relTol)} assertEmpty(unequalRDD.take(maxUnequalRowsToShow)) } finally { @@ -178,15 +182,55 @@ trait DataFrameSuiteBaseLike extends SparkContextProvider rdd.zipWithIndex().map{ case (row, idx) => (idx, row) } } - def approxEquals(r1: Row, r2: Row, tol: Double): Boolean = { - DataFrameSuiteBase.approxEquals(r1, r2, tol) + def approxEquals( + r1: Row, + r2: Row, + tol: Double = 0.0, + relTol: Double = 0.0 + ): Boolean = { + DataFrameSuiteBase.approxEquals(r1, r2, tol, relTol) } } object DataFrameSuiteBase { + trait WithinToleranceChecker { + def apply(a: Double, b: Double): Boolean + def apply(a: BigDecimal, b: BigDecimal): Boolean + } + object WithinToleranceChecker { + def apply(tol: Double = 0.0, relTol: Double = 0.0) = + if(tol != 0.0 || relTol == 0.0) { + new WithinAbsoluteToleranceChecker(tol) + } else { + new WithinRelativeToleranceChecker(relTol) + } + } + + class WithinAbsoluteToleranceChecker(tolerance: Double) + extends WithinToleranceChecker { + def apply(a: Double, b: Double): Boolean = + (a - b).abs <= tolerance + def apply(a: BigDecimal, b: BigDecimal): Boolean = + (a - b).abs <= tolerance + } + + class WithinRelativeToleranceChecker(relativeTolerance: Double) + extends WithinToleranceChecker { + def apply(a: Double, b: Double): Boolean = + (a - b).abs / (a.abs max b.abs) <= relativeTolerance + def apply(a: BigDecimal, b: BigDecimal): Boolean = + (a - b).abs / (a.abs max b.abs) <= relativeTolerance + } /** Approximate equality, based on equals from [[Row]] */ - def approxEquals(r1: Row, r2: Row, tol: Double): Boolean = { + def approxEquals( + r1: Row, + r2: Row, + tol: Double = 0.0, + relTol: Double = 0.0 + ): Boolean = { + val withinTolerance = WithinToleranceChecker(tol, relTol) + if (r1.length != r2.length) { return false } else { @@ -212,7 +256,7 @@ object DataFrameSuiteBase { { return false } - if (abs(f1 - o2.asInstanceOf[Float]) > tol) { + if (!withinTolerance(f1, o2.asInstanceOf[Float])) { return false } @@ -222,18 +266,20 @@ object DataFrameSuiteBase { { return false } - if (abs(d1 - o2.asInstanceOf[Double]) > tol) { + if (!withinTolerance(d1, o2.asInstanceOf[Double])) { return false } case d1: java.math.BigDecimal => - if (d1.subtract(o2.asInstanceOf[java.math.BigDecimal]).abs - .compareTo(new java.math.BigDecimal(tol)) > 0) { + if (!withinTolerance( + BigDecimal(d1), + BigDecimal(o2.asInstanceOf[java.math.BigDecimal] + ))) { return false } case d1: scala.math.BigDecimal => - if ((d1 - o2.asInstanceOf[scala.math.BigDecimal]).abs > tol) { + if (!withinTolerance(d1, o2.asInstanceOf[scala.math.BigDecimal])) { return false } diff --git a/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala b/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala index 6f4b9a4b..6664d75d 100644 --- a/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala +++ b/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala @@ -18,6 +18,7 @@ package com.holdenkarau.spark.testing import org.apache.spark.sql.Row import org.scalatest.FunSuite +import java.math.{ BigDecimal => JBigDecimal } class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase { val byteArray = new Array[Byte](1) @@ -64,10 +65,10 @@ class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase { val row6 = Row("1") val row6a = Row("2") val row7 = Row(1.toFloat) - val row8 = Row(new java.math.BigDecimal(1.0)) - val row8a = Row(new java.math.BigDecimal(1.0 + 1.0E-6)) - val row9 = Row(scala.math.BigDecimal(1.0)) - val row9a = Row(scala.math.BigDecimal(1.0 + 1.0E-6)) + val row8 = Row(new JBigDecimal(1.0)) + val row8a = Row(new JBigDecimal(1.0 + 1.0E-6)) + val row9 = Row(BigDecimal(1.0)) + val row9a = Row(BigDecimal(1.0 + 1.0E-6)) assert(false === approxEquals(row, row2, 1E-7)) assert(true === approxEquals(row, row2, 1E-5)) assert(true === approxEquals(row3, row3, 1E-5)) @@ -82,6 +83,35 @@ class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase { assert(true === approxEquals(row9, row9a, 1.0E-6)) } + test("dataframe approxEquals on rows with relative tolerance") { + import sqlContext.implicits._ + // Use 1 / 2^n as example numbers to avoid numeric errors + val relTol = 0.125 + val orig = 0.25 + val within = orig - relTol * orig + val outside = within - 1.0E-4 + def assertRelativeApproxEqualsWorksFor[T](constructor: Double => T) = { + assertResult(true) { + approxEquals( + Row(constructor(orig)), + Row(constructor(within)), + relTol = relTol + ) + } + assertResult(false) { + approxEquals( + Row(constructor(orig)), + Row(constructor(outside)), + relTol = relTol + ) + } + } + assertRelativeApproxEqualsWorksFor[Double](identity) + assertRelativeApproxEqualsWorksFor[Float](_.toFloat) + assertRelativeApproxEqualsWorksFor[BigDecimal](BigDecimal.apply) + assertRelativeApproxEqualsWorksFor[JBigDecimal](new JBigDecimal(_)) + } + test("verify hive function support") { import sqlContext.implicits._ // Convert to int since old versions of hive only support percentile on