From 1780c667bd1e29fb3b0411c4f9e212be8a4e9dcf Mon Sep 17 00:00:00 2001 From: Kedar Sankar Behera Date: Thu, 2 Aug 2018 14:49:35 -0700 Subject: [PATCH] MD-3579: Compare double values with float precision (#493) Co-authored-by: kedarbcs16 Co-authored-by: agirish --- .../drill/test/framework/ColumnList.java | 53 +++++++++++------ .../drill/test/framework/TestVerifier.java | 58 +++++++++---------- 2 files changed, 63 insertions(+), 48 deletions(-) diff --git a/framework/src/main/java/org/apache/drill/test/framework/ColumnList.java b/framework/src/main/java/org/apache/drill/test/framework/ColumnList.java index 4d24b5b0c..c5565032f 100644 --- a/framework/src/main/java/org/apache/drill/test/framework/ColumnList.java +++ b/framework/src/main/java/org/apache/drill/test/framework/ColumnList.java @@ -30,18 +30,19 @@ public class ColumnList { private final List values; private final List types; - private final boolean Simba; + private final boolean simba; public static final String SIMBA_JDBC = "sjdbc"; + private static final double MAX_DIFF_VALUE = 1.0E-6; public ColumnList(List types, List values) { this.values = values; this.types = types; if (TestDriver.cmdParam.driverExt != null && TestDriver.cmdParam.driverExt.equals(ColumnList.SIMBA_JDBC)) { - this.Simba = true; + this.simba = true; } else { - this.Simba = false; + this.simba = false; } } @@ -54,9 +55,15 @@ public List getValues() { * "loosened" logic to handle float, double and decimal types. The algorithm * used for the comparison follows: * - * Floats: f1 equals f2 iff |((f1-f2)/(average(f1,f2)))| < 0.000001 + * Floats: f1 and f2 are equal + * if f1 == f2 + * if f1 > f2 && if |(f1-f2)/f1| < 0.000001 + * if f1 < f2 && if |(f1-f2)/f2| < 0.000001 * - * Doubles: d1 equals d2 iff |((d1-d2)/(average(d1,d2)))| < 0.000000000001 + * Doubles: d1 and d2 are equal + * if d1 == d2 + * if d1 > d2 && if |(d1-d2)/d1| < 0.000001 + * if d1 < d2 && if |(d1-d2)/d2| < 0.000001 * * Decimals: dec1 equals dec2 iff value(dec1) == value(dec2) and scale(dec1) * == scale(dec2) @@ -76,7 +83,7 @@ public int hashCode() { if (values.get(i) == null) { continue; } - int type = (Integer) (types.get(i)); + int type = types.get(i); switch (type) { case Types.FLOAT: case Types.DOUBLE: @@ -98,8 +105,8 @@ public int hashCode() { public String toString() { StringBuilder sb = new StringBuilder(); for (int i = 0; i < values.size() - 1; i++) { - int type = (Integer) (types.get(i)); - if (Simba && (type == Types.VARCHAR)) { + int type = types.get(i); + if (simba && (type == Types.VARCHAR)) { String s1 = String.valueOf(values.get(i)); // if the field has a JSON string or list, then remove newlines so that // each record fits on a single line in the actual output file. @@ -114,8 +121,8 @@ public String toString() { } sb.append(values.get(i) + "\t"); } - int type = (Integer) (types.get(values.size()-1)); - if (Simba && (type == Types.VARCHAR)) { + int type = types.get(values.size()-1); + if (simba && (type == Types.VARCHAR)) { String s1 = String.valueOf(values.get(values.size()-1)); // if the field has a JSON string or list, then remove newlines so that // each record fits on a single line in the actual output file. @@ -147,17 +154,23 @@ private boolean compare(ColumnList o1, ColumnList o2) { if (oneNull(list1.get(i), list2.get(i))) { return false; } - int type = (Integer) (types.get(i)); + int type = types.get(i); try { switch (type) { case Types.FLOAT: case Types.REAL: float f1 = (Float) list1.get(i); float f2 = (Float) list2.get(i); - if ((f1 + f2) / 2 != 0) { - if (!(Math.abs((f1 - f2) / ((f1 + f2) / 2)) < 1.0E-6)) return false; - } else if (f1 != 0) { - return false; + if (f1 != f2) { + double relativeError; + if (f1 > f2) { + relativeError = Math.abs((f1-f2)/f1); + } else { + relativeError = Math.abs((f1-f2)/f2); + } + if (relativeError > MAX_DIFF_VALUE) { + return false; + } } break; case Types.DOUBLE: @@ -167,9 +180,13 @@ private boolean compare(ColumnList o1, ColumnList o2) { // especially for the cases when doubles are NaN / POSITIVE_INFINITY / NEGATIVE_INFINITY // otherwise proceed with "loosened" logic if (!d1.equals(d2)) { - if ((d1 + d2) / 2 != 0) { - if (!(Math.abs((d1 - d2) / ((d1 + d2) / 2)) < 1.0E-12)) return false; - } else if (d1 != 0) { + double relativeError; + if (d1 > d2) { + relativeError = Math.abs((d1-d2)/d1); + } else { + relativeError = Math.abs((d1-d2)/d2); + } + if (relativeError > MAX_DIFF_VALUE) { return false; } } diff --git a/framework/src/main/java/org/apache/drill/test/framework/TestVerifier.java b/framework/src/main/java/org/apache/drill/test/framework/TestVerifier.java index 089dd457f..d18ee40ec 100755 --- a/framework/src/main/java/org/apache/drill/test/framework/TestVerifier.java +++ b/framework/src/main/java/org/apache/drill/test/framework/TestVerifier.java @@ -20,7 +20,6 @@ import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.File; -import java.io.FileNotFoundException; import java.io.FileReader; import java.io.FileWriter; import java.io.IOException; @@ -30,7 +29,6 @@ import java.nio.file.Paths; import java.sql.Types; import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedHashMap; @@ -51,6 +49,7 @@ public class TestVerifier { private static final Logger LOG = Logger.getLogger("DrillTestLogger"); private static final int MAX_MISMATCH_SIZE = 10; + private static final double MAX_DIFF_VALUE = 1.0E-6; public TestStatus testStatus = TestStatus.PENDING; private int mapSize = 0; private List resultSet = null; @@ -253,28 +252,28 @@ private Map loadFromFileToMap(String filename, boolean orde typedFields.add(null); continue; } - int type = (Integer) (types.get(i)); + int type = types.get(i); try { switch (type) { - case Types.INTEGER: - case Types.BIGINT: - case Types.SMALLINT: - case Types.TINYINT: - typedFields.add(new BigInteger(fields[i])); - break; - case Types.FLOAT: - case Types.REAL: - typedFields.add(new Float(fields[i])); - break; - case Types.DOUBLE: - typedFields.add(new Double(fields[i])); - break; - case Types.DECIMAL: - typedFields.add(new BigDecimal(fields[i])); - break; - default: - typedFields.add(fields[i]); - break; + case Types.INTEGER: + case Types.BIGINT: + case Types.SMALLINT: + case Types.TINYINT: + typedFields.add(new BigInteger(fields[i])); + break; + case Types.FLOAT: + case Types.REAL: + typedFields.add(new Float(fields[i])); + break; + case Types.DOUBLE: + typedFields.add(new Double(fields[i])); + break; + case Types.DECIMAL: + typedFields.add(new BigDecimal(fields[i])); + break; + default: + typedFields.add(fields[i]); + break; } } catch (Exception e) { typedFields.add(fields[i]); @@ -399,7 +398,7 @@ private static int compareTo(ColumnList list1, ColumnList list2, return 0; } int idx = columnIndexAndOrder.get(start).index; - int result = -1; + int result; Object o1 = list1.getValues().get(idx); Object o2 = list2.getValues().get(idx); if (ColumnList.bothNull(o1, o2)) { @@ -408,9 +407,9 @@ private static int compareTo(ColumnList list1, ColumnList list2, return 0; // TODO handle NULLS FIRST and NULLS LAST cases } if (o1 instanceof Number) { - Number number1 = (Number) o1; - Number number2 = (Number) o2; - double diff = number1.doubleValue() - number2.doubleValue(); + double d1 = ((Number) o1).doubleValue(); + double d2 = ((Number) o2).doubleValue(); + double diff = d1 - d2; if (diff == 0) { return compareTo(list1, list2, columnIndexAndOrder, start + 1, orderByColumns); } else { @@ -447,9 +446,9 @@ private static int compareTo(ColumnList list1, ColumnList list2, return 1; } if (idNode1.isNumber()) { - Number number1 = (Number) idNode1.asInt(); - Number number2 = (Number) idNode2.asInt(); - double diff = number1.doubleValue() - number2.doubleValue(); + double d1 = idNode1.doubleValue(); + double d2 = idNode2.doubleValue(); + double diff = d1 - d2; if (diff == 0) { return compareTo(list1, list2, columnIndexAndOrder, start + 1, orderByColumns); } else { @@ -872,7 +871,6 @@ public static Map getOrderByColumns(String statement, column = column.trim(); String[] columnOrder = column.split("\\s+"); String columnName = columnOrder[0].trim(); - String fieldName; if (columnName.indexOf('.') >= 0) { // table name precedes column name. remove it columnName = columnName.substring(columnName.indexOf('.') + 1);