Skip to content

Commit

Permalink
Refactor select combine statement parse result to SubquerySegment (#3…
Browse files Browse the repository at this point in the history
…0693)

* Refactor select combine statement parse result to SubquerySegment

* setSubqueryType when CombineSegment bind
  • Loading branch information
strongduanmu authored Mar 29, 2024
1 parent 0cff1a0 commit 4e85074
Show file tree
Hide file tree
Showing 13 changed files with 71 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.infra.binder.segment.from.TableSegmentBinderContext;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementBinderContext;
import org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementBinder;
import org.apache.shardingsphere.infra.metadata.ShardingSphereMetaData;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.combine.CombineSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;

import java.util.Map;

/**
* Combine segment binder.
*/
Expand All @@ -41,10 +45,15 @@ public final class CombineSegmentBinder {
public static CombineSegment bind(final CombineSegment segment, final SQLStatementBinderContext statementBinderContext) {
ShardingSphereMetaData metaData = statementBinderContext.getMetaData();
String defaultDatabaseName = statementBinderContext.getDefaultDatabaseName();
SelectStatement boundedLeftSelect = new SelectStatementBinder().bindWithExternalTableContexts(segment.getLeft(), metaData, defaultDatabaseName,
statementBinderContext.getExternalTableBinderContexts());
SelectStatement boundedRightSelect = new SelectStatementBinder().bindWithExternalTableContexts(segment.getRight(), metaData, defaultDatabaseName,
statementBinderContext.getExternalTableBinderContexts());
return new CombineSegment(segment.getStartIndex(), segment.getStopIndex(), boundedLeftSelect, segment.getCombineType(), boundedRightSelect);
Map<String, TableSegmentBinderContext> externalTableBinderContexts = statementBinderContext.getExternalTableBinderContexts();
SelectStatement boundedLeftSelect = new SelectStatementBinder().bindWithExternalTableContexts(segment.getLeft().getSelect(), metaData, defaultDatabaseName, externalTableBinderContexts);
SelectStatement boundedRightSelect = new SelectStatementBinder().bindWithExternalTableContexts(segment.getRight().getSelect(), metaData, defaultDatabaseName, externalTableBinderContexts);
SubquerySegment boundedLeft = new SubquerySegment(segment.getLeft().getStartIndex(), segment.getLeft().getStopIndex(), segment.getLeft().getText());
boundedLeft.setSelect(boundedLeftSelect);
boundedLeft.setSubqueryType(segment.getLeft().getSubqueryType());
SubquerySegment boundedRight = new SubquerySegment(segment.getRight().getStartIndex(), segment.getRight().getStopIndex(), segment.getRight().getText());
boundedRight.setSelect(boundedRightSelect);
boundedRight.setSubqueryType(segment.getRight().getSubqueryType());
return new CombineSegment(segment.getStartIndex(), segment.getStopIndex(), boundedLeft, segment.getCombineType(), boundedRight);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ private SqlNode convertCombine(final SqlNode sqlNode, final SelectStatement sele
if (selectStatement.getCombine().isPresent()) {
CombineSegment combineSegment = selectStatement.getCombine().get();
return new SqlBasicCall(CombineOperatorConverter.convert(combineSegment.getCombineType()),
Arrays.asList(convert(combineSegment.getLeft()), convert(combineSegment.getRight())), SqlParserPos.ZERO);
Arrays.asList(convert(combineSegment.getLeft().getSelect()), convert(combineSegment.getRight().getSelect())), SqlParserPos.ZERO);
}
return sqlNode;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -766,26 +766,28 @@ public ASTNode visitQueryExpressionBody(final QueryExpressionBodyContext ctx) {
}
if (null != ctx.queryExpressionBody()) {
MySQLSelectStatement result = new MySQLSelectStatement();
MySQLSelectStatement left = (MySQLSelectStatement) visit(ctx.queryExpressionBody());
result.setProjections(left.getProjections());
left.getFrom().ifPresent(result::setFrom);
left.getTable().ifPresent(result::setTable);
SubquerySegment left = new SubquerySegment(ctx.queryExpressionBody().start.getStartIndex(), ctx.queryExpressionBody().stop.getStopIndex(),
(MySQLSelectStatement) visit(ctx.queryExpressionBody()), getOriginalText(ctx.queryExpressionBody()));
result.setProjections(left.getSelect().getProjections());
left.getSelect().getFrom().ifPresent(result::setFrom);
((MySQLSelectStatement) left.getSelect()).getTable().ifPresent(result::setTable);
result.setCombine(createCombineSegment(ctx.combineClause(), left));
return result;
}
if (null != ctx.queryExpressionParens()) {
MySQLSelectStatement result = new MySQLSelectStatement();
MySQLSelectStatement left = (MySQLSelectStatement) visit(ctx.queryExpressionParens());
result.setProjections(left.getProjections());
left.getFrom().ifPresent(result::setFrom);
left.getTable().ifPresent(result::setTable);
SubquerySegment left = new SubquerySegment(ctx.queryExpressionParens().start.getStartIndex(), ctx.queryExpressionParens().stop.getStopIndex(),
(MySQLSelectStatement) visit(ctx.queryExpressionParens()), getOriginalText(ctx.queryExpressionParens()));
result.setProjections(left.getSelect().getProjections());
left.getSelect().getFrom().ifPresent(result::setFrom);
((MySQLSelectStatement) left.getSelect()).getTable().ifPresent(result::setTable);
result.setCombine(createCombineSegment(ctx.combineClause(), left));
return result;
}
return visit(ctx.queryExpressionParens());
}

private CombineSegment createCombineSegment(final CombineClauseContext ctx, final MySQLSelectStatement left) {
private CombineSegment createCombineSegment(final CombineClauseContext ctx, final SubquerySegment left) {
CombineType combineType;
if (null != ctx.EXCEPT()) {
combineType = CombineType.EXCEPT;
Expand All @@ -794,7 +796,8 @@ private CombineSegment createCombineSegment(final CombineClauseContext ctx, fina
} else {
combineType = null == ctx.combineOption() || null == ctx.combineOption().ALL() ? CombineType.UNION : CombineType.UNION_ALL;
}
MySQLSelectStatement right = null == ctx.queryPrimary() ? (MySQLSelectStatement) visit(ctx.queryExpressionParens()) : (MySQLSelectStatement) visit(ctx.queryPrimary());
ParserRuleContext ruleContext = null == ctx.queryPrimary() ? ctx.queryExpressionParens() : ctx.queryPrimary();
SubquerySegment right = new SubquerySegment(ruleContext.start.getStartIndex(), ruleContext.stop.getStopIndex(), (MySQLSelectStatement) visit(ruleContext), getOriginalText(ruleContext));
return new CombineSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), left, combineType, right);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -969,14 +969,18 @@ public ASTNode visitSelectClauseN(final SelectClauseNContext ctx) {
OpenGaussSelectStatement left = (OpenGaussSelectStatement) visit(ctx.selectClauseN(0));
result.setProjections(left.getProjections());
left.getFrom().ifPresent(result::setFrom);
CombineSegment combineSegment = new CombineSegment(((TerminalNode) ctx.getChild(1)).getSymbol().getStartIndex(), ctx.getStop().getStopIndex(), left, getCombineType(ctx),
(OpenGaussSelectStatement) visit(ctx.selectClauseN(1)));
CombineSegment combineSegment = new CombineSegment(((TerminalNode) ctx.getChild(1)).getSymbol().getStartIndex(), ctx.getStop().getStopIndex(),
createSubquerySegment(ctx.selectClauseN(0), left), getCombineType(ctx), createSubquerySegment(ctx.selectClauseN(1), (OpenGaussSelectStatement) visit(ctx.selectClauseN(1))));
result.setCombine(combineSegment);
return result;
}
return visit(ctx.selectWithParens());
}

private SubquerySegment createSubquerySegment(final SelectClauseNContext ctx, final OpenGaussSelectStatement selectStatement) {
return new SubquerySegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), selectStatement, getOriginalText(ctx));
}

private CombineType getCombineType(final SelectClauseNContext ctx) {
boolean isDistinct = null == ctx.allOrDistinct() || null != ctx.allOrDistinct().DISTINCT();
if (null != ctx.UNION()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,12 @@ private void createSelectCombineClause(final SelectSubqueryContext ctx, final Or
} else {
combineType = CombineType.MINUS;
}
result.setCombine(new CombineSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), left, combineType, (OracleSelectStatement) visit(ctx.selectSubquery(1))));
result.setCombine(new CombineSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), createSubquerySegment(ctx.selectSubquery(0), left), combineType,
createSubquerySegment(ctx.selectSubquery(1), (OracleSelectStatement) visit(ctx.selectSubquery(1)))));
}

private SubquerySegment createSubquerySegment(final SelectSubqueryContext ctx, final OracleSelectStatement selectStatement) {
return new SubquerySegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), selectStatement, getOriginalText(ctx));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -939,14 +939,18 @@ public ASTNode visitSelectClauseN(final SelectClauseNContext ctx) {
PostgreSQLSelectStatement left = (PostgreSQLSelectStatement) visit(ctx.selectClauseN(0));
result.setProjections(left.getProjections());
left.getFrom().ifPresent(result::setFrom);
CombineSegment combineSegment = new CombineSegment(((TerminalNode) ctx.getChild(1)).getSymbol().getStartIndex(), ctx.getStop().getStopIndex(), left, getCombineType(ctx),
(PostgreSQLSelectStatement) visit(ctx.selectClauseN(1)));
CombineSegment combineSegment = new CombineSegment(((TerminalNode) ctx.getChild(1)).getSymbol().getStartIndex(), ctx.getStop().getStopIndex(),
createSubquerySegment(ctx.selectClauseN(0), left), getCombineType(ctx), createSubquerySegment(ctx.selectClauseN(1), (PostgreSQLSelectStatement) visit(ctx.selectClauseN(1))));
result.setCombine(combineSegment);
return result;
}
return visit(ctx.selectWithParens());
}

private SubquerySegment createSubquerySegment(final SelectClauseNContext ctx, final PostgreSQLSelectStatement selectStatement) {
return new SubquerySegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), selectStatement, getOriginalText(ctx));
}

private CombineType getCombineType(final SelectClauseNContext ctx) {
boolean isDistinct = null == ctx.allOrDistinct() || null != ctx.allOrDistinct().DISTINCT();
if (null != ctx.UNION()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ public final class TableExtractor {
public void extractTablesFromSelect(final SelectStatement selectStatement) {
if (selectStatement.getCombine().isPresent()) {
CombineSegment combineSegment = selectStatement.getCombine().get();
extractTablesFromSelect(combineSegment.getLeft());
extractTablesFromSelect(combineSegment.getRight());
extractTablesFromSelect(combineSegment.getLeft().getSelect());
extractTablesFromSelect(combineSegment.getRight().getSelect());
}
if (selectStatement.getFrom().isPresent() && !selectStatement.getCombine().isPresent()) {
extractTablesFromTableSegment(selectStatement.getFrom().get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.sql.parser.sql.common.enums.CombineType;
import org.apache.shardingsphere.sql.parser.sql.common.segment.SQLSegment;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.subquery.SubquerySegment;

/**
* Combine segment.
Expand All @@ -34,9 +34,9 @@ public final class CombineSegment implements SQLSegment {

private final int stopIndex;

private final SelectStatement left;
private final SubquerySegment left;

private final CombineType combineType;

private final SelectStatement right;
private final SubquerySegment right;
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ public static void extractFromSelectStatementWithoutProjection(final Collection<
statement.getGroupBy().ifPresent(optional -> extractFromGroupBy(columnSegments, optional, containsSubQuery));
statement.getHaving().ifPresent(optional -> extractFromHaving(columnSegments, optional, containsSubQuery));
statement.getOrderBy().ifPresent(optional -> extractFromOrderBy(columnSegments, optional, containsSubQuery));
statement.getCombine().ifPresent(optional -> extractFromSelectStatement(columnSegments, optional.getRight(), containsSubQuery));
statement.getCombine().ifPresent(optional -> extractFromSelectStatement(columnSegments, optional.getRight().getSelect(), containsSubQuery));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ private static void extractSubquerySegmentsFromExpression(final List<SubquerySeg
}

private static void extractSubquerySegmentsFromCombine(final List<SubquerySegment> result, final CombineSegment combineSegment) {
extractSubquerySegments(result, combineSegment.getLeft());
extractSubquerySegments(result, combineSegment.getRight());
extractSubquerySegments(result, combineSegment.getLeft().getSelect());
extractSubquerySegments(result, combineSegment.getRight().getSelect());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ private void assertTableSegment(final SimpleTableSegment actual, final int expec
@Test
void assertExtractTablesFromCombineSegment() {
MySQLSelectStatement selectStatement = createSelectStatement("t_order");
selectStatement.setCombine(new CombineSegment(0, 0, createSelectStatement("t_order"), CombineType.UNION, createSelectStatement("t_order_item")));
SubquerySegment left = new SubquerySegment(0, 0, createSelectStatement("t_order"), "");
SubquerySegment right = new SubquerySegment(0, 0, createSelectStatement("t_order_item"), "");
selectStatement.setCombine(new CombineSegment(0, 0, left, CombineType.UNION, right));
tableExtractor.extractTablesFromSelect(selectStatement);
Collection<SimpleTableSegment> actual = tableExtractor.getRewriteTables();
assertThat(actual.size(), is(2));
Expand All @@ -172,7 +174,9 @@ private static MySQLSelectStatement createSelectStatement(final String tableName
@Test
void assertExtractTablesFromCombineSegmentWithColumnProjection() {
MySQLSelectStatement selectStatement = createSelectStatementWithColumnProjection("t_order");
selectStatement.setCombine(new CombineSegment(0, 0, createSelectStatementWithColumnProjection("t_order"), CombineType.UNION, createSelectStatementWithColumnProjection("t_order_item")));
SubquerySegment left = new SubquerySegment(0, 0, createSelectStatementWithColumnProjection("t_order"), "");
SubquerySegment right = new SubquerySegment(0, 0, createSelectStatementWithColumnProjection("t_order_item"), "");
selectStatement.setCombine(new CombineSegment(0, 0, left, CombineType.UNION, right));
tableExtractor.extractTablesFromSelect(selectStatement);
Collection<SimpleTableSegment> actual = tableExtractor.getRewriteTables();
assertThat(actual.size(), is(2));
Expand All @@ -197,7 +201,9 @@ private MySQLSelectStatement createSelectStatementWithColumnProjection(final Str
@Test
void assertExtractTablesFromCombineWithSubQueryProjection() {
MySQLSelectStatement selectStatement = createSelectStatementWithSubQueryProjection("t_order");
selectStatement.setCombine(new CombineSegment(0, 0, createSelectStatementWithSubQueryProjection("t_order"), CombineType.UNION, createSelectStatementWithSubQueryProjection("t_order_item")));
SubquerySegment left = new SubquerySegment(0, 0, createSelectStatementWithSubQueryProjection("t_order"), "");
SubquerySegment right = new SubquerySegment(0, 0, createSelectStatementWithSubQueryProjection("t_order_item"), "");
selectStatement.setCombine(new CombineSegment(0, 0, left, CombineType.UNION, right));
tableExtractor.extractTablesFromSelect(selectStatement);
Collection<SimpleTableSegment> actual = tableExtractor.getRewriteTables();
assertThat(actual.size(), is(2));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,17 +171,17 @@ private SubquerySegment createSubquerySegmentForFrom() {
@Test
void assertGetSubquerySegmentsWithCombineSegment() {
SelectStatement selectStatement = new MySQLSelectStatement();
selectStatement.setCombine(new CombineSegment(0, 0, new MySQLSelectStatement(), CombineType.UNION, createSelectStatementForCombineSegment()));
SubquerySegment left = new SubquerySegment(0, 0, new MySQLSelectStatement(), "");
selectStatement.setCombine(new CombineSegment(0, 0, left, CombineType.UNION, createSelectStatementForCombineSegment()));
Collection<SubquerySegment> actual = SubqueryExtractUtils.getSubquerySegments(selectStatement);
assertThat(actual.size(), is(1));
}

private SelectStatement createSelectStatementForCombineSegment() {
SelectStatement result = new MySQLSelectStatement();
private SubquerySegment createSelectStatementForCombineSegment() {
SelectStatement selectStatement = new MySQLSelectStatement();
ExpressionSegment left = new ColumnSegment(0, 0, new IdentifierValue("order_id"));
result.setWhere(new WhereSegment(0, 0, new InExpression(0, 0,
left, new SubqueryExpressionSegment(new SubquerySegment(0, 0, new MySQLSelectStatement(), "")), false)));
return result;
selectStatement.setWhere(new WhereSegment(0, 0, new InExpression(0, 0, left, new SubqueryExpressionSegment(new SubquerySegment(0, 0, new MySQLSelectStatement(), "")), false)));
return new SubquerySegment(0, 0, selectStatement, "");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ private static void assertCombineClause(final SQLCaseAssertContext assertContext
assertTrue(combineSegment.isPresent(), assertContext.getText("Actual combine segment should exist."));
assertThat(assertContext.getText("Combine type assertion error: "), combineSegment.get().getCombineType().name(), is(expected.getCombineClause().getCombineType()));
SQLSegmentAssert.assertIs(assertContext, combineSegment.get(), expected.getCombineClause());
assertIs(assertContext, combineSegment.get().getLeft(), expected.getCombineClause().getLeft());
assertIs(assertContext, combineSegment.get().getRight(), expected.getCombineClause().getRight());
assertIs(assertContext, combineSegment.get().getLeft().getSelect(), expected.getCombineClause().getLeft());
assertIs(assertContext, combineSegment.get().getRight().getSelect(), expected.getCombineClause().getRight());
}
}

Expand Down

0 comments on commit 4e85074

Please sign in to comment.