diff --git a/be/src/column/array_column.cpp b/be/src/column/array_column.cpp index 1926f2070bab7..fdf99d881202e 100644 --- a/be/src/column/array_column.cpp +++ b/be/src/column/array_column.cpp @@ -18,7 +18,9 @@ #include "column/column_helper.h" #include "column/fixed_length_column.h" +#include "column/nullable_column.h" #include "column/vectorized_fwd.h" +#include "exprs/function_helper.h" #include "gutil/bits.h" #include "gutil/casts.h" #include "gutil/strings/fastmem.h" @@ -616,4 +618,84 @@ Status ArrayColumn::unfold_const_children(const starrocks::TypeDescriptor& type) return Status::OK(); } +size_t ArrayColumn::get_total_elements_num(const NullColumnPtr& null_column) const { + if (null_column == nullptr) { + return _elements->size(); + } + DCHECK_LE(_offsets->size() - 1, null_column->size()); + size_t elements_num = 0; + size_t num_rows = _offsets->size() - 1; + const auto& null_data = null_column->get_data(); + for (size_t i = 0; i < num_rows; i++) { + if (!null_data[i]) { + elements_num += _offsets->get_data()[i + 1] - _offsets->get_data()[i]; + } + } + return elements_num; +} + +template +bool ArrayColumn::compare_lengths_from_offsets(const UInt32Column& v1, const UInt32Column& v2, + const NullColumnPtr& null_column) { + [[maybe_unused]] uint8_t* null_data = nullptr; + if constexpr (!IgnoreNull) { + null_data = null_column->get_data().data(); + } + + size_t num_rows = v1.size() - 1; + if constexpr (ConstV1 && ConstV2) { + // if both are const column, we only compare the first row once + num_rows = 1; + } + bool result = true; + const auto& offsets_v1 = v1.get_data(); + const auto& offsets_v2 = v2.get_data(); + + for (size_t i = 0; i < num_rows && result; i++) { + [[maybe_unused]] uint32_t len1 = + (ConstV1) ? (offsets_v1[1] - offsets_v1[0]) : (offsets_v1[i + 1] - offsets_v1[i]); + [[maybe_unused]] uint32_t len2 = + (ConstV2) ? (offsets_v2[1] - offsets_v2[0]) : (offsets_v2[i + 1] - offsets_v2[i]); + if constexpr (IgnoreNull) { + result &= (len1 == len2); + } else { + if (!null_data[i]) { + result &= (len1 == len2); + } + } + } + return result; +} + +template +bool ArrayColumn::is_all_array_lengths_equal(const ColumnPtr& v1, const ColumnPtr& v2, + const NullColumnPtr& null_column) { + DCHECK(v1->is_array() && v2->is_array()); + DCHECK(!v1->is_nullable() && !v2->is_nullable()); + + if (v1->size() != v2->size()) { + return false; + } + auto data_v1 = FunctionHelper::get_data_column_of_const(v1); + auto data_v2 = FunctionHelper::get_data_column_of_const(v2); + auto* array_v1 = down_cast(data_v1.get()); + auto* array_v2 = down_cast(data_v2.get()); + const auto& offsets_v1 = array_v1->offsets(); + const auto& offsets_v2 = array_v2->offsets(); + if (v1->is_constant() && v2->is_constant()) { + return compare_lengths_from_offsets(offsets_v1, offsets_v2, null_column); + } else if (v1->is_constant() && !v2->is_constant()) { + return compare_lengths_from_offsets(offsets_v1, offsets_v2, null_column); + } else if (!v1->is_constant() && v2->is_constant()) { + return compare_lengths_from_offsets(offsets_v1, offsets_v2, null_column); + } + + return compare_lengths_from_offsets(offsets_v1, offsets_v2, null_column); +} + +template bool ArrayColumn::is_all_array_lengths_equal(const ColumnPtr& v1, const ColumnPtr& v2, + const NullColumnPtr& null_data); +template bool ArrayColumn::is_all_array_lengths_equal(const ColumnPtr& v1, const ColumnPtr& v2, + const NullColumnPtr& null_data); + } // namespace starrocks diff --git a/be/src/column/array_column.h b/be/src/column/array_column.h index 007b563e8878c..1398666d2c16f 100644 --- a/be/src/column/array_column.h +++ b/be/src/column/array_column.h @@ -145,7 +145,7 @@ class ArrayColumn final : public ColumnFactory { void put_mysql_row_buffer(MysqlRowBuffer* buf, size_t idx, bool is_binary_protocol = false) const override; - std::string get_name() const override { return "array"; } + std::string get_name() const override { return "array-" + _elements->get_name(); } Datum get(size_t idx) const override; @@ -173,6 +173,7 @@ class ArrayColumn final : public ColumnFactory { const UInt32Column& offsets() const { return *_offsets; } UInt32Column::Ptr& offsets_column() { return _offsets; } + UInt32Column::Ptr offsets_column() const { return _offsets; } bool is_nullable() const override { return false; } @@ -195,7 +196,19 @@ class ArrayColumn final : public ColumnFactory { Status unfold_const_children(const starrocks::TypeDescriptor& type) override; + // get the number of all non-null elements + size_t get_total_elements_num(const NullColumnPtr& null_column) const; + + // check if the length of each array in two columns is equal + // v1 and v2 must be one of ArrayColumn or Const(ArrayColumn) + template + static bool is_all_array_lengths_equal(const ColumnPtr& v1, const ColumnPtr& v2, const NullColumnPtr& null_data); + private: + template + static bool compare_lengths_from_offsets(const UInt32Column& v1, const UInt32Column& v2, + const NullColumnPtr& null_data); + // Elements must be NullableColumn to facilitate handling nested types. ColumnPtr _elements; // Offsets column will store the start position of every array element. @@ -205,4 +218,9 @@ class ArrayColumn final : public ColumnFactory { UInt32Column::Ptr _offsets; }; +extern template bool ArrayColumn::is_all_array_lengths_equal(const ColumnPtr& v1, const ColumnPtr& v2, + const NullColumnPtr& null_data); +extern template bool ArrayColumn::is_all_array_lengths_equal(const ColumnPtr& v1, const ColumnPtr& v2, + const NullColumnPtr& null_data); + } // namespace starrocks diff --git a/be/src/column/chunk.cpp b/be/src/column/chunk.cpp index 887ed8f480658..343b66c5ec6ec 100644 --- a/be/src/column/chunk.cpp +++ b/be/src/column/chunk.cpp @@ -142,6 +142,7 @@ void Chunk::append_vector_column(ColumnPtr column, const FieldPtr& field, SlotId } void Chunk::append_column(ColumnPtr column, SlotId slot_id) { + DCHECK(!_slot_id_to_index.contains(slot_id)); _slot_id_to_index[slot_id] = _columns.size(); _columns.emplace_back(std::move(column)); check_or_die(); diff --git a/be/src/column/column.cpp b/be/src/column/column.cpp index 09c3da865e788..25c7db43901b7 100644 --- a/be/src/column/column.cpp +++ b/be/src/column/column.cpp @@ -67,7 +67,7 @@ StatusOr Column::upgrade_helper_func(ColumnPtr* col) { } bool Column::empty_null_in_complex_column(const Filter& null_data, const Buffer& offsets) { - DCHECK(null_data.size() == this->size()); + DCHECK_EQ(null_data.size(), this->size()); if (!is_array() && !is_map()) { throw std::runtime_error("empty_null_in_complex_column() only works for array and map column."); } diff --git a/be/src/column/column.h b/be/src/column/column.h index e889fa272f81f..1b6ad54ae91f5 100644 --- a/be/src/column/column.h +++ b/be/src/column/column.h @@ -177,6 +177,7 @@ class Column { } return dest; } + // Update elements to default value which hit by the filter virtual void fill_default(const Filter& filter) = 0; diff --git a/be/src/column/const_column.h b/be/src/column/const_column.h index 04adbf617ca97..9492d48d58612 100644 --- a/be/src/column/const_column.h +++ b/be/src/column/const_column.h @@ -49,6 +49,7 @@ class ConstColumn final : public ColumnFactory { bool is_nullable() const override { return _data->is_nullable(); } bool is_json() const override { return _data->is_json(); } + bool is_array() const override { return _data->is_array(); } bool is_null(size_t index) const override { return _data->is_null(0); } diff --git a/be/src/column/nullable_column.h b/be/src/column/nullable_column.h index 44906cc51cfc5..87701f2a41bfd 100644 --- a/be/src/column/nullable_column.h +++ b/be/src/column/nullable_column.h @@ -82,6 +82,7 @@ class NullableColumn : public ColumnFactory { bool is_nullable() const override { return true; } bool is_json() const override { return _data_column->is_json(); } + bool is_array() const override { return _data_column->is_array(); } bool is_null(size_t index) const override { DCHECK_EQ(_null_column->size(), _data_column->size()); diff --git a/be/src/exprs/array_functions.cpp b/be/src/exprs/array_functions.cpp index 8823d46b3f991..b9ea4b989b55e 100644 --- a/be/src/exprs/array_functions.cpp +++ b/be/src/exprs/array_functions.cpp @@ -1104,7 +1104,6 @@ StatusOr ArrayFunctions::concat(FunctionContext* ctx, const Columns& RETURN_IF_COLUMNS_ONLY_NULL(columns); auto num_rows = columns[0]->size(); - // compute nulls NullColumnPtr nulls; for (auto& column : columns) { diff --git a/be/src/exprs/array_map_expr.cpp b/be/src/exprs/array_map_expr.cpp index 9f97cc38fa2bd..2b19370b09719 100644 --- a/be/src/exprs/array_map_expr.cpp +++ b/be/src/exprs/array_map_expr.cpp @@ -16,17 +16,24 @@ #include +#include +#include + #include "column/array_column.h" #include "column/chunk.h" #include "column/column_helper.h" #include "column/const_column.h" #include "column/fixed_length_column.h" +#include "column/nullable_column.h" #include "column/vectorized_fwd.h" +#include "common/constexpr.h" +#include "common/statusor.h" #include "exprs/anyval_util.h" #include "exprs/expr_context.h" #include "exprs/function_helper.h" #include "exprs/lambda_function.h" #include "runtime/user_function_cache.h" +#include "simd/simd.h" #include "storage/chunk_helper.h" namespace starrocks { @@ -34,123 +41,343 @@ ArrayMapExpr::ArrayMapExpr(const TExprNode& node) : Expr(node, false) {} ArrayMapExpr::ArrayMapExpr(TypeDescriptor type) : Expr(std::move(type), false) {} +Status ArrayMapExpr::prepare(RuntimeState* state, ExprContext* context) { + for (int i = 1; i < _children.size(); ++i) { + RETURN_IF_ERROR(_children[i]->prepare(state, context)); + } + + auto lambda_expr = down_cast(_children[0]); + + LambdaFunction::ExtractContext extract_ctx; + // assign slot ids to outer common exprs starting with max_used_slot_id + 1 + extract_ctx.next_slot_id = lambda_expr->max_used_slot_id() + 1; + + RETURN_IF_ERROR(lambda_expr->extract_outer_common_exprs(state, &extract_ctx)); + _outer_common_exprs.swap(extract_ctx.outer_common_exprs); + for (auto [_, expr] : _outer_common_exprs) { + RETURN_IF_ERROR(expr->prepare(state, context)); + } + RETURN_IF_ERROR(lambda_expr->prepare(state, context)); + + return Status::OK(); +} + +template +StatusOr ArrayMapExpr::evaluate_lambda_expr(ExprContext* context, Chunk* chunk, + const std::vector& input_elements, + const NullColumnPtr& result_null_column) { + // create a new chunk to evaluate the lambda expression + auto cur_chunk = std::make_shared(); + // 1. evaluate outer common expressions + for (const auto& [slot_id, expr] : _outer_common_exprs) { + ASSIGN_OR_RETURN(auto col, context->evaluate(expr, chunk)); + chunk->append_column(col, slot_id); + } + + auto lambda_func = dynamic_cast(_children[0]); + std::vector capture_slot_ids; + lambda_func->get_captured_slot_ids(&capture_slot_ids); + + // 2. check captured columns' size + for (auto slot_id : capture_slot_ids) { + DCHECK(slot_id > 0); + auto captured_column = chunk->get_column_by_slot_id(slot_id); + if (UNLIKELY(captured_column->size() < input_elements[0]->size())) { + return Status::InternalError(fmt::format("The size of the captured column {} is less than array's size.", + captured_column->get_name())); + } + } + + UInt32Column::Ptr aligned_offsets = nullptr; + size_t null_rows = result_null_column ? SIMD::count_nonzero(result_null_column->get_data()) : 0; + + std::vector arguments_ids; + int argument_num = lambda_func->get_lambda_arguments_ids(&arguments_ids); + + // 3. prepare arguments of lambda expr, put all arguments into cur_chunk + for (int i = 0; i < argument_num; ++i) { + auto data_column = FunctionHelper::get_data_column_of_const(input_elements[i]); + auto array_column = down_cast(data_column.get()); + auto elements_column = array_column->elements_column(); + UInt32Column::Ptr offsets_column = array_column->offsets_column(); + if constexpr (!all_const_input) { + if (input_elements[i]->is_constant()) { + size_t elements_num = array_column->get_element_size(0); + elements_column = elements_column->clone(); + offsets_column = UInt32Column::create(); + // replicate N time and ignore null + size_t repeat_times = input_elements[i]->size() - null_rows; + size_t offset = elements_num; + offsets_column->append(0); + offsets_column->append(offset); + for (size_t i = 1; i < repeat_times; i++) { + elements_column->append(*elements_column, 0, elements_num); + offset += elements_num; + offsets_column->append(offset); + } + } else { + data_column->empty_null_in_complex_column(result_null_column->get_data(), + array_column->offsets().get_data()); + elements_column = down_cast(data_column.get())->elements_column(); + } + } + + if (aligned_offsets == nullptr) { + aligned_offsets = offsets_column; + } + + // if lambda expr doesn't rely on argument, we don't need to put it into cur_chunk + if constexpr (!independent_lambda_expr) { + cur_chunk->append_column(elements_column, arguments_ids[i]); + } + } + DCHECK(aligned_offsets != nullptr); + + // 4. prepare outer common exprs + for (const auto& [slot_id, expr] : _outer_common_exprs) { + auto column = chunk->get_column_by_slot_id(slot_id); + column = ColumnHelper::unpack_and_duplicate_const_column(column->size(), column); + if constexpr (independent_lambda_expr) { + // if lambda expr doesn't rely on arguments, we don't need to align offset + cur_chunk->append_column(column, slot_id); + } else { + cur_chunk->append_column(column->replicate(aligned_offsets->get_data()), slot_id); + } + } + + // 5. prepare capture columns + for (auto slot_id : capture_slot_ids) { + if (cur_chunk->is_slot_exist(slot_id)) { + continue; + } + auto captured_column = chunk->get_column_by_slot_id(slot_id); + if constexpr (independent_lambda_expr) { + cur_chunk->append_column(captured_column, slot_id); + } else { + cur_chunk->append_column(captured_column->replicate(aligned_offsets->get_data()), slot_id); + } + } + + // 6. evaluate lambda expr + ColumnPtr column = nullptr; + if constexpr (independent_lambda_expr) { + // if lambda expr doesn't rely on arguments, we evaluate it first, and then align offsets + ColumnPtr tmp_col; + if (!cur_chunk->has_columns()) { + ASSIGN_OR_RETURN(tmp_col, context->evaluate(_children[0], nullptr)); + } else { + ASSIGN_OR_RETURN(tmp_col, context->evaluate(_children[0], cur_chunk.get())); + } + tmp_col->check_or_die(); + column = tmp_col->replicate(aligned_offsets->get_data()); + column = ColumnHelper::align_return_type(column, type().children[0], column->size(), true); + } else { + // if all input arguments are const, + if constexpr (all_const_input) { + ASSIGN_OR_RETURN(auto tmp_col, context->evaluate(_children[0], cur_chunk.get())); + tmp_col->check_or_die(); + // if result is a const column, we should unpack it first and make it to be the elements column of array column + column = ColumnHelper::unpack_and_duplicate_const_column(tmp_col->size(), tmp_col); + column = ColumnHelper::align_return_type(column, type().children[0], column->size(), true); + } else { + ChunkAccumulator accumulator(DEFAULT_CHUNK_SIZE); + RETURN_IF_ERROR(accumulator.push(std::move(cur_chunk))); + accumulator.finalize(); + while (auto tmp_chunk = accumulator.pull()) { + tmp_chunk->check_or_die(); + ASSIGN_OR_RETURN(auto tmp_col, context->evaluate(_children[0], tmp_chunk.get())); + tmp_col->check_or_die(); + tmp_col = ColumnHelper::align_return_type(tmp_col, type().children[0], tmp_chunk->num_rows(), true); + if (column == nullptr) { + column = tmp_col; + } else { + column->append(*tmp_col); + } + } + } + } + DCHECK(column != nullptr); + column = ColumnHelper::cast_to_nullable_column(column); + + if constexpr (all_const_input) { + // if all input arguments are const, we can return a const column + auto data_column = FunctionHelper::get_data_column_of_const(column); + + aligned_offsets = UInt32Column::create(); + aligned_offsets->append(0); + aligned_offsets->append(data_column->size()); + auto array_column = + std::make_shared(data_column, ColumnHelper::as_column(aligned_offsets)); + array_column->check_or_die(); + ColumnPtr result_column = array_column; + if (result_null_column != nullptr) { + result_column = NullableColumn::create(std::move(array_column), result_null_column); + result_column->check_or_die(); + } + result_column = ConstColumn::create(result_column, chunk->num_rows()); + result_column->check_or_die(); + return result_column; + } else { + auto array_column = std::make_shared( + column, ColumnHelper::as_column(aligned_offsets->clone_shared())); + array_column->check_or_die(); + if (result_null_column != nullptr) { + return NullableColumn::create(std::move(array_column), result_null_column); + } + return array_column; + } +} + // The input array column maybe nullable, so first remove the wrap of nullable property. // The result of lambda expressions do not change the offsets of the current array and the null map. // NOTE the return column must be of the return type. StatusOr ArrayMapExpr::evaluate_checked(ExprContext* context, Chunk* chunk) { std::vector input_elements; - NullColumnPtr null_column = nullptr; bool is_single_nullable_child = false; - ArrayColumn* input_array = nullptr; - ColumnPtr input_array_ptr_ref = nullptr; // hold shared_ptr to avoid early deleted. - // for many valid arguments: - // if one of them is a null literal, the result is a null literal; - // if one of them is only null, then results are null; - // unfold const columns. - // make sure all inputs have the same offsets. - // TODO(fzh): support several arrays with different offsets and set null for non-equal size of arrays. + + NullColumnPtr result_null_column = nullptr; + bool all_input_is_constant = true; + for (int i = 1; i < _children.size(); ++i) { ASSIGN_OR_RETURN(auto child_col, context->evaluate(_children[i], chunk)); // the column is a null literal. if (child_col->only_null()) { return ColumnHelper::align_return_type(child_col, type(), chunk->num_rows(), true); } - // no optimization for const columns. - child_col = ColumnHelper::unpack_and_duplicate_const_column(child_col->size(), child_col); - - auto column = child_col; - if (child_col->is_nullable()) { - auto nullable = down_cast(child_col.get()); - DCHECK(nullable != nullptr); - column = nullable->data_column(); - // empty null array with non-zero elements - column->empty_null_in_complex_column(nullable->null_column()->get_data(), - down_cast(column.get())->offsets().get_data()); - if (null_column) { + + bool is_const = child_col->is_constant(); + bool is_nullable = child_col->is_nullable(); + size_t num_rows = child_col->size(); + all_input_is_constant &= is_const; + + auto data_column = child_col; + if (is_const) { + auto const_column = down_cast(child_col.get()); + data_column = const_column->data_column(); + } + + if (is_nullable) { + auto nullable_column = down_cast(data_column.get()); + DCHECK(nullable_column); + data_column = nullable_column->data_column(); + + auto null_column = nullable_column->null_column(); + if (is_const) { + // if null_column is from const_column, should unpack + null_column->assign(num_rows, 0); + } + + if (result_null_column) { is_single_nullable_child = false; - null_column = FunctionHelper::union_null_column(nullable->null_column(), null_column); // merge null + result_null_column = FunctionHelper::union_null_column(null_column, result_null_column); } else { is_single_nullable_child = true; - null_column = nullable->null_column(); + result_null_column = null_column; } } - DCHECK(column->is_array()); - auto cur_array = down_cast(column.get()); + DCHECK(data_column->is_array() && !data_column->is_nullable()); - if (input_array == nullptr) { - input_array = cur_array; - input_array_ptr_ref = column; - } else { - if (UNLIKELY(!ColumnHelper::offsets_equal(cur_array->offsets_column(), input_array->offsets_column()))) { + ColumnPtr column = data_column; + if (is_const) { + // keep it as a const array column in input_elements + column = ConstColumn::create(data_column, num_rows); + } + + // check each array's lengths in input_elements + if (!input_elements.empty()) { + const auto& first_input = input_elements[0]; + + bool is_array_lengths_valid = + result_null_column + ? ArrayColumn::is_all_array_lengths_equal(first_input, column, result_null_column) + : ArrayColumn::is_all_array_lengths_equal(first_input, column, result_null_column); + if (!is_array_lengths_valid) { return Status::InternalError("Input array element's size is not equal in array_map()."); } } - input_elements.push_back(cur_array->elements_column()); + + input_elements.emplace_back(column); } if (is_single_nullable_child) { - DCHECK(null_column != nullptr); + DCHECK(result_null_column != nullptr); // If there are more than one nullable children, the nullable column has been cloned when calling // union_null_column to merge, so only one nullable child needs to be cloned. - null_column = ColumnHelper::as_column(null_column->clone_shared()); + result_null_column = ColumnHelper::as_column(result_null_column->clone_shared()); } ColumnPtr column = nullptr; - if (input_array->elements_column()->empty()) { // arrays may be null or empty + size_t null_rows = result_null_column ? SIMD::count_nonzero(result_null_column->get_data()) : 0; + + if (null_rows == input_elements[0]->size()) { + // if all input rows are null, just return a const nullable array column as result column = ColumnHelper::create_column(type().children[0], true); // array->elements must be of return array->elements' type - } else { - // construct a new chunk to evaluate the lambda expression. - auto cur_chunk = std::make_shared(); - // put all arguments into the new chunk - std::vector arguments_ids; - auto lambda_func = dynamic_cast(_children[0]); - int argument_num = lambda_func->get_lambda_arguments_ids(&arguments_ids); - DCHECK(argument_num == input_elements.size()); - for (int i = 0; i < argument_num; ++i) { - cur_chunk->append_column(input_elements[i], arguments_ids[i]); // column ref - } - // put captured columns into the new chunk aligning with the first array's offsets - std::vector slot_ids; - _children[0]->get_slot_ids(&slot_ids); - for (auto id : slot_ids) { - DCHECK(id > 0); - auto captured = chunk->get_column_by_slot_id(id); - if (UNLIKELY(captured->size() < input_array->size())) { - return Status::InternalError(fmt::format( - "The size of the captured column {} is less than array's size.", captured->get_name())); - } - cur_chunk->append_column(captured->replicate(input_array->offsets_column()->get_data()), id); - } - if (cur_chunk->num_rows() <= chunk->num_rows() * 8) { - ASSIGN_OR_RETURN(column, context->evaluate(_children[0], cur_chunk.get())); - column = ColumnHelper::align_return_type(column, type().children[0], cur_chunk->num_rows(), true); - } else { // split large chunks into small ones to avoid too large or various batch_size - ChunkAccumulator accumulator(DEFAULT_CHUNK_SIZE); - RETURN_IF_ERROR(accumulator.push(std::move(cur_chunk))); - accumulator.finalize(); - while (auto tmp_chunk = accumulator.pull()) { - ASSIGN_OR_RETURN(auto tmp_col, context->evaluate(_children[0], tmp_chunk.get())); - tmp_col = ColumnHelper::align_return_type(tmp_col, type().children[0], tmp_chunk->num_rows(), true); - if (column == nullptr) { - column = tmp_col; - } else { - column->append(*tmp_col); - } - } + column->append_default(1); + auto aligned_offsets = UInt32Column::create(0); + aligned_offsets->append(0); + aligned_offsets->append(1); + auto array_col = std::make_shared(column, aligned_offsets); + array_col->check_or_die(); + if (result_null_column) { + result_null_column->resize(1); + auto result = ConstColumn::create(NullableColumn::create(std::move(array_col), result_null_column), + chunk->num_rows()); + result->check_or_die(); + return result; } + auto result = ConstColumn::create(std::move(array_col), chunk->num_rows()); + result->check_or_die(); + return result; + } - // construct the result array - DCHECK(column != nullptr); - column = ColumnHelper::cast_to_nullable_column(column); + size_t total_elements_num = + down_cast(FunctionHelper::get_data_column_of_const(input_elements[0]).get()) + ->get_total_elements_num(result_null_column); + + if (total_elements_num == 0) { + // if all input rows are empty arrays, return a const empty array column as result + column = ColumnHelper::create_column(type().children[0], true); + auto aligned_offsets = UInt32Column::create(0); + aligned_offsets->append_default(2); + auto array_col = std::make_shared(column, aligned_offsets); + array_col->check_or_die(); + auto result = ConstColumn::create(std::move(array_col), chunk->num_rows() - null_rows); + result->check_or_die(); + return result; } - // attach offsets - auto array_col = std::make_shared( - column, ColumnHelper::as_column(input_array->offsets_column()->clone_shared())); - if (null_column != nullptr) { - return NullableColumn::create(std::move(array_col), null_column); + + auto lambda_func = dynamic_cast(_children[0]); + bool is_lambda_expr_independent = lambda_func->is_lambda_expr_independent(); + if (all_input_is_constant && is_lambda_expr_independent) { + return evaluate_lambda_expr(context, chunk, input_elements, result_null_column); + } else if (all_input_is_constant && !is_lambda_expr_independent) { + return evaluate_lambda_expr(context, chunk, input_elements, result_null_column); + } else if (!all_input_is_constant && is_lambda_expr_independent) { + return evaluate_lambda_expr(context, chunk, input_elements, result_null_column); + } else { + return evaluate_lambda_expr(context, chunk, input_elements, result_null_column); + } +} + +std::string ArrayMapExpr::debug_string() const { + std::stringstream out; + auto expr_debug_string = Expr::debug_string(); + out << "array_map ("; + for (int i = 0; i < _children.size(); i++) { + out << (i == 0 ? "" : ", ") << _children[i]->debug_string(); + } + out << ")"; + return out.str(); +} + +int ArrayMapExpr::get_slot_ids(std::vector* slot_ids) const { + int num = Expr::get_slot_ids(slot_ids); + for (const auto& [slot_id, _] : _outer_common_exprs) { + slot_ids->push_back(slot_id); + num++; } - return array_col; + return num; } } // namespace starrocks diff --git a/be/src/exprs/array_map_expr.h b/be/src/exprs/array_map_expr.h index 98718103564ad..85e4c5435f3db 100644 --- a/be/src/exprs/array_map_expr.h +++ b/be/src/exprs/array_map_expr.h @@ -16,7 +16,9 @@ #include #include +#include +#include "column/nullable_column.h" #include "common/global_types.h" #include "common/object_pool.h" #include "exprs/column_ref.h" @@ -34,8 +36,19 @@ class ArrayMapExpr final : public Expr { // for tests explicit ArrayMapExpr(TypeDescriptor type); + Status prepare(RuntimeState* state, ExprContext* context) override; Expr* clone(ObjectPool* pool) const override { return pool->add(new ArrayMapExpr(*this)); } StatusOr evaluate_checked(ExprContext* context, Chunk* ptr) override; + std::string debug_string() const override; + int get_slot_ids(std::vector* slot_ids) const override; + +private: + template + StatusOr evaluate_lambda_expr(ExprContext* context, Chunk* chunk, + const std::vector& arguments, const NullColumnPtr& null_column); + + // use map to make sure the order of execution + std::map _outer_common_exprs; }; } // namespace starrocks diff --git a/be/src/exprs/expr.cpp b/be/src/exprs/expr.cpp index cd805dcff4e61..1e020edd1b01e 100644 --- a/be/src/exprs/expr.cpp +++ b/be/src/exprs/expr.cpp @@ -73,6 +73,7 @@ #include "exprs/match_expr.h" #include "exprs/placeholder_ref.h" #include "exprs/subfield_expr.h" +#include "gutil/casts.h" #include "gutil/strings/substitute.h" #include "runtime/runtime_state.h" #include "types/logical_type.h" diff --git a/be/src/exprs/expr.h b/be/src/exprs/expr.h index 65182c7c17495..6d01eadc6e736 100644 --- a/be/src/exprs/expr.h +++ b/be/src/exprs/expr.h @@ -66,6 +66,7 @@ class JITContext; class JITExpr; struct JitScore; struct LLVMDatum; +class LambdaFunction; // This is the superclass of all expr evaluation nodes. class Expr { @@ -119,6 +120,8 @@ class Expr { bool is_monotonic() const { return _is_monotonic; } bool is_cast_expr() const { return _node_type == TExprNodeType::CAST_EXPR; } + virtual bool is_lambda_function() const { return false; } + virtual bool is_literal() const { return false; } // In most time, this field is passed from FE // Sometimes we want to construct expr on BE implicitly and we have knowledge about `monotonicity` @@ -197,6 +200,7 @@ class Expr { static void close(const std::vector& exprs); virtual std::string debug_string() const; + static std::string debug_string(const std::vector& exprs); static std::string debug_string(const std::vector& ctxs); @@ -270,6 +274,8 @@ class Expr { friend class Literal; friend class ExprContext; friend class ColumnPredicateRewriter; + friend class LambdaFunction; + friend class ArrayMapExpr; explicit Expr(TypeDescriptor type); explicit Expr(const TExprNode& node); diff --git a/be/src/exprs/lambda_function.cpp b/be/src/exprs/lambda_function.cpp index 3dff0cee3bb3e..d2a186e0a521d 100644 --- a/be/src/exprs/lambda_function.cpp +++ b/be/src/exprs/lambda_function.cpp @@ -16,33 +16,116 @@ #include +#include #include #include "column/chunk.h" #include "column/column_helper.h" #include "column/vectorized_fwd.h" +#include "exec/exec_node.h" +#include "exprs/column_ref.h" +#include "exprs/expr.h" #include "exprs/expr_context.h" +#include "util/defer_op.h" namespace starrocks { LambdaFunction::LambdaFunction(const TExprNode& node) : Expr(node, false), _common_sub_expr_num(node.output_column) {} -Status LambdaFunction::prepare(starrocks::RuntimeState* state, starrocks::ExprContext* context) { - RETURN_IF_ERROR(Expr::prepare(state, context)); - if (_is_prepared) { +Status LambdaFunction::extract_outer_common_exprs(RuntimeState* state, Expr* expr, ExtractContext* ctx) { + if (expr->is_lambda_function()) { + auto lambda_function = static_cast(expr); + RETURN_IF_ERROR(lambda_function->collect_lambda_argument_ids()); + for (auto argument_id : lambda_function->get_lambda_arguments_ids()) { + ctx->lambda_arguments.insert(argument_id); + } + RETURN_IF_ERROR(lambda_function->collect_common_sub_exprs()); + for (auto slot_id : lambda_function->get_common_sub_expr_ids()) { + ctx->common_sub_expr_ids.insert(slot_id); + } + } + + DeferOp defer([&]() { + if (expr->is_lambda_function()) { + auto lambda_function = static_cast(expr); + for (auto argument_id : lambda_function->get_lambda_arguments_ids()) { + ctx->lambda_arguments.erase(argument_id); + } + for (auto slot_id : lambda_function->get_common_sub_expr_ids()) { + ctx->common_sub_expr_ids.erase(slot_id); + } + } + }); + + int child_num = expr->get_num_children(); + std::vector slot_ids; + + for (int i = 0; i < child_num; i++) { + auto child = expr->get_child(i); + + RETURN_IF_ERROR(extract_outer_common_exprs(state, child, ctx)); + // if child is a slotref or a lambda function or a literal, we can't replace it. + if (child->is_slotref() || child->is_lambda_function() || child->is_literal()) { + continue; + } + + slot_ids.clear(); + child->get_slot_ids(&slot_ids); + bool is_independent = std::all_of(slot_ids.begin(), slot_ids.end(), [ctx](const SlotId& id) { + return ctx->lambda_arguments.find(id) == ctx->lambda_arguments.end() && + ctx->common_sub_expr_ids.find(id) == ctx->common_sub_expr_ids.end(); + }); + + if (is_independent) { + SlotId slot_id = ctx->next_slot_id++; + ColumnRef* column_ref = state->obj_pool()->add(new ColumnRef(child->type(), slot_id)); + VLOG(1) << "add new common expr, slot_id: " << slot_id << ", new expr: " << column_ref->debug_string() + << ", old expr: " << child->debug_string(); + expr->_children[i] = column_ref; + ctx->outer_common_exprs.insert({slot_id, child}); + } + } + + return Status::OK(); +} + +Status LambdaFunction::extract_outer_common_exprs(RuntimeState* state, ExtractContext* ctx) { + RETURN_IF_ERROR(extract_outer_common_exprs(state, this, ctx)); + return Status::OK(); +} + +Status LambdaFunction::collect_lambda_argument_ids() { + if (!_arguments_ids.empty()) { return Status::OK(); } - _is_prepared = true; - // common sub expressions include 2 parts in a pair: (slot id, expression) const int child_num = get_num_children() - 2 * _common_sub_expr_num; - // collect the slot ids of lambda arguments - for (int i = 1; i < child_num; ++i) { - get_child(i)->get_slot_ids(&_arguments_ids); + for (int i = 1; i < child_num; i++) { + _children[i]->get_slot_ids(&_arguments_ids); } if (child_num - 1 != _arguments_ids.size()) { return Status::InternalError(fmt::format("Lambda arguments get ids failed, just get {} ids from {} arguments.", _arguments_ids.size(), child_num - 1)); } + return Status::OK(); +} + +SlotId LambdaFunction::max_used_slot_id() const { + std::vector ids; + for (auto child : _children) { + child->get_slot_ids(&ids); + } + DCHECK(!ids.empty()); + return *std::max_element(ids.begin(), ids.end()); +} + +Status LambdaFunction::collect_common_sub_exprs() { + if (!_common_sub_expr_ids.empty()) { + return Status::OK(); + } + + // common sub expressions include 2 parts in a pair: (slot id, expression) + const int child_num = get_num_children() - 2 * _common_sub_expr_num; + // sorted common sub expressions so that the later expressions can reference the previous ones. for (auto i = child_num; i < child_num + _common_sub_expr_num; ++i) { get_child(i)->get_slot_ids(&_common_sub_expr_ids); @@ -62,9 +145,36 @@ Status LambdaFunction::prepare(starrocks::RuntimeState* state, starrocks::ExprCo _common_sub_expr.size(), _common_sub_expr_num)); } + return Status::OK(); +} + +Status LambdaFunction::prepare(starrocks::RuntimeState* state, starrocks::ExprContext* context) { + RETURN_IF_ERROR(Expr::prepare(state, context)); + if (_is_prepared) { + return Status::OK(); + } + _is_prepared = true; + + RETURN_IF_ERROR(collect_lambda_argument_ids()); + RETURN_IF_ERROR(collect_common_sub_exprs()); // get slot ids from the lambda expression get_child(0)->get_slot_ids(&_captured_slot_ids); + _is_lambda_expr_independent = true; + + // if all captured slot ids are not in lambda arguments ids, then lambda expr is independent. + // for example, + // in array_map(x->id, arg1), the lambda expr `id` is independent. + // but in array_map(x->arg1+id, arg1), the lambda expr `arg1+id` is not independent. + for (size_t i = 0; i < _captured_slot_ids.size() && _is_lambda_expr_independent; ++i) { + for (const auto& arguments_id : _arguments_ids) { + if (_captured_slot_ids[i] == arguments_id) { + _is_lambda_expr_independent = false; + break; + } + } + } + // remove current argument ids and duplicated ids from captured_slot_ids std::map captured_mask; int valid_id = 0; @@ -105,4 +215,24 @@ StatusOr LambdaFunction::evaluate_checked(ExprContext* context, Chunk return get_child(0)->evaluate_checked(context, chunk); } +int LambdaFunction::get_slot_ids(std::vector* slot_ids) const { + if (_is_prepared) { + slot_ids->insert(slot_ids->end(), _captured_slot_ids.begin(), _captured_slot_ids.end()); + slot_ids->insert(slot_ids->end(), _arguments_ids.begin(), _arguments_ids.end()); + return _captured_slot_ids.size() + _arguments_ids.size(); + } else { + return Expr::get_slot_ids(slot_ids); + } +} + +std::string LambdaFunction::debug_string() const { + std::stringstream out; + auto expr_debug_string = Expr::debug_string(); + out << "LambaFunction ("; + for (int i = 0; i < _children.size(); i++) { + out << (i == 0 ? "lambda expr: " : " input argument: ") << _children[i]->debug_string(); + } + out << ")"; + return out.str(); +} } // namespace starrocks diff --git a/be/src/exprs/lambda_function.h b/be/src/exprs/lambda_function.h index 2be3d08e0ccd0..c6bf1928299f8 100644 --- a/be/src/exprs/lambda_function.h +++ b/be/src/exprs/lambda_function.h @@ -16,11 +16,11 @@ #include #include +#include #include #include "common/global_types.h" #include "common/object_pool.h" -#include "exprs/column_ref.h" #include "exprs/expr.h" #include "glog/logging.h" #include "gutil/casts.h" @@ -45,26 +45,63 @@ class LambdaFunction final : public Expr { StatusOr evaluate_checked(ExprContext* context, Chunk* ptr) override; - // the slot ids of lambda expression may be originally from the arguments of this lambda function - // or its parent lambda functions, or captured columns, remove the first one. - int get_slot_ids(std::vector* slot_ids) const override { + int get_slot_ids(std::vector* slot_ids) const override; + + int get_captured_slot_ids(std::vector* slot_ids) const { + DCHECK(_is_prepared); slot_ids->insert(slot_ids->end(), _captured_slot_ids.begin(), _captured_slot_ids.end()); return _captured_slot_ids.size(); } - int get_lambda_arguments_ids(std::vector* ids) { + int get_lambda_arguments_ids(std::vector* ids) const { + DCHECK(_is_prepared); ids->assign(_arguments_ids.begin(), _arguments_ids.end()); return _arguments_ids.size(); } + const std::vector& get_lambda_arguments_ids() const { return _arguments_ids; } + const std::vector& get_common_sub_expr_ids() const { return _common_sub_expr_ids; } + + bool is_lambda_function() const override { return true; } + bool is_lambda_expr_independent() const { return _is_lambda_expr_independent; } + Expr* get_lambda_expr() const { return _children[0]; } + std::string debug_string() const override; + + SlotId max_used_slot_id() const; + + struct ExtractContext { + std::unordered_set lambda_arguments; + // slot id of common sub expr inside lambda expr + std::unordered_set common_sub_expr_ids; + SlotId next_slot_id; + std::map outer_common_exprs; + }; + + // Extract the outer common expression in lambda expr. + // Outer common expr is an expression that does not depend on lambda arguments at all. Such expressions can be calculated independently. + // NOTE: Calling this interface may rewrite Lambda expr, and all outer common expr will be replaced with ColumnRef expr. + // Functions using lambda expressions can extract common expressions first and calculate them separately, so as to optimize. + + // take `array_map(x->any_match(array_map(x->x < 10, arr1)), arr1)` as an example, + // `any_match(array_map(x->x<10, arr1))` is an outer common expr. it will create 2 column ref exprs to replace them. + // 1. slot 1 -> array_map(x->x<10, arr1) + // 2. slot 2 -> any_match(slot 1, arr1) + Status extract_outer_common_exprs(RuntimeState* state, ExtractContext* ctx); private: + Status collect_lambda_argument_ids(); + Status collect_capture_slot_ids(); + Status collect_common_sub_exprs(); + Status extract_outer_common_exprs(RuntimeState* state, Expr* expr, ExtractContext* ctx); + std::vector _captured_slot_ids; std::vector _arguments_ids; std::vector _common_sub_expr_ids; std::vector _common_sub_expr; + int _common_sub_expr_num; bool _is_prepared = false; + bool _is_lambda_expr_independent = false; }; } // namespace starrocks diff --git a/be/src/exprs/literal.h b/be/src/exprs/literal.h index a01001a602f45..f177995edef5e 100644 --- a/be/src/exprs/literal.h +++ b/be/src/exprs/literal.h @@ -39,7 +39,7 @@ class VectorizedLiteral final : public Expr { StatusOr generate_ir_impl(ExprContext* context, JITContext* jit_ctx) override; #endif - + bool is_literal() const override { return true; } std::string debug_string() const override; private: diff --git a/be/src/exprs/map_apply_expr.cpp b/be/src/exprs/map_apply_expr.cpp index bc21157271e9a..e262b6d94ed92 100644 --- a/be/src/exprs/map_apply_expr.cpp +++ b/be/src/exprs/map_apply_expr.cpp @@ -112,8 +112,9 @@ StatusOr MapApplyExpr::evaluate_checked(ExprContext* context, Chunk* cur_chunk->append_column(input_columns[i], _arguments_ids[i]); // column ref } // put captured columns into the new chunk aligning with the first map's offsets + auto lambda_func = dynamic_cast(_children[0]); std::vector slot_ids; - _children[0]->get_slot_ids(&slot_ids); + lambda_func->get_captured_slot_ids(&slot_ids); for (auto id : slot_ids) { DCHECK(id > 0); auto captured = chunk->get_column_by_slot_id(id); diff --git a/be/test/exprs/lambda_array_expr_test.cpp b/be/test/exprs/lambda_array_expr_test.cpp index 876a643fdf4e5..4bcfa3b18cd44 100644 --- a/be/test/exprs/lambda_array_expr_test.cpp +++ b/be/test/exprs/lambda_array_expr_test.cpp @@ -16,6 +16,7 @@ #include #include +#include #include "butil/time.h" #include "column/column_helper.h" @@ -25,6 +26,7 @@ #include "exprs/array_map_expr.h" #include "exprs/cast_expr.h" #include "exprs/function_call_expr.h" +#include "exprs/function_helper.h" #include "exprs/is_null_predicate.h" #include "exprs/lambda_function.h" #include "exprs/literal.h" @@ -253,7 +255,7 @@ TEST_F(VectorizedLambdaFunctionExprTest, array_map_lambda_test_normal_array) { // check LambdaFunction::prepare() std::vector ids, arguments; - lambda->get_slot_ids(&ids); + lambda->get_captured_slot_ids(&ids); lambda->get_lambda_arguments_ids(&arguments); ASSERT_TRUE(arguments.size() == 1 && arguments[0] == 100000); // the x's slot_id = 100000 @@ -272,37 +274,37 @@ TEST_F(VectorizedLambdaFunctionExprTest, array_map_lambda_test_normal_array) { ASSERT_FALSE(result->is_constant()); ASSERT_FALSE(result->is_numeric()); - EXPECT_EQ(3, result->size()); - EXPECT_EQ(1, result->get(0).get_array()[0].get_int32()); - EXPECT_EQ(4, result->get(0).get_array()[1].get_int32()); + ASSERT_EQ(3, result->size()); + ASSERT_EQ(1, result->get(0).get_array()[0].get_int32()); + ASSERT_EQ(4, result->get(0).get_array()[1].get_int32()); ASSERT_TRUE(result->get(1).get_array()[0].is_null()); ASSERT_TRUE(result->get(1).get_array()[1].is_null()); ASSERT_TRUE(result->get(2).get_array()[0].is_null()); - EXPECT_EQ(12, result->get(2).get_array()[1].get_int32()); + ASSERT_EQ(12, result->get(2).get_array()[1].get_int32()); } else if (i == 0 && j == 1) { // array_map(x -> x is null, array) - EXPECT_EQ(3, result->size()); - EXPECT_EQ(0, result->get(0).get_array()[0].get_int8()); - EXPECT_EQ(0, result->get(0).get_array()[1].get_int8()); - EXPECT_EQ(1, result->get(1).get_array()[0].get_int8()); - EXPECT_EQ(1, result->get(1).get_array()[1].get_int8()); - EXPECT_EQ(1, result->get(2).get_array()[0].get_int8()); - EXPECT_EQ(0, result->get(2).get_array()[1].get_int8()); + ASSERT_EQ(3, result->size()); + ASSERT_EQ(0, result->get(0).get_array()[0].get_int8()); + ASSERT_EQ(0, result->get(0).get_array()[1].get_int8()); + ASSERT_EQ(1, result->get(1).get_array()[0].get_int8()); + ASSERT_EQ(1, result->get(1).get_array()[1].get_int8()); + ASSERT_EQ(1, result->get(2).get_array()[0].get_int8()); + ASSERT_EQ(0, result->get(2).get_array()[1].get_int8()); } else if (i == 0 && j == 2) { // // array_map(x -> x+a, array) - EXPECT_EQ(3, result->size()); - EXPECT_EQ(2, result->get(0).get_array()[0].get_int32()); - EXPECT_EQ(5, result->get(0).get_array()[1].get_int32()); + ASSERT_EQ(3, result->size()); + ASSERT_EQ(2, result->get(0).get_array()[0].get_int32()); + ASSERT_EQ(5, result->get(0).get_array()[1].get_int32()); ASSERT_TRUE(result->get(1).get_array()[0].is_null()); ASSERT_TRUE(result->get(1).get_array()[1].is_null()); ASSERT_TRUE(result->get(2).get_array()[0].is_null()); - EXPECT_EQ(13, result->get(2).get_array()[1].get_int32()); + ASSERT_EQ(13, result->get(2).get_array()[1].get_int32()); } else if (i == 0 && j == 3) { - EXPECT_EQ(3, result->size()); - EXPECT_EQ(-110, result->get(0).get_array()[0].get_int32()); - EXPECT_EQ(-110, result->get(0).get_array()[1].get_int32()); - EXPECT_EQ(-110, result->get(1).get_array()[0].get_int32()); - EXPECT_EQ(-110, result->get(1).get_array()[1].get_int32()); - EXPECT_EQ(-110, result->get(2).get_array()[0].get_int32()); - EXPECT_EQ(-110, result->get(2).get_array()[1].get_int32()); + ASSERT_EQ(3, result->size()); + ASSERT_EQ(-110, result->get(0).get_array()[0].get_int32()); + ASSERT_EQ(-110, result->get(0).get_array()[1].get_int32()); + ASSERT_EQ(-110, result->get(1).get_array()[0].get_int32()); + ASSERT_EQ(-110, result->get(1).get_array()[1].get_int32()); + ASSERT_EQ(-110, result->get(2).get_array()[0].get_int32()); + ASSERT_EQ(-110, result->get(2).get_array()[1].get_int32()); } Expr::close(expr_ctxs, &_runtime_state); @@ -329,7 +331,7 @@ TEST_F(VectorizedLambdaFunctionExprTest, array_map_lambda_test_special_array) { // check LambdaFunction::prepare() std::vector ids, arguments; - lambda->get_slot_ids(&ids); + lambda->get_captured_slot_ids(&ids); lambda->get_lambda_arguments_ids(&arguments); ASSERT_TRUE(arguments.size() == 1 && arguments[0] == 100000); // the x's slot_id = 100000 @@ -338,40 +340,39 @@ TEST_F(VectorizedLambdaFunctionExprTest, array_map_lambda_test_special_array) { } else { ASSERT_TRUE(ids.empty()); } - ColumnPtr result = array_map_expr.evaluate(&exprContext, cur_chunk.get()); if (i == 1) { // array_map(x->xxx,null) - EXPECT_EQ(1, result->size()); + ASSERT_EQ(3, result->size()); ASSERT_TRUE(result->is_null(0)); } else if (i == 2 && (j == 0 || j == 2)) { // array_map( x->x || x->x+a, [null]) - EXPECT_EQ(1, result->size()); + ASSERT_EQ(1, result->size()); ASSERT_TRUE(result->get(0).get_array()[0].is_null()); } else if (i == 2 && j == 1) { // array_map(x -> x is null,[null]) - EXPECT_EQ(1, result->size()); - EXPECT_EQ(1, result->get(0).get_array()[0].get_int8()); + ASSERT_EQ(1, result->size()); + ASSERT_EQ(1, result->get(0).get_array()[0].get_int8()); } else if (i == 2 && j == 3) { // array_map(x -> -110,[null]) - EXPECT_EQ(1, result->size()); - EXPECT_EQ(-110, result->get(0).get_array()[0].get_int32()); + ASSERT_EQ(1, result->size()); + ASSERT_EQ(-110, result->get(0).get_array()[0].get_int32()); } else if (i == 3) { // array_map(x->xxx,[]) - EXPECT_EQ(1, result->size()); + ASSERT_EQ(3, result->size()); ASSERT_TRUE(result->get(0).get_array().empty()); } else if (i == 4 && (j == 0 || j == 2)) { // array_map(x->x || x->x+a, array) // [null] // [] // NULL - EXPECT_EQ(3, result->size()); + ASSERT_EQ(3, result->size()); ASSERT_TRUE(result->get(0).get_array()[0].is_null()); ASSERT_TRUE(result->get(1).get_array().empty()); ASSERT_TRUE(result->is_null(2)); } else if (i == 4 && j == 1) { // array_map(x->x is null, array) - EXPECT_EQ(3, result->size()); - EXPECT_EQ(1, result->get(0).get_array()[0].get_int8()); + ASSERT_EQ(3, result->size()); + ASSERT_EQ(1, result->get(0).get_array()[0].get_int8()); ASSERT_TRUE(result->get(1).get_array().empty()); ASSERT_TRUE(result->is_null(2)); } else if (i == 4 && j == 3) { // array_map(x-> -110, array) - EXPECT_EQ(3, result->size()); - EXPECT_EQ(-110, result->get(0).get_array()[0].get_int32()); + ASSERT_EQ(3, result->size()); + ASSERT_EQ(-110, result->get(0).get_array()[0].get_int32()); ASSERT_TRUE(result->get(1).get_array().empty()); ASSERT_TRUE(result->is_null(2)); } @@ -400,7 +401,7 @@ TEST_F(VectorizedLambdaFunctionExprTest, array_map_lambda_test_const_array) { // check LambdaFunction::prepare() std::vector ids, arguments; - lambda->get_slot_ids(&ids); + lambda->get_captured_slot_ids(&ids); lambda->get_lambda_arguments_ids(&arguments); ASSERT_TRUE(arguments.size() == 1 && arguments[0] == 100000); // the x's slot_id = 100000 @@ -409,77 +410,78 @@ TEST_F(VectorizedLambdaFunctionExprTest, array_map_lambda_test_const_array) { } else { ASSERT_TRUE(ids.empty()); } - ColumnPtr result = array_map_expr.evaluate(&exprContext, cur_chunk.get()); if (i == 5 && j == 0) { // array_map( x->x, array) - EXPECT_EQ(3, result->size()); - EXPECT_EQ(1, result->get(0).get_array()[0].get_int32()); - EXPECT_EQ(4, result->get(0).get_array()[1].get_int32()); - EXPECT_EQ(1, result->get(1).get_array()[0].get_int32()); - EXPECT_EQ(4, result->get(1).get_array()[1].get_int32()); - EXPECT_EQ(1, result->get(2).get_array()[0].get_int32()); - EXPECT_EQ(4, result->get(2).get_array()[1].get_int32()); + ASSERT_EQ(3, result->size()); + ASSERT_EQ(1, result->get(0).get_array()[0].get_int32()); + ASSERT_EQ(4, result->get(0).get_array()[1].get_int32()); + ASSERT_EQ(1, result->get(1).get_array()[0].get_int32()); + ASSERT_EQ(4, result->get(1).get_array()[1].get_int32()); + ASSERT_EQ(1, result->get(2).get_array()[0].get_int32()); + ASSERT_EQ(4, result->get(2).get_array()[1].get_int32()); } else if (i == 5 && j == 1) { // array_map(x->x is null, array) - EXPECT_EQ(3, result->size()); - EXPECT_EQ(0, result->get(0).get_array()[0].get_int8()); - EXPECT_EQ(0, result->get(0).get_array()[1].get_int8()); - EXPECT_EQ(0, result->get(1).get_array()[0].get_int8()); - EXPECT_EQ(0, result->get(1).get_array()[1].get_int8()); - EXPECT_EQ(0, result->get(2).get_array()[0].get_int8()); - EXPECT_EQ(0, result->get(2).get_array()[1].get_int8()); + ASSERT_EQ(3, result->size()); + ASSERT_EQ(0, result->get(0).get_array()[0].get_int8()); + ASSERT_EQ(0, result->get(0).get_array()[1].get_int8()); + ASSERT_EQ(0, result->get(1).get_array()[0].get_int8()); + ASSERT_EQ(0, result->get(1).get_array()[1].get_int8()); + ASSERT_EQ(0, result->get(2).get_array()[0].get_int8()); + ASSERT_EQ(0, result->get(2).get_array()[1].get_int8()); + LOG(INFO) << "pass"; } else if (i == 5 && j == 2) { // // array_map( x->x + a, array) - EXPECT_EQ(3, result->size()); - EXPECT_EQ(2, result->get(0).get_array()[0].get_int32()); - EXPECT_EQ(5, result->get(0).get_array()[1].get_int32()); - EXPECT_EQ(2, result->get(1).get_array()[0].get_int32()); - EXPECT_EQ(5, result->get(1).get_array()[1].get_int32()); - EXPECT_EQ(2, result->get(2).get_array()[0].get_int32()); - EXPECT_EQ(5, result->get(2).get_array()[1].get_int32()); + ASSERT_EQ(3, result->size()); + ASSERT_EQ(2, result->get(0).get_array()[0].get_int32()); + ASSERT_EQ(5, result->get(0).get_array()[1].get_int32()); + ASSERT_EQ(2, result->get(1).get_array()[0].get_int32()); + ASSERT_EQ(5, result->get(1).get_array()[1].get_int32()); + ASSERT_EQ(2, result->get(2).get_array()[0].get_int32()); + ASSERT_EQ(5, result->get(2).get_array()[1].get_int32()); } else if (i == 5 && j == 3) { // // array_map( x-> -110, array) - EXPECT_EQ(3, result->size()); - EXPECT_EQ(-110, result->get(0).get_array()[0].get_int32()); - EXPECT_EQ(-110, result->get(0).get_array()[1].get_int32()); - EXPECT_EQ(-110, result->get(1).get_array()[0].get_int32()); - EXPECT_EQ(-110, result->get(1).get_array()[1].get_int32()); - EXPECT_EQ(-110, result->get(2).get_array()[0].get_int32()); - EXPECT_EQ(-110, result->get(2).get_array()[1].get_int32()); + ASSERT_EQ(3, result->size()); + ASSERT_EQ(-110, result->get(0).get_array()[0].get_int32()); + ASSERT_EQ(-110, result->get(0).get_array()[1].get_int32()); + ASSERT_EQ(-110, result->get(1).get_array()[0].get_int32()); + ASSERT_EQ(-110, result->get(1).get_array()[1].get_int32()); + ASSERT_EQ(-110, result->get(2).get_array()[0].get_int32()); + ASSERT_EQ(-110, result->get(2).get_array()[1].get_int32()); } else if (i == 6) { // array_map(x -> x || x->x is null || x -> x+a, array) - EXPECT_EQ(3, result->size()); + ASSERT_EQ(3, result->size()); ASSERT_TRUE(result->is_null(0)); ASSERT_TRUE(result->is_null(1)); ASSERT_TRUE(result->is_null(2)); } else if (i == 7 && (j == 0 || j == 2)) { // array_map(x -> x || x-> x+a,array) - EXPECT_EQ(3, result->size()); + ASSERT_EQ(3, result->size()); ASSERT_TRUE(result->get(0).get_array()[0].is_null()); ASSERT_TRUE(result->get(1).get_array()[0].is_null()); ASSERT_TRUE(result->get(2).get_array()[0].is_null()); } else if (i == 7 && j == 1) { // array_map(x -> x is null, array) - EXPECT_EQ(3, result->size()); - EXPECT_EQ(1, result->get(0).get_array()[0].get_int8()); - EXPECT_EQ(1, result->get(1).get_array()[0].get_int8()); - EXPECT_EQ(1, result->get(2).get_array()[0].get_int8()); + ASSERT_EQ(3, result->size()); + ASSERT_EQ(1, result->get(0).get_array()[0].get_int8()); + ASSERT_EQ(1, result->get(1).get_array()[0].get_int8()); + ASSERT_EQ(1, result->get(2).get_array()[0].get_int8()); } else if (i == 7 && j == 3) { // array_map(x -> -110, array) - EXPECT_EQ(3, result->size()); - EXPECT_EQ(-110, result->get(0).get_array()[0].get_int32()); - EXPECT_EQ(-110, result->get(1).get_array()[0].get_int32()); - EXPECT_EQ(-110, result->get(2).get_array()[0].get_int32()); + ASSERT_EQ(3, result->size()); + ASSERT_EQ(-110, result->get(0).get_array()[0].get_int32()); + ASSERT_EQ(-110, result->get(1).get_array()[0].get_int32()); + ASSERT_EQ(-110, result->get(2).get_array()[0].get_int32()); } else if (i == 8) { // array_map(x -> x || x -> x is null || x -> x+a || x -> -110, array) - EXPECT_EQ(3, result->size()); + ASSERT_EQ(3, result->size()); ASSERT_TRUE(result->get(0).get_array().empty()); ASSERT_TRUE(result->get(1).get_array().empty()); ASSERT_TRUE(result->get(2).get_array().empty()); } if (j == 1) { // array -> array - if (result->is_nullable()) { - auto col = std::dynamic_pointer_cast(result); - auto array_col = std::dynamic_pointer_cast(col->data_column()); - EXPECT_EQ(2, array_col->elements_column()->type_size()); // nullable bool - } else { - auto array_col = std::dynamic_pointer_cast(result); - EXPECT_EQ(2, array_col->elements_column()->type_size()); // nullable bool + auto data_column = result; + if (data_column->is_constant()) { + data_column = FunctionHelper::get_data_column_of_const(data_column); + } + if (data_column->is_nullable()) { + data_column = down_cast(data_column.get())->data_column(); } + auto array_col = std::dynamic_pointer_cast(data_column); + ASSERT_EQ(2, array_col->elements_column()->type_size()); } Expr::close(expr_ctxs, &_runtime_state); } diff --git a/test/sql/test_array/R/test_array_map b/test/sql/test_array/R/test_array_map index f76f33c62b23d..9806b92c57b63 100644 --- a/test/sql/test_array/R/test_array_map +++ b/test/sql/test_array/R/test_array_map @@ -81,4 +81,5 @@ VALUES ARRAY_GENERATE(1, 1000) )); -- result: --- !result \ No newline at end of file +-- !result + diff --git a/test/sql/test_array/T/test_array_map b/test/sql/test_array/T/test_array_map index a91bc1024ff92..6655e83d8e189 100644 --- a/test/sql/test_array/T/test_array_map +++ b/test/sql/test_array/T/test_array_map @@ -78,5 +78,3 @@ VALUES x -> CAST(x AS STRING), ARRAY_GENERATE(1, 1000) )); - - diff --git a/test/sql/test_array_fn/R/test_array_map_2 b/test/sql/test_array_fn/R/test_array_map_2 new file mode 100644 index 0000000000000..dac1097c06b0e --- /dev/null +++ b/test/sql/test_array_fn/R/test_array_map_2 @@ -0,0 +1,150 @@ +-- name: test_array_map_2 +CREATE TABLE `array_map_test` ( + `id` tinyint(4) NOT NULL COMMENT "", + `arr_str` array NULL COMMENT "", + `arr_largeint` array NULL COMMENT "" +) ENGINE=OLAP +DUPLICATE KEY(`id`) +DISTRIBUTED BY RANDOM +PROPERTIES ( +"replication_num" = "1" +); +-- result: +-- !result +insert into array_map_test values (1, array_repeat("abcdefghasdasdasirnqwrq", 20000), array_repeat(100, 20000)); +-- result: +-- !result +select count() from array_map_test where array_length(array_map((x,y)->(id+length(x)+y), arr_str, arr_largeint)) > 10 ; +-- result: +1 +-- !result +select count(array_length(array_map((x,y)->(id+length(x)+y), arr_str, arr_largeint))) from array_map_test; +-- result: +1 +-- !result +select count() from array_map_test where any_match(x->any_match(x->x<10, arr_largeint), arr_largeint); +-- result: +0 +-- !result +select count(any_match(x->any_match(x->x<10, arr_largeint), arr_largeint)) from array_map_test; +-- result: +1 +-- !result +select count(array_map(x->array_length(array_concat(arr_str,[])), arr_largeint)) from array_map_test; +-- result: +1 +-- !result +set @arr=array_repeat("12345",1000000); +-- result: +-- !result +select array_length(array_map((x,y)->x > y, @arr,@arr)) from table(generate_series(1,10,1)); +-- result: +1000000 +1000000 +1000000 +1000000 +1000000 +1000000 +1000000 +1000000 +1000000 +1000000 +-- !result +-- name: test_array_map_3 +CREATE TABLE `t` ( + `k` bigint NOT NULL COMMENT "", + `arr_0` array NOT NULL COMMENT "", + `arr_1` array NULL COMMENT "", + `arr_2` array NULL COMMENT "" +) ENGINE=OLAP +PRIMARY KEY(`k`) +DISTRIBUTED BY HASH(`k`) BUCKETS 1 +PROPERTIES ( +"replication_num" = "1" +); +-- result: +-- !result +insert into t values (1, [1,2], [1,2],[2,3]), (2, [1,2], null, [2,3]), (3, [1,2],[1,2],null),(4, [1,2],[null,null],[2,3]), (5, [1], [1,2], [3]); +-- result: +-- !result +select array_map((x,y,z)->x+y+z, arr_0, arr_1, arr_2) from t; +-- result: +[REGEX].*Input array element's size is not equal in array_map().* +-- !result +select array_map((x,y,z)->x+y+z, arr_0, arr_1, arr_2) from t where k != 5 order by k; +-- result: +[4,7] +None +None +[null,null] +-- !result +delete from t where k = 5; +-- result: +-- !result +select array_map((x,y,z)->x+y+z, arr_0, arr_1, arr_2) from t order by k; +-- result: +[4,7] +None +None +[null,null] +-- !result +select array_map((x,y,z,d)->x+y+z+d, arr_0, arr_1, arr_2, [1,2]) from t order by k; +-- result: +[5,9] +None +None +[null,null] +-- !result +select array_map((x,y,z,d)->x+y+z+d, arr_0, arr_1, arr_2, [1]) from t order by k; +-- result: +[REGEX].*Input array element's size is not equal in array_map().* +-- !result +select array_map(x->x, arr_0) from t order by k; +-- result: +[1,2] +[1,2] +[1,2] +[1,2] +-- !result +select array_map((x,y,z)->10, arr_0, arr_1, arr_2) from t; +-- result: +[10,10] +None +None +[10,10] +-- !result +select array_map((x,y)-> k, arr_0, arr_1) from t order by k; +-- result: +[1,1] +None +[3,3] +[4,4] +-- !result +select array_map((x,y)->k, [1,2],[2,3]) from t order by k; +-- result: +[1,1] +[1,1] +[1,1] +[1,1] +-- !result +select array_map((x,y,z)->x+y+z, [1,2],[2,3],[3,4]) from t; +-- result: +[6,9] +[6,9] +[6,9] +[6,9] +-- !result +select array_map((x,y,z)->x+y+z, [1,2],[2,null],[3,4]) from t; +-- result: +[6,null] +[6,null] +[6,null] +[6,null] +-- !result +select array_map((x,y,z)->x+y+z, [1,2],[2,null],null) from t; +-- result: +None +None +None +None +-- !result \ No newline at end of file diff --git a/test/sql/test_array_fn/T/test_array_map_2 b/test/sql/test_array_fn/T/test_array_map_2 new file mode 100644 index 0000000000000..cfc22d00c61fc --- /dev/null +++ b/test/sql/test_array_fn/T/test_array_map_2 @@ -0,0 +1,60 @@ + +-- name: test_array_map_2 +CREATE TABLE `array_map_test` ( + `id` tinyint(4) NOT NULL COMMENT "", + `arr_str` array NULL COMMENT "", + `arr_largeint` array NULL COMMENT "" +) ENGINE=OLAP +DUPLICATE KEY(`id`) +DISTRIBUTED BY RANDOM +PROPERTIES ( +"replication_num" = "1" +); + +insert into array_map_test values (1, array_repeat("abcdefghasdasdasirnqwrq", 20000), array_repeat(100, 20000)); + +select count() from array_map_test where array_length(array_map((x,y)->(id+length(x)+y), arr_str, arr_largeint)) > 10 ; +select count(array_length(array_map((x,y)->(id+length(x)+y), arr_str, arr_largeint))) from array_map_test; + +select count() from array_map_test where any_match(x->any_match(x->x<10, arr_largeint), arr_largeint); +select count(any_match(x->any_match(x->x<10, arr_largeint), arr_largeint)) from array_map_test; +select count(array_map(x->array_length(array_concat(arr_str,[])), arr_largeint)) from array_map_test; + +set @arr=array_repeat("12345",1000000); +select array_length(array_map((x,y)->x > y, @arr,@arr)) from table(generate_series(1,10,1)); + +-- name: test_array_map_3 +CREATE TABLE `t` ( + `k` bigint NOT NULL COMMENT "", + `arr_0` array NOT NULL COMMENT "", + `arr_1` array NULL COMMENT "", + `arr_2` array NULL COMMENT "" +) ENGINE=OLAP +PRIMARY KEY(`k`) +DISTRIBUTED BY HASH(`k`) BUCKETS 1 +PROPERTIES ( +"replication_num" = "1" +); + +insert into t values (1, [1,2], [1,2],[2,3]), (2, [1,2], null, [2,3]), (3, [1,2],[1,2],null),(4, [1,2],[null,null],[2,3]), (5, [1], [1,2], [3]); +select array_map((x,y,z)->x+y+z, arr_0, arr_1, arr_2) from t; +select array_map((x,y,z)->x+y+z, arr_0, arr_1, arr_2) from t where k != 5 order by k; +delete from t where k = 5; + +select array_map((x,y,z)->x+y+z, arr_0, arr_1, arr_2) from t order by k; +select array_map((x,y,z,d)->x+y+z+d, arr_0, arr_1, arr_2, [1,2]) from t order by k; +select array_map((x,y,z,d)->x+y+z+d, arr_0, arr_1, arr_2, [1]) from t order by k; + +select array_map(x->x, arr_0) from t order by k; +-- independent expr +select array_map((x,y,z)->10, arr_0, arr_1, arr_2) from t; +select array_map((x,y)-> k, arr_0, arr_1) from t order by k; + +-- independent expr with all const +select array_map((x,y)->k, [1,2],[2,3]) from t order by k; + +-- non indepentdent with all const +select array_map((x,y,z)->x+y+z, [1,2],[2,3],[3,4]) from t; +select array_map((x,y,z)->x+y+z, [1,2],[2,null],[3,4]) from t; +select array_map((x,y,z)->x+y+z, [1,2],[2,null],null) from t; +