Skip to content

Commit

Permalink
Add relative tolerance checks to approxEquals
Browse files Browse the repository at this point in the history
  • Loading branch information
nightscape committed Dec 29, 2019
1 parent 380b6b1 commit 468a1d6
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ trait DataFrameSuiteBaseLike extends SparkContextProvider
* @param tol max acceptable tolerance, should be less than 1.
*/
def assertDataFrameApproximateEquals(
expected: DataFrame, result: DataFrame, tol: Double) {
expected: DataFrame,
result: DataFrame,
tol: Double = 0.0,
relTol: Double = 0.0
) {

assert(expected.schema, result.schema)

Expand All @@ -142,7 +146,7 @@ trait DataFrameSuiteBaseLike extends SparkContextProvider

val unequalRDD = expectedIndexValue.join(resultIndexValue).
filter{case (idx, (r1, r2)) =>
!(r1.equals(r2) || DataFrameSuiteBase.approxEquals(r1, r2, tol))}
!(r1.equals(r2) || DataFrameSuiteBase.approxEquals(r1, r2, tol, relTol))}

assertEmpty(unequalRDD.take(maxUnequalRowsToShow))
} finally {
Expand All @@ -160,15 +164,67 @@ trait DataFrameSuiteBaseLike extends SparkContextProvider
rdd.zipWithIndex().map{ case (row, idx) => (idx, row) }
}

def approxEquals(r1: Row, r2: Row, tol: Double): Boolean = {
DataFrameSuiteBase.approxEquals(r1, r2, tol)
def approxEquals(
r1: Row,
r2: Row,
tol: Double = 0.0,
relTol: Double = 0.0
): Boolean = {
DataFrameSuiteBase.approxEquals(r1, r2, tol, relTol)
}
}

object DataFrameSuiteBase {
trait WithinToleranceChecker {
def apply(a: Double, b: Double): Boolean
def apply(a: BigDecimal, b: BigDecimal): Boolean
}
object WithinToleranceChecker {
def apply(tol: Double = 0.0, relTol: Double = 0.0) =
if(tol != 0.0 || relTol == 0.0) {
new WithinAbsoluteToleranceChecker(tol)
} else {
new WithinRelativeToleranceChecker(relTol)
}
}

class WithinAbsoluteToleranceChecker(tolerance: Double)
extends WithinToleranceChecker {
def apply(a: Double, b: Double): Boolean =
(a - b).abs <= tolerance
def apply(a: BigDecimal, b: BigDecimal): Boolean =
(a - b).abs <= tolerance
}

class WithinRelativeToleranceChecker(relativeTolerance: Double)
extends WithinToleranceChecker {
def apply(a: Double, b: Double): Boolean = {
val max = (a.abs max b.abs)
if (max == 0.0) {
true
} else {
(a - b).abs / max <= relativeTolerance
}
}
def apply(a: BigDecimal, b: BigDecimal): Boolean = {
val max = (a.abs max b.abs)
if (max == 0.0) {
true
} else {
(a - b).abs / max <= relativeTolerance
}
}
}

/** Approximate equality, based on equals from [[Row]] */
def approxEquals(r1: Row, r2: Row, tol: Double): Boolean = {
def approxEquals(
r1: Row,
r2: Row,
tol: Double = 0.0,
relTol: Double = 0.0
): Boolean = {
val withinTolerance = WithinToleranceChecker(tol, relTol)

if (r1.length != r2.length) {
return false
} else {
Expand All @@ -192,7 +248,7 @@ object DataFrameSuiteBase {
{
return false
}
if (abs(f1 - o2.asInstanceOf[Float]) > tol) {
if (!withinTolerance(f1, o2.asInstanceOf[Float])) {
return false
}

Expand All @@ -202,18 +258,20 @@ object DataFrameSuiteBase {
{
return false
}
if (abs(d1 - o2.asInstanceOf[Double]) > tol) {
if (!withinTolerance(d1, o2.asInstanceOf[Double])) {
return false
}

case d1: java.math.BigDecimal =>
if (d1.subtract(o2.asInstanceOf[java.math.BigDecimal]).abs
.compareTo(new java.math.BigDecimal(tol)) > 0) {
if (!withinTolerance(
BigDecimal(d1),
BigDecimal(o2.asInstanceOf[java.math.BigDecimal]
))) {
return false
}

case d1: scala.math.BigDecimal =>
if ((d1 - o2.asInstanceOf[scala.math.BigDecimal]).abs > tol) {
if (!withinTolerance(d1, o2.asInstanceOf[scala.math.BigDecimal])) {
return false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.sql.Timestamp
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.scalatest.FunSuite
import java.math.{ BigDecimal => JBigDecimal }

class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase {
val byteArray = new Array[Byte](1)
Expand Down Expand Up @@ -70,10 +71,10 @@ class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase {
val row8 = Row(Timestamp.valueOf("2018-01-12 20:22:13"))
val row9 = Row(Timestamp.valueOf("2018-01-12 20:22:18"))
val row10 = Row(Timestamp.valueOf("2018-01-12 20:23:13"))
val row11 = Row(new java.math.BigDecimal(1.0))
val row11a = Row(new java.math.BigDecimal(1.0 + 1.0E-6))
val row12 = Row(scala.math.BigDecimal(1.0))
val row12a = Row(scala.math.BigDecimal(1.0 + 1.0E-6))
val row11 = Row(new JBigDecimal(1.0))
val row11a = Row(new JBigDecimal(1.0 + 1.0E-6))
val row12 = Row(BigDecimal(1.0))
val row12a = Row(BigDecimal(1.0 + 1.0E-6))
assert(false === approxEquals(row, row2, 1E-7))
assert(true === approxEquals(row, row2, 1E-5))
assert(true === approxEquals(row3, row3, 1E-5))
Expand All @@ -92,6 +93,42 @@ class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase {
assert(true === approxEquals(row12, row12a, 1.0E-6))
}

test("dataframe approxEquals on rows with relative tolerance") {
import sqlContext.implicits._
// Use 1 / 2^n as example numbers to avoid numeric errors
val relTol = scala.math.pow(2, -6)
val orig = 0.25
val within = orig - relTol * orig
val outside = within - 1.0E-4
def assertRelativeApproxEqualsWorksFor[T](constructor: Double => T) = {
assertResult(true) {
approxEquals(
Row(constructor(orig)),
Row(constructor(within)),
relTol = relTol
)
}
assertResult(false) {
approxEquals(
Row(constructor(orig)),
Row(constructor(outside)),
relTol = relTol
)
}
assertResult(true) {
approxEquals(
Row(constructor(0.0)),
Row(constructor(0.0)),
relTol = relTol
)
}
}
assertRelativeApproxEqualsWorksFor[Double](identity)
assertRelativeApproxEqualsWorksFor[Float](_.toFloat)
assertRelativeApproxEqualsWorksFor[BigDecimal](BigDecimal.apply)
assertRelativeApproxEqualsWorksFor[JBigDecimal](new JBigDecimal(_))
}

test("verify hive function support") {
import sqlContext.implicits._
// Convert to int since old versions of hive only support percentile on
Expand Down

0 comments on commit 468a1d6

Please sign in to comment.