Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support GroupConcat sql for aggregating multiple shards(#33797) #33808

Merged
merged 26 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ private void initForFirstGroupByValue(final SelectStatementContext selectStateme
dataMap.put(groupByValue, new MemoryQueryResultRow(queryResult));
}
aggregationMap.computeIfAbsent(groupByValue, unused -> selectStatementContext.getProjectionsContext().getAggregationProjections().stream()
.collect(Collectors.toMap(Function.identity(), input -> AggregationUnitFactory.create(input.getType(), input instanceof AggregationDistinctProjection))));
.collect(Collectors.toMap(Function.identity(), input -> AggregationUnitFactory.create(input.getType(), input instanceof AggregationDistinctProjection, input.getSeparator()))));
}

private void aggregate(final SelectStatementContext selectStatementContext, final QueryResult queryResult,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ private boolean aggregateCurrentGroupByRowAndNext() throws SQLException {
boolean result = false;
boolean cachedRow = false;
Map<AggregationProjection, AggregationUnit> aggregationUnitMap = Maps.toMap(
selectStatementContext.getProjectionsContext().getAggregationProjections(), input -> AggregationUnitFactory.create(input.getType(), input instanceof AggregationDistinctProjection));
selectStatementContext.getProjectionsContext().getAggregationProjections(),
input -> AggregationUnitFactory.create(input.getType(), input instanceof AggregationDistinctProjection, input.getSeparator()));
while (currentGroupByValues.equals(new GroupByValue(getCurrentQueryResult(), selectStatementContext.getGroupByContext().getItems()).getGroupValues())) {
aggregate(aggregationUnitMap);
if (!cachedRow) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ public final class AggregationUnitFactory {
* @throws UnsupportedSQLOperationException unsupported SQL operation exception
*/
public static AggregationUnit create(final AggregationType type, final boolean isDistinct) {
YaoFly marked this conversation as resolved.
Show resolved Hide resolved
return create(type, isDistinct, null);
}

/**
* Create aggregation unit instance.
*
* @param type aggregation function type
* @param isDistinct is distinct
* @param separator is separator for group_concat
* @return aggregation unit instance
* @throws UnsupportedSQLOperationException unsupported SQL operation exception
*/
public static AggregationUnit create(final AggregationType type, final boolean isDistinct, final String separator) {
switch (type) {
case MAX:
return new ComparableAggregationUnit(false);
Expand All @@ -50,6 +63,8 @@ public static AggregationUnit create(final AggregationType type, final boolean i
return isDistinct ? new DistinctAverageAggregationUnit() : new AverageAggregationUnit();
case BIT_XOR:
return new BitXorAggregationUnit();
case GROUP_CONCAT:
return isDistinct ? new DistinctGroupConcatAggregationUnit(separator) : new GroupConcatAggregationUnit(separator);
default:
throw new UnsupportedSQLOperationException(type.name());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.shardingsphere.sharding.merge.dql.groupby.aggregation;

import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.stream.Collectors;

public class DistinctGroupConcatAggregationUnit implements AggregationUnit {
YaoFly marked this conversation as resolved.
Show resolved Hide resolved

private static final String DEFAULT_SEPARATOR = ",";

private final Collection<Comparable<?>> values = new LinkedHashSet<>();

private String separator;
YaoFly marked this conversation as resolved.
Show resolved Hide resolved

public DistinctGroupConcatAggregationUnit(final String separator) {
this.separator = separator;
}

@Override
public void merge(final List<Comparable<?>> values) {
if (null == values || null == values.get(0)) {
YaoFly marked this conversation as resolved.
Show resolved Hide resolved
return;
}
this.values.add(values.get(0));
}

@Override
public Comparable<?> getResult() {
if (null == separator) {
YaoFly marked this conversation as resolved.
Show resolved Hide resolved
separator = DEFAULT_SEPARATOR;
}
return values.stream().map(Object::toString).collect(Collectors.joining(separator));
YaoFly marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.shardingsphere.sharding.merge.dql.groupby.aggregation;

import lombok.NoArgsConstructor;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;

@NoArgsConstructor
public class GroupConcatAggregationUnit implements AggregationUnit {
YaoFly marked this conversation as resolved.
Show resolved Hide resolved

private static final String DEFAULT_SEPARATOR = ",";

private final Collection<Comparable<?>> values = new ArrayList<>();

private String separator;

public GroupConcatAggregationUnit(final String separator) {
this.separator = separator;
}

@Override
public void merge(final List<Comparable<?>> values) {
if (null == values || null == values.get(0)) {
return;
}
this.values.add(values.get(0));
}

@Override
public Comparable<?> getResult() {
if (null == separator) {
YaoFly marked this conversation as resolved.
Show resolved Hide resolved
separator = DEFAULT_SEPARATOR;
}
return values.stream().map(Object::toString).collect(Collectors.joining(separator));
YaoFly marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,10 @@ void assertCreateDistinctAverageAggregationUnit() {
void assertCreateBitXorAggregationUnit() {
assertThat(AggregationUnitFactory.create(AggregationType.BIT_XOR, false), instanceOf(BitXorAggregationUnit.class));
}

@Test
void assertGroupConcatAggregationUnit() {
assertThat(AggregationUnitFactory.create(AggregationType.GROUP_CONCAT, true), instanceOf(DistinctGroupConcatAggregationUnit.class));
assertThat(AggregationUnitFactory.create(AggregationType.GROUP_CONCAT, true, " "), instanceOf(DistinctGroupConcatAggregationUnit.class));
YaoFly marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.shardingsphere.sharding.merge.dql.groupby.aggregation;

import org.junit.jupiter.api.Test;

import java.util.Collections;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;

class GroupConcatAggregationUnitTest {

@Test
void assertGroupConcatAggregation() {
GroupConcatAggregationUnit groupConcatAggregationUnit = new GroupConcatAggregationUnit(" ");
groupConcatAggregationUnit.merge(null);
groupConcatAggregationUnit.merge(Collections.singletonList(null));
groupConcatAggregationUnit.merge(Collections.singletonList("001"));
groupConcatAggregationUnit.merge(Collections.singletonList("002"));
groupConcatAggregationUnit.merge(Collections.singletonList("002 003"));
assertThat(groupConcatAggregationUnit.getResult(), is("001 002 002 003"));
}

@Test
void assertDistinctGroupConcatAggregation() {
DistinctGroupConcatAggregationUnit distinctGroupConcatAggregationUnit = new DistinctGroupConcatAggregationUnit(" ");
distinctGroupConcatAggregationUnit.merge(null);
distinctGroupConcatAggregationUnit.merge(Collections.singletonList(null));
distinctGroupConcatAggregationUnit.merge(Collections.singletonList(""));
distinctGroupConcatAggregationUnit.merge(Collections.singletonList("001"));
distinctGroupConcatAggregationUnit.merge(Collections.singletonList("001"));
distinctGroupConcatAggregationUnit.merge(Collections.singletonList("003"));
assertThat(distinctGroupConcatAggregationUnit.getResult(), is(" 001 003"));
YaoFly marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,16 @@ private AggregationDistinctProjection createProjection(final AggregationDistinct
projectionSegment.getAlias().orElseGet(() -> new IdentifierValue(DerivedColumn.AGGREGATION_DISTINCT_DERIVED.getDerivedColumnAlias(aggregationDistinctDerivedColumnCount++)));
AggregationDistinctProjection result = new AggregationDistinctProjection(
projectionSegment.getStartIndex(), projectionSegment.getStopIndex(), projectionSegment.getType(), projectionSegment.getExpression(), alias,
projectionSegment.getDistinctInnerExpression(), databaseType);
projectionSegment.getDistinctInnerExpression(), databaseType, projectionSegment.getSeparator());
if (AggregationType.AVG == result.getType()) {
appendAverageDistinctDerivedProjection(result);
}
return result;
}

private AggregationProjection createProjection(final AggregationProjectionSegment projectionSegment) {
AggregationProjection result = new AggregationProjection(projectionSegment.getType(), projectionSegment.getExpression(), projectionSegment.getAlias().orElse(null), databaseType);
AggregationProjection result =
new AggregationProjection(projectionSegment.getType(), projectionSegment.getExpression(), projectionSegment.getAlias().orElse(null), databaseType, projectionSegment.getSeparator());
if (AggregationType.AVG == result.getType()) {
appendAverageDerivedProjection(result);
// TODO replace avg to constant, avoid calculate useless avg
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,23 @@ public final class AggregationDistinctProjection extends AggregationProjection {

private final String distinctInnerExpression;

private final String separator;
YaoFly marked this conversation as resolved.
Show resolved Hide resolved

public AggregationDistinctProjection(final int startIndex, final int stopIndex, final AggregationType type, final String expression,
final IdentifierValue alias, final String distinctInnerExpression, final DatabaseType databaseType) {
super(type, expression, alias, databaseType);
this.startIndex = startIndex;
this.stopIndex = stopIndex;
this.distinctInnerExpression = distinctInnerExpression;
this.separator = null;
}

public AggregationDistinctProjection(final int startIndex, final int stopIndex, final AggregationType type, final String expression,
final IdentifierValue alias, final String distinctInnerExpression, final DatabaseType databaseType, final String separator) {
super(type, expression, alias, databaseType);
this.startIndex = startIndex;
this.stopIndex = stopIndex;
this.distinctInnerExpression = distinctInnerExpression;
this.separator = separator;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,21 @@ public class AggregationProjection implements Projection {

private final DatabaseType databaseType;

private final String separator;
YaoFly marked this conversation as resolved.
Show resolved Hide resolved

private final List<AggregationProjection> derivedAggregationProjections = new ArrayList<>(2);

@Setter
private int index = -1;

public AggregationProjection(final AggregationType type, final String expression, final IdentifierValue alias, final DatabaseType databaseType) {
this.type = type;
this.expression = expression;
this.alias = alias;
this.databaseType = databaseType;
this.separator = null;
}

@Override
public String getColumnName() {
return getColumnLabel();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,17 @@ public final class AggregationProjectionConverter {
register(SqlStdOperatorTable.COUNT);
register(SqlStdOperatorTable.AVG);
register(SqlStdOperatorTable.BIT_XOR);
register(SqlStdOperatorTable.LISTAGG, "GROUP_CONCAT");
}

private static void register(final SqlAggFunction sqlAggFunction) {
REGISTRY.put(sqlAggFunction.getName(), sqlAggFunction);
}

private static void register(final SqlAggFunction sqlAggFunction, final String alias) {
REGISTRY.put(alias, sqlAggFunction);
}

/**
* Convert aggregation projection segment to sql node.
*
Expand All @@ -75,7 +80,7 @@ public static Optional<SqlNode> convert(final AggregationProjectionSegment segme
}
SqlLiteral functionQuantifier = segment instanceof AggregationDistinctProjectionSegment ? SqlLiteral.createSymbol(SqlSelectKeyword.DISTINCT, SqlParserPos.ZERO) : null;
SqlAggFunction operator = convertOperator(segment.getType().name());
List<SqlNode> params = convertParameters(segment.getParameters(), segment.getExpression());
List<SqlNode> params = convertParameters(segment.getParameters(), segment.getExpression(), segment.getSeparator());
SqlBasicCall sqlBasicCall = new SqlBasicCall(operator, params, SqlParserPos.ZERO, functionQuantifier);
if (segment.getAliasName().isPresent()) {
return Optional.of(new SqlBasicCall(SqlStdOperatorTable.AS, Arrays.asList(sqlBasicCall,
Expand All @@ -89,14 +94,17 @@ private static SqlAggFunction convertOperator(final String operator) {
return REGISTRY.get(operator);
}

private static List<SqlNode> convertParameters(final Collection<ExpressionSegment> params, final String expression) {
private static List<SqlNode> convertParameters(final Collection<ExpressionSegment> params, final String expression, final String separator) {
if (expression.contains("*")) {
return Collections.singletonList(SqlIdentifier.star(SqlParserPos.ZERO));
}
List<SqlNode> result = new LinkedList<>();
for (ExpressionSegment each : params) {
ExpressionConverter.convert(each).ifPresent(result::add);
}
if (null != separator) {
result.add(SqlLiteral.createCharString(separator, SqlParserPos.ZERO));
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -962,8 +962,16 @@ udfFunction
: functionName LP_ (expr? | expr (COMMA_ expr)*) RP_
;

separatorName
YaoFly marked this conversation as resolved.
Show resolved Hide resolved
: SEPARATOR string_
;

aggregationExpression
: expr (COMMA_ expr)* | ASTERISK_
;

aggregationFunction
: aggregationFunctionName LP_ distinct? (expr (COMMA_ expr)* | ASTERISK_)? collateClause? RP_ overClause?
: aggregationFunctionName LP_ distinct? aggregationExpression? collateClause? separatorName? RP_ overClause?
;

jsonFunction
Expand Down
Loading
Loading