Skip to content

Commit

Permalink
Support insert statement rewrite use quote (#34259)
Browse files Browse the repository at this point in the history
* Support insert statement rewrite use quote

* Support insert statement rewrite use quote

* Support insert statement rewrite use quote
  • Loading branch information
FlyingZC authored Jan 6, 2025
1 parent aa84856 commit 52be4b5
Show file tree
Hide file tree
Showing 18 changed files with 109 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,17 @@ private Optional<EncryptAssignmentToken> generateSQLToken(final String schemaNam
}

private EncryptAssignmentToken generateParameterSQLToken(final EncryptColumn encryptColumn, final ColumnAssignmentSegment segment) {
EncryptParameterAssignmentToken result = new EncryptParameterAssignmentToken(segment.getColumns().get(0).getStartIndex(), segment.getStopIndex());
EncryptParameterAssignmentToken result =
new EncryptParameterAssignmentToken(segment.getColumns().get(0).getStartIndex(), segment.getStopIndex(), segment.getColumns().get(0).getIdentifier().getQuoteCharacter());
result.addColumnName(encryptColumn.getCipher().getName());
encryptColumn.getAssistedQuery().ifPresent(optional -> result.addColumnName(optional.getName()));
encryptColumn.getLikeQuery().ifPresent(optional -> result.addColumnName(optional.getName()));
return result;
}

private EncryptAssignmentToken generateLiteralSQLToken(final String schemaName, final String tableName, final EncryptColumn encryptColumn, final ColumnAssignmentSegment segment) {
EncryptLiteralAssignmentToken result = new EncryptLiteralAssignmentToken(segment.getColumns().get(0).getStartIndex(), segment.getStopIndex());
EncryptLiteralAssignmentToken result =
new EncryptLiteralAssignmentToken(segment.getColumns().get(0).getStartIndex(), segment.getStopIndex(), segment.getColumns().get(0).getIdentifier().getQuoteCharacter());
addCipherAssignment(schemaName, tableName, encryptColumn, segment, result);
addAssistedQueryAssignment(schemaName, tableName, encryptColumn, segment, result);
addLikeAssignment(schemaName, tableName, encryptColumn, segment, result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.generic.SubstitutableColumnNameToken;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.InsertColumnsSegment;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;

import java.util.Collection;
import java.util.Collections;
Expand Down Expand Up @@ -75,7 +76,8 @@ public Collection<SQLToken> generateSQLTokens(final InsertStatementContext inser
String columnName = each.getIdentifier().getValue();
if (encryptTable.isEncryptColumn(columnName)) {
Collection<Projection> projections =
Collections.singleton(new ColumnProjection(null, encryptTable.getEncryptColumn(columnName).getCipher().getName(), null, insertStatementContext.getDatabaseType()));
Collections.singleton(new ColumnProjection(null, new IdentifierValue(encryptTable.getEncryptColumn(columnName).getCipher().getName(), each.getIdentifier().getQuoteCharacter()),
null, insertStatementContext.getDatabaseType()));
result.add(new SubstitutableColumnNameToken(each.getStartIndex(), each.getStopIndex(), projections, insertStatementContext.getDatabaseType()));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import org.apache.shardingsphere.infra.binder.context.segment.select.projection.Projection;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.InsertStatementContext;
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.QuoteCharacter;
import org.apache.shardingsphere.infra.database.core.type.DatabaseTypeRegistry;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.infra.exception.generic.UnsupportedSQLOperationException;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.generator.OptionalSQLTokenGenerator;
Expand Down Expand Up @@ -96,8 +98,9 @@ private UseDefaultInsertColumnsToken generateNewSQLToken(final InsertStatementCo
ShardingSpherePreconditions.checkState(InsertSelectColumnsEncryptorComparator.isSame(derivedInsertColumns, projections, rule),
() -> new UnsupportedSQLOperationException("Can not use different encryptor in insert select columns"));
}
QuoteCharacter quoteCharacter = new DatabaseTypeRegistry(insertStatementContext.getDatabaseType()).getDialectDatabaseMetaData().getQuoteCharacter();
return new UseDefaultInsertColumnsToken(
insertColumnsSegment.get().getStopIndex(), getColumnNames(insertStatementContext, rule.getEncryptTable(tableName), insertStatementContext.getColumnNames()));
insertColumnsSegment.get().getStopIndex(), getColumnNames(insertStatementContext, rule.getEncryptTable(tableName), insertStatementContext.getColumnNames()), quoteCharacter);
}

private List<String> getColumnNames(final InsertStatementContext sqlStatementContext, final EncryptTable encryptTable, final List<String> currentColumnNames) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public Collection<SQLToken> generateSQLTokens(final InsertStatementContext inser
for (ColumnSegment each : insertStatementContext.getSqlStatement().getColumns()) {
List<String> derivedColumnNames = getDerivedColumnNames(encryptTable, each);
if (!derivedColumnNames.isEmpty()) {
result.add(new InsertColumnsToken(each.getStopIndex() + 1, derivedColumnNames));
result.add(new InsertColumnsToken(each.getStopIndex() + 1, derivedColumnNames, each.getIdentifier().getQuoteCharacter()));
}
}
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ private Optional<EncryptAssignmentToken> generateSQLToken(final String schemaNam
}

private EncryptAssignmentToken generateParameterSQLToken(final EncryptTable encryptTable, final ColumnAssignmentSegment assignmentSegment) {
EncryptParameterAssignmentToken result = new EncryptParameterAssignmentToken(assignmentSegment.getColumns().get(0).getStartIndex(), assignmentSegment.getStopIndex());
EncryptParameterAssignmentToken result = new EncryptParameterAssignmentToken(assignmentSegment.getColumns().get(0).getStartIndex(), assignmentSegment.getStopIndex(),
assignmentSegment.getColumns().get(0).getIdentifier().getQuoteCharacter());
String columnName = assignmentSegment.getColumns().get(0).getIdentifier().getValue();
EncryptColumn encryptColumn = encryptTable.getEncryptColumn(columnName);
result.addColumnName(encryptColumn.getCipher().getName());
Expand All @@ -126,7 +127,8 @@ private EncryptAssignmentToken generateParameterSQLToken(final EncryptTable encr

private EncryptAssignmentToken generateLiteralSQLToken(final String schemaName, final String tableName,
final EncryptColumn encryptColumn, final ColumnAssignmentSegment assignmentSegment) {
EncryptLiteralAssignmentToken result = new EncryptLiteralAssignmentToken(assignmentSegment.getColumns().get(0).getStartIndex(), assignmentSegment.getStopIndex());
EncryptLiteralAssignmentToken result = new EncryptLiteralAssignmentToken(assignmentSegment.getColumns().get(0).getStartIndex(), assignmentSegment.getStopIndex(),
assignmentSegment.getColumns().get(0).getIdentifier().getQuoteCharacter());
addCipherAssignment(schemaName, tableName, encryptColumn, assignmentSegment, result);
addAssistedQueryAssignment(schemaName, tableName, encryptColumn, assignmentSegment, result);
addLikeAssignment(schemaName, tableName, encryptColumn, assignmentSegment, result);
Expand All @@ -139,7 +141,8 @@ private EncryptAssignmentToken generateValuesSQLToken(final EncryptTable encrypt
Optional<ExpressionSegment> valueColumnSegment = functionSegment.getParameters().stream().findFirst();
Preconditions.checkState(valueColumnSegment.isPresent());
String valueColumn = ((ColumnSegment) valueColumnSegment.get()).getIdentifier().getValue();
EncryptFunctionAssignmentToken result = new EncryptFunctionAssignmentToken(columnSegment.getStartIndex(), assignmentSegment.getStopIndex());
EncryptFunctionAssignmentToken result =
new EncryptFunctionAssignmentToken(columnSegment.getStartIndex(), assignmentSegment.getStopIndex(), assignmentSegment.getColumns().get(0).getIdentifier().getQuoteCharacter());
boolean isEncryptColumn = encryptTable.isEncryptColumn(column);
boolean isEncryptValueColumn = encryptTable.isEncryptColumn(valueColumn);
EncryptColumn encryptColumn = encryptTable.getEncryptColumn(column);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.shardingsphere.encrypt.rewrite.token.pojo;

import lombok.Getter;
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.QuoteCharacter;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.Substitutable;

Expand All @@ -29,8 +30,11 @@ public abstract class EncryptAssignmentToken extends SQLToken implements Substit

private final int stopIndex;

protected EncryptAssignmentToken(final int startIndex, final int stopIndex) {
private final QuoteCharacter quoteCharacter;

protected EncryptAssignmentToken(final int startIndex, final int stopIndex, final QuoteCharacter quoteCharacter) {
super(startIndex);
this.stopIndex = stopIndex;
this.quoteCharacter = quoteCharacter;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.shardingsphere.encrypt.rewrite.token.pojo;

import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.QuoteCharacter;

import java.util.Collection;
import java.util.LinkedList;
Expand All @@ -31,8 +32,8 @@ public final class EncryptFunctionAssignmentToken extends EncryptAssignmentToken

private final Collection<FunctionAssignment> assignments = new LinkedList<>();

public EncryptFunctionAssignmentToken(final int startIndex, final int stopIndex) {
super(startIndex, stopIndex);
public EncryptFunctionAssignmentToken(final int startIndex, final int stopIndex, final QuoteCharacter quoteCharacter) {
super(startIndex, stopIndex, quoteCharacter);
}

/**
Expand All @@ -42,7 +43,7 @@ public EncryptFunctionAssignmentToken(final int startIndex, final int stopIndex)
* @param value assignment value
*/
public void addAssignment(final String columnName, final Object value) {
FunctionAssignment functionAssignment = new FunctionAssignment(columnName, value);
FunctionAssignment functionAssignment = new FunctionAssignment(columnName, value, getQuoteCharacter());
assignments.add(functionAssignment);
builder.append(functionAssignment).append(", ");
}
Expand All @@ -68,9 +69,11 @@ private static final class FunctionAssignment {

private final Object value;

private final QuoteCharacter quoteCharacter;

@Override
public String toString() {
return String.format("%s = %s", columnName, value);
return quoteCharacter.wrap(columnName) + " = " + value;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.shardingsphere.encrypt.rewrite.token.pojo;

import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.QuoteCharacter;

import java.util.Collection;
import java.util.LinkedList;
Expand All @@ -30,8 +31,8 @@ public final class EncryptLiteralAssignmentToken extends EncryptAssignmentToken

private final Collection<LiteralAssignment> assignments = new LinkedList<>();

public EncryptLiteralAssignmentToken(final int startIndex, final int stopIndex) {
super(startIndex, stopIndex);
public EncryptLiteralAssignmentToken(final int startIndex, final int stopIndex, final QuoteCharacter quoteCharacter) {
super(startIndex, stopIndex, quoteCharacter);
}

/**
Expand All @@ -41,7 +42,7 @@ public EncryptLiteralAssignmentToken(final int startIndex, final int stopIndex)
* @param value assignment value
*/
public void addAssignment(final String columnName, final Object value) {
assignments.add(new LiteralAssignment(columnName, value));
assignments.add(new LiteralAssignment(columnName, value, getQuoteCharacter()));
}

@Override
Expand All @@ -56,9 +57,11 @@ private static final class LiteralAssignment {

private final Object value;

private final QuoteCharacter quoteCharacter;

@Override
public String toString() {
return columnName + " = " + toString(value);
return quoteCharacter.wrap(columnName) + " = " + toString(value);
}

private String toString(final Object value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.shardingsphere.encrypt.rewrite.token.pojo;

import org.apache.shardingsphere.infra.database.core.metadata.database.enums.QuoteCharacter;

import java.util.Collection;
import java.util.LinkedList;
import java.util.StringJoiner;
Expand All @@ -28,8 +30,8 @@ public final class EncryptParameterAssignmentToken extends EncryptAssignmentToke

private final Collection<String> columnNames = new LinkedList<>();

public EncryptParameterAssignmentToken(final int startIndex, final int stopIndex) {
super(startIndex, stopIndex);
public EncryptParameterAssignmentToken(final int startIndex, final int stopIndex, final QuoteCharacter quoteCharacter) {
super(startIndex, stopIndex, quoteCharacter);
}

/**
Expand All @@ -45,7 +47,7 @@ public void addColumnName(final String columnName) {
public String toString() {
StringJoiner result = new StringJoiner(", ");
for (String each : columnNames) {
result.add(each + " = ?");
result.add(getQuoteCharacter().wrap(each) + " = ?");
}
return result.toString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ void assertIsGenerateSQLToken() {
void assertGenerateSQLTokenFromGenerateNewSQLToken() {
generator.setPreviousSQLTokens(Collections.emptyList());
assertThat(generator.generateSQLToken(EncryptGeneratorFixtureBuilder.createInsertStatementContext(Collections.emptyList())).toString(),
is("(id, name, status, pwd_cipher, pwd_assist, pwd_like)"));
is("(`id`, `name`, `status`, `pwd_cipher`, `pwd_assist`, `pwd_like`)"));
}

@Test
void assertGenerateSQLTokenFromPreviousSQLTokens() {
generator.setPreviousSQLTokens(EncryptGeneratorFixtureBuilder.getPreviousSQLTokens());
assertThat(generator.generateSQLToken(EncryptGeneratorFixtureBuilder.createInsertStatementContext(Collections.emptyList())).toString(),
is("(id, name, status, pwd_cipher, pwd_assist, pwd_like)"));
is("(`id`, `name`, `status`, `pwd_cipher`, `pwd_assist`, `pwd_like`)"));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.shardingsphere.encrypt.rewrite.token.pojo;

import org.apache.shardingsphere.infra.database.core.metadata.database.enums.QuoteCharacter;
import org.junit.jupiter.api.Test;

import static org.hamcrest.CoreMatchers.is;
Expand All @@ -26,7 +27,7 @@ class EncryptLiteralAssignmentTokenTest {

@Test
void assertToString() {
EncryptLiteralAssignmentToken actual = new EncryptLiteralAssignmentToken(0, 1);
EncryptLiteralAssignmentToken actual = new EncryptLiteralAssignmentToken(0, 1, QuoteCharacter.NONE);
actual.addAssignment("c1", "c1");
actual.addAssignment("c2", 1);
assertThat(actual.toString(), is("c1 = 'c1', c2 = 1"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.shardingsphere.encrypt.rewrite.token.pojo;

import org.apache.shardingsphere.infra.database.core.metadata.database.enums.QuoteCharacter;
import org.junit.jupiter.api.Test;

import static org.hamcrest.CoreMatchers.is;
Expand All @@ -26,7 +27,7 @@ class EncryptParameterAssignmentTokenTest {

@Test
void assertToString() {
EncryptParameterAssignmentToken actual = new EncryptParameterAssignmentToken(0, 1);
EncryptParameterAssignmentToken actual = new EncryptParameterAssignmentToken(0, 1, QuoteCharacter.NONE);
actual.addColumnName("c1");
actual.addColumnName("c2");
assertThat(actual.toString(), is("c1 = ?, c2 = ?"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.generic;

import org.apache.shardingsphere.infra.database.core.metadata.database.enums.QuoteCharacter;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.Attachable;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken;

Expand All @@ -30,9 +31,18 @@ public final class InsertColumnsToken extends SQLToken implements Attachable {

private final List<String> columns;

private final QuoteCharacter quoteCharacter;

public InsertColumnsToken(final int startIndex, final List<String> columns) {
super(startIndex);
this.columns = columns;
this.quoteCharacter = QuoteCharacter.NONE;
}

public InsertColumnsToken(final int startIndex, final List<String> columns, final QuoteCharacter quoteCharacter) {
super(startIndex);
this.columns = columns;
this.quoteCharacter = quoteCharacter;
}

@Override
Expand All @@ -41,7 +51,7 @@ public String toString() {
return "";
}
StringJoiner result = new StringJoiner(", ", ", ", "");
columns.forEach(result::add);
columns.forEach(each -> result.add(quoteCharacter.wrap(each)));
return result.toString();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
package org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.generic;

import lombok.Getter;
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.QuoteCharacter;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.Attachable;
import org.apache.shardingsphere.infra.rewrite.sql.token.common.pojo.SQLToken;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

/**
Expand All @@ -31,14 +34,31 @@ public final class UseDefaultInsertColumnsToken extends SQLToken implements Atta

private final List<String> columns;

private final QuoteCharacter quoteCharacter;

public UseDefaultInsertColumnsToken(final int startIndex, final List<String> columns) {
super(startIndex);
this.columns = columns;
this.quoteCharacter = QuoteCharacter.NONE;
}

public UseDefaultInsertColumnsToken(final int startIndex, final List<String> columns, final QuoteCharacter quoteCharacter) {
super(startIndex);
this.columns = columns;
this.quoteCharacter = quoteCharacter;
}

@Override
public String toString() {
return columns.isEmpty() ? "" : "(" + String.join(", ", columns) + ")";
return columns.isEmpty() ? "" : "(" + String.join(", ", getColumnNames()) + ")";
}

private Collection<String> getColumnNames() {
Collection<String> result = new ArrayList<>(columns.size());
for (String each : columns) {
result.add(quoteCharacter.wrap(each));
}
return result;
}

@Override
Expand Down
Loading

0 comments on commit 52be4b5

Please sign in to comment.