Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Refactor hll_sketch and fix stack use after return bug #50600

Merged
merged 3 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 50 additions & 65 deletions be/src/exec/aggregator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,20 +203,6 @@ void AggregatorParams::init() {
#define ALIGN_TO(size, align) ((size + align - 1) / align * align)
#define PAD(size, align) (align - (size % align)) % align;

class ThreadLocalStateAllocatorSetter {
public:
ThreadLocalStateAllocatorSetter(Allocator* allocator)
: _agg_state_allocator_setter(allocator), _roaring_allocator_setter(allocator) {}
~ThreadLocalStateAllocatorSetter() = default;

private:
ThreadLocalAggregateStateAllocatorSetter _agg_state_allocator_setter;
ThreadLocalRoaringAllocatorSetter _roaring_allocator_setter;
};

#define SCOPED_THREAD_LOCAL_STATE_ALLOCATOR_SETTER(allocator) \
auto VARNAME_LINENUM(alloc_setter) = ThreadLocalStateAllocatorSetter(allocator)

Aggregator::Aggregator(AggregatorParamsPtr params) : _params(std::move(params)) {
_allocator = std::make_unique<CountingAllocatorWithHook>();
}
Expand Down Expand Up @@ -395,59 +381,22 @@ Status Aggregator::prepare(RuntimeState* state, ObjectPool* pool, RuntimeProfile
const TFunction& fn = desc.nodes[0].fn;
const auto& agg_fn_type = _agg_fn_types[i];
_is_merge_funcs[i] = aggregate_functions[i].nodes[0].agg_expr.is_merge_agg;
// get function
if (fn.name.function_name == "count") {
bool is_input_nullable =
!fn.arg_types.empty() && (has_outer_join_child || desc.nodes[0].has_nullable_child);
auto* func = get_aggregate_function("count", TYPE_BIGINT, TYPE_BIGINT, is_input_nullable);
_agg_functions[i] = func;
} else {
TypeDescriptor return_type = TypeDescriptor::from_thrift(fn.ret_type);
TypeDescriptor serde_type = TypeDescriptor::from_thrift(fn.aggregate_fn.intermediate_type);

TypeDescriptor arg_type = TypeDescriptor::from_thrift(fn.arg_types[0]);
// Because intersect_count have two input types.
// And intersect_count's first argument's type is alwasy Bitmap,
// so we use its second arguments type as input.
if (fn.name.function_name == "intersect_count") {
arg_type = TypeDescriptor::from_thrift(fn.arg_types[1]);
}

// Because max_by and min_by function have two input types,
// so we use its second arguments type as input.
if (fn.name.function_name == "max_by" || fn.name.function_name == "min_by") {
arg_type = TypeDescriptor::from_thrift(fn.arg_types[1]);
}

// Because windowfunnel have more two input types.
// functions registry use 2th args(datetime/date).
if (fn.name.function_name == "window_funnel") {
arg_type = TypeDescriptor::from_thrift(fn.arg_types[1]);
}

// hack for accepting various arguments
if (fn.name.function_name == "exchange_bytes" || fn.name.function_name == "exchange_speed") {
arg_type = TypeDescriptor(TYPE_BIGINT);
}

if (fn.name.function_name == "array_union_agg" || fn.name.function_name == "array_unique_agg") {
// for array_union_agg use inner type as signature
arg_type = arg_type.children[0];
}

const bool use_nullable_fn = agg_fn_type.use_nullable_fn(_use_intermediate_as_output());
auto* func = get_aggregate_function(fn.name.function_name, arg_type.type, return_type.type, use_nullable_fn,
fn.binary_type, state->func_version());
if (func == nullptr) {
return Status::InternalError(strings::Substitute(
"Invalid agg function plan: $0 with (arg type $1, serde type $2, result type $3, nullable $4)",
fn.name.function_name, type_to_string(arg_type.type), type_to_string(serde_type.type),
type_to_string(return_type.type), use_nullable_fn ? "true" : "false"));
// get function
auto is_result_nullable_func = [&]() {
if (fn.name.function_name == "count") {
if (fn.arg_types.empty()) {
return false;
}
if (has_outer_join_child || desc.nodes[0].has_nullable_child) {
return true;
}
return false;
} else {
return agg_fn_type.use_nullable_fn(_use_intermediate_as_output());
}
VLOG_ROW << "get agg function " << func->get_name() << " serde_type " << serde_type << " return_type "
<< return_type;
_agg_functions[i] = func;
}
};
RETURN_IF_ERROR(_create_aggregate_function(state, fn, is_result_nullable_func(), &_agg_functions[i]));

int node_idx = 0;
for (int j = 0; j < desc.nodes[0].num_children; ++j) {
Expand Down Expand Up @@ -530,6 +479,42 @@ Status Aggregator::prepare(RuntimeState* state, ObjectPool* pool, RuntimeProfile
return Status::OK();
}

Status Aggregator::_create_aggregate_function(starrocks::RuntimeState* state, const TFunction& fn,
bool is_result_nullable, const AggregateFunction** ret) {
std::vector<TypeDescriptor> arg_types;
for (auto& type : fn.arg_types) {
arg_types.push_back(TypeDescriptor::from_thrift(type));
}

// check whether it's _merge/_union combinator if it contains agg state type
auto& func_name = fn.name.function_name;
// get function
if (func_name == "count") {
auto* func = get_aggregate_function("count", TYPE_BIGINT, TYPE_BIGINT, is_result_nullable);
if (func == nullptr) {
return Status::InternalError(strings::Substitute("Invalid agg function plan: $0 ", func_name));
}
*ret = func;
} else {
TypeDescriptor return_type = TypeDescriptor::from_thrift(fn.ret_type);
TypeDescriptor serde_type = TypeDescriptor::from_thrift(fn.aggregate_fn.intermediate_type);
DCHECK_LE(1, fn.arg_types.size());
TypeDescriptor arg_type = arg_types[0];
auto* func = get_aggregate_function(func_name, return_type, arg_types, is_result_nullable, fn.binary_type,
state->func_version());
if (func == nullptr) {
return Status::InternalError(strings::Substitute(
"Invalid agg function plan: $0 with (arg type $1, serde type $2, result type $3, nullable $4)",
func_name, type_to_string(arg_type.type), type_to_string(serde_type.type),
type_to_string(return_type.type), is_result_nullable ? "true" : "false"));
}
*ret = func;
VLOG_ROW << "get agg function " << func->get_name() << " serde_type " << serde_type << " return_type "
<< return_type;
}
return Status::OK();
}

Status Aggregator::reset_state(starrocks::RuntimeState* state, const std::vector<ChunkPtr>& refill_chunks,
pipeline::Operator* refill_op, bool reset_sink_complete) {
RETURN_IF_ERROR(_reset_state(state, reset_sink_complete));
LiShuMing marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
3 changes: 3 additions & 0 deletions be/src/exec/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,9 @@ class Aggregator : public pipeline::ContextWithDependency {

bool is_streaming_all_states() const { return _streaming_all_states; }

Status _create_aggregate_function(starrocks::RuntimeState* state, const TFunction& fn, bool is_result_nullable,
const AggregateFunction** ret);

HashTableKeyAllocator _state_allocator;

protected:
Expand Down
17 changes: 17 additions & 0 deletions be/src/exprs/agg/aggregate_state_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "column/hash_set.h"
#include "common/config.h"
#include "runtime/memory/allocator.h"
#include "runtime/memory/roaring_hook.h"

namespace starrocks {

Expand Down Expand Up @@ -92,4 +93,20 @@ using SliceHashSetWithAggStateAllocator = phmap::flat_hash_set<SliceWithHash, Ha

template <typename T>
using VectorWithAggStateAllocator = std::vector<T, AggregateStateAllocator<T>>;

// Thread local aggregate state allocator setter with roaring allocator
class ThreadLocalStateAllocatorSetter {
public:
ThreadLocalStateAllocatorSetter(Allocator* allocator)
: _agg_state_allocator_setter(allocator), _roaring_allocator_setter(allocator) {}
~ThreadLocalStateAllocatorSetter() = default;

private:
ThreadLocalAggregateStateAllocatorSetter _agg_state_allocator_setter;
ThreadLocalRoaringAllocatorSetter _roaring_allocator_setter;
};

#define SCOPED_THREAD_LOCAL_STATE_ALLOCATOR_SETTER(allocator) \
auto VARNAME_LINENUM(alloc_setter) = ThreadLocalStateAllocatorSetter(allocator)

} // namespace starrocks
3 changes: 2 additions & 1 deletion be/src/exprs/agg/avg.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,14 @@ class AvgAggregateFunction final

void finalize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to) const override {
DCHECK(!to->is_nullable());
auto* column = down_cast<ResultColumnType*>(to);
// In fact, for StarRocks real query, we don't need this check.
// But for robust, we add this check.
if (this->data(state).count == 0) {
column->append_default();
return;
}

auto* column = down_cast<ResultColumnType*>(to);
ResultType result;
if constexpr (lt_is_decimalv2<LT>) {
result = this->data(state).sum / DecimalV2Value(this->data(state).count, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
#include "column/vectorized_fwd.h"
#include "exprs/agg/aggregate.h"
#include "gutil/casts.h"
#include "types/hll_sketch.h"

namespace starrocks {

struct HLLSketchState {
std::unique_ptr<DataSketchesHll> hll_sketch = nullptr;
int64_t memory_usage = 0;
};

/**
Expand Down Expand Up @@ -101,10 +103,10 @@ class HllSketchAggregateFunction final
void merge(FunctionContext* ctx, const Column* column, AggDataPtr __restrict state, size_t row_num) const override {
DCHECK(column->is_binary());
const BinaryColumn* hll_column = down_cast<const BinaryColumn*>(column);
DataSketchesHll hll(hll_column->get(row_num).get_slice());
DataSketchesHll hll(hll_column->get(row_num).get_slice(), &(this->data(state).memory_usage));
if (UNLIKELY(this->data(state).hll_sketch == nullptr)) {
this->data(state).hll_sketch =
std::make_unique<DataSketchesHll>(hll.get_lg_config_k(), hll.get_target_type());
this->data(state).hll_sketch = std::make_unique<DataSketchesHll>(
hll.get_lg_config_k(), hll.get_target_type(), &(this->data(state).memory_usage));
}
int64_t prev_memory = this->data(state).hll_sketch->mem_usage();
this->data(state).hll_sketch->merge(hll);
Expand Down Expand Up @@ -159,7 +161,8 @@ class HllSketchAggregateFunction final
const Column** src_datas_ptr = src_datas.data();
std::tie(log_k, tgt_type) = _parse_hll_sketch_args(ctx, src_datas_ptr);
for (size_t i = 0; i < chunk_size; ++i) {
DataSketchesHll hll{log_k, tgt_type};
int64_t memory_usage = 0;
DataSketchesHll hll{log_k, tgt_type, &memory_usage};
if constexpr (lt_is_string<LT>) {
Slice s = column->get_slice(i);
value = HashUtil::murmur_hash64A(s.data, s.size, HashUtil::MURMUR_SEED);
Expand Down Expand Up @@ -201,7 +204,7 @@ class HllSketchAggregateFunction final
uint8_t log_k;
datasketches::target_hll_type tgt_type;
std::tie(log_k, tgt_type) = _parse_hll_sketch_args(ctx, columns);
this->data(state).hll_sketch = _init_hll_sketch(log_k, tgt_type);
this->data(state).hll_sketch = _init_hll_sketch(log_k, tgt_type, &(this->data(state).memory_usage));
}
}

Expand All @@ -228,9 +231,9 @@ class HllSketchAggregateFunction final
}

// init hll sketch with default log_k and target type
std::unique_ptr<DataSketchesHll> _init_hll_sketch(
uint8_t log_k = DEFAULT_HLL_LOG_K, datasketches::target_hll_type tgt_type = datasketches::HLL_6) const {
return std::make_unique<DataSketchesHll>(log_k, tgt_type);
std::unique_ptr<DataSketchesHll> _init_hll_sketch(uint8_t log_k, datasketches::target_hll_type tgt_type,
int64_t* memory_usage) const {
return std::make_unique<DataSketchesHll>(log_k, tgt_type, memory_usage);
}
};

Expand Down
2 changes: 1 addition & 1 deletion be/src/exprs/agg/factory/aggregate_factory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
#include "exprs/agg/count.h"
#include "exprs/agg/covariance.h"
#include "exprs/agg/distinct.h"
#include "exprs/agg/ds_hll_count_distinct.h"
#include "exprs/agg/exchange_perf.h"
#include "exprs/agg/group_concat.h"
#include "exprs/agg/histogram.h"
#include "exprs/agg/hll_ndv.h"
#include "exprs/agg/hll_sketch.h"
#include "exprs/agg/hll_union.h"
#include "exprs/agg/hll_union_count.h"
#include "exprs/agg/intersect_count.h"
Expand Down
9 changes: 3 additions & 6 deletions be/src/storage/chunk_aggregator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,10 @@
#include "exec/sorting/sorting.h"
#include "exprs/agg/aggregate_state_allocator.h"
#include "gutil/casts.h"
#include "runtime/memory/mem_hook_allocator.h"
#include "storage/column_aggregate_func.h"

namespace starrocks {

static MemHookAllocator kDefaultChunkAggregatorAllocator = MemHookAllocator{};

ChunkAggregator::ChunkAggregator(const starrocks::Schema* schema, uint32_t reserve_rows, uint32_t max_aggregate_rows,
double factor, bool is_vertical_merge, bool is_key)
: _schema(schema),
Expand Down Expand Up @@ -173,7 +170,7 @@ void ChunkAggregator::aggregate() {
_aggregate_loops[_aggregate_loops.size() - 1] += 1;
}
}
SCOPED_THREAD_LOCAL_AGG_STATE_ALLOCATOR_SETTER(&kDefaultChunkAggregatorAllocator);
SCOPED_THREAD_LOCAL_AGG_STATE_ALLOCATOR_SETTER(&kDefaultColumnAggregatorAllocator);
// 3. Copy the selected key rows
// 4. Aggregate the value rows
for (int i = 0; i < _key_fields; ++i) {
Expand All @@ -196,7 +193,7 @@ bool ChunkAggregator::is_finish() {
void ChunkAggregator::aggregate_reset() {
_aggregate_chunk = ChunkHelper::new_chunk(*_schema, _reserve_rows);
_aggregate_rows = 0;
SCOPED_THREAD_LOCAL_AGG_STATE_ALLOCATOR_SETTER(&kDefaultChunkAggregatorAllocator);
SCOPED_THREAD_LOCAL_AGG_STATE_ALLOCATOR_SETTER(&kDefaultColumnAggregatorAllocator);
for (int i = 0; i < _num_fields; ++i) {
auto p = _aggregate_chunk->get_column_by_index(i).get();
_column_aggregator[i]->update_aggregate(p);
Expand All @@ -210,7 +207,7 @@ void ChunkAggregator::aggregate_reset() {
}

ChunkPtr ChunkAggregator::aggregate_result() {
SCOPED_THREAD_LOCAL_AGG_STATE_ALLOCATOR_SETTER(&kDefaultChunkAggregatorAllocator);
SCOPED_THREAD_LOCAL_AGG_STATE_ALLOCATOR_SETTER(&kDefaultColumnAggregatorAllocator);
for (int i = 0; i < _num_fields; ++i) {
_column_aggregator[i]->finalize();
}
Expand Down
1 change: 1 addition & 0 deletions be/src/storage/column_aggregate_func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ class AggFuncBasedValueAggregator : public ValueColumnAggregatorBase {
}

~AggFuncBasedValueAggregator() override {
SCOPED_THREAD_LOCAL_AGG_STATE_ALLOCATOR_SETTER(&kDefaultColumnAggregatorAllocator);
if (_state != nullptr) {
_agg_func->destroy(_func_ctx, _state);
std::free(_state);
Expand Down
3 changes: 3 additions & 0 deletions be/src/storage/column_aggregate_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
#pragma once

#include "column/field.h"
#include "runtime/memory/mem_hook_allocator.h"
#include "storage/column_aggregator.h"

namespace starrocks {

static MemHookAllocator kDefaultColumnAggregatorAllocator = MemHookAllocator{};
class ColumnAggregatorFactory {
public:
static ColumnAggregatorPtr create_key_column_aggregator(const FieldPtr& field);
Expand Down
1 change: 1 addition & 0 deletions be/src/types/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_library(Types STATIC
bitmap_value.cpp
date_value.cpp
hll.cpp
hll_sketch.cpp
logical_type.cpp
map_type_info.cpp
struct_type_info.cpp
Expand Down
Loading
Loading