From ebe96cb3006c925238a727f3ea926be0087f2790 Mon Sep 17 00:00:00 2001 From: James Petty Date: Mon, 13 Nov 2023 13:28:50 -0500 Subject: [PATCH 1/3] Short circuit on zero-length hash length --- .../io/trino/operator/FlatHashStrategyCompiler.java | 11 +++++++++-- .../java/io/trino/operator/TestFlatHashStrategy.java | 12 ++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java b/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java index 1919c222ba91..97541befb65a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java @@ -51,6 +51,7 @@ import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; import static io.airlift.bytecode.expression.BytecodeExpressions.constantLong; import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; +import static io.airlift.bytecode.expression.BytecodeExpressions.equal; import static io.airlift.bytecode.expression.BytecodeExpressions.invokeDynamic; import static io.airlift.bytecode.expression.BytecodeExpressions.invokeStatic; import static io.airlift.bytecode.expression.BytecodeExpressions.lessThan; @@ -378,6 +379,8 @@ private static void generateHashBlocksBatched(ClassDefinition definition, List typeMethods = new HashMap<>(); for (KeyField keyField : keyFields) { MethodDefinition method; @@ -393,9 +396,13 @@ private static void generateHashBlocksBatched(ClassDefinition definition, List types = createTestingTypes(); + FlatHashStrategy flatHashStrategy = JOIN_COMPILER.getFlatHashStrategy(types); + + int positionCount = 10; + // Attempting to touch any of the blocks would result in a NullPointerException + assertDoesNotThrow(() -> flatHashStrategy.hashBlocksBatched(new Block[types.size()], new long[positionCount], 0, 0)); + } + @Test public void testBatchedRawHashesMatchSinglePositionHashes() { From 5617bce3f2eb8eeb3761ba50720092bf50f30654 Mon Sep 17 00:00:00 2001 From: James Petty Date: Mon, 13 Nov 2023 13:58:47 -0500 Subject: [PATCH 2/3] Special case handling of RLE Blocks for batched hashing --- .../operator/FlatHashStrategyCompiler.java | 68 +++++++++++++------ .../trino/operator/TestFlatHashStrategy.java | 24 +++++-- 2 files changed, 68 insertions(+), 24 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java b/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java index 97541befb65a..499ef7125004 100644 --- a/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java @@ -23,15 +23,18 @@ import io.airlift.bytecode.Variable; import io.airlift.bytecode.control.ForLoop; import io.airlift.bytecode.control.IfStatement; +import io.airlift.bytecode.expression.BytecodeExpression; import io.trino.operator.scalar.CombineHashFunction; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.sql.gen.CallSiteBinder; import java.lang.invoke.MethodHandle; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -429,35 +432,62 @@ private static MethodDefinition generateHashBlockVectorized(ClassDefinition defi Variable mayHaveNull = scope.declareVariable(boolean.class, "mayHaveNull"); Variable hash = scope.declareVariable(long.class, "hash"); - body.append(mayHaveNull.set(block.invoke("mayHaveNull", boolean.class))); body.append(position.set(invokeStatic(Objects.class, "checkFromToIndex", int.class, offset, add(offset, length), block.invoke("getPositionCount", int.class)))); body.append(invokeStatic(Objects.class, "checkFromIndexSize", int.class, constantInt(0), length, hashes.length()).pop()); - BytecodeBlock loopBody = new BytecodeBlock().append(new IfStatement("if (mayHaveNull && block.isNull(position))") - .condition(and(mayHaveNull, block.invoke("isNull", boolean.class, position))) - .ifTrue(hash.set(constantLong(NULL_HASH_CODE))) - .ifFalse(hash.set(invokeDynamic( - BOOTSTRAP_METHOD, - ImmutableList.of(callSiteBinder.bind(field.hashBlockMethod()).getBindingId()), - "hash", - long.class, - block, - position)))); + BytecodeExpression computeHashNonNull = invokeDynamic( + BOOTSTRAP_METHOD, + ImmutableList.of(callSiteBinder.bind(field.hashBlockMethod()).getBindingId()), + "hash", + long.class, + block, + position); + + BytecodeExpression setHashExpression; if (field.index() == 0) { // hashes[index] = hash; - loopBody.append(hashes.setElement(index, hash)); + setHashExpression = hashes.setElement(index, hash); } else { // hashes[index] = CombineHashFunction.getHash(hashes[index], hash); - loopBody.append(hashes.setElement(index, invokeStatic(CombineHashFunction.class, "getHash", long.class, hashes.getElement(index), hash))); + setHashExpression = hashes.setElement(index, invokeStatic(CombineHashFunction.class, "getHash", long.class, hashes.getElement(index), hash)); + } + + BytecodeBlock rleHandling = new BytecodeBlock() + .append(new IfStatement("hash = block.isNull(position) ? NULL_HASH_CODE : hash(block, position)") + .condition(block.invoke("isNull", boolean.class, position)) + .ifTrue(hash.set(constantLong(NULL_HASH_CODE))) + .ifFalse(hash.set(computeHashNonNull))); + if (field.index() == 0) { + // Arrays.fill(hashes, 0, length, hash) + rleHandling.append(invokeStatic(Arrays.class, "fill", void.class, hashes, constantInt(0), length, hash)); + } + else { + rleHandling.append(new ForLoop("for (int index = 0; index < length; index++) { hashes[index] = CombineHashFunction.getHash(hashes[index], hash); }") + .initialize(index.set(constantInt(0))) + .condition(lessThan(index, length)) + .update(index.increment()) + .body(setHashExpression)); } - loopBody.append(position.increment()); - body.append(new ForLoop("for (index = 0; index < length; index++)") - .initialize(index.set(constantInt(0))) - .condition(lessThan(index, length)) - .update(index.increment()) - .body(loopBody)) + BytecodeBlock computeHashLoop = new BytecodeBlock() + .append(mayHaveNull.set(block.invoke("mayHaveNull", boolean.class))) + .append(new ForLoop("for (int index = 0; index < length; index++)") + .initialize(index.set(constantInt(0))) + .condition(lessThan(index, length)) + .update(index.increment()) + .body(new BytecodeBlock() + .append(new IfStatement("if (mayHaveNull && block.isNull(position))") + .condition(and(mayHaveNull, block.invoke("isNull", boolean.class, position))) + .ifTrue(hash.set(constantLong(NULL_HASH_CODE))) + .ifFalse(hash.set(computeHashNonNull))) + .append(setHashExpression) + .append(position.increment()))); + + body.append(new IfStatement("if (block instanceof RunLengthEncodedBlock)") + .condition(block.instanceOf(RunLengthEncodedBlock.class)) + .ifTrue(rleHandling) + .ifFalse(computeHashLoop)) .ret(); return methodDefinition; diff --git a/core/trino-main/src/test/java/io/trino/operator/TestFlatHashStrategy.java b/core/trino-main/src/test/java/io/trino/operator/TestFlatHashStrategy.java index d9c7716d2f66..1eafd9a2f094 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestFlatHashStrategy.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestFlatHashStrategy.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import io.trino.spi.block.Block; +import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.ArrayType; import io.trino.spi.type.MapType; @@ -78,16 +79,29 @@ public void testBatchedRawHashesMatchSinglePositionHashes() long[] hashes = new long[positionCount]; flatHashStrategy.hashBlocksBatched(blocks, hashes, 0, positionCount); - for (int position = 0; position < hashes.length; position++) { - long singleRowHash = flatHashStrategy.hash(blocks, position); - if (hashes[position] != singleRowHash) { - fail("Hash mismatch: %s <> %s at position %s - Values: %s".formatted(hashes[position], singleRowHash, position, singleRowTypesAndValues(types, blocks, position))); - } + assertHashesEqual(types, blocks, hashes, flatHashStrategy); + + // Convert all blocks to RunLengthEncoded and re-check results match + for (int i = 0; i < blocks.length; i++) { + blocks[i] = RunLengthEncodedBlock.create(blocks[i].getSingleValueBlock(0), positionCount); } + flatHashStrategy.hashBlocksBatched(blocks, hashes, 0, positionCount); + assertHashesEqual(types, blocks, hashes, flatHashStrategy); + // Ensure the formatting logic produces a real string and doesn't blow up since otherwise this code wouldn't be exercised assertNotNull(singleRowTypesAndValues(types, blocks, 0)); } + private static void assertHashesEqual(List types, Block[] blocks, long[] batchedHashes, FlatHashStrategy flatHashStrategy) + { + for (int position = 0; position < batchedHashes.length; position++) { + long singleRowHash = flatHashStrategy.hash(blocks, position); + if (batchedHashes[position] != singleRowHash) { + fail("Hash mismatch: %s <> %s at position %s - Values: %s".formatted(batchedHashes[position], singleRowHash, position, singleRowTypesAndValues(types, blocks, position))); + } + } + } + private static List createTestingTypes() { List baseTypes = List.of( From 9ac52d5776775ec8582c580e62ad845e34156140 Mon Sep 17 00:00:00 2001 From: James Petty Date: Mon, 13 Nov 2023 14:07:18 -0500 Subject: [PATCH 3/3] Avoid retaining generated classes in TestFlatHashStrategy --- .../io/trino/operator/TestFlatHashStrategy.java | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/core/trino-main/src/test/java/io/trino/operator/TestFlatHashStrategy.java b/core/trino-main/src/test/java/io/trino/operator/TestFlatHashStrategy.java index 1eafd9a2f094..1948f04ea03f 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestFlatHashStrategy.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestFlatHashStrategy.java @@ -51,14 +51,14 @@ public class TestFlatHashStrategy { - private static final TypeOperators TYPE_OPERATORS = new TypeOperators(); - private static final JoinCompiler JOIN_COMPILER = new JoinCompiler(TYPE_OPERATORS); + private final TypeOperators typeOperators = new TypeOperators(); + private final JoinCompiler joinCompiler = new JoinCompiler(typeOperators); @Test public void testBatchedRawHashesZeroLength() { - List types = createTestingTypes(); - FlatHashStrategy flatHashStrategy = JOIN_COMPILER.getFlatHashStrategy(types); + List types = createTestingTypes(typeOperators); + FlatHashStrategy flatHashStrategy = joinCompiler.getFlatHashStrategy(types); int positionCount = 10; // Attempting to touch any of the blocks would result in a NullPointerException @@ -68,8 +68,8 @@ public void testBatchedRawHashesZeroLength() @Test public void testBatchedRawHashesMatchSinglePositionHashes() { - List types = createTestingTypes(); - FlatHashStrategy flatHashStrategy = JOIN_COMPILER.getFlatHashStrategy(types); + List types = createTestingTypes(typeOperators); + FlatHashStrategy flatHashStrategy = joinCompiler.getFlatHashStrategy(types); int positionCount = 1024; Block[] blocks = new Block[types.size()]; @@ -102,7 +102,7 @@ private static void assertHashesEqual(List types, Block[] blocks, long[] b } } - private static List createTestingTypes() + private static List createTestingTypes(TypeOperators typeOperators) { List baseTypes = List.of( BIGINT, @@ -128,7 +128,7 @@ private static List createTestingTypes() builder.add(RowType.anonymous(baseTypes)); for (Type baseType : baseTypes) { builder.add(new ArrayType(baseType)); - builder.add(new MapType(baseType, baseType, TYPE_OPERATORS)); + builder.add(new MapType(baseType, baseType, typeOperators)); } return builder.build(); }