From d7e93c5c9ee37b6ab99ad9c5843f5a0dff39bcda Mon Sep 17 00:00:00 2001 From: Joshua Chen <27291761@qq.com> Date: Sat, 18 Jan 2025 23:45:08 +0800 Subject: [PATCH] Fix the same issue with broadcast --- .../TablelessRouteEngineFactory.java | 22 +++++++------- ...blelessDataSourceBroadcastRouteEngine.java | 8 ++++- ...TablelessDataSourceUnicastRouteEngine.java | 2 +- ...essDataSourceBroadcastRouteEngineTest.java | 29 +++++++++++++++---- ...elessDataSourceUnicastRouteEngineTest.java | 20 ++++++------- 5 files changed, 52 insertions(+), 29 deletions(-) diff --git a/infra/route/src/main/java/org/apache/shardingsphere/infra/route/engine/tableless/TablelessRouteEngineFactory.java b/infra/route/src/main/java/org/apache/shardingsphere/infra/route/engine/tableless/TablelessRouteEngineFactory.java index 0d39734c0d966..c81733eb3b889 100644 --- a/infra/route/src/main/java/org/apache/shardingsphere/infra/route/engine/tableless/TablelessRouteEngineFactory.java +++ b/infra/route/src/main/java/org/apache/shardingsphere/infra/route/engine/tableless/TablelessRouteEngineFactory.java @@ -72,14 +72,14 @@ public static TablelessRouteEngine newInstance(final QueryContext queryContext, SQLStatement sqlStatement = queryContext.getSqlStatementContext().getSqlStatement(); // TODO remove this logic when proxy and jdbc support all dal statement @duanzhengqiang if (sqlStatement instanceof DALStatement) { - return getDALRouteEngine(sqlStatement, database); + return getDALRouteEngine(sqlStatement, database, queryContext.getConnectionContext()); } // TODO Support more TCL statements by transaction module, then remove this. if (sqlStatement instanceof TCLStatement) { - return new TablelessDataSourceBroadcastRouteEngine(); + return new TablelessDataSourceBroadcastRouteEngine(queryContext.getConnectionContext()); } if (sqlStatement instanceof DDLStatement) { - return getDDLRouteEngine(queryContext.getSqlStatementContext(), database); + return getDDLRouteEngine(queryContext.getSqlStatementContext(), database, queryContext.getConnectionContext()); } if (sqlStatement instanceof DMLStatement) { return getDMLRouteEngine(queryContext.getSqlStatementContext(), queryContext.getConnectionContext()); @@ -87,12 +87,12 @@ public static TablelessRouteEngine newInstance(final QueryContext queryContext, return new TablelessIgnoreRouteEngine(); } - private static TablelessRouteEngine getDALRouteEngine(final SQLStatement sqlStatement, final ShardingSphereDatabase database) { + private static TablelessRouteEngine getDALRouteEngine(final SQLStatement sqlStatement, final ShardingSphereDatabase database, final ConnectionContext connectionContext) { if (sqlStatement instanceof ShowTablesStatement || sqlStatement instanceof ShowTableStatusStatement || sqlStatement instanceof SetStatement) { - return new TablelessDataSourceBroadcastRouteEngine(); + return new TablelessDataSourceBroadcastRouteEngine(connectionContext); } if (sqlStatement instanceof ResetParameterStatement || sqlStatement instanceof ShowDatabasesStatement || sqlStatement instanceof LoadStatement) { - return new TablelessDataSourceBroadcastRouteEngine(); + return new TablelessDataSourceBroadcastRouteEngine(connectionContext); } if (isResourceGroupStatement(sqlStatement)) { return new TablelessInstanceBroadcastRouteEngine(database); @@ -105,13 +105,13 @@ private static boolean isResourceGroupStatement(final SQLStatement sqlStatement) || sqlStatement instanceof SetResourceGroupStatement; } - private static TablelessRouteEngine getDDLRouteEngine(final SQLStatementContext sqlStatementContext, final ShardingSphereDatabase database) { + private static TablelessRouteEngine getDDLRouteEngine(final SQLStatementContext sqlStatementContext, final ShardingSphereDatabase database, final ConnectionContext connectionContext) { if (sqlStatementContext instanceof CursorAvailable) { - return getCursorRouteEngine(sqlStatementContext, database); + return getCursorRouteEngine(sqlStatementContext, database, connectionContext); } SQLStatement sqlStatement = sqlStatementContext.getSqlStatement(); if (isFunctionDDLStatement(sqlStatement) || isSchemaDDLStatement(sqlStatement)) { - return new TablelessDataSourceBroadcastRouteEngine(); + return new TablelessDataSourceBroadcastRouteEngine(connectionContext); } return new TablelessIgnoreRouteEngine(); } @@ -124,9 +124,9 @@ private static boolean isSchemaDDLStatement(final SQLStatement sqlStatement) { return sqlStatement instanceof CreateSchemaStatement || sqlStatement instanceof AlterSchemaStatement || sqlStatement instanceof DropSchemaStatement; } - private static TablelessRouteEngine getCursorRouteEngine(final SQLStatementContext sqlStatementContext, final ShardingSphereDatabase database) { + private static TablelessRouteEngine getCursorRouteEngine(final SQLStatementContext sqlStatementContext, final ShardingSphereDatabase database, final ConnectionContext connectionContext) { if (sqlStatementContext instanceof CloseStatementContext && ((CloseStatementContext) sqlStatementContext).getSqlStatement().isCloseAll()) { - return new TablelessDataSourceBroadcastRouteEngine(); + return new TablelessDataSourceBroadcastRouteEngine(connectionContext); } SQLStatement sqlStatement = sqlStatementContext.getSqlStatement(); if (sqlStatement instanceof CreateTablespaceStatement || sqlStatement instanceof AlterTablespaceStatement || sqlStatement instanceof DropTablespaceStatement) { diff --git a/infra/route/src/main/java/org/apache/shardingsphere/infra/route/engine/tableless/type/broadcast/TablelessDataSourceBroadcastRouteEngine.java b/infra/route/src/main/java/org/apache/shardingsphere/infra/route/engine/tableless/type/broadcast/TablelessDataSourceBroadcastRouteEngine.java index 307f6bdf7849e..49a6155630384 100644 --- a/infra/route/src/main/java/org/apache/shardingsphere/infra/route/engine/tableless/type/broadcast/TablelessDataSourceBroadcastRouteEngine.java +++ b/infra/route/src/main/java/org/apache/shardingsphere/infra/route/engine/tableless/type/broadcast/TablelessDataSourceBroadcastRouteEngine.java @@ -17,11 +17,13 @@ package org.apache.shardingsphere.infra.route.engine.tableless.type.broadcast; +import lombok.RequiredArgsConstructor; import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; import org.apache.shardingsphere.infra.route.context.RouteContext; import org.apache.shardingsphere.infra.route.context.RouteMapper; import org.apache.shardingsphere.infra.route.context.RouteUnit; import org.apache.shardingsphere.infra.route.engine.tableless.TablelessRouteEngine; +import org.apache.shardingsphere.infra.session.connection.ConnectionContext; import java.util.Collection; import java.util.Collections; @@ -29,12 +31,16 @@ /** * Tableless datasource broadcast route engine. */ +@RequiredArgsConstructor public final class TablelessDataSourceBroadcastRouteEngine implements TablelessRouteEngine { + private final ConnectionContext connectionContext; + @Override public RouteContext route(final RuleMetaData globalRuleMetaData, final Collection aggregatedDataSources) { RouteContext result = new RouteContext(); - for (String each : aggregatedDataSources) { + Collection usedDataSourceNames = null == aggregatedDataSources || aggregatedDataSources.isEmpty() ? connectionContext.getUsedDataSourceNames() : aggregatedDataSources; + for (String each : usedDataSourceNames) { result.getRouteUnits().add(new RouteUnit(new RouteMapper(each, each), Collections.emptyList())); } return result; diff --git a/infra/route/src/main/java/org/apache/shardingsphere/infra/route/engine/tableless/type/unicast/TablelessDataSourceUnicastRouteEngine.java b/infra/route/src/main/java/org/apache/shardingsphere/infra/route/engine/tableless/type/unicast/TablelessDataSourceUnicastRouteEngine.java index ffd587839dc4c..fdb4dcc9740ed 100644 --- a/infra/route/src/main/java/org/apache/shardingsphere/infra/route/engine/tableless/type/unicast/TablelessDataSourceUnicastRouteEngine.java +++ b/infra/route/src/main/java/org/apache/shardingsphere/infra/route/engine/tableless/type/unicast/TablelessDataSourceUnicastRouteEngine.java @@ -51,7 +51,7 @@ private RouteMapper getDataSourceRouteMapper(final Collection dataSource } private String getRandomDataSourceName(final Collection dataSourceNames) { - Collection usedDataSourceNames = dataSourceNames == null || dataSourceNames.isEmpty() ? connectionContext.getUsedDataSourceNames() : dataSourceNames; + Collection usedDataSourceNames = null == dataSourceNames || dataSourceNames.isEmpty() ? connectionContext.getUsedDataSourceNames() : dataSourceNames; return new ArrayList<>(usedDataSourceNames).get(ThreadLocalRandom.current().nextInt(usedDataSourceNames.size())); } } diff --git a/infra/route/src/test/java/org/apache/shardingsphere/infra/route/engine/tableless/type/broadcast/TablelessDataSourceBroadcastRouteEngineTest.java b/infra/route/src/test/java/org/apache/shardingsphere/infra/route/engine/tableless/type/broadcast/TablelessDataSourceBroadcastRouteEngineTest.java index b1d84018277a8..421d6128c9131 100644 --- a/infra/route/src/test/java/org/apache/shardingsphere/infra/route/engine/tableless/type/broadcast/TablelessDataSourceBroadcastRouteEngineTest.java +++ b/infra/route/src/test/java/org/apache/shardingsphere/infra/route/engine/tableless/type/broadcast/TablelessDataSourceBroadcastRouteEngineTest.java @@ -20,12 +20,14 @@ import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData; import org.apache.shardingsphere.infra.route.context.RouteContext; import org.apache.shardingsphere.infra.route.context.RouteUnit; +import org.apache.shardingsphere.infra.session.connection.ConnectionContext; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.junit.jupiter.MockitoExtension; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import static org.hamcrest.CoreMatchers.is; @@ -36,15 +38,30 @@ class TablelessDataSourceBroadcastRouteEngineTest { @Test - void assertRoute() { - RouteContext actual = new TablelessDataSourceBroadcastRouteEngine().route(mock(RuleMetaData.class), Arrays.asList("foo_ds_1", "foo_ds_2")); + void assertRouteWithNoAggregatedDataSources() { + ConnectionContext connectionContext = new ConnectionContext(() -> Arrays.asList("foo_ds_1", "foo_ds_2")); + RouteContext actual = new TablelessDataSourceBroadcastRouteEngine(connectionContext).route(mock(RuleMetaData.class), Collections.emptyList()); assertThat(actual.getRouteUnits().size(), is(2)); List routeUnits = new ArrayList<>(actual.getRouteUnits()); - assertThat(routeUnits.get(0).getDataSourceMapper().getLogicName(), is("foo_ds_1")); - assertThat(routeUnits.get(0).getDataSourceMapper().getActualName(), is("foo_ds_1")); + assertThat(routeUnits.get(0).getDataSourceMapper().getLogicName(), is("foo_ds_2")); + assertThat(routeUnits.get(0).getDataSourceMapper().getActualName(), is("foo_ds_2")); assertThat(routeUnits.get(0).getTableMappers().size(), is(0)); - assertThat(routeUnits.get(1).getDataSourceMapper().getLogicName(), is("foo_ds_2")); - assertThat(routeUnits.get(1).getDataSourceMapper().getActualName(), is("foo_ds_2")); + assertThat(routeUnits.get(1).getDataSourceMapper().getLogicName(), is("foo_ds_1")); + assertThat(routeUnits.get(1).getDataSourceMapper().getActualName(), is("foo_ds_1")); + assertThat(routeUnits.get(1).getTableMappers().size(), is(0)); + } + + @Test + void assertRouteWithAggregatedDataSources() { + ConnectionContext connectionContext = new ConnectionContext(() -> Arrays.asList("foo_ds_1", "foo_ds_2")); + RouteContext actual = new TablelessDataSourceBroadcastRouteEngine(connectionContext).route(mock(RuleMetaData.class), Arrays.asList("foo_ds_3", "foo_ds_4")); + assertThat(actual.getRouteUnits().size(), is(2)); + List routeUnits = new ArrayList<>(actual.getRouteUnits()); + assertThat(routeUnits.get(0).getDataSourceMapper().getLogicName(), is("foo_ds_3")); + assertThat(routeUnits.get(0).getDataSourceMapper().getActualName(), is("foo_ds_3")); + assertThat(routeUnits.get(0).getTableMappers().size(), is(0)); + assertThat(routeUnits.get(1).getDataSourceMapper().getLogicName(), is("foo_ds_4")); + assertThat(routeUnits.get(1).getDataSourceMapper().getActualName(), is("foo_ds_4")); assertThat(routeUnits.get(1).getTableMappers().size(), is(0)); } } diff --git a/infra/route/src/test/java/org/apache/shardingsphere/infra/route/engine/tableless/type/unicast/TablelessDataSourceUnicastRouteEngineTest.java b/infra/route/src/test/java/org/apache/shardingsphere/infra/route/engine/tableless/type/unicast/TablelessDataSourceUnicastRouteEngineTest.java index e2f26d8c7231e..132d5cc2e6057 100644 --- a/infra/route/src/test/java/org/apache/shardingsphere/infra/route/engine/tableless/type/unicast/TablelessDataSourceUnicastRouteEngineTest.java +++ b/infra/route/src/test/java/org/apache/shardingsphere/infra/route/engine/tableless/type/unicast/TablelessDataSourceUnicastRouteEngineTest.java @@ -31,33 +31,33 @@ import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.hamcrest.Matchers.anyOf; import static org.mockito.Mockito.mock; @ExtendWith(MockitoExtension.class) class TablelessDataSourceUnicastRouteEngineTest { @Test - void assertRouteWithoutUsedDataSourceNames() { - ConnectionContext connectionContext = new ConnectionContext(Collections::emptyList); - Collection aggregatedDataSources = Arrays.asList("foo_ds_1", "foo_ds_2"); + void assertRouteWithNoAggregatedDataSources() { + ConnectionContext connectionContext = new ConnectionContext(() -> Arrays.asList("foo_ds_1", "foo_ds_2")); + Collection aggregatedDataSources = Collections.emptyList(); RouteContext actual = new TablelessDataSourceUnicastRouteEngine(connectionContext).route(mock(RuleMetaData.class), aggregatedDataSources); assertThat(actual.getRouteUnits().size(), is(1)); RouteUnit routeUnit = actual.getRouteUnits().iterator().next(); - assertTrue(aggregatedDataSources.contains(routeUnit.getDataSourceMapper().getLogicName())); - assertTrue(aggregatedDataSources.contains(routeUnit.getDataSourceMapper().getActualName())); + assertThat(routeUnit.getDataSourceMapper().getLogicName(), anyOf(is("foo_ds_1"), is("foo_ds_2"))); + assertThat(routeUnit.getDataSourceMapper().getActualName(), anyOf(is("foo_ds_1"), is("foo_ds_2"))); assertThat(routeUnit.getTableMappers().size(), is(0)); } @Test - void assertRouteWithUsedDataSourceNames() { + void assertRouteWithAggregatedDataSources() { ConnectionContext connectionContext = new ConnectionContext(() -> Collections.singleton("foo_ds_1")); - Collection aggregatedDataSources = Arrays.asList("foo_ds_1", "foo_ds_2"); + Collection aggregatedDataSources = Arrays.asList("foo_ds_2", "foo_ds_3"); RouteContext actual = new TablelessDataSourceUnicastRouteEngine(connectionContext).route(mock(RuleMetaData.class), aggregatedDataSources); assertThat(actual.getRouteUnits().size(), is(1)); RouteUnit routeUnit = actual.getRouteUnits().iterator().next(); - assertThat(routeUnit.getDataSourceMapper().getLogicName(), is("foo_ds_1")); - assertThat(routeUnit.getDataSourceMapper().getActualName(), is("foo_ds_1")); + assertThat(routeUnit.getDataSourceMapper().getLogicName(), anyOf(is("foo_ds_2"), is("foo_ds_3"))); + assertThat(routeUnit.getDataSourceMapper().getActualName(), anyOf(is("foo_ds_2"), is("foo_ds_3"))); assertThat(routeUnit.getTableMappers().size(), is(0)); } }