diff --git a/java/sketches/src/main/java/sleeper/sketches/Sketches.java b/java/sketches/src/main/java/sleeper/sketches/Sketches.java index 0b65286fd9..e8008b5d48 100644 --- a/java/sketches/src/main/java/sleeper/sketches/Sketches.java +++ b/java/sketches/src/main/java/sleeper/sketches/Sketches.java @@ -23,6 +23,7 @@ import sleeper.core.schema.Field; import sleeper.core.schema.Schema; import sleeper.core.schema.type.ByteArrayType; +import sleeper.core.schema.type.IntType; import sleeper.core.schema.type.PrimitiveType; import sleeper.core.schema.type.Type; @@ -50,19 +51,11 @@ public static Sketches from(Schema schema) { } public static ItemsSketch createSketch(Type type, int k) { - if (type instanceof PrimitiveType) { - return (ItemsSketch) ItemsSketch.getInstance(k, Comparator.naturalOrder()); - } else { - throw new IllegalArgumentException("Unknown key type of " + type); - } + return (ItemsSketch) ItemsSketch.getInstance(k, createComparator(type)); } public static ItemsUnion createUnion(Type type, int maxK) { - if (type instanceof PrimitiveType) { - return (ItemsUnion) ItemsUnion.getInstance(maxK, Comparator.naturalOrder()); - } else { - throw new IllegalArgumentException("Unknown key type of " + type); - } + return (ItemsUnion) ItemsUnion.getInstance(maxK, createComparator(type)); } public static Comparator createComparator(Type type) { @@ -95,6 +88,8 @@ public static void update(ItemsSketch sketch, Record record, Field field) { public static Object readValueFromSketchWithWrappedBytes(Object value, Field field) { if (value == null) { return null; + } else if (field.getType() instanceof IntType) { + return ((Long) value).intValue(); } else { return value; } @@ -104,6 +99,8 @@ private static Object convertValueForSketch(Record record, Field field) { Object value = record.get(field.getName()); if (value == null) { return null; + } else if (field.getType() instanceof IntType) { + return ((Integer) value).longValue(); } else if (field.getType() instanceof ByteArrayType) { return ByteArray.wrap((byte[]) value); } else { diff --git a/java/sketches/src/main/java/sleeper/sketches/SketchesSerDe.java b/java/sketches/src/main/java/sleeper/sketches/SketchesSerDe.java index 927c4f2a6e..5d5aa6a354 100644 --- a/java/sketches/src/main/java/sleeper/sketches/SketchesSerDe.java +++ b/java/sketches/src/main/java/sleeper/sketches/SketchesSerDe.java @@ -17,7 +17,7 @@ import com.facebook.collections.ByteArray; import org.apache.datasketches.ArrayOfItemsSerDe; -import org.apache.datasketches.ArrayOfNumbersSerDe; +import org.apache.datasketches.ArrayOfLongsSerDe; import org.apache.datasketches.ArrayOfStringsSerDe; import org.apache.datasketches.Util; import org.apache.datasketches.memory.Memory; @@ -35,7 +35,6 @@ import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; -import java.util.Comparator; import java.util.HashMap; import java.util.Map; @@ -49,8 +48,8 @@ public SketchesSerDe(Schema schema) { public void serialise(Sketches sketches, DataOutputStream dos) throws IOException { for (Field field : schema.getRowKeyFields()) { if (field.getType() instanceof IntType || field.getType() instanceof LongType) { - ItemsSketch sketch = sketches.getQuantilesSketch(field.getName()); - byte[] b = sketch.toByteArray(new ArrayOfNumbersSerDe()); + ItemsSketch sketch = sketches.getQuantilesSketch(field.getName()); + byte[] b = sketch.toByteArray(new ArrayOfLongsSerDe()); dos.writeInt(b.length); dos.write(b); } else if (field.getType() instanceof StringType) { @@ -81,16 +80,18 @@ private static ItemsSketch deserialise(DataInputStream dis, Type type) throws int length = dis.readInt(); byte[] b = new byte[length]; dis.readFully(b); - if (type instanceof IntType) { - return ItemsSketch.getInstance(Memory.wrap(b), Comparator.comparing(Number::intValue), new ArrayOfNumbersSerDe()); - } else if (type instanceof LongType) { - return ItemsSketch.getInstance(Memory.wrap(b), Comparator.comparing(Number::longValue), new ArrayOfNumbersSerDe()); + return ItemsSketch.getInstance(Memory.wrap(b), Sketches.createComparator(type), getItemsSerDe(type)); + } + + private static ArrayOfItemsSerDe getItemsSerDe(Type type) { + if (type instanceof IntType || type instanceof LongType) { + return (ArrayOfItemsSerDe) new ArrayOfLongsSerDe(); } else if (type instanceof StringType) { - return ItemsSketch.getInstance(Memory.wrap(b), Comparator.naturalOrder(), new ArrayOfStringsSerDe()); + return (ArrayOfItemsSerDe) new ArrayOfStringsSerDe(); } else if (type instanceof ByteArrayType) { - return ItemsSketch.getInstance(Memory.wrap(b), Comparator.naturalOrder(), new ArrayOfByteArraysSerSe()); + return (ArrayOfItemsSerDe) new ArrayOfByteArraysSerSe(); } else { - throw new IOException("Unknown key type of " + type); + throw new IllegalArgumentException("Unknown key type of " + type); } }