Skip to content

Commit

Permalink
Change return value to Optional with SelectStatement#getFrom method a…
Browse files Browse the repository at this point in the history
…nd remove useless table segment in ProjectionEngine (#30692)

* Change return value to Optional with SelectStatement#getFrom method and remove useless table segment in ProjectionEngine

* fix sql parse test

* fix unit test

* fix unit test

* fix checkstyle
  • Loading branch information
strongduanmu authored Mar 29, 2024
1 parent e7e2b07 commit 0cff1a0
Show file tree
Hide file tree
Showing 40 changed files with 134 additions and 136 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ private boolean isSubqueryContainsShardingCondition(final List<ShardingCondition
startIndexShardingConditions.computeIfAbsent(each.getStartIndex(), unused -> new LinkedList<>()).add(each);
}
for (SelectStatement each : selectStatements) {
if (each.getFrom() instanceof SubqueryTableSegment) {
if (each.getFrom().isPresent() && each.getFrom().get() instanceof SubqueryTableSegment) {
continue;
}
if (!each.getWhere().isPresent() || !startIndexShardingConditions.containsKey(each.getWhere().get().getExpr().getStartIndex())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.mockito.quality.Strictness;

import java.util.Collections;
import java.util.Optional;

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertThrows;
Expand Down Expand Up @@ -75,7 +76,7 @@ class ShardingCreateViewStatementValidatorTest {
void setUp() {
when(createViewStatementContext.getSqlStatement()).thenReturn(createViewStatement);
when(createViewStatement.getSelect()).thenReturn(selectStatement);
when(selectStatement.getFrom()).thenReturn(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))));
when(selectStatement.getFrom()).thenReturn(Optional.of(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order")))));
when(createViewStatement.getView()).thenReturn(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("order_view"))));
when(routeContext.getRouteUnits().size()).thenReturn(2);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ void assertIsAllBindingTableWithJoinQueryWithDatabaseJoinCondition() {
JoinTableSegment joinTable = mock(JoinTableSegment.class);
when(joinTable.getCondition()).thenReturn(condition);
MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class);
when(selectStatement.getFrom()).thenReturn(joinTable);
when(selectStatement.getFrom()).thenReturn(Optional.of(joinTable));
SelectStatementContext sqlStatementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(sqlStatementContext.getSqlStatement()).thenReturn(selectStatement);
when(sqlStatementContext.isContainsJoinQuery()).thenReturn(true);
Expand All @@ -622,7 +622,7 @@ void assertIsAllBindingTableWithJoinQueryWithDatabaseJoinConditionInUpperCaseAnd
JoinTableSegment joinTable = mock(JoinTableSegment.class);
when(joinTable.getCondition()).thenReturn(condition);
MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class);
when(selectStatement.getFrom()).thenReturn(joinTable);
when(selectStatement.getFrom()).thenReturn(Optional.of(joinTable));
SelectStatementContext sqlStatementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(sqlStatementContext.getSqlStatement()).thenReturn(selectStatement);
when(sqlStatementContext.isContainsJoinQuery()).thenReturn(true);
Expand Down Expand Up @@ -651,7 +651,7 @@ void assertIsAllBindingTableWithJoinQueryWithDatabaseTableJoinCondition() {
BinaryOperationExpression condition = createBinaryOperationExpression(databaseJoin, tableJoin, AND);
when(joinTable.getCondition()).thenReturn(condition);
MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class);
when(selectStatement.getFrom()).thenReturn(joinTable);
when(selectStatement.getFrom()).thenReturn(Optional.of(joinTable));
SelectStatementContext sqlStatementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS);
when(sqlStatementContext.getSqlStatement()).thenReturn(selectStatement);
when(sqlStatementContext.isContainsJoinQuery()).thenReturn(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.sqlserver.SQLServerStatement;

import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
Expand Down Expand Up @@ -75,7 +76,7 @@ private boolean containsRowNumberPagination(final SelectStatement selectStatemen
}

private Optional<TopProjectionSegment> findTopProjection(final SelectStatement selectStatement) {
List<SubqueryTableSegment> subqueryTableSegments = SQLUtils.getSubqueryTableSegmentFromTableSegment(selectStatement.getFrom());
List<SubqueryTableSegment> subqueryTableSegments = selectStatement.getFrom().map(SQLUtils::getSubqueryTableSegmentFromTableSegment).orElse(Collections.emptyList());
for (SubqueryTableSegment subquery : subqueryTableSegments) {
SelectStatement subquerySelect = subquery.getSubquery().getSelect();
for (ProjectionSegment each : subquerySelect.getProjections().getProjections()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ShorthandProjectionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.SubqueryProjectionSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.OwnerSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment;
import org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;

import java.util.Collection;
Expand All @@ -61,13 +60,12 @@ public final class ProjectionEngine {
/**
* Create projection.
*
* @param table table segment
* @param projectionSegment projection segment
* @return projection
*/
public Optional<Projection> createProjection(final TableSegment table, final ProjectionSegment projectionSegment) {
public Optional<Projection> createProjection(final ProjectionSegment projectionSegment) {
if (projectionSegment instanceof ShorthandProjectionSegment) {
return Optional.of(createProjection(table, (ShorthandProjectionSegment) projectionSegment));
return Optional.of(createProjection((ShorthandProjectionSegment) projectionSegment));
}
if (projectionSegment instanceof ColumnProjectionSegment) {
return Optional.of(createProjection((ColumnProjectionSegment) projectionSegment));
Expand All @@ -82,7 +80,7 @@ public Optional<Projection> createProjection(final TableSegment table, final Pro
return Optional.of(createProjection((AggregationProjectionSegment) projectionSegment));
}
if (projectionSegment instanceof SubqueryProjectionSegment) {
return Optional.of(createProjection(table, (SubqueryProjectionSegment) projectionSegment));
return Optional.of(createProjection((SubqueryProjectionSegment) projectionSegment));
}
if (projectionSegment instanceof ParameterMarkerExpressionSegment) {
return Optional.of(createProjection((ParameterMarkerExpressionSegment) projectionSegment));
Expand All @@ -94,16 +92,16 @@ private ParameterMarkerProjection createProjection(final ParameterMarkerExpressi
return new ParameterMarkerProjection(projectionSegment.getParameterMarkerIndex(), projectionSegment.getParameterMarkerType(), projectionSegment.getAlias().orElse(null));
}

private SubqueryProjection createProjection(final TableSegment table, final SubqueryProjectionSegment projectionSegment) {
Projection subqueryProjection = createProjection(table, projectionSegment.getSubquery().getSelect().getProjections().getProjections().iterator().next())
private SubqueryProjection createProjection(final SubqueryProjectionSegment projectionSegment) {
Projection subqueryProjection = createProjection(projectionSegment.getSubquery().getSelect().getProjections().getProjections().iterator().next())
.orElseThrow(() -> new IllegalArgumentException("Subquery projection must have at least one projection column."));
return new SubqueryProjection(projectionSegment, subqueryProjection, projectionSegment.getAlias().orElse(null), databaseType);
}

private ShorthandProjection createProjection(final TableSegment table, final ShorthandProjectionSegment projectionSegment) {
private ShorthandProjection createProjection(final ShorthandProjectionSegment projectionSegment) {
IdentifierValue owner = projectionSegment.getOwner().map(OwnerSegment::getIdentifier).orElse(null);
Collection<Projection> projections = new LinkedHashSet<>();
projectionSegment.getActualProjectionSegments().forEach(each -> createProjection(table, each).ifPresent(projections::add));
projectionSegment.getActualProjectionSegments().forEach(each -> createProjection(each).ifPresent(projections::add));
return new ShorthandProjection(owner, projections);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.IndexOrderByItemSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.OrderByItemSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.TextOrderByItemSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment;
import org.apache.shardingsphere.sql.parser.sql.common.util.SQLUtils;
import org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;

Expand All @@ -55,25 +54,23 @@ public ProjectionsContextEngine(final DatabaseType databaseType) {
/**
* Create projections context.
*
* @param table table segment
* @param projectionsSegment projection segments
* @param groupByContext group by context
* @param orderByContext order by context
* @return projections context
*/
public ProjectionsContext createProjectionsContext(final TableSegment table, final ProjectionsSegment projectionsSegment,
final GroupByContext groupByContext, final OrderByContext orderByContext) {
Collection<Projection> projections = getProjections(table, projectionsSegment);
public ProjectionsContext createProjectionsContext(final ProjectionsSegment projectionsSegment, final GroupByContext groupByContext, final OrderByContext orderByContext) {
Collection<Projection> projections = getProjections(projectionsSegment);
ProjectionsContext result = new ProjectionsContext(projectionsSegment.getStartIndex(), projectionsSegment.getStopIndex(), projectionsSegment.isDistinctRow(), projections);
result.getProjections().addAll(getDerivedGroupByColumns(groupByContext, projections));
result.getProjections().addAll(getDerivedOrderByColumns(orderByContext, projections));
return result;
}

private Collection<Projection> getProjections(final TableSegment table, final ProjectionsSegment projectionsSegment) {
private Collection<Projection> getProjections(final ProjectionsSegment projectionsSegment) {
Collection<Projection> result = new LinkedList<>();
for (ProjectionSegment each : projectionsSegment.getProjections()) {
projectionEngine.createProjection(table, each).ifPresent(result::add);
projectionEngine.createProjection(each).ifPresent(result::add);
}
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public final class SubqueryTableContextEngine {
*/
public Map<String, SubqueryTableContext> createSubqueryTableContexts(final SelectStatementContext subqueryContext, final String aliasName) {
Map<String, SubqueryTableContext> result = new LinkedHashMap<>();
TableSegment tableSegment = subqueryContext.getSqlStatement().getFrom();
TableSegment tableSegment = subqueryContext.getSqlStatement().getFrom().orElse(null);
for (Projection each : subqueryContext.getProjectionsContext().getExpandProjections()) {
if (!(each instanceof ColumnProjection)) {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ public SelectStatementContext(final ShardingSphereMetaData metaData, final List<
tablesContext = new TablesContext(getAllTableSegments(), subqueryContexts, getDatabaseType());
groupByContext = new GroupByContextEngine().createGroupByContext(sqlStatement);
orderByContext = new OrderByContextEngine().createOrderBy(sqlStatement, groupByContext);
projectionsContext = new ProjectionsContextEngine(getDatabaseType())
.createProjectionsContext(getSqlStatement().getFrom(), getSqlStatement().getProjections(), groupByContext, orderByContext);
projectionsContext = new ProjectionsContextEngine(getDatabaseType()).createProjectionsContext(getSqlStatement().getProjections(), groupByContext, orderByContext);
paginationContext = new PaginationContextEngine().createPaginationContext(sqlStatement, projectionsContext, params, whereSegments);
String databaseName = tablesContext.getDatabaseName().orElse(defaultDatabaseName);
containsEnhancedTable = isContainsEnhancedTable(metaData, databaseName, getTablesContext().getTableNames());
Expand Down Expand Up @@ -166,7 +165,7 @@ private Map<Integer, SelectStatementContext> createSubqueryContexts(final Shardi
* @return whether contains join query or not
*/
public boolean isContainsJoinQuery() {
return getSqlStatement().getFrom() instanceof JoinTableSegment;
return getSqlStatement().getFrom().isPresent() && getSqlStatement().getFrom().get() instanceof JoinTableSegment;
}

/**
Expand Down Expand Up @@ -389,7 +388,7 @@ private Collection<TableSegment> getAllTableSegments() {
* @return whether sql statement contains table subquery segment or not
*/
public boolean containsTableSubquery() {
return getSqlStatement().getFrom() instanceof SubqueryTableSegment;
return getSqlStatement().getFrom().isPresent() && getSqlStatement().getFrom().get() instanceof SubqueryTableSegment;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Optional;

/**
* Select statement binder.
Expand All @@ -56,9 +57,9 @@ private SelectStatement bind(final SelectStatement sqlStatement, final ShardingS
statementBinderContext.getExternalTableBinderContexts().putAll(externalTableBinderContexts);
SelectStatementHandler.getWithSegment(sqlStatement).ifPresent(optional -> SelectStatementHandler.setWithSegment(result,
WithSegmentBinder.bind(optional, statementBinderContext, tableBinderContexts, statementBinderContext.getExternalTableBinderContexts())));
TableSegment boundedTableSegment = TableSegmentBinder.bind(sqlStatement.getFrom(), statementBinderContext, tableBinderContexts, outerTableBinderContexts);
result.setFrom(boundedTableSegment);
result.setProjections(ProjectionsSegmentBinder.bind(sqlStatement.getProjections(), statementBinderContext, boundedTableSegment, tableBinderContexts, outerTableBinderContexts));
Optional<TableSegment> boundedTableSegment = sqlStatement.getFrom().map(optional -> TableSegmentBinder.bind(optional, statementBinderContext, tableBinderContexts, outerTableBinderContexts));
boundedTableSegment.ifPresent(result::setFrom);
result.setProjections(ProjectionsSegmentBinder.bind(sqlStatement.getProjections(), statementBinderContext, boundedTableSegment.orElse(null), tableBinderContexts, outerTableBinderContexts));
sqlStatement.getWhere().ifPresent(optional -> result.setWhere(WhereSegmentBinder.bind(optional, statementBinderContext, tableBinderContexts, outerTableBinderContexts)));
// TODO support other segment bind in select statement
sqlStatement.getGroupBy().ifPresent(result::setGroupBy);
Expand Down
Loading

0 comments on commit 0cff1a0

Please sign in to comment.