Skip to content

Commit

Permalink
Add relative tolerance checks to approxEquals
Browse files Browse the repository at this point in the history
  • Loading branch information
nightscape committed Nov 9, 2017
1 parent 088792a commit 1236c8f
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,11 @@ trait DataFrameSuiteBaseLike extends SparkContextProvider
* @param tol max acceptable tolerance, should be less than 1.
*/
def assertDataFrameApproximateEquals(
expected: DataFrame, result: DataFrame, tol: Double) {
expected: DataFrame,
result: DataFrame,
tol: Double = 0.0,
relTol: Double = 0.0
) {

assert(expected.schema, result.schema)

Expand All @@ -160,7 +164,7 @@ trait DataFrameSuiteBaseLike extends SparkContextProvider

val unequalRDD = expectedIndexValue.join(resultIndexValue).
filter{case (idx, (r1, r2)) =>
!DataFrameSuiteBase.approxEquals(r1, r2, tol)}
!DataFrameSuiteBase.approxEquals(r1, r2, tol, relTol)}

assertEmpty(unequalRDD.take(maxUnequalRowsToShow))
} finally {
Expand All @@ -178,15 +182,55 @@ trait DataFrameSuiteBaseLike extends SparkContextProvider
rdd.zipWithIndex().map{ case (row, idx) => (idx, row) }
}

def approxEquals(r1: Row, r2: Row, tol: Double): Boolean = {
DataFrameSuiteBase.approxEquals(r1, r2, tol)
def approxEquals(
r1: Row,
r2: Row,
tol: Double = 0.0,
relTol: Double = 0.0
): Boolean = {
DataFrameSuiteBase.approxEquals(r1, r2, tol, relTol)
}
}

object DataFrameSuiteBase {
trait WithinToleranceChecker {
def apply(a: Double, b: Double): Boolean
def apply(a: BigDecimal, b: BigDecimal): Boolean
}
object WithinToleranceChecker {
def apply(tol: Double = 0.0, relTol: Double = 0.0) =
if(tol != 0.0 || relTol == 0.0) {
new WithinAbsoluteToleranceChecker(tol)
} else {
new WithinRelativeToleranceChecker(relTol)
}
}

class WithinAbsoluteToleranceChecker(tolerance: Double)
extends WithinToleranceChecker {
def apply(a: Double, b: Double): Boolean =
(a - b).abs <= tolerance
def apply(a: BigDecimal, b: BigDecimal): Boolean =
(a - b).abs <= tolerance
}

class WithinRelativeToleranceChecker(relativeTolerance: Double)
extends WithinToleranceChecker {
def apply(a: Double, b: Double): Boolean =
(a - b).abs / (a.abs max b.abs) <= relativeTolerance
def apply(a: BigDecimal, b: BigDecimal): Boolean =
(a - b).abs / (a.abs max b.abs) <= relativeTolerance
}

/** Approximate equality, based on equals from [[Row]] */
def approxEquals(r1: Row, r2: Row, tol: Double): Boolean = {
def approxEquals(
r1: Row,
r2: Row,
tol: Double = 0.0,
relTol: Double = 0.0
): Boolean = {
val withinTolerance = WithinToleranceChecker(tol, relTol)

if (r1.length != r2.length) {
return false
} else {
Expand All @@ -212,7 +256,7 @@ object DataFrameSuiteBase {
{
return false
}
if (abs(f1 - o2.asInstanceOf[Float]) > tol) {
if (!withinTolerance(f1, o2.asInstanceOf[Float])) {
return false
}

Expand All @@ -222,18 +266,20 @@ object DataFrameSuiteBase {
{
return false
}
if (abs(d1 - o2.asInstanceOf[Double]) > tol) {
if (!withinTolerance(d1, o2.asInstanceOf[Double])) {
return false
}

case d1: java.math.BigDecimal =>
if (d1.subtract(o2.asInstanceOf[java.math.BigDecimal]).abs
.compareTo(new java.math.BigDecimal(tol)) > 0) {
if (!withinTolerance(
BigDecimal(d1),
BigDecimal(o2.asInstanceOf[java.math.BigDecimal]
))) {
return false
}

case d1: scala.math.BigDecimal =>
if ((d1 - o2.asInstanceOf[scala.math.BigDecimal]).abs > tol) {
if (!withinTolerance(d1, o2.asInstanceOf[scala.math.BigDecimal])) {
return false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.holdenkarau.spark.testing

import org.apache.spark.sql.Row
import org.scalatest.FunSuite
import java.math.{ BigDecimal => JBigDecimal }

class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase {
val byteArray = new Array[Byte](1)
Expand Down Expand Up @@ -64,10 +65,10 @@ class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase {
val row6 = Row("1")
val row6a = Row("2")
val row7 = Row(1.toFloat)
val row8 = Row(new java.math.BigDecimal(1.0))
val row8a = Row(new java.math.BigDecimal(1.0 + 1.0E-6))
val row9 = Row(scala.math.BigDecimal(1.0))
val row9a = Row(scala.math.BigDecimal(1.0 + 1.0E-6))
val row8 = Row(new JBigDecimal(1.0))
val row8a = Row(new JBigDecimal(1.0 + 1.0E-6))
val row9 = Row(BigDecimal(1.0))
val row9a = Row(BigDecimal(1.0 + 1.0E-6))
assert(false === approxEquals(row, row2, 1E-7))
assert(true === approxEquals(row, row2, 1E-5))
assert(true === approxEquals(row3, row3, 1E-5))
Expand All @@ -82,6 +83,35 @@ class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase {
assert(true === approxEquals(row9, row9a, 1.0E-6))
}

test("dataframe approxEquals on rows with relative tolerance") {
import sqlContext.implicits._
// Use 1 / 2^n as example numbers to avoid numeric errors
val relTol = 0.125
val orig = 0.25
val within = orig - relTol * orig
val outside = within - 1.0E-4
def assertRelativeApproxEqualsWorksFor[T](constructor: Double => T) = {
assertResult(true) {
approxEquals(
Row(constructor(orig)),
Row(constructor(within)),
relTol = relTol
)
}
assertResult(false) {
approxEquals(
Row(constructor(orig)),
Row(constructor(outside)),
relTol = relTol
)
}
}
assertRelativeApproxEqualsWorksFor[Double](identity)
assertRelativeApproxEqualsWorksFor[Float](_.toFloat)
assertRelativeApproxEqualsWorksFor[BigDecimal](BigDecimal.apply)
assertRelativeApproxEqualsWorksFor[JBigDecimal](new JBigDecimal(_))
}

test("verify hive function support") {
import sqlContext.implicits._
// Convert to int since old versions of hive only support percentile on
Expand Down

0 comments on commit 1236c8f

Please sign in to comment.