Skip to content

Commit

Permalink
be part
Browse files Browse the repository at this point in the history
  • Loading branch information
Mryange committed Nov 11, 2024
1 parent e7de4f3 commit 08ff692
Show file tree
Hide file tree
Showing 3 changed files with 388 additions and 55 deletions.
61 changes: 41 additions & 20 deletions be/src/vec/aggregate_functions/aggregate_function_percentile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "vec/aggregate_functions/aggregate_function_percentile.h"

#include "vec/aggregate_functions/aggregate_function_percentile_approx.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/helpers.h"
#include "vec/core/types.h"
Expand All @@ -28,16 +29,24 @@ AggregateFunctionPtr create_aggregate_function_percentile_approx(
const AggregateFunctionAttr& attr) {
const DataTypePtr& argument_type = remove_nullable(argument_types[0]);
WhichDataType which(argument_type);
if (which.idx != TypeIndex::Float64) {
return nullptr;
}
if (argument_types.size() == 2) {
return creator_without_type::create<AggregateFunctionPercentileApproxTwoParams>(
argument_types, result_is_nullable);
}
if (argument_types.size() == 3) {
return creator_without_type::create<AggregateFunctionPercentileApproxThreeParams>(
argument_types, result_is_nullable);
if (which.idx == TypeIndex::Float64) {
if (argument_types.size() == 2) {
return creator_without_type::create<AggregateFunctionPercentileApproxTwoParamsOld>(
argument_types, result_is_nullable);
}
if (argument_types.size() == 3) {
return creator_without_type::create<AggregateFunctionPercentileApproxThreeParamsOld>(
argument_types, result_is_nullable);
}
} else if (which.idx == TypeIndex::Float32) {
if (argument_types.size() == 2) {
return creator_without_type::create<AggregateFunctionPercentileApproxTwoParams>(
argument_types, result_is_nullable);
}
if (argument_types.size() == 3) {
return creator_without_type::create<AggregateFunctionPercentileApproxThreeParams>(
argument_types, result_is_nullable);
}
}
return nullptr;
}
Expand All @@ -47,16 +56,28 @@ AggregateFunctionPtr create_aggregate_function_percentile_approx_weighted(
const AggregateFunctionAttr& attr) {
const DataTypePtr& argument_type = remove_nullable(argument_types[0]);
WhichDataType which(argument_type);
if (which.idx != TypeIndex::Float64) {
return nullptr;
}
if (argument_types.size() == 3) {
return creator_without_type::create<AggregateFunctionPercentileApproxWeightedThreeParams>(
argument_types, result_is_nullable);
}
if (argument_types.size() == 4) {
return creator_without_type::create<AggregateFunctionPercentileApproxWeightedFourParams>(
argument_types, result_is_nullable);
if (which.idx == TypeIndex::Float64) {
if (argument_types.size() == 3) {
return creator_without_type::create<
AggregateFunctionPercentileApproxWeightedThreeParamsOld>(argument_types,
result_is_nullable);
}
if (argument_types.size() == 4) {
return creator_without_type::create<
AggregateFunctionPercentileApproxWeightedFourParamsOld>(argument_types,
result_is_nullable);
}
} else if (which.idx == TypeIndex::Float32) {
if (argument_types.size() == 3) {
return creator_without_type::create<
AggregateFunctionPercentileApproxWeightedThreeParams>(argument_types,
result_is_nullable);
}
if (argument_types.size() == 4) {
return creator_without_type::create<
AggregateFunctionPercentileApproxWeightedFourParams>(argument_types,
result_is_nullable);
}
}
return nullptr;
}
Expand Down
75 changes: 40 additions & 35 deletions be/src/vec/aggregate_functions/aggregate_function_percentile.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,14 @@ namespace doris::vectorized {
class Arena;
class BufferReadable;

struct PercentileApproxState {
struct PercentileApproxStateOld {
// Since TDigest internally performs calculations using float32, but the function definitions use double,
// there is an additional type conversion overhead. To ensure compatibility with both up- and down-casting,
// the original code is appended with the suffix "Old". The BE will support both float32 and float64 implementations.
// To ensure the result remains unchanged, the return value is still float64.
static constexpr double INIT_QUANTILE = -1.0;
PercentileApproxState() = default;
~PercentileApproxState() = default;
PercentileApproxStateOld() = default;
~PercentileApproxStateOld() = default;

void init(double compression = 10000) {
if (!init_flag) {
Expand Down Expand Up @@ -109,7 +113,7 @@ struct PercentileApproxState {
}
}

void merge(const PercentileApproxState& rhs) {
void merge(const PercentileApproxStateOld& rhs) {
if (!rhs.init_flag) {
return;
}
Expand All @@ -121,7 +125,7 @@ struct PercentileApproxState {
digest->merge(rhs.digest.get());
init_flag = true;
}
if (target_quantile == PercentileApproxState::INIT_QUANTILE) {
if (target_quantile == PercentileApproxStateOld::INIT_QUANTILE) {
target_quantile = rhs.target_quantile;
}
}
Expand Down Expand Up @@ -152,40 +156,40 @@ struct PercentileApproxState {
double compressions = 10000;
};

class AggregateFunctionPercentileApprox
: public IAggregateFunctionDataHelper<PercentileApproxState,
AggregateFunctionPercentileApprox> {
class AggregateFunctionPercentileApproxOld
: public IAggregateFunctionDataHelper<PercentileApproxStateOld,
AggregateFunctionPercentileApproxOld> {
public:
AggregateFunctionPercentileApprox(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<PercentileApproxState,
AggregateFunctionPercentileApprox>(argument_types_) {}
AggregateFunctionPercentileApproxOld(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<PercentileApproxStateOld,
AggregateFunctionPercentileApproxOld>(argument_types_) {}

String get_name() const override { return "percentile_approx"; }

void reset(AggregateDataPtr __restrict place) const override {
AggregateFunctionPercentileApprox::data(place).reset();
AggregateFunctionPercentileApproxOld::data(place).reset();
}

void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena*) const override {
AggregateFunctionPercentileApprox::data(place).merge(
AggregateFunctionPercentileApprox::data(rhs));
AggregateFunctionPercentileApproxOld::data(place).merge(
AggregateFunctionPercentileApproxOld::data(rhs));
}

void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
AggregateFunctionPercentileApprox::data(place).write(buf);
AggregateFunctionPercentileApproxOld::data(place).write(buf);
}

void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena*) const override {
AggregateFunctionPercentileApprox::data(place).read(buf);
AggregateFunctionPercentileApproxOld::data(place).read(buf);
}
};

class AggregateFunctionPercentileApproxTwoParams : public AggregateFunctionPercentileApprox {
class AggregateFunctionPercentileApproxTwoParamsOld : public AggregateFunctionPercentileApproxOld {
public:
AggregateFunctionPercentileApproxTwoParams(const DataTypes& argument_types_)
: AggregateFunctionPercentileApprox(argument_types_) {}
AggregateFunctionPercentileApproxTwoParamsOld(const DataTypes& argument_types_)
: AggregateFunctionPercentileApproxOld(argument_types_) {}
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
const auto& sources =
Expand All @@ -200,7 +204,7 @@ class AggregateFunctionPercentileApproxTwoParams : public AggregateFunctionPerce

void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& col = assert_cast<ColumnFloat64&>(to);
double result = AggregateFunctionPercentileApprox::data(place).get();
double result = AggregateFunctionPercentileApproxOld::data(place).get();

if (std::isnan(result)) {
col.insert_default();
Expand All @@ -210,10 +214,11 @@ class AggregateFunctionPercentileApproxTwoParams : public AggregateFunctionPerce
}
};

class AggregateFunctionPercentileApproxThreeParams : public AggregateFunctionPercentileApprox {
class AggregateFunctionPercentileApproxThreeParamsOld
: public AggregateFunctionPercentileApproxOld {
public:
AggregateFunctionPercentileApproxThreeParams(const DataTypes& argument_types_)
: AggregateFunctionPercentileApprox(argument_types_) {}
AggregateFunctionPercentileApproxThreeParamsOld(const DataTypes& argument_types_)
: AggregateFunctionPercentileApproxOld(argument_types_) {}
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
const auto& sources =
Expand All @@ -231,7 +236,7 @@ class AggregateFunctionPercentileApproxThreeParams : public AggregateFunctionPer

void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& col = assert_cast<ColumnFloat64&>(to);
double result = AggregateFunctionPercentileApprox::data(place).get();
double result = AggregateFunctionPercentileApproxOld::data(place).get();

if (std::isnan(result)) {
col.insert_default();
Expand All @@ -241,11 +246,11 @@ class AggregateFunctionPercentileApproxThreeParams : public AggregateFunctionPer
}
};

class AggregateFunctionPercentileApproxWeightedThreeParams
: public AggregateFunctionPercentileApprox {
class AggregateFunctionPercentileApproxWeightedThreeParamsOld
: public AggregateFunctionPercentileApproxOld {
public:
AggregateFunctionPercentileApproxWeightedThreeParams(const DataTypes& argument_types_)
: AggregateFunctionPercentileApprox(argument_types_) {}
AggregateFunctionPercentileApproxWeightedThreeParamsOld(const DataTypes& argument_types_)
: AggregateFunctionPercentileApproxOld(argument_types_) {}

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
Expand All @@ -265,7 +270,7 @@ class AggregateFunctionPercentileApproxWeightedThreeParams

void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& col = assert_cast<ColumnFloat64&>(to);
double result = AggregateFunctionPercentileApprox::data(place).get();
double result = AggregateFunctionPercentileApproxOld::data(place).get();

if (std::isnan(result)) {
col.insert_default();
Expand All @@ -275,11 +280,11 @@ class AggregateFunctionPercentileApproxWeightedThreeParams
}
};

class AggregateFunctionPercentileApproxWeightedFourParams
: public AggregateFunctionPercentileApprox {
class AggregateFunctionPercentileApproxWeightedFourParamsOld
: public AggregateFunctionPercentileApproxOld {
public:
AggregateFunctionPercentileApproxWeightedFourParams(const DataTypes& argument_types_)
: AggregateFunctionPercentileApprox(argument_types_) {}
AggregateFunctionPercentileApproxWeightedFourParamsOld(const DataTypes& argument_types_)
: AggregateFunctionPercentileApproxOld(argument_types_) {}
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
const auto& sources =
Expand All @@ -300,7 +305,7 @@ class AggregateFunctionPercentileApproxWeightedFourParams

void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
auto& col = assert_cast<ColumnFloat64&>(to);
double result = AggregateFunctionPercentileApprox::data(place).get();
double result = AggregateFunctionPercentileApproxOld::data(place).get();

if (std::isnan(result)) {
col.insert_default();
Expand Down Expand Up @@ -351,7 +356,7 @@ struct PercentileState {
}
}

void add(T source, const PaddedPODArray<Float64>& quantiles, int arg_size) {
void add(T source, const PaddedPODArray<Float64>& quantiles, int64_t arg_size) {
if (!inited_flag) {
vec_counts.resize(arg_size);
vec_quantile.resize(arg_size, -1);
Expand Down
Loading

0 comments on commit 08ff692

Please sign in to comment.