Skip to content

Commit

Permalink
Fix the same issue with broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshuaChen committed Jan 18, 2025
1 parent 48cc200 commit d7e93c5
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,27 +72,27 @@ public static TablelessRouteEngine newInstance(final QueryContext queryContext,
SQLStatement sqlStatement = queryContext.getSqlStatementContext().getSqlStatement();
// TODO remove this logic when proxy and jdbc support all dal statement @duanzhengqiang
if (sqlStatement instanceof DALStatement) {
return getDALRouteEngine(sqlStatement, database);
return getDALRouteEngine(sqlStatement, database, queryContext.getConnectionContext());
}
// TODO Support more TCL statements by transaction module, then remove this.
if (sqlStatement instanceof TCLStatement) {
return new TablelessDataSourceBroadcastRouteEngine();
return new TablelessDataSourceBroadcastRouteEngine(queryContext.getConnectionContext());
}
if (sqlStatement instanceof DDLStatement) {
return getDDLRouteEngine(queryContext.getSqlStatementContext(), database);
return getDDLRouteEngine(queryContext.getSqlStatementContext(), database, queryContext.getConnectionContext());
}
if (sqlStatement instanceof DMLStatement) {
return getDMLRouteEngine(queryContext.getSqlStatementContext(), queryContext.getConnectionContext());
}
return new TablelessIgnoreRouteEngine();
}

private static TablelessRouteEngine getDALRouteEngine(final SQLStatement sqlStatement, final ShardingSphereDatabase database) {
private static TablelessRouteEngine getDALRouteEngine(final SQLStatement sqlStatement, final ShardingSphereDatabase database, final ConnectionContext connectionContext) {
if (sqlStatement instanceof ShowTablesStatement || sqlStatement instanceof ShowTableStatusStatement || sqlStatement instanceof SetStatement) {
return new TablelessDataSourceBroadcastRouteEngine();
return new TablelessDataSourceBroadcastRouteEngine(connectionContext);
}
if (sqlStatement instanceof ResetParameterStatement || sqlStatement instanceof ShowDatabasesStatement || sqlStatement instanceof LoadStatement) {
return new TablelessDataSourceBroadcastRouteEngine();
return new TablelessDataSourceBroadcastRouteEngine(connectionContext);
}
if (isResourceGroupStatement(sqlStatement)) {
return new TablelessInstanceBroadcastRouteEngine(database);
Expand All @@ -105,13 +105,13 @@ private static boolean isResourceGroupStatement(final SQLStatement sqlStatement)
|| sqlStatement instanceof SetResourceGroupStatement;
}

private static TablelessRouteEngine getDDLRouteEngine(final SQLStatementContext sqlStatementContext, final ShardingSphereDatabase database) {
private static TablelessRouteEngine getDDLRouteEngine(final SQLStatementContext sqlStatementContext, final ShardingSphereDatabase database, final ConnectionContext connectionContext) {
if (sqlStatementContext instanceof CursorAvailable) {
return getCursorRouteEngine(sqlStatementContext, database);
return getCursorRouteEngine(sqlStatementContext, database, connectionContext);
}
SQLStatement sqlStatement = sqlStatementContext.getSqlStatement();
if (isFunctionDDLStatement(sqlStatement) || isSchemaDDLStatement(sqlStatement)) {
return new TablelessDataSourceBroadcastRouteEngine();
return new TablelessDataSourceBroadcastRouteEngine(connectionContext);
}
return new TablelessIgnoreRouteEngine();
}
Expand All @@ -124,9 +124,9 @@ private static boolean isSchemaDDLStatement(final SQLStatement sqlStatement) {
return sqlStatement instanceof CreateSchemaStatement || sqlStatement instanceof AlterSchemaStatement || sqlStatement instanceof DropSchemaStatement;
}

private static TablelessRouteEngine getCursorRouteEngine(final SQLStatementContext sqlStatementContext, final ShardingSphereDatabase database) {
private static TablelessRouteEngine getCursorRouteEngine(final SQLStatementContext sqlStatementContext, final ShardingSphereDatabase database, final ConnectionContext connectionContext) {
if (sqlStatementContext instanceof CloseStatementContext && ((CloseStatementContext) sqlStatementContext).getSqlStatement().isCloseAll()) {
return new TablelessDataSourceBroadcastRouteEngine();
return new TablelessDataSourceBroadcastRouteEngine(connectionContext);
}
SQLStatement sqlStatement = sqlStatementContext.getSqlStatement();
if (sqlStatement instanceof CreateTablespaceStatement || sqlStatement instanceof AlterTablespaceStatement || sqlStatement instanceof DropTablespaceStatement) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,30 @@

package org.apache.shardingsphere.infra.route.engine.tableless.type.broadcast;

import lombok.RequiredArgsConstructor;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.route.context.RouteMapper;
import org.apache.shardingsphere.infra.route.context.RouteUnit;
import org.apache.shardingsphere.infra.route.engine.tableless.TablelessRouteEngine;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;

import java.util.Collection;
import java.util.Collections;

/**
* Tableless datasource broadcast route engine.
*/
@RequiredArgsConstructor
public final class TablelessDataSourceBroadcastRouteEngine implements TablelessRouteEngine {

private final ConnectionContext connectionContext;

@Override
public RouteContext route(final RuleMetaData globalRuleMetaData, final Collection<String> aggregatedDataSources) {
RouteContext result = new RouteContext();
for (String each : aggregatedDataSources) {
Collection<String> usedDataSourceNames = null == aggregatedDataSources || aggregatedDataSources.isEmpty() ? connectionContext.getUsedDataSourceNames() : aggregatedDataSources;
for (String each : usedDataSourceNames) {
result.getRouteUnits().add(new RouteUnit(new RouteMapper(each, each), Collections.emptyList()));
}
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ private RouteMapper getDataSourceRouteMapper(final Collection<String> dataSource
}

private String getRandomDataSourceName(final Collection<String> dataSourceNames) {
Collection<String> usedDataSourceNames = dataSourceNames == null || dataSourceNames.isEmpty() ? connectionContext.getUsedDataSourceNames() : dataSourceNames;
Collection<String> usedDataSourceNames = null == dataSourceNames || dataSourceNames.isEmpty() ? connectionContext.getUsedDataSourceNames() : dataSourceNames;
return new ArrayList<>(usedDataSourceNames).get(ThreadLocalRandom.current().nextInt(usedDataSourceNames.size()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.route.context.RouteUnit;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.junit.jupiter.MockitoExtension;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import static org.hamcrest.CoreMatchers.is;
Expand All @@ -36,15 +38,30 @@
class TablelessDataSourceBroadcastRouteEngineTest {

@Test
void assertRoute() {
RouteContext actual = new TablelessDataSourceBroadcastRouteEngine().route(mock(RuleMetaData.class), Arrays.asList("foo_ds_1", "foo_ds_2"));
void assertRouteWithNoAggregatedDataSources() {
ConnectionContext connectionContext = new ConnectionContext(() -> Arrays.asList("foo_ds_1", "foo_ds_2"));
RouteContext actual = new TablelessDataSourceBroadcastRouteEngine(connectionContext).route(mock(RuleMetaData.class), Collections.emptyList());
assertThat(actual.getRouteUnits().size(), is(2));
List<RouteUnit> routeUnits = new ArrayList<>(actual.getRouteUnits());
assertThat(routeUnits.get(0).getDataSourceMapper().getLogicName(), is("foo_ds_1"));
assertThat(routeUnits.get(0).getDataSourceMapper().getActualName(), is("foo_ds_1"));
assertThat(routeUnits.get(0).getDataSourceMapper().getLogicName(), is("foo_ds_2"));
assertThat(routeUnits.get(0).getDataSourceMapper().getActualName(), is("foo_ds_2"));
assertThat(routeUnits.get(0).getTableMappers().size(), is(0));
assertThat(routeUnits.get(1).getDataSourceMapper().getLogicName(), is("foo_ds_2"));
assertThat(routeUnits.get(1).getDataSourceMapper().getActualName(), is("foo_ds_2"));
assertThat(routeUnits.get(1).getDataSourceMapper().getLogicName(), is("foo_ds_1"));
assertThat(routeUnits.get(1).getDataSourceMapper().getActualName(), is("foo_ds_1"));
assertThat(routeUnits.get(1).getTableMappers().size(), is(0));
}

@Test
void assertRouteWithAggregatedDataSources() {
ConnectionContext connectionContext = new ConnectionContext(() -> Arrays.asList("foo_ds_1", "foo_ds_2"));
RouteContext actual = new TablelessDataSourceBroadcastRouteEngine(connectionContext).route(mock(RuleMetaData.class), Arrays.asList("foo_ds_3", "foo_ds_4"));
assertThat(actual.getRouteUnits().size(), is(2));
List<RouteUnit> routeUnits = new ArrayList<>(actual.getRouteUnits());
assertThat(routeUnits.get(0).getDataSourceMapper().getLogicName(), is("foo_ds_3"));
assertThat(routeUnits.get(0).getDataSourceMapper().getActualName(), is("foo_ds_3"));
assertThat(routeUnits.get(0).getTableMappers().size(), is(0));
assertThat(routeUnits.get(1).getDataSourceMapper().getLogicName(), is("foo_ds_4"));
assertThat(routeUnits.get(1).getDataSourceMapper().getActualName(), is("foo_ds_4"));
assertThat(routeUnits.get(1).getTableMappers().size(), is(0));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,33 +31,33 @@

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.hamcrest.Matchers.anyOf;
import static org.mockito.Mockito.mock;

@ExtendWith(MockitoExtension.class)
class TablelessDataSourceUnicastRouteEngineTest {

@Test
void assertRouteWithoutUsedDataSourceNames() {
ConnectionContext connectionContext = new ConnectionContext(Collections::emptyList);
Collection<String> aggregatedDataSources = Arrays.asList("foo_ds_1", "foo_ds_2");
void assertRouteWithNoAggregatedDataSources() {
ConnectionContext connectionContext = new ConnectionContext(() -> Arrays.asList("foo_ds_1", "foo_ds_2"));
Collection<String> aggregatedDataSources = Collections.emptyList();
RouteContext actual = new TablelessDataSourceUnicastRouteEngine(connectionContext).route(mock(RuleMetaData.class), aggregatedDataSources);
assertThat(actual.getRouteUnits().size(), is(1));
RouteUnit routeUnit = actual.getRouteUnits().iterator().next();
assertTrue(aggregatedDataSources.contains(routeUnit.getDataSourceMapper().getLogicName()));
assertTrue(aggregatedDataSources.contains(routeUnit.getDataSourceMapper().getActualName()));
assertThat(routeUnit.getDataSourceMapper().getLogicName(), anyOf(is("foo_ds_1"), is("foo_ds_2")));
assertThat(routeUnit.getDataSourceMapper().getActualName(), anyOf(is("foo_ds_1"), is("foo_ds_2")));
assertThat(routeUnit.getTableMappers().size(), is(0));
}

@Test
void assertRouteWithUsedDataSourceNames() {
void assertRouteWithAggregatedDataSources() {
ConnectionContext connectionContext = new ConnectionContext(() -> Collections.singleton("foo_ds_1"));
Collection<String> aggregatedDataSources = Arrays.asList("foo_ds_1", "foo_ds_2");
Collection<String> aggregatedDataSources = Arrays.asList("foo_ds_2", "foo_ds_3");
RouteContext actual = new TablelessDataSourceUnicastRouteEngine(connectionContext).route(mock(RuleMetaData.class), aggregatedDataSources);
assertThat(actual.getRouteUnits().size(), is(1));
RouteUnit routeUnit = actual.getRouteUnits().iterator().next();
assertThat(routeUnit.getDataSourceMapper().getLogicName(), is("foo_ds_1"));
assertThat(routeUnit.getDataSourceMapper().getActualName(), is("foo_ds_1"));
assertThat(routeUnit.getDataSourceMapper().getLogicName(), anyOf(is("foo_ds_2"), is("foo_ds_3")));
assertThat(routeUnit.getDataSourceMapper().getActualName(), anyOf(is("foo_ds_2"), is("foo_ds_3")));
assertThat(routeUnit.getTableMappers().size(), is(0));
}
}

0 comments on commit d7e93c5

Please sign in to comment.