diff --git a/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala b/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala index 186a0efe..3c3f0a99 100644 --- a/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala +++ b/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala @@ -208,7 +208,13 @@ object DataFrameSuiteBase { } case d1: java.math.BigDecimal => - if (d1.compareTo(o2.asInstanceOf[java.math.BigDecimal]) != 0) { + if (d1.subtract(o2.asInstanceOf[java.math.BigDecimal]).abs + .compareTo(new java.math.BigDecimal(tol)) > 0) { + return false + } + + case d1: scala.math.BigDecimal => + if ((d1 - o2.asInstanceOf[scala.math.BigDecimal]).abs > tol) { return false } diff --git a/src/main/pre-2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala b/src/main/pre-2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala index 64fe68c1..a66f9423 100644 --- a/src/main/pre-2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala +++ b/src/main/pre-2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala @@ -171,7 +171,11 @@ object DataFrameSuiteBase { if (abs(d1 - o2.asInstanceOf[Double]) > tol) return false case d1: java.math.BigDecimal => - if (d1.compareTo(o2.asInstanceOf[java.math.BigDecimal]) != 0) return false + if (d1.subtract(o2.asInstanceOf[java.math.BigDecimal]).abs + .compareTo(new java.math.BigDecimal(tol)) > 0) return false + + case d1: scala.math.BigDecimal => + if ((d1 - o2.asInstanceOf[scala.math.BigDecimal]).abs > tol) return false case t1: Timestamp => if (abs(t1.getTime - o2.asInstanceOf[Timestamp].getTime) > tol) { diff --git a/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala b/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala index 8177abbb..1ca7faae 100644 --- a/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala +++ b/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala @@ -70,6 +70,10 @@ class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase { val row8 = Row(Timestamp.valueOf("2018-01-12 20:22:13")) val row9 = Row(Timestamp.valueOf("2018-01-12 20:22:18")) val row10 = Row(Timestamp.valueOf("2018-01-12 20:23:13")) + val row11 = Row(new java.math.BigDecimal(1.0)) + val row11a = Row(new java.math.BigDecimal(1.0 + 1.0E-6)) + val row12 = Row(scala.math.BigDecimal(1.0)) + val row12a = Row(scala.math.BigDecimal(1.0 + 1.0E-6)) assert(false === approxEquals(row, row2, 1E-7)) assert(true === approxEquals(row, row2, 1E-5)) assert(true === approxEquals(row3, row3, 1E-5)) @@ -84,6 +88,8 @@ class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase { assert(false === approxEquals(row9, row8, 3000)) assert(true === approxEquals(row9, row10, 60000)) assert(false === approxEquals(row9, row10, 53000)) + assert(true === approxEquals(row11, row11a, 1.0E-6)) + assert(true === approxEquals(row12, row12a, 1.0E-6)) } test("verify hive function support") {