diff --git a/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java b/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java index 1919c222ba91..499ef7125004 100644 --- a/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java +++ b/core/trino-main/src/main/java/io/trino/operator/FlatHashStrategyCompiler.java @@ -23,15 +23,18 @@ import io.airlift.bytecode.Variable; import io.airlift.bytecode.control.ForLoop; import io.airlift.bytecode.control.IfStatement; +import io.airlift.bytecode.expression.BytecodeExpression; import io.trino.operator.scalar.CombineHashFunction; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.sql.gen.CallSiteBinder; import java.lang.invoke.MethodHandle; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -51,6 +54,7 @@ import static io.airlift.bytecode.expression.BytecodeExpressions.constantInt; import static io.airlift.bytecode.expression.BytecodeExpressions.constantLong; import static io.airlift.bytecode.expression.BytecodeExpressions.constantTrue; +import static io.airlift.bytecode.expression.BytecodeExpressions.equal; import static io.airlift.bytecode.expression.BytecodeExpressions.invokeDynamic; import static io.airlift.bytecode.expression.BytecodeExpressions.invokeStatic; import static io.airlift.bytecode.expression.BytecodeExpressions.lessThan; @@ -378,6 +382,8 @@ private static void generateHashBlocksBatched(ClassDefinition definition, List typeMethods = new HashMap<>(); for (KeyField keyField : keyFields) { MethodDefinition method; @@ -393,9 +399,13 @@ private static void generateHashBlocksBatched(ClassDefinition definition, List types = createTestingTypes(typeOperators); + FlatHashStrategy flatHashStrategy = joinCompiler.getFlatHashStrategy(types); + + int positionCount = 10; + // Attempting to touch any of the blocks would result in a NullPointerException + assertDoesNotThrow(() -> flatHashStrategy.hashBlocksBatched(new Block[types.size()], new long[positionCount], 0, 0)); + } @Test public void testBatchedRawHashesMatchSinglePositionHashes() { - List types = createTestingTypes(); - FlatHashStrategy flatHashStrategy = JOIN_COMPILER.getFlatHashStrategy(types); + List types = createTestingTypes(typeOperators); + FlatHashStrategy flatHashStrategy = joinCompiler.getFlatHashStrategy(types); int positionCount = 1024; Block[] blocks = new Block[types.size()]; @@ -66,17 +79,30 @@ public void testBatchedRawHashesMatchSinglePositionHashes() long[] hashes = new long[positionCount]; flatHashStrategy.hashBlocksBatched(blocks, hashes, 0, positionCount); - for (int position = 0; position < hashes.length; position++) { - long singleRowHash = flatHashStrategy.hash(blocks, position); - if (hashes[position] != singleRowHash) { - fail("Hash mismatch: %s <> %s at position %s - Values: %s".formatted(hashes[position], singleRowHash, position, singleRowTypesAndValues(types, blocks, position))); - } + assertHashesEqual(types, blocks, hashes, flatHashStrategy); + + // Convert all blocks to RunLengthEncoded and re-check results match + for (int i = 0; i < blocks.length; i++) { + blocks[i] = RunLengthEncodedBlock.create(blocks[i].getSingleValueBlock(0), positionCount); } + flatHashStrategy.hashBlocksBatched(blocks, hashes, 0, positionCount); + assertHashesEqual(types, blocks, hashes, flatHashStrategy); + // Ensure the formatting logic produces a real string and doesn't blow up since otherwise this code wouldn't be exercised assertNotNull(singleRowTypesAndValues(types, blocks, 0)); } - private static List createTestingTypes() + private static void assertHashesEqual(List types, Block[] blocks, long[] batchedHashes, FlatHashStrategy flatHashStrategy) + { + for (int position = 0; position < batchedHashes.length; position++) { + long singleRowHash = flatHashStrategy.hash(blocks, position); + if (batchedHashes[position] != singleRowHash) { + fail("Hash mismatch: %s <> %s at position %s - Values: %s".formatted(batchedHashes[position], singleRowHash, position, singleRowTypesAndValues(types, blocks, position))); + } + } + } + + private static List createTestingTypes(TypeOperators typeOperators) { List baseTypes = List.of( BIGINT, @@ -102,7 +128,7 @@ private static List createTestingTypes() builder.add(RowType.anonymous(baseTypes)); for (Type baseType : baseTypes) { builder.add(new ArrayType(baseType)); - builder.add(new MapType(baseType, baseType, TYPE_OPERATORS)); + builder.add(new MapType(baseType, baseType, typeOperators)); } return builder.build(); }