Skip to content

Commit

Permalink
[core] Introduce projection fields to EqualiserCodeGenerator (#4154)
Browse files Browse the repository at this point in the history
  • Loading branch information
JingsongLi authored Sep 10, 2024
1 parent b4447ef commit 588b6e0
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ public GeneratedClass<RecordComparator> generateRecordComparator(
}

@Override
public GeneratedClass<RecordEqualiser> generateRecordEqualiser(List<DataType> fieldTypes) {
return new EqualiserCodeGenerator(RowType.builder().fields(fieldTypes).build())
public GeneratedClass<RecordEqualiser> generateRecordEqualiser(
List<DataType> fieldTypes, int[] fields) {
return new EqualiserCodeGenerator(fieldTypes.toArray(new DataType[0]), fields)
.generateRecordEqualiser("RecordEqualiser");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,28 @@ package org.apache.paimon.codegen

import org.apache.paimon.codegen.GenerateUtils._
import org.apache.paimon.codegen.ScalarOperatorGens.{generateEquals, generateRowEqualiser}
import org.apache.paimon.types.{BooleanType, DataType, RowType}
import org.apache.paimon.types.{BooleanType, DataType}
import org.apache.paimon.types.DataTypeChecks.isCompositeType
import org.apache.paimon.types.DataTypeRoot._
import org.apache.paimon.utils.TypeUtils.isPrimitive

import scala.collection.JavaConverters._

class EqualiserCodeGenerator(fieldTypes: Array[DataType]) {
class EqualiserCodeGenerator(fieldTypes: Array[DataType], fields: Array[Int]) {

private val RECORD_EQUALISER = className[RecordEqualiser]
private val LEFT_INPUT = "left"
private val RIGHT_INPUT = "right"

def this(rowType: RowType) = {
this(rowType.getFieldTypes.asScala.toArray)
def this(fieldTypes: Array[DataType]) = {
this(fieldTypes, fieldTypes.indices.toArray)
}

def generateRecordEqualiser(name: String): GeneratedClass[RecordEqualiser] = {
// ignore time zone
val ctx = new CodeGeneratorContext
val className = newName(name)

val equalsMethodCodes = for (idx <- fieldTypes.indices) yield generateEqualsMethod(ctx, idx)
val equalsMethodCalls = for (idx <- fieldTypes.indices) yield {
val equalsMethodCodes = for (idx <- fields) yield generateEqualsMethod(ctx, idx)
val equalsMethodCalls = for (idx <- fields) yield {
val methodName = getEqualsMethodName(idx)
s"""result = result && $methodName($LEFT_INPUT, $RIGHT_INPUT);"""
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Function;

Expand Down Expand Up @@ -204,6 +205,30 @@ public void testSingleField(DataTypeRoot dataTypeRoot) {
assertBoolean(equaliser, func, testData.left(), testData.right(), false);
}

@RepeatedTest(100)
public void testProjection() {
GeneratedData field0 = TEST_DATA.get(DataTypeRoot.INTEGER);
GeneratedData field1 = TEST_DATA.get(DataTypeRoot.VARCHAR);
GeneratedData field2 = TEST_DATA.get(DataTypeRoot.BIGINT);

RecordEqualiser equaliser =
new EqualiserCodeGenerator(
new DataType[] {field0.dataType, field1.dataType, field2.dataType},
new int[] {1, 2})
.generateRecordEqualiser("projectionFieldEquals")
.newInstance(Thread.currentThread().getContextClassLoader());

boolean result =
equaliser.equals(
GenericRow.of(field0.left(), field1.left(), field2.left()),
GenericRow.of(field0.right(), field1.right(), field2.right()));
boolean expected =
Objects.equals(
GenericRow.of(field1.left(), field2.left()),
GenericRow.of(field1.right(), field2.right()));
assertThat(result).isEqualTo(expected);
}

@RepeatedTest(100)
public void testManyFields() {
int size = 499;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,7 @@ GeneratedClass<NormalizedKeyComputer> generateNormalizedKeyComputer(
GeneratedClass<RecordComparator> generateRecordComparator(
List<DataType> inputTypes, int[] sortFields);

/**
* Generate a {@link RecordEqualiser}.
*
* @param fieldTypes Both the input row field types and the sort key field types. Records are *
* compared by the first field, then the second field, then the third field and so on. All *
* fields are compared in ascending order.
*/
GeneratedClass<RecordEqualiser> generateRecordEqualiser(List<DataType> fieldTypes);
/** Generate a {@link RecordEqualiser} with fields. */
GeneratedClass<RecordEqualiser> generateRecordEqualiser(
List<DataType> fieldTypes, int[] fields);
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,15 @@ public static RecordComparator newRecordComparator(
}

public static RecordEqualiser newRecordEqualiser(List<DataType> fieldTypes) {
return newRecordEqualiser(fieldTypes, IntStream.range(0, fieldTypes.size()).toArray());
}

public static RecordEqualiser newRecordEqualiser(List<DataType> fieldTypes, int[] fields) {
return generate(
RecordEqualiser.class,
fieldTypes,
IntStream.range(0, fieldTypes.size()).toArray(),
() -> getCodeGenerator().generateRecordEqualiser(fieldTypes));
fields,
() -> getCodeGenerator().generateRecordEqualiser(fieldTypes, fields));
}

private static <T> T generate(
Expand Down

0 comments on commit 588b6e0

Please sign in to comment.