Skip to content

Commit

Permalink
[Enhancement] Refactor hll_sktch and fix stack use after return bug (#…
Browse files Browse the repository at this point in the history
…50600)

Signed-off-by: shuming.li <[email protected]>
  • Loading branch information
LiShuMing authored Sep 3, 2024
1 parent 8921afd commit 8c4dc97
Show file tree
Hide file tree
Showing 15 changed files with 362 additions and 284 deletions.
115 changes: 50 additions & 65 deletions be/src/exec/aggregator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,20 +203,6 @@ void AggregatorParams::init() {
#define ALIGN_TO(size, align) ((size + align - 1) / align * align)
#define PAD(size, align) (align - (size % align)) % align;

class ThreadLocalStateAllocatorSetter {
public:
ThreadLocalStateAllocatorSetter(Allocator* allocator)
: _agg_state_allocator_setter(allocator), _roaring_allocator_setter(allocator) {}
~ThreadLocalStateAllocatorSetter() = default;

private:
ThreadLocalAggregateStateAllocatorSetter _agg_state_allocator_setter;
ThreadLocalRoaringAllocatorSetter _roaring_allocator_setter;
};

#define SCOPED_THREAD_LOCAL_STATE_ALLOCATOR_SETTER(allocator) \
auto VARNAME_LINENUM(alloc_setter) = ThreadLocalStateAllocatorSetter(allocator)

Aggregator::Aggregator(AggregatorParamsPtr params) : _params(std::move(params)) {
_allocator = std::make_unique<CountingAllocatorWithHook>();
}
Expand Down Expand Up @@ -395,59 +381,22 @@ Status Aggregator::prepare(RuntimeState* state, ObjectPool* pool, RuntimeProfile
const TFunction& fn = desc.nodes[0].fn;
const auto& agg_fn_type = _agg_fn_types[i];
_is_merge_funcs[i] = aggregate_functions[i].nodes[0].agg_expr.is_merge_agg;
// get function
if (fn.name.function_name == "count") {
bool is_input_nullable =
!fn.arg_types.empty() && (has_outer_join_child || desc.nodes[0].has_nullable_child);
auto* func = get_aggregate_function("count", TYPE_BIGINT, TYPE_BIGINT, is_input_nullable);
_agg_functions[i] = func;
} else {
TypeDescriptor return_type = TypeDescriptor::from_thrift(fn.ret_type);
TypeDescriptor serde_type = TypeDescriptor::from_thrift(fn.aggregate_fn.intermediate_type);

TypeDescriptor arg_type = TypeDescriptor::from_thrift(fn.arg_types[0]);
// Because intersect_count have two input types.
// And intersect_count's first argument's type is alwasy Bitmap,
// so we use its second arguments type as input.
if (fn.name.function_name == "intersect_count") {
arg_type = TypeDescriptor::from_thrift(fn.arg_types[1]);
}

// Because max_by and min_by function have two input types,
// so we use its second arguments type as input.
if (fn.name.function_name == "max_by" || fn.name.function_name == "min_by") {
arg_type = TypeDescriptor::from_thrift(fn.arg_types[1]);
}

// Because windowfunnel have more two input types.
// functions registry use 2th args(datetime/date).
if (fn.name.function_name == "window_funnel") {
arg_type = TypeDescriptor::from_thrift(fn.arg_types[1]);
}

// hack for accepting various arguments
if (fn.name.function_name == "exchange_bytes" || fn.name.function_name == "exchange_speed") {
arg_type = TypeDescriptor(TYPE_BIGINT);
}

if (fn.name.function_name == "array_union_agg" || fn.name.function_name == "array_unique_agg") {
// for array_union_agg use inner type as signature
arg_type = arg_type.children[0];
}

const bool use_nullable_fn = agg_fn_type.use_nullable_fn(_use_intermediate_as_output());
auto* func = get_aggregate_function(fn.name.function_name, arg_type.type, return_type.type, use_nullable_fn,
fn.binary_type, state->func_version());
if (func == nullptr) {
return Status::InternalError(strings::Substitute(
"Invalid agg function plan: $0 with (arg type $1, serde type $2, result type $3, nullable $4)",
fn.name.function_name, type_to_string(arg_type.type), type_to_string(serde_type.type),
type_to_string(return_type.type), use_nullable_fn ? "true" : "false"));
// get function
auto is_result_nullable_func = [&]() {
if (fn.name.function_name == "count") {
if (fn.arg_types.empty()) {
return false;
}
if (has_outer_join_child || desc.nodes[0].has_nullable_child) {
return true;
}
return false;
} else {
return agg_fn_type.use_nullable_fn(_use_intermediate_as_output());
}
VLOG_ROW << "get agg function " << func->get_name() << " serde_type " << serde_type << " return_type "
<< return_type;
_agg_functions[i] = func;
}
};
RETURN_IF_ERROR(_create_aggregate_function(state, fn, is_result_nullable_func(), &_agg_functions[i]));

int node_idx = 0;
for (int j = 0; j < desc.nodes[0].num_children; ++j) {
Expand Down Expand Up @@ -530,6 +479,42 @@ Status Aggregator::prepare(RuntimeState* state, ObjectPool* pool, RuntimeProfile
return Status::OK();
}

Status Aggregator::_create_aggregate_function(starrocks::RuntimeState* state, const TFunction& fn,
bool is_result_nullable, const AggregateFunction** ret) {
std::vector<TypeDescriptor> arg_types;
for (auto& type : fn.arg_types) {
arg_types.push_back(TypeDescriptor::from_thrift(type));
}

// check whether it's _merge/_union combinator if it contains agg state type
auto& func_name = fn.name.function_name;
// get function
if (func_name == "count") {
auto* func = get_aggregate_function("count", TYPE_BIGINT, TYPE_BIGINT, is_result_nullable);
if (func == nullptr) {
return Status::InternalError(strings::Substitute("Invalid agg function plan: $0 ", func_name));
}
*ret = func;
} else {
TypeDescriptor return_type = TypeDescriptor::from_thrift(fn.ret_type);
TypeDescriptor serde_type = TypeDescriptor::from_thrift(fn.aggregate_fn.intermediate_type);
DCHECK_LE(1, fn.arg_types.size());
TypeDescriptor arg_type = arg_types[0];
auto* func = get_aggregate_function(func_name, return_type, arg_types, is_result_nullable, fn.binary_type,
state->func_version());
if (func == nullptr) {
return Status::InternalError(strings::Substitute(
"Invalid agg function plan: $0 with (arg type $1, serde type $2, result type $3, nullable $4)",
func_name, type_to_string(arg_type.type), type_to_string(serde_type.type),
type_to_string(return_type.type), is_result_nullable ? "true" : "false"));
}
*ret = func;
VLOG_ROW << "get agg function " << func->get_name() << " serde_type " << serde_type << " return_type "
<< return_type;
}
return Status::OK();
}

Status Aggregator::reset_state(starrocks::RuntimeState* state, const std::vector<ChunkPtr>& refill_chunks,
pipeline::Operator* refill_op, bool reset_sink_complete) {
RETURN_IF_ERROR(_reset_state(state, reset_sink_complete));
Expand Down
3 changes: 3 additions & 0 deletions be/src/exec/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,9 @@ class Aggregator : public pipeline::ContextWithDependency {

bool is_streaming_all_states() const { return _streaming_all_states; }

Status _create_aggregate_function(starrocks::RuntimeState* state, const TFunction& fn, bool is_result_nullable,
const AggregateFunction** ret);

HashTableKeyAllocator _state_allocator;

protected:
Expand Down
17 changes: 17 additions & 0 deletions be/src/exprs/agg/aggregate_state_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "column/hash_set.h"
#include "common/config.h"
#include "runtime/memory/allocator.h"
#include "runtime/memory/roaring_hook.h"

namespace starrocks {

Expand Down Expand Up @@ -92,4 +93,20 @@ using SliceHashSetWithAggStateAllocator = phmap::flat_hash_set<SliceWithHash, Ha

template <typename T>
using VectorWithAggStateAllocator = std::vector<T, AggregateStateAllocator<T>>;

// Thread local aggregate state allocator setter with roaring allocator
class ThreadLocalStateAllocatorSetter {
public:
ThreadLocalStateAllocatorSetter(Allocator* allocator)
: _agg_state_allocator_setter(allocator), _roaring_allocator_setter(allocator) {}
~ThreadLocalStateAllocatorSetter() = default;

private:
ThreadLocalAggregateStateAllocatorSetter _agg_state_allocator_setter;
ThreadLocalRoaringAllocatorSetter _roaring_allocator_setter;
};

#define SCOPED_THREAD_LOCAL_STATE_ALLOCATOR_SETTER(allocator) \
auto VARNAME_LINENUM(alloc_setter) = ThreadLocalStateAllocatorSetter(allocator)

} // namespace starrocks
3 changes: 2 additions & 1 deletion be/src/exprs/agg/avg.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,14 @@ class AvgAggregateFunction final

void finalize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to) const override {
DCHECK(!to->is_nullable());
auto* column = down_cast<ResultColumnType*>(to);
// In fact, for StarRocks real query, we don't need this check.
// But for robust, we add this check.
if (this->data(state).count == 0) {
column->append_default();
return;
}

auto* column = down_cast<ResultColumnType*>(to);
ResultType result;
if constexpr (lt_is_decimalv2<LT>) {
result = this->data(state).sum / DecimalV2Value(this->data(state).count, 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
#include "column/vectorized_fwd.h"
#include "exprs/agg/aggregate.h"
#include "gutil/casts.h"
#include "types/hll_sketch.h"

namespace starrocks {

struct HLLSketchState {
std::unique_ptr<DataSketchesHll> hll_sketch = nullptr;
int64_t memory_usage = 0;
};

/**
Expand Down Expand Up @@ -101,10 +103,10 @@ class HllSketchAggregateFunction final
void merge(FunctionContext* ctx, const Column* column, AggDataPtr __restrict state, size_t row_num) const override {
DCHECK(column->is_binary());
const BinaryColumn* hll_column = down_cast<const BinaryColumn*>(column);
DataSketchesHll hll(hll_column->get(row_num).get_slice());
DataSketchesHll hll(hll_column->get(row_num).get_slice(), &(this->data(state).memory_usage));
if (UNLIKELY(this->data(state).hll_sketch == nullptr)) {
this->data(state).hll_sketch =
std::make_unique<DataSketchesHll>(hll.get_lg_config_k(), hll.get_target_type());
this->data(state).hll_sketch = std::make_unique<DataSketchesHll>(
hll.get_lg_config_k(), hll.get_target_type(), &(this->data(state).memory_usage));
}
int64_t prev_memory = this->data(state).hll_sketch->mem_usage();
this->data(state).hll_sketch->merge(hll);
Expand Down Expand Up @@ -159,7 +161,8 @@ class HllSketchAggregateFunction final
const Column** src_datas_ptr = src_datas.data();
std::tie(log_k, tgt_type) = _parse_hll_sketch_args(ctx, src_datas_ptr);
for (size_t i = 0; i < chunk_size; ++i) {
DataSketchesHll hll{log_k, tgt_type};
int64_t memory_usage = 0;
DataSketchesHll hll{log_k, tgt_type, &memory_usage};
if constexpr (lt_is_string<LT>) {
Slice s = column->get_slice(i);
value = HashUtil::murmur_hash64A(s.data, s.size, HashUtil::MURMUR_SEED);
Expand Down Expand Up @@ -201,7 +204,7 @@ class HllSketchAggregateFunction final
uint8_t log_k;
datasketches::target_hll_type tgt_type;
std::tie(log_k, tgt_type) = _parse_hll_sketch_args(ctx, columns);
this->data(state).hll_sketch = _init_hll_sketch(log_k, tgt_type);
this->data(state).hll_sketch = _init_hll_sketch(log_k, tgt_type, &(this->data(state).memory_usage));
}
}

Expand All @@ -228,9 +231,9 @@ class HllSketchAggregateFunction final
}

// init hll sketch with default log_k and target type
std::unique_ptr<DataSketchesHll> _init_hll_sketch(
uint8_t log_k = DEFAULT_HLL_LOG_K, datasketches::target_hll_type tgt_type = datasketches::HLL_6) const {
return std::make_unique<DataSketchesHll>(log_k, tgt_type);
std::unique_ptr<DataSketchesHll> _init_hll_sketch(uint8_t log_k, datasketches::target_hll_type tgt_type,
int64_t* memory_usage) const {
return std::make_unique<DataSketchesHll>(log_k, tgt_type, memory_usage);
}
};

Expand Down
2 changes: 1 addition & 1 deletion be/src/exprs/agg/factory/aggregate_factory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
#include "exprs/agg/count.h"
#include "exprs/agg/covariance.h"
#include "exprs/agg/distinct.h"
#include "exprs/agg/ds_hll_count_distinct.h"
#include "exprs/agg/exchange_perf.h"
#include "exprs/agg/group_concat.h"
#include "exprs/agg/histogram.h"
#include "exprs/agg/hll_ndv.h"
#include "exprs/agg/hll_sketch.h"
#include "exprs/agg/hll_union.h"
#include "exprs/agg/hll_union_count.h"
#include "exprs/agg/intersect_count.h"
Expand Down
9 changes: 3 additions & 6 deletions be/src/storage/chunk_aggregator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,10 @@
#include "exec/sorting/sorting.h"
#include "exprs/agg/aggregate_state_allocator.h"
#include "gutil/casts.h"
#include "runtime/memory/mem_hook_allocator.h"
#include "storage/column_aggregate_func.h"

namespace starrocks {

static MemHookAllocator kDefaultChunkAggregatorAllocator = MemHookAllocator{};

ChunkAggregator::ChunkAggregator(const starrocks::Schema* schema, uint32_t reserve_rows, uint32_t max_aggregate_rows,
double factor, bool is_vertical_merge, bool is_key)
: _schema(schema),
Expand Down Expand Up @@ -173,7 +170,7 @@ void ChunkAggregator::aggregate() {
_aggregate_loops[_aggregate_loops.size() - 1] += 1;
}
}
SCOPED_THREAD_LOCAL_AGG_STATE_ALLOCATOR_SETTER(&kDefaultChunkAggregatorAllocator);
SCOPED_THREAD_LOCAL_AGG_STATE_ALLOCATOR_SETTER(&kDefaultColumnAggregatorAllocator);
// 3. Copy the selected key rows
// 4. Aggregate the value rows
for (int i = 0; i < _key_fields; ++i) {
Expand All @@ -196,7 +193,7 @@ bool ChunkAggregator::is_finish() {
void ChunkAggregator::aggregate_reset() {
_aggregate_chunk = ChunkHelper::new_chunk(*_schema, _reserve_rows);
_aggregate_rows = 0;
SCOPED_THREAD_LOCAL_AGG_STATE_ALLOCATOR_SETTER(&kDefaultChunkAggregatorAllocator);
SCOPED_THREAD_LOCAL_AGG_STATE_ALLOCATOR_SETTER(&kDefaultColumnAggregatorAllocator);
for (int i = 0; i < _num_fields; ++i) {
auto p = _aggregate_chunk->get_column_by_index(i).get();
_column_aggregator[i]->update_aggregate(p);
Expand All @@ -210,7 +207,7 @@ void ChunkAggregator::aggregate_reset() {
}

ChunkPtr ChunkAggregator::aggregate_result() {
SCOPED_THREAD_LOCAL_AGG_STATE_ALLOCATOR_SETTER(&kDefaultChunkAggregatorAllocator);
SCOPED_THREAD_LOCAL_AGG_STATE_ALLOCATOR_SETTER(&kDefaultColumnAggregatorAllocator);
for (int i = 0; i < _num_fields; ++i) {
_column_aggregator[i]->finalize();
}
Expand Down
1 change: 1 addition & 0 deletions be/src/storage/column_aggregate_func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ class AggFuncBasedValueAggregator : public ValueColumnAggregatorBase {
}

~AggFuncBasedValueAggregator() override {
SCOPED_THREAD_LOCAL_AGG_STATE_ALLOCATOR_SETTER(&kDefaultColumnAggregatorAllocator);
if (_state != nullptr) {
_agg_func->destroy(_func_ctx, _state);
std::free(_state);
Expand Down
3 changes: 3 additions & 0 deletions be/src/storage/column_aggregate_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
#pragma once

#include "column/field.h"
#include "runtime/memory/mem_hook_allocator.h"
#include "storage/column_aggregator.h"

namespace starrocks {

static MemHookAllocator kDefaultColumnAggregatorAllocator = MemHookAllocator{};
class ColumnAggregatorFactory {
public:
static ColumnAggregatorPtr create_key_column_aggregator(const FieldPtr& field);
Expand Down
1 change: 1 addition & 0 deletions be/src/types/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_library(Types STATIC
bitmap_value.cpp
date_value.cpp
hll.cpp
hll_sketch.cpp
logical_type.cpp
map_type_info.cpp
struct_type_info.cpp
Expand Down
Loading

0 comments on commit 8c4dc97

Please sign in to comment.