diff --git a/paimon-common/src/main/java/org/apache/paimon/CoreOptions.java b/paimon-common/src/main/java/org/apache/paimon/CoreOptions.java index 494e3609218b..9ff68ce9017b 100644 --- a/paimon-common/src/main/java/org/apache/paimon/CoreOptions.java +++ b/paimon-common/src/main/java/org/apache/paimon/CoreOptions.java @@ -64,6 +64,8 @@ public class CoreOptions implements Serializable { public static final String FIELDS_PREFIX = "fields"; + public static final String FIELDS_SEPARATOR = ","; + public static final String AGG_FUNCTION = "aggregate-function"; public static final String DEFAULT_AGG_FUNCTION = "default-aggregate-function"; diff --git a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunction.java b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunction.java index 737fc5284aa6..a3d0a284ce40 100644 --- a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunction.java +++ b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunction.java @@ -24,27 +24,16 @@ import org.apache.paimon.data.InternalRow; import org.apache.paimon.mergetree.compact.aggregate.FieldAggregator; import org.apache.paimon.options.Options; -import org.apache.paimon.types.BigIntType; -import org.apache.paimon.types.CharType; +import org.apache.paimon.types.DataField; import org.apache.paimon.types.DataType; -import org.apache.paimon.types.DataTypeDefaultVisitor; -import org.apache.paimon.types.DateType; -import org.apache.paimon.types.DecimalType; -import org.apache.paimon.types.DoubleType; -import org.apache.paimon.types.FloatType; -import org.apache.paimon.types.IntType; -import org.apache.paimon.types.LocalZonedTimestampType; import org.apache.paimon.types.RowKind; import org.apache.paimon.types.RowType; -import org.apache.paimon.types.SmallIntType; -import org.apache.paimon.types.TimestampType; -import org.apache.paimon.types.TinyIntType; -import org.apache.paimon.types.VarCharType; -import org.apache.paimon.utils.InternalRowUtils; import org.apache.paimon.utils.Projection; +import org.apache.paimon.utils.UserDefinedSeqComparator; import javax.annotation.Nullable; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.LinkedHashSet; @@ -56,6 +45,7 @@ import java.util.stream.Stream; import static org.apache.paimon.CoreOptions.FIELDS_PREFIX; +import static org.apache.paimon.CoreOptions.FIELDS_SEPARATOR; import static org.apache.paimon.utils.InternalRowUtils.createFieldGetters; /** @@ -68,7 +58,7 @@ public class PartialUpdateMergeFunction implements MergeFunction { private final InternalRow.FieldGetter[] getters; private final boolean ignoreDelete; - private final Map fieldSequences; + private final Map fieldSeqComparators; private final boolean fieldSequenceEnabled; private final Map fieldAggregators; @@ -80,12 +70,12 @@ public class PartialUpdateMergeFunction implements MergeFunction { protected PartialUpdateMergeFunction( InternalRow.FieldGetter[] getters, boolean ignoreDelete, - Map fieldSequences, + Map fieldSeqComparators, Map fieldAggregators, boolean fieldSequenceEnabled) { this.getters = getters; this.ignoreDelete = ignoreDelete; - this.fieldSequences = fieldSequences; + this.fieldSeqComparators = fieldSeqComparators; this.fieldAggregators = fieldAggregators; this.fieldSequenceEnabled = fieldSequenceEnabled; } @@ -126,7 +116,7 @@ public void add(KeyValue kv) { } latestSequenceNumber = kv.sequenceNumber(); - if (fieldSequences.isEmpty()) { + if (fieldSeqComparators.isEmpty()) { updateNonNullFields(kv); } else { updateWithSequenceGroup(kv); @@ -145,25 +135,31 @@ private void updateNonNullFields(KeyValue kv) { private void updateWithSequenceGroup(KeyValue kv) { for (int i = 0; i < getters.length; i++) { Object field = getters[i].getFieldOrNull(kv.value()); - SequenceGenerator sequenceGen = fieldSequences.get(i); + UserDefinedSeqComparator seqComparator = fieldSeqComparators.get(i); FieldAggregator aggregator = fieldAggregators.get(i); Object accumulator = getters[i].getFieldOrNull(row); - if (sequenceGen == null) { + if (seqComparator == null) { if (aggregator != null) { row.setField(i, aggregator.agg(accumulator, field)); } else if (field != null) { row.setField(i, field); } } else { - Long currentSeq = sequenceGen.generate(kv.value()); - if (currentSeq != null) { - Long previousSeq = sequenceGen.generate(row); - if (previousSeq == null || currentSeq >= previousSeq) { - row.setField( - i, aggregator == null ? field : aggregator.agg(accumulator, field)); - } else if (aggregator != null) { - row.setField(i, aggregator.agg(field, accumulator)); + if (seqComparator.compare(kv.value(), row) >= 0) { + int index = i; + + // Multiple sequence fields should be updated at once. + if (Arrays.stream(seqComparator.compareFields()) + .anyMatch(seqIndex -> seqIndex == index)) { + for (int fieldIndex : seqComparator.compareFields()) { + row.setField( + fieldIndex, getters[fieldIndex].getFieldOrNull(kv.value())); + } } + row.setField( + i, aggregator == null ? field : aggregator.agg(accumulator, field)); + } else if (aggregator != null) { + row.setField(i, aggregator.agg(field, accumulator)); } } } @@ -171,38 +167,37 @@ private void updateWithSequenceGroup(KeyValue kv) { private void retractWithSequenceGroup(KeyValue kv) { for (int i = 0; i < getters.length; i++) { - SequenceGenerator sequenceGen = fieldSequences.get(i); - if (sequenceGen != null) { - Long currentSeq = sequenceGen.generate(kv.value()); - if (currentSeq != null) { - Long previousSeq = sequenceGen.generate(row); - FieldAggregator aggregator = fieldAggregators.get(i); - if (previousSeq == null || currentSeq >= previousSeq) { - if (sequenceGen.index() == i) { - // update sequence field - row.setField(i, getters[i].getFieldOrNull(kv.value())); + UserDefinedSeqComparator seqComparator = fieldSeqComparators.get(i); + if (seqComparator != null) { + FieldAggregator aggregator = fieldAggregators.get(i); + if (seqComparator.compare(kv.value(), row) >= 0) { + int index = i; + + // Multiple sequence fields should be updated at once. + if (Arrays.stream(seqComparator.compareFields()) + .anyMatch(field -> field == index)) { + for (int field : seqComparator.compareFields()) { + row.setField(field, getters[field].getFieldOrNull(kv.value())); + } + } else { + // retract normal field + if (aggregator == null) { + row.setField(i, null); } else { - // retract normal field - if (aggregator == null) { - row.setField(i, null); - } else { - // retract agg field - Object accumulator = getters[i].getFieldOrNull(row); - row.setField( - i, - aggregator.retract( - accumulator, - getters[i].getFieldOrNull(kv.value()))); - } + // retract agg field + Object accumulator = getters[i].getFieldOrNull(row); + row.setField( + i, + aggregator.retract( + accumulator, getters[i].getFieldOrNull(kv.value()))); } - } else if (aggregator != null) { - // retract agg field for old sequence - Object accumulator = getters[i].getFieldOrNull(row); - row.setField( - i, - aggregator.retract( - accumulator, getters[i].getFieldOrNull(kv.value()))); } + } else if (aggregator != null) { + // retract agg field for old sequence + Object accumulator = getters[i].getFieldOrNull(row); + row.setField( + i, + aggregator.retract(accumulator, getters[i].getFieldOrNull(kv.value()))); } } } @@ -226,57 +221,71 @@ private static class Factory implements MergeFunctionFactory { private static final long serialVersionUID = 1L; private final boolean ignoreDelete; + private final RowType rowType; + + private final List primaryFields; + private final List tableTypes; - private final Map fieldSequences; + + private final Map fieldSeqComparators; private final Map fieldAggregators; private Factory(Options options, RowType rowType, List primaryKeys) { this.ignoreDelete = options.get(CoreOptions.IGNORE_DELETE); + this.rowType = rowType; + this.primaryFields = new ArrayList<>(); + for (String primaryKey : primaryKeys) { + this.primaryFields.add(rowType.getField(primaryKey)); + } this.tableTypes = rowType.getFieldTypes(); List fieldNames = rowType.getFieldNames(); - this.fieldSequences = new HashMap<>(); + this.fieldSeqComparators = new HashMap<>(); for (Map.Entry entry : options.toMap().entrySet()) { String k = entry.getKey(); String v = entry.getValue(); if (k.startsWith(FIELDS_PREFIX) && k.endsWith(SEQUENCE_GROUP)) { - String sequenceFieldName = - k.substring( - FIELDS_PREFIX.length() + 1, - k.length() - SEQUENCE_GROUP.length() - 1); - SequenceGenerator sequenceGen = - new SequenceGenerator(sequenceFieldName, rowType); - Arrays.stream(v.split(",")) + List sequenceFields = + Arrays.stream( + k.substring( + FIELDS_PREFIX.length() + 1, + k.length() + - SEQUENCE_GROUP.length() + - 1) + .split(FIELDS_SEPARATOR)) + .map(fieldName -> validateFieldName(fieldName, fieldNames)) + .collect(Collectors.toList()); + + UserDefinedSeqComparator userDefinedSeqComparator = + UserDefinedSeqComparator.create(rowType, sequenceFields); + Arrays.stream(v.split(FIELDS_SEPARATOR)) .map( - fieldName -> { - int field = fieldNames.indexOf(fieldName); - if (field == -1) { - throw new IllegalArgumentException( - String.format( - "Field %s can not be found in table schema", - fieldName)); - } - return field; - }) + fieldName -> + fieldNames.indexOf( + validateFieldName(fieldName, fieldNames))) .forEach( field -> { - if (fieldSequences.containsKey(field)) { + if (fieldSeqComparators.containsKey(field)) { throw new IllegalArgumentException( String.format( "Field %s is defined repeatedly by multiple groups: %s", fieldNames.get(field), k)); } - fieldSequences.put(field, sequenceGen); + fieldSeqComparators.put(field, userDefinedSeqComparator); }); // add self - fieldSequences.put(sequenceGen.index(), sequenceGen); + sequenceFields.forEach( + fieldName -> { + int index = fieldNames.indexOf(fieldName); + fieldSeqComparators.put(index, userDefinedSeqComparator); + }); } } this.fieldAggregators = createFieldAggregators(rowType, primaryKeys, new CoreOptions(options)); - if (fieldAggregators.size() > 0 && fieldSequences.isEmpty()) { + if (!fieldAggregators.isEmpty() && fieldSeqComparators.isEmpty()) { throw new IllegalArgumentException( "Must use sequence group for aggregation functions."); } @@ -285,29 +294,46 @@ private Factory(Options options, RowType rowType, List primaryKeys) { @Override public MergeFunction create(@Nullable int[][] projection) { if (projection != null) { - Map projectedSequences = new HashMap<>(); + Map projectedSeqComparators = new HashMap<>(); Map projectedAggregators = new HashMap<>(); int[] projects = Projection.of(projection).toTopLevelIndexes(); Map indexMap = new HashMap<>(); + List dataFields = rowType.getFields(); + List newDataTypes = new ArrayList<>(); + for (int i = 0; i < projects.length; i++) { indexMap.put(projects[i], i); + newDataTypes.add(dataFields.get(projects[i]).type()); } - fieldSequences.forEach( - (field, sequence) -> { + RowType newRowType = RowType.builder().fields(newDataTypes).build(); + + fieldSeqComparators.forEach( + (field, comparator) -> { int newField = indexMap.getOrDefault(field, -1); if (newField != -1) { - int newSequenceId = indexMap.getOrDefault(sequence.index(), -1); - if (newSequenceId == -1) { - throw new RuntimeException( - String.format( - "Can not find new sequence field for new field. new field index is %s", - newField)); - } else { - projectedSequences.put( - newField, - new SequenceGenerator( - newSequenceId, sequence.fieldType())); - } + int[] newSequenceFields = + Arrays.stream(comparator.compareFields()) + .map( + index -> { + int newIndex = + indexMap.getOrDefault( + index, -1); + if (newIndex == -1) { + throw new RuntimeException( + String.format( + "Can not find new sequence field " + + "for new field. new field " + + "index is %s", + newField)); + } else { + return newIndex; + } + }) + .toArray(); + projectedSeqComparators.put( + newField, + UserDefinedSeqComparator.create( + newRowType, newSequenceFields)); } }); for (int i = 0; i < projects.length; i++) { @@ -319,22 +345,22 @@ public MergeFunction create(@Nullable int[][] projection) { return new PartialUpdateMergeFunction( createFieldGetters(Projection.of(projection).project(tableTypes)), ignoreDelete, - projectedSequences, + projectedSeqComparators, projectedAggregators, - !fieldSequences.isEmpty()); + !fieldSeqComparators.isEmpty()); } else { return new PartialUpdateMergeFunction( createFieldGetters(tableTypes), ignoreDelete, - fieldSequences, + fieldSeqComparators, fieldAggregators, - !fieldSequences.isEmpty()); + !fieldSeqComparators.isEmpty()); } } @Override public AdjustedProjection adjustProjection(@Nullable int[][] projection) { - if (fieldSequences.isEmpty()) { + if (fieldSeqComparators.isEmpty()) { return new AdjustedProjection(projection, null); } @@ -345,9 +371,15 @@ public AdjustedProjection adjustProjection(@Nullable int[][] projection) { int[] topProjects = Projection.of(projection).toTopLevelIndexes(); Set indexSet = Arrays.stream(topProjects).boxed().collect(Collectors.toSet()); for (int index : topProjects) { - SequenceGenerator generator = fieldSequences.get(index); - if (generator != null && !indexSet.contains(generator.index())) { - extraFields.add(generator.index()); + UserDefinedSeqComparator comparator = fieldSeqComparators.get(index); + if (comparator == null) { + continue; + } + + for (int field : comparator.compareFields()) { + if (!indexSet.contains(field)) { + extraFields.add(field); + } } } @@ -356,11 +388,21 @@ public AdjustedProjection adjustProjection(@Nullable int[][] projection) { .mapToInt(Integer::intValue) .toArray(); - int[][] pushdown = Projection.of(allProjects).toNestedIndexes(); + int[][] pushDown = Projection.of(allProjects).toNestedIndexes(); int[][] outer = Projection.of(IntStream.range(0, topProjects.length).toArray()) .toNestedIndexes(); - return new AdjustedProjection(pushdown, outer); + return new AdjustedProjection(pushDown, outer); + } + + private String validateFieldName(String fieldName, List fieldNames) { + int field = fieldNames.indexOf(fieldName); + if (field == -1) { + throw new IllegalArgumentException( + String.format("Field %s can not be found in table schema", fieldName)); + } + + return fieldName; } /** @@ -408,137 +450,4 @@ private Map createFieldAggregators( return fieldAggregators; } } - - private static class SequenceGenerator { - - private final int index; - - private final Generator generator; - private final DataType fieldType; - - private SequenceGenerator(String field, RowType rowType) { - index = rowType.getFieldNames().indexOf(field); - if (index == -1) { - throw new RuntimeException( - String.format( - "Can not find sequence field %s in table schema: %s", - field, rowType)); - } - fieldType = rowType.getTypeAt(index); - generator = fieldType.accept(new SequenceGeneratorVisitor()); - } - - public SequenceGenerator(int index, DataType dataType) { - this.index = index; - - this.fieldType = dataType; - if (index == -1) { - throw new RuntimeException(String.format("Index : %s is invalid", index)); - } - generator = fieldType.accept(new SequenceGeneratorVisitor()); - } - - public int index() { - return index; - } - - public DataType fieldType() { - return fieldType; - } - - @Nullable - public Long generate(InternalRow row) { - return generator.generateNullable(row, index); - } - - private interface Generator { - long generate(InternalRow row, int i); - - @Nullable - default Long generateNullable(InternalRow row, int i) { - if (row.isNullAt(i)) { - return null; - } - return generate(row, i); - } - } - - private static class SequenceGeneratorVisitor extends DataTypeDefaultVisitor { - - @Override - public Generator visit(CharType charType) { - return stringGenerator(); - } - - @Override - public Generator visit(VarCharType varCharType) { - return stringGenerator(); - } - - private Generator stringGenerator() { - return (row, i) -> Long.parseLong(row.getString(i).toString()); - } - - @Override - public Generator visit(DecimalType decimalType) { - return (row, i) -> - InternalRowUtils.castToIntegral( - row.getDecimal( - i, decimalType.getPrecision(), decimalType.getScale())); - } - - @Override - public Generator visit(TinyIntType tinyIntType) { - return InternalRow::getByte; - } - - @Override - public Generator visit(SmallIntType smallIntType) { - return InternalRow::getShort; - } - - @Override - public Generator visit(IntType intType) { - return InternalRow::getInt; - } - - @Override - public Generator visit(BigIntType bigIntType) { - return InternalRow::getLong; - } - - @Override - public Generator visit(FloatType floatType) { - return (row, i) -> (long) row.getFloat(i); - } - - @Override - public Generator visit(DoubleType doubleType) { - return (row, i) -> (long) row.getDouble(i); - } - - @Override - public Generator visit(DateType dateType) { - return InternalRow::getInt; - } - - @Override - public Generator visit(TimestampType timestampType) { - return (row, i) -> - row.getTimestamp(i, timestampType.getPrecision()).getMillisecond(); - } - - @Override - public Generator visit(LocalZonedTimestampType localZonedTimestampType) { - return (row, i) -> - row.getTimestamp(i, localZonedTimestampType.getPrecision()) - .getMillisecond(); - } - - @Override - protected Generator defaultMethod(DataType dataType) { - throw new UnsupportedOperationException("Unsupported type: " + dataType); - } - } - } } diff --git a/paimon-core/src/main/java/org/apache/paimon/schema/SchemaValidation.java b/paimon-core/src/main/java/org/apache/paimon/schema/SchemaValidation.java index 8b97c0911503..05f77e09fc2e 100644 --- a/paimon-core/src/main/java/org/apache/paimon/schema/SchemaValidation.java +++ b/paimon-core/src/main/java/org/apache/paimon/schema/SchemaValidation.java @@ -35,6 +35,7 @@ import org.apache.paimon.types.RowType; import org.apache.paimon.types.VarCharType; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -51,6 +52,7 @@ import static org.apache.paimon.CoreOptions.CHANGELOG_PRODUCER; import static org.apache.paimon.CoreOptions.DEFAULT_AGG_FUNCTION; import static org.apache.paimon.CoreOptions.FIELDS_PREFIX; +import static org.apache.paimon.CoreOptions.FIELDS_SEPARATOR; import static org.apache.paimon.CoreOptions.FULL_COMPACTION_DELTA_COMMITS; import static org.apache.paimon.CoreOptions.INCREMENTAL_BETWEEN; import static org.apache.paimon.CoreOptions.INCREMENTAL_BETWEEN_TIMESTAMP; @@ -349,13 +351,15 @@ private static void validateFieldsPrefix(TableSchema schema, CoreOptions options .forEach( k -> { if (k.startsWith(FIELDS_PREFIX)) { - String fieldName = k.split("\\.")[1]; - checkArgument( - DEFAULT_AGG_FUNCTION.equals(fieldName) - || fieldNames.contains(fieldName), - String.format( - "Field %s can not be found in table schema.", - fieldName)); + String[] fields = k.split("\\.")[1].split(FIELDS_SEPARATOR); + for (String field : fields) { + checkArgument( + DEFAULT_AGG_FUNCTION.equals(field) + || fieldNames.contains(field), + String.format( + "Field %s can not be found in table schema.", + field)); + } } }); } @@ -367,29 +371,42 @@ private static void validateSequenceGroup(TableSchema schema, CoreOptions option String v = entry.getValue(); List fieldNames = schema.fieldNames(); if (k.startsWith(FIELDS_PREFIX) && k.endsWith(SEQUENCE_GROUP)) { - String sequenceFieldName = + String[] sequenceFieldNames = k.substring( - FIELDS_PREFIX.length() + 1, - k.length() - SEQUENCE_GROUP.length() - 1); - if (!fieldNames.contains(sequenceFieldName)) { - throw new IllegalArgumentException( - String.format( - "The sequence field group: %s can not be found in table schema.", - sequenceFieldName)); - } + FIELDS_PREFIX.length() + 1, + k.length() - SEQUENCE_GROUP.length() - 1) + .split(FIELDS_SEPARATOR); - for (String field : v.split(",")) { + for (String field : v.split(FIELDS_SEPARATOR)) { if (!fieldNames.contains(field)) { throw new IllegalArgumentException( String.format("Field %s can not be found in table schema.", field)); } - Set group = fields2Group.computeIfAbsent(field, p -> new HashSet<>()); - if (group.add(sequenceFieldName) && group.size() > 1) { + + List sequenceFieldsList = new ArrayList<>(); + for (String sequenceFieldName : sequenceFieldNames) { + if (!fieldNames.contains(sequenceFieldName)) { + throw new IllegalArgumentException( + String.format( + "The sequence field group: %s can not be found in table schema.", + sequenceFieldName)); + } + sequenceFieldsList.add(sequenceFieldName); + } + + if (fields2Group.containsKey(field)) { + List> sequenceGroups = new ArrayList<>(); + sequenceGroups.add(new ArrayList<>(fields2Group.get(field))); + sequenceGroups.add(sequenceFieldsList); + throw new IllegalArgumentException( String.format( "Field %s is defined repeatedly by multiple groups: %s.", - field, group)); + field, sequenceGroups)); } + + Set group = fields2Group.computeIfAbsent(field, p -> new HashSet<>()); + group.addAll(sequenceFieldsList); } } } diff --git a/paimon-core/src/main/java/org/apache/paimon/utils/UserDefinedSeqComparator.java b/paimon-core/src/main/java/org/apache/paimon/utils/UserDefinedSeqComparator.java index 35fa7a66d775..ec7a00bcb3b2 100644 --- a/paimon-core/src/main/java/org/apache/paimon/utils/UserDefinedSeqComparator.java +++ b/paimon-core/src/main/java/org/apache/paimon/utils/UserDefinedSeqComparator.java @@ -62,8 +62,18 @@ public static UserDefinedSeqComparator create(RowType rowType, List sequ List fieldNames = rowType.getFieldNames(); int[] fields = sequenceFields.stream().mapToInt(fieldNames::indexOf).toArray(); + + return create(rowType, fields); + } + + @Nullable + public static UserDefinedSeqComparator create(RowType rowType, int[] sequenceFields) { + if (sequenceFields.length == 0) { + return null; + } + RecordComparator comparator = - CodeGenUtils.newRecordComparator(rowType.getFieldTypes(), fields); - return new UserDefinedSeqComparator(fields, comparator); + CodeGenUtils.newRecordComparator(rowType.getFieldTypes(), sequenceFields); + return new UserDefinedSeqComparator(sequenceFields, comparator); } } diff --git a/paimon-core/src/test/java/org/apache/paimon/mergetree/SortBufferWriteBufferTestBase.java b/paimon-core/src/test/java/org/apache/paimon/mergetree/SortBufferWriteBufferTestBase.java index 9315fb15dcfa..dea0268ec87f 100644 --- a/paimon-core/src/test/java/org/apache/paimon/mergetree/SortBufferWriteBufferTestBase.java +++ b/paimon-core/src/test/java/org/apache/paimon/mergetree/SortBufferWriteBufferTestBase.java @@ -179,7 +179,7 @@ protected List getExpected(List input) { protected MergeFunction createMergeFunction() { Options options = new Options(); return PartialUpdateMergeFunction.factory( - options, RowType.of(DataTypes.BIGINT()), ImmutableList.of("key")) + options, RowType.of(DataTypes.BIGINT()), ImmutableList.of("f0")) .create(); } } diff --git a/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunctionTest.java b/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunctionTest.java index fa41607d10c3..03c5b83dbcf1 100644 --- a/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunctionTest.java +++ b/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/PartialUpdateMergeFunctionTest.java @@ -95,6 +95,43 @@ public void testSequenceGroup() { validate(func, 1, null, null, 6, null, null, 6); } + @Test + public void testMultiSequenceFields() { + Options options = new Options(); + options.set("fields.f3,f4.sequence-group", "f1,f2"); + options.set("fields.f7,f8.sequence-group", "f5,f6"); + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT()); + MergeFunction func = + PartialUpdateMergeFunction.factory(options, rowType, ImmutableList.of("f0")) + .create(); + func.reset(); + add(func, 1, 1, 1, 1, 1, 1, 1, 1, 3); + add(func, 1, 2, 2, 2, 2, 2, 1, 1, null); + validate(func, 1, 2, 2, 2, 2, 1, 1, 1, 3); + add(func, 1, 1, 3, 1, 3, 3, 3, 3, 2); + validate(func, 1, 2, 2, 2, 2, 3, 3, 3, 2); + + // delete + add(func, RowKind.DELETE, 1, 1, 1, 3, 3, 1, 1, null, null); + validate(func, 1, null, null, 3, 3, 3, 3, 3, 2); + add(func, RowKind.DELETE, 1, 1, 1, 3, 1, 1, 1, 4, 4); + validate(func, 1, null, null, 3, 3, null, null, 4, 4); + add(func, 1, 4, 4, 4, 4, 5, 5, 5, 5); + validate(func, 1, 4, 4, 4, 4, 5, 5, 5, 5); + add(func, RowKind.DELETE, 1, 1, 1, 6, 1, 1, 1, 6, 1); + validate(func, 1, null, null, 6, 1, null, null, 6, 1); + } + @Test public void testSequenceGroupDefaultAggFunc() { Options options = new Options(); @@ -123,6 +160,36 @@ public void testSequenceGroupDefaultAggFunc() { validate(func, 1, 4, 2, 4, 5, 3, 5); } + @Test + public void testMultiSequenceFieldsDefaultAggFunc() { + Options options = new Options(); + options.set("fields.f3,f4.sequence-group", "f1,f2"); + options.set("fields.f7,f8.sequence-group", "f5,f6"); + options.set(FIELDS_DEFAULT_AGG_FUNC, "last_non_null_value"); + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT()); + MergeFunction func = + PartialUpdateMergeFunction.factory(options, rowType, ImmutableList.of("f0")) + .create(); + func.reset(); + add(func, 1, 1, 1, 1, 1, 1, 1, 1, 1); + add(func, 1, 2, 2, 2, 2, 2, 2, null, null); + validate(func, 1, 2, 2, 2, 2, 1, 1, 1, 1); + add(func, 1, 3, 3, 1, 1, 3, 3, 3, 3); + validate(func, 1, 2, 2, 2, 2, 3, 3, 3, 3); + add(func, 1, 4, null, 4, 4, 5, null, 5, 5); + validate(func, 1, 4, 2, 4, 4, 5, 3, 5, 5); + } + @Test public void testSequenceGroupDefinedNoField() { Options options = new Options(); @@ -144,6 +211,27 @@ public void testSequenceGroupDefinedNoField() { .hasMessageContaining("can not be found in table schema"); } + @Test + public void testMultiSequenceFieldsDefinedNoField() { + Options options = new Options(); + options.set("fields.f2,f3.sequence-group", "f1,f7"); + options.set("fields.f5,f6.sequence-group", "f4"); + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT()); + assertThatThrownBy( + () -> + PartialUpdateMergeFunction.factory( + options, rowType, ImmutableList.of("f0"))) + .hasMessageContaining("can not be found in table schema"); + } + @Test public void testSequenceGroupRepeatDefine() { Options options = new Options(); @@ -163,6 +251,27 @@ public void testSequenceGroupRepeatDefine() { .hasMessageContaining("is defined repeatedly by multiple groups"); } + @Test + public void testMultiSequenceFieldsRepeatDefine() { + Options options = new Options(); + options.set("fields.f3,f4.sequence-group", "f1,f2"); + options.set("fields.f5,f6.sequence-group", "f1,f2"); + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT()); + assertThatThrownBy( + () -> + PartialUpdateMergeFunction.factory( + options, rowType, ImmutableList.of("f0"))) + .hasMessageContaining("is defined repeatedly by multiple groups"); + } + @Test public void testAdjustProjectionRepeatProject() { Options options = new Options(); @@ -198,6 +307,45 @@ public void testAdjustProjectionRepeatProject() { validate(func, 3, 3, null, 2, 4, 2); } + @Test + public void testMultiSequenceFieldsAdjustProjectionRepeatProject() { + Options options = new Options(); + options.set("fields.f2,f4.sequence-group", "f1,f3"); + options.set("fields.f5,f6.sequence-group", "f7"); + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT()); + // the field 'f1' is projected twice + int[][] projection = new int[][] {{1}, {1}, {3}, {7}}; + MergeFunctionFactory factory = + PartialUpdateMergeFunction.factory(options, rowType, ImmutableList.of("f0")); + MergeFunctionFactory.AdjustedProjection adjustedProjection = + factory.adjustProjection(projection); + + validate(adjustedProjection, new int[] {1, 1, 3, 7, 2, 4, 5, 6}, new int[] {0, 1, 2, 3}); + + MergeFunction func = factory.create(adjustedProjection.pushdownProjection); + func.reset(); + add(func, 1, 1, 1, 1, 1, 1, 1, 1); + add(func, 2, 2, 6, 2, 2, 2, 2, 6); + validate(func, 2, 2, 6, 2, 2, 2, 2, 6); + + // update first sequence group + add(func, 3, 3, null, 7, 4, null, 1, 8); + validate(func, 3, 3, null, 2, 4, null, 2, 6); + + // update second sequence group + add(func, 5, 5, 3, 3, 3, 5, 5, 6); + validate(func, 5, 3, null, 3, 4, null, 5, 6); + } + @Test public void testAdjustProjectionSequenceFieldsProject() { Options options = new Options(); @@ -230,6 +378,38 @@ public void testAdjustProjectionSequenceFieldsProject() { validate(func, 1, 1, 1, 2, 2); } + @Test + public void testMultiSequenceFieldsAdjustProjectionProject() { + Options options = new Options(); + options.set("fields.f2,f4.sequence-group", "f1,f3"); + options.set("fields.f5,f6.sequence-group", "f7"); + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT()); + // the sequence field 'f4' is projected too + int[][] projection = new int[][] {{1}, {4}, {3}, {7}}; + MergeFunctionFactory factory = + PartialUpdateMergeFunction.factory(options, rowType, ImmutableList.of("f0")); + MergeFunctionFactory.AdjustedProjection adjustedProjection = + factory.adjustProjection(projection); + + validate(adjustedProjection, new int[] {1, 4, 3, 7, 2, 5, 6}, new int[] {0, 1, 2, 3}); + + MergeFunction func = factory.create(adjustedProjection.pushdownProjection); + func.reset(); + // if sequence field is null, the related fields should not be updated + add(func, 1, 1, 1, 1, 1, 1, 1); + add(func, 1, null, 1, 3, 2, 2, 2); + validate(func, 1, null, 1, 3, 2, 2, 2); + } + @Test public void testAdjustProjectionAllFieldsProject() { Options options = new Options(); @@ -265,6 +445,41 @@ public void testAdjustProjectionAllFieldsProject() { validate(func, 4, 2, 4, 2, 2, 1, 1, 1); } + @Test + public void testMultiSequenceFieldsAdjustProjectionAllFieldsProject() { + Options options = new Options(); + options.set("fields.f2,f4.sequence-group", "f1,f3"); + options.set("fields.f5,f6.sequence-group", "f7"); + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT()); + // all fields are projected + int[][] projection = new int[][] {{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}}; + MergeFunctionFactory factory = + PartialUpdateMergeFunction.factory(options, rowType, ImmutableList.of("f0")); + MergeFunctionFactory.AdjustedProjection adjustedProjection = + factory.adjustProjection(projection); + + validate( + adjustedProjection, + new int[] {0, 1, 2, 3, 4, 5, 6, 7}, + new int[] {0, 1, 2, 3, 4, 5, 6, 7}); + + MergeFunction func = factory.create(adjustedProjection.pushdownProjection); + func.reset(); + // 'f6' has no sequence group, it should not be updated by null + add(func, 1, 1, 1, 1, 1, 1, 1, 1); + add(func, 4, 2, 4, 2, 2, 0, null, 3); + validate(func, 4, 2, 4, 2, 2, 1, 1, 1); + } + @Test public void testAdjustProjectionNonProject() { Options options = new Options(); @@ -372,6 +587,33 @@ public void testFirstValue() { validate(func, 1, 2, 3, 2); } + @Test + public void testMultiSequenceFieldsFirstValue() { + Options options = new Options(); + options.set("fields.f1,f2.sequence-group", "f3,f4"); + options.set("fields.f3.aggregate-function", "first_value"); + options.set("fields.f4.aggregate-function", "last_value"); + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT()); + MergeFunction func = + PartialUpdateMergeFunction.factory(options, rowType, ImmutableList.of("f0")) + .create(); + + func.reset(); + + // f7 sequence group 2 + add(func, 1, 1, 1, 1, 1); + add(func, 1, 2, 2, 2, 2); + validate(func, 1, 2, 2, 1, 2); + add(func, 1, 0, 1, 3, 3); + validate(func, 1, 2, 2, 3, 2); + } + @Test public void testPartialUpdateWithAggregation() { Options options = new Options(); @@ -431,6 +673,65 @@ public void testPartialUpdateWithAggregation() { validate(func, 1, 3, -3, null, 1, 1, 1, 3); } + @Test + public void testMultiSequenceFieldsPartialUpdateWithAggregation() { + Options options = new Options(); + options.set("fields.f1,f2.sequence-group", "f3,f4,f5"); + options.set("fields.f7,f8.sequence-group", "f6"); + options.set("fields.f0.aggregate-function", "listagg"); + options.set("fields.f3.aggregate-function", "sum"); + options.set("fields.f4.aggregate-function", "first_value"); + options.set("fields.f5.aggregate-function", "last_value"); + options.set("fields.f6.aggregate-function", "last_non_null_value"); + options.set("fields.f4.ignore-retract", "true"); + options.set("fields.f6.ignore-retract", "true"); + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT()); + MergeFunction func = + PartialUpdateMergeFunction.factory(options, rowType, ImmutableList.of("f0")) + .create(); + + func.reset(); + // f0 pk + // f1, f2 sequence group 1 + // f3 in f1, f2 group with sum agg + // f4 in f1, f2 group with first_value agg + // f5 in f1, f2 group with last_value agg + // f6 in f7, f8 group with last_not_null agg + // f7, f8 sequence group 2 + add(func, 1, 1, 1, 1, 1, 1, 1, 1, 1); + add(func, 1, 1, 2, 1, 2, 2, null, 2, 0); + validate(func, 1, 1, 2, 2, 1, 2, 1, 2, 0); + + // sequence group not advanced + add(func, 1, 1, 1, 1, 3, 1, 1, 2, 0); + validate(func, 1, 1, 2, 3, 3, 2, 1, 2, 0); + + // test null + add(func, 1, 1, 3, null, null, null, null, 4, 2); + validate(func, 1, 1, 3, 3, 3, null, 1, 4, 2); + + // test retract + add(func, 1, 2, 3, 1, 1, 1, 1, 4, 3); + validate(func, 1, 2, 3, 4, 3, 1, 1, 4, 3); + add(func, RowKind.UPDATE_BEFORE, 1, 2, 3, 2, 1, 2, 1, 4, 3); + validate(func, 1, 2, 3, 2, 3, null, 1, 4, 3); + add(func, RowKind.DELETE, 1, 3, 2, 3, 1, 1, 4, 3); + validate(func, 1, 3, 2, -1, 3, null, 1, 4, 3); + // retract for old sequence + add(func, RowKind.DELETE, 1, 2, 2, 2, 1, 1, 1, 1, 3); + validate(func, 1, 3, 2, -3, 3, null, 1, 4, 3); + } + @Test public void testPartialUpdateWithAggregationProjectPushDown() { Options options = new Options(); @@ -485,6 +786,62 @@ public void testPartialUpdateWithAggregationProjectPushDown() { validate(func, null, -2, 2, 3); } + @Test + public void testMultiSequenceFieldsPartialUpdateWithAggregationProjectPushDown() { + Options options = new Options(); + options.set("fields.f1,f8.sequence-group", "f2,f3,f4"); + options.set("fields.f7,f9.sequence-group", "f6"); + options.set("fields.f0.aggregate-function", "listagg"); + options.set("fields.f2.aggregate-function", "sum"); + options.set("fields.f4.aggregate-function", "last_value"); + options.set("fields.f6.aggregate-function", "last_non_null_value"); + options.set("fields.f4.ignore-retract", "true"); + options.set("fields.f6.ignore-retract", "true"); + RowType rowType = + RowType.of( + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT(), + DataTypes.INT()); + MergeFunctionFactory factory = + PartialUpdateMergeFunction.factory(options, rowType, ImmutableList.of("f0")); + + MergeFunctionFactory.AdjustedProjection adjustedProjection = + factory.adjustProjection(new int[][] {{3}, {2}, {5}}); + + validate(adjustedProjection, new int[] {3, 2, 5, 1, 8}, new int[] {0, 1, 2}); + + MergeFunction func = factory.create(adjustedProjection.pushdownProjection); + + func.reset(); + // f0 pk + // f1, f8 sequence group + // f2 in f1, f8 group with sum agg + // f3 in f1, f8 group without agg + // f4 in f1, f8 group with last_value agg + // f5 not in group + // f6 in f7, f9 group with last_not_null agg + // f7, f9 sequence group 2 + add(func, 1, 1, 1, 1, 1); + add(func, 2, 1, 2, 2, 2); + validate(func, 2, 2, 2, 2, 2); + + add(func, RowKind.INSERT, null, null, null, 3, 3); + validate(func, null, 2, 2, 3, 3); + + // test retract + add(func, RowKind.UPDATE_BEFORE, 1, 2, 1, 3, 3); + validate(func, null, 0, 2, 3, 3); + add(func, RowKind.DELETE, 1, 2, 1, 3, 3); + validate(func, null, -2, 2, 3, 3); + } + @Test public void testAggregationWithoutSequenceGroup() { Options options = new Options(); diff --git a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/PartialUpdateITCase.java b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/PartialUpdateITCase.java index e41df196b972..67294392463e 100644 --- a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/PartialUpdateITCase.java +++ b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/PartialUpdateITCase.java @@ -215,6 +215,43 @@ public void testSequenceGroup() { assertThat(sql("SELECT c, d FROM SG")).containsExactlyInAnyOrder(Row.of(5, null)); } + @Test + public void testMultiFieldsSequenceGroup() { + sql( + "CREATE TABLE SG (" + + "k INT, a INT, b INT, g_1 INT, c INT, d INT, g_2 INT, g_3 INT, PRIMARY KEY (k) NOT ENFORCED)" + + " WITH (" + + "'merge-engine'='partial-update', " + + "'fields.g_1.sequence-group'='a,b', " + + "'fields.g_2,g_3.sequence-group'='c,d');"); + + sql("INSERT INTO SG VALUES (1, 1, 1, 1, 1, 1, 1, 1)"); + + // g_2, g_3 should not be updated + sql("INSERT INTO SG VALUES (1, 2, 2, 2, 2, 2, 1, CAST(NULL AS INT))"); + + // select * + assertThat(sql("SELECT * FROM SG")) + .containsExactlyInAnyOrder(Row.of(1, 2, 2, 2, 1, 1, 1, 1)); + + // projection + assertThat(sql("SELECT c, d FROM SG")).containsExactlyInAnyOrder(Row.of(1, 1)); + + // g_1 should not be updated + sql("INSERT INTO SG VALUES (1, 3, 3, 1, 3, 3, 3, 1)"); + + assertThat(sql("SELECT * FROM SG")) + .containsExactlyInAnyOrder(Row.of(1, 2, 2, 2, 3, 3, 3, 1)); + + // d should be updated by null + sql("INSERT INTO SG VALUES (1, 3, 3, 3, 2, 2, CAST(NULL AS INT), 1)"); + sql("INSERT INTO SG VALUES (1, 4, 4, 4, 2, 2, CAST(NULL AS INT), 1)"); + sql("INSERT INTO SG VALUES (1, 5, 5, 3, 5, CAST(NULL AS INT), 4, 1)"); + + assertThat(sql("SELECT a, b FROM SG")).containsExactlyInAnyOrder(Row.of(4, 4)); + assertThat(sql("SELECT c, d FROM SG")).containsExactlyInAnyOrder(Row.of(5, null)); + } + @Test public void testSequenceGroupWithDefaultAggFunc() { sql( @@ -285,7 +322,19 @@ public void testInvalidSequenceGroup() { + "'fields.g_1.sequence-group'='a,b', " + "'fields.g_2.sequence-group'='a,d');")) .hasRootCauseMessage( - "Field a is defined repeatedly by multiple groups: [g_1, g_2]."); + "Field a is defined repeatedly by multiple groups: [[g_1], [g_2]]."); + + Assertions.assertThatThrownBy( + () -> + sql( + "CREATE TABLE SG (" + + "k INT, a INT, b INT, g_1 INT, c INT, d INT, g_2 INT, g_3 INT, PRIMARY KEY (k) NOT ENFORCED)" + + " WITH (" + + "'merge-engine'='partial-update', " + + "'fields.g_1.sequence-group'='a,b', " + + "'fields.g_2,g_3.sequence-group'='a,d');")) + .hasRootCauseMessage( + "Field a is defined repeatedly by multiple groups: [[g_1], [g_2, g_3]]."); } @Test