diff --git a/core/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala b/core/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala index 5ae5998b..cfe59469 100644 --- a/core/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala +++ b/core/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala @@ -250,7 +250,13 @@ object DataFrameSuiteBase { } case d1: java.math.BigDecimal => - if (d1.compareTo(o2.asInstanceOf[java.math.BigDecimal]) != 0) { + if (d1.subtract(o2.asInstanceOf[java.math.BigDecimal]).abs + .compareTo(new java.math.BigDecimal(tol)) > 0) { + return false + } + + case d1: scala.math.BigDecimal => + if ((d1 - o2.asInstanceOf[scala.math.BigDecimal]).abs > tol) { return false } diff --git a/core/src/main/pre-2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala b/core/src/main/pre-2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala index 78c5d8de..0c5a7262 100644 --- a/core/src/main/pre-2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala +++ b/core/src/main/pre-2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala @@ -171,7 +171,11 @@ object DataFrameSuiteBase { if (abs(d1 - o2.asInstanceOf[Double]) > tol) return false case d1: java.math.BigDecimal => - if (d1.compareTo(o2.asInstanceOf[java.math.BigDecimal]) != 0) return false + if (d1.subtract(o2.asInstanceOf[java.math.BigDecimal]).abs + .compareTo(new java.math.BigDecimal(tol)) > 0) return false + + case d1: scala.math.BigDecimal => + if ((d1 - o2.asInstanceOf[scala.math.BigDecimal]).abs > tol) return false case t1: Timestamp => if (abs(t1.getTime - o2.asInstanceOf[Timestamp].getTime) > tol) { diff --git a/core/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala b/core/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala index 8907cd03..e325fc40 100644 --- a/core/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala +++ b/core/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala @@ -117,6 +117,10 @@ class SampleDataFrameTest extends AnyFunSuite with DataFrameSuiteBase { val row8 = Row(Timestamp.valueOf("2018-01-12 20:22:13")) val row9 = Row(Timestamp.valueOf("2018-01-12 20:22:18")) val row10 = Row(Timestamp.valueOf("2018-01-12 20:23:13")) + val row11 = Row(new java.math.BigDecimal(1.0)) + val row11a = Row(new java.math.BigDecimal(1.0 + 1.0E-6)) + val row12 = Row(scala.math.BigDecimal(1.0)) + val row12a = Row(scala.math.BigDecimal(1.0 + 1.0E-6)) assert(false === approxEquals(row, row2, 1E-7)) assert(true === approxEquals(row, row2, 1E-5)) assert(true === approxEquals(row3, row3, 1E-5)) @@ -131,6 +135,8 @@ class SampleDataFrameTest extends AnyFunSuite with DataFrameSuiteBase { assert(false === approxEquals(row9, row8, 3000)) assert(true === approxEquals(row9, row10, 60000)) assert(false === approxEquals(row9, row10, 53000)) + assert(true === approxEquals(row11, row11a, 1.0E-6)) + assert(true === approxEquals(row12, row12a, 1.0E-6)) } test("verify hive function support") {