From ad739d2f91a8c72d1f16093cc776fc1047385c44 Mon Sep 17 00:00:00 2001 From: wyb Date: Fri, 20 Sep 2024 09:21:48 +0800 Subject: [PATCH] [Enhancement] Support insert match column by name Signed-off-by: wyb --- .../java/com/starrocks/common/ErrorCode.java | 3 +- .../java/com/starrocks/sql/InsertPlanner.java | 79 ++++++----- .../sql/analyzer/InsertAnalyzer.java | 64 +++++++-- .../com/starrocks/sql/ast/InsertStmt.java | 35 +++++ .../test_insert_empty/R/test_insert_by_name | 126 ++++++++++++++++++ .../test_insert_empty/T/test_insert_by_name | 46 +++++++ 6 files changed, 310 insertions(+), 43 deletions(-) create mode 100644 test/sql/test_insert_empty/R/test_insert_by_name create mode 100644 test/sql/test_insert_empty/T/test_insert_by_name diff --git a/fe/fe-core/src/main/java/com/starrocks/common/ErrorCode.java b/fe/fe-core/src/main/java/com/starrocks/common/ErrorCode.java index a9cfcb2857c84..e9d7589180ce4 100644 --- a/fe/fe-core/src/main/java/com/starrocks/common/ErrorCode.java +++ b/fe/fe-core/src/main/java/com/starrocks/common/ErrorCode.java @@ -320,7 +320,7 @@ public enum ErrorCode { "No partitions have data available for loading. If you are sure there may be no data to be loaded, " + "you can use `ADMIN SET FRONTEND CONFIG ('empty_load_as_error' = 'false')` " + "to ensure such load jobs can succeed"), - ERR_INSERTED_COLUMN_MISMATCH(5604, new byte[] {'2', '2', '0', '0', '0'}, + ERR_INSERT_COLUMN_COUNT_MISMATCH(5604, new byte[] {'2', '2', '0', '0', '0'}, "Inserted target column count: %d doesn't match select/value column count: %d"), ERR_ILLEGAL_BYTES_LENGTH(5605, new byte[] {'4', '2', '0', '0', '0'}, "The valid bytes length for '%s' is [%d, %d]"), ERR_TOO_MANY_ERROR_ROWS(5606, new byte[] {'2', '2', '0', '0', '0'}, @@ -329,6 +329,7 @@ public enum ErrorCode { ERR_ROUTINE_LOAD_OFFSET_INVALID(5607, new byte[] {'0', '2', '0', '0', '0'}, "Consume offset: %d is greater than the latest offset: %d in kafka partition: %d. " + "You can modify 'kafka_offsets' property through ALTER ROUTINE LOAD and RESUME the job"), + ERR_INSERT_COLUMN_NAME_MISMATCH(5608, new byte[] {'2', '2', '0', '0', '0'}, "%s column: %s has no matching %s column"), /** * 5700 - 5799: Partition diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/InsertPlanner.java b/fe/fe-core/src/main/java/com/starrocks/sql/InsertPlanner.java index 226749f433664..56eac05d95e76 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/InsertPlanner.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/InsertPlanner.java @@ -17,6 +17,7 @@ import com.google.common.base.Preconditions; import com.google.common.base.Stopwatch; import com.google.common.collect.Lists; +import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.starrocks.alter.SchemaChangeHandler; import com.starrocks.analysis.DescriptorTable; @@ -592,8 +593,31 @@ private void castLiteralToTargetColumnsType(InsertStmt insertStatement) { private OptExprBuilder fillDefaultValue(LogicalPlan logicalPlan, ColumnRefFactory columnRefFactory, InsertStmt insertStatement, List outputColumns) { - Map columnRefMap = new HashMap<>(); + // targetColumnNames is for check whether schema column is in target column list or not + Set targetColumnNames = Sets.newTreeSet(String.CASE_INSENSITIVE_ORDER); + targetColumnNames.addAll( + insertStatement.getTargetColumnNames() != null ? insertStatement.getTargetColumnNames() : + outputBaseSchema.stream().map(Column::getName).collect(Collectors.toList())); + + // sourceColumnMappedNames is the mapped name of source columns corresponding to the target columns. + // 1. if match by position, source mapped column name can be converted from target column name one by one. + // 2. if match by name, source mapped column name is same as the source column name. + Map mappedColumnToSourceIdx = Maps.newTreeMap(String.CASE_INSENSITIVE_ORDER); + List sourceColumnMappedNames = null; + if (insertStatement.isColumnMatchByPosition()) { + sourceColumnMappedNames = insertStatement.getTargetColumnNames() != null ? insertStatement.getTargetColumnNames() : + outputBaseSchema.stream().map(Column::getName).collect(Collectors.toList()); + } else { + Preconditions.checkState(insertStatement.isColumnMatchByName()); + sourceColumnMappedNames = insertStatement.getQueryStatement().getQueryRelation().getColumnOutputNames(); + } + Preconditions.checkState(sourceColumnMappedNames != null); + for (int columnIdx = 0; columnIdx < sourceColumnMappedNames.size(); ++columnIdx) { + mappedColumnToSourceIdx.put(sourceColumnMappedNames.get(columnIdx), columnIdx); + } + // generate columnRefMap (fill default value) + Map columnRefMap = new HashMap<>(); for (int columnIdx = 0; columnIdx < outputBaseSchema.size(); ++columnIdx) { if (needToSkip(insertStatement, columnIdx)) { continue; @@ -603,40 +627,35 @@ private OptExprBuilder fillDefaultValue(LogicalPlan logicalPlan, ColumnRefFactor if (targetColumn.isGeneratedColumn()) { continue; } - if (insertStatement.getTargetColumnNames() == null) { - outputColumns.add(logicalPlan.getOutputColumn().get(columnIdx)); - columnRefMap.put(logicalPlan.getOutputColumn().get(columnIdx), - logicalPlan.getOutputColumn().get(columnIdx)); + + String targetColumnName = targetColumn.getName(); + if (mappedColumnToSourceIdx.containsKey(targetColumnName) && targetColumnNames.contains(targetColumnName)) { + ColumnRefOperator col = logicalPlan.getOutputColumn().get(mappedColumnToSourceIdx.get(targetColumnName)); + outputColumns.add(col); + columnRefMap.put(col, col); } else { - int idx = insertStatement.getTargetColumnNames().indexOf(targetColumn.getName().toLowerCase()); - if (idx == -1) { - ScalarOperator scalarOperator; - Column.DefaultValueType defaultValueType = targetColumn.getDefaultValueType(); - if (defaultValueType == Column.DefaultValueType.NULL || targetColumn.isAutoIncrement()) { - scalarOperator = ConstantOperator.createNull(targetColumn.getType()); - } else if (defaultValueType == Column.DefaultValueType.CONST) { - scalarOperator = ConstantOperator.createVarchar(targetColumn.calculatedDefaultValue()); - } else if (defaultValueType == Column.DefaultValueType.VARY) { - if (SUPPORTED_DEFAULT_FNS.contains(targetColumn.getDefaultExpr().getExpr())) { - scalarOperator = SqlToScalarOperatorTranslator. - translate(targetColumn.getDefaultExpr().obtainExpr()); - } else { - throw new SemanticException( - "Column:" + targetColumn.getName() + " has unsupported default value:" - + targetColumn.getDefaultExpr().getExpr()); - } + ScalarOperator scalarOperator; + Column.DefaultValueType defaultValueType = targetColumn.getDefaultValueType(); + if (defaultValueType == Column.DefaultValueType.NULL || targetColumn.isAutoIncrement()) { + scalarOperator = ConstantOperator.createNull(targetColumn.getType()); + } else if (defaultValueType == Column.DefaultValueType.CONST) { + scalarOperator = ConstantOperator.createVarchar(targetColumn.calculatedDefaultValue()); + } else if (defaultValueType == Column.DefaultValueType.VARY) { + if (SUPPORTED_DEFAULT_FNS.contains(targetColumn.getDefaultExpr().getExpr())) { + scalarOperator = SqlToScalarOperatorTranslator. + translate(targetColumn.getDefaultExpr().obtainExpr()); } else { - throw new SemanticException("Unknown default value type:%s", defaultValueType.toString()); + throw new SemanticException("Column:" + targetColumnName + " has unsupported default value:" + + targetColumn.getDefaultExpr().getExpr()); } - ColumnRefOperator col = columnRefFactory - .create(scalarOperator, scalarOperator.getType(), scalarOperator.isNullable()); - - outputColumns.add(col); - columnRefMap.put(col, scalarOperator); } else { - outputColumns.add(logicalPlan.getOutputColumn().get(idx)); - columnRefMap.put(logicalPlan.getOutputColumn().get(idx), logicalPlan.getOutputColumn().get(idx)); + throw new SemanticException("Unknown default value type:%s", defaultValueType.toString()); } + ColumnRefOperator col = columnRefFactory + .create(scalarOperator, scalarOperator.getType(), scalarOperator.isNullable()); + + outputColumns.add(col); + columnRefMap.put(col, scalarOperator); } } return logicalPlan.getRootBuilder().withNewRoot(new LogicalProjectOperator(new HashMap<>(columnRefMap))); diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/InsertAnalyzer.java b/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/InsertAnalyzer.java index 08db3754c08de..a24993d90f91d 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/InsertAnalyzer.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/InsertAnalyzer.java @@ -46,6 +46,7 @@ import com.starrocks.sql.ast.DefaultValueExpr; import com.starrocks.sql.ast.FileTableFunctionRelation; import com.starrocks.sql.ast.InsertStmt; +import com.starrocks.sql.ast.InsertStmt.ColumnMatchPolicy; import com.starrocks.sql.ast.LoadStmt; import com.starrocks.sql.ast.PartitionNames; import com.starrocks.sql.ast.QueryRelation; @@ -213,13 +214,11 @@ public static void analyzeWithDeferredLock(InsertStmt insertStmt, ConnectContext if (insertStmt.getTargetColumnNames() == null) { if (table instanceof OlapTable) { targetColumns = new ArrayList<>(((OlapTable) table).getBaseSchemaWithoutGeneratedColumn()); - mentionedColumns = - ((OlapTable) table).getBaseSchemaWithoutGeneratedColumn().stream() - .map(Column::getName).collect(Collectors.toSet()); + mentionedColumns.addAll(((OlapTable) table).getBaseSchemaWithoutGeneratedColumn().stream().map(Column::getName) + .collect(Collectors.toSet())); } else { targetColumns = new ArrayList<>(table.getBaseSchema()); - mentionedColumns = - table.getBaseSchema().stream().map(Column::getName).collect(Collectors.toSet()); + mentionedColumns.addAll(table.getBaseSchema().stream().map(Column::getName).collect(Collectors.toSet())); } } else { targetColumns = new ArrayList<>(); @@ -235,7 +234,7 @@ public static void analyzeWithDeferredLock(InsertStmt insertStmt, ConnectContext throw new SemanticException("generated column '%s' can not be specified", colName); } if (!mentionedColumns.add(colName)) { - throw new SemanticException("Column '%s' specified twice", colName); + ErrorReport.reportSemanticException(ErrorCode.ERR_DUP_FIELDNAME, colName); } requiredKeyColumns.remove(colName.toLowerCase()); targetColumns.add(column); @@ -274,13 +273,40 @@ public static void analyzeWithDeferredLock(InsertStmt insertStmt, ConnectContext if ((table.isIcebergTable() || table.isHiveTable()) && insertStmt.isStaticKeyPartitionInsert()) { // full column size = mentioned column size + partition column size for static partition insert mentionedColumnSize -= table.getPartitionColumnNames().size(); + mentionedColumns.removeAll(table.getPartitionColumnNames()); } + // check target and source columns match QueryRelation query = insertStmt.getQueryStatement().getQueryRelation(); - if (query.getRelationFields().size() != mentionedColumnSize) { - ErrorReport.reportSemanticException(ErrorCode.ERR_INSERTED_COLUMN_MISMATCH, mentionedColumnSize, - query.getRelationFields().size()); + if (insertStmt.isColumnMatchByPosition()) { + if (query.getRelationFields().size() != mentionedColumnSize) { + ErrorReport.reportSemanticException(ErrorCode.ERR_INSERT_COLUMN_COUNT_MISMATCH, mentionedColumnSize, + query.getRelationFields().size()); + } + } else { + Preconditions.checkState(insertStmt.isColumnMatchByName()); + if (query instanceof ValuesRelation) { + throw new SemanticException("Insert match column by name does not support values()"); + } + + Set selectColumnNames = Sets.newTreeSet(String.CASE_INSENSITIVE_ORDER); + for (String colName : insertStmt.getQueryStatement().getQueryRelation().getColumnOutputNames()) { + if (!selectColumnNames.add(colName)) { + ErrorReport.reportSemanticException(ErrorCode.ERR_DUP_FIELDNAME, colName); + } + } + if (!selectColumnNames.containsAll(mentionedColumns)) { + mentionedColumns.removeAll(selectColumnNames); + ErrorReport.reportSemanticException( + ErrorCode.ERR_INSERT_COLUMN_NAME_MISMATCH, "Target", String.join(", ", mentionedColumns), "source"); + } + if (!mentionedColumns.containsAll(selectColumnNames)) { + selectColumnNames.removeAll(mentionedColumns); + ErrorReport.reportSemanticException( + ErrorCode.ERR_INSERT_COLUMN_NAME_MISMATCH, "Source", String.join(", ", selectColumnNames), "target"); + } } + // check default value expr if (query instanceof ValuesRelation) { ValuesRelation valuesRelation = (ValuesRelation) query; @@ -308,15 +334,29 @@ public static void analyzeWithDeferredLock(InsertStmt insertStmt, ConnectContext private static void analyzeProperties(InsertStmt insertStmt, ConnectContext session) { Map properties = insertStmt.getProperties(); + + // column match by related properties + // parse the property and remove it for 'LoadStmt.checkProperties' validation + if (properties.containsKey(InsertStmt.PROPERTY_MATCH_COLUMN_BY)) { + String property = properties.remove(InsertStmt.PROPERTY_MATCH_COLUMN_BY); + ColumnMatchPolicy columnMatchPolicy = ColumnMatchPolicy.fromString(property); + if (columnMatchPolicy == null) { + String msg = String.format("%s (case insensitive)", String.join(", ", ColumnMatchPolicy.getCandidates())); + ErrorReport.reportSemanticException( + ErrorCode.ERR_INVALID_VALUE, InsertStmt.PROPERTY_MATCH_COLUMN_BY, property, msg); + } + insertStmt.setColumnMatchPolicy(columnMatchPolicy); + } + + // check common properties // use session variable if not set max_filter_ratio property if (!properties.containsKey(LoadStmt.MAX_FILTER_RATIO_PROPERTY)) { properties.put(LoadStmt.MAX_FILTER_RATIO_PROPERTY, String.valueOf(session.getSessionVariable().getInsertMaxFilterRatio())); } // use session variable if not set strict_mode property - if (!properties.containsKey(LoadStmt.STRICT_MODE) && - session.getSessionVariable().getEnableInsertStrict()) { - properties.put(LoadStmt.STRICT_MODE, "true"); + if (!properties.containsKey(LoadStmt.STRICT_MODE)) { + properties.put(LoadStmt.STRICT_MODE, String.valueOf(session.getSessionVariable().getEnableInsertStrict())); } try { diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/ast/InsertStmt.java b/fe/fe-core/src/main/java/com/starrocks/sql/ast/InsertStmt.java index 40f7ef3b022f5..81804b8e99402 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/ast/InsertStmt.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/ast/InsertStmt.java @@ -29,6 +29,7 @@ import com.starrocks.sql.parser.NodePosition; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -53,6 +54,7 @@ */ public class InsertStmt extends DmlStmt { public static final String STREAMING = "STREAMING"; + public static final String PROPERTY_MATCH_COLUMN_BY = "match_column_by"; private final TableName tblName; private PartitionNames targetPartitionNames; @@ -96,6 +98,9 @@ public class InsertStmt extends DmlStmt { private boolean isVersionOverwrite = false; + // column match by position or name + private ColumnMatchPolicy columnMatchPolicy = ColumnMatchPolicy.POSITION; + public InsertStmt(TableName tblName, PartitionNames targetPartitionNames, String label, List cols, QueryStatement queryStatement, boolean isOverwrite, Map insertProperties, NodePosition pos) { @@ -328,4 +333,34 @@ public Table makeTableFunctionTable(SessionVariable sessionVariable) { List columns = collectSelectedFieldsFromQueryStatement(); return new TableFunctionTable(columns, getTableFunctionProperties(), sessionVariable); } + + public enum ColumnMatchPolicy { + POSITION, + NAME; + + public static ColumnMatchPolicy fromString(String value) { + for (ColumnMatchPolicy policy : values()) { + if (policy.name().equalsIgnoreCase(value)) { + return policy; + } + } + return null; + } + + public static List getCandidates() { + return Arrays.stream(values()).map(p -> p.name().toLowerCase()).collect(Collectors.toList()); + } + } + + public boolean isColumnMatchByPosition() { + return columnMatchPolicy == ColumnMatchPolicy.POSITION; + } + + public boolean isColumnMatchByName() { + return columnMatchPolicy == ColumnMatchPolicy.NAME; + } + + public void setColumnMatchPolicy(ColumnMatchPolicy columnMatchPolicy) { + this.columnMatchPolicy = columnMatchPolicy; + } } diff --git a/test/sql/test_insert_empty/R/test_insert_by_name b/test/sql/test_insert_empty/R/test_insert_by_name new file mode 100644 index 0000000000000..78b8ab78e6a5e --- /dev/null +++ b/test/sql/test_insert_empty/R/test_insert_by_name @@ -0,0 +1,126 @@ +-- name: test_insert_by_name + +create database db_${uuid0}; +use db_${uuid0}; + +create table t1 (k1 int, k2 varchar(100)); + +insert into t1 properties("match_column_by" = "name") select "a" as k2, 1 as k1; +-- result: +-- !result + +select * from t1; +-- result: +1 a +-- !result + +truncate table t1; +-- result: +-- !result + +insert into t1 (k2, k1) properties("match_column_by" = "name") select 2 as k1, "b" as k2; +-- result: +-- !result + +select * from t1; +-- result: +2 b +-- !result + +truncate table t1; +-- result: +-- !result + +insert into t1 (k2, k1) properties("match_column_by" = "position") select "d" as k1, 4 as k2; +-- result: +-- !result + +select * from t1; +-- result: +4 d +-- !result + +truncate table t1; +-- result: +-- !result + +insert into t1 properties("match_column_by" = "name") values(1, "a"); +-- result: +[REGEX].*Insert match column by name does not support values\(\). +-- !result + +insert into t1 properties("match_column_by" = "name") select "a" as k2, 1 as k1, 2 as k3; +-- result: +[REGEX].*Source column: k3 has no matching target column. +-- !result + +insert into t1 properties("match_column_by" = "name") select 1 as k1; +-- result: +[REGEX].*Target column: k2 has no matching source column. +-- !result + +insert into t1 properties("match_column_by" = "invalid_value") values(1, "a"); +-- result: +[REGEX].*Invalid match_column_by: 'invalid_value'. Expected values should be position, name \(case insensitive\). +-- !result + + +create table t2 (k1 int, k2 varchar(100), k3 int default "10"); + +insert into t1 values(3, "c"); +-- result: +-- !result + +select * from t1; +-- result: +3 c +-- !result + +insert into t2 (k1, k2) properties("match_column_by" = "name") select * from t1; +-- result: +-- !result + +select * from t2; +-- result: +3 c 10 +-- !result + +truncate table t2; +-- result: +-- !result + +insert into t2 properties("match_column_by" = "name") select *, 11 as k3 from t1; +-- result: +-- !result + +select * from t2; +-- result: +3 c 11 +-- !result + +truncate table t2; +-- result: +-- !result + +insert into t2 properties("match_column_by" = "name") select k1 + 1 as k1, k2, 12 as k3 from t1; +-- result: +-- !result + +select * from t2; +-- result: +4 c 12 +-- !result + +truncate table t2; +-- result: +-- !result + +insert into t2 properties("match_column_by" = "name") select * from t1; +-- result: +[REGEX].*Target column: k3 has no matching source column. +-- !result + +insert into t2 properties("match_column_by" = "name") select k1 + 1 as k1, k1, 12 as k3 from t1; +-- result: +[REGEX].*Duplicate column name 'k1'. +-- !result diff --git a/test/sql/test_insert_empty/T/test_insert_by_name b/test/sql/test_insert_empty/T/test_insert_by_name new file mode 100644 index 0000000000000..e522db17cbb42 --- /dev/null +++ b/test/sql/test_insert_empty/T/test_insert_by_name @@ -0,0 +1,46 @@ +-- name: test_insert_by_name + +create database db_${uuid0}; +use db_${uuid0}; + +create table t1 (k1 int, k2 varchar(100)); + +insert into t1 properties("match_column_by" = "name") select "a" as k2, 1 as k1; +select * from t1; +truncate table t1; + +insert into t1 (k2, k1) properties("match_column_by" = "name") select 2 as k1, "b" as k2; +select * from t1; +truncate table t1; + +insert into t1 (k2, k1) properties("match_column_by" = "position") select "d" as k1, 4 as k2; +select * from t1; +truncate table t1; + +-- error case +insert into t1 properties("match_column_by" = "name") values(1, "a"); +insert into t1 properties("match_column_by" = "name") select "a" as k2, 1 as k1, 2 as k3; +insert into t1 properties("match_column_by" = "name") select 1 as k1; +insert into t1 properties("match_column_by" = "invalid_value") values(1, "a"); + + +create table t2 (k1 int, k2 varchar(100), k3 int default "10"); + +insert into t1 values(3, "c"); +select * from t1; + +insert into t2 (k1, k2) properties("match_column_by" = "name") select * from t1; +select * from t2; +truncate table t2; + +insert into t2 properties("match_column_by" = "name") select *, 11 as k3 from t1; +select * from t2; +truncate table t2; + +insert into t2 properties("match_column_by" = "name") select k1 + 1 as k1, k2, 12 as k3 from t1; +select * from t2; +truncate table t2; + +-- error case +insert into t2 properties("match_column_by" = "name") select * from t1; +insert into t2 properties("match_column_by" = "name") select k1 + 1 as k1, k1, 12 as k3 from t1;