diff --git a/paimon-common/src/main/java/org/apache/paimon/CoreOptions.java b/paimon-common/src/main/java/org/apache/paimon/CoreOptions.java index 78f8fc4f2409..0969dba38cb1 100644 --- a/paimon-common/src/main/java/org/apache/paimon/CoreOptions.java +++ b/paimon-common/src/main/java/org/apache/paimon/CoreOptions.java @@ -64,6 +64,8 @@ public class CoreOptions implements Serializable { public static final String FIELDS_PREFIX = "fields"; + public static final String FIELDS_SEPARATOR = ","; + public static final String AGG_FUNCTION = "aggregate-function"; public static final String DEFAULT_AGG_FUNCTION = "default-aggregate-function"; diff --git a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunction.java b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunction.java index 680e3dd06dc1..6307c551d582 100644 --- a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunction.java +++ b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunction.java @@ -24,29 +24,19 @@ import org.apache.paimon.data.InternalRow; import org.apache.paimon.mergetree.compact.aggregate.FieldAggregator; import org.apache.paimon.options.Options; -import org.apache.paimon.types.BigIntType; -import org.apache.paimon.types.CharType; +import org.apache.paimon.types.DataField; import org.apache.paimon.types.DataType; -import org.apache.paimon.types.DataTypeDefaultVisitor; -import org.apache.paimon.types.DateType; -import org.apache.paimon.types.DecimalType; -import org.apache.paimon.types.DoubleType; -import org.apache.paimon.types.FloatType; -import org.apache.paimon.types.IntType; -import org.apache.paimon.types.LocalZonedTimestampType; import org.apache.paimon.types.RowKind; import org.apache.paimon.types.RowType; -import org.apache.paimon.types.SmallIntType; -import org.apache.paimon.types.TimestampType; -import org.apache.paimon.types.TinyIntType; -import org.apache.paimon.types.VarCharType; -import org.apache.paimon.utils.InternalRowUtils; import org.apache.paimon.utils.Projection; +import org.apache.paimon.utils.UserDefinedSeqComparator; import javax.annotation.Nullable; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -56,6 +46,7 @@ import java.util.stream.Stream; import static org.apache.paimon.CoreOptions.FIELDS_PREFIX; +import static org.apache.paimon.CoreOptions.FIELDS_SEPARATOR; import static org.apache.paimon.utils.InternalRowUtils.createFieldGetters; /** @@ -67,7 +58,7 @@ public class PartialUpdateMergeFunction implements MergeFunction { public static final String SEQUENCE_GROUP = "sequence-group"; private final InternalRow.FieldGetter[] getters; - private final Map fieldSequences; + private final Map fieldSeqComparators; private final boolean fieldSequenceEnabled; private final Map fieldAggregators; @@ -78,11 +69,11 @@ public class PartialUpdateMergeFunction implements MergeFunction { protected PartialUpdateMergeFunction( InternalRow.FieldGetter[] getters, - Map fieldSequences, + Map fieldSeqComparators, Map fieldAggregators, boolean fieldSequenceEnabled) { this.getters = getters; - this.fieldSequences = fieldSequences; + this.fieldSeqComparators = fieldSeqComparators; this.fieldAggregators = fieldAggregators; this.fieldSequenceEnabled = fieldSequenceEnabled; } @@ -117,7 +108,7 @@ public void add(KeyValue kv) { } latestSequenceNumber = kv.sequenceNumber(); - if (fieldSequences.isEmpty()) { + if (fieldSeqComparators.isEmpty()) { updateNonNullFields(kv); } else { updateWithSequenceGroup(kv); @@ -136,25 +127,31 @@ private void updateNonNullFields(KeyValue kv) { private void updateWithSequenceGroup(KeyValue kv) { for (int i = 0; i < getters.length; i++) { Object field = getters[i].getFieldOrNull(kv.value()); - SequenceGenerator sequenceGen = fieldSequences.get(i); + UserDefinedSeqComparator seqComparator = fieldSeqComparators.get(i); FieldAggregator aggregator = fieldAggregators.get(i); Object accumulator = getters[i].getFieldOrNull(row); - if (sequenceGen == null) { + if (seqComparator == null) { if (aggregator != null) { row.setField(i, aggregator.agg(accumulator, field)); } else if (field != null) { row.setField(i, field); } } else { - Long currentSeq = sequenceGen.generate(kv.value()); - if (currentSeq != null) { - Long previousSeq = sequenceGen.generate(row); - if (previousSeq == null || currentSeq >= previousSeq) { - row.setField( - i, aggregator == null ? field : aggregator.agg(accumulator, field)); - } else if (aggregator != null) { - row.setField(i, aggregator.agg(field, accumulator)); + if (seqComparator.compare(kv.value(), row) >= 0) { + int index = i; + + // Multiple sequence fields should be updated at once. + if (Arrays.stream(seqComparator.compareFields()) + .anyMatch(seqIndex -> seqIndex == index)) { + for (int fieldIndex : seqComparator.compareFields()) { + row.setField( + fieldIndex, getters[fieldIndex].getFieldOrNull(kv.value())); + } } + row.setField( + i, aggregator == null ? field : aggregator.agg(accumulator, field)); + } else if (aggregator != null) { + row.setField(i, aggregator.agg(field, accumulator)); } } } @@ -162,38 +159,37 @@ private void updateWithSequenceGroup(KeyValue kv) { private void retractWithSequenceGroup(KeyValue kv) { for (int i = 0; i < getters.length; i++) { - SequenceGenerator sequenceGen = fieldSequences.get(i); - if (sequenceGen != null) { - Long currentSeq = sequenceGen.generate(kv.value()); - if (currentSeq != null) { - Long previousSeq = sequenceGen.generate(row); - FieldAggregator aggregator = fieldAggregators.get(i); - if (previousSeq == null || currentSeq >= previousSeq) { - if (sequenceGen.index() == i) { - // update sequence field - row.setField(i, getters[i].getFieldOrNull(kv.value())); + UserDefinedSeqComparator seqComparator = fieldSeqComparators.get(i); + if (seqComparator != null) { + FieldAggregator aggregator = fieldAggregators.get(i); + if (seqComparator.compare(kv.value(), row) >= 0) { + int index = i; + + // Multiple sequence fields should be updated at once. + if (Arrays.stream(seqComparator.compareFields()) + .anyMatch(field -> field == index)) { + for (int field : seqComparator.compareFields()) { + row.setField(field, getters[field].getFieldOrNull(kv.value())); + } + } else { + // retract normal field + if (aggregator == null) { + row.setField(i, null); } else { - // retract normal field - if (aggregator == null) { - row.setField(i, null); - } else { - // retract agg field - Object accumulator = getters[i].getFieldOrNull(row); - row.setField( - i, - aggregator.retract( - accumulator, - getters[i].getFieldOrNull(kv.value()))); - } + // retract agg field + Object accumulator = getters[i].getFieldOrNull(row); + row.setField( + i, + aggregator.retract( + accumulator, getters[i].getFieldOrNull(kv.value()))); } - } else if (aggregator != null) { - // retract agg field for old sequence - Object accumulator = getters[i].getFieldOrNull(row); - row.setField( - i, - aggregator.retract( - accumulator, getters[i].getFieldOrNull(kv.value()))); } + } else if (aggregator != null) { + // retract agg field for old sequence + Object accumulator = getters[i].getFieldOrNull(row); + row.setField( + i, + aggregator.retract(accumulator, getters[i].getFieldOrNull(kv.value()))); } } } @@ -216,56 +212,70 @@ private static class Factory implements MergeFunctionFactory { private static final long serialVersionUID = 1L; + private final RowType rowType; + + private final List primaryFields; + private final List tableTypes; - private final Map fieldSequences; + + private final Map fieldSeqComparators; private final Map fieldAggregators; private Factory(Options options, RowType rowType, List primaryKeys) { + this.rowType = rowType; + this.primaryFields = new ArrayList<>(); + for (String primaryKey : primaryKeys) { + this.primaryFields.add(rowType.getField(primaryKey)); + } this.tableTypes = rowType.getFieldTypes(); List fieldNames = rowType.getFieldNames(); - this.fieldSequences = new HashMap<>(); + this.fieldSeqComparators = new HashMap<>(); for (Map.Entry entry : options.toMap().entrySet()) { String k = entry.getKey(); String v = entry.getValue(); if (k.startsWith(FIELDS_PREFIX) && k.endsWith(SEQUENCE_GROUP)) { - String sequenceFieldName = - k.substring( - FIELDS_PREFIX.length() + 1, - k.length() - SEQUENCE_GROUP.length() - 1); - SequenceGenerator sequenceGen = - new SequenceGenerator(sequenceFieldName, rowType); - Arrays.stream(v.split(",")) + List sequenceFields = + Arrays.stream( + k.substring( + FIELDS_PREFIX.length() + 1, + k.length() + - SEQUENCE_GROUP.length() + - 1) + .split(FIELDS_SEPARATOR)) + .map(fieldName -> validateFieldName(fieldName, fieldNames)) + .collect(Collectors.toList()); + + UserDefinedSeqComparator userDefinedSeqComparator = + UserDefinedSeqComparator.create(rowType, sequenceFields); + Arrays.stream(v.split(FIELDS_SEPARATOR)) .map( - fieldName -> { - int field = fieldNames.indexOf(fieldName); - if (field == -1) { - throw new IllegalArgumentException( - String.format( - "Field %s can not be found in table schema", - fieldName)); - } - return field; - }) + fieldName -> + fieldNames.indexOf( + validateFieldName(fieldName, fieldNames))) .forEach( field -> { - if (fieldSequences.containsKey(field)) { + if (fieldSeqComparators.containsKey(field)) { throw new IllegalArgumentException( String.format( "Field %s is defined repeatedly by multiple groups: %s", fieldNames.get(field), k)); } - fieldSequences.put(field, sequenceGen); + fieldSeqComparators.put(field, userDefinedSeqComparator); }); // add self - fieldSequences.put(sequenceGen.index(), sequenceGen); + sequenceFields.forEach( + fieldName -> { + int index = fieldNames.indexOf(fieldName); + fieldSeqComparators.put(index, userDefinedSeqComparator); + }); } } this.fieldAggregators = createFieldAggregators(rowType, primaryKeys, new CoreOptions(options)); - if (fieldAggregators.size() > 0 && fieldSequences.isEmpty()) { + if (!fieldAggregators.isEmpty() && fieldSeqComparators.isEmpty()) { throw new IllegalArgumentException( "Must use sequence group for aggregation functions."); } @@ -274,29 +284,47 @@ private Factory(Options options, RowType rowType, List primaryKeys) { @Override public MergeFunction create(@Nullable int[][] projection) { if (projection != null) { - Map projectedSequences = new HashMap<>(); + Map projectedSeqComparators = new HashMap<>(); Map projectedAggregators = new HashMap<>(); int[] projects = Projection.of(projection).toTopLevelIndexes(); Map indexMap = new HashMap<>(); + List dataFields = rowType.getFields(); + Set newDataFields = new HashSet<>(primaryFields); + for (int i = 0; i < projects.length; i++) { indexMap.put(projects[i], i); + newDataFields.add(dataFields.get(projects[i])); } - fieldSequences.forEach( - (field, sequence) -> { + RowType newRowType = + new RowType(rowType.isNullable(), new ArrayList<>(newDataFields)); + + fieldSeqComparators.forEach( + (field, comparator) -> { int newField = indexMap.getOrDefault(field, -1); if (newField != -1) { - int newSequenceId = indexMap.getOrDefault(sequence.index(), -1); - if (newSequenceId == -1) { - throw new RuntimeException( - String.format( - "Can not find new sequence field for new field. new field index is %s", - newField)); - } else { - projectedSequences.put( - newField, - new SequenceGenerator( - newSequenceId, sequence.fieldType())); - } + int[] newSequenceFields = + Arrays.stream(comparator.compareFields()) + .map( + index -> { + int newIndex = + indexMap.getOrDefault( + index, -1); + if (newIndex == -1) { + throw new RuntimeException( + String.format( + "Can not find new sequence field " + + "for new field. new field " + + "index is %s", + newField)); + } else { + return newIndex; + } + }) + .toArray(); + projectedSeqComparators.put( + newField, + UserDefinedSeqComparator.create( + newRowType, newSequenceFields)); } }); for (int i = 0; i < projects.length; i++) { @@ -307,21 +335,21 @@ public MergeFunction create(@Nullable int[][] projection) { return new PartialUpdateMergeFunction( createFieldGetters(Projection.of(projection).project(tableTypes)), - projectedSequences, + projectedSeqComparators, projectedAggregators, - !fieldSequences.isEmpty()); + !fieldSeqComparators.isEmpty()); } else { return new PartialUpdateMergeFunction( createFieldGetters(tableTypes), - fieldSequences, + fieldSeqComparators, fieldAggregators, - !fieldSequences.isEmpty()); + !fieldSeqComparators.isEmpty()); } } @Override public AdjustedProjection adjustProjection(@Nullable int[][] projection) { - if (fieldSequences.isEmpty()) { + if (fieldSeqComparators.isEmpty()) { return new AdjustedProjection(projection, null); } @@ -332,9 +360,15 @@ public AdjustedProjection adjustProjection(@Nullable int[][] projection) { int[] topProjects = Projection.of(projection).toTopLevelIndexes(); Set indexSet = Arrays.stream(topProjects).boxed().collect(Collectors.toSet()); for (int index : topProjects) { - SequenceGenerator generator = fieldSequences.get(index); - if (generator != null && !indexSet.contains(generator.index())) { - extraFields.add(generator.index()); + UserDefinedSeqComparator comparator = fieldSeqComparators.get(index); + if (comparator == null) { + continue; + } + + for (int field : comparator.compareFields()) { + if (!indexSet.contains(field)) { + extraFields.add(field); + } } } @@ -343,11 +377,21 @@ public AdjustedProjection adjustProjection(@Nullable int[][] projection) { .mapToInt(Integer::intValue) .toArray(); - int[][] pushdown = Projection.of(allProjects).toNestedIndexes(); + int[][] pushDown = Projection.of(allProjects).toNestedIndexes(); int[][] outer = Projection.of(IntStream.range(0, topProjects.length).toArray()) .toNestedIndexes(); - return new AdjustedProjection(pushdown, outer); + return new AdjustedProjection(pushDown, outer); + } + + private String validateFieldName(String fieldName, List fieldNames) { + int field = fieldNames.indexOf(fieldName); + if (field == -1) { + throw new IllegalArgumentException( + String.format("Field %s can not be found in table schema", fieldName)); + } + + return fieldName; } /** @@ -395,137 +439,4 @@ private Map createFieldAggregators( return fieldAggregators; } } - - private static class SequenceGenerator { - - private final int index; - - private final Generator generator; - private final DataType fieldType; - - private SequenceGenerator(String field, RowType rowType) { - index = rowType.getFieldNames().indexOf(field); - if (index == -1) { - throw new RuntimeException( - String.format( - "Can not find sequence field %s in table schema: %s", - field, rowType)); - } - fieldType = rowType.getTypeAt(index); - generator = fieldType.accept(new SequenceGeneratorVisitor()); - } - - public SequenceGenerator(int index, DataType dataType) { - this.index = index; - - this.fieldType = dataType; - if (index == -1) { - throw new RuntimeException(String.format("Index : %s is invalid", index)); - } - generator = fieldType.accept(new SequenceGeneratorVisitor()); - } - - public int index() { - return index; - } - - public DataType fieldType() { - return fieldType; - } - - @Nullable - public Long generate(InternalRow row) { - return generator.generateNullable(row, index); - } - - private interface Generator { - long generate(InternalRow row, int i); - - @Nullable - default Long generateNullable(InternalRow row, int i) { - if (row.isNullAt(i)) { - return null; - } - return generate(row, i); - } - } - - private static class SequenceGeneratorVisitor extends DataTypeDefaultVisitor { - - @Override - public Generator visit(CharType charType) { - return stringGenerator(); - } - - @Override - public Generator visit(VarCharType varCharType) { - return stringGenerator(); - } - - private Generator stringGenerator() { - return (row, i) -> Long.parseLong(row.getString(i).toString()); - } - - @Override - public Generator visit(DecimalType decimalType) { - return (row, i) -> - InternalRowUtils.castToIntegral( - row.getDecimal( - i, decimalType.getPrecision(), decimalType.getScale())); - } - - @Override - public Generator visit(TinyIntType tinyIntType) { - return InternalRow::getByte; - } - - @Override - public Generator visit(SmallIntType smallIntType) { - return InternalRow::getShort; - } - - @Override - public Generator visit(IntType intType) { - return InternalRow::getInt; - } - - @Override - public Generator visit(BigIntType bigIntType) { - return InternalRow::getLong; - } - - @Override - public Generator visit(FloatType floatType) { - return (row, i) -> (long) row.getFloat(i); - } - - @Override - public Generator visit(DoubleType doubleType) { - return (row, i) -> (long) row.getDouble(i); - } - - @Override - public Generator visit(DateType dateType) { - return InternalRow::getInt; - } - - @Override - public Generator visit(TimestampType timestampType) { - return (row, i) -> - row.getTimestamp(i, timestampType.getPrecision()).getMillisecond(); - } - - @Override - public Generator visit(LocalZonedTimestampType localZonedTimestampType) { - return (row, i) -> - row.getTimestamp(i, localZonedTimestampType.getPrecision()) - .getMillisecond(); - } - - @Override - protected Generator defaultMethod(DataType dataType) { - throw new UnsupportedOperationException("Unsupported type: " + dataType); - } - } - } } diff --git a/paimon-core/src/main/java/org/apache/paimon/utils/UserDefinedSeqComparator.java b/paimon-core/src/main/java/org/apache/paimon/utils/UserDefinedSeqComparator.java index 35fa7a66d775..ec7a00bcb3b2 100644 --- a/paimon-core/src/main/java/org/apache/paimon/utils/UserDefinedSeqComparator.java +++ b/paimon-core/src/main/java/org/apache/paimon/utils/UserDefinedSeqComparator.java @@ -62,8 +62,18 @@ public static UserDefinedSeqComparator create(RowType rowType, List sequ List fieldNames = rowType.getFieldNames(); int[] fields = sequenceFields.stream().mapToInt(fieldNames::indexOf).toArray(); + + return create(rowType, fields); + } + + @Nullable + public static UserDefinedSeqComparator create(RowType rowType, int[] sequenceFields) { + if (sequenceFields.length == 0) { + return null; + } + RecordComparator comparator = - CodeGenUtils.newRecordComparator(rowType.getFieldTypes(), fields); - return new UserDefinedSeqComparator(fields, comparator); + CodeGenUtils.newRecordComparator(rowType.getFieldTypes(), sequenceFields); + return new UserDefinedSeqComparator(sequenceFields, comparator); } } diff --git a/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunctionTest.java b/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunctionTest.java index fa41607d10c3..03c5b83dbcf1 100644 --- a/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunctionTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunctionTest.java @@ -95,6 +95,43 @@ public void testSequenceGroup() { validate(func, 1, null, null, 6, null, null, 6); } + @Test + public void testMultiSequenceFields() { + Options options = new Options(); + options.set("fields.f3,f4.sequence-group", "f1,f2"); + options.set("fields.f7,f8.sequence-group", "f5,f6"); + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT()); + MergeFunction func = + PartialUpdateMergeFunction.factory(options, rowType, ImmutableList.of("f0")) + .create(); + func.reset(); + add(func, 1, 1, 1, 1, 1, 1, 1, 1, 3); + add(func, 1, 2, 2, 2, 2, 2, 1, 1, null); + validate(func, 1, 2, 2, 2, 2, 1, 1, 1, 3); + add(func, 1, 1, 3, 1, 3, 3, 3, 3, 2); + validate(func, 1, 2, 2, 2, 2, 3, 3, 3, 2); + + // delete + add(func, RowKind.DELETE, 1, 1, 1, 3, 3, 1, 1, null, null); + validate(func, 1, null, null, 3, 3, 3, 3, 3, 2); + add(func, RowKind.DELETE, 1, 1, 1, 3, 1, 1, 1, 4, 4); + validate(func, 1, null, null, 3, 3, null, null, 4, 4); + add(func, 1, 4, 4, 4, 4, 5, 5, 5, 5); + validate(func, 1, 4, 4, 4, 4, 5, 5, 5, 5); + add(func, RowKind.DELETE, 1, 1, 1, 6, 1, 1, 1, 6, 1); + validate(func, 1, null, null, 6, 1, null, null, 6, 1); + } + @Test public void testSequenceGroupDefaultAggFunc() { Options options = new Options(); @@ -123,6 +160,36 @@ public void testSequenceGroupDefaultAggFunc() { validate(func, 1, 4, 2, 4, 5, 3, 5); } + @Test + public void testMultiSequenceFieldsDefaultAggFunc() { + Options options = new Options(); + options.set("fields.f3,f4.sequence-group", "f1,f2"); + options.set("fields.f7,f8.sequence-group", "f5,f6"); + options.set(FIELDS_DEFAULT_AGG_FUNC, "last_non_null_value"); + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT()); + MergeFunction func = + PartialUpdateMergeFunction.factory(options, rowType, ImmutableList.of("f0")) + .create(); + func.reset(); + add(func, 1, 1, 1, 1, 1, 1, 1, 1, 1); + add(func, 1, 2, 2, 2, 2, 2, 2, null, null); + validate(func, 1, 2, 2, 2, 2, 1, 1, 1, 1); + add(func, 1, 3, 3, 1, 1, 3, 3, 3, 3); + validate(func, 1, 2, 2, 2, 2, 3, 3, 3, 3); + add(func, 1, 4, null, 4, 4, 5, null, 5, 5); + validate(func, 1, 4, 2, 4, 4, 5, 3, 5, 5); + } + @Test public void testSequenceGroupDefinedNoField() { Options options = new Options(); @@ -144,6 +211,27 @@ public void testSequenceGroupDefinedNoField() { .hasMessageContaining("can not be found in table schema"); } + @Test + public void testMultiSequenceFieldsDefinedNoField() { + Options options = new Options(); + options.set("fields.f2,f3.sequence-group", "f1,f7"); + options.set("fields.f5,f6.sequence-group", "f4"); + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT()); + assertThatThrownBy( + () -> + PartialUpdateMergeFunction.factory( + options, rowType, ImmutableList.of("f0"))) + .hasMessageContaining("can not be found in table schema"); + } + @Test public void testSequenceGroupRepeatDefine() { Options options = new Options(); @@ -163,6 +251,27 @@ public void testSequenceGroupRepeatDefine() { .hasMessageContaining("is defined repeatedly by multiple groups"); } + @Test + public void testMultiSequenceFieldsRepeatDefine() { + Options options = new Options(); + options.set("fields.f3,f4.sequence-group", "f1,f2"); + options.set("fields.f5,f6.sequence-group", "f1,f2"); + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT()); + assertThatThrownBy( + () -> + PartialUpdateMergeFunction.factory( + options, rowType, ImmutableList.of("f0"))) + .hasMessageContaining("is defined repeatedly by multiple groups"); + } + @Test public void testAdjustProjectionRepeatProject() { Options options = new Options(); @@ -198,6 +307,45 @@ public void testAdjustProjectionRepeatProject() { validate(func, 3, 3, null, 2, 4, 2); } + @Test + public void testMultiSequenceFieldsAdjustProjectionRepeatProject() { + Options options = new Options(); + options.set("fields.f2,f4.sequence-group", "f1,f3"); + options.set("fields.f5,f6.sequence-group", "f7"); + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT()); + // the field 'f1' is projected twice + int[][] projection = new int[][] {{1}, {1}, {3}, {7}}; + MergeFunctionFactory factory = + PartialUpdateMergeFunction.factory(options, rowType, ImmutableList.of("f0")); + MergeFunctionFactory.AdjustedProjection adjustedProjection = + factory.adjustProjection(projection); + + validate(adjustedProjection, new int[] {1, 1, 3, 7, 2, 4, 5, 6}, new int[] {0, 1, 2, 3}); + + MergeFunction func = factory.create(adjustedProjection.pushdownProjection); + func.reset(); + add(func, 1, 1, 1, 1, 1, 1, 1, 1); + add(func, 2, 2, 6, 2, 2, 2, 2, 6); + validate(func, 2, 2, 6, 2, 2, 2, 2, 6); + + // update first sequence group + add(func, 3, 3, null, 7, 4, null, 1, 8); + validate(func, 3, 3, null, 2, 4, null, 2, 6); + + // update second sequence group + add(func, 5, 5, 3, 3, 3, 5, 5, 6); + validate(func, 5, 3, null, 3, 4, null, 5, 6); + } + @Test public void testAdjustProjectionSequenceFieldsProject() { Options options = new Options(); @@ -230,6 +378,38 @@ public void testAdjustProjectionSequenceFieldsProject() { validate(func, 1, 1, 1, 2, 2); } + @Test + public void testMultiSequenceFieldsAdjustProjectionProject() { + Options options = new Options(); + options.set("fields.f2,f4.sequence-group", "f1,f3"); + options.set("fields.f5,f6.sequence-group", "f7"); + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT()); + // the sequence field 'f4' is projected too + int[][] projection = new int[][] {{1}, {4}, {3}, {7}}; + MergeFunctionFactory factory = + PartialUpdateMergeFunction.factory(options, rowType, ImmutableList.of("f0")); + MergeFunctionFactory.AdjustedProjection adjustedProjection = + factory.adjustProjection(projection); + + validate(adjustedProjection, new int[] {1, 4, 3, 7, 2, 5, 6}, new int[] {0, 1, 2, 3}); + + MergeFunction func = factory.create(adjustedProjection.pushdownProjection); + func.reset(); + // if sequence field is null, the related fields should not be updated + add(func, 1, 1, 1, 1, 1, 1, 1); + add(func, 1, null, 1, 3, 2, 2, 2); + validate(func, 1, null, 1, 3, 2, 2, 2); + } + @Test public void testAdjustProjectionAllFieldsProject() { Options options = new Options(); @@ -265,6 +445,41 @@ public void testAdjustProjectionAllFieldsProject() { validate(func, 4, 2, 4, 2, 2, 1, 1, 1); } + @Test + public void testMultiSequenceFieldsAdjustProjectionAllFieldsProject() { + Options options = new Options(); + options.set("fields.f2,f4.sequence-group", "f1,f3"); + options.set("fields.f5,f6.sequence-group", "f7"); + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT()); + // all fields are projected + int[][] projection = new int[][] {{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}}; + MergeFunctionFactory factory = + PartialUpdateMergeFunction.factory(options, rowType, ImmutableList.of("f0")); + MergeFunctionFactory.AdjustedProjection adjustedProjection = + factory.adjustProjection(projection); + + validate( + adjustedProjection, + new int[] {0, 1, 2, 3, 4, 5, 6, 7}, + new int[] {0, 1, 2, 3, 4, 5, 6, 7}); + + MergeFunction func = factory.create(adjustedProjection.pushdownProjection); + func.reset(); + // 'f6' has no sequence group, it should not be updated by null + add(func, 1, 1, 1, 1, 1, 1, 1, 1); + add(func, 4, 2, 4, 2, 2, 0, null, 3); + validate(func, 4, 2, 4, 2, 2, 1, 1, 1); + } + @Test public void testAdjustProjectionNonProject() { Options options = new Options(); @@ -372,6 +587,33 @@ public void testFirstValue() { validate(func, 1, 2, 3, 2); } + @Test + public void testMultiSequenceFieldsFirstValue() { + Options options = new Options(); + options.set("fields.f1,f2.sequence-group", "f3,f4"); + options.set("fields.f3.aggregate-function", "first_value"); + options.set("fields.f4.aggregate-function", "last_value"); + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT()); + MergeFunction func = + PartialUpdateMergeFunction.factory(options, rowType, ImmutableList.of("f0")) + .create(); + + func.reset(); + + // f7 sequence group 2 + add(func, 1, 1, 1, 1, 1); + add(func, 1, 2, 2, 2, 2); + validate(func, 1, 2, 2, 1, 2); + add(func, 1, 0, 1, 3, 3); + validate(func, 1, 2, 2, 3, 2); + } + @Test public void testPartialUpdateWithAggregation() { Options options = new Options(); @@ -431,6 +673,65 @@ public void testPartialUpdateWithAggregation() { validate(func, 1, 3, -3, null, 1, 1, 1, 3); } + @Test + public void testMultiSequenceFieldsPartialUpdateWithAggregation() { + Options options = new Options(); + options.set("fields.f1,f2.sequence-group", "f3,f4,f5"); + options.set("fields.f7,f8.sequence-group", "f6"); + options.set("fields.f0.aggregate-function", "listagg"); + options.set("fields.f3.aggregate-function", "sum"); + options.set("fields.f4.aggregate-function", "first_value"); + options.set("fields.f5.aggregate-function", "last_value"); + options.set("fields.f6.aggregate-function", "last_non_null_value"); + options.set("fields.f4.ignore-retract", "true"); + options.set("fields.f6.ignore-retract", "true"); + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT()); + MergeFunction func = + PartialUpdateMergeFunction.factory(options, rowType, ImmutableList.of("f0")) + .create(); + + func.reset(); + // f0 pk + // f1, f2 sequence group 1 + // f3 in f1, f2 group with sum agg + // f4 in f1, f2 group with first_value agg + // f5 in f1, f2 group with last_value agg + // f6 in f7, f8 group with last_not_null agg + // f7, f8 sequence group 2 + add(func, 1, 1, 1, 1, 1, 1, 1, 1, 1); + add(func, 1, 1, 2, 1, 2, 2, null, 2, 0); + validate(func, 1, 1, 2, 2, 1, 2, 1, 2, 0); + + // sequence group not advanced + add(func, 1, 1, 1, 1, 3, 1, 1, 2, 0); + validate(func, 1, 1, 2, 3, 3, 2, 1, 2, 0); + + // test null + add(func, 1, 1, 3, null, null, null, null, 4, 2); + validate(func, 1, 1, 3, 3, 3, null, 1, 4, 2); + + // test retract + add(func, 1, 2, 3, 1, 1, 1, 1, 4, 3); + validate(func, 1, 2, 3, 4, 3, 1, 1, 4, 3); + add(func, RowKind.UPDATE_BEFORE, 1, 2, 3, 2, 1, 2, 1, 4, 3); + validate(func, 1, 2, 3, 2, 3, null, 1, 4, 3); + add(func, RowKind.DELETE, 1, 3, 2, 3, 1, 1, 4, 3); + validate(func, 1, 3, 2, -1, 3, null, 1, 4, 3); + // retract for old sequence + add(func, RowKind.DELETE, 1, 2, 2, 2, 1, 1, 1, 1, 3); + validate(func, 1, 3, 2, -3, 3, null, 1, 4, 3); + } + @Test public void testPartialUpdateWithAggregationProjectPushDown() { Options options = new Options(); @@ -485,6 +786,62 @@ public void testPartialUpdateWithAggregationProjectPushDown() { validate(func, null, -2, 2, 3); } + @Test + public void testMultiSequenceFieldsPartialUpdateWithAggregationProjectPushDown() { + Options options = new Options(); + options.set("fields.f1,f8.sequence-group", "f2,f3,f4"); + options.set("fields.f7,f9.sequence-group", "f6"); + options.set("fields.f0.aggregate-function", "listagg"); + options.set("fields.f2.aggregate-function", "sum"); + options.set("fields.f4.aggregate-function", "last_value"); + options.set("fields.f6.aggregate-function", "last_non_null_value"); + options.set("fields.f4.ignore-retract", "true"); + options.set("fields.f6.ignore-retract", "true"); + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT()); + MergeFunctionFactory factory = + PartialUpdateMergeFunction.factory(options, rowType, ImmutableList.of("f0")); + + MergeFunctionFactory.AdjustedProjection adjustedProjection = + factory.adjustProjection(new int[][] {{3}, {2}, {5}}); + + validate(adjustedProjection, new int[] {3, 2, 5, 1, 8}, new int[] {0, 1, 2}); + + MergeFunction func = factory.create(adjustedProjection.pushdownProjection); + + func.reset(); + // f0 pk + // f1, f8 sequence group + // f2 in f1, f8 group with sum agg + // f3 in f1, f8 group without agg + // f4 in f1, f8 group with last_value agg + // f5 not in group + // f6 in f7, f9 group with last_not_null agg + // f7, f9 sequence group 2 + add(func, 1, 1, 1, 1, 1); + add(func, 2, 1, 2, 2, 2); + validate(func, 2, 2, 2, 2, 2); + + add(func, RowKind.INSERT, null, null, null, 3, 3); + validate(func, null, 2, 2, 3, 3); + + // test retract + add(func, RowKind.UPDATE_BEFORE, 1, 2, 1, 3, 3); + validate(func, null, 0, 2, 3, 3); + add(func, RowKind.DELETE, 1, 2, 1, 3, 3); + validate(func, null, -2, 2, 3, 3); + } + @Test public void testAggregationWithoutSequenceGroup() { Options options = new Options();