Skip to content

Commit

Permalink
Support multi version of group_concat
Browse files Browse the repository at this point in the history
Signed-off-by: shuming.li <[email protected]>
  • Loading branch information
LiShuMing committed Oct 18, 2024
1 parent 2b3d80a commit 970caa0
Show file tree
Hide file tree
Showing 25 changed files with 235 additions and 113 deletions.
16 changes: 13 additions & 3 deletions be/src/exec/aggregator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ void AggregatorParams::init() {
agg_fn_types[i] = {return_type, serde_type, arg_typedescs, has_nullable_child, is_nullable};
agg_fn_types[i].is_always_nullable_result =
ALWAYS_NULLABLE_RESULT_AGG_FUNCS.contains(fn.name.function_name);
if (fn.name.function_name == "array_agg" || fn.name.function_name == "group_concat") {
// TODO(fixme): move this to FE
if (fn.name.function_name == "array_agg" || fn.name.function_name == "group_concat2") {
// set order by info
if (fn.aggregate_fn.__isset.is_asc_order && fn.aggregate_fn.__isset.nulls_first &&
!fn.aggregate_fn.is_asc_order.empty()) {
Expand Down Expand Up @@ -469,8 +470,12 @@ Status Aggregator::prepare(RuntimeState* state, ObjectPool* pool, RuntimeProfile
_agg_fn_ctxs[i] =
FunctionContext::create_context(state, _mem_pool.get(), return_type, arg_types, agg_fn_type.is_distinct,
agg_fn_type.is_asc_order, agg_fn_type.nulls_first);
auto& ctx_query_options = _agg_fn_ctxs[i]->get_ctx_query_options();
if (state->query_options().__isset.group_concat_max_len) {
_agg_fn_ctxs[i]->set_group_concat_max_len(state->query_options().group_concat_max_len);
ctx_query_options.set_group_concat_max_len(state->query_options().group_concat_max_len);
}
if (state->query_options().__isset.default_group_concat_separator) {
ctx_query_options.set_default_group_concat_separator(state->query_options().default_group_concat_separator);
}
state->obj_pool()->add(_agg_fn_ctxs[i]);
_agg_fn_ctxs[i]->set_mem_usage_counter(&_agg_state_mem_usage);
Expand Down Expand Up @@ -511,7 +516,7 @@ Status Aggregator::_create_aggregate_function(starrocks::RuntimeState* state, co
}

// check whether it's _merge/_union combinator if it contains agg state type
auto& func_name = fn.name.function_name;
auto func_name = fn.name.function_name;
if (fn.__isset.agg_state_desc) {
if (arg_types.size() != 1) {
return Status::InternalError(strings::Substitute("Invalid agg function plan: $0 with (arg type $1)",
Expand Down Expand Up @@ -558,6 +563,11 @@ Status Aggregator::_create_aggregate_function(starrocks::RuntimeState* state, co
TypeDescriptor serde_type = TypeDescriptor::from_thrift(fn.aggregate_fn.intermediate_type);
DCHECK_LE(1, fn.arg_types.size());
TypeDescriptor arg_type = arg_types[0];

// To be compatible with old versions, change group_concat2 name to group_concat if the intermediate type is string.
if (fn.name.function_name == "group_concat" && serde_type.type == TYPE_STRUCT) {
func_name = "group_concat2";
}
auto* func = get_aggregate_function(func_name, return_type, arg_types, is_result_nullable, fn.binary_type,
state->func_version());
if (func == nullptr) {
Expand Down
6 changes: 0 additions & 6 deletions be/src/exprs/agg/factory/aggregate_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,6 @@ static const AggregateFunction* get_function(const std::string& name, LogicalTyp
}
}

if (func_version > 6) {
if (name == "group_concat") {
func_name = "group_concat2";
}
}

if (binary_type == TFunctionBinaryType::BUILTIN) {
auto func = AggregateFuncResolver::instance()->get_aggregate_info(func_name, arg_type, return_type,
is_window_function, is_null);
Expand Down
14 changes: 8 additions & 6 deletions be/src/exprs/agg/group_concat.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,18 @@ class GroupConcatAggregateFunction
std::string& result = this->data(state).intermediate_string;

Slice val = column_val->get_slice(row_num);
auto separator = ctx->get_ctx_query_options().default_group_concat_separator;
//DEFAULT sep_length.
if (!this->data(state).initial) {
this->data(state).initial = true;

// separator's length;
uint32_t size = 2;
uint32_t size = separator.size();
result.append(reinterpret_cast<const char*>(&size), sizeof(uint32_t))
.append(", ")
.append(separator)
.append(val.get_data(), val.get_size());
} else {
result.append(", ").append(val.get_data(), val.get_size());
result.append(separator).append(val.get_data(), val.get_size());
}
}
}
Expand Down Expand Up @@ -258,8 +259,9 @@ class GroupConcatAggregateFunction
const auto* column_value = down_cast<BinaryColumn*>(src[0].get());

if (chunk_size > 0) {
const char* sep = ", ";
const uint32_t size_sep = 2;
auto separator = ctx->get_ctx_query_options().default_group_concat_separator;
auto sep = separator.data();
const uint32_t size_sep = separator.size();

size_t old_size = bytes.size();
CHECK_EQ(old_size, 0);
Expand Down Expand Up @@ -666,7 +668,7 @@ class GroupConcatAggregateFunctionV2

bytes.resize(offset + length);
bool overflow = false;
size_t limit = ctx->get_group_concat_max_len() + offset;
size_t limit = ctx->get_ctx_query_options().group_concat_max_len + offset;
auto last_unique_row_id = elem_size - 1;
for (auto i = elem_size - 1; i >= 0; i--) {
auto idx = i;
Expand Down
19 changes: 15 additions & 4 deletions be/src/exprs/function_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,19 @@ class FunctionContext {
THREAD_LOCAL,
};

// Query options for query runtime.
struct QueryOptions {
// min value is 4, default is 1024
ssize_t group_concat_max_len = 1024;
std::string default_group_concat_separator = ",";

public:
void set_group_concat_max_len(ssize_t len) { this->group_concat_max_len = len < 4 ? 4 : len; }
void set_default_group_concat_separator(std::string separator) {
this->default_group_concat_separator = separator;
}
};

/// Create a FunctionContext for a UDF. Caller is responsible for deleting it.
static FunctionContext* create_context(RuntimeState* state, MemPool* pool,
const FunctionContext::TypeDesc& return_type,
Expand Down Expand Up @@ -168,9 +181,7 @@ class FunctionContext {

void release_mems();

ssize_t get_group_concat_max_len() { return group_concat_max_len; }
// min value is 4, default is 1024
void set_group_concat_max_len(ssize_t len) { group_concat_max_len = len < 4 ? 4 : len; }
FunctionContext::QueryOptions& get_ctx_query_options() { return _query_options; }

bool error_if_overflow() const;

Expand Down Expand Up @@ -223,7 +234,7 @@ class FunctionContext {
std::vector<bool> _is_asc_order;
std::vector<bool> _nulls_first;
bool _is_distinct = false;
ssize_t group_concat_max_len = 1024;
QueryOptions _query_options;

// used for ngram bloom filter to speed up some function
std::unique_ptr<NgramBloomFilterState> _ngramState;
Expand Down
30 changes: 29 additions & 1 deletion be/test/exprs/agg/aggregate_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1336,7 +1336,7 @@ TEST_F(AggregateTest, test_bitmap_nullable) {
ASSERT_EQ(50, result_data.get_data()[0]);
}

TEST_F(AggregateTest, test_group_concat) {
TEST_F(AggregateTest, test_group_concat1) {
const AggregateFunction* group_concat_function =
get_aggregate_function("group_concat", TYPE_VARCHAR, TYPE_VARCHAR, false);
auto state = ManagedAggrState::create(ctx, group_concat_function);
Expand All @@ -1352,6 +1352,8 @@ TEST_F(AggregateTest, test_group_concat) {
const Column* row_column = data_column.get();

// test update
auto& query_options = ctx->get_ctx_query_options();
query_options.set_default_group_concat_separator(", ");
group_concat_function->update_batch_single_state(ctx, data_column->size(), &row_column, state->state());

auto result_column = BinaryColumn::create();
Expand All @@ -1360,6 +1362,32 @@ TEST_F(AggregateTest, test_group_concat) {
ASSERT_EQ("starrocks0, starrocks1, starrocks2, starrocks3, starrocks4, starrocks5", result_column->get_data()[0]);
}

TEST_F(AggregateTest, test_group_concat2) {
const AggregateFunction* group_concat_function =
get_aggregate_function("group_concat", TYPE_VARCHAR, TYPE_VARCHAR, false);
auto state = ManagedAggrState::create(ctx, group_concat_function);

auto data_column = BinaryColumn::create();

for (int i = 0; i < 6; i++) {
std::string val("starrocks");
val.append(std::to_string(i));
data_column->append(val);
}

const Column* row_column = data_column.get();

// test update
auto& query_options = ctx->get_ctx_query_options();
query_options.set_default_group_concat_separator(",");
group_concat_function->update_batch_single_state(ctx, data_column->size(), &row_column, state->state());

auto result_column = BinaryColumn::create();
group_concat_function->finalize_to_column(ctx, state->state(), result_column.get());

ASSERT_EQ("starrocks0,starrocks1,starrocks2,starrocks3,starrocks4,starrocks5", result_column->get_data()[0]);
}

TEST_F(AggregateTest, test_group_concat_const_seperator) {
std::vector<TypeDescriptor> arg_types = {TypeDescriptor::from_logical_type(TYPE_VARCHAR),
TypeDescriptor::from_logical_type(TYPE_VARCHAR)};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import com.google.common.collect.Lists;
import com.starrocks.catalog.Function;
import com.starrocks.common.io.Writable;
import org.apache.commons.collections.CollectionUtils;

import java.io.DataInput;
import java.io.DataOutput;
Expand Down Expand Up @@ -112,7 +113,7 @@ public int getOrderByElemNum() {
}

public String getOrderByStringToSql() {
if (orderByElements != null && !orderByElements.isEmpty()) {
if (!CollectionUtils.isEmpty(orderByElements)) {
StringBuilder sb = new StringBuilder();
sb.append(" ORDER BY ").append(orderByElements.stream().map(OrderByElement::toSql).
collect(Collectors.joining(" ")));
Expand All @@ -123,7 +124,7 @@ public String getOrderByStringToSql() {
}

public String getOrderByStringToExplain() {
if (orderByElements != null && !orderByElements.isEmpty()) {
if (!CollectionUtils.isEmpty(orderByElements)) {
StringBuilder sb = new StringBuilder();
sb.append(" ORDER BY ").append(orderByElements.stream().map(OrderByElement::explain).
collect(Collectors.joining(" ")));
Expand Down
20 changes: 20 additions & 0 deletions fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,12 @@ public class FunctionSet {
public static final String CONCAT_WS = "concat_ws";
public static final String ENDS_WITH = "ends_with";
public static final String FIND_IN_SET = "find_in_set";
// multi version of group_concat:
// group_concat : no distinct or order by arguments
// group_concat2: with distinct or order by arguments
public static final String GROUP_CONCAT = "group_concat";
public static final String GROUP_CONCAT_V2 = "group_concat2";

public static final String INSTR = "instr";
public static final String LCASE = "lcase";
public static final String LEFT = "left";
Expand Down Expand Up @@ -578,6 +583,12 @@ public class FunctionSet {
ImmutableSet.<Type>builder()
.addAll(Type.INTEGER_TYPES)
.build();

public static final Set<String> GROUP_CONCAT_FUNCS =
ImmutableSortedSet.orderedBy(String.CASE_INSENSITIVE_ORDER)
.add(GROUP_CONCAT)
.add(GROUP_CONCAT_V2)
.build();
/**
* Use for vectorized engine, but we can't use vectorized function directly, because we
* need to check whether the expression tree can use vectorized function from bottom to
Expand Down Expand Up @@ -1050,7 +1061,16 @@ private void initAggregateBuiltins() {
Lists.newArrayList(Type.ANY_ELEMENT), Type.ANY_ARRAY, Type.ANY_STRUCT, true,
true, false, false));

// group_concat(string)
addBuiltin(AggregateFunction.createBuiltin(GROUP_CONCAT,
Lists.newArrayList(Type.VARCHAR), Type.VARCHAR, Type.VARBINARY,
false, false, false));
// group_concat(string, string)
addBuiltin(AggregateFunction.createBuiltin(GROUP_CONCAT,
Lists.newArrayList(Type.VARCHAR, Type.VARCHAR), Type.VARCHAR, Type.VARBINARY,
false, false, false));
// group_concat with distinct or order by arguments
addBuiltin(AggregateFunction.createBuiltin(GROUP_CONCAT_V2,
Lists.newArrayList(Type.ANY_ELEMENT), Type.VARCHAR, Type.ANY_STRUCT, true,
false, false, false));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4351,6 +4351,9 @@ public TQueryOptions toThrift() {

tResult.setTransmission_encode_level(transmissionEncodeLevel);
tResult.setGroup_concat_max_len(groupConcatMaxLen);
if (SqlModeHelper.check(sqlMode, SqlModeHelper.MODE_GROUP_CONCAT_LEGACY)) {
tResult.setDefault_group_concat_separator(", ");
}
tResult.setRpc_http_min_size(rpcHttpMinSize);
tResult.setInterleaving_group_size(interleavingGroupSize);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1056,7 +1056,11 @@ public String visitFunctionCall(FunctionCallExpr node, Void context) {
sb.append("`" + node.getFnName().getDb() + "`.");
}
String functionName = node.getFnName().getFunction();
sb.append(functionName);
if (functionName.equals(FunctionSet.GROUP_CONCAT_V2)) {
sb.append(FunctionSet.GROUP_CONCAT);
} else {
sb.append(functionName);
}

sb.append("(");
if (fnParams.isStar()) {
Expand All @@ -1075,9 +1079,9 @@ public String visitFunctionCall(FunctionCallExpr node, Void context) {
StringLiteral boundary = (StringLiteral) node.getChild(3);
sb.append(", ").append(boundary.getValue());
sb.append(")");
} else if (functionName.equals(FunctionSet.ARRAY_AGG) || functionName.equals(FunctionSet.GROUP_CONCAT)) {
} else if (functionName.equals(FunctionSet.ARRAY_AGG) || functionName.equals(FunctionSet.GROUP_CONCAT_V2)) {
int end = 1;
if (functionName.equals(FunctionSet.GROUP_CONCAT)) {
if (functionName.equals(FunctionSet.GROUP_CONCAT_V2)) {
end = fnParams.exprs().size() - fnParams.getOrderByElemNum() - 1;
}
for (int i = 0; i < end && i < node.getChildren().size(); ++i) {
Expand All @@ -1087,10 +1091,12 @@ public String visitFunctionCall(FunctionCallExpr node, Void context) {
sb.append(visit(node.getChild(i)));
}
List<OrderByElement> sortClause = fnParams.getOrderByElements();
if (sortClause != null) {
if (!CollectionUtils.isEmpty(sortClause)) {
sb.append(" ORDER BY ").append(visitAstList(sortClause));
}
if (functionName.equals(FunctionSet.GROUP_CONCAT) && end < node.getChildren().size() && end > 0) {
boolean isGroupConcatV2 = functionName.equals(FunctionSet.GROUP_CONCAT_V2) &&
(fnParams.isDistinct() || !CollectionUtils.isEmpty(sortClause));
if (isGroupConcatV2 && end < node.getChildren().size() && end > 0) {
sb.append(" SEPARATOR ");
sb.append(visit(node.getChild(end)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1153,13 +1153,38 @@ private void checkFunction(String fnName, FunctionCallExpr node, Type[] argument
" can't cast from " + node.getChild(1).getType().toSql() + " to ARRAY<BOOL>");
}
break;
case FunctionSet.GROUP_CONCAT:
case FunctionSet.GROUP_CONCAT: {
if (node.getChildren().size() > 2 || node.getChildren().isEmpty()) {
throw new SemanticException(
"group_concat requires one or two parameters: " + node.toSql(),
node.getPos());
}
if (node.getParams().isDistinct()) {
throw new SemanticException("group_concat does not support DISTINCT", node.getPos());
}
Expr arg0 = node.getChild(0);
if (!Type.canCastTo(arg0.getType(), Type.VARCHAR)) {
throw new SemanticException(
"group_concat requires first parameter to be of getType() STRING: " + node.toSql(),
arg0.getPos());
}
if (node.getChildren().size() == 2) {
Expr arg1 = node.getChild(1);
if (!Type.canCastTo(arg1.getType(), Type.VARCHAR)) {
throw new SemanticException(
"group_concat requires second parameter to be of getType() STRING: " +
node.toSql(), arg1.getPos());
}
}
break;
}
case FunctionSet.GROUP_CONCAT_V2:
case FunctionSet.ARRAY_AGG: {
if (node.getChildren().size() == 0) {
throw new SemanticException(fnName + " should have at least one input", node.getPos());
}
int start = argumentTypes.length - node.getParams().getOrderByElemNum();
if (fnName.equals(FunctionSet.GROUP_CONCAT)) {
if (fnName.equals(FunctionSet.GROUP_CONCAT_V2)) {
if (start < 2) {
throw new SemanticException(fnName + " should have output expressions before [ORDER BY]",
node.getPos());
Expand Down
Loading

0 comments on commit 970caa0

Please sign in to comment.