diff --git a/be/CMakeLists.txt b/be/CMakeLists.txt index df4b036347ba1f..917a2008d97e21 100644 --- a/be/CMakeLists.txt +++ b/be/CMakeLists.txt @@ -625,6 +625,10 @@ if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wno-switch-default") endif () else () + if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL "14.0.0") + # ignore error from apache-orc + set(CXX_COMMON_FLAGS "${CXX_COMMON_FLAGS} -Wno-error=dangling-reference") + endif () set(CXX_GCC_FLAGS "${CXX_GCC_FLAGS} -fcoroutines") endif() @@ -773,6 +777,8 @@ endif() set(WL_START_GROUP "-Wl,--start-group") set(WL_END_GROUP "-Wl,--end-group") +set(WL_LINK_STATIC "-Wl,-Bstatic") +set(WL_LINK_DYNAMIC "-Wl,-Bdynamic") # Set starrocks libraries set(STARROCKS_LINK_LIBS @@ -963,7 +969,8 @@ if (NOT ("${MAKE_TEST}" STREQUAL "ON" AND "${BUILD_FOR_SANITIZE}" STREQUAL "ON") endif() set(STARROCKS_LINK_LIBS ${STARROCKS_LINK_LIBS} - -lresolv -lbfd -liberty -lc -lm -ldl -rdynamic -pthread -Wl,-wrap=__cxa_throw + ${WL_LINK_STATIC} -lbfd + ${WL_LINK_DYNAMIC} -lresolv -liberty -lc -lm -ldl -rdynamic -pthread -Wl,-wrap=__cxa_throw ) # link gcov if WITH_GCOV is on diff --git a/be/src/column/adaptive_nullable_column.h b/be/src/column/adaptive_nullable_column.h index b247bdce6e18fa..0360f8c696a54d 100644 --- a/be/src/column/adaptive_nullable_column.h +++ b/be/src/column/adaptive_nullable_column.h @@ -276,9 +276,7 @@ class AdaptiveNullableColumn final : public ColumnFactory upgrade_if_overflow() override { materialized_nullable(); - if (_null_column->capacity_limit_reached()) { - return Status::InternalError("Size of NullableColumn exceed the limit"); - } + RETURN_IF_ERROR(_null_column->capacity_limit_reached()); return upgrade_helper_func(&_data_column); } @@ -556,9 +554,9 @@ class AdaptiveNullableColumn final : public ColumnFactory { std::string debug_string() const override; - bool capacity_limit_reached(std::string* msg = nullptr) const override { - return _elements->capacity_limit_reached(msg) || _offsets->capacity_limit_reached(msg); + Status capacity_limit_reached() const override { + RETURN_IF_ERROR(_elements->capacity_limit_reached()); + return _offsets->capacity_limit_reached(); } StatusOr upgrade_if_overflow() override; diff --git a/be/src/column/binary_column.cpp b/be/src/column/binary_column.cpp index e209e548defa01..b4e5bedc48e856 100644 --- a/be/src/column/binary_column.cpp +++ b/be/src/column/binary_column.cpp @@ -24,6 +24,7 @@ #include "gutil/bits.h" #include "gutil/casts.h" #include "gutil/strings/fastmem.h" +#include "gutil/strings/substitute.h" #include "util/hash_util.hpp" #include "util/mysql_row_buffer.h" #include "util/raw_container.h" @@ -738,45 +739,37 @@ bool BinaryColumnBase::has_large_column() const { } template -bool BinaryColumnBase::capacity_limit_reached(std::string* msg) const { +Status BinaryColumnBase::capacity_limit_reached() const { static_assert(std::is_same_v || std::is_same_v); if constexpr (std::is_same_v) { // The size limit of a single element is 2^32 - 1. // The size limit of all elements is 2^32 - 1. // The number limit of elements is 2^32 - 1. if (_bytes.size() >= Column::MAX_CAPACITY_LIMIT) { - if (msg != nullptr) { - msg->append("Total byte size of binary column exceed the limit: " + - std::to_string(Column::MAX_CAPACITY_LIMIT)); - } - return true; + return Status::CapacityLimitExceed( + strings::Substitute("Total byte size of binary column exceed the limit: $0", + std::to_string(Column::MAX_CAPACITY_LIMIT))); } else if (_offsets.size() >= Column::MAX_CAPACITY_LIMIT) { - if (msg != nullptr) { - msg->append("Total row count of binary column exceed the limit: " + - std::to_string(Column::MAX_CAPACITY_LIMIT)); - } - return true; + return Status::CapacityLimitExceed( + strings::Substitute("Total row count of binary column exceed the limit: $0", + std::to_string(Column::MAX_CAPACITY_LIMIT))); } else { - return false; + return Status::OK(); } } else { // The size limit of a single element is 2^32 - 1. // The size limit of all elements is 2^64 - 1. // The number limit of elements is 2^32 - 1. if (_bytes.size() >= Column::MAX_LARGE_CAPACITY_LIMIT) { - if (msg != nullptr) { - msg->append("Total byte size of large binary column exceed the limit: " + - std::to_string(Column::MAX_LARGE_CAPACITY_LIMIT)); - } - return true; + return Status::CapacityLimitExceed( + strings::Substitute("Total byte size of large binary column exceed the limit: $0", + std::to_string(Column::MAX_LARGE_CAPACITY_LIMIT))); } else if (_offsets.size() >= Column::MAX_CAPACITY_LIMIT) { - if (msg != nullptr) { - msg->append("Total row count of large binary column exceed the limit: " + - std::to_string(Column::MAX_CAPACITY_LIMIT)); - } - return true; + return Status::CapacityLimitExceed( + strings::Substitute("Total row count of large binary column exceed the limit: $0", + std::to_string(Column::MAX_CAPACITY_LIMIT))); } else { - return false; + return Status::OK(); } } } diff --git a/be/src/column/binary_column.h b/be/src/column/binary_column.h index 09bfac117a058b..6e0b59b374daed 100644 --- a/be/src/column/binary_column.h +++ b/be/src/column/binary_column.h @@ -339,7 +339,7 @@ class BinaryColumnBase final : public ColumnFactory> return ss.str(); } - bool capacity_limit_reached(std::string* msg = nullptr) const override; + Status capacity_limit_reached() const override; private: void _build_slices() const; diff --git a/be/src/column/chunk.h b/be/src/column/chunk.h index cf382610771115..67de2c8768540c 100644 --- a/be/src/column/chunk.h +++ b/be/src/column/chunk.h @@ -277,13 +277,11 @@ class Chunk { std::string rebuild_csv_row(size_t index, const std::string& delimiter) const; - bool capacity_limit_reached(std::string* msg = nullptr) const { + Status capacity_limit_reached() const { for (const auto& column : _columns) { - if (column->capacity_limit_reached(msg)) { - return true; - } + RETURN_IF_ERROR(column->capacity_limit_reached()); } - return false; + return Status::OK(); } query_cache::owner_info& owner_info() { return _owner_info; } diff --git a/be/src/column/column.h b/be/src/column/column.h index ec730265a0fc8e..3d86550fb09a71 100644 --- a/be/src/column/column.h +++ b/be/src/column/column.h @@ -394,7 +394,7 @@ class Column { // The interface will not free memory!!! virtual void reset_column() { _delete_state = DEL_NOT_SATISFIED; } - virtual bool capacity_limit_reached(std::string* msg = nullptr) const = 0; + virtual Status capacity_limit_reached() const = 0; virtual Status accept(ColumnVisitor* visitor) const = 0; diff --git a/be/src/column/column_builder.h b/be/src/column/column_builder.h index 42534d9ab689bc..221b11ec34a2de 100644 --- a/be/src/column/column_builder.h +++ b/be/src/column/column_builder.h @@ -57,7 +57,7 @@ class ColumnBuilder { ColumnBuilder(DataColumnPtr column, NullColumnPtr null_column, bool has_null) : _column(std::move(column)), _null_column(std::move(null_column)), _has_null(has_null) {} //do nothing ctor, members are initialized by its offsprings. - explicit ColumnBuilder(void*) {} + explicit ColumnBuilder(void*) {} void append(const DatumType& value) { _null_column->append(DATUM_NOT_NULL); diff --git a/be/src/column/column_helper.h b/be/src/column/column_helper.h index ea1a6f0a60dbbb..97d53bb2d353d7 100644 --- a/be/src/column/column_helper.h +++ b/be/src/column/column_helper.h @@ -525,8 +525,10 @@ struct ChunkSliceTemplate { template struct GetContainer { using ColumnType = typename RunTimeTypeTraits::ColumnType; - const auto& get_data(const Column* column) { return ColumnHelper::as_raw_column(column)->get_data(); } - const auto& get_data(const ColumnPtr& column) { + static const auto& get_data(const Column* column) { + return ColumnHelper::as_raw_column(column)->get_data(); + } + static const auto& get_data(const ColumnPtr& column) { return ColumnHelper::as_raw_column(column.get())->get_data(); } }; @@ -534,10 +536,10 @@ struct GetContainer { #define GET_CONTAINER(ltype) \ template <> \ struct GetContainer { \ - const auto& get_data(const Column* column) { \ + static const auto& get_data(const Column* column) { \ return ColumnHelper::as_raw_column(column)->get_proxy_data(); \ } \ - const auto& get_data(const ColumnPtr& column) { \ + static const auto& get_data(const ColumnPtr& column) { \ return ColumnHelper::as_raw_column(column.get())->get_proxy_data(); \ } \ }; diff --git a/be/src/column/const_column.h b/be/src/column/const_column.h index 6a2b27000b321c..4edd7449a9958b 100644 --- a/be/src/column/const_column.h +++ b/be/src/column/const_column.h @@ -18,6 +18,7 @@ #include "column/datum.h" #include "column/vectorized_fwd.h" #include "common/logging.h" +#include "gutil/strings/substitute.h" namespace starrocks { @@ -251,15 +252,13 @@ class ConstColumn final : public ColumnFactory { return ss.str(); } - bool capacity_limit_reached(std::string* msg = nullptr) const override { - RETURN_IF_UNLIKELY(_data->capacity_limit_reached(msg), true); + Status capacity_limit_reached() const override { + RETURN_IF_ERROR(_data->capacity_limit_reached()); if (_size > Column::MAX_CAPACITY_LIMIT) { - if (msg != nullptr) { - msg->append("Row count of const column reach limit: " + std::to_string(Column::MAX_CAPACITY_LIMIT)); - } - return true; + return Status::CapacityLimitExceed(strings::Substitute("Row count of const column reach limit: $0", + std::to_string(Column::MAX_CAPACITY_LIMIT))); } - return false; + return Status::OK(); } void check_or_die() const override; diff --git a/be/src/column/field.h b/be/src/column/field.h index 6203a4fda0f070..aa057db47648f6 100644 --- a/be/src/column/field.h +++ b/be/src/column/field.h @@ -29,22 +29,36 @@ namespace starrocks { class Datum; +class AggStateDesc; class Field { public: Field(ColumnId id, std::string_view name, TypeInfoPtr type, starrocks::StorageAggregateType agg, - uint8_t short_key_length, bool is_key, bool nullable) + AggStateDesc* agg_state_desc, uint8_t short_key_length, bool is_key, bool nullable) : _id(id), _agg_method(agg), + _agg_state_desc(agg_state_desc), _name(name), _type(std::move(type)), _sub_fields(nullptr), _short_key_length(short_key_length), - _flags(static_cast((is_key << kIsKeyShift) | (nullable << kNullableShift))) {} + _flags(static_cast((is_key << kIsKeyShift) | (nullable << kNullableShift))) { + if (_agg_method == STORAGE_AGGREGATE_AGG_STATE_UNION) { + DCHECK(_agg_state_desc != nullptr); + } + } + + // AggMethod is not STORAGE_AGGREGATE_AGG_STATE_UNION + Field(ColumnId id, std::string_view name, TypeInfoPtr type, starrocks::StorageAggregateType agg, + uint8_t short_key_length, bool is_key, bool nullable) + : Field(id, name, std::move(type), agg, nullptr, short_key_length, is_key, nullable) { + DCHECK(_agg_method != STORAGE_AGGREGATE_AGG_STATE_UNION); + } // Non-key field of any type except for ARRAY Field(ColumnId id, std::string_view name, LogicalType type, int precision, int scale, bool nullable) - : Field(id, name, get_type_info(type, precision, scale), STORAGE_AGGREGATE_NONE, 0, false, nullable) {} + : Field(id, name, get_type_info(type, precision, scale), STORAGE_AGGREGATE_NONE, nullptr, 0, false, + nullable) {} // Non-key field of any type except for DECIMAL32, DECIMAL64, DECIMAL128, and ARRAY Field(ColumnId id, std::string_view name, LogicalType type, bool nullable) @@ -57,7 +71,7 @@ class Field { // Non-key field of any type Field(ColumnId id, std::string_view name, TypeInfoPtr type, bool nullable = true) - : Field(id, name, std::move(type), STORAGE_AGGREGATE_NONE, 0, false, nullable) {} + : Field(id, name, std::move(type), STORAGE_AGGREGATE_NONE, nullptr, 0, false, nullable) {} ~Field() { delete _sub_fields; } @@ -66,6 +80,7 @@ class Field { Field(const Field& rhs) : _id(rhs._id), _agg_method(rhs._agg_method), + _agg_state_desc(rhs._agg_state_desc), _name(rhs._name), _type(rhs._type), _sub_fields(rhs._sub_fields ? new Buffer(*rhs._sub_fields) : nullptr), @@ -76,6 +91,7 @@ class Field { Field(Field&& rhs) noexcept : _id(rhs._id), _agg_method(rhs._agg_method), + _agg_state_desc(rhs._agg_state_desc), _name(std::move(rhs._name)), _type(std::move(rhs._type)), _sub_fields(rhs._sub_fields), @@ -92,6 +108,7 @@ class Field { _name = rhs._name; _type = rhs._type; _agg_method = rhs._agg_method; + _agg_state_desc = rhs._agg_state_desc; _short_key_length = rhs._short_key_length; _flags = rhs._flags; _sub_fields = rhs._sub_fields ? new Buffer(*rhs._sub_fields) : nullptr; @@ -106,6 +123,7 @@ class Field { _name = std::move(rhs._name); _type = std::move(rhs._type); _agg_method = rhs._agg_method; + _agg_state_desc = rhs._agg_state_desc; _short_key_length = rhs._short_key_length; _flags = rhs._flags; _uid = rhs._uid; @@ -171,6 +189,9 @@ class Field { void set_uid(ColumnUID uid) { _uid = uid; } const ColumnUID& uid() const { return _uid; } + void set_agg_state_desc(AggStateDesc* agg_state_desc) { _agg_state_desc = agg_state_desc; } + AggStateDesc* get_agg_state_desc() const { return _agg_state_desc; } + static FieldPtr convert_to_dict_field(const Field& field); private: @@ -179,6 +200,8 @@ class Field { ColumnId _id = 0; starrocks::StorageAggregateType _agg_method; + // agg_state_desc if agg_method is STORAGE_AGGREGATE_AGG_STATE_UNION + AggStateDesc* _agg_state_desc; CString _name; TypeInfoPtr _type = nullptr; std::vector* _sub_fields; @@ -225,16 +248,17 @@ inline const Field& Field::sub_field(int i) const { inline FieldPtr Field::with_type(const TypeInfoPtr& type) { return std::make_shared(_id, std::string_view(_name.data(), _name.size()), type, _agg_method, - _short_key_length, is_key(), is_nullable()); + _agg_state_desc, _short_key_length, is_key(), is_nullable()); } inline FieldPtr Field::with_name(std::string_view name) { - return std::make_shared(_id, name, _type, _agg_method, _short_key_length, is_key(), is_nullable()); + return std::make_shared(_id, name, _type, _agg_method, _agg_state_desc, _short_key_length, is_key(), + is_nullable()); } inline FieldPtr Field::with_nullable(bool nullable) { return std::make_shared(_id, std::string_view(_name.data(), _name.size()), _type, _agg_method, - _short_key_length, is_key(), nullable); + _agg_state_desc, _short_key_length, is_key(), nullable); } inline std::ostream& operator<<(std::ostream& os, const Field& field) { diff --git a/be/src/column/fixed_length_column_base.cpp b/be/src/column/fixed_length_column_base.cpp index b29ab265f0e828..308feaa4ba2e13 100644 --- a/be/src/column/fixed_length_column_base.cpp +++ b/be/src/column/fixed_length_column_base.cpp @@ -29,11 +29,8 @@ namespace starrocks { template StatusOr FixedLengthColumnBase::upgrade_if_overflow() { - if (capacity_limit_reached()) { - return Status::InternalError("Size of FixedLengthColumn exceed the limit"); - } else { - return nullptr; - } + RETURN_IF_ERROR(capacity_limit_reached()); + return nullptr; } template diff --git a/be/src/column/fixed_length_column_base.h b/be/src/column/fixed_length_column_base.h index e976cce366da36..7ac5a17f5fd672 100644 --- a/be/src/column/fixed_length_column_base.h +++ b/be/src/column/fixed_length_column_base.h @@ -21,6 +21,7 @@ #include "column/datum.h" #include "column/vectorized_fwd.h" #include "common/statusor.h" +#include "gutil/strings/substitute.h" #include "runtime/decimalv2_value.h" #include "types/date_value.hpp" #include "types/timestamp_value.h" @@ -236,15 +237,13 @@ class FixedLengthColumnBase : public ColumnFactory 2^32), but some interface such as update_rows() will use index of uint32_t to // access the item, so we should use 2^32 as the limit - bool capacity_limit_reached(std::string* msg = nullptr) const override { + Status capacity_limit_reached() const override { if (_data.size() > Column::MAX_CAPACITY_LIMIT) { - if (msg != nullptr) { - msg->append("row count of fixed length column exceend the limit: " + - std::to_string(Column::MAX_CAPACITY_LIMIT)); - } - return true; + return Status::CapacityLimitExceed( + strings::Substitute("row count of fixed length column exceend the limit: $0", + std::to_string(Column::MAX_CAPACITY_LIMIT))); } - return false; + return Status::OK(); } void check_or_die() const override {} diff --git a/be/src/column/json_column.cpp b/be/src/column/json_column.cpp index 46bdb2dc88f91c..f8841a1eec71ee 100644 --- a/be/src/column/json_column.cpp +++ b/be/src/column/json_column.cpp @@ -24,6 +24,7 @@ #include "common/compiler_util.h" #include "glog/logging.h" #include "gutil/casts.h" +#include "gutil/strings/substitute.h" #include "simd/simd.h" #include "types/logical_type.h" #include "util/hash_util.hpp" @@ -430,14 +431,12 @@ void JsonColumn::reset_column() { _path_to_index.clear(); } -bool JsonColumn::capacity_limit_reached(std::string* msg) const { +Status JsonColumn::capacity_limit_reached() const { if (size() > Column::MAX_CAPACITY_LIMIT) { - if (msg != nullptr) { - msg->append("row count of object column exceed the limit: " + std::to_string(Column::MAX_CAPACITY_LIMIT)); - } - return true; + return Status::CapacityLimitExceed(strings::Substitute("row count of object column exceed the limit: $0", + std::to_string(Column::MAX_CAPACITY_LIMIT))); } - return false; + return Status::OK(); } void JsonColumn::check_or_die() const { diff --git a/be/src/column/json_column.h b/be/src/column/json_column.h index ca71de64940d16..935bb84333ed8c 100644 --- a/be/src/column/json_column.h +++ b/be/src/column/json_column.h @@ -101,7 +101,7 @@ class JsonColumn final : public ColumnFactory, JsonColum void swap_column(Column& rhs) override; void reset_column() override; - bool capacity_limit_reached(std::string* msg = nullptr) const override; + Status capacity_limit_reached() const override; void check_or_die() const override; // support flat json on storage diff --git a/be/src/column/map_column.h b/be/src/column/map_column.h index 4617b225ff0741..15c3fd9a4e65a9 100644 --- a/be/src/column/map_column.h +++ b/be/src/column/map_column.h @@ -164,9 +164,10 @@ class MapColumn final : public ColumnFactory { std::string debug_string() const override; - bool capacity_limit_reached(std::string* msg = nullptr) const override { - return _keys->capacity_limit_reached(msg) || _values->capacity_limit_reached(msg) || - _offsets->capacity_limit_reached(msg); + Status capacity_limit_reached() const override { + RETURN_IF_ERROR(_keys->capacity_limit_reached()); + RETURN_IF_ERROR(_values->capacity_limit_reached()); + return _offsets->capacity_limit_reached(); } StatusOr upgrade_if_overflow() override; diff --git a/be/src/column/nullable_column.cpp b/be/src/column/nullable_column.cpp index ae601bfe6d5245..0aae77ce4297ea 100644 --- a/be/src/column/nullable_column.cpp +++ b/be/src/column/nullable_column.cpp @@ -430,9 +430,7 @@ void NullableColumn::check_or_die() const { } StatusOr NullableColumn::upgrade_if_overflow() { - if (_null_column->capacity_limit_reached()) { - return Status::InternalError("Size of NullableColumn exceed the limit"); - } + RETURN_IF_ERROR(_null_column->capacity_limit_reached()); return upgrade_helper_func(&_data_column); } diff --git a/be/src/column/nullable_column.h b/be/src/column/nullable_column.h index 3fa7f64dac2fbb..304f57a4d38f04 100644 --- a/be/src/column/nullable_column.h +++ b/be/src/column/nullable_column.h @@ -311,8 +311,9 @@ class NullableColumn : public ColumnFactory { return ss.str(); } - bool capacity_limit_reached(std::string* msg = nullptr) const override { - return _data_column->capacity_limit_reached(msg) || _null_column->capacity_limit_reached(msg); + Status capacity_limit_reached() const override { + RETURN_IF_ERROR(_data_column->capacity_limit_reached()); + return _null_column->capacity_limit_reached(); } void check_or_die() const override; diff --git a/be/src/column/object_column.cpp b/be/src/column/object_column.cpp index da13ca585a30ec..3e3b061d1666b7 100644 --- a/be/src/column/object_column.cpp +++ b/be/src/column/object_column.cpp @@ -328,9 +328,7 @@ std::string ObjectColumn::debug_item(size_t idx) const { template StatusOr ObjectColumn::upgrade_if_overflow() { - if (capacity_limit_reached()) { - return Status::InternalError("Size of ObjectColumn exceed the limit"); - } + RETURN_IF_ERROR(capacity_limit_reached()); return nullptr; } diff --git a/be/src/column/object_column.h b/be/src/column/object_column.h index 25ec0e851ecc73..482b68ec7f1a02 100644 --- a/be/src/column/object_column.h +++ b/be/src/column/object_column.h @@ -20,6 +20,7 @@ #include "column/datum.h" #include "column/vectorized_fwd.h" #include "common/object_pool.h" +#include "gutil/strings/substitute.h" #include "types/bitmap_value.h" #include "types/hll.h" #include "util/json.h" @@ -213,15 +214,12 @@ class ObjectColumn : public ColumnFactory> { return ss.str(); } - bool capacity_limit_reached(std::string* msg = nullptr) const override { + Status capacity_limit_reached() const override { if (_pool.size() > Column::MAX_CAPACITY_LIMIT) { - if (msg != nullptr) { - msg->append("row count of object column exceed the limit: " + - std::to_string(Column::MAX_CAPACITY_LIMIT)); - } - return true; + return Status::CapacityLimitExceed(strings::Substitute("row count of object column exceed the limit: $0", + std::to_string(Column::MAX_CAPACITY_LIMIT))); } - return false; + return Status::OK(); } StatusOr upgrade_if_overflow() override; @@ -230,7 +228,7 @@ class ObjectColumn : public ColumnFactory> { bool has_large_column() const override { return false; } - void check_or_die() const {} + void check_or_die() const override {} private: // add this to avoid warning clang-diagnostic-overloaded-virtual diff --git a/be/src/column/struct_column.cpp b/be/src/column/struct_column.cpp index a921e0ed147368..a1482d5ef2cf86 100644 --- a/be/src/column/struct_column.cpp +++ b/be/src/column/struct_column.cpp @@ -447,12 +447,11 @@ void StructColumn::swap_column(Column& rhs) { // _field_names dont need swap } -bool StructColumn::capacity_limit_reached(std::string* msg) const { - bool res = false; +Status StructColumn::capacity_limit_reached() const { for (const auto& column : _fields) { - res = res || column->capacity_limit_reached(msg); + RETURN_IF_ERROR(column->capacity_limit_reached()); } - return res; + return Status::OK(); } void StructColumn::check_or_die() const { diff --git a/be/src/column/struct_column.h b/be/src/column/struct_column.h index 77532683cc7a44..1bf73ceba349bc 100644 --- a/be/src/column/struct_column.h +++ b/be/src/column/struct_column.h @@ -171,7 +171,7 @@ class StructColumn final : public ColumnFactory { void reset_column() override; - bool capacity_limit_reached(std::string* msg = nullptr) const override; + Status capacity_limit_reached() const override; void check_or_die() const override; diff --git a/be/src/common/config.h b/be/src/common/config.h index 308873f7127be1..0dce0fee0901fd 100644 --- a/be/src/common/config.h +++ b/be/src/common/config.h @@ -1251,6 +1251,8 @@ CONF_String(rocksdb_db_options_string, "create_if_missing=true;create_missing_co CONF_Int64(local_exchange_buffer_mem_limit_per_driver, "134217728"); // 128MB // only used for test. default: 128M CONF_mInt64(streaming_agg_limited_memory_size, "134217728"); +// mem limit for partition hash join probe side buffer +CONF_mInt64(partition_hash_join_probe_limit_size, "134217728"); // pipeline streaming aggregate chunk buffer size CONF_mInt32(streaming_agg_chunk_buffer_size, "1024"); CONF_mInt64(wait_apply_time, "6000"); // 6s @@ -1297,7 +1299,7 @@ CONF_mInt32(finish_publish_version_internal, "100"); CONF_mBool(enable_stream_load_verbose_log, "false"); -CONF_mInt32(get_txn_status_internal_sec, "30"); +CONF_mInt32(get_txn_status_internal_sec, "10"); CONF_mBool(dump_metrics_with_bvar, "true"); @@ -1375,10 +1377,6 @@ CONF_mDouble(connector_sink_mem_urgent_space_ratio, "0.1"); // .crm file can be removed after 1day. CONF_mInt32(unused_crm_file_threshold_second, "86400" /** 1day **/); -// When the keys that we want to delete, number of them is larger than this config, -// we will fallback and using `DeleteRange` in rocksdb. -CONF_mInt32(rocksdb_opt_delete_range_limit, "500"); - // python envs config // create time worker timeout CONF_mInt32(create_child_worker_timeout_ms, "1000"); diff --git a/be/src/connector/lake_connector.cpp b/be/src/connector/lake_connector.cpp index 4ebf6cabfd77fd..62344602785635 100644 --- a/be/src/connector/lake_connector.cpp +++ b/be/src/connector/lake_connector.cpp @@ -591,7 +591,6 @@ void LakeDataSource::init_counter(RuntimeState* state) { _prefetch_wait_finish_timer = ADD_CHILD_TIMER(_runtime_profile, "PrefetchWaitFinishTime", io_statistics_name); _prefetch_pending_timer = ADD_CHILD_TIMER(_runtime_profile, "PrefetchPendingTime", io_statistics_name); - _json_flatten_timer = ADD_CHILD_TIMER(_runtime_profile, "JsonFlattern", segment_read_name); _access_path_hits_counter = ADD_COUNTER(_runtime_profile, "AccessPathHits", TUnit::UNIT); _access_path_unhits_counter = ADD_COUNTER(_runtime_profile, "AccessPathUnhits", TUnit::UNIT); } @@ -723,17 +722,19 @@ void LakeDataSource::update_counter() { std::string access_path_hits = "AccessPathHits"; int64_t total = 0; for (auto& [k, v] : _reader->stats().flat_json_hits) { - auto* path_counter = _runtime_profile->get_counter(fmt::format("[Hit]{}", k)); + std::string path = fmt::format("[Hit]{}", k); + auto* path_counter = _runtime_profile->get_counter(path); if (path_counter == nullptr) { - path_counter = ADD_CHILD_COUNTER(_runtime_profile, k, TUnit::UNIT, access_path_hits); + path_counter = ADD_CHILD_COUNTER(_runtime_profile, path, TUnit::UNIT, access_path_hits); } total += v; COUNTER_UPDATE(path_counter, v); } for (auto& [k, v] : _reader->stats().merge_json_hits) { - auto* path_counter = _runtime_profile->get_counter(fmt::format("[HitMerge]{}", k)); + std::string merge_path = fmt::format("[HitMerge]{}", k); + auto* path_counter = _runtime_profile->get_counter(merge_path); if (path_counter == nullptr) { - path_counter = ADD_CHILD_COUNTER(_runtime_profile, k, TUnit::UNIT, access_path_hits); + path_counter = ADD_CHILD_COUNTER(_runtime_profile, merge_path, TUnit::UNIT, access_path_hits); } total += v; COUNTER_UPDATE(path_counter, v); @@ -744,9 +745,10 @@ void LakeDataSource::update_counter() { std::string access_path_unhits = "AccessPathUnhits"; int64_t total = 0; for (auto& [k, v] : _reader->stats().dynamic_json_hits) { - auto* path_counter = _runtime_profile->get_counter(fmt::format("[Unhit]{}", k)); + std::string path = fmt::format("[Unhit]{}", k); + auto* path_counter = _runtime_profile->get_counter(path); if (path_counter == nullptr) { - path_counter = ADD_CHILD_COUNTER(_runtime_profile, k, TUnit::UNIT, access_path_unhits); + path_counter = ADD_CHILD_COUNTER(_runtime_profile, path, TUnit::UNIT, access_path_unhits); } total += v; COUNTER_UPDATE(path_counter, v); @@ -754,7 +756,23 @@ void LakeDataSource::update_counter() { COUNTER_UPDATE(_access_path_unhits_counter, total); } - COUNTER_UPDATE(_json_flatten_timer, _reader->stats().json_flatten_ns); + std::string parent_name = "SegmentRead"; + if (_reader->stats().json_init_ns > 0) { + RuntimeProfile::Counter* c = ADD_CHILD_TIMER(_runtime_profile, "FlatJsonInit", parent_name); + COUNTER_UPDATE(c, _reader->stats().json_init_ns); + } + if (_reader->stats().json_cast_ns > 0) { + RuntimeProfile::Counter* c = ADD_CHILD_TIMER(_runtime_profile, "FlatJsonCast", parent_name); + COUNTER_UPDATE(c, _reader->stats().json_cast_ns); + } + if (_reader->stats().json_merge_ns > 0) { + RuntimeProfile::Counter* c = ADD_CHILD_TIMER(_runtime_profile, "FlatJsonMerge", parent_name); + COUNTER_UPDATE(c, _reader->stats().json_merge_ns); + } + if (_reader->stats().json_flatten_ns > 0) { + RuntimeProfile::Counter* c = ADD_CHILD_TIMER(_runtime_profile, "FlatJsonFlatten", parent_name); + COUNTER_UPDATE(c, _reader->stats().json_flatten_ns); + } } // ================================ diff --git a/be/src/connector/lake_connector.h b/be/src/connector/lake_connector.h index a89ae49c8983b1..b946b7a1f4886e 100644 --- a/be/src/connector/lake_connector.h +++ b/be/src/connector/lake_connector.h @@ -181,7 +181,6 @@ class LakeDataSource final : public DataSource { RuntimeProfile::Counter* _prefetch_pending_timer = nullptr; RuntimeProfile::Counter* _pushdown_access_paths_counter = nullptr; - RuntimeProfile::Counter* _json_flatten_timer = nullptr; RuntimeProfile::Counter* _access_path_hits_counter = nullptr; RuntimeProfile::Counter* _access_path_unhits_counter = nullptr; }; diff --git a/be/src/exec/aggregator.cpp b/be/src/exec/aggregator.cpp index 8564a05174c02e..999ffbdfaa92b7 100644 --- a/be/src/exec/aggregator.cpp +++ b/be/src/exec/aggregator.cpp @@ -203,20 +203,6 @@ void AggregatorParams::init() { #define ALIGN_TO(size, align) ((size + align - 1) / align * align) #define PAD(size, align) (align - (size % align)) % align; -class ThreadLocalStateAllocatorSetter { -public: - ThreadLocalStateAllocatorSetter(Allocator* allocator) - : _agg_state_allocator_setter(allocator), _roaring_allocator_setter(allocator) {} - ~ThreadLocalStateAllocatorSetter() = default; - -private: - ThreadLocalAggregateStateAllocatorSetter _agg_state_allocator_setter; - ThreadLocalRoaringAllocatorSetter _roaring_allocator_setter; -}; - -#define SCOPED_THREAD_LOCAL_STATE_ALLOCATOR_SETTER(allocator) \ - auto VARNAME_LINENUM(alloc_setter) = ThreadLocalStateAllocatorSetter(allocator) - Aggregator::Aggregator(AggregatorParamsPtr params) : _params(std::move(params)) { _allocator = std::make_unique(); } @@ -395,59 +381,22 @@ Status Aggregator::prepare(RuntimeState* state, ObjectPool* pool, RuntimeProfile const TFunction& fn = desc.nodes[0].fn; const auto& agg_fn_type = _agg_fn_types[i]; _is_merge_funcs[i] = aggregate_functions[i].nodes[0].agg_expr.is_merge_agg; - // get function - if (fn.name.function_name == "count") { - bool is_input_nullable = - !fn.arg_types.empty() && (has_outer_join_child || desc.nodes[0].has_nullable_child); - auto* func = get_aggregate_function("count", TYPE_BIGINT, TYPE_BIGINT, is_input_nullable); - _agg_functions[i] = func; - } else { - TypeDescriptor return_type = TypeDescriptor::from_thrift(fn.ret_type); - TypeDescriptor serde_type = TypeDescriptor::from_thrift(fn.aggregate_fn.intermediate_type); - TypeDescriptor arg_type = TypeDescriptor::from_thrift(fn.arg_types[0]); - // Because intersect_count have two input types. - // And intersect_count's first argument's type is alwasy Bitmap, - // so we use its second arguments type as input. - if (fn.name.function_name == "intersect_count") { - arg_type = TypeDescriptor::from_thrift(fn.arg_types[1]); - } - - // Because max_by and min_by function have two input types, - // so we use its second arguments type as input. - if (fn.name.function_name == "max_by" || fn.name.function_name == "min_by") { - arg_type = TypeDescriptor::from_thrift(fn.arg_types[1]); - } - - // Because windowfunnel have more two input types. - // functions registry use 2th args(datetime/date). - if (fn.name.function_name == "window_funnel") { - arg_type = TypeDescriptor::from_thrift(fn.arg_types[1]); - } - - // hack for accepting various arguments - if (fn.name.function_name == "exchange_bytes" || fn.name.function_name == "exchange_speed") { - arg_type = TypeDescriptor(TYPE_BIGINT); - } - - if (fn.name.function_name == "array_union_agg" || fn.name.function_name == "array_unique_agg") { - // for array_union_agg use inner type as signature - arg_type = arg_type.children[0]; - } - - const bool use_nullable_fn = agg_fn_type.use_nullable_fn(_use_intermediate_as_output()); - auto* func = get_aggregate_function(fn.name.function_name, arg_type.type, return_type.type, use_nullable_fn, - fn.binary_type, state->func_version()); - if (func == nullptr) { - return Status::InternalError(strings::Substitute( - "Invalid agg function plan: $0 with (arg type $1, serde type $2, result type $3, nullable $4)", - fn.name.function_name, type_to_string(arg_type.type), type_to_string(serde_type.type), - type_to_string(return_type.type), use_nullable_fn ? "true" : "false")); + // get function + auto is_result_nullable_func = [&]() { + if (fn.name.function_name == "count") { + if (fn.arg_types.empty()) { + return false; + } + if (has_outer_join_child || desc.nodes[0].has_nullable_child) { + return true; + } + return false; + } else { + return agg_fn_type.use_nullable_fn(_use_intermediate_as_output()); } - VLOG_ROW << "get agg function " << func->get_name() << " serde_type " << serde_type << " return_type " - << return_type; - _agg_functions[i] = func; - } + }; + RETURN_IF_ERROR(_create_aggregate_function(state, fn, is_result_nullable_func(), &_agg_functions[i])); int node_idx = 0; for (int j = 0; j < desc.nodes[0].num_children; ++j) { @@ -530,6 +479,42 @@ Status Aggregator::prepare(RuntimeState* state, ObjectPool* pool, RuntimeProfile return Status::OK(); } +Status Aggregator::_create_aggregate_function(starrocks::RuntimeState* state, const TFunction& fn, + bool is_result_nullable, const AggregateFunction** ret) { + std::vector arg_types; + for (auto& type : fn.arg_types) { + arg_types.push_back(TypeDescriptor::from_thrift(type)); + } + + // check whether it's _merge/_union combinator if it contains agg state type + auto& func_name = fn.name.function_name; + // get function + if (func_name == "count") { + auto* func = get_aggregate_function("count", TYPE_BIGINT, TYPE_BIGINT, is_result_nullable); + if (func == nullptr) { + return Status::InternalError(strings::Substitute("Invalid agg function plan: $0 ", func_name)); + } + *ret = func; + } else { + TypeDescriptor return_type = TypeDescriptor::from_thrift(fn.ret_type); + TypeDescriptor serde_type = TypeDescriptor::from_thrift(fn.aggregate_fn.intermediate_type); + DCHECK_LE(1, fn.arg_types.size()); + TypeDescriptor arg_type = arg_types[0]; + auto* func = get_aggregate_function(func_name, return_type, arg_types, is_result_nullable, fn.binary_type, + state->func_version()); + if (func == nullptr) { + return Status::InternalError(strings::Substitute( + "Invalid agg function plan: $0 with (arg type $1, serde type $2, result type $3, nullable $4)", + func_name, type_to_string(arg_type.type), type_to_string(serde_type.type), + type_to_string(return_type.type), is_result_nullable ? "true" : "false")); + } + *ret = func; + VLOG_ROW << "get agg function " << func->get_name() << " serde_type " << serde_type << " return_type " + << return_type; + } + return Status::OK(); +} + Status Aggregator::reset_state(starrocks::RuntimeState* state, const std::vector& refill_chunks, pipeline::Operator* refill_op, bool reset_sink_complete) { RETURN_IF_ERROR(_reset_state(state, reset_sink_complete)); diff --git a/be/src/exec/aggregator.h b/be/src/exec/aggregator.h index 2de2c183a7c935..da78d1f036038e 100644 --- a/be/src/exec/aggregator.h +++ b/be/src/exec/aggregator.h @@ -400,6 +400,9 @@ class Aggregator : public pipeline::ContextWithDependency { bool is_streaming_all_states() const { return _streaming_all_states; } + Status _create_aggregate_function(starrocks::RuntimeState* state, const TFunction& fn, bool is_result_nullable, + const AggregateFunction** ret); + HashTableKeyAllocator _state_allocator; protected: diff --git a/be/src/exec/analytor.cpp b/be/src/exec/analytor.cpp index b4ee5c87d60c01..95a273d8250222 100644 --- a/be/src/exec/analytor.cpp +++ b/be/src/exec/analytor.cpp @@ -586,13 +586,6 @@ Status Analytor::_add_chunk(const ChunkPtr& chunk) { const size_t chunk_size = chunk->num_rows(); { - auto check_if_overflow = [](Column* column) { - std::string msg; - if (column->capacity_limit_reached(&msg)) { - return Status::InternalError(msg); - } - return Status::OK(); - }; SCOPED_TIMER(_column_resize_timer); for (size_t i = 0; i < _agg_fn_ctxs.size(); i++) { for (size_t j = 0; j < _agg_expr_ctxs[i].size(); j++) { @@ -601,20 +594,20 @@ Status Analytor::_add_chunk(const ChunkPtr& chunk) { // When chunk's column is const, maybe need to unpack it. TRY_CATCH_BAD_ALLOC(_append_column(chunk_size, _agg_intput_columns[i][j].get(), column)); - RETURN_IF_ERROR(check_if_overflow(_agg_intput_columns[i][j].get())); + RETURN_IF_ERROR(_agg_intput_columns[i][j]->capacity_limit_reached()); } } for (size_t i = 0; i < _partition_ctxs.size(); i++) { ASSIGN_OR_RETURN(ColumnPtr column, _partition_ctxs[i]->evaluate(chunk.get())); TRY_CATCH_BAD_ALLOC(_append_column(chunk_size, _partition_columns[i].get(), column)); - RETURN_IF_ERROR(check_if_overflow(_partition_columns[i].get())); + RETURN_IF_ERROR(_partition_columns[i]->capacity_limit_reached()); } for (size_t i = 0; i < _order_ctxs.size(); i++) { ASSIGN_OR_RETURN(ColumnPtr column, _order_ctxs[i]->evaluate(chunk.get())); TRY_CATCH_BAD_ALLOC(_append_column(chunk_size, _order_columns[i].get(), column)); - RETURN_IF_ERROR(check_if_overflow(_order_columns[i].get())); + RETURN_IF_ERROR(_order_columns[i]->capacity_limit_reached()); } } diff --git a/be/src/exec/hash_join_components.cpp b/be/src/exec/hash_join_components.cpp index b7755416f7f895..86991a632c1058 100644 --- a/be/src/exec/hash_join_components.cpp +++ b/be/src/exec/hash_join_components.cpp @@ -14,12 +14,23 @@ #include "exec/hash_join_components.h" +#include #include +#include #include "column/vectorized_fwd.h" +#include "common/config.h" +#include "common/logging.h" +#include "common/object_pool.h" #include "exec/hash_joiner.h" #include "exec/join_hash_map.h" +#include "exprs/agg/distinct.h" +#include "exprs/expr_context.h" #include "gutil/casts.h" +#include "runtime/descriptors.h" +#include "runtime/mem_tracker.h" +#include "util/cpu_info.h" +#include "util/runtime_profile.h" namespace starrocks { @@ -28,12 +39,16 @@ class SingleHashJoinProberImpl final : public HashJoinProberImpl { SingleHashJoinProberImpl(HashJoiner& hash_joiner) : HashJoinProberImpl(hash_joiner) {} ~SingleHashJoinProberImpl() override = default; bool probe_chunk_empty() const override { return _probe_chunk == nullptr; } + Status on_input_finished(RuntimeState* state) override { return Status::OK(); } Status push_probe_chunk(RuntimeState* state, ChunkPtr&& chunk) override; StatusOr probe_chunk(RuntimeState* state) override; StatusOr probe_remain(RuntimeState* state, bool* has_remain) override; - void reset() override { + void reset(RuntimeState* runtime_state) override { _probe_chunk.reset(); _current_probe_has_remain = false; + if (_hash_table != nullptr) { + _hash_table->reset_probe_state(runtime_state); + } } void set_ht(JoinHashTable* hash_table) { _hash_table = hash_table; } @@ -85,6 +100,241 @@ void HashJoinProber::attach(HashJoinBuilder* builder, const HashJoinProbeMetrics _impl = builder->create_prober(); } +class PartitionChunkChannel { +public: + PartitionChunkChannel(MemTracker* tracker) : _tracker(tracker) {} + bool processing() const { return _processing; } + void set_processing(bool processing) { _processing = processing; } + + ChunkPtr pull() { + auto chunk = std::move(_chunks.front()); + _tracker->release(chunk->memory_usage()); + _chunks.pop_front(); + return chunk; + } + + void push(ChunkPtr&& chunk) { + _tracker->consume(chunk->memory_usage()); + _chunks.emplace_back(std::move(chunk)); + } + + const ChunkPtr& back() { return _chunks.back(); } + + bool is_full() const { + return _chunks.size() >= 4 || _tracker->consumption() > config::partition_hash_join_probe_limit_size; + } + + size_t size() const { return _chunks.size(); } + + bool is_empty() const { return _chunks.empty() || _chunks.front()->is_empty(); } + + bool not_empty() const { return !is_empty(); } + +private: + MemTracker* _tracker; + std::deque _chunks; + bool _processing = false; +}; + +class PartitionedHashJoinProberImpl final : public HashJoinProberImpl { +public: + PartitionedHashJoinProberImpl(HashJoiner& hash_joiner) : HashJoinProberImpl(hash_joiner) {} + ~PartitionedHashJoinProberImpl() override = default; + bool probe_chunk_empty() const override; + Status on_input_finished(RuntimeState* state) override; + Status push_probe_chunk(RuntimeState* state, ChunkPtr&& chunk) override; + StatusOr probe_chunk(RuntimeState* state) override; + StatusOr probe_remain(RuntimeState* state, bool* has_remain) override; + void reset(RuntimeState* runtime_state) override; + void set_probers(std::vector>&& probers) { + _probers = std::move(probers); + _partition_input_channels.resize(_probers.size(), PartitionChunkChannel(&_mem_tracker)); + } + +private: + MemTracker _mem_tracker; + bool _all_input_finished = false; + int32_t _remain_partition_idx = 0; + std::vector> _probers; + std::vector _partition_input_channels; +}; + +bool PartitionedHashJoinProberImpl::probe_chunk_empty() const { + auto& probers = _probers; + size_t num_partitions = probers.size(); + + if (!_all_input_finished) { + for (size_t i = 0; i < num_partitions; ++i) { + if (!probers[i]->probe_chunk_empty() || _partition_input_channels[i].processing()) { + return false; + } + } + } else { + for (size_t i = 0; i < num_partitions; ++i) { + if (!probers[i]->probe_chunk_empty() || _partition_input_channels[i].not_empty()) { + return false; + } + } + } + + return true; +} + +Status PartitionedHashJoinProberImpl::on_input_finished(RuntimeState* runtime_state) { + SCOPED_TIMER(_hash_joiner.probe_metrics().partition_probe_overhead); + _all_input_finished = true; + auto& probers = _probers; + size_t num_partitions = probers.size(); + + for (size_t i = 0; i < num_partitions; ++i) { + if (_partition_input_channels[i].is_empty()) { + continue; + } + if (!probers[i]->probe_chunk_empty()) { + continue; + } + RETURN_IF_ERROR(probers[i]->push_probe_chunk(runtime_state, _partition_input_channels[i].pull())); + } + return Status::OK(); +} + +Status PartitionedHashJoinProberImpl::push_probe_chunk(RuntimeState* state, ChunkPtr&& chunk) { + SCOPED_TIMER(_hash_joiner.probe_metrics().partition_probe_overhead); + auto& probers = _probers; + auto& partition_keys = _hash_joiner.probe_expr_ctxs(); + + size_t num_rows = chunk->num_rows(); + size_t num_partitions = probers.size(); + size_t num_partition_cols = partition_keys.size(); + + std::vector partition_columns(num_partition_cols); + for (size_t i = 0; i < num_partition_cols; ++i) { + ASSIGN_OR_RETURN(partition_columns[i], partition_keys[i]->evaluate(chunk.get())); + } + std::vector hash_values; + { + hash_values.assign(num_rows, HashUtil::FNV_SEED); + + for (const ColumnPtr& column : partition_columns) { + column->fnv_hash(hash_values.data(), 0, num_rows); + } + // find partition id + for (size_t i = 0; i < hash_values.size(); ++i) { + hash_values[i] = HashUtil::fmix32(hash_values[i]) & (num_partitions - 1); + } + } + + const auto& partitions = hash_values; + + std::vector selection; + selection.resize(chunk->num_rows()); + + std::vector channel_row_idx_start_points; + channel_row_idx_start_points.assign(num_partitions + 1, 0); + + for (uint32_t i : partitions) { + channel_row_idx_start_points[i]++; + } + + for (int32_t i = 1; i <= channel_row_idx_start_points.size() - 1; ++i) { + channel_row_idx_start_points[i] += channel_row_idx_start_points[i - 1]; + } + + for (int32_t i = chunk->num_rows() - 1; i >= 0; --i) { + selection[channel_row_idx_start_points[partitions[i]] - 1] = i; + channel_row_idx_start_points[partitions[i]]--; + } + _partition_input_channels.resize(num_partitions, PartitionChunkChannel(&_mem_tracker)); + + for (size_t i = 0; i < num_partitions; ++i) { + auto from = channel_row_idx_start_points[i]; + auto size = channel_row_idx_start_points[i + 1] - from; + if (size == 0) { + continue; + } + + if (_partition_input_channels[i].is_empty()) { + _partition_input_channels[i].push(chunk->clone_empty()); + } + + if (_partition_input_channels[i].back()->num_rows() + size <= 4096) { + _partition_input_channels[i].back()->append_selective(*chunk, selection.data(), from, size); + } else { + _partition_input_channels[i].push(chunk->clone_empty()); + _partition_input_channels[i].back()->append_selective(*chunk, selection.data(), from, size); + } + + if (_partition_input_channels[i].is_full()) { + _partition_input_channels[i].set_processing(true); + RETURN_IF_ERROR(probers[i]->push_probe_chunk(state, _partition_input_channels[i].pull())); + } + } + + return Status::OK(); +} + +StatusOr PartitionedHashJoinProberImpl::probe_chunk(RuntimeState* state) { + auto& probers = _probers; + size_t num_partitions = probers.size(); + if (_all_input_finished) { + for (size_t i = 0; i < num_partitions; ++i) { + if (probers[i]->probe_chunk_empty() && _partition_input_channels[i].is_empty()) { + continue; + } + if (probers[i]->probe_chunk_empty()) { + RETURN_IF_ERROR(probers[i]->push_probe_chunk(state, _partition_input_channels[i].pull())); + } + auto chunk = std::make_shared(); + ASSIGN_OR_RETURN(chunk, probers[i]->probe_chunk(state)) + return chunk; + } + } else { + for (size_t i = 0; i < num_partitions; ++i) { + if (probers[i]->probe_chunk_empty() && !_partition_input_channels[i].processing()) { + continue; + } + if (probers[i]->probe_chunk_empty()) { + RETURN_IF_ERROR(probers[i]->push_probe_chunk(state, _partition_input_channels[i].pull())); + } + _partition_input_channels[i].set_processing(_partition_input_channels[i].size() > 1); + auto chunk = std::make_shared(); + ASSIGN_OR_RETURN(chunk, probers[i]->probe_chunk(state)) + return chunk; + } + } + CHECK(false); + + return nullptr; +} + +StatusOr PartitionedHashJoinProberImpl::probe_remain(RuntimeState* state, bool* has_remain) { + auto& probers = _probers; + size_t num_partitions = probers.size(); + while (_remain_partition_idx < num_partitions) { + auto chunk = std::make_shared(); + bool sub_map_has_remain = false; + ASSIGN_OR_RETURN(chunk, probers[_remain_partition_idx]->probe_remain(state, &sub_map_has_remain)); + if (!sub_map_has_remain) { + _remain_partition_idx++; + } + if (chunk->is_empty()) { + continue; + } + *has_remain = true; + return chunk; + } + + *has_remain = false; + return nullptr; +} + +void PartitionedHashJoinProberImpl::reset(RuntimeState* runtime_state) { + _probers.clear(); + _partition_input_channels.clear(); + _all_input_finished = false; + _remain_partition_idx = 0; +} + void SingleHashJoinBuilder::create(const HashTableParam& param) { _ht.create(param); } @@ -99,11 +349,6 @@ void SingleHashJoinBuilder::reset(const HashTableParam& param) { create(param); } -void SingleHashJoinBuilder::reset_probe(RuntimeState* state) { - _key_columns.clear(); - _ht.reset_probe_state(state); -} - bool SingleHashJoinBuilder::anti_join_key_column_has_null() const { if (_ht.get_key_columns().size() != 1) { return false; @@ -155,4 +400,388 @@ ChunkPtr SingleHashJoinBuilder::convert_to_spill_schema(const ChunkPtr& chunk) c return _ht.convert_to_spill_schema(chunk); } +enum class CacheLevel { L2, L3, MEMORY }; + +class AdaptivePartitionHashJoinBuilder final : public HashJoinBuilder { +public: + AdaptivePartitionHashJoinBuilder(HashJoiner& hash_joiner); + ~AdaptivePartitionHashJoinBuilder() override = default; + + void create(const HashTableParam& param) override; + + void close() override; + + void reset(const HashTableParam& param) override; + + Status do_append_chunk(const ChunkPtr& chunk) override; + + Status build(RuntimeState* state) override; + + bool anti_join_key_column_has_null() const override; + + int64_t ht_mem_usage() const override; + + void get_build_info(size_t* bucket_size, float* avg_keys_per_bucket) override; + + size_t get_output_probe_column_count() const override; + size_t get_output_build_column_count() const override; + + void visitHt(const std::function& visitor) override; + + std::unique_ptr create_prober() override; + + void clone_readable(HashJoinBuilder* builder) override; + + ChunkPtr convert_to_spill_schema(const ChunkPtr& chunk) const override; + +private: + size_t _estimated_row_size(const HashTableParam& param) const; + size_t _estimated_probe_cost(const HashTableParam& param) const; + template + size_t _estimated_build_cost(size_t build_row_size) const; + void _adjust_partition_rows(size_t build_row_size); + + void _init_partition_nums(const HashTableParam& param); + Status _convert_to_single_partition(); + Status _append_chunk_to_partitions(const ChunkPtr& chunk); + +private: + std::vector> _builders; + + size_t _partition_num = 0; + size_t _partition_join_min_rows = 0; + size_t _partition_join_max_rows = 0; + + size_t _probe_estimated_costs = 0; + + size_t _fit_L2_cache_max_rows = 0; + size_t _fit_L3_cache_max_rows = 0; + + size_t _L2_cache_size = 0; + size_t _L3_cache_size = 0; + + size_t _pushed_chunks = 0; +}; + +AdaptivePartitionHashJoinBuilder::AdaptivePartitionHashJoinBuilder(HashJoiner& hash_joiner) + : HashJoinBuilder(hash_joiner) { + static constexpr size_t DEFAULT_L2_CACHE_SIZE = 1 * 1024 * 1024; + static constexpr size_t DEFAULT_L3_CACHE_SIZE = 32 * 1024 * 1024; + const auto& cache_sizes = CpuInfo::get_cache_sizes(); + _L2_cache_size = cache_sizes[CpuInfo::L2_CACHE]; + _L3_cache_size = cache_sizes[CpuInfo::L3_CACHE]; + _L2_cache_size = _L2_cache_size ? _L2_cache_size : DEFAULT_L2_CACHE_SIZE; + _L3_cache_size = _L3_cache_size ? _L3_cache_size : DEFAULT_L3_CACHE_SIZE; +} + +size_t AdaptivePartitionHashJoinBuilder::_estimated_row_size(const HashTableParam& param) const { + size_t estimated_each_row = 0; + + for (auto* tuple : param.build_row_desc->tuple_descriptors()) { + for (auto slot : tuple->slots()) { + if (param.build_output_slots.contains(slot->id())) { + estimated_each_row += get_size_of_fixed_length_type(slot->type().type); + estimated_each_row += type_estimated_overhead_bytes(slot->type().type); + } + } + } + + // for hash table bucket + estimated_each_row += 4; + + return estimated_each_row; +} + +// We could use a better estimation model. +size_t AdaptivePartitionHashJoinBuilder::_estimated_probe_cost(const HashTableParam& param) const { + size_t size = 0; + + for (auto* tuple : param.probe_row_desc->tuple_descriptors()) { + for (auto slot : tuple->slots()) { + if (param.probe_output_slots.contains(slot->id())) { + size += get_size_of_fixed_length_type(slot->type().type); + size += type_estimated_overhead_bytes(slot->type().type); + } + } + } + // we define probe cost is bytes size * 6 + return size * 6; +} + +template <> +size_t AdaptivePartitionHashJoinBuilder::_estimated_build_cost(size_t build_row_size) const { + return build_row_size / 2; +} + +template <> +size_t AdaptivePartitionHashJoinBuilder::_estimated_build_cost(size_t build_row_size) const { + return build_row_size; +} + +template <> +size_t AdaptivePartitionHashJoinBuilder::_estimated_build_cost(size_t build_row_size) const { + return build_row_size * 2; +} + +void AdaptivePartitionHashJoinBuilder::_adjust_partition_rows(size_t build_row_size) { + build_row_size = std::max(build_row_size, 4UL); + _fit_L2_cache_max_rows = _L2_cache_size / build_row_size; + _fit_L3_cache_max_rows = _L3_cache_size / build_row_size; + + // If the hash table is smaller than the L2 cache. we don't think partition hash join is needed. + _partition_join_min_rows = _fit_L2_cache_max_rows; + // If the hash table after partition can't be loaded to L3. we don't think partition hash join is needed. + _partition_join_max_rows = _fit_L3_cache_max_rows * _partition_num; + + if (_probe_estimated_costs + _estimated_build_cost(build_row_size) < + _estimated_build_cost(build_row_size)) { + // overhead after hash table partitioning + probe extra cost < cost before partitioning + // nothing to do + } else if (_probe_estimated_costs + _estimated_build_cost(build_row_size) < + _estimated_build_cost(build_row_size)) { + // It is only after this that performance gains can be realized beyond the L3 cache. + _partition_join_min_rows = _fit_L3_cache_max_rows; + } else { + // Partitioned joins don't have performance gains. Not using partition hash join. + _partition_num = 1; + } + + VLOG_OPERATOR << "TRACE:" + << "partition_num=" << _partition_num << " partition_join_min_rows=" << _partition_join_min_rows + << " partition_join_max_rows=" << _partition_join_max_rows << " probe cost=" << _probe_estimated_costs + << " build cost L2=" << _estimated_build_cost(build_row_size) + << " build cost L3=" << _estimated_build_cost(build_row_size) + << " build cost Mem=" << _estimated_build_cost(build_row_size); +} + +void AdaptivePartitionHashJoinBuilder::_init_partition_nums(const HashTableParam& param) { + _partition_num = 16; + + size_t estimated_bytes_each_row = _estimated_row_size(param); + + _probe_estimated_costs = _estimated_probe_cost(param); + + _adjust_partition_rows(estimated_bytes_each_row); + + COUNTER_SET(_hash_joiner.build_metrics().partition_nums, (int64_t)_partition_num); +} + +void AdaptivePartitionHashJoinBuilder::create(const HashTableParam& param) { + _init_partition_nums(param); + for (size_t i = 0; i < _partition_num; ++i) { + _builders.emplace_back(std::make_unique(_hash_joiner)); + _builders.back()->create(param); + } +} + +void AdaptivePartitionHashJoinBuilder::close() { + for (const auto& builder : _builders) { + builder->close(); + } + _builders.clear(); + _partition_num = 0; + _partition_join_min_rows = 0; + _partition_join_max_rows = 0; + _probe_estimated_costs = 0; + _fit_L2_cache_max_rows = 0; + _fit_L3_cache_max_rows = 0; + _pushed_chunks = 0; +} + +void AdaptivePartitionHashJoinBuilder::reset(const HashTableParam& param) { + close(); + create(param); +} + +bool AdaptivePartitionHashJoinBuilder::anti_join_key_column_has_null() const { + return std::any_of(_builders.begin(), _builders.end(), + [](const auto& builder) { return builder->anti_join_key_column_has_null(); }); +} + +void AdaptivePartitionHashJoinBuilder::get_build_info(size_t* bucket_size, float* avg_keys_per_bucket) { + size_t total_bucket_size = 0; + float total_keys_per_bucket = 0; + for (const auto& builder : _builders) { + size_t bucket_size = 0; + float keys_per_bucket = 0; + builder->get_build_info(&bucket_size, &keys_per_bucket); + total_bucket_size += bucket_size; + total_keys_per_bucket += keys_per_bucket; + } + *bucket_size = total_bucket_size; + *avg_keys_per_bucket = total_keys_per_bucket / _builders.size(); +} + +size_t AdaptivePartitionHashJoinBuilder::get_output_probe_column_count() const { + return _builders[0]->get_output_probe_column_count(); +} + +size_t AdaptivePartitionHashJoinBuilder::get_output_build_column_count() const { + return _builders[0]->get_output_build_column_count(); +} + +int64_t AdaptivePartitionHashJoinBuilder::ht_mem_usage() const { + return std::accumulate(_builders.begin(), _builders.end(), 0L, + [](int64_t sum, const auto& builder) { return sum + builder->ht_mem_usage(); }); +} + +Status AdaptivePartitionHashJoinBuilder::_convert_to_single_partition() { + // merge all partition data to the first partition + for (size_t i = 1; i < _partition_num; ++i) { + _builders[0]->hash_table().merge_ht(_builders[i]->hash_table()); + } + _builders.resize(1); + _partition_num = 1; + return Status::OK(); +} + +Status AdaptivePartitionHashJoinBuilder::_append_chunk_to_partitions(const ChunkPtr& chunk) { + const std::vector& build_partition_keys = _hash_joiner.build_expr_ctxs(); + + size_t num_rows = chunk->num_rows(); + size_t num_partitions = _builders.size(); + size_t num_partition_cols = build_partition_keys.size(); + + std::vector partition_columns(num_partition_cols); + for (size_t i = 0; i < num_partition_cols; ++i) { + ASSIGN_OR_RETURN(partition_columns[i], build_partition_keys[i]->evaluate(chunk.get())); + } + std::vector hash_values; + { + hash_values.assign(num_rows, HashUtil::FNV_SEED); + + for (const ColumnPtr& column : partition_columns) { + column->fnv_hash(hash_values.data(), 0, num_rows); + } + // find partition id + for (size_t i = 0; i < hash_values.size(); ++i) { + hash_values[i] = HashUtil::fmix32(hash_values[i]) & (num_partitions - 1); + } + } + + const auto& partitions = hash_values; + + std::vector selection; + selection.resize(chunk->num_rows()); + + std::vector channel_row_idx_start_points; + channel_row_idx_start_points.assign(num_partitions + 1, 0); + + for (uint32_t i : partitions) { + channel_row_idx_start_points[i]++; + } + + for (int32_t i = 1; i <= channel_row_idx_start_points.size() - 1; ++i) { + channel_row_idx_start_points[i] += channel_row_idx_start_points[i - 1]; + } + + for (int32_t i = chunk->num_rows() - 1; i >= 0; --i) { + selection[channel_row_idx_start_points[partitions[i]] - 1] = i; + channel_row_idx_start_points[partitions[i]]--; + } + + for (size_t i = 0; i < num_partitions; ++i) { + auto from = channel_row_idx_start_points[i]; + auto size = channel_row_idx_start_points[i + 1] - from; + if (size == 0) { + continue; + } + // TODO: make builder implements append with selective + auto partition_chunk = chunk->clone_empty(); + partition_chunk->append_selective(*chunk, selection.data(), from, size); + RETURN_IF_ERROR(_builders[i]->append_chunk(std::move(partition_chunk))); + } + return Status::OK(); +} + +Status AdaptivePartitionHashJoinBuilder::do_append_chunk(const ChunkPtr& chunk) { + if (_partition_num > 1 && hash_table_row_count() > _partition_join_max_rows) { + RETURN_IF_ERROR(_convert_to_single_partition()); + } + + if (_partition_num > 1 && ++_pushed_chunks % 8 == 0) { + size_t build_row_size = ht_mem_usage() / hash_table_row_count(); + _adjust_partition_rows(build_row_size); + if (_partition_num == 1) { + RETURN_IF_ERROR(_convert_to_single_partition()); + } + } + + if (_partition_num > 1) { + RETURN_IF_ERROR(_append_chunk_to_partitions(chunk)); + } else { + RETURN_IF_ERROR(_builders[0]->do_append_chunk(chunk)); + } + + return Status::OK(); +} + +ChunkPtr AdaptivePartitionHashJoinBuilder::convert_to_spill_schema(const ChunkPtr& chunk) const { + return _builders[0]->convert_to_spill_schema(chunk); +} + +Status AdaptivePartitionHashJoinBuilder::build(RuntimeState* state) { + DCHECK_EQ(_partition_num, _builders.size()); + + if (_partition_num > 1 && hash_table_row_count() < _partition_join_min_rows) { + RETURN_IF_ERROR(_convert_to_single_partition()); + } + + for (auto& builder : _builders) { + RETURN_IF_ERROR(builder->build(state)); + } + _ready = true; + return Status::OK(); +} + +void AdaptivePartitionHashJoinBuilder::visitHt(const std::function& visitor) { + for (auto& builder : _builders) { + builder->visitHt(visitor); + } +} + +std::unique_ptr AdaptivePartitionHashJoinBuilder::create_prober() { + DCHECK_EQ(_partition_num, _builders.size()); + + if (_partition_num == 1) { + return _builders[0]->create_prober(); + } else { + std::vector> sub_probers; + auto prober = std::make_unique(_hash_joiner); + sub_probers.resize(_partition_num); + for (size_t i = 0; i < _builders.size(); ++i) { + sub_probers[i].reset(down_cast(_builders[i]->create_prober().release())); + } + prober->set_probers(std::move(sub_probers)); + return prober; + } +} + +void AdaptivePartitionHashJoinBuilder::clone_readable(HashJoinBuilder* builder) { + for (auto& builder : _builders) { + DCHECK(builder->ready()); + } + DCHECK(_ready); + DCHECK_EQ(_partition_num, _builders.size()); + auto other = down_cast(builder); + other->_builders.clear(); + other->_partition_num = _partition_num; + other->_partition_join_max_rows = _partition_join_max_rows; + other->_partition_join_min_rows = _partition_join_min_rows; + other->_ready = _ready; + for (size_t i = 0; i < _partition_num; ++i) { + other->_builders.emplace_back(std::make_unique(_hash_joiner)); + _builders[i]->clone_readable(other->_builders[i].get()); + } +} + +HashJoinBuilder* HashJoinBuilderFactory::create(ObjectPool* pool, const HashJoinBuildOptions& options, + HashJoiner& hash_joiner) { + if (options.enable_partitioned_hash_join) { + return pool->add(new AdaptivePartitionHashJoinBuilder(hash_joiner)); + } else { + return pool->add(new SingleHashJoinBuilder(hash_joiner)); + } +} + } // namespace starrocks diff --git a/be/src/exec/hash_join_components.h b/be/src/exec/hash_join_components.h index 3d1a5b08c2a8d7..a582fb14e7bdb9 100644 --- a/be/src/exec/hash_join_components.h +++ b/be/src/exec/hash_join_components.h @@ -31,10 +31,11 @@ class HashJoinProberImpl { public: virtual ~HashJoinProberImpl() = default; virtual bool probe_chunk_empty() const = 0; + virtual Status on_input_finished(RuntimeState* state) = 0; virtual Status push_probe_chunk(RuntimeState* state, ChunkPtr&& chunk) = 0; virtual StatusOr probe_chunk(RuntimeState* state) = 0; virtual StatusOr probe_remain(RuntimeState* state, bool* has_remain) = 0; - virtual void reset() = 0; + virtual void reset(RuntimeState* runtime_state) = 0; protected: HashJoinProberImpl(HashJoiner& hash_joiner) : _hash_joiner(hash_joiner) {} @@ -51,6 +52,13 @@ class HashJoinProber { return _impl->push_probe_chunk(state, std::move(chunk)); } + Status on_input_finished(RuntimeState* state) { + if (_impl == nullptr) { + return Status::OK(); + } + return _impl->on_input_finished(state); + } + // probe hash table StatusOr probe_chunk(RuntimeState* state) { return _impl->probe_chunk(state); } @@ -58,7 +66,7 @@ class HashJoinProber { return _impl->probe_remain(state, has_remain); } - void reset() { return _impl->reset(); } + void reset(RuntimeState* runtime_state) { return _impl->reset(runtime_state); } HashJoinProber* clone_empty(ObjectPool* pool) { return pool->add(new HashJoinProber(_hash_joiner)); } @@ -96,10 +104,6 @@ class HashJoinBuilder { virtual void reset(const HashTableParam& param) = 0; - virtual void reset_probe(RuntimeState* state) = 0; - - virtual HashJoinBuilder* clone_empty(ObjectPool* pool) = 0; - virtual int64_t ht_mem_usage() const = 0; // used for check NULL_AWARE_LEFT_ANTI_JOIN build side has null @@ -118,6 +122,7 @@ class HashJoinBuilder { virtual std::unique_ptr create_prober() = 0; + // clone readable to to builder virtual void clone_readable(HashJoinBuilder* builder) = 0; virtual ChunkPtr convert_to_spill_schema(const ChunkPtr& chunk) const = 0; @@ -148,12 +153,6 @@ class SingleHashJoinBuilder final : public HashJoinBuilder { Status build(RuntimeState* state) override; - void reset_probe(RuntimeState* state) override; - - SingleHashJoinBuilder* clone_empty(ObjectPool* pool) override { - return pool->add(new SingleHashJoinBuilder(_hash_joiner)); - } - bool anti_join_key_column_has_null() const override; int64_t ht_mem_usage() const override { return _ht.mem_usage(); } @@ -179,4 +178,12 @@ class SingleHashJoinBuilder final : public HashJoinBuilder { Columns _key_columns; }; +struct HashJoinBuildOptions { + bool enable_partitioned_hash_join = false; +}; + +class HashJoinBuilderFactory { +public: + static HashJoinBuilder* create(ObjectPool* pool, const HashJoinBuildOptions& options, HashJoiner& hash_joiner); +}; } // namespace starrocks \ No newline at end of file diff --git a/be/src/exec/hash_join_node.cpp b/be/src/exec/hash_join_node.cpp index 10689bbccd9b25..fd1712e3cf51b8 100644 --- a/be/src/exec/hash_join_node.cpp +++ b/be/src/exec/hash_join_node.cpp @@ -145,6 +145,9 @@ Status HashJoinNode::init(const TPlanNode& tnode, RuntimeState* state) { if (tnode.hash_join_node.__isset.late_materialization) { _enable_late_materialization = tnode.hash_join_node.late_materialization; } + if (tnode.hash_join_node.__isset.enable_partition_hash_join) { + _enable_partition_hash_join = tnode.hash_join_node.enable_partition_hash_join; + } return Status::OK(); } @@ -203,6 +206,7 @@ void HashJoinNode::_init_hash_table_param(HashTableParam* param) { param->build_output_slots = _output_slots; param->probe_output_slots = _output_slots; param->enable_late_materialization = _enable_late_materialization; + param->enable_partition_hash_join = _enable_partition_hash_join; std::set predicate_slots; for (ExprContext* expr_context : _conjunct_ctxs) { @@ -472,7 +476,8 @@ pipeline::OpFactories HashJoinNode::_decompose_to_pipeline(pipeline::PipelineBui HashJoinerParam param(pool, _hash_join_node, _is_null_safes, _build_expr_ctxs, _probe_expr_ctxs, _other_join_conjunct_ctxs, _conjunct_ctxs, child(1)->row_desc(), child(0)->row_desc(), child(1)->type(), child(0)->type(), child(1)->conjunct_ctxs().empty(), _build_runtime_filters, - _output_slots, _output_slots, _distribution_mode, false, _enable_late_materialization); + _output_slots, _output_slots, _distribution_mode, false, _enable_late_materialization, + _enable_partition_hash_join); auto hash_joiner_factory = std::make_shared(param); // Create a shared RefCountedRuntimeFilterCollector diff --git a/be/src/exec/hash_join_node.h b/be/src/exec/hash_join_node.h index d83c0b9fed0bec..f7ea19015b24f8 100644 --- a/be/src/exec/hash_join_node.h +++ b/be/src/exec/hash_join_node.h @@ -115,6 +115,7 @@ class HashJoinNode final : public ExecNode { bool _is_push_down = false; bool _enable_late_materialization = false; + bool _enable_partition_hash_join = false; JoinHashTable _ht; diff --git a/be/src/exec/hash_joiner.cpp b/be/src/exec/hash_joiner.cpp index e7ea8a9f37d133..7101254c86b3c3 100644 --- a/be/src/exec/hash_joiner.cpp +++ b/be/src/exec/hash_joiner.cpp @@ -44,6 +44,7 @@ void HashJoinProbeMetrics::prepare(RuntimeProfile* runtime_profile) { other_join_conjunct_evaluate_timer = ADD_TIMER(runtime_profile, "OtherJoinConjunctEvaluateTime"); where_conjunct_evaluate_timer = ADD_TIMER(runtime_profile, "WhereConjunctEvaluateTime"); probe_counter = ADD_COUNTER(runtime_profile, "probeCount", TUnit::UNIT); + partition_probe_overhead = ADD_TIMER(runtime_profile, "PartitionProbeOverhead"); } void HashJoinBuildMetrics::prepare(RuntimeProfile* runtime_profile) { @@ -55,8 +56,8 @@ void HashJoinBuildMetrics::prepare(RuntimeProfile* runtime_profile) { runtime_filter_num = ADD_COUNTER(runtime_profile, "RuntimeFilterNum", TUnit::UNIT); build_keys_per_bucket = ADD_COUNTER(runtime_profile, "BuildKeysPerBucket%", TUnit::UNIT); hash_table_memory_usage = ADD_COUNTER(runtime_profile, "HashTableMemoryUsage", TUnit::BYTES); - partial_runtime_bloom_filter_bytes = ADD_COUNTER(runtime_profile, "PartialRuntimeBloomFilterBytes", TUnit::BYTES); + partition_nums = ADD_COUNTER(runtime_profile, "PartitionNums", TUnit::UNIT); } HashJoiner::HashJoiner(const HashJoinerParam& param) @@ -87,7 +88,11 @@ HashJoiner::HashJoiner(const HashJoinerParam& param) if (param._hash_join_node.__isset.build_runtime_filters_from_planner) { _build_runtime_filters_from_planner = param._hash_join_node.build_runtime_filters_from_planner; } - _hash_join_builder = _pool->add(new SingleHashJoinBuilder(*this)); + + HashJoinBuildOptions build_options; + build_options.enable_partitioned_hash_join = param._enable_partition_hash_join; + + _hash_join_builder = HashJoinBuilderFactory::create(_pool, build_options, *this); _hash_join_prober = _pool->add(new HashJoinProber(*this)); _build_metrics = _pool->add(new HashJoinBuildMetrics()); _probe_metrics = _pool->add(new HashJoinProbeMetrics()); @@ -247,8 +252,8 @@ Status HashJoiner::push_chunk(RuntimeState* state, ChunkPtr&& chunk) { return _hash_join_prober->push_probe_chunk(state, std::move(chunk)); } -Status HashJoiner::probe_input_finished() { - return Status::OK(); +Status HashJoiner::probe_input_finished(RuntimeState* state) { + return _hash_join_prober->on_input_finished(state); } StatusOr HashJoiner::pull_chunk(RuntimeState* state) { @@ -365,9 +370,7 @@ Status HashJoiner::reset_probe(starrocks::RuntimeState* state) { return Status::OK(); } - _hash_join_prober->reset(); - - _hash_join_builder->reset_probe(state); + _hash_join_prober->reset(state); return Status::OK(); } diff --git a/be/src/exec/hash_joiner.h b/be/src/exec/hash_joiner.h index 5d2fa1cc75124b..e671bd6c4b45fa 100644 --- a/be/src/exec/hash_joiner.h +++ b/be/src/exec/hash_joiner.h @@ -72,7 +72,7 @@ struct HashJoinerParam { bool build_conjunct_ctxs_is_empty, std::list build_runtime_filters, std::set build_output_slots, std::set probe_output_slots, const TJoinDistributionMode::type distribution_mode, bool mor_reader_mode, - bool enable_late_materialization) + bool enable_late_materialization, bool enable_partition_hash_join) : _pool(pool), _hash_join_node(hash_join_node), _is_null_safes(std::move(is_null_safes)), @@ -90,7 +90,8 @@ struct HashJoinerParam { _probe_output_slots(std::move(probe_output_slots)), _distribution_mode(distribution_mode), _mor_reader_mode(mor_reader_mode), - _enable_late_materialization(enable_late_materialization) {} + _enable_late_materialization(enable_late_materialization), + _enable_partition_hash_join(enable_partition_hash_join) {} HashJoinerParam(HashJoinerParam&&) = default; HashJoinerParam(HashJoinerParam&) = default; @@ -115,6 +116,7 @@ struct HashJoinerParam { const TJoinDistributionMode::type _distribution_mode; const bool _mor_reader_mode; const bool _enable_late_materialization; + const bool _enable_partition_hash_join; }; inline bool could_short_circuit(TJoinOp::type join_type) { @@ -143,6 +145,7 @@ struct HashJoinProbeMetrics { RuntimeProfile::Counter* where_conjunct_evaluate_timer = nullptr; RuntimeProfile::Counter* output_build_column_timer = nullptr; RuntimeProfile::Counter* probe_counter = nullptr; + RuntimeProfile::Counter* partition_probe_overhead = nullptr; void prepare(RuntimeProfile* runtime_profile); }; @@ -156,8 +159,8 @@ struct HashJoinBuildMetrics { RuntimeProfile::Counter* runtime_filter_num = nullptr; RuntimeProfile::Counter* build_keys_per_bucket = nullptr; RuntimeProfile::Counter* hash_table_memory_usage = nullptr; - RuntimeProfile::Counter* partial_runtime_bloom_filter_bytes = nullptr; + RuntimeProfile::Counter* partition_nums = nullptr; void prepare(RuntimeProfile* runtime_profile); }; @@ -210,7 +213,7 @@ class HashJoiner final : public pipeline::ContextWithDependency { Status build_ht(RuntimeState* state); // probe phase Status push_chunk(RuntimeState* state, ChunkPtr&& chunk); - Status probe_input_finished(); + Status probe_input_finished(RuntimeState* state); StatusOr pull_chunk(RuntimeState* state); pipeline::RuntimeInFilters& get_runtime_in_filters() { return _runtime_in_filters; } @@ -274,10 +277,16 @@ class HashJoiner final : public pipeline::ContextWithDependency { return Status::OK(); } - const std::vector probe_expr_ctxs() { return _probe_expr_ctxs; } + const std::vector& probe_expr_ctxs() { return _probe_expr_ctxs; } + const std::vector& build_expr_ctxs() { return _build_expr_ctxs; } HashJoinProber* new_prober(ObjectPool* pool) { return _hash_join_prober->clone_empty(pool); } - HashJoinBuilder* new_builder(ObjectPool* pool) { return _hash_join_builder->clone_empty(pool); } + HashJoinBuilder* new_builder(ObjectPool* pool) { + // We don't support spill partition hash join now. + HashJoinBuildOptions options; + options.enable_partitioned_hash_join = false; + return HashJoinBuilderFactory::create(pool, options, *this); + } Status filter_probe_output_chunk(ChunkPtr& chunk, JoinHashTable& hash_table) { // Probe in JoinHashMap is divided into probe with other_conjuncts and without other_conjuncts. diff --git a/be/src/exec/join_hash_map.cpp b/be/src/exec/join_hash_map.cpp index afbe8107586b31..9389cb90e6c66f 100644 --- a/be/src/exec/join_hash_map.cpp +++ b/be/src/exec/join_hash_map.cpp @@ -656,6 +656,21 @@ void JoinHashTable::append_chunk(const ChunkPtr& chunk, const Columns& key_colum _table_items->row_count += chunk->num_rows(); } +void JoinHashTable::merge_ht(const JoinHashTable& ht) { + _table_items->row_count += ht._table_items->row_count; + + Columns& columns = _table_items->build_chunk->columns(); + Columns& other_columns = ht._table_items->build_chunk->columns(); + + for (size_t i = 0; i < _table_items->build_column_count; i++) { + if (!columns[i]->is_nullable() && other_columns[i]->is_nullable()) { + // upgrade to nullable column + columns[i] = NullableColumn::create(columns[i], NullColumn::create(columns[i]->size(), 0)); + } + columns[i]->append(*other_columns[i], 1, other_columns[i]->size() - 1); + } +} + ChunkPtr JoinHashTable::convert_to_spill_schema(const ChunkPtr& chunk) const { DCHECK(chunk != nullptr && chunk->num_rows() > 0); ChunkPtr output = std::make_shared(); diff --git a/be/src/exec/join_hash_map.h b/be/src/exec/join_hash_map.h index 9f4e5eb05b9396..781a466cd63d62 100644 --- a/be/src/exec/join_hash_map.h +++ b/be/src/exec/join_hash_map.h @@ -280,6 +280,7 @@ struct HashTableProbeState { struct HashTableParam { bool with_other_conjunct = false; bool enable_late_materialization = false; + bool enable_partition_hash_join = false; TJoinOp::type join_type = TJoinOp::INNER_JOIN; const RowDescriptor* build_row_desc = nullptr; const RowDescriptor* probe_row_desc = nullptr; @@ -833,6 +834,7 @@ class JoinHashTable { Status lazy_output(RuntimeState* state, ChunkPtr* probe_chunk, ChunkPtr* result_chunk); void append_chunk(const ChunkPtr& chunk, const Columns& key_columns); + void merge_ht(const JoinHashTable& ht); // convert input column to spill schema order ChunkPtr convert_to_spill_schema(const ChunkPtr& chunk) const; diff --git a/be/src/exec/mor_processor.cpp b/be/src/exec/mor_processor.cpp index dd6a881f58b6d4..63a05bbc1a22a0 100644 --- a/be/src/exec/mor_processor.cpp +++ b/be/src/exec/mor_processor.cpp @@ -48,7 +48,7 @@ Status IcebergMORProcessor::init(RuntimeState* runtime_state, const MORParams& p std::vector(), std::vector(), *_build_row_desc, *_probe_row_desc, TPlanNodeType::HDFS_SCAN_NODE, TPlanNodeType::HDFS_SCAN_NODE, true, std::list(), std::set(), probe_output_slot_ids, - TJoinDistributionMode::PARTITIONED, true, false)); + TJoinDistributionMode::PARTITIONED, true, false, false)); _hash_joiner = _pool.add(new HashJoiner(*param)); RETURN_IF_ERROR(_hash_joiner->prepare_builder(runtime_state, _runtime_profile)); diff --git a/be/src/exec/parquet_scanner.cpp b/be/src/exec/parquet_scanner.cpp index 5e6a0945dec2ea..500cba44b14b1d 100644 --- a/be/src/exec/parquet_scanner.cpp +++ b/be/src/exec/parquet_scanner.cpp @@ -108,14 +108,6 @@ Status ParquetScanner::append_batch_to_src_chunk(ChunkPtr* chunk) { _conv_ctx.current_slot = slot_desc; auto* array = _batch->column(column_pos++).get(); auto& column = (*chunk)->get_column_by_slot_id(slot_desc->id()); - // for timestamp type, _state->timezone which is specified by user. convert function - // obtains timezone from array. thus timezone in array should be rectified to - // _state->timezone. - if (array->type_id() == ArrowTypeId::TIMESTAMP) { - auto* timestamp_type = down_cast(array->type().get()); - auto& mutable_timezone = (std::string&)timestamp_type->timezone(); - mutable_timezone = _state->timezone(); - } RETURN_IF_ERROR(convert_array_to_column(_conv_funcs[i].get(), num_elements, array, column, _batch_start_idx, _chunk_start_idx, &_chunk_filter, &_conv_ctx)); } @@ -321,6 +313,15 @@ Status ParquetScanner::convert_array_to_column(ConvertFuncTree* conv_func, size_ const arrow::Array* array, const ColumnPtr& column, size_t batch_start_idx, size_t chunk_start_idx, Filter* chunk_filter, ArrowConvertContext* conv_ctx) { + // for timestamp type, state->timezone which is specified by user. convert function + // obtains timezone from array. thus timezone in array should be rectified to + // state->timezone. + if (array->type_id() == ArrowTypeId::TIMESTAMP) { + auto* timestamp_type = down_cast(array->type().get()); + auto& mutable_timezone = (std::string&)timestamp_type->timezone(); + mutable_timezone = conv_ctx->state->timezone(); + } + uint8_t* null_data; Column* data_column; if (column->is_nullable()) { diff --git a/be/src/exec/parquet_schema_builder.cpp b/be/src/exec/parquet_schema_builder.cpp index 8a3ac6bc1f9ead..f582520ed9da59 100644 --- a/be/src/exec/parquet_schema_builder.cpp +++ b/be/src/exec/parquet_schema_builder.cpp @@ -23,6 +23,7 @@ static Status get_parquet_type_from_group(const ::parquet::schema::NodePtr& node static Status get_parquet_type_from_primitive(const ::parquet::schema::NodePtr& node, TypeDescriptor* type_desc); static Status get_parquet_type_from_list(const ::parquet::schema::NodePtr& node, TypeDescriptor* type_desc); static Status get_parquet_type_from_map(const ::parquet::schema::NodePtr& node, TypeDescriptor* type_desc); +static Status try_to_infer_struct_type(const ::parquet::schema::NodePtr& node, TypeDescriptor* type_desc); Status get_parquet_type(const ::parquet::schema::NodePtr& node, TypeDescriptor* type_desc) { if (node->is_group()) { @@ -121,6 +122,11 @@ static Status get_parquet_type_from_group(const ::parquet::schema::NodePtr& node return get_parquet_type_from_map(node, type_desc); } + auto st = try_to_infer_struct_type(node, type_desc); + if (st.ok()) { + return Status::OK(); + } + // Treat unsupported types as VARCHAR. *type_desc = TypeDescriptor::create_varchar_type(TypeDescriptor::MAX_VARCHAR_LENGTH); return Status::OK(); @@ -217,4 +223,44 @@ static Status get_parquet_type_from_map(const ::parquet::schema::NodePtr& node, return Status::OK(); } +/* +try to infer struct type from group node. + +parquet does not have struct type, there is no struct definition in parquet. +try to infer like this. +group { + type field0; + type field1; + ... +} +*/ +static Status try_to_infer_struct_type(const ::parquet::schema::NodePtr& node, TypeDescriptor* type_desc) { + // 1st level. + // group name + DCHECK(node->is_group()); + + auto group_node = std::static_pointer_cast<::parquet::schema::GroupNode>(node); + int field_count = group_node->field_count(); + if (field_count == 0) { + return Status::Unknown("unknown type"); + } + + // 2nd level. + // field + std::vector field_names; + std::vector field_types; + field_names.reserve(field_count); + field_types.reserve(field_count); + for (auto i = 0; i < group_node->field_count(); ++i) { + const auto& field = group_node->field(i); + field_names.emplace_back(field->name()); + auto& field_type_desc = field_types.emplace_back(); + RETURN_IF_ERROR(get_parquet_type(field, &field_type_desc)); + } + + *type_desc = TypeDescriptor::create_struct_type(field_names, field_types); + + return Status::OK(); +} + } //namespace starrocks \ No newline at end of file diff --git a/be/src/exec/pipeline/exec_state_reporter.cpp b/be/src/exec/pipeline/exec_state_reporter.cpp index 8d8407f4ad6def..996f8985ef2ca6 100644 --- a/be/src/exec/pipeline/exec_state_reporter.cpp +++ b/be/src/exec/pipeline/exec_state_reporter.cpp @@ -182,6 +182,7 @@ Status ExecStateReporter::report_exec_status(const TReportExecStatusParams& para } coord->reportExecStatus(res, params); } else { + (void)coord.reopen(config::thrift_rpc_timeout_ms); std::stringstream msg; msg << "ReportExecStatus() to " << fe_addr << " failed:\n" << e.what(); LOG(WARNING) << msg.str(); @@ -296,6 +297,7 @@ Status ExecStateReporter::report_epoch(const TMVMaintenanceTasks& params, ExecEn } coord->mvReport(res, params); } else { + (void)coord.reopen(config::thrift_rpc_timeout_ms); std::stringstream msg; msg << "mvReport() to " << fe_addr << " failed:\n" << e.what(); LOG(WARNING) << msg.str(); diff --git a/be/src/exec/pipeline/hashjoin/hash_join_probe_operator.cpp b/be/src/exec/pipeline/hashjoin/hash_join_probe_operator.cpp index 61158cf0e69ad9..7b4225054fef03 100644 --- a/be/src/exec/pipeline/hashjoin/hash_join_probe_operator.cpp +++ b/be/src/exec/pipeline/hashjoin/hash_join_probe_operator.cpp @@ -87,7 +87,7 @@ StatusOr HashJoinProbeOperator::pull_chunk(RuntimeState* state) { } Status HashJoinProbeOperator::set_finishing(RuntimeState* state) { - RETURN_IF_ERROR(_join_prober->probe_input_finished()); + RETURN_IF_ERROR(_join_prober->probe_input_finished(state)); _join_prober->enter_post_probe_phase(); return Status::OK(); } diff --git a/be/src/exec/pipeline/pipeline_driver.cpp b/be/src/exec/pipeline/pipeline_driver.cpp index df6ff53a5fd88a..114769e0c5f3d6 100644 --- a/be/src/exec/pipeline/pipeline_driver.cpp +++ b/be/src/exec/pipeline/pipeline_driver.cpp @@ -555,7 +555,9 @@ void PipelineDriver::finish_operators(RuntimeState* runtime_state) { void PipelineDriver::cancel_operators(RuntimeState* runtime_state) { if (this->query_ctx()->is_query_expired()) { - LOG(WARNING) << "begin to cancel operators for " << to_readable_string(); + if (_has_log_cancelled.exchange(true) == false) { + VLOG_ROW << "begin to cancel operators for " << to_readable_string(); + } } for (auto& op : _operators) { WARN_IF_ERROR(_mark_operator_cancelled(op, runtime_state), @@ -805,10 +807,11 @@ Status PipelineDriver::_mark_operator_finished(OperatorPtr& op, RuntimeState* st Status PipelineDriver::_mark_operator_cancelled(OperatorPtr& op, RuntimeState* state) { Status res = _mark_operator_finished(op, state); - if (!res.ok()) { - LOG(WARNING) << fmt::format("fragment_id {} driver {} cancels operator {} with finished error {}", - print_id(state->fragment_instance_id()), to_readable_string(), op->get_name(), - res.message()); + if (!res.ok() && !res.is_cancelled()) { + LOG(WARNING) << fmt::format( + "[Driver] failed to finish operator called by cancelling operator [fragment_id={}] [driver={}] " + "[operator={}] [error={}]", + print_id(state->fragment_instance_id()), to_readable_string(), op->get_name(), res.message()); } auto& op_state = _operator_stages[op->get_id()]; if (op_state >= OperatorStage::CANCELLED) { diff --git a/be/src/exec/pipeline/pipeline_driver.h b/be/src/exec/pipeline/pipeline_driver.h index 4ffa5a56985cde..5f9d1c8b0dcd15 100644 --- a/be/src/exec/pipeline/pipeline_driver.h +++ b/be/src/exec/pipeline/pipeline_driver.h @@ -526,6 +526,8 @@ class PipelineDriver { size_t _driver_queue_level = 0; std::atomic _in_ready_queue{false}; + std::atomic _has_log_cancelled{false}; + // metrics RuntimeProfile::Counter* _total_timer = nullptr; RuntimeProfile::Counter* _active_timer = nullptr; diff --git a/be/src/exec/pipeline/query_context.cpp b/be/src/exec/pipeline/query_context.cpp index d70c18d40b5c8d..fc8aa6598cca3a 100644 --- a/be/src/exec/pipeline/query_context.cpp +++ b/be/src/exec/pipeline/query_context.cpp @@ -375,8 +375,8 @@ QueryContext* QueryContextManager::get_or_register(const TUniqueId& query_id) { // lookup query context for the second chance in sc_map if (sc_it != sc_map.end()) { auto ctx = std::move(sc_it->second); - RETURN_NULL_IF_CTX_CANCELLED(ctx); sc_map.erase(sc_it); + RETURN_NULL_IF_CTX_CANCELLED(ctx); auto* raw_ctx_ptr = ctx.get(); context_map.emplace(query_id, std::move(ctx)); return raw_ctx_ptr; diff --git a/be/src/exec/pipeline/scan/olap_chunk_source.cpp b/be/src/exec/pipeline/scan/olap_chunk_source.cpp index 6edf39b5ca9319..d0c469985099aa 100644 --- a/be/src/exec/pipeline/scan/olap_chunk_source.cpp +++ b/be/src/exec/pipeline/scan/olap_chunk_source.cpp @@ -173,7 +173,6 @@ void OlapChunkSource::_init_counter(RuntimeState* state) { // IOTime _io_timer = ADD_CHILD_TIMER(_runtime_profile, "IOTime", IO_TASK_EXEC_TIMER_NAME); - _json_flatten_timer = ADD_CHILD_TIMER(_runtime_profile, "JsonFlattern", segment_read_name); _access_path_hits_counter = ADD_COUNTER(_runtime_profile, "AccessPathHits", TUnit::UNIT); _access_path_unhits_counter = ADD_COUNTER(_runtime_profile, "AccessPathUnhits", TUnit::UNIT); } @@ -654,21 +653,23 @@ void OlapChunkSource::_update_counter() { COUNTER_UPDATE(c2, _reader->stats().rows_del_filtered); } - if (_reader->stats().flat_json_hits.size() > 0) { + if (_reader->stats().flat_json_hits.size() > 0 || _reader->stats().merge_json_hits.size() > 0) { std::string access_path_hits = "AccessPathHits"; int64_t total = 0; for (auto& [k, v] : _reader->stats().flat_json_hits) { - auto* path_counter = _runtime_profile->get_counter(fmt::format("[Hit]{}", k)); + std::string path = fmt::format("[Hit]{}", k); + auto* path_counter = _runtime_profile->get_counter(path); if (path_counter == nullptr) { - path_counter = ADD_CHILD_COUNTER(_runtime_profile, k, TUnit::UNIT, access_path_hits); + path_counter = ADD_CHILD_COUNTER(_runtime_profile, path, TUnit::UNIT, access_path_hits); } total += v; COUNTER_UPDATE(path_counter, v); } for (auto& [k, v] : _reader->stats().merge_json_hits) { - auto* path_counter = _runtime_profile->get_counter(fmt::format("[HitMerge]{}", k)); + std::string merge_path = fmt::format("[HitMerge]{}", k); + auto* path_counter = _runtime_profile->get_counter(merge_path); if (path_counter == nullptr) { - path_counter = ADD_CHILD_COUNTER(_runtime_profile, k, TUnit::UNIT, access_path_hits); + path_counter = ADD_CHILD_COUNTER(_runtime_profile, merge_path, TUnit::UNIT, access_path_hits); } total += v; COUNTER_UPDATE(path_counter, v); @@ -679,9 +680,10 @@ void OlapChunkSource::_update_counter() { std::string access_path_unhits = "AccessPathUnhits"; int64_t total = 0; for (auto& [k, v] : _reader->stats().dynamic_json_hits) { - auto* path_counter = _runtime_profile->get_counter(fmt::format("[Unhit]{}", k)); + std::string path = fmt::format("[Unhit]{}", k); + auto* path_counter = _runtime_profile->get_counter(path); if (path_counter == nullptr) { - path_counter = ADD_CHILD_COUNTER(_runtime_profile, k, TUnit::UNIT, access_path_unhits); + path_counter = ADD_CHILD_COUNTER(_runtime_profile, path, TUnit::UNIT, access_path_unhits); } total += v; COUNTER_UPDATE(path_counter, v); @@ -689,7 +691,23 @@ void OlapChunkSource::_update_counter() { COUNTER_UPDATE(_access_path_unhits_counter, total); } - COUNTER_UPDATE(_json_flatten_timer, _reader->stats().json_flatten_ns); + std::string parent_name = "SegmentRead"; + if (_reader->stats().json_init_ns > 0) { + RuntimeProfile::Counter* c = ADD_CHILD_TIMER(_runtime_profile, "FlatJsonInit", parent_name); + COUNTER_UPDATE(c, _reader->stats().json_init_ns); + } + if (_reader->stats().json_cast_ns > 0) { + RuntimeProfile::Counter* c = ADD_CHILD_TIMER(_runtime_profile, "FlatJsonCast", parent_name); + COUNTER_UPDATE(c, _reader->stats().json_cast_ns); + } + if (_reader->stats().json_merge_ns > 0) { + RuntimeProfile::Counter* c = ADD_CHILD_TIMER(_runtime_profile, "FlatJsonMerge", parent_name); + COUNTER_UPDATE(c, _reader->stats().json_merge_ns); + } + if (_reader->stats().json_flatten_ns > 0) { + RuntimeProfile::Counter* c = ADD_CHILD_TIMER(_runtime_profile, "FlatJsonFlatten", parent_name); + COUNTER_UPDATE(c, _reader->stats().json_flatten_ns); + } } } // namespace starrocks::pipeline diff --git a/be/src/exec/pipeline/scan/olap_chunk_source.h b/be/src/exec/pipeline/scan/olap_chunk_source.h index 3348c61dd17426..e52513ac34843e 100644 --- a/be/src/exec/pipeline/scan/olap_chunk_source.h +++ b/be/src/exec/pipeline/scan/olap_chunk_source.h @@ -155,7 +155,6 @@ class OlapChunkSource final : public ChunkSource { RuntimeProfile::Counter* _total_columns_data_page_count = nullptr; RuntimeProfile::Counter* _read_pk_index_timer = nullptr; RuntimeProfile::Counter* _pushdown_access_paths_counter = nullptr; - RuntimeProfile::Counter* _json_flatten_timer = nullptr; RuntimeProfile::Counter* _access_path_hits_counter = nullptr; RuntimeProfile::Counter* _access_path_unhits_counter = nullptr; }; diff --git a/be/src/exec/schema_scan_node.cpp b/be/src/exec/schema_scan_node.cpp index 0188f2496c4906..d521be6ce57530 100644 --- a/be/src/exec/schema_scan_node.cpp +++ b/be/src/exec/schema_scan_node.cpp @@ -302,8 +302,7 @@ std::vector> SchemaScanNode::decompos pipeline::PipelineBuilderContext* context) { auto exec_group = context->find_exec_group_by_plan_node_id(_id); context->set_current_execution_group(exec_group); - // the dop of SchemaScanOperator should always be 1. - size_t dop = 1; + size_t dop = context->dop_of_source_operator(_id); size_t buffer_capacity = pipeline::ScanOperator::max_buffer_capacity() * dop; pipeline::ChunkBufferLimiterPtr buffer_limiter = std::make_unique( diff --git a/be/src/exec/schema_scanner/schema_materialized_views_scanner.cpp b/be/src/exec/schema_scanner/schema_materialized_views_scanner.cpp index 790cf9c0bca338..75976a90e8c0cf 100644 --- a/be/src/exec/schema_scanner/schema_materialized_views_scanner.cpp +++ b/be/src/exec/schema_scanner/schema_materialized_views_scanner.cpp @@ -60,6 +60,7 @@ SchemaScanner::ColumnDesc SchemaMaterializedViewsScanner::_s_tbls_columns[] = { false}, {"EXTRA_MESSAGE", TypeDescriptor::create_varchar_type(sizeof(StringValue)), sizeof(StringValue), false}, {"QUERY_REWRITE_STATUS", TypeDescriptor::create_varchar_type(sizeof(StringValue)), sizeof(StringValue), false}, + {"CREATOR", TypeDescriptor::create_varchar_type(sizeof(StringValue)), sizeof(StringValue), false}, }; SchemaMaterializedViewsScanner::SchemaMaterializedViewsScanner() @@ -118,7 +119,8 @@ Status SchemaMaterializedViewsScanner::fill_chunk(ChunkPtr* chunk) { Slice(info.rows), Slice(info.text), Slice(info.extra_message), - Slice(info.query_rewrite_status)}; + Slice(info.query_rewrite_status), + Slice(info.creator)}; for (const auto& [slot_id, index] : slot_id_map) { Column* column = (*chunk)->get_column_by_slot_id(slot_id).get(); diff --git a/be/src/exec/schema_scanner/schema_tasks_scanner.cpp b/be/src/exec/schema_scanner/schema_tasks_scanner.cpp index 3bc190bd0af335..99738251bf7ca4 100644 --- a/be/src/exec/schema_scanner/schema_tasks_scanner.cpp +++ b/be/src/exec/schema_scanner/schema_tasks_scanner.cpp @@ -30,6 +30,7 @@ SchemaScanner::ColumnDesc SchemaTasksScanner::_s_tbls_columns[] = { {"DEFINITION", TypeDescriptor::create_varchar_type(sizeof(StringValue)), sizeof(StringValue), false}, {"EXPIRE_TIME", TypeDescriptor::from_logical_type(TYPE_DATETIME), sizeof(StringValue), true}, {"PROPERTIES", TypeDescriptor::create_varchar_type(sizeof(StringValue)), sizeof(StringValue), false}, + {"CREATOR", TypeDescriptor::create_varchar_type(sizeof(StringValue)), sizeof(StringValue), false}, }; SchemaTasksScanner::SchemaTasksScanner() @@ -63,10 +64,8 @@ DatumArray SchemaTasksScanner::_build_row() { ? TimestampValue::create_from_unixtime(task.create_time, _runtime_state->timezone_obj()) : kNullDatum; - return { - Slice(task.task_name), create_time, Slice(task.schedule), Slice(task.catalog), Slice(task.database), - Slice(task.definition), expire_time, Slice(task.properties), - }; + return {Slice(task.task_name), create_time, Slice(task.schedule), Slice(task.catalog), Slice(task.database), + Slice(task.definition), expire_time, Slice(task.properties), Slice(task.creator)}; } Status SchemaTasksScanner::fill_chunk(ChunkPtr* chunk) { diff --git a/be/src/exec/topn_node.cpp b/be/src/exec/topn_node.cpp index 0a9a0bcbd2b469..dbe96de9139f3c 100644 --- a/be/src/exec/topn_node.cpp +++ b/be/src/exec/topn_node.cpp @@ -372,8 +372,9 @@ pipeline::OpFactories TopNNode::decompose_to_pipeline(pipeline::PipelineBuilderC bool is_partition_skewed = _tnode.sort_node.__isset.analytic_partition_skewed && _tnode.sort_node.analytic_partition_skewed; bool need_merge = _analytic_partition_exprs.empty() || is_partition_skewed; - bool enable_parallel_merge = - _tnode.sort_node.__isset.enable_parallel_merge && _tnode.sort_node.enable_parallel_merge; + bool enable_parallel_merge = _tnode.sort_node.__isset.enable_parallel_merge && + _tnode.sort_node.enable_parallel_merge && + !_sort_exec_exprs.lhs_ordering_expr_ctxs().empty(); OpFactories operators_source_with_sort; diff --git a/be/src/exprs/CMakeLists.txt b/be/src/exprs/CMakeLists.txt index 2f9b0404ffa8b9..cb1bdb7cda8372 100644 --- a/be/src/exprs/CMakeLists.txt +++ b/be/src/exprs/CMakeLists.txt @@ -21,6 +21,7 @@ set(EXPR_FILES agg/factory/aggregate_factory.cpp agg/factory/aggregate_resolver_approx.cpp agg/factory/aggregate_resolver_avg.cpp + agg/factory/aggregate_resolver_hypothesis_testing.cpp agg/factory/aggregate_resolver_minmaxany.cpp agg/factory/aggregate_resolver_others.cpp agg/factory/aggregate_resolver_sumcount.cpp diff --git a/be/src/exprs/agg/aggregate_factory.h b/be/src/exprs/agg/aggregate_factory.h index 65ac3deda2352b..90b6f7d91c6454 100644 --- a/be/src/exprs/agg/aggregate_factory.h +++ b/be/src/exprs/agg/aggregate_factory.h @@ -29,4 +29,7 @@ const AggregateFunction* get_window_function(const std::string& name, LogicalTyp TFunctionBinaryType::type binary_type = TFunctionBinaryType::BUILTIN, int func_version = 1); +const AggregateFunction* get_aggregate_function(const std::string& agg_func_name, const TypeDescriptor& return_type, + const std::vector& arg_types, bool is_input_nullable, + TFunctionBinaryType::type binary_type, int func_version = 1); } // namespace starrocks diff --git a/be/src/exprs/agg/aggregate_state_allocator.h b/be/src/exprs/agg/aggregate_state_allocator.h index b980ee8e909c05..66d7f85dc0933b 100644 --- a/be/src/exprs/agg/aggregate_state_allocator.h +++ b/be/src/exprs/agg/aggregate_state_allocator.h @@ -17,6 +17,7 @@ #include "column/hash_set.h" #include "common/config.h" #include "runtime/memory/allocator.h" +#include "runtime/memory/roaring_hook.h" namespace starrocks { @@ -92,4 +93,20 @@ using SliceHashSetWithAggStateAllocator = phmap::flat_hash_set using VectorWithAggStateAllocator = std::vector>; + +// Thread local aggregate state allocator setter with roaring allocator +class ThreadLocalStateAllocatorSetter { +public: + ThreadLocalStateAllocatorSetter(Allocator* allocator) + : _agg_state_allocator_setter(allocator), _roaring_allocator_setter(allocator) {} + ~ThreadLocalStateAllocatorSetter() = default; + +private: + ThreadLocalAggregateStateAllocatorSetter _agg_state_allocator_setter; + ThreadLocalRoaringAllocatorSetter _roaring_allocator_setter; +}; + +#define SCOPED_THREAD_LOCAL_STATE_ALLOCATOR_SETTER(allocator) \ + auto VARNAME_LINENUM(alloc_setter) = ThreadLocalStateAllocatorSetter(allocator) + } // namespace starrocks diff --git a/be/src/exprs/agg/array_agg.h b/be/src/exprs/agg/array_agg.h index e4520550a1e299..ac15f44573d2ba 100644 --- a/be/src/exprs/agg/array_agg.h +++ b/be/src/exprs/agg/array_agg.h @@ -94,19 +94,12 @@ struct ArrayAggAggregateState { return &data_column; } - bool check_overflow(FunctionContext* ctx) const { - std::string err_msg; - if (UNLIKELY(data_column.capacity_limit_reached(&err_msg))) { - ctx->set_error(fmt::format("The column generated by array_agg is overflow: {}", err_msg).c_str()); - return true; - } - return false; - } + bool check_overflow(FunctionContext* ctx) const { return check_overflow(data_column, ctx); } static bool check_overflow(const Column& col, FunctionContext* ctx) { - std::string err_msg; - if (UNLIKELY(col.capacity_limit_reached(&err_msg))) { - ctx->set_error(fmt::format("The column generated by array_agg is overflow: {}", err_msg).c_str()); + Status st = col.capacity_limit_reached(); + if (!st.ok()) { + ctx->set_error(fmt::format("The column generated by array_agg is overflow: {}", st.message()).c_str()); return true; } return false; @@ -195,8 +188,9 @@ struct ArrayAggAggregateStateV2 { bool check_overflow(FunctionContext* ctx) const { std::string err_msg; for (size_t i = 0; i < data_columns.size(); i++) { - if (UNLIKELY(data_columns[i]->capacity_limit_reached(&err_msg))) { - ctx->set_error(fmt::format("The column generated by array_agg is overflow: {}", err_msg).c_str()); + Status st = data_columns[i]->capacity_limit_reached(); + if (!st.ok()) { + ctx->set_error(fmt::format("The column generated by array_agg is overflow: {}", st.message()).c_str()); return true; } } @@ -204,9 +198,9 @@ struct ArrayAggAggregateStateV2 { } static bool check_overflow(const Column& col, FunctionContext* ctx) { - std::string err_msg; - if (UNLIKELY(col.capacity_limit_reached(&err_msg))) { - ctx->set_error(fmt::format("The column generated by array_agg is overflow: {}", err_msg).c_str()); + Status st = col.capacity_limit_reached(); + if (!st.ok()) { + ctx->set_error(fmt::format("The column generated by array_agg is overflow: {}", st.message()).c_str()); return true; } return false; diff --git a/be/src/exprs/agg/avg.h b/be/src/exprs/agg/avg.h index d3608501c302df..cfdf6d5663df02 100644 --- a/be/src/exprs/agg/avg.h +++ b/be/src/exprs/agg/avg.h @@ -211,13 +211,14 @@ class AvgAggregateFunction final void finalize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to) const override { DCHECK(!to->is_nullable()); + auto* column = down_cast(to); // In fact, for StarRocks real query, we don't need this check. // But for robust, we add this check. if (this->data(state).count == 0) { + column->append_default(); return; } - auto* column = down_cast(to); ResultType result; if constexpr (lt_is_decimalv2) { result = this->data(state).sum / DecimalV2Value(this->data(state).count, 0); diff --git a/be/src/exprs/agg/ds_hll_count_distinct.h b/be/src/exprs/agg/ds_hll_count_distinct.h new file mode 100644 index 00000000000000..e83914e0de2cc2 --- /dev/null +++ b/be/src/exprs/agg/ds_hll_count_distinct.h @@ -0,0 +1,240 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "column/binary_column.h" +#include "column/object_column.h" +#include "column/type_traits.h" +#include "column/vectorized_fwd.h" +#include "exprs/agg/aggregate.h" +#include "gutil/casts.h" +#include "types/hll_sketch.h" + +namespace starrocks { + +struct HLLSketchState { + std::unique_ptr hll_sketch = nullptr; + int64_t memory_usage = 0; +}; + +/** + * RETURN_TYPE: TYPE_BIGINT + * ARGS_TYPE: ALL TYPE + * SERIALIZED_TYPE: TYPE_VARCHAR + */ +template > +class HllSketchAggregateFunction final + : public AggregateFunctionBatchHelper> { +public: + using ColumnType = RunTimeColumnType; + + void reset(FunctionContext* ctx, const Columns& args, AggDataPtr state) const override { + if (this->data(state).hll_sketch != nullptr) { + ctx->add_mem_usage(-this->data(state).hll_sketch->mem_usage()); + this->data(state).hll_sketch->clear(); + } + } + + void update_state(FunctionContext* ctx, AggDataPtr state, uint64_t value) const { + int64_t prev_memory = this->data(state).hll_sketch->mem_usage(); + this->data(state).hll_sketch->update(value); + ctx->add_mem_usage(this->data(state).hll_sketch->mem_usage() - prev_memory); + } + + void update(FunctionContext* ctx, const Column** columns, AggDataPtr __restrict state, + size_t row_num) const override { + // init state if needed + _init_if_needed(ctx, columns, state); + + uint64_t value = 0; + const ColumnType* column = down_cast(columns[0]); + + if constexpr (lt_is_string) { + Slice s = column->get_slice(row_num); + value = HashUtil::murmur_hash64A(s.data, s.size, HashUtil::MURMUR_SEED); + } else { + const auto& v = column->get_data(); + value = HashUtil::murmur_hash64A(&v[row_num], sizeof(v[row_num]), HashUtil::MURMUR_SEED); + } + update_state(ctx, state, value); + } + + void update_batch_single_state_with_frame(FunctionContext* ctx, AggDataPtr __restrict state, const Column** columns, + int64_t peer_group_start, int64_t peer_group_end, int64_t frame_start, + int64_t frame_end) const override { + // init state if needed + _init_if_needed(ctx, columns, state); + const ColumnType* column = down_cast(columns[0]); + if constexpr (lt_is_string) { + uint64_t value = 0; + for (size_t i = frame_start; i < frame_end; ++i) { + Slice s = column->get_slice(i); + value = HashUtil::murmur_hash64A(s.data, s.size, HashUtil::MURMUR_SEED); + + if (value != 0) { + update_state(ctx, state, value); + } + } + } else { + uint64_t value = 0; + const auto& v = column->get_data(); + for (size_t i = frame_start; i < frame_end; ++i) { + value = HashUtil::murmur_hash64A(&v[i], sizeof(v[i]), HashUtil::MURMUR_SEED); + + if (value != 0) { + update_state(ctx, state, value); + } + } + } + } + + void merge(FunctionContext* ctx, const Column* column, AggDataPtr __restrict state, size_t row_num) const override { + DCHECK(column->is_binary()); + const BinaryColumn* hll_column = down_cast(column); + DataSketchesHll hll(hll_column->get(row_num).get_slice(), &(this->data(state).memory_usage)); + if (UNLIKELY(this->data(state).hll_sketch == nullptr)) { + this->data(state).hll_sketch = std::make_unique( + hll.get_lg_config_k(), hll.get_target_type(), &(this->data(state).memory_usage)); + } + int64_t prev_memory = this->data(state).hll_sketch->mem_usage(); + this->data(state).hll_sketch->merge(hll); + ctx->add_mem_usage(this->data(state).hll_sketch->mem_usage() - prev_memory); + } + + void get_values(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* dst, size_t start, + size_t end) const override { + DCHECK_GT(end, start); + Int64Column* column = down_cast(dst); + int64_t result = 0L; + if (LIKELY(this->data(state).hll_sketch != nullptr)) { + result = this->data(state).hll_sketch->estimate_cardinality(); + } + for (size_t i = start; i < end; ++i) { + column->get_data()[i] = result; + } + } + + void serialize_to_column([[maybe_unused]] FunctionContext* ctx, ConstAggDataPtr __restrict state, + Column* to) const override { + DCHECK(to->is_binary()); + auto* column = down_cast(to); + if (UNLIKELY(this->data(state).hll_sketch == nullptr)) { + column->append_default(); + } else { + size_t size = this->data(state).hll_sketch->serialize_size(); + uint8_t result[size]; + size = this->data(state).hll_sketch->serialize(result); + column->append(Slice(result, size)); + } + } + + void convert_to_serialize_format([[maybe_unused]] FunctionContext* ctx, const Columns& src, size_t chunk_size, + ColumnPtr* dst) const override { + const ColumnType* column = down_cast(src[0].get()); + auto* result = down_cast((*dst).get()); + + Bytes& bytes = result->get_bytes(); + bytes.reserve(chunk_size * 10); + result->get_offset().resize(chunk_size + 1); + + size_t old_size = bytes.size(); + uint64_t value = 0; + uint8_t log_k; + datasketches::target_hll_type tgt_type; + // convert to const Column* + std::vector src_datas; + src_datas.reserve(src.size()); + std::transform(src.begin(), src.end(), std::back_inserter(src_datas), + [](const ColumnPtr& col) { return col.get(); }); + const Column** src_datas_ptr = src_datas.data(); + std::tie(log_k, tgt_type) = _parse_hll_sketch_args(ctx, src_datas_ptr); + for (size_t i = 0; i < chunk_size; ++i) { + int64_t memory_usage = 0; + DataSketchesHll hll{log_k, tgt_type, &memory_usage}; + if constexpr (lt_is_string) { + Slice s = column->get_slice(i); + value = HashUtil::murmur_hash64A(s.data, s.size, HashUtil::MURMUR_SEED); + } else { + auto v = column->get_data()[i]; + value = HashUtil::murmur_hash64A(&v, sizeof(v), HashUtil::MURMUR_SEED); + } + if (value != 0) { + hll.update(value); + } + + size_t new_size = old_size + hll.serialize_size(); + bytes.resize(new_size); + hll.serialize(bytes.data() + old_size); + + result->get_offset()[i + 1] = new_size; + old_size = new_size; + } + } + + void finalize_to_column(FunctionContext* ctx __attribute__((unused)), ConstAggDataPtr __restrict state, + Column* to) const override { + DCHECK(to->is_numeric()); + + auto* column = down_cast(to); + if (UNLIKELY(this->data(state).hll_sketch == nullptr)) { + column->append(0L); + } else { + column->append(this->data(state).hll_sketch->estimate_cardinality()); + } + } + + std::string get_name() const override { return "ds_hll_count_distinct"; } + +private: + // init hll sketch if needed + void _init_if_needed(FunctionContext* ctx, const Column** columns, AggDataPtr __restrict state) const { + if (UNLIKELY(this->data(state).hll_sketch == nullptr)) { + uint8_t log_k; + datasketches::target_hll_type tgt_type; + std::tie(log_k, tgt_type) = _parse_hll_sketch_args(ctx, columns); + this->data(state).hll_sketch = _init_hll_sketch(log_k, tgt_type, &(this->data(state).memory_usage)); + } + } + + // parse log_k and target type from args + std::tuple _parse_hll_sketch_args(FunctionContext* ctx, + const Column** columns) const { + uint8_t log_k = DEFAULT_HLL_LOG_K; + datasketches::target_hll_type tgt_type = datasketches::HLL_6; + if (ctx->get_num_args() == 2) { + log_k = (uint8_t)(columns[1]->get(0).get_int32()); + } else if (ctx->get_num_args() == 3) { + log_k = (uint8_t)(columns[1]->get(0).get_int32()); + std::string tgt_type_str = columns[2]->get(0).get_slice().to_string(); + std::transform(tgt_type_str.begin(), tgt_type_str.end(), tgt_type_str.begin(), ::toupper); + if (tgt_type_str == "HLL_4") { + tgt_type = datasketches::HLL_4; + } else if (tgt_type_str == "HLL_8") { + tgt_type = datasketches::HLL_8; + } else { + tgt_type = datasketches::HLL_6; + } + } + return {log_k, tgt_type}; + } + + // init hll sketch with default log_k and target type + std::unique_ptr _init_hll_sketch(uint8_t log_k, datasketches::target_hll_type tgt_type, + int64_t* memory_usage) const { + return std::make_unique(log_k, tgt_type, memory_usage); + } +}; + +} // namespace starrocks diff --git a/be/src/exprs/agg/factory/aggregate_factory.cpp b/be/src/exprs/agg/factory/aggregate_factory.cpp index dcc63993f26c7e..d790204ba91eb3 100644 --- a/be/src/exprs/agg/factory/aggregate_factory.cpp +++ b/be/src/exprs/agg/factory/aggregate_factory.cpp @@ -15,15 +15,11 @@ #include "exprs/agg/aggregate_factory.h" #include -#include -#include -#include "column/type_traits.h" #include "exprs/agg/aggregate.h" #include "exprs/agg/factory/aggregate_factory.hpp" #include "exprs/agg/factory/aggregate_resolver.hpp" #include "types/logical_type.h" -#include "types/logical_type_infra.h" #include "udf/java/java_function_fwd.h" namespace starrocks { @@ -40,6 +36,7 @@ AggregateFuncResolver::AggregateFuncResolver() { register_approx(); register_others(); register_retract_functions(); + register_hypothesis_testing(); } AggregateFuncResolver::~AggregateFuncResolver() = default; @@ -168,4 +165,45 @@ const AggregateFunction* get_window_function(const std::string& name, LogicalTyp return nullptr; } +const AggregateFunction* get_aggregate_function(const std::string& agg_func_name, const TypeDescriptor& return_type, + const std::vector& arg_types, bool is_result_nullable, + TFunctionBinaryType::type binary_type, int func_version) { + // get function + if (agg_func_name == "count") { + return get_aggregate_function("count", TYPE_BIGINT, TYPE_BIGINT, is_result_nullable); + } else { + DCHECK_GE(arg_types.size(), 1); + TypeDescriptor arg_type = arg_types[0]; + // Because intersect_count have two input types. + // And intersect_count's first argument's type is alwasy Bitmap, + // so we use its second arguments type as input. + if (agg_func_name == "intersect_count") { + arg_type = arg_types[1]; + } + + // Because max_by and min_by function have two input types, + // so we use its second arguments type as input. + if (agg_func_name == "max_by" || agg_func_name == "min_by") { + arg_type = arg_types[1]; + } + + // Because windowfunnel have more two input types. + // functions registry use 2th args(datetime/date). + if (agg_func_name == "window_funnel") { + arg_type = arg_types[1]; + } + + // hack for accepting various arguments + if (agg_func_name == "exchange_bytes" || agg_func_name == "exchange_speed") { + arg_type = TypeDescriptor(TYPE_BIGINT); + } + + if (agg_func_name == "array_union_agg" || agg_func_name == "array_unique_agg") { + arg_type = arg_type.children[0]; + } + return get_aggregate_function(agg_func_name, arg_type.type, return_type.type, is_result_nullable, binary_type, + func_version); + } +} + } // namespace starrocks diff --git a/be/src/exprs/agg/factory/aggregate_factory.hpp b/be/src/exprs/agg/factory/aggregate_factory.hpp index a38fb2fa18644b..f8f605bfa7eceb 100644 --- a/be/src/exprs/agg/factory/aggregate_factory.hpp +++ b/be/src/exprs/agg/factory/aggregate_factory.hpp @@ -15,8 +15,6 @@ #pragma once #include -#include -#include #include "column/type_traits.h" #include "exprs/agg/aggregate.h" @@ -33,14 +31,15 @@ #include "exprs/agg/count.h" #include "exprs/agg/covariance.h" #include "exprs/agg/distinct.h" +#include "exprs/agg/ds_hll_count_distinct.h" #include "exprs/agg/exchange_perf.h" #include "exprs/agg/group_concat.h" #include "exprs/agg/histogram.h" #include "exprs/agg/hll_ndv.h" -#include "exprs/agg/hll_sketch.h" #include "exprs/agg/hll_union.h" #include "exprs/agg/hll_union_count.h" #include "exprs/agg/intersect_count.h" +#include "exprs/agg/mann_whitney.h" #include "exprs/agg/maxmin.h" #include "exprs/agg/maxmin_by.h" #include "exprs/agg/nullable_aggregate.h" @@ -120,6 +119,10 @@ class AggregateFactory { return std::make_shared(); } + static auto MakeMannWhitneyUTestAggregateFunction() { + return std::make_shared(); + } + template static auto MakeMaxAggregateFunction(); diff --git a/be/src/exprs/agg/factory/aggregate_resolver.hpp b/be/src/exprs/agg/factory/aggregate_resolver.hpp index a9721c5ec7a07e..9bdc9bb862ce23 100644 --- a/be/src/exprs/agg/factory/aggregate_resolver.hpp +++ b/be/src/exprs/agg/factory/aggregate_resolver.hpp @@ -22,7 +22,6 @@ #include "exprs/agg/aggregate.h" #include "exprs/agg/factory/aggregate_factory.hpp" #include "types/logical_type.h" -#include "types/logical_type_infra.h" #include "udf/java/java_function_fwd.h" namespace starrocks { @@ -68,6 +67,7 @@ class AggregateFuncResolver { void register_approx(); void register_others(); void register_retract_functions(); + void register_hypothesis_testing(); const std::vector& aggregate_types() const { const static std::vector kTypes{ diff --git a/be/src/exprs/agg/factory/aggregate_resolver_hypothesis_testing.cpp b/be/src/exprs/agg/factory/aggregate_resolver_hypothesis_testing.cpp new file mode 100644 index 00000000000000..66f88817876340 --- /dev/null +++ b/be/src/exprs/agg/factory/aggregate_resolver_hypothesis_testing.cpp @@ -0,0 +1,26 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "exprs/agg/factory/aggregate_factory.hpp" +#include "exprs/agg/factory/aggregate_resolver.hpp" +#include "types/logical_type.h" + +namespace starrocks { + +void AggregateFuncResolver::register_hypothesis_testing() { + add_aggregate_mapping_variadic( + "mann_whitney_u_test", false, AggregateFactory::MakeMannWhitneyUTestAggregateFunction()); +} + +} // namespace starrocks diff --git a/be/src/exprs/agg/helpers/serialize_helpers.hpp b/be/src/exprs/agg/helpers/serialize_helpers.hpp new file mode 100644 index 00000000000000..395cc4c6376b75 --- /dev/null +++ b/be/src/exprs/agg/helpers/serialize_helpers.hpp @@ -0,0 +1,258 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace starrocks { + +class SerializeHelpers { +public: + template + static void serialize(const T* data, uint8_t*& buffer, size_t num_elems = 1) { + using T_ = std::remove_cv_t>; + std::copy_n(reinterpret_cast(data), num_elems * sizeof(T), buffer); + buffer += num_elems * sizeof(T); + } + + template + static void deserialize(const uint8_t*& buffer, T* data, size_t num_elems = 1) { + using T_ = std::remove_cv_t>; + std::copy_n(buffer, num_elems * sizeof(T), reinterpret_cast(data)); + buffer += num_elems * sizeof(T); + } + + template + static void serialize(const T& data, uint8_t*& buffer) { + serialize(&data, buffer, 1); + } + + template + static void deserialize(const uint8_t*& buffer, T& data) { + deserialize(buffer, &data, 1); + } + + static void serialize(const std::string& data, uint8_t*& buffer) { + size_t size = data.length(); + serialize(size, buffer); + serialize(data.data(), buffer, size); + } + + static void deserialize(const uint8_t*& buffer, std::string& data) { + size_t size = 0; + deserialize(buffer, size); + data.resize(size); + deserialize(buffer, data.data(), size); + } + + template + static void serialize(std::vector const& data, uint8_t*& buffer) { + size_t size = data.size(); + serialize(size, buffer); + for (auto const& elem : data) { + serialize(elem, buffer); + } + } + + template + static void deserialize(const uint8_t*& buffer, std::vector& data) { + size_t size = 0; + deserialize(buffer, size); + data.resize(size); + for (auto& elem : data) { + deserialize(buffer, elem); + } + } + + template + static typename std::enable_if_t< + std::is_same_v> || + std::is_same_v>, + void> + serialize(MapType const& data, uint8_t*& buffer) { + size_t size = data.size(); + serialize(size, buffer); + for (auto const& [key, value] : data) { + serialize(key, buffer); + serialize(value, buffer); + } + } + + template + static typename std::enable_if_t< + std::is_same_v> || + std::is_same_v>, + void> + deserialize(const uint8_t*& buffer, MapType& data) { + size_t size = 0; + deserialize(buffer, size); + for (size_t i = 0; i < size; ++i) { + typename MapType::key_type key; + typename MapType::mapped_type value; + deserialize(buffer, key); + deserialize(buffer, value); + data[std::move(key)] = std::move(value); + } + } + + template + static typename std::enable_if_t> || + std::is_same_v>, + void> + serialize(SetType const& data, uint8_t*& buffer) { + size_t size = data.size(); + serialize(size, buffer); + for (auto const& key : data) { + serialize(key, buffer); + } + } + + template + static typename std::enable_if_t> || + std::is_same_v>, + void> + deserialize(const uint8_t*& buffer, SetType& data) { + size_t size = 0; + deserialize(buffer, size); + for (size_t i = 0; i < size; ++i) { + typename SetType::key_type key; + deserialize(buffer, key); + data.emplace(std::move(key)); + } + } + + template + static void serialize(std::array const& data, uint8_t*& buffer) { + for (auto const& elem : data) { + serialize(elem, buffer); + } + } + + template + static void deserialize(const uint8_t*& buffer, std::array& data) { + for (auto& elem : data) { + deserialize(buffer, elem); + } + } + + template + static size_t serialized_size(T const&) { + return sizeof(T); + } + + static size_t serialized_size(std::string const& str) { return sizeof(size_t) + sizeof(char) * str.length(); } + + template + static size_t serialized_size(std::vector const& vec) { + size_t size = sizeof(size_t); + for (auto const& elem : vec) { + size += serialized_size(elem); + } + return size; + } + + template + static size_t serialized_size(std::array const& vec) { + size_t size = 0; + for (auto const& elem : vec) { + size += serialized_size(elem); + } + return size; + } + + template + static typename std::enable_if_t< + std::is_same_v> || + std::is_same_v>, + size_t> + serialized_size(MapType const& data) { + size_t size = sizeof(size_t); + for (auto const& [key, value] : data) { + size += serialized_size(key) + serialized_size(value); + } + return size; + } + + template + static typename std::enable_if_t> || + std::is_same_v>, + size_t> + serialized_size(SetType const& data) { + size_t size = sizeof(size_t); + for (auto const& elem : data) { + size += serialized_size(elem); + } + return size; + } + + template + static size_t serialized_size(std::tuple const& tuple) { + size_t size = 0; + std::apply([&size](const auto&... element) { ((size += serialized_size(element)), ...); }, tuple); + return size; + } + + template + static void serialize(std::tuple const& tuple, uint8_t*& buffer) { + std::apply([&buffer](const auto&... element) { (serialize(element, buffer), ...); }, tuple); + } + + template + static void deserialize(const uint8_t*& buffer, std::tuple& tuple) { + std::apply([&buffer](auto&... element) { (deserialize(buffer, element), ...); }, tuple); + } + + template + static size_t serialized_size_all(First&& first, Args&&... rest) { + if constexpr (sizeof...(rest) == 0) { + return serialized_size(first); + } else { + return serialized_size(first) + serialized_size_all(std::forward(rest)...); + } + } + + template + static void serialize_all(uint8_t*& data, First const& first, Args&&... rest) { + serialize(first, data); + if constexpr (sizeof...(rest) != 0) { + serialize_all(data, std::forward(rest)...); + } + } + + template + static void deserialize_all(const uint8_t*& data, First& first, Args&&... rest) { + deserialize(data, first); + if constexpr (sizeof...(rest) != 0) { + deserialize_all(data, std::forward(rest)...); + } + } +}; + +} // namespace starrocks diff --git a/be/src/exprs/agg/hll_sketch.h b/be/src/exprs/agg/hll_sketch.h deleted file mode 100644 index 8fd51fc3947461..00000000000000 --- a/be/src/exprs/agg/hll_sketch.h +++ /dev/null @@ -1,237 +0,0 @@ -// Copyright 2021-present StarRocks, Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "column/binary_column.h" -#include "column/object_column.h" -#include "column/type_traits.h" -#include "column/vectorized_fwd.h" -#include "exprs/agg/aggregate.h" -#include "gutil/casts.h" - -namespace starrocks { - -struct HLLSketchState { - std::unique_ptr hll_sketch = nullptr; -}; - -/** - * RETURN_TYPE: TYPE_BIGINT - * ARGS_TYPE: ALL TYPE - * SERIALIZED_TYPE: TYPE_VARCHAR - */ -template > -class HllSketchAggregateFunction final - : public AggregateFunctionBatchHelper> { -public: - using ColumnType = RunTimeColumnType; - - void reset(FunctionContext* ctx, const Columns& args, AggDataPtr state) const override { - if (this->data(state).hll_sketch != nullptr) { - ctx->add_mem_usage(-this->data(state).hll_sketch->mem_usage()); - this->data(state).hll_sketch->clear(); - } - } - - void update_state(FunctionContext* ctx, AggDataPtr state, uint64_t value) const { - int64_t prev_memory = this->data(state).hll_sketch->mem_usage(); - this->data(state).hll_sketch->update(value); - ctx->add_mem_usage(this->data(state).hll_sketch->mem_usage() - prev_memory); - } - - void update(FunctionContext* ctx, const Column** columns, AggDataPtr __restrict state, - size_t row_num) const override { - // init state if needed - _init_if_needed(ctx, columns, state); - - uint64_t value = 0; - const ColumnType* column = down_cast(columns[0]); - - if constexpr (lt_is_string) { - Slice s = column->get_slice(row_num); - value = HashUtil::murmur_hash64A(s.data, s.size, HashUtil::MURMUR_SEED); - } else { - const auto& v = column->get_data(); - value = HashUtil::murmur_hash64A(&v[row_num], sizeof(v[row_num]), HashUtil::MURMUR_SEED); - } - update_state(ctx, state, value); - } - - void update_batch_single_state_with_frame(FunctionContext* ctx, AggDataPtr __restrict state, const Column** columns, - int64_t peer_group_start, int64_t peer_group_end, int64_t frame_start, - int64_t frame_end) const override { - // init state if needed - _init_if_needed(ctx, columns, state); - const ColumnType* column = down_cast(columns[0]); - if constexpr (lt_is_string) { - uint64_t value = 0; - for (size_t i = frame_start; i < frame_end; ++i) { - Slice s = column->get_slice(i); - value = HashUtil::murmur_hash64A(s.data, s.size, HashUtil::MURMUR_SEED); - - if (value != 0) { - update_state(ctx, state, value); - } - } - } else { - uint64_t value = 0; - const auto& v = column->get_data(); - for (size_t i = frame_start; i < frame_end; ++i) { - value = HashUtil::murmur_hash64A(&v[i], sizeof(v[i]), HashUtil::MURMUR_SEED); - - if (value != 0) { - update_state(ctx, state, value); - } - } - } - } - - void merge(FunctionContext* ctx, const Column* column, AggDataPtr __restrict state, size_t row_num) const override { - DCHECK(column->is_binary()); - const BinaryColumn* hll_column = down_cast(column); - DataSketchesHll hll(hll_column->get(row_num).get_slice()); - if (UNLIKELY(this->data(state).hll_sketch == nullptr)) { - this->data(state).hll_sketch = - std::make_unique(hll.get_lg_config_k(), hll.get_target_type()); - } - int64_t prev_memory = this->data(state).hll_sketch->mem_usage(); - this->data(state).hll_sketch->merge(hll); - ctx->add_mem_usage(this->data(state).hll_sketch->mem_usage() - prev_memory); - } - - void get_values(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* dst, size_t start, - size_t end) const override { - DCHECK_GT(end, start); - Int64Column* column = down_cast(dst); - int64_t result = 0L; - if (LIKELY(this->data(state).hll_sketch != nullptr)) { - result = this->data(state).hll_sketch->estimate_cardinality(); - } - for (size_t i = start; i < end; ++i) { - column->get_data()[i] = result; - } - } - - void serialize_to_column([[maybe_unused]] FunctionContext* ctx, ConstAggDataPtr __restrict state, - Column* to) const override { - DCHECK(to->is_binary()); - auto* column = down_cast(to); - if (UNLIKELY(this->data(state).hll_sketch == nullptr)) { - column->append_default(); - } else { - size_t size = this->data(state).hll_sketch->serialize_size(); - uint8_t result[size]; - size = this->data(state).hll_sketch->serialize(result); - column->append(Slice(result, size)); - } - } - - void convert_to_serialize_format([[maybe_unused]] FunctionContext* ctx, const Columns& src, size_t chunk_size, - ColumnPtr* dst) const override { - const ColumnType* column = down_cast(src[0].get()); - auto* result = down_cast((*dst).get()); - - Bytes& bytes = result->get_bytes(); - bytes.reserve(chunk_size * 10); - result->get_offset().resize(chunk_size + 1); - - size_t old_size = bytes.size(); - uint64_t value = 0; - uint8_t log_k; - datasketches::target_hll_type tgt_type; - // convert to const Column* - std::vector src_datas; - src_datas.reserve(src.size()); - std::transform(src.begin(), src.end(), std::back_inserter(src_datas), - [](const ColumnPtr& col) { return col.get(); }); - const Column** src_datas_ptr = src_datas.data(); - std::tie(log_k, tgt_type) = _parse_hll_sketch_args(ctx, src_datas_ptr); - for (size_t i = 0; i < chunk_size; ++i) { - DataSketchesHll hll{log_k, tgt_type}; - if constexpr (lt_is_string) { - Slice s = column->get_slice(i); - value = HashUtil::murmur_hash64A(s.data, s.size, HashUtil::MURMUR_SEED); - } else { - auto v = column->get_data()[i]; - value = HashUtil::murmur_hash64A(&v, sizeof(v), HashUtil::MURMUR_SEED); - } - if (value != 0) { - hll.update(value); - } - - size_t new_size = old_size + hll.serialize_size(); - bytes.resize(new_size); - hll.serialize(bytes.data() + old_size); - - result->get_offset()[i + 1] = new_size; - old_size = new_size; - } - } - - void finalize_to_column(FunctionContext* ctx __attribute__((unused)), ConstAggDataPtr __restrict state, - Column* to) const override { - DCHECK(to->is_numeric()); - - auto* column = down_cast(to); - if (UNLIKELY(this->data(state).hll_sketch == nullptr)) { - column->append(0L); - } else { - column->append(this->data(state).hll_sketch->estimate_cardinality()); - } - } - - std::string get_name() const override { return "ds_hll_count_distinct"; } - -private: - // init hll sketch if needed - void _init_if_needed(FunctionContext* ctx, const Column** columns, AggDataPtr __restrict state) const { - if (UNLIKELY(this->data(state).hll_sketch == nullptr)) { - uint8_t log_k; - datasketches::target_hll_type tgt_type; - std::tie(log_k, tgt_type) = _parse_hll_sketch_args(ctx, columns); - this->data(state).hll_sketch = _init_hll_sketch(log_k, tgt_type); - } - } - - // parse log_k and target type from args - std::tuple _parse_hll_sketch_args(FunctionContext* ctx, - const Column** columns) const { - uint8_t log_k = DEFAULT_HLL_LOG_K; - datasketches::target_hll_type tgt_type = datasketches::HLL_6; - if (ctx->get_num_args() == 2) { - log_k = (uint8_t)(columns[1]->get(0).get_int32()); - } else if (ctx->get_num_args() == 3) { - log_k = (uint8_t)(columns[1]->get(0).get_int32()); - std::string tgt_type_str = columns[2]->get(0).get_slice().to_string(); - std::transform(tgt_type_str.begin(), tgt_type_str.end(), tgt_type_str.begin(), ::toupper); - if (tgt_type_str == "HLL_4") { - tgt_type = datasketches::HLL_4; - } else if (tgt_type_str == "HLL_8") { - tgt_type = datasketches::HLL_8; - } else { - tgt_type = datasketches::HLL_6; - } - } - return {log_k, tgt_type}; - } - - // init hll sketch with default log_k and target type - std::unique_ptr _init_hll_sketch( - uint8_t log_k = DEFAULT_HLL_LOG_K, datasketches::target_hll_type tgt_type = datasketches::HLL_6) const { - return std::make_unique(log_k, tgt_type); - } -}; - -} // namespace starrocks diff --git a/be/src/exprs/agg/hypothesis_testing_common.h b/be/src/exprs/agg/hypothesis_testing_common.h new file mode 100644 index 00000000000000..f371e57bd4f679 --- /dev/null +++ b/be/src/exprs/agg/hypothesis_testing_common.h @@ -0,0 +1,63 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace starrocks { + +class TestingAlternative { +public: + enum Alternative : uint8_t { + // The initial value of TestingAlternative is Unknown, + // which should be considered as an error + // if we encounter a TestingAlternative with a value of Unknown. + Unknown = 0, + TwoSided, + Less, + Greater, + }; + + TestingAlternative(Alternative alternative = Unknown) : _alternative(alternative) {} + TestingAlternative(uint8_t alternative) : _alternative(static_cast(alternative)) {} + + static TestingAlternative from_str(std::string_view alternative_str) { + TestingAlternative alternative; + if (alternative_str == "two-sided") { + alternative._alternative = TwoSided; + } else if (alternative_str == "less") { + alternative._alternative = Less; + } else if (alternative_str == "greater") { + alternative._alternative = Greater; + } else { + alternative._alternative = Unknown; + } + return alternative; + } + + friend bool operator==(TestingAlternative const& lhs, TestingAlternative const& rhs) { + return lhs._alternative == rhs._alternative; + } + + friend bool operator==(TestingAlternative const& lhs, Alternative const& rhs) { return lhs._alternative == rhs; } + + uint8_t value() const { return _alternative; } + +private: + Alternative _alternative{Unknown}; +}; + +} // namespace starrocks diff --git a/be/src/exprs/agg/mann_whitney.h b/be/src/exprs/agg/mann_whitney.h new file mode 100644 index 00000000000000..11b58663cb3163 --- /dev/null +++ b/be/src/exprs/agg/mann_whitney.h @@ -0,0 +1,355 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +#include "boost/math/distributions/normal.hpp" +#include "column/type_traits.h" +#include "column/vectorized_fwd.h" +#include "exprs/agg/aggregate.h" +#include "exprs/agg/helpers/serialize_helpers.hpp" +#include "exprs/agg/hypothesis_testing_common.h" +#include "exprs/function_context.h" +#include "exprs/function_helper.h" +#include "gutil/casts.h" +#include "types/logical_type.h" +#include "util/json.h" + +namespace starrocks { + +class MannWhitneyAggregateState { +public: + MannWhitneyAggregateState() = default; + MannWhitneyAggregateState(const uint8_t*& data) { deserialize(data); } + + void init(TestingAlternative alternative, int64_t continuity_correction) { + _alternative = alternative; + _continuity_correction = continuity_correction; + } + + bool is_uninitialized() const { return _alternative == TestingAlternative::Unknown; } + + void update(double x, bool treatment) { + _stats[treatment].emplace_back(x); + _sorted = false; + } + + void merge(MannWhitneyAggregateState const& other) { + DCHECK(_alternative == other._alternative); + sort_if_not_sorted(); + other.sort_if_not_sorted(); + for (size_t idx = 0; idx < 2; ++idx) { + std::vector tmp; + std::merge(_stats[idx].begin(), _stats[idx].end(), other._stats[idx].begin(), other._stats[idx].end(), + std::back_inserter(tmp)); + _stats[idx] = tmp; + } + } + + void serialize(uint8_t*& data) const { + DCHECK(!is_uninitialized()); + sort_if_not_sorted(); + SerializeHelpers::serialize_all(data, _alternative.value(), _stats, _continuity_correction); + } + + void deserialize(const uint8_t*& data) { + uint8_t tmp; + SerializeHelpers::deserialize_all(data, tmp, _stats, _continuity_correction); + _alternative = TestingAlternative(tmp); + _sorted = true; + } + + size_t serialized_size() const { + return SerializeHelpers::serialized_size_all(_alternative.value(), _stats, _continuity_correction); + } + + void build_result(vpack::Builder& builder) const { + if (_alternative == TestingAlternative::Unknown) { + vpack::ObjectBuilder obj_builder(&builder); + builder.add("Logical Error", vpack::Value("state not initialized.")); + return; + } + // If there's only one state, it's not sorted yet, so sorted it. + sort_if_not_sorted(); + size_t size = _stats[0].size() + _stats[1].size(); + std::vector index(size); + std::iota(index.begin(), index.end(), 0); + auto data = [this](size_t idx) { + if (idx < this->_stats[0].size()) { + return this->_stats[0][idx]; + } + return this->_stats[1][idx - this->_stats[0].size()]; + }; + std::inplace_merge(index.begin(), index.begin() + _stats[0].size(), index.end(), + [data](size_t lhs, size_t rhs) { return data(lhs) < data(rhs); }); + DCHECK(std::is_sorted(index.begin(), index.end(), + [data](size_t lhs, size_t rhs) { return data(lhs) < data(rhs); })); + + const double n1 = _stats[0].size(); + const double n2 = _stats[1].size(); + double r1 = 0; + double tie_correction = 0; + size_t left = 0; + double tie_numenator = 0; + while (left < size) { + size_t right = left; + while (right < size && data(index[left]) == data(index[right])) { + ++right; + } + auto adjusted = (left + right + 1.) / 2.; + auto count_equal = right - left; + + // Scipy implementation throws exception in this case too. + if (count_equal == size) { + vpack::ObjectBuilder obj_builder(&builder); + builder.add("Error", vpack::Value("All numbers in both samples are identical.")); + return; + } + + tie_numenator += std::pow(count_equal, 3) - count_equal; + for (size_t iter = left; iter < right; ++iter) { + if (index[iter] < n1) { + r1 += adjusted; + } + } + left = right; + } + tie_correction = 1 - (tie_numenator / (std::pow(size, 3) - size)); + + const double u1 = n1 * n2 + (n1 * (n1 + 1.)) / 2. - r1; + const double u2 = n1 * n2 - u1; + + /// The distribution of U-statistic under null hypothesis H0 is symmetric with respect to meanrank. + const double meanrank = n1 * n2 / 2. + 0.5 * _continuity_correction; + const double sd = std::sqrt(tie_correction * n1 * n2 * (n1 + n2 + 1) / 12.0); + + if (std::isnan(sd) || std::isinf(sd) || std::abs(sd) < 1e-7) { + vpack::ObjectBuilder obj_builder(&builder); + builder.add("Logical Error", vpack::Value(fmt::format("sd({}) is not a valid value.", sd))); + return; + } + + double u = 0; + if (_alternative == TestingAlternative::TwoSided) { + u = std::max(u1, u2); + } else if (_alternative == TestingAlternative::Less) { + u = u1; + } else if (_alternative == TestingAlternative::Greater) { + u = u2; + } else { + DCHECK(false); + } + + double z = (u - meanrank) / sd; + if (_alternative == TestingAlternative::TwoSided) { + z = std::abs(z); + } + + auto standart_normal_distribution = boost::math::normal_distribution(); + auto cdf = boost::math::cdf(standart_normal_distribution, z); + + double p_value = 0; + if (_alternative == TestingAlternative::TwoSided) { + p_value = 2 - 2 * cdf; + } else { + p_value = 1 - cdf; + } + + vpack::ArrayBuilder array_builder(&builder); + builder.add(vpack::Value(u2)); + builder.add(vpack::Value(p_value)); + } + +private: + TestingAlternative _alternative; + mutable std::array, 2> _stats; + int64_t _continuity_correction{0}; + mutable bool _sorted{true}; + + void sort_if_not_sorted() const { + if (!_sorted) { + for (size_t idx = 0; idx < 2; ++idx) { + std::sort(_stats[idx].begin(), _stats[idx].end()); + } + _sorted = true; + } + for (size_t idx = 0; idx < 2; ++idx) { + DCHECK(std::is_sorted(_stats[idx].begin(), _stats[idx].end())); + } + } +}; + +class MannWhitneyUTestAggregateFunction + : public AggregateFunctionBatchHelper { +public: + using DataColumn = RunTimeColumnType; + using DataCppType = RunTimeCppType; + using IndexColumn = RunTimeColumnType; + using AlternativeColumn = RunTimeColumnType; + using ContinuityCorrectionColumn = RunTimeColumnType; + using ResultColumn = RunTimeColumnType; + + void init_state_if_needed(FunctionContext* ctx, const Column* alternative_col, + const Column* continuity_correction_col, MannWhitneyAggregateState& state) const { + if (!state.is_uninitialized()) { + return; + } + init_state(ctx, alternative_col, continuity_correction_col, state); + } + + void init_state(FunctionContext* ctx, const Column* alternative_col, const Column* continuity_correction_col, + MannWhitneyAggregateState& state) const { + TestingAlternative alternative{TestingAlternative::TwoSided}; + int64_t continuity_correction = 1; + + if (alternative_col != nullptr) { + Slice alternative_slice; + FunctionHelper::get_data_of_column(alternative_col, 0, alternative_slice); + auto alternative_str = boost::to_lower_copy(alternative_slice.to_string()); + TestingAlternative init_alternative = TestingAlternative::from_str(alternative_str); + if (init_alternative == TestingAlternative::Unknown) { + ctx->set_error(fmt::format("Logical Error: invalid alternative `{}`.", alternative_str).c_str()); + return; + } + alternative = init_alternative; + } + + if (continuity_correction_col != nullptr) { + FunctionHelper::get_data_of_column(continuity_correction_col, 0, + continuity_correction); + if (continuity_correction < 0) { + ctx->set_error("Logical Error: continuity_correction must be non-negative."); + return; + } + } + + state.init(alternative, continuity_correction); + } + + void update(FunctionContext* ctx, const Column** columns, AggDataPtr __restrict state, + size_t row_num) const override { + const Column* alternative_col = ctx->get_num_args() >= 3 ? columns[2] : nullptr; + const Column* continuity_correction_col = ctx->get_num_args() >= 4 ? columns[3] : nullptr; + init_state_if_needed(ctx, alternative_col, continuity_correction_col, this->data(state)); + + DataCppType x{}; + const Column* x_col = columns[0]; + FunctionHelper::get_data_of_column(x_col, row_num, x); + + if (std::isnan((double)x) || std::isinf((double)x)) { + return; + } + + bool treatment = false; + const Column* treatment_col = columns[1]; + FunctionHelper::get_data_of_column(treatment_col, row_num, treatment); + + this->data(state).update(x, treatment); + } + + void merge(FunctionContext* ctx, const Column* column, AggDataPtr __restrict state, size_t row_num) const override { + DCHECK(column->is_binary()); + const uint8_t* serialized_data = reinterpret_cast(column->get(row_num).get_slice().data); + if (this->data(state).is_uninitialized()) { + this->data(state).deserialize(serialized_data); + return; + } + MannWhitneyAggregateState other(serialized_data); + this->data(state).merge(other); + } + + void serialize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to) const override { + DCHECK(to->is_binary()); + auto* column = down_cast(to); + Bytes& bytes = column->get_bytes(); + size_t old_size = bytes.size(); + size_t new_size = old_size + this->data(state).serialized_size(); + bytes.resize(new_size); + column->get_offset().emplace_back(new_size); + uint8_t* serialized_data = bytes.data() + old_size; + this->data(state).serialize(serialized_data); + } + + void finalize_to_column(FunctionContext* ctx, ConstAggDataPtr __restrict state, Column* to) const override { + if (this->data(state).is_uninitialized()) { + ctx->set_error("Internal Error: state not initialized."); + return; + } + vpack::Builder result_builder; + this->data(state).build_result(result_builder); + JsonValue result_json(result_builder.slice()); + down_cast(to)->append(std::move(result_json)); + } + + void convert_to_serialize_format(FunctionContext* ctx, const Columns& src, size_t chunk_size, + ColumnPtr* dst) const override { + TestingAlternative alternative{TestingAlternative::TwoSided}; + int64_t continuity_correction = 1; + + const Column* alternative_col = ctx->get_num_args() >= 3 ? src[2].get() : nullptr; + const Column* continuity_correction_col = ctx->get_num_args() >= 4 ? src[3].get() : nullptr; + if (alternative_col != nullptr) { + Slice alternative_slice; + FunctionHelper::get_data_of_column(alternative_col, 0, alternative_slice); + auto alternative_str = boost::to_lower_copy(alternative_slice.to_string()); + TestingAlternative init_alternative = TestingAlternative::from_str(alternative_str); + if (init_alternative == TestingAlternative::Unknown) { + ctx->set_error(fmt::format("Logical Error: invalid alternative `{}`.", alternative_str).c_str()); + return; + } + alternative = init_alternative; + } + + if (continuity_correction_col != nullptr) { + FunctionHelper::get_data_of_column(continuity_correction_col, 0, + continuity_correction); + if (continuity_correction < 0) { + ctx->set_error("Logical Error: continuity_correction must be non-negative."); + return; + } + } + + for (size_t idx = 0; idx < chunk_size; ++idx) { + MannWhitneyAggregateState state; + state.init(alternative, continuity_correction); + + DataCppType x{}; + const Column* x_col = src[0].get(); + FunctionHelper::get_data_of_column(x_col, idx, x); + + bool treatment = false; + const Column* treatment_col = src[1].get(); + FunctionHelper::get_data_of_column(treatment_col, idx, treatment); + + if (!(std::isnan((double)x) || std::isinf((double)x))) { + state.update(x, treatment); + } + + serialize_to_column(ctx, reinterpret_cast(&state), dst->get()); + } + } + + std::string get_name() const override { return "mann_whitney_u_test"; } +}; + +} // namespace starrocks diff --git a/be/src/exprs/array_element_expr.cpp b/be/src/exprs/array_element_expr.cpp index cb042c1a5b0b2b..ca73edf4394ad3 100644 --- a/be/src/exprs/array_element_expr.cpp +++ b/be/src/exprs/array_element_expr.cpp @@ -35,7 +35,8 @@ class ArrayElementExpr final : public Expr { StatusOr evaluate_checked(ExprContext* context, Chunk* chunk) override { DCHECK_EQ(2, _children.size()); - DCHECK_EQ(_type, _children[0]->type().children[0]); + // After DLA's complex type prune, ArrayElement expr's type is different from children's type + // DCHECK_EQ(_type, _children[0]->type().children[0]); ASSIGN_OR_RETURN(ColumnPtr arg0, _children[0]->evaluate_checked(context, chunk)); ASSIGN_OR_RETURN(ColumnPtr arg1, _children[1]->evaluate_checked(context, chunk)); size_t num_rows = std::max(arg0->size(), arg1->size()); diff --git a/be/src/exprs/array_functions.tpp b/be/src/exprs/array_functions.tpp index a39e735c2590b3..5eee8a0b99d7bd 100644 --- a/be/src/exprs/array_functions.tpp +++ b/be/src/exprs/array_functions.tpp @@ -1171,7 +1171,7 @@ public: if constexpr (HasNull) { elements_nulls = elements_null_col->get_data().data(); } - const auto& elements_data = GetContainer().get_data(elements); + const auto& elements_data = GetContainer::get_data(elements); auto* offsets_ptr = offsets->get_data().data(); auto* null_ptr = null_cols->get_data().data(); diff --git a/be/src/exprs/bitmap_functions.cpp b/be/src/exprs/bitmap_functions.cpp index e4a2b97c1ffd9f..ff34607b279030 100644 --- a/be/src/exprs/bitmap_functions.cpp +++ b/be/src/exprs/bitmap_functions.cpp @@ -700,13 +700,8 @@ StatusOr BitmapFunctions::bitmap_to_binary(FunctionContext* context, } ColumnPtr col = builder.build(ColumnHelper::is_all_const(columns)); - std::string err_msg; - if (col->capacity_limit_reached(&err_msg)) { - return Status::InternalError( - strings::Substitute("Size of binary column generated by bitmap_to_binary reaches limit: $0", err_msg)); - } else { - return col; - } + RETURN_IF_ERROR(col->capacity_limit_reached()); + return col; } StatusOr BitmapFunctions::bitmap_from_binary(FunctionContext* context, const Columns& columns) { diff --git a/be/src/exprs/debug_expr.cpp b/be/src/exprs/debug_expr.cpp index b1d0098906f800..5d4b260c96c8ec 100644 --- a/be/src/exprs/debug_expr.cpp +++ b/be/src/exprs/debug_expr.cpp @@ -45,12 +45,7 @@ StatusOr DebugFunctions::chunk_check_valid(ExprContext* context, Chun size_t num_rows = ptr->num_rows(); for (const auto& column : ptr->columns()) { // check column size capacity - std::string msg; - column->capacity_limit_reached(&msg); - if (!msg.empty()) { - DCHECK(false) << "not expected"; - throw std::runtime_error(msg); - } + RETURN_IF_ERROR(column->capacity_limit_reached()); // check column size matched if (column->size() != num_rows) { DCHECK(false) << "not expected"; diff --git a/be/src/exprs/function_call_expr.cpp b/be/src/exprs/function_call_expr.cpp index 988c05d4896cc4..b27fe037bb6a61 100644 --- a/be/src/exprs/function_call_expr.cpp +++ b/be/src/exprs/function_call_expr.cpp @@ -192,11 +192,7 @@ StatusOr VectorizedFunctionCallExpr::evaluate_checked(starrocks::Expr } RETURN_IF_ERROR(result); if (_fn_desc->check_overflow) { - std::string err_msg; - if (UNLIKELY(result.value()->capacity_limit_reached(&err_msg))) { - return Status::InternalError( - fmt::format("Result column of function {} exceed limit: {}", _fn_desc->name, err_msg)); - } + RETURN_IF_ERROR(result.value()->capacity_limit_reached()); } // For no args function call (pi, e) diff --git a/be/src/exprs/function_helper.h b/be/src/exprs/function_helper.h index 6f3a15d080fc56..ac953f11cafa54 100644 --- a/be/src/exprs/function_helper.h +++ b/be/src/exprs/function_helper.h @@ -17,6 +17,7 @@ #include "column/column_helper.h" #include "column/const_column.h" #include "column/type_traits.h" +#include "gutil/casts.h" namespace starrocks { class FunctionContext; @@ -37,6 +38,21 @@ class FunctionHelper { return ptr; } + /** + * get data of column. + * @param col, row_num, data + */ + template + static void get_data_of_column(const Column* col, size_t row_num, CppType& data) { + if (col->is_constant()) { + auto const_col = down_cast(col); + col = const_col->data_column().get(); + row_num = 0; + } + const auto* column = down_cast(col); + data = column->get_data()[row_num]; + } + /** * if ptr is ConstColumn, return data column * else return ptr diff --git a/be/src/exprs/in_const_predicate.cpp b/be/src/exprs/in_const_predicate.cpp index 71d7a896553feb..1021f2ed16ae54 100644 --- a/be/src/exprs/in_const_predicate.cpp +++ b/be/src/exprs/in_const_predicate.cpp @@ -85,7 +85,7 @@ void VectorizedInConstPredicateBuilder::add_values(const ColumnPtr& column, size case LogicalType::FIELD_TYPE: { \ using ColumnType = typename RunTimeTypeTraits::ColumnType; \ auto* in_pred = (VectorizedInConstPredicate*)(expr); \ - const auto& data_ptr = GetContainer().get_data(column); \ + const auto& data_ptr = GetContainer::get_data(column); \ if (in_pred->is_use_array()) { \ for (size_t j = column_offset; j < data_ptr.size(); j++) { \ in_pred->insert_array(data_ptr[j]); \ @@ -103,34 +103,34 @@ void VectorizedInConstPredicateBuilder::add_values(const ColumnPtr& column, size } } else { switch (type) { -#define M(FIELD_TYPE) \ - case LogicalType::FIELD_TYPE: { \ - using ColumnType = typename RunTimeTypeTraits::ColumnType; \ - auto* in_pred = (VectorizedInConstPredicate*)(expr); \ - auto* nullable_column = ColumnHelper::as_raw_column(column); \ - const auto& data_array = GetContainer().get_data(nullable_column->data_column()); \ - if (in_pred->is_use_array()) { \ - for (size_t j = column_offset; j < data_array.size(); j++) { \ - if (!nullable_column->is_null(j)) { \ - in_pred->insert_array(data_array[j]); \ - } else { \ - if (_eq_null) { \ - in_pred->insert_null(); \ - } \ - } \ - } \ - } else { \ - for (size_t j = column_offset; j < data_array.size(); j++) { \ - if (!nullable_column->is_null(j)) { \ - in_pred->insert(data_array[j]); \ - } else { \ - if (_eq_null) { \ - in_pred->insert_null(); \ - } \ - } \ - } \ - } \ - break; \ +#define M(FIELD_TYPE) \ + case LogicalType::FIELD_TYPE: { \ + using ColumnType = typename RunTimeTypeTraits::ColumnType; \ + auto* in_pred = (VectorizedInConstPredicate*)(expr); \ + auto* nullable_column = ColumnHelper::as_raw_column(column); \ + const auto& data_array = GetContainer::get_data(nullable_column->data_column()); \ + if (in_pred->is_use_array()) { \ + for (size_t j = column_offset; j < data_array.size(); j++) { \ + if (!nullable_column->is_null(j)) { \ + in_pred->insert_array(data_array[j]); \ + } else { \ + if (_eq_null) { \ + in_pred->insert_null(); \ + } \ + } \ + } \ + } else { \ + for (size_t j = column_offset; j < data_array.size(); j++) { \ + if (!nullable_column->is_null(j)) { \ + in_pred->insert(data_array[j]); \ + } else { \ + if (_eq_null) { \ + in_pred->insert_null(); \ + } \ + } \ + } \ + } \ + break; \ } APPLY_FOR_ALL_SCALAR_TYPE(M) #undef M diff --git a/be/src/exprs/json_functions.cpp b/be/src/exprs/json_functions.cpp index 9e47bc48d6a6e6..e4158bc4457035 100644 --- a/be/src/exprs/json_functions.cpp +++ b/be/src/exprs/json_functions.cpp @@ -536,8 +536,6 @@ static StatusOr _extract_with_hyper(NativeJsonState* state, const std state->real_path.paths.emplace_back(p); } - state->init_flat = true; - state->flat_path = flat_path.substr(1); if (in_flat) { state->is_partial_match = false; state->flat_column_type = TargetType; @@ -545,6 +543,8 @@ static StatusOr _extract_with_hyper(NativeJsonState* state, const std state->is_partial_match = true; state->flat_column_type = TYPE_JSON; } + state->flat_path = flat_path.substr(1); + state->init_flat = true; }); } std::vector dst_path{state->flat_path}; diff --git a/be/src/exprs/like_predicate.cpp b/be/src/exprs/like_predicate.cpp index 84aecb671e7c37..813d3bfbf96460 100644 --- a/be/src/exprs/like_predicate.cpp +++ b/be/src/exprs/like_predicate.cpp @@ -447,6 +447,48 @@ StatusOr LikePredicate::_predicate_const_regex(FunctionContext* conte return result->build(value_column->is_constant()); } +enum class FastPathType { + EQUALS = 0, + START_WITH = 1, + END_WITH = 2, + SUBSTRING = 3, + REGEX = 4, +}; + +FastPathType extract_fast_path(const Slice& pattern) { + if (pattern.empty() || pattern.size < 2) { + return FastPathType::REGEX; + } + + if (pattern.data[0] == '_' || pattern.data[pattern.size - 1] == '_') { + return FastPathType::REGEX; + } + + bool is_end_with = pattern.data[0] == '%'; + bool is_start_with = pattern.data[pattern.size - 1] == '%'; + + for (size_t i = 1; i < pattern.size - 1;) { + if (pattern.data[i] == '\\') { + i += 2; + } else { + if (pattern.data[i] == '%' || pattern.data[i] == '_') { + return FastPathType::REGEX; + } + i++; + } + } + + if (is_end_with && is_start_with) { + return FastPathType::SUBSTRING; + } else if (is_end_with) { + return FastPathType::END_WITH; + } else if (is_start_with) { + return FastPathType::START_WITH; + } else { + return FastPathType::EQUALS; + } +} + StatusOr LikePredicate::regex_match_full(FunctionContext* context, const starrocks::Columns& columns) { const auto& value_column = VECTORIZED_FN_ARGS(0); const auto& pattern_column = VECTORIZED_FN_ARGS(1); @@ -478,18 +520,56 @@ StatusOr LikePredicate::regex_match_full(FunctionContext* context, co continue; } - auto re_pattern = LikePredicate::template convert_like_pattern(context, pattern_viewer.value(row)); - - re2::RE2 re(re_pattern, opts); - - if (!re.ok()) { - context->set_error(strings::Substitute("Invalid regex: $0", re_pattern).c_str()); - result.append_null(); - continue; + Slice pattern = pattern_viewer.value(row); + FastPathType val = extract_fast_path(pattern); + switch (val) { + case FastPathType::EQUALS: { + std::string str_pattern = pattern.to_string(); + remove_escape_character(&str_pattern); + result.append(value_viewer.value(row) == str_pattern); + break; } + case FastPathType::START_WITH: { + std::string str_pattern = pattern.to_string(); + remove_escape_character(&str_pattern); + auto pattern_slice = Slice(str_pattern); + pattern_slice.remove_suffix(1); + result.append(ConstantStartsImpl::apply(value_viewer.value(row), pattern_slice)); + break; + } + case FastPathType::END_WITH: { + std::string str_pattern = pattern.to_string(); + remove_escape_character(&str_pattern); + auto pattern_slice = Slice(str_pattern); + pattern_slice.remove_prefix(1); + result.append(ConstantEndsImpl::apply(value_viewer.value(row), pattern_slice)); + break; + } + case FastPathType::SUBSTRING: { + std::string str_pattern = pattern.to_string(); + remove_escape_character(&str_pattern); + auto pattern_slice = Slice(str_pattern); + pattern_slice.remove_prefix(1); + pattern_slice.remove_suffix(1); + auto searcher = LibcASCIICaseSensitiveStringSearcher(pattern_slice.get_data(), pattern_slice.get_size()); + /// searcher returns a pointer to the found substring or to the end of `haystack`. + const Slice& value = value_viewer.value(row); + const char* res_pointer = searcher.search(value.data, value.size); + result.append(!!res_pointer); + break; + } + case FastPathType::REGEX: { + auto re_pattern = LikePredicate::template convert_like_pattern(context, pattern); - auto v = RE2::FullMatch(re2::StringPiece(value_viewer.value(row).data, value_viewer.value(row).size), re); - result.append(v); + re2::RE2 re(re_pattern, opts); + if (!re.ok()) { + return Status::InvalidArgument(strings::Substitute("Invalid regex: $0", re_pattern)); + } + auto v = RE2::FullMatch(re2::StringPiece(value_viewer.value(row).data, value_viewer.value(row).size), re); + result.append(v); + break; + } + } } return result.build(all_const); diff --git a/be/src/exprs/runtime_filter.h b/be/src/exprs/runtime_filter.h index 6a4cd90a412f30..4323413f48a34a 100644 --- a/be/src/exprs/runtime_filter.h +++ b/be/src/exprs/runtime_filter.h @@ -894,7 +894,7 @@ class RuntimeBloomFilter final : public JoinRuntimeFilter { if (const_column->only_null()) { _selection[0] = _has_null; } else { - const auto& input_data = GetContainer().get_data(const_column->data_column()); + const auto& input_data = GetContainer::get_data(const_column->data_column()); _evaluate_min_max(input_data, _selection, 1); if constexpr (can_use_bf) { _rf_test_data(_selection, input_data, _hash_values, 0); @@ -904,7 +904,7 @@ class RuntimeBloomFilter final : public JoinRuntimeFilter { memset(_selection, sel, size); } else if (input_column->is_nullable()) { const auto* nullable_column = down_cast(input_column); - const auto& input_data = GetContainer().get_data(nullable_column->data_column()); + const auto& input_data = GetContainer::get_data(nullable_column->data_column()); _evaluate_min_max(input_data, _selection, size); if (nullable_column->has_null()) { const uint8_t* null_data = nullable_column->immutable_null_column_data().data(); @@ -925,7 +925,7 @@ class RuntimeBloomFilter final : public JoinRuntimeFilter { } } } else { - const auto& input_data = GetContainer().get_data(input_column); + const auto& input_data = GetContainer::get_data(input_column); _evaluate_min_max(input_data, _selection, size); if constexpr (can_use_bf) { for (int i = 0; i < size; ++i) { diff --git a/be/src/exprs/runtime_filter_bank.cpp b/be/src/exprs/runtime_filter_bank.cpp index 193c1ead7fbc4e..418ae366b202dd 100644 --- a/be/src/exprs/runtime_filter_bank.cpp +++ b/be/src/exprs/runtime_filter_bank.cpp @@ -125,7 +125,7 @@ struct FilterIniter { if (column->is_nullable()) { auto* nullable_column = ColumnHelper::as_raw_column(column); - const auto& data_array = GetContainer().get_data(nullable_column->data_column().get()); + const auto& data_array = GetContainer::get_data(nullable_column->data_column().get()); for (size_t j = column_offset; j < data_array.size(); j++) { if (!nullable_column->is_null(j)) { filter->insert(data_array[j]); @@ -136,7 +136,7 @@ struct FilterIniter { } } } else { - const auto& data_array = GetContainer().get_data(column.get()); + const auto& data_array = GetContainer::get_data(column.get()); for (size_t j = column_offset; j < data_array.size(); j++) { filter->insert(data_array[j]); } diff --git a/be/src/exprs/string_functions.cpp b/be/src/exprs/string_functions.cpp index 799953d67628ce..499549edc4a638 100644 --- a/be/src/exprs/string_functions.cpp +++ b/be/src/exprs/string_functions.cpp @@ -56,7 +56,7 @@ namespace starrocks { static const RE2 SUBSTRING_RE(R"((?:\.\*)*([^\.\^\{\[\(\|\)\]\}\+\*\?\$\\]+)(?:\.\*)*)", re2::RE2::Quiet); #define THROW_RUNTIME_ERROR_IF_EXCEED_LIMIT(col, func_name) \ - if (UNLIKELY(col->capacity_limit_reached())) { \ + if (UNLIKELY(!col->capacity_limit_reached().ok())) { \ col->reset_column(); \ throw std::runtime_error("binary column exceed 4G in function " #func_name); \ } @@ -2633,7 +2633,8 @@ StatusOr StringFunctions::ascii(FunctionContext* context, const Colum } DEFINE_UNARY_FN_WITH_IMPL(get_charImpl, value) { - return std::string((char*)&value, 1); + char* p = (char*)&value; + return std::string(p, 1); } StatusOr StringFunctions::get_char(FunctionContext* context, const Columns& columns) { diff --git a/be/src/exprs/subfield_expr.cpp b/be/src/exprs/subfield_expr.cpp index 30f973061de818..1f314e4394543c 100644 --- a/be/src/exprs/subfield_expr.cpp +++ b/be/src/exprs/subfield_expr.cpp @@ -20,6 +20,7 @@ #include "column/nullable_column.h" #include "column/struct_column.h" #include "common/object_pool.h" +#include "exprs/function_helper.h" namespace starrocks { @@ -39,9 +40,19 @@ class SubfieldExpr final : public Expr { ASSIGN_OR_RETURN(ColumnPtr col, _children.at(0)->evaluate_checked(context, chunk)); - // Enter multiple subfield for struct type, remain last subfield - for (size_t i = 0; i < _used_subfield_names.size() - 1; i++) { - std::string fieldname = _used_subfield_names[i]; + // handle nullable column + const size_t num_rows = col->size(); + NullColumnPtr union_null_column = NullColumn::create(num_rows, false); + + for (size_t i = 0; i < _used_subfield_names.size(); i++) { + const std::string& fieldname = _used_subfield_names[i]; + + // merge null flags for each level + if (col->is_nullable()) { + auto* nullable = down_cast(col.get()); + union_null_column = FunctionHelper::union_null_column(union_null_column, nullable->null_column()); + } + Column* tmp_col = ColumnHelper::get_data_column(col.get()); DCHECK(tmp_col->is_struct()); auto* struct_column = down_cast(tmp_col); @@ -51,40 +62,19 @@ class SubfieldExpr final : public Expr { } } - // handle nullable column - std::vector null_flags; - size_t num_rows = col->size(); - null_flags.resize(num_rows, false); if (col->is_nullable()) { auto* nullable = down_cast(col.get()); - const uint8_t* nulls = nullable->null_column()->raw_data(); - std::memcpy(&null_flags[0], &nulls[0], num_rows * sizeof(uint8_t)); - } - - Column* tmp_col = ColumnHelper::get_data_column(col.get()); - DCHECK(tmp_col->is_struct()); - auto* struct_column = down_cast(tmp_col); - - std::string fieldname = _used_subfield_names.back(); - ColumnPtr subfield_column = struct_column->field_column(fieldname); - if (subfield_column->is_nullable()) { - auto* nullable = down_cast(subfield_column.get()); - const uint8_t* nulls = nullable->null_column()->raw_data(); - for (size_t i = 0; i < num_rows; i++) { - null_flags[i] |= nulls[i]; - } - subfield_column = nullable->data_column(); + union_null_column = FunctionHelper::union_null_column(union_null_column, nullable->null_column()); + col = nullable->data_column(); } - NullColumnPtr result_null = NullColumn::create(); - result_null->get_data().swap(null_flags); - DCHECK_EQ(subfield_column->size(), result_null->size()); + DCHECK_EQ(col->size(), union_null_column->size()); - // We need clone a new subfield column + // We need to clone a new subfield column if (_copy_flag) { - return NullableColumn::create(subfield_column->clone_shared(), result_null); + return NullableColumn::create(col->clone_shared(), union_null_column); } else { - return NullableColumn::create(subfield_column, result_null); + return NullableColumn::create(col, union_null_column); } } diff --git a/be/src/exprs/table_function/java_udtf_function.cpp b/be/src/exprs/table_function/java_udtf_function.cpp index 9884ab79c1a135..33389ac60cde05 100644 --- a/be/src/exprs/table_function/java_udtf_function.cpp +++ b/be/src/exprs/table_function/java_udtf_function.cpp @@ -21,6 +21,7 @@ #include "column/column_helper.h" #include "column/nullable_column.h" #include "column/vectorized_fwd.h" +#include "common/compiler_util.h" #include "exprs/table_function/table_function.h" #include "gutil/casts.h" #include "jni.h" @@ -137,21 +138,28 @@ std::pair JavaUDTFFunction::process(RuntimeState* ru std::vector call_stack; std::vector rets; - DeferOp defer = DeferOp([&]() { - // clean up arrays - for (auto& ret : rets) { - if (ret) { - env->DeleteLocalRef(ret); - } - } - }); size_t num_rows = cols[0]->size(); size_t num_cols = cols.size(); state->set_processed_rows(num_rows); call_stack.reserve(num_cols); rets.resize(num_rows); + + // reserve 16 local refs + DeferOp defer = DeferOp([&]() { + // clean up arrays + env->PopLocalFrame(nullptr); + }); + env->PushLocalFrame(num_cols * num_rows + 16); + for (int i = 0; i < num_rows; ++i) { + DeferOp defer = DeferOp([&]() { + for (int j = 0; j < num_cols; ++j) { + release_jvalue(stateUDTF->method_process()->method_desc[j + 1].is_box, call_stack[j]); + } + call_stack.clear(); + }); + for (int j = 0; j < num_cols; ++j) { auto method_type = stateUDTF->method_process()->method_desc[j + 1]; jvalue val = cast_to_jvalue(method_type.type, method_type.is_box, cols[j].get(), i); @@ -160,11 +168,13 @@ std::pair JavaUDTFFunction::process(RuntimeState* ru rets[i] = env->CallObjectMethodA(stateUDTF->handle(), methodID, call_stack.data()); - for (int j = 0; j < num_cols; ++j) { - release_jvalue(stateUDTF->method_process()->method_desc[j + 1].is_box, call_stack[j]); + if (auto jthr = helper.getEnv()->ExceptionOccurred(); jthr != nullptr) { + std::string err = fmt::format("execute UDF Function meet Exception:{}", helper.dumpExceptionString(jthr)); + LOG(WARNING) << err; + helper.getEnv()->ExceptionClear(); + state->set_status(Status::InternalError(err)); + return std::make_pair(Columns{}, nullptr); } - - call_stack.clear(); } // Build Return Type @@ -185,8 +195,12 @@ std::pair JavaUDTFFunction::process(RuntimeState* ru for (int j = 0; j < len; ++j) { jobject vi = env->GetObjectArrayElement((jobjectArray)rets[i], j); LOCAL_REF_GUARD_ENV(env, vi); + auto st = check_type_matched(method_desc, vi); + if (UNLIKELY(!st.ok())) { + state->set_status(st); + return std::make_pair(Columns{}, nullptr); + } append_jvalue(method_desc, col.get(), {.l = vi}); - release_jvalue(method_desc.is_box, {.l = vi}); } } diff --git a/be/src/exprs/table_function/subdivide_bitmap.h b/be/src/exprs/table_function/subdivide_bitmap.h index b55f63422c40c7..7617f6421c0b76 100644 --- a/be/src/exprs/table_function/subdivide_bitmap.h +++ b/be/src/exprs/table_function/subdivide_bitmap.h @@ -105,17 +105,16 @@ class SubdivideBitmap final : public TableFunction { &compact_offset); } } - std::string err_msg; - auto ret = dst_bitmap_col->capacity_limit_reached(&err_msg); - if (ret) { + Status st = dst_bitmap_col->capacity_limit_reached(); + if (!st.ok()) { state->set_status(Status::InternalError( - fmt::format("Bitmap column generate by subdivide_bitmap reach limit, {}", err_msg))); + fmt::format("Bitmap column generate by subdivide_bitmap reach limit, {}", st.message()))); return {}; } - ret = dst_offset_col->capacity_limit_reached(&err_msg); - if (ret) { + st = dst_offset_col->capacity_limit_reached(); + if (!st.ok()) { state->set_status(Status::InternalError( - fmt::format("Offset column generate by subdivide_bitmap reach limit, {}", err_msg))); + fmt::format("Offset column generate by subdivide_bitmap reach limit, {}", st.message()))); return {}; } dst_columns.emplace_back(std::move(dst_bitmap_col)); diff --git a/be/src/exprs/unary_function.h b/be/src/exprs/unary_function.h index f705a9afcf5d28..c8ceaf99ebc2c8 100644 --- a/be/src/exprs/unary_function.h +++ b/be/src/exprs/unary_function.h @@ -101,7 +101,7 @@ class UnaryFunction { result->resize(size); auto* r3 = result->get_data().data(); - const auto& data_array = GetContainer().get_data(v1); + const auto& data_array = GetContainer::get_data(v1); if constexpr (lt_is_string || lt_is_binary) { for (int i = 0; i < size; ++i) { @@ -136,7 +136,6 @@ struct StringUnaryFunction { int size = v1->size(); for (int i = 0; i < size; ++i) { std::string ret = OP::template apply, std::string>(r1[i], std::forward(args)...); - bytes.reserve(ret.size()); bytes.insert(bytes.end(), (uint8_t*)ret.data(), (uint8_t*)ret.data() + ret.size()); offset.emplace_back(bytes.size()); } diff --git a/be/src/formats/disk_range.hpp b/be/src/formats/disk_range.hpp index 90641ee320566b..ced8c97c8bdbc2 100644 --- a/be/src/formats/disk_range.hpp +++ b/be/src/formats/disk_range.hpp @@ -24,7 +24,7 @@ class DiskRange { public: DiskRange(const int64_t off, const int64_t len) : _offset(off), _length(len) { DCHECK(off >= 0); - DCHECK(len > 0); + DCHECK(len >= 0); } /** diff --git a/be/src/formats/parquet/column_reader.cpp b/be/src/formats/parquet/column_reader.cpp index 47435a32608729..a5f23fd6694931 100644 --- a/be/src/formats/parquet/column_reader.cpp +++ b/be/src/formats/parquet/column_reader.cpp @@ -109,12 +109,8 @@ Status ColumnDictFilterContext::rewrite_conjunct_ctxs_to_predicate(StoredColumnR predicate = obj_pool.add( new_column_eq_predicate(get_type_info(kDictCodeFieldType), slot_id, std::to_string(dict_codes[0]))); } else { - std::vector str_codes; - str_codes.reserve(dict_codes.size()); - for (int code : dict_codes) { - str_codes.emplace_back(std::to_string(code)); - } - predicate = obj_pool.add(new_column_in_predicate(get_type_info(kDictCodeFieldType), slot_id, str_codes)); + predicate = obj_pool.add(new_dictionary_code_in_predicate(get_type_info(kDictCodeFieldType), slot_id, + dict_codes, dict_value_column->size())); } // deal with if NULL works or not. diff --git a/be/src/formats/parquet/schema.cpp b/be/src/formats/parquet/schema.cpp index 9799389cb3faae..8311891ed23ab7 100644 --- a/be/src/formats/parquet/schema.cpp +++ b/be/src/formats/parquet/schema.cpp @@ -224,6 +224,13 @@ Status SchemaDescriptor::map_to_field(const std::vector // } // } // + + // check map's key must be primitive type + ASSIGN_OR_RETURN(const auto* key_schema, _get_schema_element(t_schemas, pos + 2)); + if (is_group(key_schema)) { + return Status::InvalidArgument("Map keys must be primitive type."); + } + RETURN_IF_ERROR(node_to_field(t_schemas, pos + 2, cur_level_info, key_field, next_pos)); RETURN_IF_ERROR(node_to_field(t_schemas, pos + 3, cur_level_info, value_field, next_pos)); diff --git a/be/src/runtime/CMakeLists.txt b/be/src/runtime/CMakeLists.txt index 03e77e436a313f..c8cea88e2d6405 100644 --- a/be/src/runtime/CMakeLists.txt +++ b/be/src/runtime/CMakeLists.txt @@ -43,6 +43,7 @@ set(RUNTIME_FILES load_path_mgr.cpp dummy_load_path_mgr.cpp types.cpp + agg_state_desc.cpp mem_tracker.cpp data_stream_recvr.cpp export_sink.cpp diff --git a/be/src/runtime/agg_state_desc.cpp b/be/src/runtime/agg_state_desc.cpp new file mode 100644 index 00000000000000..a10db044f6e7d3 --- /dev/null +++ b/be/src/runtime/agg_state_desc.cpp @@ -0,0 +1,124 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "runtime/agg_state_desc.h" + +#include + +#include "exprs/agg/aggregate.h" +#include "exprs/agg/aggregate_factory.h" + +namespace starrocks { + +// Create a new AggStateDesc from a thrift TTypeDesc. +AggStateDesc AggStateDesc::from_thrift(const TAggStateDesc& desc) { + VLOG(2) << "TAggStateDesc:" << apache::thrift::ThriftDebugString(desc); + std::string agg_func_name = desc.agg_func_name; + // return type + auto return_type = TypeDescriptor::from_thrift(desc.ret_type); + // arg types + std::vector arg_types; + for (auto& arg_type : desc.arg_types) { + arg_types.emplace_back(TypeDescriptor::from_thrift(arg_type)); + } + bool result_nullable = desc.result_nullable; + int func_version = desc.func_version; + return AggStateDesc{agg_func_name, std::move(return_type), std::move(arg_types), result_nullable, func_version}; +} + +// Transform this AggStateDesc to a thrift TTypeDesc. +void AggStateDesc::to_thrift(TAggStateDesc* t) { + t->agg_func_name = _func_name; + t->result_nullable = _is_result_nullable; + t->func_version = _func_version; + // return type + t->ret_type = _return_type.to_thrift(); + // arg types + for (auto& arg_type : _arg_types) { + t->arg_types.push_back(arg_type.to_thrift()); + } +} + +AggStateDesc AggStateDesc::from_protobuf(const AggStateDescPB& desc) { + auto& agg_func_name = desc.agg_func_name(); + bool is_result_nullable = desc.is_result_nullable(); + int func_version = desc.func_version(); + std::vector arg_types; + // arg types + for (auto& arg_type : desc.arg_types()) { + arg_types.emplace_back(TypeDescriptor::from_protobuf(arg_type)); + } + // ret type + auto ret_type = TypeDescriptor::from_protobuf(desc.ret_type()); + return AggStateDesc{agg_func_name, std::move(ret_type), std::move(arg_types), is_result_nullable, func_version}; +} + +void AggStateDesc::to_protobuf(AggStateDescPB* desc) { + desc->set_agg_func_name(this->get_func_name()); + desc->set_is_result_nullable(this->is_result_nullable()); + desc->set_func_version(this->get_func_version()); + // arg types + for (auto& arg_type : this->get_arg_types()) { + auto* arg_type_pb = desc->add_arg_types(); + *arg_type_pb = arg_type.to_protobuf(); + } + // ret type + auto ret_type_desc = this->get_return_type(); + auto* ret_type_pb = desc->mutable_ret_type(); + *ret_type_pb = ret_type_desc.to_protobuf(); +} + +void AggStateDesc::thrift_to_protobuf(const TAggStateDesc& desc, AggStateDescPB* pb) { + pb->set_agg_func_name(desc.agg_func_name); + pb->set_is_result_nullable(desc.result_nullable); + pb->set_func_version(desc.func_version); + // arg types + for (auto& arg_type : desc.arg_types) { + auto arg_type_desc = TypeDescriptor::from_thrift(arg_type); + auto* arg_type_pb = pb->add_arg_types(); + *arg_type_pb = arg_type_desc.to_protobuf(); + } + // ret type + auto ret_type_desc = TypeDescriptor::from_thrift(desc.ret_type); + auto* ret_type_pb = pb->mutable_ret_type(); + *ret_type_pb = ret_type_desc.to_protobuf(); +} + +std::string AggStateDesc::debug_string() const { + std::stringstream ss; + ss << "[" << _func_name << ", args:<"; + for (size_t i = 0; i < _arg_types.size(); i++) { + if (i != _arg_types.size() - 1) { + ss << _arg_types[i] << ", "; + } else { + ss << _arg_types[i] << ">"; + } + } + ss << ", ret:" << _return_type << ", result_nullable:" << _is_result_nullable << ", func_version:" << _func_version + << "]"; + return ss.str(); +} + +const AggregateFunction* AggStateDesc::get_agg_state_func(AggStateDesc* agg_state_desc) { + DCHECK(agg_state_desc); + auto* agg_function = get_aggregate_function(agg_state_desc->get_func_name(), agg_state_desc->get_return_type(), + agg_state_desc->get_arg_types(), agg_state_desc->is_result_nullable(), + TFunctionBinaryType::BUILTIN, agg_state_desc->get_func_version()); + if (agg_function == nullptr) { + LOG(WARNING) << "Failed to get aggregate function for " << agg_state_desc->debug_string(); + } + return agg_function; +} + +} // namespace starrocks \ No newline at end of file diff --git a/be/src/runtime/agg_state_desc.h b/be/src/runtime/agg_state_desc.h new file mode 100644 index 00000000000000..592efea31728e8 --- /dev/null +++ b/be/src/runtime/agg_state_desc.h @@ -0,0 +1,114 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include + +#include +#include +#include + +#include "gen_cpp/Types_types.h" +#include "runtime/types.h" + +namespace starrocks { + +struct TypeDescriptor; + +class AggStateDesc; +using AggStateDescPtr = std::shared_ptr; + +class AggregateFunction; + +class AggStateDesc { +public: + AggStateDesc(std::string func_name, TypeDescriptor return_type, std::vector arg_types, + bool is_result_nullable, int func_version) + : _func_name(std::move(func_name)), + _return_type(std::move(return_type)), + _arg_types(std::move(arg_types)), + _is_result_nullable(is_result_nullable), + _func_version(func_version) {} + + // copy assignment operator + AggStateDesc(const AggStateDesc& other) + : _func_name(other._func_name), + _return_type(other._return_type), + _arg_types(other._arg_types), + _is_result_nullable(other._is_result_nullable), + _func_version(other._func_version) {} + AggStateDesc& operator=(const AggStateDesc& other) { + if (this != &other) { + this->_func_name = other._func_name; + this->_return_type = other._return_type; + this->_arg_types = other._arg_types; + this->_is_result_nullable = other._is_result_nullable; + this->_func_version = other._func_version; + } + return *this; + } + + // move assignment operator + AggStateDesc(AggStateDesc&& other) noexcept + : _func_name(std::move(other._func_name)), + _return_type(std::move(other._return_type)), + _arg_types(std::move(other._arg_types)), + _is_result_nullable(other._is_result_nullable), + _func_version(other._func_version) {} + AggStateDesc& operator=(AggStateDesc&& other) noexcept { + if (this != &other) { + this->_func_name = std::move(other._func_name); + this->_return_type = std::move(other._return_type); + this->_arg_types = std::move(other._arg_types); + this->_is_result_nullable = other._is_result_nullable; + this->_func_version = other._func_version; + } + return *this; + } + + const std::string& get_func_name() const { return _func_name; } + const TypeDescriptor& get_return_type() const { return _return_type; } + const std::vector& get_arg_types() const { return _arg_types; } + bool is_result_nullable() const { return _is_result_nullable; } + int get_func_version() const { return _func_version; } + std::string debug_string() const; + + // Transform this AggStateDesc to a thrift TTypeDesc. + void to_thrift(TAggStateDesc* t); + // Transform this AggStateDesc to a protobuf AggStateDescPB. + void to_protobuf(AggStateDescPB* desc); + + // Create a new AggStateDesc from a thrift TTypeDesc. + static AggStateDesc from_thrift(const TAggStateDesc& desc); + // Create a new AggStateDesc from a protobuf AggStateDescPB. + static AggStateDesc from_protobuf(const AggStateDescPB& desc); + // Convert thrift TAggStateDesc to protobuf AggStateDescPB. + static void thrift_to_protobuf(const TAggStateDesc& desc, AggStateDescPB* pb); + // Get the aggregate function state descriptor. + static const AggregateFunction* get_agg_state_func(AggStateDesc* agg_state_desc); + +private: + // nested aggregate function name + std::string _func_name; + // nested aggregate function return type + TypeDescriptor _return_type; + // nested aggregate function argument types + std::vector _arg_types; + // nested aggregate function input is nullable + bool _is_result_nullable; + // nested aggregate function version + int _func_version; +}; + +} // namespace starrocks \ No newline at end of file diff --git a/be/src/runtime/current_thread.h b/be/src/runtime/current_thread.h index 9f56f4da9e8309..5c4512484f11c4 100644 --- a/be/src/runtime/current_thread.h +++ b/be/src/runtime/current_thread.h @@ -93,6 +93,7 @@ class CurrentThread { _reserved_bytes = prev_reserved; _cache_size -= size; _allocated_cache_size -= size; + _total_consumed_bytes -= size; _try_consume_mem_size = size; }; if (_cache_size >= BATCH_SIZE) { @@ -132,6 +133,7 @@ class CurrentThread { } else { _cache_size -= size; _allocated_cache_size -= size; + _total_consumed_bytes -= size; _try_consume_mem_size = size; tls_exceed_mem_tracker = limit_tracker; return false; @@ -162,6 +164,7 @@ class CurrentThread { void release(int64_t size) { _cache_size -= size; _deallocated_cache_size += size; + _total_consumed_bytes -= size; if (_cache_size <= -BATCH_SIZE) { commit(false); } diff --git a/be/src/runtime/runtime_state.cpp b/be/src/runtime/runtime_state.cpp index e2802e6cf634cd..8997293930d7a4 100644 --- a/be/src/runtime/runtime_state.cpp +++ b/be/src/runtime/runtime_state.cpp @@ -262,7 +262,7 @@ void RuntimeState::log_error(const Status& status) { return; } - log_error(status); + log_error(status.message()); } void RuntimeState::get_unreported_errors(std::vector* new_errors) { diff --git a/be/src/runtime/stream_load/stream_load_executor.cpp b/be/src/runtime/stream_load/stream_load_executor.cpp index 68070aa0ad1193..132f191b180adb 100644 --- a/be/src/runtime/stream_load/stream_load_executor.cpp +++ b/be/src/runtime/stream_load/stream_load_executor.cpp @@ -301,7 +301,11 @@ StatusOr get_txn_status(const AuthInfo& auth, std::str bool wait_txn_visible_until(const AuthInfo& auth, std::string_view db, std::string_view table, int64_t txn_id, int64_t deadline) { while (deadline > UnixSeconds()) { - sleep(std::min((int64_t)config::get_txn_status_internal_sec, deadline - UnixSeconds())); + auto wait_seconds = std::min((int64_t)config::get_txn_status_internal_sec, deadline - UnixSeconds()); + LOG(WARNING) << "transaction is not visible now, will wait " << wait_seconds + << " seconds before retrieving the status again, txn_id: " << txn_id; + // The following sleep might introduce delay to the commit and publish total time + sleep(wait_seconds); auto status_or = get_txn_status(auth, db, table, txn_id); if (!status_or.ok()) { return false; diff --git a/be/src/runtime/types.cpp b/be/src/runtime/types.cpp index deaebdedf26962..7a82dca090519c 100644 --- a/be/src/runtime/types.cpp +++ b/be/src/runtime/types.cpp @@ -123,9 +123,7 @@ void TypeDescriptor::to_thrift(TTypeDesc* thrift_type) const { curr_node.__set_scalar_type(TScalarType()); TScalarType& scalar_type = curr_node.scalar_type; scalar_type.__set_type(starrocks::to_thrift(type)); - if (len != -1) { - scalar_type.__set_len(len); - } + scalar_type.__set_len(len); if (scale != -1) { scalar_type.__set_scale(scale); } @@ -159,9 +157,7 @@ void TypeDescriptor::to_protobuf(PTypeDesc* proto_type) const { node->set_type(TTypeNodeType::SCALAR); PScalarType* scalar_type = node->mutable_scalar_type(); scalar_type->set_type(starrocks::to_thrift(type)); - if (len != -1) { - scalar_type->set_len(len); - } + scalar_type->set_len(len); if (scale != -1) { scalar_type->set_scale(scale); } diff --git a/be/src/storage/aggregate_type.cpp b/be/src/storage/aggregate_type.cpp index 65d71bc1053a65..b98618702b89b5 100644 --- a/be/src/storage/aggregate_type.cpp +++ b/be/src/storage/aggregate_type.cpp @@ -21,18 +21,19 @@ namespace starrocks { StorageAggregateType get_aggregation_type_by_string(const std::string& str) { - std::string upper_str = str; - std::transform(str.begin(), str.end(), upper_str.begin(), ::tolower); + std::string lower_str = str; + std::transform(str.begin(), str.end(), lower_str.begin(), ::tolower); - if (upper_str == "none") return STORAGE_AGGREGATE_NONE; - if (upper_str == "sum") return STORAGE_AGGREGATE_SUM; - if (upper_str == "min") return STORAGE_AGGREGATE_MIN; - if (upper_str == "max") return STORAGE_AGGREGATE_MAX; - if (upper_str == "replace") return STORAGE_AGGREGATE_REPLACE; - if (upper_str == "replace_if_not_null") return STORAGE_AGGREGATE_REPLACE_IF_NOT_NULL; - if (upper_str == "hll_union") return STORAGE_AGGREGATE_HLL_UNION; - if (upper_str == "bitmap_union") return STORAGE_AGGREGATE_BITMAP_UNION; - if (upper_str == "percentile_union") return STORAGE_AGGREGATE_PERCENTILE_UNION; + if (lower_str == "none") return STORAGE_AGGREGATE_NONE; + if (lower_str == "sum") return STORAGE_AGGREGATE_SUM; + if (lower_str == "min") return STORAGE_AGGREGATE_MIN; + if (lower_str == "max") return STORAGE_AGGREGATE_MAX; + if (lower_str == "replace") return STORAGE_AGGREGATE_REPLACE; + if (lower_str == "replace_if_not_null") return STORAGE_AGGREGATE_REPLACE_IF_NOT_NULL; + if (lower_str == "hll_union") return STORAGE_AGGREGATE_HLL_UNION; + if (lower_str == "bitmap_union") return STORAGE_AGGREGATE_BITMAP_UNION; + if (lower_str == "percentile_union") return STORAGE_AGGREGATE_PERCENTILE_UNION; + if (lower_str == "agg_state_union") return STORAGE_AGGREGATE_AGG_STATE_UNION; LOG(WARNING) << "invalid aggregation type string. [aggregation='" << str << "']"; return STORAGE_AGGREGATE_UNKNOWN; } @@ -59,6 +60,8 @@ std::string get_string_by_aggregation_type(StorageAggregateType type) { return "percentile_union"; case STORAGE_AGGREGATE_UNKNOWN: return "unknown"; + case STORAGE_AGGREGATE_AGG_STATE_UNION: + return "agg_state_union"; } return ""; } diff --git a/be/src/storage/aggregate_type.h b/be/src/storage/aggregate_type.h index 5c71c4670d720c..5cc372fb733b76 100644 --- a/be/src/storage/aggregate_type.h +++ b/be/src/storage/aggregate_type.h @@ -29,7 +29,9 @@ enum StorageAggregateType { STORAGE_AGGREGATE_BITMAP_UNION = 7, // Replace if and only if added value is not null STORAGE_AGGREGATE_REPLACE_IF_NOT_NULL = 8, - STORAGE_AGGREGATE_PERCENTILE_UNION = 9 + STORAGE_AGGREGATE_PERCENTILE_UNION = 9, + // Generic aggregate function's state union + STORAGE_AGGREGATE_AGG_STATE_UNION = 10 }; StorageAggregateType get_aggregation_type_by_string(const std::string& str); diff --git a/be/src/storage/chunk_aggregator.cpp b/be/src/storage/chunk_aggregator.cpp index 83ae8a6dae45de..cafd09ed43d6c5 100644 --- a/be/src/storage/chunk_aggregator.cpp +++ b/be/src/storage/chunk_aggregator.cpp @@ -18,13 +18,10 @@ #include "exec/sorting/sorting.h" #include "exprs/agg/aggregate_state_allocator.h" #include "gutil/casts.h" -#include "runtime/memory/mem_hook_allocator.h" #include "storage/column_aggregate_func.h" namespace starrocks { -static MemHookAllocator kDefaultChunkAggregatorAllocator = MemHookAllocator{}; - ChunkAggregator::ChunkAggregator(const starrocks::Schema* schema, uint32_t reserve_rows, uint32_t max_aggregate_rows, double factor, bool is_vertical_merge, bool is_key) : _schema(schema), @@ -173,7 +170,7 @@ void ChunkAggregator::aggregate() { _aggregate_loops[_aggregate_loops.size() - 1] += 1; } } - SCOPED_THREAD_LOCAL_AGG_STATE_ALLOCATOR_SETTER(&kDefaultChunkAggregatorAllocator); + SCOPED_THREAD_LOCAL_AGG_STATE_ALLOCATOR_SETTER(&kDefaultColumnAggregatorAllocator); // 3. Copy the selected key rows // 4. Aggregate the value rows for (int i = 0; i < _key_fields; ++i) { @@ -196,7 +193,7 @@ bool ChunkAggregator::is_finish() { void ChunkAggregator::aggregate_reset() { _aggregate_chunk = ChunkHelper::new_chunk(*_schema, _reserve_rows); _aggregate_rows = 0; - SCOPED_THREAD_LOCAL_AGG_STATE_ALLOCATOR_SETTER(&kDefaultChunkAggregatorAllocator); + SCOPED_THREAD_LOCAL_AGG_STATE_ALLOCATOR_SETTER(&kDefaultColumnAggregatorAllocator); for (int i = 0; i < _num_fields; ++i) { auto p = _aggregate_chunk->get_column_by_index(i).get(); _column_aggregator[i]->update_aggregate(p); @@ -210,7 +207,7 @@ void ChunkAggregator::aggregate_reset() { } ChunkPtr ChunkAggregator::aggregate_result() { - SCOPED_THREAD_LOCAL_AGG_STATE_ALLOCATOR_SETTER(&kDefaultChunkAggregatorAllocator); + SCOPED_THREAD_LOCAL_AGG_STATE_ALLOCATOR_SETTER(&kDefaultColumnAggregatorAllocator); for (int i = 0; i < _num_fields; ++i) { _column_aggregator[i]->finalize(); } diff --git a/be/src/storage/chunk_helper.cpp b/be/src/storage/chunk_helper.cpp index 788990a7277fb9..e948e9ca94afd9 100644 --- a/be/src/storage/chunk_helper.cpp +++ b/be/src/storage/chunk_helper.cpp @@ -141,6 +141,7 @@ Field ChunkHelper::convert_field(ColumnId id, const TabletColumn& c) { f.set_short_key_length(c.index_length()); f.set_aggregate_method(c.aggregation()); + f.set_agg_state_desc(c.get_agg_state_desc()); return f; } diff --git a/be/src/storage/column_aggregate_func.cpp b/be/src/storage/column_aggregate_func.cpp index 49bd2c2346f600..470887c3945e6d 100644 --- a/be/src/storage/column_aggregate_func.cpp +++ b/be/src/storage/column_aggregate_func.cpp @@ -243,19 +243,34 @@ class AggFuncBasedValueAggregator : public ValueColumnAggregatorBase { public: AggFuncBasedValueAggregator(const AggregateFunction* agg_func) : _agg_func(agg_func) { _state = static_cast(std::aligned_alloc(_agg_func->alignof_size(), _agg_func->size())); - _agg_func->create(&_func_ctx, _state); + // TODO: create a new FunctionContext by using specific FunctionContext::create_context + _func_ctx = new FunctionContext(); + _agg_func->create(_func_ctx, _state); + } + + AggFuncBasedValueAggregator(AggStateDesc* agg_state_desc, std::unique_ptr agg_state_unoin) + : _agg_func(agg_state_unoin.get()) { + _agg_state_unoin = std::move(agg_state_unoin); + _func_ctx = FunctionContext::create_context(nullptr, nullptr, agg_state_desc->get_return_type(), + agg_state_desc->get_arg_types()); + _state = static_cast(std::aligned_alloc(_agg_func->alignof_size(), _agg_func->size())); + _agg_func->create(_func_ctx, _state); } ~AggFuncBasedValueAggregator() override { + SCOPED_THREAD_LOCAL_AGG_STATE_ALLOCATOR_SETTER(&kDefaultColumnAggregatorAllocator); if (_state != nullptr) { - _agg_func->destroy(&_func_ctx, _state); + _agg_func->destroy(_func_ctx, _state); std::free(_state); } + if (_func_ctx != nullptr) { + delete _func_ctx; + } } void reset() override { - _agg_func->destroy(&_func_ctx, _state); - _agg_func->create(&_func_ctx, _state); + _agg_func->destroy(_func_ctx, _state); + _agg_func->create(_func_ctx, _state); } void update_aggregate(Column* agg) override { @@ -263,16 +278,16 @@ class AggFuncBasedValueAggregator : public ValueColumnAggregatorBase { reset(); } - void append_data(Column* agg) override { _agg_func->finalize_to_column(&_func_ctx, _state, agg); } + void append_data(Column* agg) override { _agg_func->finalize_to_column(_func_ctx, _state, agg); } // |data| is readonly. void aggregate_impl(int row, const ColumnPtr& data) override { - _agg_func->merge(&_func_ctx, data.get(), _state, row); + _agg_func->merge(_func_ctx, data.get(), _state, row); } // |data| is readonly. void aggregate_batch_impl(int start, int end, const ColumnPtr& input) override { - _agg_func->merge_batch_single_state(&_func_ctx, _state, input.get(), start, end - start); + _agg_func->merge_batch_single_state(_func_ctx, _state, input.get(), start, end - start); } bool need_deep_copy() const override { return false; }; @@ -305,9 +320,10 @@ class AggFuncBasedValueAggregator : public ValueColumnAggregatorBase { } private: - FunctionContext _func_ctx; const AggregateFunction* _agg_func; + FunctionContext* _func_ctx = nullptr; AggDataPtr _state{nullptr}; + std::unique_ptr _agg_state_unoin = nullptr; }; #define CASE_DEFAULT_WARNING(TYPE) \ @@ -417,6 +433,21 @@ ColumnAggregatorPtr ColumnAggregatorFactory::create_value_column_aggregator(cons } else { return p; } + } else if (method == STORAGE_AGGREGATE_AGG_STATE_UNION) { + if (field->get_agg_state_desc() == nullptr) { + CHECK(false) << "Bad agg state union method for column: " << field->name() + << " for its agg state type is null"; + return nullptr; + } + auto* agg_state_desc = field->get_agg_state_desc(); + auto func_name = agg_state_desc->get_func_name(); + DCHECK_EQ(field->is_nullable(), agg_state_desc->is_result_nullable()); + auto* agg_func = AggStateDesc::get_agg_state_func(agg_state_desc); + CHECK(agg_func != nullptr) << "Unknown aggregate function, name=" << func_name << ", type=" << type + << ", is_nullable=" << field->is_nullable() + << ", agg_state_desc=" << agg_state_desc->debug_string(); + // TODO(fixme): use agg_state_union instead of agg_func + return std::make_unique(std::move(agg_func)); } else { auto func_name = get_string_by_aggregation_type(method); // TODO(alvin): To keep compatible with old code, when type must not be the legacy type, diff --git a/be/src/storage/column_aggregate_func.h b/be/src/storage/column_aggregate_func.h index a0305d654e45bf..02bfbf11aa38db 100644 --- a/be/src/storage/column_aggregate_func.h +++ b/be/src/storage/column_aggregate_func.h @@ -15,9 +15,12 @@ #pragma once #include "column/field.h" +#include "runtime/memory/mem_hook_allocator.h" #include "storage/column_aggregator.h" namespace starrocks { + +static MemHookAllocator kDefaultColumnAggregatorAllocator = MemHookAllocator{}; class ColumnAggregatorFactory { public: static ColumnAggregatorPtr create_key_column_aggregator(const FieldPtr& field); diff --git a/be/src/storage/column_in_predicate.cpp b/be/src/storage/column_in_predicate.cpp index e6f5ec4cbde200..7f6613c8ff53c7 100644 --- a/be/src/storage/column_in_predicate.cpp +++ b/be/src/storage/column_in_predicate.cpp @@ -15,13 +15,16 @@ #include #include "column/column.h" +#include "column/column_helper.h" #include "column/nullable_column.h" +#include "column/vectorized_fwd.h" #include "gutil/casts.h" #include "roaring/roaring.hh" #include "storage/column_predicate.h" #include "storage/in_predicate_utils.h" #include "storage/rowset/bitmap_index_reader.h" #include "storage/rowset/bloom_filter.h" +#include "types/logical_type.h" namespace starrocks { @@ -365,6 +368,86 @@ class BinaryColumnInPredicate : public ColumnPredicate { ItemHashSet _slices; }; +class DictionaryCodeInPredicate : public ColumnPredicate { +private: + enum LogicOp { ASSIGN, AND, OR }; + +public: + DictionaryCodeInPredicate(const TypeInfoPtr& type_info, ColumnId id, const std::vector& operands, + size_t size) + : ColumnPredicate(type_info, id), _bit_mask(size) { + for (auto item : operands) { + DCHECK(item < size); + _bit_mask[item] = 1; + } + } + + ~DictionaryCodeInPredicate() override = default; + + template + inline void t_evaluate(const Column* column, uint8_t* sel, uint16_t from, uint16_t to) const { + const Int32Column* dict_code_column = down_cast(ColumnHelper::get_data_column(column)); + std::vector data = dict_code_column->get_data(); + Filter filter(to - from, 1); + + if (column->has_null()) { + const NullColumn* null_column = down_cast(column)->null_column().get(); + auto null_data = null_column->get_data(); + for (auto i = from; i < to; i++) { + auto index = data[i] >= _bit_mask.size() ? 0 : data[i]; + filter[i - from] = (!null_data[i]) & _bit_mask[index]; + } + } else { + for (auto i = from; i < to; i++) { + filter[i - from] = _bit_mask[data[i]]; + } + } + + for (auto i = from; i < to; i++) { + if constexpr (Op == ASSIGN) { + sel[i] = filter[i - from]; + } else if constexpr (Op == AND) { + sel[i] &= filter[i - from]; + } else { + sel[i] |= filter[i - from]; + } + } + } + + Status evaluate(const Column* column, uint8_t* selection, uint16_t from, uint16_t to) const override { + t_evaluate(column, selection, from, to); + return Status::OK(); + } + + Status evaluate_and(const Column* column, uint8_t* selection, uint16_t from, uint16_t to) const override { + t_evaluate(column, selection, from, to); + return Status::OK(); + } + + Status evaluate_or(const Column* column, uint8_t* selection, uint16_t from, uint16_t to) const override { + t_evaluate(column, selection, from, to); + return Status::OK(); + } + + bool can_vectorized() const override { return false; } + + PredicateType type() const override { return PredicateType::kInList; } + + Status convert_to(const ColumnPredicate** output, const TypeInfoPtr& target_type_info, + ObjectPool* obj_pool) const override { + const auto to_type = target_type_info->type(); + if (to_type == LogicalType::TYPE_INT) { + *output = this; + return Status::OK(); + } + CHECK(false) << "Not support, from_type=" << LogicalType::TYPE_INT << ", to_type=" << to_type; + return Status::OK(); + } + +private: + std::vector _bit_mask; +}; + template