From 0cff1a0ba86193adfb7de57800f2d31741b67dd5 Mon Sep 17 00:00:00 2001 From: Zhengqiang Duan Date: Fri, 29 Mar 2024 13:41:14 +0800 Subject: [PATCH] Change return value to Optional with SelectStatement#getFrom method and remove useless table segment in ProjectionEngine (#30692) * Change return value to Optional with SelectStatement#getFrom method and remove useless table segment in ProjectionEngine * fix sql parse test * fix unit test * fix unit test * fix checkstyle --- .../engine/condition/ShardingConditions.java | 2 +- ...rdingCreateViewStatementValidatorTest.java | 3 +- .../sharding/rule/ShardingRuleTest.java | 6 ++-- .../engine/PaginationContextEngine.java | 3 +- .../projection/engine/ProjectionEngine.java | 16 +++++----- .../engine/ProjectionsContextEngine.java | 11 +++---- .../engine/SubqueryTableContextEngine.java | 2 +- .../statement/dml/SelectStatementContext.java | 7 ++--- .../statement/dml/SelectStatementBinder.java | 7 +++-- .../engine/ProjectionEngineTest.java | 29 ++++++------------- .../engine/ProjectionsContextEngineTest.java | 23 +++++++-------- .../ddl/AlterViewStatementContextTest.java | 2 +- .../impl/SubqueryTableSegmentBinderTest.java | 7 +++-- .../statement/SelectStatementBinderTest.java | 7 +++-- .../from/impl/SubqueryTableConverter.java | 2 +- .../select/SelectStatementConverter.java | 2 +- .../transaction/util/AutoCommitUtils.java | 2 +- .../statement/MySQLStatementVisitor.java | 4 +-- .../statement/OpenGaussStatementVisitor.java | 21 +++++++------- .../type/OracleDMLStatementVisitor.java | 2 +- .../statement/PostgreSQLStatementVisitor.java | 4 +-- .../sql/common/extractor/TableExtractor.java | 4 +-- .../common/statement/dml/SelectStatement.java | 9 ++++++ .../sql/common/util/ColumnExtractor.java | 5 +--- .../sql/common/util/SubqueryExtractUtils.java | 2 +- .../sql/common/util/WhereExtractUtils.java | 2 +- .../common/util/SubqueryExtractUtilsTest.java | 7 ++--- .../common/util/WhereExtractUtilsTest.java | 2 +- .../data/DatabaseBackendHandlerFactory.java | 3 +- .../HeterogeneousSelectStatementChecker.java | 2 +- .../hbase/result/query/HBaseGetResultSet.java | 13 +++++++-- .../MySQLDialectSaneQueryResultEngine.java | 2 +- .../admin/MySQLAdminExecutorCreator.java | 4 +-- ...MySQLInformationSchemaExecutorFactory.java | 4 +-- .../MySQLMySQLSchemaExecutorFactory.java | 4 +-- ...MySQLPerformanceSchemaExecutorFactory.java | 4 +-- .../admin/MySQLSysSchemaExecutorFactory.java | 4 +-- .../executor/NoResourceShowExecutor.java | 4 +-- .../admin/MySQLAdminExecutorCreatorTest.java | 26 ++++++++--------- .../dml/impl/SelectStatementAssert.java | 7 +++-- 40 files changed, 134 insertions(+), 136 deletions(-) diff --git a/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/ShardingConditions.java b/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/ShardingConditions.java index 3344c2ee6e053..eb9bb54e1049d 100644 --- a/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/ShardingConditions.java +++ b/features/sharding/core/src/main/java/org/apache/shardingsphere/sharding/route/engine/condition/ShardingConditions.java @@ -125,7 +125,7 @@ private boolean isSubqueryContainsShardingCondition(final List new LinkedList<>()).add(each); } for (SelectStatement each : selectStatements) { - if (each.getFrom() instanceof SubqueryTableSegment) { + if (each.getFrom().isPresent() && each.getFrom().get() instanceof SubqueryTableSegment) { continue; } if (!each.getWhere().isPresent() || !startIndexShardingConditions.containsKey(each.getWhere().get().getExpr().getStartIndex())) { diff --git a/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/route/engine/validator/ddl/ShardingCreateViewStatementValidatorTest.java b/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/route/engine/validator/ddl/ShardingCreateViewStatementValidatorTest.java index 62e0ffddec82d..2b32fbfba1651 100644 --- a/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/route/engine/validator/ddl/ShardingCreateViewStatementValidatorTest.java +++ b/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/route/engine/validator/ddl/ShardingCreateViewStatementValidatorTest.java @@ -44,6 +44,7 @@ import org.mockito.quality.Strictness; import java.util.Collections; +import java.util.Optional; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -75,7 +76,7 @@ class ShardingCreateViewStatementValidatorTest { void setUp() { when(createViewStatementContext.getSqlStatement()).thenReturn(createViewStatement); when(createViewStatement.getSelect()).thenReturn(selectStatement); - when(selectStatement.getFrom()).thenReturn(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order")))); + when(selectStatement.getFrom()).thenReturn(Optional.of(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))))); when(createViewStatement.getView()).thenReturn(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("order_view")))); when(routeContext.getRouteUnits().size()).thenReturn(2); } diff --git a/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rule/ShardingRuleTest.java b/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rule/ShardingRuleTest.java index cb5366a2a954a..dd412dac5eb61 100644 --- a/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rule/ShardingRuleTest.java +++ b/features/sharding/core/src/test/java/org/apache/shardingsphere/sharding/rule/ShardingRuleTest.java @@ -600,7 +600,7 @@ void assertIsAllBindingTableWithJoinQueryWithDatabaseJoinCondition() { JoinTableSegment joinTable = mock(JoinTableSegment.class); when(joinTable.getCondition()).thenReturn(condition); MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class); - when(selectStatement.getFrom()).thenReturn(joinTable); + when(selectStatement.getFrom()).thenReturn(Optional.of(joinTable)); SelectStatementContext sqlStatementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS); when(sqlStatementContext.getSqlStatement()).thenReturn(selectStatement); when(sqlStatementContext.isContainsJoinQuery()).thenReturn(true); @@ -622,7 +622,7 @@ void assertIsAllBindingTableWithJoinQueryWithDatabaseJoinConditionInUpperCaseAnd JoinTableSegment joinTable = mock(JoinTableSegment.class); when(joinTable.getCondition()).thenReturn(condition); MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class); - when(selectStatement.getFrom()).thenReturn(joinTable); + when(selectStatement.getFrom()).thenReturn(Optional.of(joinTable)); SelectStatementContext sqlStatementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS); when(sqlStatementContext.getSqlStatement()).thenReturn(selectStatement); when(sqlStatementContext.isContainsJoinQuery()).thenReturn(true); @@ -651,7 +651,7 @@ void assertIsAllBindingTableWithJoinQueryWithDatabaseTableJoinCondition() { BinaryOperationExpression condition = createBinaryOperationExpression(databaseJoin, tableJoin, AND); when(joinTable.getCondition()).thenReturn(condition); MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class); - when(selectStatement.getFrom()).thenReturn(joinTable); + when(selectStatement.getFrom()).thenReturn(Optional.of(joinTable)); SelectStatementContext sqlStatementContext = mock(SelectStatementContext.class, RETURNS_DEEP_STUBS); when(sqlStatementContext.getSqlStatement()).thenReturn(selectStatement); when(sqlStatementContext.isContainsJoinQuery()).thenReturn(true); diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/pagination/engine/PaginationContextEngine.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/pagination/engine/PaginationContextEngine.java index 5d5ccbcd1056d..ea045fa783c63 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/pagination/engine/PaginationContextEngine.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/pagination/engine/PaginationContextEngine.java @@ -32,6 +32,7 @@ import org.apache.shardingsphere.sql.parser.sql.dialect.statement.sqlserver.SQLServerStatement; import java.util.Collection; +import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.Optional; @@ -75,7 +76,7 @@ private boolean containsRowNumberPagination(final SelectStatement selectStatemen } private Optional findTopProjection(final SelectStatement selectStatement) { - List subqueryTableSegments = SQLUtils.getSubqueryTableSegmentFromTableSegment(selectStatement.getFrom()); + List subqueryTableSegments = selectStatement.getFrom().map(SQLUtils::getSubqueryTableSegmentFromTableSegment).orElse(Collections.emptyList()); for (SubqueryTableSegment subquery : subqueryTableSegments) { SelectStatement subquerySelect = subquery.getSubquery().getSelect(); for (ProjectionSegment each : subquerySelect.getProjections().getProjections()) { diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/engine/ProjectionEngine.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/engine/ProjectionEngine.java index 20460a0522935..8523576fabb07 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/engine/ProjectionEngine.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/engine/ProjectionEngine.java @@ -39,7 +39,6 @@ import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ShorthandProjectionSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.SubqueryProjectionSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.OwnerSegment; -import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment; import org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue; import java.util.Collection; @@ -61,13 +60,12 @@ public final class ProjectionEngine { /** * Create projection. * - * @param table table segment * @param projectionSegment projection segment * @return projection */ - public Optional createProjection(final TableSegment table, final ProjectionSegment projectionSegment) { + public Optional createProjection(final ProjectionSegment projectionSegment) { if (projectionSegment instanceof ShorthandProjectionSegment) { - return Optional.of(createProjection(table, (ShorthandProjectionSegment) projectionSegment)); + return Optional.of(createProjection((ShorthandProjectionSegment) projectionSegment)); } if (projectionSegment instanceof ColumnProjectionSegment) { return Optional.of(createProjection((ColumnProjectionSegment) projectionSegment)); @@ -82,7 +80,7 @@ public Optional createProjection(final TableSegment table, final Pro return Optional.of(createProjection((AggregationProjectionSegment) projectionSegment)); } if (projectionSegment instanceof SubqueryProjectionSegment) { - return Optional.of(createProjection(table, (SubqueryProjectionSegment) projectionSegment)); + return Optional.of(createProjection((SubqueryProjectionSegment) projectionSegment)); } if (projectionSegment instanceof ParameterMarkerExpressionSegment) { return Optional.of(createProjection((ParameterMarkerExpressionSegment) projectionSegment)); @@ -94,16 +92,16 @@ private ParameterMarkerProjection createProjection(final ParameterMarkerExpressi return new ParameterMarkerProjection(projectionSegment.getParameterMarkerIndex(), projectionSegment.getParameterMarkerType(), projectionSegment.getAlias().orElse(null)); } - private SubqueryProjection createProjection(final TableSegment table, final SubqueryProjectionSegment projectionSegment) { - Projection subqueryProjection = createProjection(table, projectionSegment.getSubquery().getSelect().getProjections().getProjections().iterator().next()) + private SubqueryProjection createProjection(final SubqueryProjectionSegment projectionSegment) { + Projection subqueryProjection = createProjection(projectionSegment.getSubquery().getSelect().getProjections().getProjections().iterator().next()) .orElseThrow(() -> new IllegalArgumentException("Subquery projection must have at least one projection column.")); return new SubqueryProjection(projectionSegment, subqueryProjection, projectionSegment.getAlias().orElse(null), databaseType); } - private ShorthandProjection createProjection(final TableSegment table, final ShorthandProjectionSegment projectionSegment) { + private ShorthandProjection createProjection(final ShorthandProjectionSegment projectionSegment) { IdentifierValue owner = projectionSegment.getOwner().map(OwnerSegment::getIdentifier).orElse(null); Collection projections = new LinkedHashSet<>(); - projectionSegment.getActualProjectionSegments().forEach(each -> createProjection(table, each).ifPresent(projections::add)); + projectionSegment.getActualProjectionSegments().forEach(each -> createProjection(each).ifPresent(projections::add)); return new ShorthandProjection(owner, projections); } diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/engine/ProjectionsContextEngine.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/engine/ProjectionsContextEngine.java index 6231d4daceb1e..4cb28a2b94bdb 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/engine/ProjectionsContextEngine.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/engine/ProjectionsContextEngine.java @@ -34,7 +34,6 @@ import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.IndexOrderByItemSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.OrderByItemSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.TextOrderByItemSegment; -import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment; import org.apache.shardingsphere.sql.parser.sql.common.util.SQLUtils; import org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue; @@ -55,25 +54,23 @@ public ProjectionsContextEngine(final DatabaseType databaseType) { /** * Create projections context. * - * @param table table segment * @param projectionsSegment projection segments * @param groupByContext group by context * @param orderByContext order by context * @return projections context */ - public ProjectionsContext createProjectionsContext(final TableSegment table, final ProjectionsSegment projectionsSegment, - final GroupByContext groupByContext, final OrderByContext orderByContext) { - Collection projections = getProjections(table, projectionsSegment); + public ProjectionsContext createProjectionsContext(final ProjectionsSegment projectionsSegment, final GroupByContext groupByContext, final OrderByContext orderByContext) { + Collection projections = getProjections(projectionsSegment); ProjectionsContext result = new ProjectionsContext(projectionsSegment.getStartIndex(), projectionsSegment.getStopIndex(), projectionsSegment.isDistinctRow(), projections); result.getProjections().addAll(getDerivedGroupByColumns(groupByContext, projections)); result.getProjections().addAll(getDerivedOrderByColumns(orderByContext, projections)); return result; } - private Collection getProjections(final TableSegment table, final ProjectionsSegment projectionsSegment) { + private Collection getProjections(final ProjectionsSegment projectionsSegment) { Collection result = new LinkedList<>(); for (ProjectionSegment each : projectionsSegment.getProjections()) { - projectionEngine.createProjection(table, each).ifPresent(result::add); + projectionEngine.createProjection(each).ifPresent(result::add); } return result; } diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/subquery/engine/SubqueryTableContextEngine.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/subquery/engine/SubqueryTableContextEngine.java index 7a1cab950f499..fb06534e050d0 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/subquery/engine/SubqueryTableContextEngine.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/select/subquery/engine/SubqueryTableContextEngine.java @@ -44,7 +44,7 @@ public final class SubqueryTableContextEngine { */ public Map createSubqueryTableContexts(final SelectStatementContext subqueryContext, final String aliasName) { Map result = new LinkedHashMap<>(); - TableSegment tableSegment = subqueryContext.getSqlStatement().getFrom(); + TableSegment tableSegment = subqueryContext.getSqlStatement().getFrom().orElse(null); for (Projection each : subqueryContext.getProjectionsContext().getExpandProjections()) { if (!(each instanceof ColumnProjection)) { continue; diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/SelectStatementContext.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/SelectStatementContext.java index 3c26672b2aecf..6ee546274dd8b 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/SelectStatementContext.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/SelectStatementContext.java @@ -121,8 +121,7 @@ public SelectStatementContext(final ShardingSphereMetaData metaData, final List< tablesContext = new TablesContext(getAllTableSegments(), subqueryContexts, getDatabaseType()); groupByContext = new GroupByContextEngine().createGroupByContext(sqlStatement); orderByContext = new OrderByContextEngine().createOrderBy(sqlStatement, groupByContext); - projectionsContext = new ProjectionsContextEngine(getDatabaseType()) - .createProjectionsContext(getSqlStatement().getFrom(), getSqlStatement().getProjections(), groupByContext, orderByContext); + projectionsContext = new ProjectionsContextEngine(getDatabaseType()).createProjectionsContext(getSqlStatement().getProjections(), groupByContext, orderByContext); paginationContext = new PaginationContextEngine().createPaginationContext(sqlStatement, projectionsContext, params, whereSegments); String databaseName = tablesContext.getDatabaseName().orElse(defaultDatabaseName); containsEnhancedTable = isContainsEnhancedTable(metaData, databaseName, getTablesContext().getTableNames()); @@ -166,7 +165,7 @@ private Map createSubqueryContexts(final Shardi * @return whether contains join query or not */ public boolean isContainsJoinQuery() { - return getSqlStatement().getFrom() instanceof JoinTableSegment; + return getSqlStatement().getFrom().isPresent() && getSqlStatement().getFrom().get() instanceof JoinTableSegment; } /** @@ -389,7 +388,7 @@ private Collection getAllTableSegments() { * @return whether sql statement contains table subquery segment or not */ public boolean containsTableSubquery() { - return getSqlStatement().getFrom() instanceof SubqueryTableSegment; + return getSqlStatement().getFrom().isPresent() && getSqlStatement().getFrom().get() instanceof SubqueryTableSegment; } @Override diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/SelectStatementBinder.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/SelectStatementBinder.java index 00a44c63e9b4d..c86a9366387cb 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/SelectStatementBinder.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/SelectStatementBinder.java @@ -35,6 +35,7 @@ import java.util.Collections; import java.util.LinkedHashMap; import java.util.Map; +import java.util.Optional; /** * Select statement binder. @@ -56,9 +57,9 @@ private SelectStatement bind(final SelectStatement sqlStatement, final ShardingS statementBinderContext.getExternalTableBinderContexts().putAll(externalTableBinderContexts); SelectStatementHandler.getWithSegment(sqlStatement).ifPresent(optional -> SelectStatementHandler.setWithSegment(result, WithSegmentBinder.bind(optional, statementBinderContext, tableBinderContexts, statementBinderContext.getExternalTableBinderContexts()))); - TableSegment boundedTableSegment = TableSegmentBinder.bind(sqlStatement.getFrom(), statementBinderContext, tableBinderContexts, outerTableBinderContexts); - result.setFrom(boundedTableSegment); - result.setProjections(ProjectionsSegmentBinder.bind(sqlStatement.getProjections(), statementBinderContext, boundedTableSegment, tableBinderContexts, outerTableBinderContexts)); + Optional boundedTableSegment = sqlStatement.getFrom().map(optional -> TableSegmentBinder.bind(optional, statementBinderContext, tableBinderContexts, outerTableBinderContexts)); + boundedTableSegment.ifPresent(result::setFrom); + result.setProjections(ProjectionsSegmentBinder.bind(sqlStatement.getProjections(), statementBinderContext, boundedTableSegment.orElse(null), tableBinderContexts, outerTableBinderContexts)); sqlStatement.getWhere().ifPresent(optional -> result.setWhere(WhereSegmentBinder.bind(optional, statementBinderContext, tableBinderContexts, outerTableBinderContexts))); // TODO support other segment bind in select statement sqlStatement.getGroupBy().ifPresent(result::setGroupBy); diff --git a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/engine/ProjectionEngineTest.java b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/engine/ProjectionEngineTest.java index d67cabdb3608d..329a6beb1e707 100644 --- a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/engine/ProjectionEngineTest.java +++ b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/engine/ProjectionEngineTest.java @@ -36,7 +36,6 @@ import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ShorthandProjectionSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.AliasSegment; import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.OwnerSegment; -import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment; import org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -49,7 +48,6 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.Mockito.mock; @ExtendWith(MockitoExtension.class) class ProjectionEngineTest { @@ -58,16 +56,14 @@ class ProjectionEngineTest { @Test void assertCreateProjectionWhenProjectionSegmentNotMatched() { - assertFalse(new ProjectionEngine( - databaseType).createProjection(mock(TableSegment.class), null).isPresent()); + assertFalse(new ProjectionEngine(databaseType).createProjection(null).isPresent()); } @Test void assertCreateProjectionWhenProjectionSegmentInstanceOfShorthandProjectionSegment() { ShorthandProjectionSegment shorthandProjectionSegment = new ShorthandProjectionSegment(0, 0); shorthandProjectionSegment.setOwner(new OwnerSegment(0, 0, new IdentifierValue("tbl"))); - Optional actual = new ProjectionEngine( - databaseType).createProjection(mock(TableSegment.class), shorthandProjectionSegment); + Optional actual = new ProjectionEngine(databaseType).createProjection(shorthandProjectionSegment); assertTrue(actual.isPresent()); assertThat(actual.get(), instanceOf(ShorthandProjection.class)); } @@ -76,8 +72,7 @@ void assertCreateProjectionWhenProjectionSegmentInstanceOfShorthandProjectionSeg void assertCreateProjectionWhenProjectionSegmentInstanceOfColumnProjectionSegment() { ColumnProjectionSegment columnProjectionSegment = new ColumnProjectionSegment(new ColumnSegment(0, 10, new IdentifierValue("name"))); columnProjectionSegment.setAlias(new AliasSegment(0, 0, new IdentifierValue("alias"))); - Optional actual = new ProjectionEngine( - databaseType).createProjection(mock(TableSegment.class), columnProjectionSegment); + Optional actual = new ProjectionEngine(databaseType).createProjection(columnProjectionSegment); assertTrue(actual.isPresent()); assertThat(actual.get(), instanceOf(ColumnProjection.class)); } @@ -85,8 +80,7 @@ void assertCreateProjectionWhenProjectionSegmentInstanceOfColumnProjectionSegmen @Test void assertCreateProjectionWhenProjectionSegmentInstanceOfExpressionProjectionSegment() { ExpressionProjectionSegment expressionProjectionSegment = new ExpressionProjectionSegment(0, 10, "text"); - Optional actual = new ProjectionEngine( - databaseType).createProjection(mock(TableSegment.class), expressionProjectionSegment); + Optional actual = new ProjectionEngine(databaseType).createProjection(expressionProjectionSegment); assertTrue(actual.isPresent()); assertThat(actual.get(), instanceOf(ExpressionProjection.class)); } @@ -94,8 +88,7 @@ void assertCreateProjectionWhenProjectionSegmentInstanceOfExpressionProjectionSe @Test void assertCreateProjectionWhenProjectionSegmentInstanceOfAggregationDistinctProjectionSegment() { AggregationDistinctProjectionSegment aggregationDistinctProjectionSegment = new AggregationDistinctProjectionSegment(0, 10, AggregationType.COUNT, "(1)", "distinctExpression"); - Optional actual = new ProjectionEngine( - databaseType).createProjection(mock(TableSegment.class), aggregationDistinctProjectionSegment); + Optional actual = new ProjectionEngine(databaseType).createProjection(aggregationDistinctProjectionSegment); assertTrue(actual.isPresent()); assertThat(actual.get(), instanceOf(AggregationDistinctProjection.class)); } @@ -103,8 +96,7 @@ void assertCreateProjectionWhenProjectionSegmentInstanceOfAggregationDistinctPro @Test void assertCreateProjectionWhenProjectionSegmentInstanceOfAggregationProjectionSegment() { AggregationProjectionSegment aggregationProjectionSegment = new AggregationProjectionSegment(0, 10, AggregationType.COUNT, "COUNT(1)"); - Optional actual = new ProjectionEngine( - databaseType).createProjection(mock(TableSegment.class), aggregationProjectionSegment); + Optional actual = new ProjectionEngine(databaseType).createProjection(aggregationProjectionSegment); assertTrue(actual.isPresent()); assertThat(actual.get(), instanceOf(AggregationProjection.class)); } @@ -112,8 +104,7 @@ void assertCreateProjectionWhenProjectionSegmentInstanceOfAggregationProjectionS @Test void assertCreateProjectionWhenProjectionSegmentInstanceOfAggregationDistinctProjectionSegmentAndAggregationTypeIsAvg() { AggregationDistinctProjectionSegment aggregationDistinctProjectionSegment = new AggregationDistinctProjectionSegment(0, 10, AggregationType.AVG, "(1)", "distinctExpression"); - Optional actual = new ProjectionEngine( - databaseType).createProjection(mock(TableSegment.class), aggregationDistinctProjectionSegment); + Optional actual = new ProjectionEngine(databaseType).createProjection(aggregationDistinctProjectionSegment); assertTrue(actual.isPresent()); assertThat(actual.get(), instanceOf(AggregationDistinctProjection.class)); } @@ -121,8 +112,7 @@ void assertCreateProjectionWhenProjectionSegmentInstanceOfAggregationDistinctPro @Test void assertCreateProjectionWhenProjectionSegmentInstanceOfAggregationProjectionSegmentAndAggregationTypeIsAvg() { AggregationProjectionSegment aggregationProjectionSegment = new AggregationProjectionSegment(0, 10, AggregationType.AVG, "AVG(1)"); - Optional actual = new ProjectionEngine( - databaseType).createProjection(mock(TableSegment.class), aggregationProjectionSegment); + Optional actual = new ProjectionEngine(databaseType).createProjection(aggregationProjectionSegment); assertTrue(actual.isPresent()); assertThat(actual.get(), instanceOf(AggregationProjection.class)); } @@ -131,8 +121,7 @@ void assertCreateProjectionWhenProjectionSegmentInstanceOfAggregationProjectionS void assertCreateProjectionWhenProjectionSegmentInstanceOfParameterMarkerExpressionSegment() { ParameterMarkerExpressionSegment parameterMarkerExpressionSegment = new ParameterMarkerExpressionSegment(7, 7, 0); parameterMarkerExpressionSegment.setAlias(new AliasSegment(0, 0, new IdentifierValue("alias"))); - Optional actual = new ProjectionEngine( - databaseType).createProjection(mock(TableSegment.class), parameterMarkerExpressionSegment); + Optional actual = new ProjectionEngine(databaseType).createProjection(parameterMarkerExpressionSegment); assertTrue(actual.isPresent()); assertThat(actual.get(), instanceOf(ParameterMarkerProjection.class)); assertThat(actual.get().getAlias().map(IdentifierValue::getValue).orElse(null), is("alias")); diff --git a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/engine/ProjectionsContextEngineTest.java b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/engine/ProjectionsContextEngineTest.java index 46029da585b55..f76a699c8c1f4 100644 --- a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/engine/ProjectionsContextEngineTest.java +++ b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/segment/select/projection/engine/ProjectionsContextEngineTest.java @@ -97,8 +97,8 @@ private void assertProjectionsContextCreatedProperly(final SelectStatement selec SelectStatementContext selectStatementContext = createSelectStatementContext(selectStatement); ProjectionsSegment projectionsSegment = selectStatement.getProjections(); ProjectionsContextEngine projectionsContextEngine = new ProjectionsContextEngine(selectStatementContext.getDatabaseType()); - ProjectionsContext actual = projectionsContextEngine - .createProjectionsContext(selectStatement.getFrom(), projectionsSegment, new GroupByContext(Collections.emptyList()), new OrderByContext(Collections.emptyList(), false)); + ProjectionsContext actual = + projectionsContextEngine.createProjectionsContext(projectionsSegment, new GroupByContext(Collections.emptyList()), new OrderByContext(Collections.emptyList(), false)); assertNotNull(actual); } @@ -136,8 +136,7 @@ private void assertProjectionsContextCreatedProperlyWhenProjectionPresent(final projectionsSegment.getProjections().add(shorthandProjectionSegment); SelectStatementContext selectStatementContext = createSelectStatementContext(selectStatement); ProjectionsContextEngine engine = new ProjectionsContextEngine(selectStatementContext.getDatabaseType()); - ProjectionsContext actual = engine.createProjectionsContext(selectStatement.getFrom(), - projectionsSegment, new GroupByContext(Collections.emptyList()), new OrderByContext(Collections.emptyList(), false)); + ProjectionsContext actual = engine.createProjectionsContext(projectionsSegment, new GroupByContext(Collections.emptyList()), new OrderByContext(Collections.emptyList(), false)); assertNotNull(actual); } @@ -177,7 +176,7 @@ private void createProjectionsContextWhenOrderByContextOrderItemsPresent(final S OrderByContext orderByContext = new OrderByContext(Collections.singletonList(orderByItem), true); SelectStatementContext selectStatementContext = createSelectStatementContext(selectStatement); ProjectionsContextEngine engine = new ProjectionsContextEngine(selectStatementContext.getDatabaseType()); - ProjectionsContext actual = engine.createProjectionsContext(selectStatement.getFrom(), projectionsSegment, new GroupByContext(Collections.emptyList()), orderByContext); + ProjectionsContext actual = engine.createProjectionsContext(projectionsSegment, new GroupByContext(Collections.emptyList()), orderByContext); assertNotNull(actual); } @@ -217,7 +216,7 @@ private void assertCreateProjectionsContextWithoutIndexOrderByItemSegment(final OrderByContext orderByContext = new OrderByContext(Collections.singletonList(orderByItem), true); SelectStatementContext selectStatementContext = createSelectStatementContext(selectStatement); ProjectionsContextEngine engine = new ProjectionsContextEngine(selectStatementContext.getDatabaseType()); - ProjectionsContext actual = engine.createProjectionsContext(selectStatement.getFrom(), projectionsSegment, new GroupByContext(Collections.emptyList()), orderByContext); + ProjectionsContext actual = engine.createProjectionsContext(projectionsSegment, new GroupByContext(Collections.emptyList()), orderByContext); assertNotNull(actual); } @@ -264,7 +263,7 @@ private void assertCreateProjectionsContextWhenColumnOrderByItemSegmentOwnerAbse OrderByContext orderByContext = new OrderByContext(Collections.singletonList(orderByItem), true); SelectStatementContext selectStatementContext = createSelectStatementContext(selectStatement); ProjectionsContextEngine engine = new ProjectionsContextEngine(selectStatementContext.getDatabaseType()); - ProjectionsContext actual = engine.createProjectionsContext(selectStatement.getFrom(), projectionsSegment, new GroupByContext(Collections.emptyList()), orderByContext); + ProjectionsContext actual = engine.createProjectionsContext(projectionsSegment, new GroupByContext(Collections.emptyList()), orderByContext); assertNotNull(actual); } @@ -305,7 +304,7 @@ private void assertCreateProjectionsContextWhenColumnOrderByItemSegmentOwnerPres OrderByContext orderByContext = new OrderByContext(Collections.singletonList(orderByItem), true); SelectStatementContext selectStatementContext = createSelectStatementContext(selectStatement); ProjectionsContextEngine engine = new ProjectionsContextEngine(selectStatementContext.getDatabaseType()); - ProjectionsContext actual = engine.createProjectionsContext(selectStatement.getFrom(), projectionsSegment, new GroupByContext(Collections.emptyList()), orderByContext); + ProjectionsContext actual = engine.createProjectionsContext(projectionsSegment, new GroupByContext(Collections.emptyList()), orderByContext); assertNotNull(actual); } @@ -354,7 +353,7 @@ private void assertCreateProjectionsContextWhenColumnOrderByItemSegmentOwnerPres OrderByContext orderByContext = new OrderByContext(Collections.singleton(orderByItem), false); SelectStatementContext selectStatementContext = createSelectStatementContext(selectStatement); ProjectionsContextEngine engine = new ProjectionsContextEngine(selectStatementContext.getDatabaseType()); - ProjectionsContext actual = engine.createProjectionsContext(selectStatement.getFrom(), projectionsSegment, new GroupByContext(Collections.emptyList()), orderByContext); + ProjectionsContext actual = engine.createProjectionsContext(projectionsSegment, new GroupByContext(Collections.emptyList()), orderByContext); assertNotNull(actual); } @@ -399,7 +398,7 @@ private void assertCreateProjectionsContextWithTemporaryTable(final SelectStatem GroupByContext groupByContext = new GroupByContext(Collections.singleton(groupByItem)); SelectStatementContext selectStatementContext = createSelectStatementContext(selectStatement); ProjectionsContextEngine engine = new ProjectionsContextEngine(selectStatementContext.getDatabaseType()); - ProjectionsContext actual = engine.createProjectionsContext(selectStatement.getFrom(), projectionsSegment, groupByContext, new OrderByContext(Collections.emptyList(), false)); + ProjectionsContext actual = engine.createProjectionsContext(projectionsSegment, groupByContext, new OrderByContext(Collections.emptyList(), false)); assertNotNull(actual); } @@ -444,7 +443,7 @@ private void assertCreateProjectionsContextWhenTableNameOrAliasIgnoreCase(final GroupByContext groupByContext = new GroupByContext(Collections.singleton(groupByItem)); SelectStatementContext selectStatementContext = createSelectStatementContext(selectStatement); ProjectionsContextEngine engine = new ProjectionsContextEngine(selectStatementContext.getDatabaseType()); - ProjectionsContext actual = engine.createProjectionsContext(selectStatement.getFrom(), projectionsSegment, groupByContext, new OrderByContext(Collections.emptyList(), false)); + ProjectionsContext actual = engine.createProjectionsContext(projectionsSegment, groupByContext, new OrderByContext(Collections.emptyList(), false)); assertNotNull(actual); } @@ -459,7 +458,7 @@ void assertCreateProjectionsContextWithOrderByExpressionForMySQL() { OrderByContext orderByContext = new OrderByContext(Collections.singleton(orderByItem), false); SelectStatementContext selectStatementContext = createSelectStatementContext(selectStatement); ProjectionsContextEngine engine = new ProjectionsContextEngine(selectStatementContext.getDatabaseType()); - ProjectionsContext actual = engine.createProjectionsContext(selectStatement.getFrom(), projectionsSegment, new GroupByContext(Collections.emptyList()), orderByContext); + ProjectionsContext actual = engine.createProjectionsContext(projectionsSegment, new GroupByContext(Collections.emptyList()), orderByContext); assertThat(actual.getProjections().size(), is(2)); } } diff --git a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/statement/ddl/AlterViewStatementContextTest.java b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/statement/ddl/AlterViewStatementContextTest.java index 64cd0c568e62d..bdab08dd096e8 100644 --- a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/statement/ddl/AlterViewStatementContextTest.java +++ b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/context/statement/ddl/AlterViewStatementContextTest.java @@ -49,7 +49,7 @@ void setUp() { @Test void assertMySQLNewInstance() { SelectStatement select = mock(MySQLSelectStatement.class); - when(select.getFrom()).thenReturn(view); + when(select.getFrom()).thenReturn(Optional.of(view)); MySQLAlterViewStatement alterViewStatement = mock(MySQLAlterViewStatement.class); when(alterViewStatement.getView()).thenReturn(view); when(alterViewStatement.getSelect()).thenReturn(select); diff --git a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/segment/from/impl/SubqueryTableSegmentBinderTest.java b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/segment/from/impl/SubqueryTableSegmentBinderTest.java index c20707a341cd1..02da20dbc3fca 100644 --- a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/segment/from/impl/SubqueryTableSegmentBinderTest.java +++ b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/segment/from/impl/SubqueryTableSegmentBinderTest.java @@ -46,6 +46,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; @@ -63,7 +64,7 @@ class SubqueryTableSegmentBinderTest { void assertBindWithSubqueryTableAlias() { MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class); when(selectStatement.getDatabaseType()).thenReturn(databaseType); - when(selectStatement.getFrom()).thenReturn(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order")))); + when(selectStatement.getFrom()).thenReturn(Optional.of(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))))); ProjectionsSegment projectionsSegment = new ProjectionsSegment(0, 0); projectionsSegment.getProjections().add(new ShorthandProjectionSegment(0, 0)); when(selectStatement.getProjections()).thenReturn(projectionsSegment); @@ -95,7 +96,7 @@ void assertBindWithSubqueryTableAlias() { void assertBindWithSubqueryProjectionAlias() { MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class); when(selectStatement.getDatabaseType()).thenReturn(databaseType); - when(selectStatement.getFrom()).thenReturn(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order")))); + when(selectStatement.getFrom()).thenReturn(Optional.of(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))))); ProjectionsSegment projectionsSegment = new ProjectionsSegment(0, 0); ColumnProjectionSegment columnProjectionSegment = new ColumnProjectionSegment(new ColumnSegment(0, 0, new IdentifierValue("order_id"))); columnProjectionSegment.setAlias(new AliasSegment(0, 0, new IdentifierValue("order_id_alias"))); @@ -121,7 +122,7 @@ void assertBindWithSubqueryProjectionAlias() { void assertBindWithoutSubqueryTableAlias() { MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class); when(selectStatement.getDatabaseType()).thenReturn(databaseType); - when(selectStatement.getFrom()).thenReturn(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order")))); + when(selectStatement.getFrom()).thenReturn(Optional.of(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))))); ProjectionsSegment projectionsSegment = new ProjectionsSegment(0, 0); projectionsSegment.getProjections().add(new ShorthandProjectionSegment(0, 0)); when(selectStatement.getProjections()).thenReturn(projectionsSegment); diff --git a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/SelectStatementBinderTest.java b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/SelectStatementBinderTest.java index 412183ce0aa33..39f75232284b0 100644 --- a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/SelectStatementBinderTest.java +++ b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/SelectStatementBinderTest.java @@ -83,9 +83,10 @@ void assertBind() { selectStatement.setWhere(mockWhereSegment()); SelectStatement actual = new SelectStatementBinder().bind(selectStatement, createMetaData(), DefaultDatabase.LOGIC_NAME); assertThat(actual, not(selectStatement)); - assertThat(actual.getFrom(), not(selectStatement.getFrom())); - assertThat(actual.getFrom(), instanceOf(SimpleTableSegment.class)); - assertThat(((SimpleTableSegment) actual.getFrom()).getTableName(), not(simpleTableSegment.getTableName())); + assertTrue(actual.getFrom().isPresent()); + assertThat(actual.getFrom().get(), not(simpleTableSegment)); + assertThat(actual.getFrom().get(), instanceOf(SimpleTableSegment.class)); + assertThat(((SimpleTableSegment) actual.getFrom().get()).getTableName(), not(simpleTableSegment.getTableName())); assertThat(actual.getProjections(), not(selectStatement.getProjections())); List actualProjections = new ArrayList<>(actual.getProjections().getProjections()); assertThat(actualProjections, not(selectStatement.getProjections())); diff --git a/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/segment/from/impl/SubqueryTableConverter.java b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/segment/from/impl/SubqueryTableConverter.java index 92b73d0687f66..27a6e405635e7 100644 --- a/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/segment/from/impl/SubqueryTableConverter.java +++ b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/segment/from/impl/SubqueryTableConverter.java @@ -53,7 +53,7 @@ public static Optional convert(final SubqueryTableSegment segment) { } Collection sqlNodes = new LinkedList<>(); if (null == segment.getSubquery().getSelect().getProjections()) { - List tables = TableConverter.convert(segment.getSubquery().getSelect().getFrom()).map(Collections::singletonList).orElseGet(Collections::emptyList); + List tables = segment.getSubquery().getSelect().getFrom().flatMap(TableConverter::convert).map(Collections::singletonList).orElseGet(Collections::emptyList); sqlNodes.add(new SqlBasicCall(SqlStdOperatorTable.EXPLICIT_TABLE, tables, SqlParserPos.ZERO)); } else { sqlNodes.add(new SelectStatementConverter().convert(segment.getSubquery().getSelect())); diff --git a/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/select/SelectStatementConverter.java b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/select/SelectStatementConverter.java index a975d0a023271..9fd87e15c02fb 100644 --- a/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/select/SelectStatementConverter.java +++ b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/select/SelectStatementConverter.java @@ -70,7 +70,7 @@ private SqlNode convertWith(final SqlNode sqlSelect, final SelectStatement selec private SqlSelect convertSelect(final SelectStatement selectStatement) { SqlNodeList distinct = DistinctConverter.convert(selectStatement.getProjections()).orElse(null); SqlNodeList projection = ProjectionsConverter.convert(selectStatement.getProjections()).orElseThrow(IllegalStateException::new); - SqlNode from = TableConverter.convert(selectStatement.getFrom()).orElse(null); + SqlNode from = selectStatement.getFrom().flatMap(TableConverter::convert).orElse(null); SqlNode where = selectStatement.getWhere().flatMap(WhereConverter::convert).orElse(null); SqlNodeList groupBy = selectStatement.getGroupBy().flatMap(GroupByConverter::convert).orElse(null); SqlNode having = selectStatement.getHaving().flatMap(HavingConverter::convert).orElse(null); diff --git a/kernel/transaction/core/src/main/java/org/apache/shardingsphere/transaction/util/AutoCommitUtils.java b/kernel/transaction/core/src/main/java/org/apache/shardingsphere/transaction/util/AutoCommitUtils.java index 6c7c8d88b047e..31c4d5e1b732e 100644 --- a/kernel/transaction/core/src/main/java/org/apache/shardingsphere/transaction/util/AutoCommitUtils.java +++ b/kernel/transaction/core/src/main/java/org/apache/shardingsphere/transaction/util/AutoCommitUtils.java @@ -37,7 +37,7 @@ public final class AutoCommitUtils { * @return need to open a new transaction. */ public static boolean needOpenTransaction(final SQLStatement sqlStatement) { - if (sqlStatement instanceof SelectStatement && null == ((SelectStatement) sqlStatement).getFrom()) { + if (sqlStatement instanceof SelectStatement && !((SelectStatement) sqlStatement).getFrom().isPresent()) { return false; } return sqlStatement instanceof DDLStatement || sqlStatement instanceof DMLStatement; diff --git a/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java b/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java index 3e40e39a27968..632f4e58a2dc4 100644 --- a/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java +++ b/parser/sql/dialect/mysql/src/main/java/org/apache/shardingsphere/sql/parser/mysql/visitor/statement/MySQLStatementVisitor.java @@ -768,7 +768,7 @@ public ASTNode visitQueryExpressionBody(final QueryExpressionBodyContext ctx) { MySQLSelectStatement result = new MySQLSelectStatement(); MySQLSelectStatement left = (MySQLSelectStatement) visit(ctx.queryExpressionBody()); result.setProjections(left.getProjections()); - result.setFrom(left.getFrom()); + left.getFrom().ifPresent(result::setFrom); left.getTable().ifPresent(result::setTable); result.setCombine(createCombineSegment(ctx.combineClause(), left)); return result; @@ -777,7 +777,7 @@ public ASTNode visitQueryExpressionBody(final QueryExpressionBodyContext ctx) { MySQLSelectStatement result = new MySQLSelectStatement(); MySQLSelectStatement left = (MySQLSelectStatement) visit(ctx.queryExpressionParens()); result.setProjections(left.getProjections()); - result.setFrom(left.getFrom()); + left.getFrom().ifPresent(result::setFrom); left.getTable().ifPresent(result::setTable); result.setCombine(createCombineSegment(ctx.combineClause(), left)); return result; diff --git a/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java b/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java index b3335d9a74197..82118214a14e9 100644 --- a/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java +++ b/parser/sql/dialect/opengauss/src/main/java/org/apache/shardingsphere/sql/parser/opengauss/visitor/statement/OpenGaussStatementVisitor.java @@ -17,13 +17,8 @@ package org.apache.shardingsphere.sql.parser.opengauss.visitor.statement; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.LinkedList; -import java.util.List; -import java.util.Optional; - +import lombok.AccessLevel; +import lombok.Getter; import org.antlr.v4.runtime.ParserRuleContext; import org.antlr.v4.runtime.misc.Interval; import org.antlr.v4.runtime.tree.ParseTree; @@ -72,6 +67,7 @@ import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.InsertContext; import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.InsertRestContext; import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.InsertTargetContext; +import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.IntoClauseContext; import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.JoinQualContext; import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.JoinedTableContext; import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.LimitClauseContext; @@ -116,7 +112,6 @@ import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.WindowClauseContext; import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.WindowDefinitionContext; import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.WindowDefinitionListContext; -import org.apache.shardingsphere.sql.parser.autogen.OpenGaussStatementParser.IntoClauseContext; import org.apache.shardingsphere.sql.parser.sql.common.enums.AggregationType; import org.apache.shardingsphere.sql.parser.sql.common.enums.CombineType; import org.apache.shardingsphere.sql.parser.sql.common.enums.JoinType; @@ -201,8 +196,12 @@ import org.apache.shardingsphere.sql.parser.sql.dialect.statement.opengauss.dml.OpenGaussSelectStatement; import org.apache.shardingsphere.sql.parser.sql.dialect.statement.opengauss.dml.OpenGaussUpdateStatement; -import lombok.AccessLevel; -import lombok.Getter; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Optional; /** * Statement visitor for openGauss. @@ -969,7 +968,7 @@ public ASTNode visitSelectClauseN(final SelectClauseNContext ctx) { OpenGaussSelectStatement result = new OpenGaussSelectStatement(); OpenGaussSelectStatement left = (OpenGaussSelectStatement) visit(ctx.selectClauseN(0)); result.setProjections(left.getProjections()); - result.setFrom(left.getFrom()); + left.getFrom().ifPresent(result::setFrom); CombineSegment combineSegment = new CombineSegment(((TerminalNode) ctx.getChild(1)).getSymbol().getStartIndex(), ctx.getStop().getStopIndex(), left, getCombineType(ctx), (OpenGaussSelectStatement) visit(ctx.selectClauseN(1))); result.setCombine(combineSegment); diff --git a/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java b/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java index 1a2a05b9fccca..b86a9bbf0f660 100644 --- a/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java +++ b/parser/sql/dialect/oracle/src/main/java/org/apache/shardingsphere/sql/parser/oracle/visitor/statement/type/OracleDMLStatementVisitor.java @@ -591,7 +591,7 @@ public ASTNode visitSelectSubquery(final SelectSubqueryContext ctx) { result = new OracleSelectStatement(); OracleSelectStatement left = (OracleSelectStatement) visit(ctx.selectSubquery(0)); result.setProjections(left.getProjections()); - result.setFrom(left.getFrom()); + left.getFrom().ifPresent(result::setFrom); createSelectCombineClause(ctx, result, left); } else { result = null != ctx.queryBlock() ? (OracleSelectStatement) visit(ctx.queryBlock()) : (OracleSelectStatement) visit(ctx.parenthesisSelectSubquery()); diff --git a/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java b/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java index f02081615a880..a12503d0da569 100644 --- a/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java +++ b/parser/sql/dialect/postgresql/src/main/java/org/apache/shardingsphere/sql/parser/postgresql/visitor/statement/PostgreSQLStatementVisitor.java @@ -65,6 +65,7 @@ import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.InsertContext; import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.InsertRestContext; import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.InsertTargetContext; +import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.IntoClauseContext; import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.JoinQualContext; import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.JoinedTableContext; import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.LimitClauseContext; @@ -110,7 +111,6 @@ import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.WindowClauseContext; import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.WindowDefinitionContext; import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.WindowDefinitionListContext; -import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParser.IntoClauseContext; import org.apache.shardingsphere.sql.parser.autogen.PostgreSQLStatementParserBaseVisitor; import org.apache.shardingsphere.sql.parser.sql.common.enums.AggregationType; import org.apache.shardingsphere.sql.parser.sql.common.enums.CombineType; @@ -938,7 +938,7 @@ public ASTNode visitSelectClauseN(final SelectClauseNContext ctx) { PostgreSQLSelectStatement result = new PostgreSQLSelectStatement(); PostgreSQLSelectStatement left = (PostgreSQLSelectStatement) visit(ctx.selectClauseN(0)); result.setProjections(left.getProjections()); - result.setFrom(left.getFrom()); + left.getFrom().ifPresent(result::setFrom); CombineSegment combineSegment = new CombineSegment(((TerminalNode) ctx.getChild(1)).getSymbol().getStartIndex(), ctx.getStop().getStopIndex(), left, getCombineType(ctx), (PostgreSQLSelectStatement) visit(ctx.selectClauseN(1))); result.setCombine(combineSegment); diff --git a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java index 551df80356f3c..89e88b8f458dd 100644 --- a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java +++ b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/extractor/TableExtractor.java @@ -83,8 +83,8 @@ public void extractTablesFromSelect(final SelectStatement selectStatement) { extractTablesFromSelect(combineSegment.getLeft()); extractTablesFromSelect(combineSegment.getRight()); } - if (null != selectStatement.getFrom() && !selectStatement.getCombine().isPresent()) { - extractTablesFromTableSegment(selectStatement.getFrom()); + if (selectStatement.getFrom().isPresent() && !selectStatement.getCombine().isPresent()) { + extractTablesFromTableSegment(selectStatement.getFrom().get()); } if (selectStatement.getWhere().isPresent()) { extractTablesFromExpression(selectStatement.getWhere().get().getExpr()); diff --git a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/SelectStatement.java b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/SelectStatement.java index b3532bc43104a..0e6aba5607aee 100644 --- a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/SelectStatement.java +++ b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/SelectStatement.java @@ -51,6 +51,15 @@ public abstract class SelectStatement extends AbstractSQLStatement implements DM private CombineSegment combine; + /** + * Get from. + * + * @return from table segment + */ + public Optional getFrom() { + return Optional.ofNullable(from); + } + /** * Get where. * diff --git a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/ColumnExtractor.java b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/ColumnExtractor.java index 20efbccffcc01..19af0d652e258 100644 --- a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/ColumnExtractor.java +++ b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/ColumnExtractor.java @@ -126,7 +126,7 @@ public static void extractFromSelectStatement(final Collection co * @param containsSubQuery whether contains sub query */ public static void extractFromSelectStatementWithoutProjection(final Collection columnSegments, final SelectStatement statement, final boolean containsSubQuery) { - extractFromTable(columnSegments, statement.getFrom(), containsSubQuery); + statement.getFrom().ifPresent(optional -> extractFromTable(columnSegments, optional, containsSubQuery)); statement.getWhere().ifPresent(optional -> extractFromWhere(columnSegments, optional, containsSubQuery)); statement.getGroupBy().ifPresent(optional -> extractFromGroupBy(columnSegments, optional, containsSubQuery)); statement.getHaving().ifPresent(optional -> extractFromHaving(columnSegments, optional, containsSubQuery)); @@ -170,9 +170,6 @@ public static void extractFromProjections(final Collection column } private static void extractFromTable(final Collection columnSegments, final TableSegment tableSegment, final boolean containsSubQuery) { - if (null == tableSegment) { - return; - } if (tableSegment instanceof CollectionTableSegment) { columnSegments.addAll(ExpressionExtractUtils.extractColumns(((CollectionTableSegment) tableSegment).getExpressionSegment(), containsSubQuery)); } diff --git a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtils.java b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtils.java index e83f5f4f5443b..c3be2352d33ed 100644 --- a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtils.java +++ b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtils.java @@ -64,7 +64,7 @@ public static Collection getSubquerySegments(final SelectStatem private static void extractSubquerySegments(final List result, final SelectStatement selectStatement) { extractSubquerySegmentsFromProjections(result, selectStatement.getProjections()); - extractSubquerySegmentsFromTableSegment(result, selectStatement.getFrom()); + selectStatement.getFrom().ifPresent(optional -> extractSubquerySegmentsFromTableSegment(result, optional)); if (selectStatement.getWhere().isPresent()) { extractSubquerySegmentsFromWhere(result, selectStatement.getWhere().get().getExpr()); } diff --git a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/WhereExtractUtils.java b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/WhereExtractUtils.java index 9ff381a3d4607..e633a54f67566 100644 --- a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/WhereExtractUtils.java +++ b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/WhereExtractUtils.java @@ -43,7 +43,7 @@ public final class WhereExtractUtils { * @return join where segment collection */ public static Collection getJoinWhereSegments(final SelectStatement selectStatement) { - return null == selectStatement.getFrom() ? Collections.emptyList() : getJoinWhereSegments(selectStatement.getFrom()); + return selectStatement.getFrom().map(WhereExtractUtils::getJoinWhereSegments).orElseGet(Collections::emptyList); } private static Collection getJoinWhereSegments(final TableSegment tableSegment) { diff --git a/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtilsTest.java b/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtilsTest.java index 0cf240e4d2d15..425f52e370ae3 100644 --- a/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtilsTest.java +++ b/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/SubqueryExtractUtilsTest.java @@ -99,15 +99,14 @@ void assertGetSubquerySegmentsInFrom1() { subquery.setFrom(new SimpleTableSegment(new TableNameSegment(45, 51, new IdentifierValue("t_order")))); subquery.setProjections(new ProjectionsSegment(31, 38)); subquery.getProjections().getProjections().add(new ColumnProjectionSegment(new ColumnSegment(31, 38, new IdentifierValue("order_id")))); - MySQLSelectStatement selectStatement = new MySQLSelectStatement(); selectStatement.setProjections(new ProjectionsSegment(7, 16)); selectStatement.getProjections().getProjections().add(new ColumnProjectionSegment(new ColumnSegment(7, 16, new IdentifierValue("order_id")))); - selectStatement.setFrom(new SubqueryTableSegment(new SubquerySegment(23, 71, subquery, ""))); - + SubqueryTableSegment subqueryTableSegment = new SubqueryTableSegment(new SubquerySegment(23, 71, subquery, "")); + selectStatement.setFrom(subqueryTableSegment); Collection result = SubqueryExtractUtils.getSubquerySegments(selectStatement); assertThat(result.size(), is(1)); - assertThat(result.iterator().next(), is(((SubqueryTableSegment) selectStatement.getFrom()).getSubquery())); + assertThat(result.iterator().next(), is(subqueryTableSegment.getSubquery())); } @Test diff --git a/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/WhereExtractUtilsTest.java b/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/WhereExtractUtilsTest.java index a59890e02c7a7..9766e40a6c443 100644 --- a/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/WhereExtractUtilsTest.java +++ b/parser/sql/statement/src/test/java/org/apache/shardingsphere/sql/parser/sql/common/util/WhereExtractUtilsTest.java @@ -91,6 +91,6 @@ void assertGetWhereSegmentsFromSubQueryJoin() { selectStatement.setFrom(new SubqueryTableSegment(new SubquerySegment(20, 84, subQuerySelectStatement, ""))); Collection subqueryWhereSegments = WhereExtractUtils.getSubqueryWhereSegments(selectStatement); WhereSegment actual = subqueryWhereSegments.iterator().next(); - assertThat(actual.getExpr(), is(((JoinTableSegment) subQuerySelectStatement.getFrom()).getCondition())); + assertThat(actual.getExpr(), is(joinTableSegment.getCondition())); } } diff --git a/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/data/DatabaseBackendHandlerFactory.java b/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/data/DatabaseBackendHandlerFactory.java index bd534957afb1b..61daa518121f4 100644 --- a/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/data/DatabaseBackendHandlerFactory.java +++ b/proxy/backend/core/src/main/java/org/apache/shardingsphere/proxy/backend/handler/data/DatabaseBackendHandlerFactory.java @@ -57,7 +57,8 @@ public static DatabaseBackendHandler newInstance(final QueryContext queryContext if (sqlStatement instanceof SetStatement && null == connectionSession.getDatabaseName()) { return () -> new UpdateResponseHeader(sqlStatement); } - if (sqlStatement instanceof DALStatement && !isDatabaseRequiredDALStatement(sqlStatement) || sqlStatement instanceof SelectStatement && null == ((SelectStatement) sqlStatement).getFrom()) { + if (sqlStatement instanceof DALStatement && !isDatabaseRequiredDALStatement(sqlStatement) + || sqlStatement instanceof SelectStatement && !((SelectStatement) sqlStatement).getFrom().isPresent()) { return new UnicastDatabaseBackendHandler(queryContext, connectionSession); } return DatabaseConnectorFactory.getInstance().newInstance(queryContext, connectionSession.getDatabaseConnectionManager(), preferPreparedStatement); diff --git a/proxy/backend/type/hbase/src/main/java/org/apache/shardingsphere/proxy/backend/hbase/checker/HeterogeneousSelectStatementChecker.java b/proxy/backend/type/hbase/src/main/java/org/apache/shardingsphere/proxy/backend/hbase/checker/HeterogeneousSelectStatementChecker.java index 75cbf31174415..555d991cbce1b 100644 --- a/proxy/backend/type/hbase/src/main/java/org/apache/shardingsphere/proxy/backend/hbase/checker/HeterogeneousSelectStatementChecker.java +++ b/proxy/backend/type/hbase/src/main/java/org/apache/shardingsphere/proxy/backend/hbase/checker/HeterogeneousSelectStatementChecker.java @@ -62,7 +62,7 @@ public void execute() { } private void checkDoNotSupportedSegment() { - Preconditions.checkArgument(sqlStatement.getFrom() instanceof SimpleTableSegment, "Only supported simple table segment."); + Preconditions.checkArgument(sqlStatement.getFrom().isPresent() && sqlStatement.getFrom().get() instanceof SimpleTableSegment, "Only supported simple table segment."); Preconditions.checkArgument(!sqlStatement.getHaving().isPresent(), "Do not supported having segment."); Preconditions.checkArgument(!sqlStatement.getGroupBy().isPresent(), "Do not supported group by segment."); MySQLSelectStatement selectStatement = (MySQLSelectStatement) sqlStatement; diff --git a/proxy/backend/type/hbase/src/main/java/org/apache/shardingsphere/proxy/backend/hbase/result/query/HBaseGetResultSet.java b/proxy/backend/type/hbase/src/main/java/org/apache/shardingsphere/proxy/backend/hbase/result/query/HBaseGetResultSet.java index 66baad9622727..025d84e734aea 100644 --- a/proxy/backend/type/hbase/src/main/java/org/apache/shardingsphere/proxy/backend/hbase/result/query/HBaseGetResultSet.java +++ b/proxy/backend/type/hbase/src/main/java/org/apache/shardingsphere/proxy/backend/hbase/result/query/HBaseGetResultSet.java @@ -173,9 +173,7 @@ private Map parseResult(final Result result) { private void logExecuteTime(final long startMills) { long endMills = System.currentTimeMillis(); - String tableName = statementContext.getSqlStatement().getFrom() instanceof SimpleTableSegment - ? ((SimpleTableSegment) statementContext.getSqlStatement().getFrom()).getTableName().getIdentifier().getValue() - : statementContext.getSqlStatement().getFrom().toString(); + String tableName = getTableName(); String whereClause = getWhereClause(); if (endMills - startMills > HBaseContext.getInstance().getProps().getValue(HBasePropertyKey.EXECUTE_TIME_OUT)) { log.info(String.format("query hbase table: %s, where case: %s , query %dms time out", tableName, whereClause, endMills - startMills)); @@ -184,6 +182,15 @@ private void logExecuteTime(final long startMills) { } } + private String getTableName() { + if (statementContext.getSqlStatement().getFrom().isPresent()) { + return statementContext.getSqlStatement().getFrom().get() instanceof SimpleTableSegment + ? ((SimpleTableSegment) statementContext.getSqlStatement().getFrom().get()).getTableName().getIdentifier().getValue() + : statementContext.getSqlStatement().getFrom().toString(); + } + return "DUAL"; + } + private String getWhereClause() { if (!statementContext.getSqlStatement().getWhere().isPresent()) { return ""; diff --git a/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/connector/sane/MySQLDialectSaneQueryResultEngine.java b/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/connector/sane/MySQLDialectSaneQueryResultEngine.java index c8f5de5185644..812494c57c198 100644 --- a/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/connector/sane/MySQLDialectSaneQueryResultEngine.java +++ b/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/connector/sane/MySQLDialectSaneQueryResultEngine.java @@ -66,7 +66,7 @@ public Optional getSaneQueryResult(final SQLStatement sqlStatemen } private Optional createQueryResult(final SelectStatement sqlStatement) { - if (null != sqlStatement.getFrom()) { + if (sqlStatement.getFrom().isPresent()) { return Optional.empty(); } List queryResultColumnMetaDataList = new ArrayList<>(sqlStatement.getProjections().getProjections().size()); diff --git a/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLAdminExecutorCreator.java b/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLAdminExecutorCreator.java index 38f9c96c70c63..25fda1aa9df56 100644 --- a/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLAdminExecutorCreator.java +++ b/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLAdminExecutorCreator.java @@ -112,7 +112,7 @@ public Optional create(final SQLStatementContext sqlState } private Optional create(final SelectStatement selectStatement, final String sql, final String databaseName, final List parameters) { - if (null == selectStatement.getFrom()) { + if (!selectStatement.getFrom().isPresent()) { return findAdminExecutorForSelectWithoutFrom(sql, databaseName, selectStatement); } if (isQueryInformationSchema(databaseName)) { @@ -180,7 +180,7 @@ private Optional mockExecutor(final String databaseName, if (hasNoResource()) { return Optional.of(new NoResourceShowExecutor(sqlStatement)); } - boolean isNotUseSchema = null == databaseName && null == sqlStatement.getFrom(); + boolean isNotUseSchema = null == databaseName && !sqlStatement.getFrom().isPresent(); return isNotUseSchema ? Optional.of(new UnicastResourceShowExecutor(sqlStatement, sql)) : Optional.empty(); } diff --git a/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLInformationSchemaExecutorFactory.java b/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLInformationSchemaExecutorFactory.java index 9a4e713c18897..7fb1ee8d316c0 100644 --- a/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLInformationSchemaExecutorFactory.java +++ b/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLInformationSchemaExecutorFactory.java @@ -46,10 +46,10 @@ public final class MySQLInformationSchemaExecutorFactory { * @return executor */ public static Optional newInstance(final SelectStatement sqlStatement, final String sql, final List parameters) { - if (!(sqlStatement.getFrom() instanceof SimpleTableSegment)) { + if (!sqlStatement.getFrom().isPresent() || !(sqlStatement.getFrom().get() instanceof SimpleTableSegment)) { return Optional.empty(); } - String tableName = ((SimpleTableSegment) sqlStatement.getFrom()).getTableName().getIdentifier().getValue(); + String tableName = ((SimpleTableSegment) sqlStatement.getFrom().get()).getTableName().getIdentifier().getValue(); if (SCHEMATA_TABLE.equalsIgnoreCase(tableName)) { return Optional.of(new SelectInformationSchemataExecutor(sqlStatement, sql, parameters)); } diff --git a/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLMySQLSchemaExecutorFactory.java b/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLMySQLSchemaExecutorFactory.java index 3efc4fcbb4486..3b1da1ea50972 100644 --- a/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLMySQLSchemaExecutorFactory.java +++ b/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLMySQLSchemaExecutorFactory.java @@ -43,10 +43,10 @@ public final class MySQLMySQLSchemaExecutorFactory { * @return executor */ public static Optional newInstance(final SelectStatement sqlStatement, final String sql, final List parameters) { - if (!(sqlStatement.getFrom() instanceof SimpleTableSegment)) { + if (!sqlStatement.getFrom().isPresent() || !(sqlStatement.getFrom().get() instanceof SimpleTableSegment)) { return Optional.empty(); } - String tableName = ((SimpleTableSegment) sqlStatement.getFrom()).getTableName().getIdentifier().getValue(); + String tableName = ((SimpleTableSegment) sqlStatement.getFrom().get()).getTableName().getIdentifier().getValue(); if (SystemSchemaManager.isSystemTable("mysql", "mysql", tableName)) { return Optional.of(new DefaultDatabaseMetaDataExecutor(sql, parameters)); } diff --git a/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLPerformanceSchemaExecutorFactory.java b/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLPerformanceSchemaExecutorFactory.java index addb2f1231415..2f4eed48ae687 100644 --- a/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLPerformanceSchemaExecutorFactory.java +++ b/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLPerformanceSchemaExecutorFactory.java @@ -43,10 +43,10 @@ public final class MySQLPerformanceSchemaExecutorFactory { * @return executor */ public static Optional newInstance(final SelectStatement sqlStatement, final String sql, final List parameters) { - if (!(sqlStatement.getFrom() instanceof SimpleTableSegment)) { + if (!sqlStatement.getFrom().isPresent() || !(sqlStatement.getFrom().get() instanceof SimpleTableSegment)) { return Optional.empty(); } - String tableName = ((SimpleTableSegment) sqlStatement.getFrom()).getTableName().getIdentifier().getValue(); + String tableName = ((SimpleTableSegment) sqlStatement.getFrom().get()).getTableName().getIdentifier().getValue(); if (SystemSchemaManager.isSystemTable("mysql", "performance_schema", tableName)) { return Optional.of(new DefaultDatabaseMetaDataExecutor(sql, parameters)); } diff --git a/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLSysSchemaExecutorFactory.java b/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLSysSchemaExecutorFactory.java index 782a0ff9fa11f..74366f24490d0 100644 --- a/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLSysSchemaExecutorFactory.java +++ b/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLSysSchemaExecutorFactory.java @@ -43,10 +43,10 @@ public final class MySQLSysSchemaExecutorFactory { * @return executor */ public static Optional newInstance(final SelectStatement sqlStatement, final String sql, final List parameters) { - if (!(sqlStatement.getFrom() instanceof SimpleTableSegment)) { + if (!sqlStatement.getFrom().isPresent() || !(sqlStatement.getFrom().get() instanceof SimpleTableSegment)) { return Optional.empty(); } - String tableName = ((SimpleTableSegment) sqlStatement.getFrom()).getTableName().getIdentifier().getValue(); + String tableName = ((SimpleTableSegment) sqlStatement.getFrom().get()).getTableName().getIdentifier().getValue(); if (SystemSchemaManager.isSystemTable("mysql", "sys", tableName)) { return Optional.of(new DefaultDatabaseMetaDataExecutor(sql, parameters)); } diff --git a/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/executor/NoResourceShowExecutor.java b/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/executor/NoResourceShowExecutor.java index 588408fc2d214..d3810be8ce4e1 100644 --- a/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/executor/NoResourceShowExecutor.java +++ b/proxy/backend/type/mysql/src/main/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/executor/NoResourceShowExecutor.java @@ -31,7 +31,6 @@ import org.apache.shardingsphere.proxy.backend.handler.admin.executor.DatabaseAdminQueryExecutor; import org.apache.shardingsphere.proxy.backend.session.ConnectionSession; import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ShorthandProjectionSegment; -import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment; import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.SelectStatement; import java.sql.Types; @@ -58,9 +57,8 @@ public final class NoResourceShowExecutor implements DatabaseAdminQueryExecutor @Override public void execute(final ConnectionSession connectionSession) { - TableSegment tableSegment = sqlStatement.getFrom(); expressions = sqlStatement.getProjections().getProjections().stream().filter(each -> !(each instanceof ShorthandProjectionSegment)) - .map(each -> new ProjectionEngine(null).createProjection(tableSegment, each)) + .map(each -> new ProjectionEngine(null).createProjection(each)) .filter(Optional::isPresent).map(each -> each.get().getAlias().isPresent() ? each.get().getAlias().get() : each.get().getExpression()).collect(Collectors.toList()); mergedResult = new TransparentMergedResult(getQueryResult()); } diff --git a/proxy/backend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLAdminExecutorCreatorTest.java b/proxy/backend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLAdminExecutorCreatorTest.java index 9078321c0dd23..f73f2b24edad3 100644 --- a/proxy/backend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLAdminExecutorCreatorTest.java +++ b/proxy/backend/type/mysql/src/test/java/org/apache/shardingsphere/proxy/backend/mysql/handler/admin/MySQLAdminExecutorCreatorTest.java @@ -166,7 +166,7 @@ void assertCreateWithSetStatement() { @Test void assertCreateWithSelectStatementForShowConnectionId() { MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class); - when(selectStatement.getFrom()).thenReturn(null); + when(selectStatement.getFrom()).thenReturn(Optional.empty()); ProjectionsSegment projectionsSegment = mock(ProjectionsSegment.class); when(projectionsSegment.getProjections()).thenReturn(Collections.singletonList(new ExpressionProjectionSegment(0, 10, "CONNECTION_ID()"))); when(selectStatement.getProjections()).thenReturn(projectionsSegment); @@ -179,7 +179,7 @@ void assertCreateWithSelectStatementForShowConnectionId() { @Test void assertCreateWithSelectStatementForShowVersion() { MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class); - when(selectStatement.getFrom()).thenReturn(null); + when(selectStatement.getFrom()).thenReturn(Optional.empty()); ProjectionsSegment projectionsSegment = mock(ProjectionsSegment.class); when(projectionsSegment.getProjections()).thenReturn(Collections.singletonList(new ExpressionProjectionSegment(0, 10, "version()"))); when(selectStatement.getProjections()).thenReturn(projectionsSegment); @@ -192,7 +192,7 @@ void assertCreateWithSelectStatementForShowVersion() { @Test void assertCreateWithSelectStatementForCurrentUser() { MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class); - when(selectStatement.getFrom()).thenReturn(null); + when(selectStatement.getFrom()).thenReturn(Optional.empty()); ProjectionsSegment projectionsSegment = mock(ProjectionsSegment.class); when(projectionsSegment.getProjections()).thenReturn(Collections.singletonList(new ExpressionProjectionSegment(0, 10, "CURRENT_USER()"))); when(selectStatement.getProjections()).thenReturn(projectionsSegment); @@ -206,7 +206,7 @@ void assertCreateWithSelectStatementForCurrentUser() { void assertCreateWithSelectStatementForTransactionReadOnly() { initProxyContext(Collections.emptyMap()); MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class); - when(selectStatement.getFrom()).thenReturn(null); + when(selectStatement.getFrom()).thenReturn(Optional.empty()); ProjectionsSegment projectionsSegment = mock(ProjectionsSegment.class); VariableSegment variableSegment = new VariableSegment(0, 0, "transaction_read_only"); variableSegment.setScope("SESSION"); @@ -222,7 +222,7 @@ void assertCreateWithSelectStatementForTransactionReadOnly() { void assertCreateWithSelectStatementForTransactionIsolation() { initProxyContext(Collections.emptyMap()); MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class); - when(selectStatement.getFrom()).thenReturn(null); + when(selectStatement.getFrom()).thenReturn(Optional.empty()); ProjectionsSegment projectionsSegment = mock(ProjectionsSegment.class); VariableSegment variableSegment = new VariableSegment(0, 0, "transaction_isolation"); variableSegment.setScope("SESSION"); @@ -238,7 +238,7 @@ void assertCreateWithSelectStatementForTransactionIsolation() { void assertCreateWithSelectStatementForShowDatabase() { initProxyContext(Collections.emptyMap()); MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class); - when(selectStatement.getFrom()).thenReturn(null); + when(selectStatement.getFrom()).thenReturn(Optional.empty()); ProjectionsSegment projectionsSegment = mock(ProjectionsSegment.class); when(projectionsSegment.getProjections()).thenReturn(Collections.singletonList(new ExpressionProjectionSegment(0, 10, "DATABASE()"))); when(selectStatement.getProjections()).thenReturn(projectionsSegment); @@ -252,7 +252,7 @@ void assertCreateWithSelectStatementForShowDatabase() { void assertCreateWithOtherSelectStatementForNoResource() { initProxyContext(Collections.emptyMap()); MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class); - when(selectStatement.getFrom()).thenReturn(null); + when(selectStatement.getFrom()).thenReturn(Optional.empty()); ProjectionsSegment projectionsSegment = mock(ProjectionsSegment.class); when(projectionsSegment.getProjections()).thenReturn(Collections.singletonList(new ExpressionProjectionSegment(0, 10, "CURRENT_DATE()"))); when(selectStatement.getProjections()).thenReturn(projectionsSegment); @@ -271,7 +271,7 @@ void assertCreateWithOtherSelectStatementForDatabaseName() { when(ProxyContext.getInstance().getAllDatabaseNames()).thenReturn(Collections.singleton("db_0")); when(ProxyContext.getInstance().getContextManager().getDatabase("db_0")).thenReturn(database); MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class); - when(selectStatement.getFrom()).thenReturn(null); + when(selectStatement.getFrom()).thenReturn(Optional.empty()); ProjectionsSegment projectionsSegment = mock(ProjectionsSegment.class); when(projectionsSegment.getProjections()).thenReturn(Collections.singletonList(new ExpressionProjectionSegment(0, 10, "CURRENT_DATE()"))); when(selectStatement.getProjections()).thenReturn(projectionsSegment); @@ -289,7 +289,7 @@ void assertCreateWithOtherSelectStatementForNullDatabaseName() { when(ProxyContext.getInstance().getAllDatabaseNames()).thenReturn(Collections.singleton("db_0")); when(ProxyContext.getInstance().getContextManager().getDatabase("db_0")).thenReturn(database); MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class); - when(selectStatement.getFrom()).thenReturn(null); + when(selectStatement.getFrom()).thenReturn(Optional.empty()); ProjectionsSegment projectionsSegment = mock(ProjectionsSegment.class); when(projectionsSegment.getProjections()).thenReturn(Collections.singletonList(new ExpressionProjectionSegment(0, 10, "CURRENT_DATE()"))); when(selectStatement.getProjections()).thenReturn(projectionsSegment); @@ -305,7 +305,7 @@ void assertCreateWithSelectStatementFromInformationSchemaOfDefaultExecutorTables SimpleTableSegment tableSegment = new SimpleTableSegment(new TableNameSegment(10, 13, new IdentifierValue("ENGINES"))); tableSegment.setOwner(new OwnerSegment(7, 8, new IdentifierValue("information_schema"))); MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class); - when(selectStatement.getFrom()).thenReturn(tableSegment); + when(selectStatement.getFrom()).thenReturn(Optional.of(tableSegment)); when(sqlStatementContext.getSqlStatement()).thenReturn(selectStatement); Optional actual = new MySQLAdminExecutorCreator().create(sqlStatementContext, "select ENGINE from ENGINES", "information_schema", Collections.emptyList()); assertTrue(actual.isPresent()); @@ -318,7 +318,7 @@ void assertCreateWithSelectStatementFromInformationSchemaOfSchemaTable() { SimpleTableSegment tableSegment = new SimpleTableSegment(new TableNameSegment(10, 13, new IdentifierValue("SCHEMATA"))); tableSegment.setOwner(new OwnerSegment(7, 8, new IdentifierValue("information_schema"))); MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class); - when(selectStatement.getFrom()).thenReturn(tableSegment); + when(selectStatement.getFrom()).thenReturn(Optional.of(tableSegment)); when(sqlStatementContext.getSqlStatement()).thenReturn(selectStatement); Optional actual = new MySQLAdminExecutorCreator().create(sqlStatementContext, "select SCHEMA_NAME from SCHEMATA", "information_schema", Collections.emptyList()); assertTrue(actual.isPresent()); @@ -334,7 +334,7 @@ void assertCreateWithSelectStatementFromInformationSchemaOfOtherTable() { SimpleTableSegment tableSegment = new SimpleTableSegment(new TableNameSegment(10, 13, new IdentifierValue("CHARACTER_SETS"))); tableSegment.setOwner(new OwnerSegment(7, 8, new IdentifierValue("information_schema"))); MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class); - when(selectStatement.getFrom()).thenReturn(tableSegment); + when(selectStatement.getFrom()).thenReturn(Optional.of(tableSegment)); when(sqlStatementContext.getSqlStatement()).thenReturn(selectStatement); Optional actual = new MySQLAdminExecutorCreator().create(sqlStatementContext, "select CHARACTER_SET_NAME from CHARACTER_SETS", "", Collections.emptyList()); assertFalse(actual.isPresent()); @@ -346,7 +346,7 @@ void assertCreateWithSelectStatementFromPerformanceSchema() { SimpleTableSegment tableSegment = new SimpleTableSegment(new TableNameSegment(10, 13, new IdentifierValue("accounts"))); tableSegment.setOwner(new OwnerSegment(7, 8, new IdentifierValue("performance_schema"))); MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class); - when(selectStatement.getFrom()).thenReturn(tableSegment); + when(selectStatement.getFrom()).thenReturn(Optional.of(tableSegment)); when(sqlStatementContext.getSqlStatement()).thenReturn(selectStatement); Optional actual = new MySQLAdminExecutorCreator().create(sqlStatementContext, "select * from accounts", "", Collections.emptyList()); assertFalse(actual.isPresent()); diff --git a/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/SelectStatementAssert.java b/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/SelectStatementAssert.java index b412216e8c8a1..9d15e2c28430a 100644 --- a/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/SelectStatementAssert.java +++ b/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/SelectStatementAssert.java @@ -50,7 +50,6 @@ import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; /** @@ -110,15 +109,17 @@ private static void assertProjection(final SQLCaseAssertContext assertContext, f private static void assertTable(final SQLCaseAssertContext assertContext, final SelectStatement actual, final SelectStatementTestCase expected) { if (null == expected.getFrom()) { - assertNull(actual.getFrom(), assertContext.getText("Actual simple-table should not exist.")); + assertFalse(actual.getFrom().isPresent(), assertContext.getText("Actual simple-table should not exist.")); } else { - TableAssert.assertIs(assertContext, actual.getFrom(), expected.getFrom()); + assertTrue(actual.getFrom().isPresent(), assertContext.getText("Actual from segment should exist.")); + TableAssert.assertIs(assertContext, actual.getFrom().get(), expected.getFrom()); } if (actual instanceof MySQLSelectStatement) { if (null == expected.getSimpleTable()) { assertFalse(((MySQLSelectStatement) actual).getTable().isPresent(), assertContext.getText("Actual simple-table should not exist.")); } else { Optional table = ((MySQLSelectStatement) actual).getTable(); + assertTrue(table.isPresent(), assertContext.getText("Actual table segment should exist.")); TableAssert.assertIs(assertContext, table.orElse(null), expected.getSimpleTable()); } }