Skip to content

Commit

Permalink
[Feature] (Part2) Support _state/_union/_merge aggregate funciton com…
Browse files Browse the repository at this point in the history
…binator (#50425)

Signed-off-by: shuming.li <[email protected]>
  • Loading branch information
LiShuMing authored Sep 18, 2024
1 parent 15f6eda commit 5b5c277
Show file tree
Hide file tree
Showing 38 changed files with 4,715 additions and 201 deletions.
160 changes: 110 additions & 50 deletions be/src/exec/aggregator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include "exec/limited_pipeline_chunk_buffer.h"
#include "exec/pipeline/operator.h"
#include "exec/spill/spiller.hpp"
#include "exprs/agg/agg_state_merge.h"
#include "exprs/agg/agg_state_union.h"
#include "exprs/agg/aggregate_state_allocator.h"
#include "exprs/anyval_util.h"
#include "gen_cpp/PlanNodes_types.h"
Expand All @@ -44,6 +46,10 @@ namespace starrocks {
static const std::unordered_set<std::string> ALWAYS_NULLABLE_RESULT_AGG_FUNCS = {"variance_samp", "var_samp",
"stddev_samp", "covar_samp", "corr"};

static const std::string AGG_STATE_UNION_SUFFIX = "_union";
static const std::string AGG_STATE_MERGE_SUFFIX = "_merge";
static const std::string FUNCTION_COUNT = "count";

template <bool UseIntermediateAsOutput>
bool AggFunctionTypes::is_result_nullable() const {
if constexpr (UseIntermediateAsOutput) {
Expand Down Expand Up @@ -151,24 +157,25 @@ void AggregatorParams::init() {
for (size_t i = 0; i < agg_size; ++i) {
const TExpr& desc = aggregate_functions[i];
const TFunction& fn = desc.nodes[0].fn;
VLOG_ROW << fn.name.function_name << " is arg nullable " << desc.nodes[0].has_nullable_child;
VLOG_ROW << fn.name.function_name << " is result nullable " << desc.nodes[0].is_nullable;
VLOG_ROW << fn.name.function_name << ", arg nullable " << desc.nodes[0].has_nullable_child
<< ", result nullable " << desc.nodes[0].is_nullable;

if (fn.name.function_name == "count") {
std::vector<FunctionContext::TypeDesc> arg_typedescs;
agg_fn_types[i] = {TypeDescriptor(TYPE_BIGINT), TypeDescriptor(TYPE_BIGINT), arg_typedescs, false, false};
if (fn.name.function_name == FUNCTION_COUNT) {
// count function is always not nullable
agg_fn_types[i] = {TypeDescriptor(TYPE_BIGINT), TypeDescriptor(TYPE_BIGINT), {}, false, false};
} else {
TypeDescriptor return_type = TypeDescriptor::from_thrift(fn.ret_type);
TypeDescriptor serde_type = TypeDescriptor::from_thrift(fn.aggregate_fn.intermediate_type);

// whether agg function has nullable child
const bool has_nullable_child = has_outer_join_child || desc.nodes[0].has_nullable_child;
// whether agg function is nullable
bool is_nullable = desc.nodes[0].is_nullable;
// collect arg_typedescs for aggregate function.
std::vector<FunctionContext::TypeDesc> arg_typedescs;
for (auto& type : fn.arg_types) {
arg_typedescs.push_back(AnyValUtil::column_type_to_type_desc(TypeDescriptor::from_thrift(type)));
}

const bool is_input_nullable = has_outer_join_child || desc.nodes[0].has_nullable_child;
agg_fn_types[i] = {return_type, serde_type, arg_typedescs, is_input_nullable, desc.nodes[0].is_nullable};
TypeDescriptor return_type = TypeDescriptor::from_thrift(fn.ret_type);
TypeDescriptor serde_type = TypeDescriptor::from_thrift(fn.aggregate_fn.intermediate_type);
agg_fn_types[i] = {return_type, serde_type, arg_typedescs, has_nullable_child, is_nullable};
agg_fn_types[i].is_always_nullable_result =
ALWAYS_NULLABLE_RESULT_AGG_FUNCS.contains(fn.name.function_name);
if (fn.name.function_name == "array_agg" || fn.name.function_name == "group_concat") {
Expand Down Expand Up @@ -355,7 +362,6 @@ Status Aggregator::prepare(RuntimeState* state, ObjectPool* pool, RuntimeProfile
}

bool has_outer_join_child = _params->has_outer_join_child;
VLOG_ROW << "has_outer_join_child " << has_outer_join_child;

size_t group_by_size = _group_by_expr_ctxs.size();
_group_by_columns.resize(group_by_size);
Expand Down Expand Up @@ -383,20 +389,9 @@ Status Aggregator::prepare(RuntimeState* state, ObjectPool* pool, RuntimeProfile
_is_merge_funcs[i] = aggregate_functions[i].nodes[0].agg_expr.is_merge_agg;

// get function
auto is_result_nullable_func = [&]() {
if (fn.name.function_name == "count") {
if (fn.arg_types.empty()) {
return false;
}
if (has_outer_join_child || desc.nodes[0].has_nullable_child) {
return true;
}
return false;
} else {
return agg_fn_type.use_nullable_fn(_use_intermediate_as_output());
}
};
RETURN_IF_ERROR(_create_aggregate_function(state, fn, is_result_nullable_func(), &_agg_functions[i]));
bool is_result_nullable = _is_agg_result_nullable(desc, agg_fn_type);
RETURN_IF_ERROR(_create_aggregate_function(state, fn, is_result_nullable, &_agg_functions[i]));
VLOG_ROW << "has_outer_join_child " << has_outer_join_child << ", is_result_nullable " << is_result_nullable;

int node_idx = 0;
for (int j = 0; j < desc.nodes[0].num_children; ++j) {
Expand Down Expand Up @@ -454,10 +449,27 @@ Status Aggregator::prepare(RuntimeState* state, ObjectPool* pool, RuntimeProfile

// Initial for FunctionContext of every aggregate functions
for (int i = 0; i < _agg_fn_ctxs.size(); ++i) {
_agg_fn_ctxs[i] = FunctionContext::create_context(
state, _mem_pool.get(), AnyValUtil::column_type_to_type_desc(_agg_fn_types[i].result_type),
_agg_fn_types[i].arg_typedescs, _agg_fn_types[i].is_distinct, _agg_fn_types[i].is_asc_order,
_agg_fn_types[i].nulls_first);
auto& agg_fn_type = _agg_fn_types[i];
auto& agg_func = _agg_functions[i];
TypeDescriptor return_type = AnyValUtil::column_type_to_type_desc(agg_fn_type.result_type);
std::vector<TypeDescriptor> arg_types = agg_fn_type.arg_typedescs;

const AggStateDesc* agg_state_desc = nullptr;
if (dynamic_cast<const AggStateUnion*>(agg_func)) {
auto* agg_state_union = down_cast<const AggStateUnion*>(agg_func);
agg_state_desc = agg_state_union->get_agg_state_desc();
} else if (dynamic_cast<const AggStateMerge*>(agg_func)) {
auto* agg_state_merge = down_cast<const AggStateMerge*>(agg_func);
agg_state_desc = agg_state_merge->get_agg_state_desc();
}
if (agg_state_desc != nullptr) {
return_type = agg_state_desc->get_return_type();
arg_types = agg_state_desc->get_arg_types();
}

_agg_fn_ctxs[i] =
FunctionContext::create_context(state, _mem_pool.get(), return_type, arg_types, agg_fn_type.is_distinct,
agg_fn_type.is_asc_order, agg_fn_type.nulls_first);
if (state->query_options().__isset.group_concat_max_len) {
_agg_fn_ctxs[i]->set_group_concat_max_len(state->query_options().group_concat_max_len);
}
Expand All @@ -479,6 +491,19 @@ Status Aggregator::prepare(RuntimeState* state, ObjectPool* pool, RuntimeProfile
return Status::OK();
}

bool Aggregator::_is_agg_result_nullable(const TExpr& desc, const AggFunctionTypes& agg_func_type) {
const TFunction& fn = desc.nodes[0].fn;
// NOTE: For count, we cannot use agg_func_type since it's only mocked valeus.
if (fn.name.function_name == FUNCTION_COUNT) {
if (fn.arg_types.empty()) {
return false;
}
return _params->has_outer_join_child || desc.nodes[0].has_nullable_child;
} else {
return agg_func_type.use_nullable_fn(_use_intermediate_as_output());
}
}

Status Aggregator::_create_aggregate_function(starrocks::RuntimeState* state, const TFunction& fn,
bool is_result_nullable, const AggregateFunction** ret) {
std::vector<TypeDescriptor> arg_types;
Expand All @@ -488,29 +513,64 @@ Status Aggregator::_create_aggregate_function(starrocks::RuntimeState* state, co

// check whether it's _merge/_union combinator if it contains agg state type
auto& func_name = fn.name.function_name;
// get function
if (func_name == "count") {
auto* func = get_aggregate_function("count", TYPE_BIGINT, TYPE_BIGINT, is_result_nullable);
if (func == nullptr) {
return Status::InternalError(strings::Substitute("Invalid agg function plan: $0 ", func_name));
if (fn.__isset.agg_state_desc) {
if (arg_types.size() != 1) {
return Status::InternalError(strings::Substitute("Invalid agg function plan: $0 with (arg type $1)",
func_name, arg_types.size()));
}
auto agg_state_desc = AggStateDesc::from_thrift(fn.agg_state_desc);
auto nested_func_name = agg_state_desc.get_func_name();
if (nested_func_name + AGG_STATE_MERGE_SUFFIX == func_name) {
// aggregate _merge combinator
auto* nested_func = AggStateDesc::get_agg_state_func(&agg_state_desc);
if (nested_func == nullptr) {
return Status::InternalError(
strings::Substitute("Merge combinator function $0 fails to get the nested agg func: $1 ",
func_name, nested_func_name));
}
auto merge_agg_func = std::make_shared<AggStateMerge>(std::move(agg_state_desc), nested_func);
*ret = merge_agg_func.get();
_combinator_function.emplace_back(std::move(merge_agg_func));
} else if (nested_func_name + AGG_STATE_UNION_SUFFIX == func_name) {
// aggregate _union combinator
auto* nested_func = AggStateDesc::get_agg_state_func(&agg_state_desc);
if (nested_func == nullptr) {
return Status::InternalError(
strings::Substitute("Union combinator function $0 fails to get the nested agg func: $1 ",
func_name, nested_func_name));
}
auto union_agg_func = std::make_shared<AggStateUnion>(std::move(agg_state_desc), nested_func);
*ret = union_agg_func.get();
_combinator_function.emplace_back(std::move(union_agg_func));
} else {
return Status::InternalError(
strings::Substitute("Agg function combinator is not implemented: $0 ", func_name));
}
*ret = func;
} else {
TypeDescriptor return_type = TypeDescriptor::from_thrift(fn.ret_type);
TypeDescriptor serde_type = TypeDescriptor::from_thrift(fn.aggregate_fn.intermediate_type);
DCHECK_LE(1, fn.arg_types.size());
TypeDescriptor arg_type = arg_types[0];
auto* func = get_aggregate_function(func_name, return_type, arg_types, is_result_nullable, fn.binary_type,
state->func_version());
if (func == nullptr) {
return Status::InternalError(strings::Substitute(
"Invalid agg function plan: $0 with (arg type $1, serde type $2, result type $3, nullable $4)",
func_name, type_to_string(arg_type.type), type_to_string(serde_type.type),
type_to_string(return_type.type), is_result_nullable ? "true" : "false"));
// get function
if (func_name == FUNCTION_COUNT) {
auto* func = get_aggregate_function(FUNCTION_COUNT, TYPE_BIGINT, TYPE_BIGINT, is_result_nullable);
if (func == nullptr) {
return Status::InternalError(strings::Substitute("Invalid agg function plan: $0 ", func_name));
}
*ret = func;
} else {
TypeDescriptor return_type = TypeDescriptor::from_thrift(fn.ret_type);
TypeDescriptor serde_type = TypeDescriptor::from_thrift(fn.aggregate_fn.intermediate_type);
DCHECK_LE(1, fn.arg_types.size());
TypeDescriptor arg_type = arg_types[0];
auto* func = get_aggregate_function(func_name, return_type, arg_types, is_result_nullable, fn.binary_type,
state->func_version());
if (func == nullptr) {
return Status::InternalError(strings::Substitute(
"Invalid agg function plan: $0 with (arg type $1, serde type $2, result type $3, nullable $4)",
func_name, type_to_string(arg_type.type), type_to_string(serde_type.type),
type_to_string(return_type.type), is_result_nullable ? "true" : "false"));
}
*ret = func;
VLOG_ROW << "get agg function " << func->get_name() << " serde_type " << serde_type << " return_type "
<< return_type;
}
*ret = func;
VLOG_ROW << "get agg function " << func->get_name() << " serde_type " << serde_type << " return_type "
<< return_type;
}
return Status::OK();
}
Expand Down
11 changes: 8 additions & 3 deletions be/src/exec/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -400,9 +400,6 @@ class Aggregator : public pipeline::ContextWithDependency {

bool is_streaming_all_states() const { return _streaming_all_states; }

Status _create_aggregate_function(starrocks::RuntimeState* state, const TFunction& fn, bool is_result_nullable,
const AggregateFunction** ret);

HashTableKeyAllocator _state_allocator;

protected:
Expand Down Expand Up @@ -510,6 +507,9 @@ class Aggregator : public pipeline::ContextWithDependency {
bool _is_prepared = false;
int64_t _agg_state_mem_usage = 0;

// aggregate combinator functions since they are not persisted in agg hash map
std::vector<AggregateFunctionPtr> _combinator_function;

public:
void build_hash_map(size_t chunk_size, bool agg_group_by_with_limit = false);
void build_hash_map(size_t chunk_size, std::atomic<int64_t>& shared_limit_countdown, bool agg_group_by_with_limit);
Expand Down Expand Up @@ -575,6 +575,11 @@ class Aggregator : public pipeline::ContextWithDependency {

void _release_agg_memory();

bool _is_agg_result_nullable(const TExpr& desc, const AggFunctionTypes& agg_func_type);

Status _create_aggregate_function(starrocks::RuntimeState* state, const TFunction& fn, bool is_result_nullable,
const AggregateFunction** ret);

template <class HashMapWithKey>
friend struct AllocateState;
};
Expand Down
90 changes: 90 additions & 0 deletions be/src/exprs/agg/agg_state_merge.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright 2021-present StarRocks, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "column/vectorized_fwd.h"
#include "exprs/agg/aggregate.h"

namespace starrocks {
struct AggStateMergeState {};

/**
* @brief Merge combinator for aggregate function to merge the agg state to return the final result of aggregate function.
* DESC: return_type {agg_func}_merge(immediate_type)
* input type : aggregate function's immediate_type
* intermediate type : aggregate function's immediate_type
* return type : aggregate function's return type
*/
class AggStateMerge final : public AggregateFunctionBatchHelper<AggStateMergeState, AggStateMerge> {
public:
AggStateMerge(AggStateDesc agg_state_desc, const AggregateFunction* function)
: _agg_state_desc(std::move(agg_state_desc)), _function(function) {
DCHECK(_function != nullptr);
}
const AggStateDesc* get_agg_state_desc() const { return &_agg_state_desc; }

void create(FunctionContext* ctx, AggDataPtr __restrict ptr) const override { _function->create(ctx, ptr); }

void destroy(FunctionContext* ctx, AggDataPtr __restrict ptr) const override { _function->destroy(ctx, ptr); }

size_t size() const override { return _function->size(); }

size_t alignof_size() const override { return _function->alignof_size(); }

bool is_pod_state() const override { return _function->is_pod_state(); }

void reset(FunctionContext* ctx, const Columns& args, AggDataPtr state) const override {
_function->reset(ctx, args, state);
}

void update(FunctionContext* ctx, const Column** columns, AggDataPtr __restrict state,
size_t row_num) const override {
_function->merge(ctx, columns[0], state, row_num);
}

void merge(FunctionContext* ctx, const Column* column, AggDataPtr __restrict state, size_t row_num) const override {
_function->merge(ctx, column, state, row_num);
}

void get_values(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* dst, size_t start,
size_t end) const override {
DCHECK_GT(end, start);
_function->get_values(ctx, state, dst, start, end);
}

void serialize_to_column([[maybe_unused]] FunctionContext* ctx, ConstAggDataPtr __restrict state,
Column* to) const override {
_function->serialize_to_column(ctx, state, to);
}

void convert_to_serialize_format([[maybe_unused]] FunctionContext* ctx, const Columns& srcs, size_t chunk_size,
ColumnPtr* dst) const override {
DCHECK_EQ(1, srcs.size());
*dst = srcs[0];
}

void finalize_to_column(FunctionContext* ctx __attribute__((unused)), ConstAggDataPtr __restrict state,
Column* to) const override {
_function->finalize_to_column(ctx, state, to);
}

std::string get_name() const override { return "agg_state_merge"; }

private:
const AggStateDesc _agg_state_desc;
const AggregateFunction* _function;
};

} // namespace starrocks
Loading

0 comments on commit 5b5c277

Please sign in to comment.