From 088792abb5cc4c6b7688dd5592122c5cf58c2a36 Mon Sep 17 00:00:00 2001 From: Martin Mauch Date: Wed, 8 Nov 2017 18:26:04 +0100 Subject: [PATCH] Compare Java and Scala BigDecimal with tolerance --- .../holdenkarau/spark/testing/DataFrameSuiteBase.scala | 8 +++++++- .../holdenkarau/spark/testing/DataFrameSuiteBase.scala | 6 +++++- .../holdenkarau/spark/testing/SampleDataFrameTest.scala | 6 ++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala b/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala index 52e1b343..6652a63e 100644 --- a/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala +++ b/src/main/2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala @@ -227,7 +227,13 @@ object DataFrameSuiteBase { } case d1: java.math.BigDecimal => - if (d1.compareTo(o2.asInstanceOf[java.math.BigDecimal]) != 0) { + if (d1.subtract(o2.asInstanceOf[java.math.BigDecimal]).abs + .compareTo(new java.math.BigDecimal(tol)) > 0) { + return false + } + + case d1: scala.math.BigDecimal => + if ((d1 - o2.asInstanceOf[scala.math.BigDecimal]).abs > tol) { return false } diff --git a/src/main/pre-2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala b/src/main/pre-2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala index 4d63adaf..752836a6 100644 --- a/src/main/pre-2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala +++ b/src/main/pre-2.0/scala/com/holdenkarau/spark/testing/DataFrameSuiteBase.scala @@ -170,7 +170,11 @@ object DataFrameSuiteBase { if (abs(d1 - o2.asInstanceOf[Double]) > tol) return false case d1: java.math.BigDecimal => - if (d1.compareTo(o2.asInstanceOf[java.math.BigDecimal]) != 0) return false + if (d1.subtract(o2.asInstanceOf[java.math.BigDecimal]).abs + .compareTo(new java.math.BigDecimal(tol)) > 0) return false + + case d1: scala.math.BigDecimal => + if ((d1 - o2.asInstanceOf[scala.math.BigDecimal]).abs > tol) return false case _ => if (o1 != o2) return false diff --git a/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala b/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala index c7b2d92d..6f4b9a4b 100644 --- a/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala +++ b/src/test/1.3/scala/com/holdenkarau/spark/testing/SampleDataFrameTest.scala @@ -64,6 +64,10 @@ class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase { val row6 = Row("1") val row6a = Row("2") val row7 = Row(1.toFloat) + val row8 = Row(new java.math.BigDecimal(1.0)) + val row8a = Row(new java.math.BigDecimal(1.0 + 1.0E-6)) + val row9 = Row(scala.math.BigDecimal(1.0)) + val row9a = Row(scala.math.BigDecimal(1.0 + 1.0E-6)) assert(false === approxEquals(row, row2, 1E-7)) assert(true === approxEquals(row, row2, 1E-5)) assert(true === approxEquals(row3, row3, 1E-5)) @@ -74,6 +78,8 @@ class SampleDataFrameTest extends FunSuite with DataFrameSuiteBase { assert(false === approxEquals(row6, row4, 1E-5)) assert(false === approxEquals(row6, row7, 1E-5)) assert(false === approxEquals(row6, row6a, 1E-5)) + assert(true === approxEquals(row8, row8a, 1.0E-6)) + assert(true === approxEquals(row9, row9a, 1.0E-6)) } test("verify hive function support") {