diff --git a/be/src/column/chunk.cpp b/be/src/column/chunk.cpp index 9f455d42fc321..887ed8f480658 100644 --- a/be/src/column/chunk.cpp +++ b/be/src/column/chunk.cpp @@ -132,6 +132,15 @@ void Chunk::append_column(ColumnPtr column, const FieldPtr& field) { check_or_die(); } +void Chunk::append_vector_column(ColumnPtr column, const FieldPtr& field, SlotId slot_id) { + DCHECK(!_cid_to_index.contains(field->id())); + _cid_to_index[field->id()] = _columns.size(); + _slot_id_to_index[slot_id] = _columns.size(); + _columns.emplace_back(std::move(column)); + _schema->append(field); + check_or_die(); +} + void Chunk::append_column(ColumnPtr column, SlotId slot_id) { _slot_id_to_index[slot_id] = _columns.size(); _columns.emplace_back(std::move(column)); diff --git a/be/src/column/chunk.h b/be/src/column/chunk.h index cf38261077111..92578f5cf6699 100644 --- a/be/src/column/chunk.h +++ b/be/src/column/chunk.h @@ -112,6 +112,8 @@ class Chunk { // schema must exist and will be updated. void append_column(ColumnPtr column, const FieldPtr& field); + void append_vector_column(ColumnPtr column, const FieldPtr& field, SlotId slot_id); + void append_column(ColumnPtr column, SlotId slot_id); void insert_column(size_t idx, ColumnPtr column, const FieldPtr& field); diff --git a/be/src/column/schema.cpp b/be/src/column/schema.cpp index 92e554c86cc54..92290ca2d0d91 100644 --- a/be/src/column/schema.cpp +++ b/be/src/column/schema.cpp @@ -45,7 +45,11 @@ Schema::Schema(Schema* schema, const std::vector& cids) auto ori_sort_idxes = schema->sort_key_idxes(); std::unordered_set scids(ori_sort_idxes.begin(), ori_sort_idxes.end()); for (int i = 0; i < cids.size(); i++) { - DCHECK_LT(cids[i], schema->_fields.size()); + if (cids[i] >= schema->_fields.size()) { + _fields.resize(_fields.size() - 1); + continue; + } + // DCHECK_LT(cids[i], schema->_fields.size()); _fields[i] = schema->_fields[cids[i]]; if (scids.find(cids[i]) != scids.end()) { _sort_key_idxes.emplace_back(i); diff --git a/be/src/common/config.h b/be/src/common/config.h index 523a35d164312..cb1b3d0894447 100644 --- a/be/src/common/config.h +++ b/be/src/common/config.h @@ -1228,6 +1228,9 @@ CONF_Int64(query_cache_capacity, "536870912"); // ranges in [1,16], default value is 4. CONF_mInt32(query_cache_num_lanes_per_driver, "4"); +// Used by vector query cache, 500MB in default +CONF_Int64(vector_query_cache_capacity, "536870912"); + // Used to limit buffer size of tablet send channel. CONF_mInt64(send_channel_buffer_limit, "67108864"); @@ -1359,6 +1362,9 @@ CONF_mInt64(arrow_io_coalesce_read_max_buffer_size, "8388608"); CONF_mInt64(arrow_io_coalesce_read_max_distance_size, "1048576"); CONF_mInt64(arrow_read_batch_size, "4096"); +// default not to build the empty index +CONF_mInt32(config_tenann_default_build_threshold, "0"); + // Set to true to enable socket_keepalive option in brpc CONF_mBool(brpc_socket_keepalive, "false"); CONF_mBool(apply_del_vec_after_all_index_filter, "true"); diff --git a/be/src/exec/olap_scan_node.cpp b/be/src/exec/olap_scan_node.cpp index 3f46b349cb720..e2b07ed1767ba 100644 --- a/be/src/exec/olap_scan_node.cpp +++ b/be/src/exec/olap_scan_node.cpp @@ -583,6 +583,10 @@ void OlapScanNode::_init_counter(RuntimeState* state) { _bf_filtered_counter = ADD_CHILD_COUNTER(_scan_profile, "BloomFilterFilterRows", TUnit::UNIT, "SegmentInit"); _gin_filtered_counter = ADD_CHILD_COUNTER(_runtime_profile, "GinFilterRows", TUnit::UNIT, "SegmentInit"); _gin_filtered_timer = ADD_CHILD_TIMER(_runtime_profile, "GinFilter", "SegmentInit"); + _get_row_ranges_by_vector_index_timer = ADD_CHILD_TIMER(_scan_profile, "GetVectorRowRangesTime", "SegmentInit"); + _vector_search_timer = ADD_CHILD_TIMER(_scan_profile, "VectorSearchTime", "SegmentInit"); + _process_vector_distance_and_id_timer = + ADD_CHILD_TIMER(_scan_profile, "ProcessVectorDistanceAndIdTime", "SegmentInit"); _seg_zm_filtered_counter = ADD_CHILD_COUNTER(_scan_profile, "SegmentZoneMapFilterRows", TUnit::UNIT, "SegmentInit"); _seg_rt_filtered_counter = ADD_CHILD_COUNTER(_scan_profile, "SegmentRuntimeZoneMapFilterRows", TUnit::UNIT, "SegmentInit"); diff --git a/be/src/exec/olap_scan_node.h b/be/src/exec/olap_scan_node.h index 1ad14f10edee9..345745fe72efd 100644 --- a/be/src/exec/olap_scan_node.h +++ b/be/src/exec/olap_scan_node.h @@ -248,6 +248,9 @@ class OlapScanNode final : public starrocks::ScanNode { RuntimeProfile::Counter* _bi_filter_timer = nullptr; RuntimeProfile::Counter* _gin_filtered_counter = nullptr; RuntimeProfile::Counter* _gin_filtered_timer = nullptr; + RuntimeProfile::Counter* _get_row_ranges_by_vector_index_timer = nullptr; + RuntimeProfile::Counter* _vector_search_timer = nullptr; + RuntimeProfile::Counter* _process_vector_distance_and_id_timer = nullptr; RuntimeProfile::Counter* _pushdown_predicates_counter = nullptr; RuntimeProfile::Counter* _rowsets_read_count = nullptr; RuntimeProfile::Counter* _segments_read_count = nullptr; diff --git a/be/src/exec/pipeline/scan/olap_chunk_source.cpp b/be/src/exec/pipeline/scan/olap_chunk_source.cpp index 0e5ed2a90f250..5048e6b4ec4bc 100644 --- a/be/src/exec/pipeline/scan/olap_chunk_source.cpp +++ b/be/src/exec/pipeline/scan/olap_chunk_source.cpp @@ -75,6 +75,12 @@ void OlapChunkSource::close(RuntimeState* state) { Status OlapChunkSource::prepare(RuntimeState* state) { RETURN_IF_ERROR(ChunkSource::prepare(state)); _runtime_state = state; + const TQueryOptions& query_options = _runtime_state->query_options(); + _use_vector_index = query_options.enable_use_ann; + if (_use_vector_index) { + _use_ivfpq = query_options.vector_search_options.use_ivfpq; + _vector_distance_column_name = query_options.vector_search_options.vector_distance_column_name; + } const TOlapScanNode& thrift_olap_scan_node = _scan_node->thrift_olap_scan_node(); const TupleDescriptor* tuple_desc = state->desc_tbl().get_tuple_descriptor(thrift_olap_scan_node.tuple_id); _slots = &tuple_desc->slots(); @@ -129,6 +135,11 @@ void OlapChunkSource::_init_counter(RuntimeState* state) { const std::string segment_init_name = "SegmentInit"; _seg_init_timer = ADD_CHILD_TIMER(_runtime_profile, segment_init_name, IO_TASK_EXEC_TIMER_NAME); _bi_filter_timer = ADD_CHILD_TIMER(_runtime_profile, "BitmapIndexFilter", segment_init_name); + _get_row_ranges_by_vector_index_timer = + ADD_CHILD_TIMER(_runtime_profile, "GetVectorRowRangesTime", segment_init_name); + _vector_search_timer = ADD_CHILD_TIMER(_runtime_profile, "VectorSearchTime", segment_init_name); + _process_vector_distance_and_id_timer = + ADD_CHILD_TIMER(_runtime_profile, "ProcessVectorDistanceAndIdTime", segment_init_name); _bi_filtered_counter = ADD_CHILD_COUNTER(_runtime_profile, "BitmapIndexFilterRows", TUnit::UNIT, segment_init_name); _bf_filtered_counter = ADD_CHILD_COUNTER(_runtime_profile, "BloomFilterFilterRows", TUnit::UNIT, segment_init_name); _gin_filtered_counter = ADD_CHILD_COUNTER(_runtime_profile, "GinFilterRows", TUnit::UNIT, segment_init_name); @@ -214,6 +225,21 @@ Status OlapChunkSource::_init_reader_params(const std::vectorquery_options().vector_search_options; + _params.vector_distance_column_name = _vector_distance_column_name; + _params.k = vector_options.vector_limit_k; + for (const std::string& str : vector_options.query_vector) { + _params.query_vector.push_back(std::stof(str)); + } + _params.query_params = vector_options.query_params; + _params.vector_range = vector_options.vector_range; + _params.result_order = vector_options.result_order; + _params.use_ivfpq = _use_ivfpq; + _params.pq_refine_factor = vector_options.pq_refine_factor; + _params.k_factor = vector_options.k_factor; + } if (thrift_olap_scan_node.__isset.sorted_by_keys_per_tablet) { _params.sorted_by_keys_per_tablet = thrift_olap_scan_node.sorted_by_keys_per_tablet; } @@ -272,7 +298,16 @@ Status OlapChunkSource::_init_reader_params(const std::vector& scanner_columns) { for (auto slot : *_slots) { DCHECK(slot->is_materialized()); - int32_t index = _tablet_schema->field_index(slot->col_name()); + int32_t index; + if (_use_vector_index && !_use_ivfpq) { + index = _tablet_schema->field_index(slot->col_name(), _vector_distance_column_name); + if (slot->col_name() == _vector_distance_column_name) { + _params.vector_column_id = index; + _params.vector_slot_id = slot->id(); + } + } else { + index = _tablet->field_index_with_max_version(slot->col_name()); + } if (index < 0) { std::stringstream ss; ss << "invalid field name: " << slot->col_name(); @@ -296,7 +331,12 @@ Status OlapChunkSource::_init_scanner_columns(std::vector& scanner_col Status OlapChunkSource::_init_unused_output_columns(const std::vector& unused_output_columns) { for (const auto& col_name : unused_output_columns) { - int32_t index = _tablet_schema->field_index(col_name); + int32_t index; + if (_use_vector_index && !_use_ivfpq) { + index = _tablet_schema->field_index(col_name, _vector_distance_column_name); + } else { + index = _tablet_schema->field_index(col_name); + } if (index < 0) { std::stringstream ss; ss << "invalid field name: " << col_name; @@ -451,7 +491,9 @@ Status OlapChunkSource::_init_olap_reader(RuntimeState* runtime_state) { _prj_iter = _reader; } else { starrocks::Schema output_schema = ChunkHelper::convert_schema(_tablet_schema, scanner_columns); - _prj_iter = new_projection_iterator(output_schema, _reader); + _prj_iter = new_projection_iterator(output_schema, _reader, _use_vector_index && !_use_ivfpq, + _params.vector_column_id, _params.vector_slot_id, + _params.vector_distance_column_name); } if (!_scan_ctx->not_push_down_conjuncts().empty() || !_non_pushdown_pred_tree.empty()) { @@ -505,7 +547,12 @@ Status OlapChunkSource::_init_global_dicts(TabletReaderParams* params) { auto iter = global_dict_map.find(slot->id()); if (iter != global_dict_map.end()) { auto& dict_map = iter->second.first; - int32_t index = _tablet_schema->field_index(slot->col_name()); + int32_t index; + if (_use_vector_index && !_use_ivfpq) { + index = _tablet_schema->field_index(slot->col_name(), _vector_distance_column_name); + } else { + index = _tablet_schema->field_index(slot->col_name()); + } DCHECK(index >= 0); global_dict->emplace(index, const_cast(&dict_map)); } @@ -625,8 +672,9 @@ void OlapChunkSource::_update_counter() { COUNTER_UPDATE(_bi_filtered_counter, _reader->stats().rows_bitmap_index_filtered); COUNTER_UPDATE(_bi_filter_timer, _reader->stats().bitmap_index_filter_timer); - COUNTER_UPDATE(_gin_filtered_counter, _reader->stats().rows_gin_filtered); - COUNTER_UPDATE(_gin_filtered_timer, _reader->stats().gin_index_filter_ns); + COUNTER_UPDATE(_get_row_ranges_by_vector_index_timer, _reader->stats().get_row_ranges_by_vector_index_timer); + COUNTER_UPDATE(_vector_search_timer, _reader->stats().vector_search_timer); + COUNTER_UPDATE(_process_vector_distance_and_id_timer, _reader->stats().process_vector_distance_and_id_timer); COUNTER_UPDATE(_block_seek_counter, _reader->stats().block_seek_num); COUNTER_UPDATE(_rowsets_read_count, _reader->stats().rowsets_read_count); diff --git a/be/src/exec/pipeline/scan/olap_chunk_source.h b/be/src/exec/pipeline/scan/olap_chunk_source.h index fc396d99896b9..93f88cd5c9c4d 100644 --- a/be/src/exec/pipeline/scan/olap_chunk_source.h +++ b/be/src/exec/pipeline/scan/olap_chunk_source.h @@ -76,7 +76,7 @@ class OlapChunkSource final : public ChunkSource { OlapScanNode* _scan_node; OlapScanContext* _scan_ctx; - const int64_t _limit; // -1: no limit + int64_t _limit; // -1: no limit TInternalScanRange* _scan_range; PredicateTree _non_pushdown_pred_tree; @@ -106,6 +106,16 @@ class OlapChunkSource final : public ChunkSource { std::vector _column_access_paths; + bool _use_vector_index = false; + + bool _use_ivfpq = false; + + std::string _vector_distance_column_name; + + double _vector_range; + + int result_order; + // The following are profile meatures int64_t _num_rows_read = 0; @@ -150,6 +160,9 @@ class OlapChunkSource final : public ChunkSource { RuntimeProfile::Counter* _bi_filter_timer = nullptr; RuntimeProfile::Counter* _gin_filtered_counter = nullptr; RuntimeProfile::Counter* _gin_filtered_timer = nullptr; + RuntimeProfile::Counter* _get_row_ranges_by_vector_index_timer = nullptr; + RuntimeProfile::Counter* _vector_search_timer = nullptr; + RuntimeProfile::Counter* _process_vector_distance_and_id_timer = nullptr; RuntimeProfile::Counter* _pushdown_predicates_counter = nullptr; RuntimeProfile::Counter* _non_pushdown_predicates_counter = nullptr; RuntimeProfile::Counter* _rowsets_read_count = nullptr; diff --git a/be/src/exprs/column_ref.cpp b/be/src/exprs/column_ref.cpp index 8799f91c1647c..04422069610ae 100644 --- a/be/src/exprs/column_ref.cpp +++ b/be/src/exprs/column_ref.cpp @@ -54,7 +54,7 @@ StatusOr ColumnRef::evaluate_checked(ExprContext* context, Chunk* ptr ColumnPtr& ColumnRef::get_column(Expr* expr, Chunk* chunk) { auto* ref = (ColumnRef*)expr; ColumnPtr& column = (chunk)->get_column_by_slot_id(ref->slot_id()); - return column; + return column; } } // namespace starrocks diff --git a/be/src/exprs/math_functions.cpp b/be/src/exprs/math_functions.cpp index 361036d8f6243..541703352977d 100644 --- a/be/src/exprs/math_functions.cpp +++ b/be/src/exprs/math_functions.cpp @@ -886,7 +886,6 @@ StatusOr MathFunctions::cosine_similarity(FunctionContext* context, c target_sum += target_data[j] * target_data[j]; } } - if constexpr (!isNorm) { result_value = sum / (std::sqrt(base_sum) * std::sqrt(target_sum)); } else { @@ -899,10 +898,308 @@ StatusOr MathFunctions::cosine_similarity(FunctionContext* context, c return result; } -// explicitly instaniate template function. +template +StatusOr MathFunctions::cosine_similarity2(FunctionContext* context, const Columns& columns) { + DCHECK_EQ(columns.size(), 2); + + const Column* base = columns[0].get(); + const Column* target = columns[1].get(); + + if (base->has_null() || target->has_null()) { + return Status::InvalidArgument( + fmt::format("cosine_similarity does not support null values. {} array has null value.", + base->has_null() ? "base" : "target")); + } + if (base->is_nullable()) { + base = down_cast(base)->data_column().get(); + } + if (target->is_nullable()) { + target = down_cast(target)->data_column().get(); + } + + if (!base->is_constant() && !target->is_constant()) { + return Status::InvalidArgument(fmt::format("cosine_similarity need a constant array as one of the parameter!")); + } + + using CppType = RunTimeCppType; + using ColumnType = RunTimeColumnType; + + std::vector base_elements; + + auto pattern = ColumnHelper::get_const_value(base); + + std::string pattern_str = pattern.to_string(); + std::string delimiter = ","; + size_t pos_start = 0, pos_end, delim_len = delimiter.length(); + std::string token; + while ((pos_end = pattern_str.find(delimiter, pos_start)) != std::string::npos) { + token = pattern_str.substr(pos_start, pos_end - pos_start); + CppType num = ::atof(token.c_str()); + base_elements.push_back(num); + pos_start = pos_end + delim_len; + } + float last = ::atof(pattern_str.substr(pos_start).c_str()); + base_elements.push_back(last); + size_t base_flat_size = base_elements.size(); + + const Column* target_flat = down_cast(target)->elements_column().get(); + const uint32_t* target_offset = down_cast(target)->offsets().get_data().data(); + + if (target_flat->has_null()) { + return Status::InvalidArgument("cosine_similarity does not support null values"); + } + if (target_flat->is_nullable()) { + target_flat = down_cast(target_flat)->data_column().get(); + } + const CppType* target_data_head = down_cast(target_flat)->get_data().data(); + size_t target_size = target->size(); + + // prepare result with nullable value. + ColumnPtr result = ColumnHelper::create_column(TypeDescriptor{TYPE}, false, false, target_size); + ColumnType* data_result = down_cast(result.get()); + CppType* result_data = data_result->get_data().data(); + + for (size_t i = 0; i < target_size; i++) { + size_t t_dim_size = target_offset[i + 1] - target_offset[i]; + if (t_dim_size != base_flat_size) { + return Status::InvalidArgument( + fmt::format("cosine_similarity requires equal length arrays in each row. base array dimension size " + "is {}, target array dimension size is {}.", + base_flat_size, t_dim_size)); + } + if (t_dim_size == 0) { + return Status::InvalidArgument("cosine_similarity requires non-empty arrays in each row"); + } + } + + const CppType* target_data = target_data_head; + const CppType* base_data = base_elements.data(); + for (size_t i = 0; i < target_size; i++) { + CppType sum = 0; + CppType base_sum = 0; + CppType target_sum = 0; + size_t dim_size = target_offset[i + 1] - target_offset[i]; + CppType result_value = 0; + for (size_t j = 0; j < dim_size; j++) { + sum += base_data[j] * target_data[j]; + if constexpr (!isNorm) { + base_sum += base_data[j] * base_data[j]; + target_sum += target_data[j] * target_data[j]; + } + } + if constexpr (!isNorm) { + result_value = sum / (std::sqrt(base_sum) * std::sqrt(target_sum)); + } else { + result_value = sum; + } + result_data[i] = result_value; + target_data += dim_size; + } + return result; +} + +// explicitly instantiate template function. template StatusOr MathFunctions::cosine_similarity(FunctionContext* context, const Columns& columns); template StatusOr MathFunctions::cosine_similarity(FunctionContext* context, const Columns& columns); +template StatusOr MathFunctions::cosine_similarity2(FunctionContext* context, + const Columns& columns); +template StatusOr MathFunctions::cosine_similarity2(FunctionContext* context, + const Columns& columns); + +template +StatusOr MathFunctions::l2_distance(FunctionContext* context, const Columns& columns) { + DCHECK_EQ(columns.size(), 2); + + const Column* base = columns[0].get(); + const Column* target = columns[1].get(); + size_t target_size = target->size(); + if (base->size() != target_size) { + return Status::InvalidArgument(fmt::format( + "l2_distance requires equal length arrays. base array size is {} and target array size is {}.", + base->size(), target->size())); + } + if (base->has_null() || target->has_null()) { + return Status::InvalidArgument(fmt::format("l2_distance does not support null values. {} array has null value.", + base->has_null() ? "base" : "target")); + } + if (base->is_nullable()) { + base = down_cast(base)->data_column().get(); + } + if (target->is_nullable()) { + target = down_cast(target)->data_column().get(); + } + + // check dimension equality. + const Column* base_flat = down_cast(base)->elements_column().get(); + const uint32_t* base_offset = down_cast(base)->offsets().get_data().data(); + size_t base_flat_size = base_flat->size(); + + const Column* target_flat = down_cast(target)->elements_column().get(); + size_t target_flat_size = target_flat->size(); + const uint32_t* target_offset = down_cast(target)->offsets().get_data().data(); + + if (base_flat_size != target_flat_size) { + return Status::InvalidArgument("l2_distance requires equal length arrays"); + } + + if (base_flat->has_null() || target_flat->has_null()) { + return Status::InvalidArgument("l2_distance does not support null values"); + } + if (base_flat->is_nullable()) { + base_flat = down_cast(base_flat)->data_column().get(); + } + if (target_flat->is_nullable()) { + target_flat = down_cast(target_flat)->data_column().get(); + } + + using CppType = RunTimeCppType; + using ColumnType = RunTimeColumnType; + + const CppType* base_data_head = down_cast(base_flat)->get_data().data(); + const CppType* target_data_head = down_cast(target_flat)->get_data().data(); + + // prepare result with nullable value. + ColumnPtr result = ColumnHelper::create_column(TypeDescriptor{TYPE}, false, false, target_size); + ColumnType* data_result = down_cast(result.get()); + CppType* result_data = data_result->get_data().data(); + + for (size_t i = 0; i < target_size; i++) { + size_t t_dim_size = target_offset[i + 1] - target_offset[i]; + size_t b_dim_size = base_offset[i + 1] - base_offset[i]; + if (t_dim_size != b_dim_size) { + return Status::InvalidArgument( + fmt::format("l2_distance requires equal length arrays in each row. base array dimension size " + "is {}, target array dimension size is {}.", + b_dim_size, t_dim_size)); + } + if (t_dim_size == 0) { + return Status::InvalidArgument("l2_distance requires non-empty arrays in each row"); + } + } + + const CppType* target_data = target_data_head; + const CppType* base_data = base_data_head; + + for (size_t i = 0; i < target_size; i++) { + CppType sum = 0; + size_t dim_size = target_offset[i + 1] - target_offset[i]; + for (size_t j = 0; j < dim_size; j++) { + CppType distance; + distance = (base_data[j] - target_data[j]) * (base_data[j] - target_data[j]); + sum += distance; + } + result_data[i] = sum; + target_data += dim_size; + base_data += dim_size; + } + + return result; +} + +template +StatusOr MathFunctions::l2_distance2(FunctionContext* context, const Columns& columns) { + DCHECK_EQ(columns.size(), 2); + + const Column* base = columns[0].get(); + const Column* target = columns[1].get(); + + if (base->has_null() || target->has_null()) { + return Status::InvalidArgument( + fmt::format("l2_distance2 does not support null values. {} array has null value.", + base->has_null() ? "base" : "target")); + } + if (base->is_nullable()) { + base = down_cast(base)->data_column().get(); + } + if (target->is_nullable()) { + target = down_cast(target)->data_column().get(); + } + + if (!base->is_constant() && !target->is_constant()) { + return Status::InvalidArgument(fmt::format("l2_distance2 need a constant array as one of the parameter!")); + } + + using CppType = RunTimeCppType; + using ColumnType = RunTimeColumnType; + + std::vector base_elements; + + auto pattern = ColumnHelper::get_const_value(base); + + std::string pattern_str = pattern.to_string(); + std::string delimiter = ","; + size_t pos_start = 0, pos_end, delim_len = delimiter.length(); + std::string token; + while ((pos_end = pattern_str.find(delimiter, pos_start)) != std::string::npos) { + token = pattern_str.substr(pos_start, pos_end - pos_start); + CppType num = ::atof(token.c_str()); + base_elements.push_back(num); + pos_start = pos_end + delim_len; + } + float last = ::atof(pattern_str.substr(pos_start).c_str()); + base_elements.push_back(last); + size_t base_flat_size = base_elements.size(); + + const Column* target_flat = down_cast(target)->elements_column().get(); + const uint32_t* target_offset = down_cast(target)->offsets().get_data().data(); + + if (target_flat->has_null()) { + return Status::InvalidArgument("l2_distance2 does not support null values"); + } + if (target_flat->is_nullable()) { + target_flat = down_cast(target_flat)->data_column().get(); + } + const CppType* target_data_head = down_cast(target_flat)->get_data().data(); + size_t target_size = target->size(); + + // prepare result with nullable value. + ColumnPtr result = ColumnHelper::create_column(TypeDescriptor{TYPE}, false, false, target_size); + ColumnType* data_result = down_cast(result.get()); + CppType* result_data = data_result->get_data().data(); + + for (size_t i = 0; i < target_size; i++) { + size_t t_dim_size = target_offset[i + 1] - target_offset[i]; + if (t_dim_size != base_flat_size) { + return Status::InvalidArgument( + fmt::format("l2_distance2 requires equal length arrays in each row. base array dimension size " + "is {}, target array dimension size is {}.", + base_flat_size, t_dim_size)); + } + if (t_dim_size == 0) { + return Status::InvalidArgument("l2_distance2 requires non-empty arrays in each row"); + } + } + + const CppType* target_data = target_data_head; + const CppType* base_data = base_elements.data(); + + for (size_t i = 0; i < target_size; i++) { + CppType sum = 0; + CppType base_magnitude = 0; + CppType target_magnitude = 0; + size_t dim_size = target_offset[i + 1] - target_offset[i]; + for (size_t j = 0; j < dim_size; j++) { + base_magnitude += base_data[j] * base_data[j]; + target_magnitude += target_data[j] * target_data[j]; + } + target_magnitude = std::sqrt(target_magnitude); + base_magnitude = std::sqrt(base_magnitude); + for (size_t j = 0; j < dim_size; j++) { + CppType distance = (base_data[j] - target_data[j]) * (base_data[j] - target_data[j]); + sum += distance; + } + result_data[i] = sum; + target_data += dim_size; + base_data += dim_size; + } + + return result; +} + +// explicitly instantiate template function. +template StatusOr MathFunctions::l2_distance(FunctionContext* context, const Columns& columns); +template StatusOr MathFunctions::l2_distance2(FunctionContext* context, const Columns& columns); } // namespace starrocks diff --git a/be/src/exprs/math_functions.h b/be/src/exprs/math_functions.h index 6a5b3abc850aa..87ab849df393e 100644 --- a/be/src/exprs/math_functions.h +++ b/be/src/exprs/math_functions.h @@ -167,6 +167,15 @@ class MathFunctions { template DEFINE_VECTORIZED_FN(cosine_similarity); + template + DEFINE_VECTORIZED_FN(cosine_similarity2); + + template + DEFINE_VECTORIZED_FN(l2_distance); + + template + DEFINE_VECTORIZED_FN(l2_distance2); + /** * @param columns: [DoubleColumn] * @return BigIntColumn diff --git a/be/src/storage/CMakeLists.txt b/be/src/storage/CMakeLists.txt index bb75e11bddb04..ee184bbdb0601 100644 --- a/be/src/storage/CMakeLists.txt +++ b/be/src/storage/CMakeLists.txt @@ -254,8 +254,12 @@ add_library(Storage STATIC index/inverted/clucene/clucene_inverted_writer.cpp index/inverted/clucene/clucene_inverted_reader.cpp index/inverted/clucene/match_operator.cpp + index/vector/empty_index_reader.cpp index/vector/vector_index_builder_factory.cpp index/vector/vector_index_writer.cpp index/vector/vector_index_builder.cpp + index/vector/vector_index_reader_factory.cpp + index/vector/tenann_index_reader.cpp + index/vector/tenann/del_id_filter.cpp index/vector/tenann/tenann_index_builder.cpp index/vector/tenann/tenann_index_utils.cpp) diff --git a/be/src/storage/index/index_descriptor.h b/be/src/storage/index/index_descriptor.h index 4891d18d23627..e330046f64b3d 100644 --- a/be/src/storage/index/index_descriptor.h +++ b/be/src/storage/index/index_descriptor.h @@ -25,6 +25,7 @@ namespace starrocks { class IndexDescriptor { public: inline static const std::string mark_word = "TENANNEMPTYMARK"; + inline static const int64_t mark_word_len = 15; static StatusOr get_index_file_path(const IndexType index_type, const std::string& rowset_dir, const std::string& rowset_id, int segment_id, int64_t index_id) { diff --git a/be/src/storage/index/vector/empty_index_reader.cpp b/be/src/storage/index/vector/empty_index_reader.cpp new file mode 100644 index 0000000000000..ca625892e1cd5 --- /dev/null +++ b/be/src/storage/index/vector/empty_index_reader.cpp @@ -0,0 +1,17 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "storage/index/vector/empty_index_reader.h" + +namespace starrocks {} // namespace starrocks \ No newline at end of file diff --git a/be/src/storage/index/vector/empty_index_reader.h b/be/src/storage/index/vector/empty_index_reader.h new file mode 100644 index 0000000000000..f3d541bcffbcc --- /dev/null +++ b/be/src/storage/index/vector/empty_index_reader.h @@ -0,0 +1,41 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "storage/index/vector/vector_index_reader.h" + +namespace starrocks { + +class EmptyIndexReader final : public VectorIndexReader { +public: + ~EmptyIndexReader() override = default; + + Status init_searcher(const tenann::IndexMeta& meta, const std::string& index_path) override { + return Status::NotSupported("Not implement"); + } + + Status search(tenann::PrimitiveSeqView query_vector, int k, int64_t* result_ids, uint8_t* result_distances, + tenann::IdFilter* id_filter = nullptr) override { + return Status::NotSupported("Not implement"); + } + + Status range_search(tenann::PrimitiveSeqView query_vector, int k, std::vector* result_ids, + std::vector* result_distances, tenann::IdFilter* id_filter, float range, + int order) override { + return Status::NotSupported("Not implement"); + } +}; + +} // namespace starrocks \ No newline at end of file diff --git a/be/src/storage/index/vector/tenann/del_id_filter.cpp b/be/src/storage/index/vector/tenann/del_id_filter.cpp new file mode 100644 index 0000000000000..350ef11b7852a --- /dev/null +++ b/be/src/storage/index/vector/tenann/del_id_filter.cpp @@ -0,0 +1,51 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is based on code available under the Apache license here: +// https://github.com/apache/incubator-doris/blob/master/be/src/olap/rowset/segment_v2/column_reader.h + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "storage/index/vector/tenann/del_id_filter.h" + +#include "storage/range.h" +#include "storage/roaring2range.h" +#include "tenann/common/type_traits.h" +#include "tenann/searcher/id_filter.h" + +namespace starrocks { + +DelIdFilter::DelIdFilter(const SparseRange<>& scan_range) + : _scan_range(scan_range), _row_bitmap(range2roaring(scan_range)) {} + +bool DelIdFilter::IsMember(tenann::idx_t id) const { + return _row_bitmap.contains(id); +} + +} // namespace starrocks \ No newline at end of file diff --git a/be/src/storage/index/vector/tenann/del_id_filter.h b/be/src/storage/index/vector/tenann/del_id_filter.h new file mode 100644 index 0000000000000..8735c1658e06f --- /dev/null +++ b/be/src/storage/index/vector/tenann/del_id_filter.h @@ -0,0 +1,54 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is based on code available under the Apache license here: +// https://github.com/apache/incubator-doris/blob/master/be/src/olap/rowset/segment_v2/column_reader.h + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "storage/del_vector.h" +#include "storage/range.h" +#include "tenann/common/type_traits.h" +#include "tenann/searcher/id_filter.h" + +namespace starrocks { + +class DelIdFilter : public tenann::IdFilter { +public: + DelIdFilter(const SparseRange<>& scan_range); + ~DelIdFilter() = default; + + bool IsMember(tenann::idx_t id) const override; + +private: + const SparseRange<>& _scan_range; + Roaring _row_bitmap; +}; + +} // namespace starrocks \ No newline at end of file diff --git a/be/src/storage/index/vector/tenann_index_reader.cpp b/be/src/storage/index/vector/tenann_index_reader.cpp new file mode 100644 index 0000000000000..d40630a3f7d5b --- /dev/null +++ b/be/src/storage/index/vector/tenann_index_reader.cpp @@ -0,0 +1,98 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is based on code available under the Apache license here: +// https://github.com/apache/incubator-doris/blob/master/be/src/olap/tablet.h + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "tenann_index_reader.h" + +#include "common/config.h" +#include "common/status.h" +#include "tenann/common/error.h" +#include "tenann/common/seq_view.h" +#include "tenann/searcher/id_filter.h" + +namespace starrocks { + +Status TenANNReader::init_searcher(const tenann::IndexMeta& meta, const std::string& index_path) { + try { + auto meta_copy = meta; + if (meta.index_type() == tenann::IndexType::kFaissIvfPq) { + if (config::enable_vector_index_block_cache) { + // cache index blocks + meta_copy.index_reader_options()[tenann::IndexReaderOptions::cache_index_file_key] = false; + meta_copy.index_reader_options()[tenann::IndexReaderOptions::cache_index_block_key] = true; + } else { + // cache index file + meta_copy.index_reader_options()[tenann::IndexReaderOptions::cache_index_file_key] = true; + meta_copy.index_reader_options()[tenann::IndexReaderOptions::cache_index_block_key] = false; + } + } else { + // cache index file + meta_copy.index_reader_options()[tenann::IndexReaderOptions::cache_index_file_key] = true; + } + + tenann::IndexCache::GetGlobalInstance()->SetCapacity(config::vector_query_cache_capacity); + + _searcher = tenann::AnnSearcherFactory::CreateSearcherFromMeta(meta_copy); + _searcher->index_reader()->SetIndexCache(tenann::IndexCache::GetGlobalInstance()); + _searcher->ReadIndex(index_path); + + DCHECK(_searcher->is_index_loaded()); + } catch (tenann::Error& e) { + return Status::InternalError(e.what()); + } + return Status::OK(); +} + +Status TenANNReader::search(tenann::PrimitiveSeqView query_vector, int k, int64_t* result_ids, + uint8_t* result_distances, tenann::IdFilter* id_filter) { + try { + _searcher->AnnSearch(query_vector, k, result_ids, result_distances, id_filter); + } catch (tenann::Error& e) { + return Status::InternalError(e.what()); + } + return Status::OK(); +}; + +Status TenANNReader::range_search(tenann::PrimitiveSeqView query_vector, int k, std::vector* result_ids, + std::vector* result_distances, tenann::IdFilter* id_filter, float range, + int order) { + try { + _searcher->RangeSearch(query_vector, range, k, tenann::AnnSearcher::ResultOrder(order), result_ids, + result_distances, id_filter); + } catch (tenann::Error& e) { + return Status::InternalError(e.what()); + } + return Status::OK(); +}; + +} // namespace starrocks \ No newline at end of file diff --git a/be/src/storage/index/vector/tenann_index_reader.h b/be/src/storage/index/vector/tenann_index_reader.h new file mode 100644 index 0000000000000..0751d86e262d6 --- /dev/null +++ b/be/src/storage/index/vector/tenann_index_reader.h @@ -0,0 +1,47 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "common/status.h" +#include "storage/index/vector/vector_index_reader.h" +#include "tenann/common/seq_view.h" +#include "tenann/common/type_traits.h" +#include "tenann/factory/ann_searcher_factory.h" +#include "tenann/factory/index_factory.h" +#include "tenann/searcher/ann_searcher.h" +#include "tenann/searcher/faiss_hnsw_ann_searcher.h" +#include "tenann/searcher/id_filter.h" +#include "tenann/store/index_meta.h" + +namespace starrocks { + +class TenANNReader final : public VectorIndexReader { +public: + TenANNReader() = default; + ~TenANNReader() override{}; + + Status init_searcher(const tenann::IndexMeta& meta, const std::string& index_path) override; + + Status search(tenann::PrimitiveSeqView query_vector, int k, int64_t* result_ids, uint8_t* result_distances, + tenann::IdFilter* id_filter = nullptr) override; + Status range_search(tenann::PrimitiveSeqView query_vector, int k, std::vector* result_ids, + std::vector* result_distances, tenann::IdFilter* id_filter, float range, + int order) override; + +private: + std::shared_ptr _searcher; +}; + +} // namespace starrocks diff --git a/be/src/storage/index/vector/vector_index_reader.h b/be/src/storage/index/vector/vector_index_reader.h new file mode 100644 index 0000000000000..d43d8d0c58483 --- /dev/null +++ b/be/src/storage/index/vector/vector_index_reader.h @@ -0,0 +1,58 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is based on code available under the Apache license here: +// https://github.com/apache/incubator-doris/blob/master/be/src/olap/rowset/segment_v2/column_reader.h + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "common/status.h" +#include "tenann/common/seq_view.h" +#include "tenann/searcher/id_filter.h" +#include "tenann/store/index_meta.h" + +namespace starrocks { + +class VectorIndexReader { +public: + VectorIndexReader() = default; + virtual ~VectorIndexReader() = default; + + virtual Status init_searcher(const tenann::IndexMeta& meta, const std::string& index_path) = 0; + + virtual Status search(tenann::PrimitiveSeqView query_vector, int k, int64_t* result_ids, uint8_t* result_distances, + tenann::IdFilter* id_filter = nullptr) = 0; + virtual Status range_search(tenann::PrimitiveSeqView query_vector, int k, std::vector* result_ids, + std::vector* result_distances, tenann::IdFilter* id_filter, float range, + int order) = 0; +}; + +} // namespace starrocks diff --git a/be/src/storage/index/vector/vector_index_reader_factory.cpp b/be/src/storage/index/vector/vector_index_reader_factory.cpp new file mode 100644 index 0000000000000..1eebbdd9be84a --- /dev/null +++ b/be/src/storage/index/vector/vector_index_reader_factory.cpp @@ -0,0 +1,44 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "storage/index/vector/vector_index_reader_factory.h" + +#include "storage/index/vector/empty_index_reader.h" +#include "storage/index/vector/tenann_index_reader.h" +#include "storage/index/vector/vector_index_reader.h" + +namespace starrocks { +Status VectorIndexReaderFactory::create_from_file(const std::string& index_path, + const std::shared_ptr& index_meta, + std::shared_ptr* vector_index_reader) { + if (!fs::path_exist(index_path)) { + return Status::NotFound(fmt::format("index path {} not found", index_path)); + } + ASSIGN_OR_RETURN(auto index_file, fs::new_random_access_file(index_path)) + ASSIGN_OR_RETURN(auto file_size, index_file->get_size()) + + if (file_size == IndexDescriptor::mark_word_len) { + std::unique_ptr buf(new (std::nothrow) unsigned char[file_size]); + index_file->read_fully(buf.get(), file_size); + std::string buf_str(reinterpret_cast(buf.get()), file_size); + if (buf_str == IndexDescriptor::mark_word) { + (*vector_index_reader) = std::make_shared(); + return Status::OK(); + } + } + (*vector_index_reader) = std::make_shared(); + return Status::OK(); +} + +} // namespace starrocks diff --git a/be/src/storage/index/vector/vector_index_reader_factory.h b/be/src/storage/index/vector/vector_index_reader_factory.h new file mode 100644 index 0000000000000..c0714e2d094b7 --- /dev/null +++ b/be/src/storage/index/vector/vector_index_reader_factory.h @@ -0,0 +1,30 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fmt/format.h" +#include "fs/fs_util.h" +#include "storage/index/index_descriptor.h" +#include "vector_index_reader.h" + +namespace starrocks { + +class VectorIndexReaderFactory { +public: + static Status create_from_file(const std::string& index_path, const std::shared_ptr& index_meta, + std::shared_ptr* vector_index_reader); +}; + +} // namespace starrocks \ No newline at end of file diff --git a/be/src/storage/olap_common.h b/be/src/storage/olap_common.h index ebf1c111dd47a..c371588d71165 100644 --- a/be/src/storage/olap_common.h +++ b/be/src/storage/olap_common.h @@ -262,6 +262,9 @@ struct OlapReaderStatistics { int64_t rows_bitmap_index_filtered = 0; int64_t bitmap_index_filter_timer = 0; + int64_t get_row_ranges_by_vector_index_timer = 0; + int64_t vector_search_timer = 0; + int64_t process_vector_distance_and_id_timer = 0; int64_t rows_del_vec_filtered = 0; diff --git a/be/src/storage/projection_iterator.cpp b/be/src/storage/projection_iterator.cpp index 64c5b820703bb..fa8f483adc6df 100644 --- a/be/src/storage/projection_iterator.cpp +++ b/be/src/storage/projection_iterator.cpp @@ -28,6 +28,16 @@ class ProjectionIterator final : public ChunkIterator { build_index_map(this->_schema, _child->schema()); } + ProjectionIterator(Schema schema, ChunkIteratorPtr child, bool has_vector_index_search, int vector_column_id, + SlotId vector_slot_id, std::string vector_distance_column_name) + : ChunkIterator(std::move(schema), child->chunk_size()), _child(std::move(child)) { + build_index_map(this->_schema, _child->schema()); + _has_vector_index_search = has_vector_index_search; + _vector_column_id = vector_column_id; + _vector_slot_id = vector_slot_id; + _vector_distance_column_name = vector_distance_column_name; + } + void close() override; size_t merged_rows() const override { return _child->merged_rows(); } @@ -50,11 +60,16 @@ class ProjectionIterator final : public ChunkIterator { private: void build_index_map(const Schema& output, const Schema& input); + FieldPtr _make_field(size_t i); ChunkIteratorPtr _child; // mapping from index of column in output chunk to index of column in input chunk. std::vector _index_map; ChunkPtr _chunk; + bool _has_vector_index_search = false; + int _vector_column_id; + SlotId _vector_slot_id; + std::string _vector_distance_column_name; }; void ProjectionIterator::build_index_map(const Schema& output, const Schema& input) { @@ -84,6 +99,11 @@ Status ProjectionIterator::do_get_next(Chunk* chunk) { for (size_t i = 0; i < _index_map.size(); i++) { chunk->get_column_by_index(i).swap(input_columns[_index_map[i]]); } + if (_has_vector_index_search) { + std::shared_ptr distance_column = FloatColumn::create(); + distance_column.swap(_chunk->get_column_by_name(_vector_distance_column_name)); + chunk->append_vector_column(distance_column, _make_field(_vector_column_id), _vector_slot_id); + } } #ifndef NDEBUG if (st.ok()) { @@ -107,4 +127,15 @@ ChunkIteratorPtr new_projection_iterator(const Schema& schema, const ChunkIterat return std::make_shared(schema, child); } +ChunkIteratorPtr new_projection_iterator(const Schema& schema, const ChunkIteratorPtr& child, + bool _has_vector_index_search, int vector_column_id, SlotId vector_slot_id, + std::string vector_distance_column_name) { + return std::make_shared(schema, child, _has_vector_index_search, vector_column_id, + vector_slot_id, vector_distance_column_name); +} + +FieldPtr ProjectionIterator::_make_field(size_t i) { + return std::make_shared(i, _vector_distance_column_name, get_type_info(TYPE_FLOAT), false); +} + } // namespace starrocks diff --git a/be/src/storage/projection_iterator.h b/be/src/storage/projection_iterator.h index c19fa1f96ce78..da60bb4a352b7 100644 --- a/be/src/storage/projection_iterator.h +++ b/be/src/storage/projection_iterator.h @@ -26,4 +26,8 @@ namespace starrocks { // columns in the order `c3`, `c1`. ChunkIteratorPtr new_projection_iterator(const Schema& schema, const ChunkIteratorPtr& child); +ChunkIteratorPtr new_projection_iterator(const Schema& schema, const ChunkIteratorPtr& child, + bool _has_vector_index_search, int vector_column_id, SlotId vector_slot_id, + std::string vector_distance_column_name); + } // namespace starrocks diff --git a/be/src/storage/rowset/rowset.cpp b/be/src/storage/rowset/rowset.cpp index 7c25754b08c13..862309aa5e46e 100644 --- a/be/src/storage/rowset/rowset.cpp +++ b/be/src/storage/rowset/rowset.cpp @@ -708,6 +708,19 @@ Status Rowset::get_segment_iterators(const Schema& schema, const RowsetReadOptio seg_options.runtime_range_pruner = options.runtime_range_pruner; seg_options.column_access_paths = options.column_access_paths; seg_options.tablet_schema = options.tablet_schema; + seg_options.use_vector_index = options.use_vector_index; + seg_options.k = options.k; + seg_options.query_vector = options.query_vector; + seg_options.vector_distance_column_name = options.vector_distance_column_name; + seg_options.vector_column_id = options.vector_column_id; + seg_options.vector_slot_id = options.vector_slot_id; + seg_options.query_params = options.query_params; + seg_options.vector_range = options.vector_range; + seg_options.result_order = options.result_order; + seg_options.use_ivfpq = options.use_ivfpq; + seg_options.pq_refine_factor = options.pq_refine_factor; + seg_options.k_factor = options.k_factor; + if (options.delete_predicates != nullptr) { seg_options.delete_predicates = options.delete_predicates->get_predicates(end_version()); } diff --git a/be/src/storage/rowset/rowset_options.h b/be/src/storage/rowset/rowset_options.h index 00640b8392f36..40fdcd8f9ad8d 100644 --- a/be/src/storage/rowset/rowset_options.h +++ b/be/src/storage/rowset/rowset_options.h @@ -87,6 +87,31 @@ class RowsetReadOptions { bool prune_column_after_index_filter = false; bool enable_gin_filter = false; + + // vector index params + int64_t k; + + std::vector query_vector; + + bool use_vector_index = false; + + bool use_ivfpq = false; + + std::string vector_distance_column_name; + + int vector_column_id; + + SlotId vector_slot_id; + + std::map query_params; + + double vector_range; + + int result_order; + + double pq_refine_factor; + + double k_factor; }; } // namespace starrocks diff --git a/be/src/storage/rowset/segment_iterator.cpp b/be/src/storage/rowset/segment_iterator.cpp index e34ffe1e72b27..c30225a63c398 100644 --- a/be/src/storage/rowset/segment_iterator.cpp +++ b/be/src/storage/rowset/segment_iterator.cpp @@ -40,6 +40,10 @@ #include "storage/column_predicate_rewriter.h" #include "storage/del_vector.h" #include "storage/index/index_descriptor.h" +#include "storage/index/vector/tenann/del_id_filter.h" +#include "storage/index/vector/tenann/tenann_index_utils.h" +#include "storage/index/vector/vector_index_reader.h" +#include "storage/index/vector/vector_index_reader_factory.h" #include "storage/lake/update_manager.h" #include "storage/olap_runtime_range_pruner.hpp" #include "storage/projection_iterator.h" @@ -212,8 +216,10 @@ class SegmentIterator final : public ChunkIterator { StatusOr> _get_row_ranges_by_key_ranges(); StatusOr> _get_row_ranges_by_short_key_ranges(); Status _get_row_ranges_by_zone_map(); + Status _get_row_ranges_by_vector_index(); Status _get_row_ranges_by_bloom_filter(); Status _get_row_ranges_by_rowid_range(); + Status _get_row_ranges_by_row_ids(std::vector* result_ids, SparseRange<>* r); uint32_t segment_id() const { return _segment->id(); } uint32_t num_rows() const { return _segment->num_rows(); } @@ -248,6 +254,8 @@ class SegmentIterator final : public ChunkIterator { Status _encode_to_global_id(ScanContext* ctx); + FieldPtr _make_field(size_t i); + Status _switch_context(ScanContext* to); // `_check_low_cardinality_optimization` and `_init_column_iterators` must have been called @@ -286,6 +294,9 @@ class SegmentIterator final : public ChunkIterator { bool need_early_materialize_subfield(const FieldPtr& field); + Status _unfold_vector_distance_column(std::vector* result_distances, ColumnPtr& vector_column_iterator); + Status _init_ann_reader(); + private: using RawColumnIterators = std::vector>; using ColumnDecoders = std::vector; @@ -295,6 +306,7 @@ class SegmentIterator final : public ChunkIterator { RawColumnIterators _column_iterators; std::vector _io_coalesce_column_index; ColumnDecoders _column_decoders; + std::shared_ptr _ann_reader; BitmapIndexEvaluator _bitmap_index_evaluator; // delete predicates std::map _del_predicates; @@ -349,6 +361,25 @@ class SegmentIterator final : public ChunkIterator { std::unordered_map _predicate_column_access_paths; std::unordered_set _prune_cols_candidate_by_inverted_index; + + // vector index params + int64_t _k; + tenann::PrimitiveSeqView _query_view; + bool _use_vector_index; + std::string _vector_distance_column_name; + std::shared_ptr _index_meta; + int _vector_column_id; + SlotId _vector_slot_id; + std::unordered_map _id2distance_map; + std::vector _first_rowids; + std::map _query_params; + double _vector_range; + int _result_order; + bool _use_ivfpq; + Buffer _filter_selection; + Buffer _filter_by_expr_selection; + + Status _init_reader_from_file(const std::string& index_path, const std::shared_ptr& tablet_index_meta); }; SegmentIterator::SegmentIterator(std::shared_ptr segment, Schema schema, SegmentReadOptions options) @@ -356,7 +387,28 @@ SegmentIterator::SegmentIterator(std::shared_ptr segment, Schema schema _segment(std::move(segment)), _opts(std::move(options)), _bitmap_index_evaluator(_schema, _opts.pred_tree), - _predicate_columns(_opts.pred_tree.num_columns()) { + _predicate_columns(_opts.pred_tree.num_columns()), + _use_vector_index(_opts.use_vector_index) { + if (_use_vector_index) { + // The K in front of Fe is long, which can be changed to uint32. This can be a problem, + // but this k is wasted memory allocation, so it should not exceed the accuracy of uint32 + // options.query_vector is a string, passed to tenann as a float string, see if you need to use the stof function to convert + // and consider precision loss + _vector_distance_column_name = _opts.vector_distance_column_name; + _vector_column_id = _opts.vector_column_id; + _vector_slot_id = _opts.vector_slot_id; + _vector_range = _opts.vector_range; + _result_order = _opts.result_order; + _use_ivfpq = _opts.use_ivfpq; + if (_vector_range > 0 && _use_ivfpq) { + _k = _opts.k * _opts.pq_refine_factor * _opts.k_factor; + } else { + _k = _opts.k * _opts.k_factor; + } + _query_view = tenann::PrimitiveSeqView{.data = reinterpret_cast(_opts.query_vector.data()), + .size = static_cast(_opts.query_vector.size()), + .elem_type = tenann::PrimitiveType::kFloatType}; + } // For small segment file (the number of rows is less than chunk_size), // the segment iterator will reserve a large amount of memory, // especially when there are many columns, many small files, many versions, @@ -432,7 +484,7 @@ Status SegmentIterator::_init() { _init_column_access_paths(); RETURN_IF_ERROR(_check_low_cardinality_optimization()); RETURN_IF_ERROR(_init_column_iterators(_schema)); - + RETURN_IF_ERROR(_init_ann_reader()); // filter by index stage // Use indexes and predicates to filter some data page RETURN_IF_ERROR(_get_row_ranges_by_rowid_range()); @@ -449,6 +501,7 @@ Status SegmentIterator::_init() { if (apply_del_vec_after_all_index_filter) { RETURN_IF_ERROR(_apply_del_vector()); } + RETURN_IF_ERROR(_get_row_ranges_by_vector_index()); // rewrite stage // Rewriting predicates using segment dictionary codes RETURN_IF_ERROR(_rewrite_predicates()); @@ -469,6 +522,158 @@ Status SegmentIterator::_init() { return Status::OK(); } +Status SegmentIterator::_unfold_vector_distance_column(std::vector* result_distances, + ColumnPtr& vector_column_iterator) { + std::shared_ptr distance_column = UInt8Column::create(); + distance_column->append(*result_distances); + vector_column_iterator = distance_column; + return Status::OK(); +} + +inline Status SegmentIterator::_init_reader_from_file(const std::string& index_path, + const std::shared_ptr& tablet_index_meta) { + // todo: support query params + ASSIGN_OR_RETURN(auto meta, get_vector_meta(tablet_index_meta, std::map{})) + _index_meta = std::make_shared(meta); + RETURN_IF_ERROR(VectorIndexReaderFactory::create_from_file(index_path, _index_meta, &_ann_reader)); + auto status = _ann_reader->init_searcher(*_index_meta.get(), index_path); + // means empty ann reader + if (status.is_not_supported()) { + _use_vector_index = false; + return Status::OK(); + } + return status; +} + +Status SegmentIterator::_init_ann_reader() { + RETURN_IF(!_use_vector_index, Status::OK()); + std::unordered_map col_map_index; + for (const auto& index : *_segment->tablet_schema().indexes()) { + if (index.index_type() == VECTOR) { + col_map_index.emplace(index.col_unique_ids()[0], index); + } + } + + std::vector hit_indexes; + for (auto& field : _schema.fields()) { + if (col_map_index.count(field->uid()) > 0) { + hit_indexes.emplace_back(col_map_index.at(field->uid())); + } + } + + // TODO: Support more index in one segment iterator, only support one index for now + CHECK(hit_indexes.size() <= 1) << "Only support query no more than one index now"; + + if (hit_indexes.empty()) { + return Status::OK(); + } + + auto tablet_index_meta = std::make_shared(hit_indexes[0]); + // TODO: to support below when tenann support this interface + // tenann::IndexMeta::Deserialize(tablet_index_meta.properties_to_json()); + + // std::string index_path = IndexDescriptor::vector_index_file_path(_opts.rowset_path, _opts.rowsetid.to_string(), + // segment_id(), tablet_index_meta->index_id()); + + return _init_reader_from_file("index_path", tablet_index_meta); +} + +Status SegmentIterator::_get_row_ranges_by_vector_index() { + RETURN_IF(!_use_vector_index, Status::OK()); + RETURN_IF(_scan_range.empty(), Status::OK()); + + SCOPED_RAW_TIMER(&_opts.stats->get_row_ranges_by_vector_index_timer); + + Status st; + std::map id2distance_map; + std::vector result_ids; + std::vector result_distances; + std::vector filtered_result_ids; + DelIdFilter del_id_filter(_scan_range); + + { + SCOPED_RAW_TIMER(&_opts.stats->vector_search_timer); + if (_vector_range > 0) { + st = _ann_reader->range_search(_query_view, _k, &result_ids, &result_distances, &del_id_filter, + static_cast(_vector_range), _result_order); + if (!st.ok()) { + // fall over to normal search and host range search to column expression predicate + LOG(WARNING) << "Vector index search failed with range_serach: " << st.to_string(); + result_ids.resize(_k); + result_distances.resize(_k); + st = _ann_reader->search(_query_view, _k, (result_ids.data()), + reinterpret_cast(result_distances.data()), &del_id_filter); + } + } else { + result_ids.resize(_k); + result_distances.resize(_k); + st = _ann_reader->search(_query_view, _k, (result_ids.data()), + reinterpret_cast(result_distances.data()), &del_id_filter); + } + } + + if (!st.ok()) { + LOG(WARNING) << "Vector index search failed: " << st.to_string(); + return Status::InternalError(st.to_string()); + } + SCOPED_RAW_TIMER(&_opts.stats->process_vector_distance_and_id_timer); + + for (size_t i = 0; i < result_ids.size() && result_ids[i] != -1; i++) { + id2distance_map[result_ids[i]] = result_distances[i]; + } + + SparseRange r; + RETURN_IF_ERROR(_get_row_ranges_by_row_ids(&result_ids, &r)); + + size_t prev_size = _scan_range.span_size(); + _scan_range = _scan_range.intersection(r); + _opts.stats->rows_stats_filtered += (prev_size - _scan_range.span_size()); + + SparseRangeIterator range_iter = _scan_range.new_iterator(); + size_t to_read = _scan_range.span_size(); + while (range_iter.has_more()) { + Range r = range_iter.next(to_read); + for (uint32_t i = r.begin(); i < r.end(); i++) { + filtered_result_ids.emplace_back(i); + } + } + + _id2distance_map.reserve(filtered_result_ids.size()); + for (size_t i = 0; i < filtered_result_ids.size(); i++) { + _id2distance_map[static_cast(filtered_result_ids[i])] = id2distance_map[filtered_result_ids[i]]; + } + return Status::OK(); +} + +Status SegmentIterator::_get_row_ranges_by_row_ids(std::vector* result_ids, SparseRange<>* r) { + if (result_ids->empty()) { + return Status::OK(); + } + + std::sort(result_ids->begin(), result_ids->end()); + + // filter -1 above + size_t first_valid_id = 0; + for (; (*result_ids)[first_valid_id] < 0; first_valid_id++); + + int64_t range_start = (*result_ids)[first_valid_id]; + int64_t range_end = range_start + 1; + + for (size_t i = first_valid_id + 1; i < result_ids->size(); ++i) { + if ((*result_ids)[i] == range_end) { + ++range_end; + } else { + r->add(Range<>(range_start, range_end)); + range_start = (*result_ids)[i]; + range_end = range_start + 1; + } + } + + r->add(Range<>(range_start, range_end)); + + return Status::OK(); +} + Status SegmentIterator::_try_to_update_ranges_by_runtime_filter() { return _opts.runtime_range_pruner.update_range_if_arrived( _opts.global_dictmaps, @@ -1096,6 +1301,12 @@ inline Status SegmentIterator::_read(Chunk* chunk, vector* rowids, size chunk->check_or_die(); } + if (_use_vector_index) { + for (uint32_t i = range.begin(); i < range.end(); i++) { + _first_rowids.push_back(i); + } + } + if (rowids != nullptr) { rowids->reserve(rowids->size() + n); SparseRangeIterator<> iter = range.new_iterator(); @@ -1277,6 +1488,37 @@ Status SegmentIterator::_do_get_next(Chunk* result, vector* rowid) { chunk = _context->_adapt_global_dict_chunk.get(); } + if (_use_vector_index && !_use_ivfpq) { + std::shared_ptr distance_column = FloatColumn::create(); + vector rowids; + for (const auto& id : _first_rowids) { + auto it = _id2distance_map.find(id); + if (it != _id2distance_map.end()) { + rowids.emplace_back(it->first); + } + } + if (rowid != nullptr) { + for (auto id : *rowid) { + auto it = _id2distance_map.find(id); + if (it != _id2distance_map.end()) { + rowids.emplace_back(it->first); + } + } + } + if (!rowids.empty()) { + // is sorted? + std::sort(rowids.begin(), rowids.end()); + for (const auto& vrid : rowids) { + distance_column->append(_id2distance_map[vrid]); + } + } + if (has_non_expr_predicate && _filter_selection.size() == distance_column->size()) { + distance_column->filter_range(_filter_selection, 0, distance_column->size()); + } + // distance_column->filter_range(_filter_by_expr_selection, 0, distance_column->size()); + chunk->append_vector_column(distance_column, _make_field(_vector_column_id), _vector_slot_id); + } + result->swap_chunk(*chunk); if (need_switch_context) { @@ -1286,6 +1528,10 @@ Status SegmentIterator::_do_get_next(Chunk* result, vector* rowid) { return Status::OK(); } +FieldPtr SegmentIterator::_make_field(size_t i) { + return std::make_shared(i, _vector_distance_column_name, get_type_info(TYPE_FLOAT), false); +} + Status SegmentIterator::_switch_context(ScanContext* to) { if (_context != nullptr) { const ordinal_t ordinal = _context->_column_iterators[0]->get_current_ordinal(); @@ -1379,6 +1625,9 @@ StatusOr SegmentIterator::_filter_by_non_expr_predicates(Chunk* chunk, } } _opts.stats->rows_vec_cond_filtered += (to - chunk_size); + for (int i = from; i < to; i++) { + _filter_selection.push_back(_selection[i]); + } return chunk_size; } @@ -1404,6 +1653,9 @@ StatusOr SegmentIterator::_filter_by_expr_predicates(Chunk* chunk, vec } } _opts.stats->rows_vec_cond_filtered += (chunk_size - new_size); + for (int i = 0; i < chunk_size; i++) { + _filter_by_expr_selection.push_back(_selection[i]); + } chunk_size = new_size; } return chunk_size; @@ -2194,7 +2446,9 @@ ChunkIteratorPtr new_segment_iterator(const std::shared_ptr& segment, c } else { Schema ordered_schema = reorder_schema(schema, options.pred_tree); auto seg_iter = std::make_shared(segment, ordered_schema, options); - return new_projection_iterator(schema, seg_iter); + return new_projection_iterator(schema, seg_iter, options.use_vector_index && !options.use_ivfpq, + options.vector_column_id, options.vector_slot_id, + options.vector_distance_column_name); } } diff --git a/be/src/storage/rowset/segment_options.h b/be/src/storage/rowset/segment_options.h index 89d94a1955c1a..9706b6eaa85ea 100644 --- a/be/src/storage/rowset/segment_options.h +++ b/be/src/storage/rowset/segment_options.h @@ -104,6 +104,29 @@ class SegmentReadOptions { bool prune_column_after_index_filter = false; bool enable_gin_filter = false; + int64_t k; + std::vector query_vector; + + bool use_vector_index = false; + + std::string vector_distance_column_name; + + int vector_column_id; + + SlotId vector_slot_id; + + std::map query_params; + + double vector_range; + + int result_order; + + bool use_ivfpq; + + double pq_refine_factor; + + double k_factor; + public: Status convert_to(SegmentReadOptions* dst, const std::vector& new_types, ObjectPool* obj_pool) const; diff --git a/be/src/storage/tablet.h b/be/src/storage/tablet.h index 94c3c11a3565b..fd1b7debc73c9 100644 --- a/be/src/storage/tablet.h +++ b/be/src/storage/tablet.h @@ -119,6 +119,7 @@ class Tablet : public BaseTablet { size_t num_rows_per_row_block_with_max_version() const; size_t next_unique_id() const; size_t field_index_with_max_version(const string& field_name) const; + size_t field_index(const string& field_name, const string& extra_column_name) const; std::string schema_debug_string() const; std::string debug_string() const; bool enable_shortcut_compaction() const; @@ -505,6 +506,10 @@ inline size_t Tablet::field_index_with_max_version(const string& field_name) con return tablet_schema()->field_index(field_name); } +inline size_t Tablet::field_index(const string& field_name, const string& extra_column_name) const { + return tablet_schema()->field_index(field_name, extra_column_name); +} + inline bool Tablet::enable_shortcut_compaction() const { std::shared_lock rdlock(_meta_lock); return _tablet_meta->enable_shortcut_compaction(); diff --git a/be/src/storage/tablet_reader.cpp b/be/src/storage/tablet_reader.cpp index fd68e6e5f0950..9a2d3b3c0e1c4 100644 --- a/be/src/storage/tablet_reader.cpp +++ b/be/src/storage/tablet_reader.cpp @@ -218,6 +218,18 @@ Status TabletReader::_init_collector_for_pk_index_read() { rs_opts.runtime_range_pruner = _reader_params->runtime_range_pruner; // single row fetch, no need to use delvec rs_opts.is_primary_keys = false; + rs_opts.use_vector_index = _reader_params->use_vector_index; + rs_opts.k = _reader_params->k; + rs_opts.query_vector = _reader_params->query_vector; + rs_opts.vector_distance_column_name = _reader_params->vector_distance_column_name; + rs_opts.vector_range = _reader_params->vector_range; + rs_opts.result_order = _reader_params->result_order; + rs_opts.use_ivfpq = _reader_params->use_ivfpq; + rs_opts.vector_column_id = _reader_params->vector_column_id; + rs_opts.vector_slot_id = _reader_params->vector_slot_id; + rs_opts.query_params = _reader_params->query_params; + rs_opts.pq_refine_factor = _reader_params->pq_refine_factor; + rs_opts.k_factor = _reader_params->k_factor; rs_opts.rowid_range_option = std::make_shared(); auto rowid_range = std::make_shared>(); @@ -290,6 +302,18 @@ Status TabletReader::get_segment_iterators(const TabletReaderParams& params, std rs_opts.unused_output_column_ids = params.unused_output_column_ids; rs_opts.runtime_range_pruner = params.runtime_range_pruner; rs_opts.column_access_paths = params.column_access_paths; + rs_opts.k = params.k; + rs_opts.query_vector = params.query_vector; + rs_opts.use_vector_index = params.use_vector_index; + rs_opts.vector_distance_column_name = params.vector_distance_column_name; + rs_opts.vector_column_id = params.vector_column_id; + rs_opts.vector_slot_id = params.vector_slot_id; + rs_opts.query_params = params.query_params; + rs_opts.pq_refine_factor = params.pq_refine_factor; + rs_opts.k_factor = params.k_factor; + rs_opts.vector_range = params.vector_range; + rs_opts.result_order = params.result_order; + rs_opts.use_ivfpq = params.use_ivfpq; if (keys_type == KeysType::PRIMARY_KEYS) { rs_opts.is_primary_keys = true; rs_opts.version = _version.second; diff --git a/be/src/storage/tablet_reader_params.h b/be/src/storage/tablet_reader_params.h index eeba0678c128f..54db64e536890 100644 --- a/be/src/storage/tablet_reader_params.h +++ b/be/src/storage/tablet_reader_params.h @@ -94,6 +94,31 @@ struct TabletReaderParams { bool prune_column_after_index_filter = false; bool enable_gin_filter = false; + // vector index params + int64_t k; + + std::vector query_vector; + + std::string vector_distance_column_name; + + bool use_vector_index = false; + + int vector_column_id; + + SlotId vector_slot_id; + + std::map query_params; + + double vector_range; + + int result_order; + + bool use_ivfpq = false; + + double pq_refine_factor; + + double k_factor; + public: std::string to_string() const; }; diff --git a/be/src/storage/tablet_schema.cpp b/be/src/storage/tablet_schema.cpp index 7f02277319cda..e9b15f957e8a6 100644 --- a/be/src/storage/tablet_schema.cpp +++ b/be/src/storage/tablet_schema.cpp @@ -659,6 +659,20 @@ size_t TabletSchema::field_index(std::string_view field_name) const { return -1; } +size_t TabletSchema::field_index(std::string_view field_name, std::string extra_column_name) const { + int ordinal = -1; + for (auto& column : _cols) { + ordinal++; + if (column.name() == field_name) { + return ordinal; + } + } + if (field_name == extra_column_name) { + return ordinal + 1; + } + return -1; +} + int32_t TabletSchema::field_index(int32_t col_unique_id) const { const auto& found = _unique_id_to_index.find(col_unique_id); return (found == _unique_id_to_index.end()) ? -1 : found->second; diff --git a/be/src/storage/tablet_schema.h b/be/src/storage/tablet_schema.h index 36dda51aa3b71..f83ff50964bec 100644 --- a/be/src/storage/tablet_schema.h +++ b/be/src/storage/tablet_schema.h @@ -283,6 +283,7 @@ class TabletSchema { size_t estimate_row_size(size_t variable_len) const; int32_t field_index(int32_t col_unique_id) const; size_t field_index(std::string_view field_name) const; + size_t field_index(std::string_view field_name, std::string extra_column_name) const; const TabletColumn& column(size_t ordinal) const; const std::vector& columns() const; const std::vector sort_key_idxes() const { return _sort_key_idxes; } diff --git a/be/src/util/system_metrics.h b/be/src/util/system_metrics.h index 04bcb34826a68..38652de6f8182 100644 --- a/be/src/util/system_metrics.h +++ b/be/src/util/system_metrics.h @@ -32,6 +32,7 @@ class SnmpMetrics; class QueryCacheMetrics; class VectorIndexCacheMetrics; class RuntimeFilterMetrics; +class VectorIndexCacheMetrics; class MemoryMetrics { public: diff --git a/fe/fe-core/src/main/java/com/starrocks/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/com/starrocks/analysis/FunctionCallExpr.java index 618d1ef1c68ae..ec06feb3b4d3c 100644 --- a/fe/fe-core/src/main/java/com/starrocks/analysis/FunctionCallExpr.java +++ b/fe/fe-core/src/main/java/com/starrocks/analysis/FunctionCallExpr.java @@ -34,6 +34,8 @@ package com.starrocks.analysis; +import org.apache.commons.lang3.StringUtils; + import com.google.common.base.Joiner; import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; @@ -55,6 +57,7 @@ import java.io.IOException; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.Set; import static com.starrocks.catalog.FunctionSet.IGNORE_NULL_WINDOW_FUNCTION; @@ -64,6 +67,8 @@ public class FunctionCallExpr extends Expr { // private BuiltinAggregateFunction.Operator aggOp; private FunctionParams fnParams; + private Optional functionHint = Optional.empty(); + // check analytic function private boolean isAnalyticFnCall = false; @@ -434,6 +439,14 @@ public void setNondeterministicId(ExprId nondeterministicId) { this.nondeterministicId = nondeterministicId; } + public void setFunctionHint(String functionHint) { + this.functionHint = Optional.of(StringUtils.upperCase(functionHint)); + } + + public Optional getFunctionHint() { + return functionHint; + } + public boolean isNondeterministicBuiltinFnName() { return FunctionSet.nonDeterministicFunctions.contains(fnName.getFunction().toLowerCase()); } @@ -454,7 +467,8 @@ public boolean equals(Object obj) { && fnParams.isDistinct() == o.fnParams.isDistinct() && fnParams.isStar() == o.fnParams.isStar() && nondeterministicId.equals(o.nondeterministicId) - && Objects.equals(fnParams.getOrderByElements(), o.fnParams.getOrderByElements()); + && Objects.equals(fnParams.getOrderByElements(), o.fnParams.getOrderByElements()) + && functionHint == o.functionHint; } /** diff --git a/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java b/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java index 57d2cb3f72289..ba642a08aa7fd 100644 --- a/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java @@ -150,6 +150,11 @@ public class FunctionSet { public static final String SHA2 = "sha2"; public static final String SM3 = "sm3"; + // Vector Index functions: + public static final String APPROX_COSINE_SIMILARITY = "approx_cosine_similarity"; + public static final String APPROX_COSINE_SIMILARITY_NORM = "approx_cosine_similarity_norm"; + public static final String APPROX_L2_DISTANCE = "approx_l2_distance"; + // Geo functions: public static final String ST_ASTEXT = "st_astext"; public static final String ST_ASWKT = "st_aswkt"; @@ -624,6 +629,13 @@ public class FunctionSet { .add(SLEEP) .build(); + public static final Set VECTOR_COMPUTE_FUNCTIONS = + ImmutableSet.builder() + .add(APPROX_COSINE_SIMILARITY) + .add(APPROX_COSINE_SIMILARITY_NORM) + .add(APPROX_L2_DISTANCE) + .build(); + // Only use query cache if these time function can be reduced into a constant // date/datetime value after applying FoldConstantRule, otherwise BE would yield // non-deterministic result when these function is delivered to BE. diff --git a/fe/fe-core/src/main/java/com/starrocks/catalog/OlapTable.java b/fe/fe-core/src/main/java/com/starrocks/catalog/OlapTable.java index bcdaefdf55642..95d4f7cb6c207 100644 --- a/fe/fe-core/src/main/java/com/starrocks/catalog/OlapTable.java +++ b/fe/fe-core/src/main/java/com/starrocks/catalog/OlapTable.java @@ -231,6 +231,10 @@ public enum OlapTableState { @SerializedName(value = "bfColumns") protected Set bfColumns; + // distance columns for vector + @SerializedName(value = "distanceColumns") + protected Set distanceColumns; + @SerializedName(value = "bfFpp") protected double bfFpp; @@ -299,6 +303,7 @@ public OlapTable(TableType type) { super(type); this.bfColumns = null; + this.distanceColumns = null; this.bfFpp = 0; this.colocateGroup = null; @@ -1697,6 +1702,10 @@ public Set getBfColumnIds() { return bfColumns; } + public Set getDistanceColumns() { + return distanceColumns; + } + public Set getBfColumnNames() { if (bfColumns == null) { return null; @@ -1734,6 +1743,10 @@ public void setBloomFilterInfo(Set bfColumns, double bfFpp) { this.bfFpp = bfFpp; } + public void setDistanceColumnsInfo(Set distanceColumns) { + this.bfColumns = bfColumns; + } + public void setIndexes(List indexes) { if (this.indexes == null) { this.indexes = new TableIndexes(null); @@ -2833,7 +2846,7 @@ public int getCompressionLevel() { if (tableProperty == null) { return -1; } - return tableProperty.getCompressionLevel(); + return tableProperty.getCompressionLevel(); } public void setPartitionLiveNumber(int number) { diff --git a/fe/fe-core/src/main/java/com/starrocks/catalog/Table.java b/fe/fe-core/src/main/java/com/starrocks/catalog/Table.java index e2ff0071911fa..3224e82d8395b 100644 --- a/fe/fe-core/src/main/java/com/starrocks/catalog/Table.java +++ b/fe/fe-core/src/main/java/com/starrocks/catalog/Table.java @@ -455,6 +455,11 @@ public List getColumns() { return new ArrayList<>(nameToColumn.values()); } + public void addColumn(Column column) { + fullSchema.add(column); + nameToColumn.put(column.getName(), column); + } + public long getCreateTime() { return createTime; } diff --git a/fe/fe-core/src/main/java/com/starrocks/planner/OlapScanNode.java b/fe/fe-core/src/main/java/com/starrocks/planner/OlapScanNode.java index d777871edc7bf..bb82a2c2c1ba2 100644 --- a/fe/fe-core/src/main/java/com/starrocks/planner/OlapScanNode.java +++ b/fe/fe-core/src/main/java/com/starrocks/planner/OlapScanNode.java @@ -756,6 +756,13 @@ protected String getNodeExplainString(String prefix, TExplainLevel detailLevel) output.append(prefix).append("PREAGGREGATION: OFF. Reason: ").append(reasonOfPreAggregation) .append("\n"); } + if (ConnectContext.get() != null) { + if (ConnectContext.get().getSessionVariable().isRewrittenVectorPlan()) { + output.append(prefix).append("VECTORINDEX: ON").append("\n"); + } else { + output.append(prefix).append("VECTORINDEX: OFF").append("\n"); + } + } if (!conjuncts.isEmpty()) { output.append(prefix).append("PREDICATES: ").append( getExplainString(conjuncts)).append("\n"); diff --git a/fe/fe-core/src/main/java/com/starrocks/plugin/AuditEvent.java b/fe/fe-core/src/main/java/com/starrocks/plugin/AuditEvent.java index b7f1c0c2149b3..8ee121434517f 100644 --- a/fe/fe-core/src/main/java/com/starrocks/plugin/AuditEvent.java +++ b/fe/fe-core/src/main/java/com/starrocks/plugin/AuditEvent.java @@ -93,6 +93,8 @@ public enum EventType { public String errorCode = ""; @AuditField(value = "Time") public long queryTime = -1; + @AuditField(value = "ParserTime") + public long parserTime = -1; @AuditField(value = "ScanBytes") public long scanBytes = -1; @AuditField(value = "ScanRows") @@ -209,6 +211,11 @@ public AuditEventBuilder setQueryTime(long queryTime) { return this; } + public AuditEventBuilder setParserTime(long parserTime) { + auditEvent.parserTime = parserTime; + return this; + } + public AuditEventBuilder setScanBytes(long scanBytes) { auditEvent.scanBytes = scanBytes; return this; diff --git a/fe/fe-core/src/main/java/com/starrocks/qe/ConnectProcessor.java b/fe/fe-core/src/main/java/com/starrocks/qe/ConnectProcessor.java index 8c279afdacc77..84ca2a0c66138 100644 --- a/fe/fe-core/src/main/java/com/starrocks/qe/ConnectProcessor.java +++ b/fe/fe-core/src/main/java/com/starrocks/qe/ConnectProcessor.java @@ -306,6 +306,9 @@ protected void handleQuery() { } catch (ParsingException parsingException) { throw new AnalysisException(parsingException.getMessage()); } + long endParserTimestamp = System.currentTimeMillis(); + long parserTimeMs = endParserTimestamp - ctx.getAuditEventBuilder().build().timestamp; + ctx.getAuditEventBuilder().setParserTime(parserTimeMs); for (int i = 0; i < stmts.size(); ++i) { ctx.getState().reset(); diff --git a/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java b/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java index e6c5c0b728ce5..d6744a429cf39 100644 --- a/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java @@ -70,6 +70,7 @@ import com.starrocks.thrift.TSpillToRemoteStorageOptions; import com.starrocks.thrift.TTabletInternalParallelMode; import com.starrocks.thrift.TTimeUnit; +import com.starrocks.thrift.TVectorSearchOptions; import org.apache.commons.lang3.EnumUtils; import org.apache.commons.lang3.StringUtils; import org.apache.logging.log4j.LogManager; @@ -82,6 +83,7 @@ import java.io.IOException; import java.io.Serializable; import java.lang.reflect.Field; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; @@ -106,6 +108,7 @@ public class SessionVariable implements Serializable, Writable, Cloneable { // The optional values are "compute_nodes_only" and "all_nodes". public static final String COMPUTATION_FRAGMENT_SCHEDULING_POLICY = "computation_fragment_scheduling_policy"; public static final String EXEC_MEM_LIMIT = "exec_mem_limit"; + public static final String VECTOR_DISTANCE_COLUMN_NAME = "vector_distance_column_name"; /** * configure the mem limit of load process on BE. @@ -703,6 +706,46 @@ public static MaterializedViewRewriteMode parse(String str) { public static final String WAREHOUSE_NAME = "warehouse"; + public static final String ENABLE_USE_ANN = "enable_use_ann"; + + public static final String USE_IVFPQ = "use_ivfpq"; + + public static final String ANN_TYPE = "ann_type"; + + public static final String VECTOR_LIMIT_K = "vector_limit_k"; + + public static final String IS_REWRITTEN_VECTOR_PLAN = "is_rewritten_vector_plan"; + + public static final String QUERY_VECTOR = "query_vector"; + + public static final String VECTOR_RANGE = "vector_range"; + + public static final String RESULT_ORDER = "result_order"; + + public static final String QUERY_PARAMS = "query_params"; + + public static final String PQ_REFINE_FACTOR = "pq_refine_factor"; + + public static final String K_FACTOR = "k_factor"; + + public enum ANNAlgorithmType { + APPROX_L2_DISTANCE, // L2 distance + APPROX_COSINE_SIMILARITY, // Cosine distance + BRUTE, // Brute + DEFAULT; // default Type: Brute + + public static String TYPE_L2 = APPROX_L2_DISTANCE.toString(); + public static String TYPE_COSINE = APPROX_COSINE_SIMILARITY.toString(); + public static String TYPE_BRUTE = BRUTE.toString(); + public static String TYPE_DEFAULT = DEFAULT.toString(); + + public static ANNAlgorithmType parse(String str) { + return EnumUtils.getEnumIgnoreCase(ANNAlgorithmType.class, str); + } + } + + public static final String DEFAULT_VECTOR_DISTANCE_COLUMN_NAME = "vector_distance"; + public static final String HDFS_BACKEND_SELECTOR_HASH_ALGORITHM = "hdfs_backend_selector_hash_algorithm"; public static final String CONSISTENT_HASH_VIRTUAL_NUMBER = "consistent_hash_virtual_number"; @@ -1275,7 +1318,7 @@ public static MaterializedViewRewriteMode parse(String str) { private double broadcastRightTableScaleFactor = 10.0; @VariableMgr.VarAttr(name = NEW_PLANNER_OPTIMIZER_TIMEOUT) - private long optimizerExecuteTimeout = 3000; + private long optimizerExecuteTimeout = 300000000; @VariableMgr.VarAttr(name = QUERY_DEBUG_OPTIONS, flag = VariableMgr.INVISIBLE) private String queryDebugOptions = ""; @@ -1973,6 +2016,42 @@ public long getConnectorSinkTargetMaxFileSize() { @VarAttr(name = ENABLE_PARTITION_COLUMN_VALUE_ONLY_OPTIMIZATION, flag = VariableMgr.INVISIBLE) private boolean enablePartitionColumnValueOnlyOptimization = true; + @VarAttr(name = ENABLE_USE_ANN) + private boolean enableUseANN = false; + + @VarAttr(name = USE_IVFPQ) + private boolean useIVFPQ = false; + + @VarAttr(name = ANN_TYPE) + private String annType = ANNAlgorithmType.TYPE_DEFAULT; + + @VarAttr(name = VECTOR_DISTANCE_COLUMN_NAME) + private String vectorDistanceColumnName = DEFAULT_VECTOR_DISTANCE_COLUMN_NAME; + + @VarAttr(name = VECTOR_LIMIT_K) + private long vectorLimitK; + + @VarAttr(name = IS_REWRITTEN_VECTOR_PLAN) + private boolean isRewrittenVectorPlan = false; + + @VarAttr(name = QUERY_VECTOR) + private List queryVector = new ArrayList<>(); + + @VarAttr(name = VECTOR_RANGE) + private double vectorRange = -1; + + @VarAttr(name = RESULT_ORDER) + private int resultOrder = 0; + + @VarAttr(name = QUERY_PARAMS) + private Map queryParams = Maps.newHashMap(); + + @VarAttr(name = PQ_REFINE_FACTOR) + private double pqRefineFactor = 1; + + @VarAttr(name = K_FACTOR) + private double kFactor = 1; + // This variable is introduced to solve compatibility issues/ // see more details: https://github.com/StarRocks/starrocks/pull/29678 @VarAttr(name = ENABLE_COLLECT_TABLE_LEVEL_SCAN_STATS) @@ -3972,6 +4051,102 @@ public void setLargeDecimalUnderlyingType(String type) { } } + public boolean isEnableUseANN() { + return enableUseANN; + } + + public void setEnableUseANN(boolean v) { + enableUseANN = v; + } + + public boolean isUseIVFPQ() { + return useIVFPQ; + } + + public void setUseIVFPQ(boolean useIVFPQ) { + this.useIVFPQ = useIVFPQ; + } + + public String getANNType() { + return annType; + } + + public void setANNType(String annType) { + this.annType = annType; + } + + public String getVectorDistanceColumnName() { + return vectorDistanceColumnName; + } + + public void setVectorDistanceColumnName(String vectorDistanceColumnName) { + this.vectorDistanceColumnName = vectorDistanceColumnName; + } + + public long getVectorLimitK() { + return vectorLimitK; + } + + public void setVectorLimitK(long vectorLimitK) { + this.vectorLimitK = vectorLimitK; + } + + public boolean isRewrittenVectorPlan() { + return isRewrittenVectorPlan; + } + + public void setIsRewrittenVectorPlan(boolean isRewrittenVectorPlan) { + this.isRewrittenVectorPlan = isRewrittenVectorPlan; + } + + public List getQueryVector() { + return queryVector; + } + + public void setQueryVector(List queryVector) { + this.queryVector = queryVector; + } + + public double getVectorRange() { + return vectorRange; + } + + public void setVectorRange(double vectorRange) { + this.vectorRange = vectorRange; + } + + public int getResultOrder() { + return resultOrder; + } + + public void setResultOrder(int resultOrder) { + this.resultOrder = resultOrder; + } + + public Map getQueryParams() { + return queryParams; + } + + public void setQueryParams(Map queryParams) { + this.queryParams = queryParams; + } + + public double getPqRefineFactor() { + return pqRefineFactor; + } + + public void setPqRefineFactor(double pqRefineFactor) { + this.pqRefineFactor = pqRefineFactor; + } + + public double getKFactor() { + return kFactor; + } + + public void setKFactor(double kFactor) { + this.kFactor = kFactor; + } + public String getLargeDecimalUnderlyingType() { return largeDecimalUnderlyingType; } @@ -4234,6 +4409,20 @@ public TQueryOptions toThrift() { tResult.setConnector_io_tasks_slow_io_latency_ms(connectorIoTasksSlowIoLatency); tResult.setConnector_scan_use_query_mem_ratio(connectorScanUseQueryMemRatio); tResult.setScan_use_query_mem_ratio(scanUseQueryMemRatio); + if (enableUseANN) { + TVectorSearchOptions tVectorSearchOptions = new TVectorSearchOptions(); + tVectorSearchOptions.setVector_distance_column_name(vectorDistanceColumnName); + tVectorSearchOptions.setVector_limit_k(vectorLimitK); + tVectorSearchOptions.setQuery_vector(queryVector); + tVectorSearchOptions.setQuery_params(queryParams); + tVectorSearchOptions.setPq_refine_factor(pqRefineFactor); + tVectorSearchOptions.setK_factor(kFactor); + tVectorSearchOptions.setVector_range(vectorRange); + tVectorSearchOptions.setResult_order(resultOrder); + tVectorSearchOptions.setUse_ivfpq(useIVFPQ); + tResult.setVector_search_options(tVectorSearchOptions); + } + tResult.setEnable_use_ann(enableUseANN); tResult.setEnable_collect_table_level_scan_stats(enableCollectTableLevelScanStats); tResult.setEnable_pipeline_level_shuffle(enablePipelineLevelShuffle); tResult.setEnable_hyperscan_vec(enableHyperscanVec); diff --git a/fe/fe-core/src/main/java/com/starrocks/qe/VariableMgr.java b/fe/fe-core/src/main/java/com/starrocks/qe/VariableMgr.java index 60b726f08ae6d..3dd3e13b85349 100644 --- a/fe/fe-core/src/main/java/com/starrocks/qe/VariableMgr.java +++ b/fe/fe-core/src/main/java/com/starrocks/qe/VariableMgr.java @@ -66,6 +66,7 @@ import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.reflect.Field; +import java.util.Arrays; import java.util.Comparator; import java.util.HashMap; import java.util.List; @@ -74,6 +75,9 @@ import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.stream.Collectors; + +import static com.starrocks.qe.SessionVariable.QUERY_PARAMS; /** * Variable manager, merge session variable and global variable. @@ -255,6 +259,9 @@ private static boolean setValue(Object obj, Field field, String value) throws Dd case "String": field.set(obj, convertedVal); break; + case "List": + case "Map": + break; default: // Unsupported type variable. ErrorReport.reportDdlException(ErrorCode.ERR_WRONG_TYPE_FOR_VAR, variableName); @@ -309,6 +316,14 @@ public static boolean containsVariable(String name) { // setVar: variable information that needs to be set public static void setSystemVariable(SessionVariable sessionVariable, SystemVariable setVar, boolean onlySetSessionVar) throws DdlException { + if (setVar.getVariable().equalsIgnoreCase(QUERY_PARAMS) && setVar.getType() != SetType.GLOBAL) { + String value = setVar.getResolvedExpression().getStringValue().replaceAll("[{}]", ""); + Map queryParamsMap = Arrays.stream(value.split(",")) + .map(s -> s.split("=")) + .collect(Collectors.toMap(s -> s[0], s -> s[1])); + sessionVariable.setQueryParams(queryParamsMap); + return; + } if (SessionVariable.DEPRECATED_VARIABLES.stream().anyMatch(c -> c.equalsIgnoreCase(setVar.getVariable()))) { return; } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/SimplePlanner.java b/fe/fe-core/src/main/java/com/starrocks/sql/SimplePlanner.java new file mode 100644 index 0000000000000..5355655aa8337 --- /dev/null +++ b/fe/fe-core/src/main/java/com/starrocks/sql/SimplePlanner.java @@ -0,0 +1,22 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.starrocks.sql; + +import com.starrocks.qe.ConnectContext; +import com.starrocks.sql.ast.StatementBase; +import com.starrocks.sql.plan.ExecPlan; + +public interface SimplePlanner { + ExecPlan plan(S simpleStatement, ConnectContext session); +} diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/StatementPlanner.java b/fe/fe-core/src/main/java/com/starrocks/sql/StatementPlanner.java index e27f6e7a26914..d446251707dc9 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/StatementPlanner.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/StatementPlanner.java @@ -21,6 +21,7 @@ import com.starrocks.analysis.ParseNode; import com.starrocks.catalog.Database; import com.starrocks.catalog.ExternalOlapTable; +import com.starrocks.catalog.Index; import com.starrocks.catalog.KeysType; import com.starrocks.catalog.OlapTable; import com.starrocks.catalog.Table; @@ -29,6 +30,8 @@ import com.starrocks.common.Config; import com.starrocks.common.DuplicatedRequestException; import com.starrocks.common.LabelAlreadyUsedException; +import com.starrocks.common.VectorIndexParams.CommonIndexParamKey; +import com.starrocks.common.VectorIndexParams.VectorIndexType; import com.starrocks.common.profile.Timer; import com.starrocks.common.profile.Tracers; import com.starrocks.http.HttpConnectContext; @@ -46,6 +49,7 @@ import com.starrocks.sql.analyzer.SemanticException; import com.starrocks.sql.ast.DeleteStmt; import com.starrocks.sql.ast.DmlStmt; +import com.starrocks.sql.ast.IndexDef; import com.starrocks.sql.ast.InsertStmt; import com.starrocks.sql.ast.QueryRelation; import com.starrocks.sql.ast.QueryStatement; @@ -80,12 +84,17 @@ import java.util.Collections; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Set; +import static com.starrocks.catalog.FunctionSet.VECTOR_COMPUTE_FUNCTIONS; + public class StatementPlanner { private static final Logger LOG = LogManager.getLogger(StatementPlanner.class); + private static final String VECTOR_FUNCTION_PREFIX = "approx_"; + public static ExecPlan plan(StatementBase stmt, ConnectContext session) { if (session instanceof HttpConnectContext) { return plan(stmt, session, TResultSinkType.HTTP_PROTOCAL); @@ -227,6 +236,7 @@ private static ExecPlan createQueryPlan(StatementBase stmt, ConnectContext session, TResultSinkType resultSinkType) { QueryStatement queryStmt = (QueryStatement) stmt; + checkVectorIndex(queryStmt, session); QueryRelation query = (QueryRelation) queryStmt.getQueryRelation(); List colNames = query.getColumnOutputNames(); // 1. Build Logical plan @@ -256,6 +266,16 @@ private static ExecPlan createQueryPlan(StatementBase stmt, columnRefFactory); } + // TODO: avoid relying on session variables for performing tricks + String vectorFunction = session.getSessionVariable().getANNType().toLowerCase(); + if (VECTOR_COMPUTE_FUNCTIONS.contains(vectorFunction) && !session.getSessionVariable().isRewrittenVectorPlan()) { + String msg = String.format( + "%s is used but the given query cannot trigger vector index scan, " + + "please consider using brute-force search or adjusting your usage of %s", + vectorFunction, vectorFunction); + throw new SemanticException(msg); + } + try (Timer ignored = Tracers.watchScope("ExecPlanBuild")) { // 3. Build fragment exec plan /* @@ -281,6 +301,8 @@ public static ExecPlan createQueryPlanWithReTry(QueryStatement queryStmt, QueryRelation query = queryStmt.getQueryRelation(); List colNames = query.getColumnOutputNames(); + checkVectorIndex(queryStmt, session); + // 1. Build Logical plan ColumnRefFactory columnRefFactory = new ColumnRefFactory(); boolean isSchemaValid = true; @@ -326,6 +348,16 @@ public static ExecPlan createQueryPlanWithReTry(QueryStatement queryStmt, columnRefFactory); } + // TODO: avoid relying on session variables for performing tricks + String vectorFunction = session.getSessionVariable().getANNType().toLowerCase(); + if (VECTOR_COMPUTE_FUNCTIONS.contains(vectorFunction) && !session.getSessionVariable().isRewrittenVectorPlan()) { + String msg = String.format( + "%s is used but the given query cannot trigger vector index scan, " + + "please consider using brute-force search or adjusting your usage of %s", + vectorFunction, vectorFunction); + throw new SemanticException(msg); + } + try (Timer ignored = Tracers.watchScope("ExecPlanBuild")) { // 3. Build fragment exec plan // SingleNodeExecPlan is set in TableQueryPlanAction to generate a single-node Plan, @@ -356,6 +388,47 @@ public static ExecPlan createQueryPlanWithReTry(QueryStatement queryStmt, "schema of %s had been updated frequently during the plan generation", updatedTables); } + private static boolean checkAndSetVectorIndex(OlapTable olapTable, SessionVariable sessionVariable) { + String annType = sessionVariable.getANNType(); + for (Index index : olapTable.getIndexes()) { + if (index.getIndexType() == IndexDef.IndexType.VECTOR) { + Map indexProperties = index.getProperties(); + String metricType = indexProperties.get(CommonIndexParamKey.METRIC_TYPE.name().toLowerCase(Locale.ROOT)); + String indexType = indexProperties.get(CommonIndexParamKey.INDEX_TYPE.name().toLowerCase(Locale.ROOT)); + + if (!isValidMetricType(metricType, annType)) { + sessionVariable.setEnableUseANN(false); + return false; + } + + if (VectorIndexType.IVFPQ.name().equalsIgnoreCase(indexType)) { + sessionVariable.setUseIVFPQ(true); + } + + sessionVariable.setEnableUseANN(true); + return true; + } + } + return false; + } + + private static boolean isValidMetricType(String metricType, String annType) { + return (VECTOR_FUNCTION_PREFIX + metricType).equalsIgnoreCase(annType); + } + + private static void checkVectorIndex(QueryStatement queryStmt, ConnectContext session) { + Set olapTables = Sets.newHashSet(); + AnalyzerUtils.copyOlapTable(queryStmt, olapTables); + boolean hasVectorIndex = false; + for (OlapTable olapTable : olapTables) { + if (checkAndSetVectorIndex(olapTable, session.getSessionVariable())) { + hasVectorIndex = true; + break; + } + } + session.getSessionVariable().setEnableUseANN(hasVectorIndex); + } + public static Set collectOriginalOlapTables(ConnectContext session, StatementBase queryStmt) { Set olapTables = Sets.newHashSet(); PlannerMetaLocker locker = new PlannerMetaLocker(session, queryStmt); diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/ExpressionAnalyzer.java b/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/ExpressionAnalyzer.java index 48cff09402c71..7931599467022 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/ExpressionAnalyzer.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/analyzer/ExpressionAnalyzer.java @@ -89,6 +89,7 @@ import com.starrocks.privilege.RolePrivilegeCollectionV2; import com.starrocks.qe.ConnectContext; import com.starrocks.qe.SessionVariable; +import com.starrocks.qe.SessionVariable.ANNAlgorithmType; import com.starrocks.qe.SqlModeHelper; import com.starrocks.qe.VariableMgr; import com.starrocks.server.GlobalStateMgr; @@ -118,6 +119,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -1032,10 +1034,10 @@ public Void visitFunctionCall(FunctionCallExpr node, Scope scope) { checkFunction(fnName, node, argumentTypes); if (fnName.equalsIgnoreCase("typeof") && argumentTypes.length == 1) { // For the typeof function, the parameter type of the function is the result of this function. - // At this time, the parameter type has been obtained. You can directly replace the current - // function with StringLiteral. However, since the parent node of the current node in ast - // cannot be obtained, this cannot be done directly. Replacement, here the StringLiteral is - // stored in the parameter of the function, so that the StringLiteral can be obtained in the + // At this time, the parameter type has been obtained. You can directly replace the current + // function with StringLiteral. However, since the parent node of the current node in ast + // cannot be obtained, this cannot be done directly. Replacement, here the StringLiteral is + // stored in the parameter of the function, so that the StringLiteral can be obtained in the // subsequent rule rewriting, and then the typeof can be replaced. Type originType = argumentTypes[0]; argumentTypes[0] = Type.STRING; @@ -1237,12 +1239,47 @@ public Void visitFunctionCall(FunctionCallExpr node, Scope scope) { } } + if (FunctionSet.VECTOR_COMPUTE_FUNCTIONS.contains(fnName)) { + if (node.getFunctionHint().isPresent()) { + // Set vector index parameters + String functionHint = node.getFunctionHint().get(); + // Hints that are not brute and default enable EnableUseAnn by default + if (!(functionHint.equalsIgnoreCase(ANNAlgorithmType.TYPE_BRUTE) || + functionHint.equalsIgnoreCase(ANNAlgorithmType.TYPE_DEFAULT))) { + setVectorSessionVariable(session, node); + } else { + session.getSessionVariable().setEnableUseANN(false); + } + session.getSessionVariable().setANNType(functionHint.toUpperCase(Locale.ROOT)); + } else { + session.getSessionVariable().setANNType(fnName.toUpperCase(Locale.ROOT)); + setVectorSessionVariable(session, node); + } + } + node.setFn(fn); node.setType(fn.getReturnType()); FunctionAnalyzer.analyze(node); return null; } + private void setVectorSessionVariable(ConnectContext session, FunctionCallExpr node) { + session.getSessionVariable().setEnableUseANN(true); + + // Set query_vector for ann searcher + if (node.getParams().exprs().size() == 2) { + // thrift doesn't have a float type, so we use string to prevent precision loss + if (node.getParams().exprs().get(0) instanceof ArrayExpr) { + session.getSessionVariable().setQueryVector(node.getParams().exprs().get(0).childrenToSql()); + return; + } else if (node.getParams().exprs().get(0) instanceof StringLiteral) { + session.getSessionVariable().setQueryVector( + Arrays.asList(((StringLiteral) node.getParams().exprs().get(0)).getValue().split(","))); + return; + } + } + } + private void checkFunction(String fnName, FunctionCallExpr node, Type[] argumentTypes) { switch (fnName) { case FunctionSet.TIME_SLICE: diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java index a039db0933c44..0a6f48a21be6a 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java @@ -616,6 +616,8 @@ private OptExpression logicalRuleRewrite( ruleRewriteOnlyOnce(tree, rootTaskContext, UnionToValuesRule.getInstance()); + ruleRewriteOnlyOnce(tree, rootTaskContext, RuleSetType.VECTOR_REWRITE); + tree = SimplifyCaseWhenPredicateRule.INSTANCE.rewrite(tree, rootTaskContext); deriveLogicalProperty(tree); return tree.getInputs().get(0); diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/logical/LogicalOlapScanOperator.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/logical/LogicalOlapScanOperator.java index 2389b8e59c24c..46f2b55357aa9 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/logical/LogicalOlapScanOperator.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/logical/LogicalOlapScanOperator.java @@ -26,6 +26,7 @@ import com.starrocks.sql.optimizer.operator.Operator; import com.starrocks.sql.optimizer.operator.OperatorType; import com.starrocks.sql.optimizer.operator.OperatorVisitor; +import com.starrocks.sql.optimizer.operator.Projection; import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator; import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator; @@ -56,6 +57,12 @@ public LogicalOlapScanOperator(Table table) { this(table, Maps.newHashMap(), Maps.newHashMap(), null, Operator.DEFAULT_LIMIT, null); } + // Only for UT + public LogicalOlapScanOperator(Table table, Projection projection, long limit) { + super(OperatorType.LOGICAL_OLAP_SCAN, table, Maps.newHashMap(), Maps.newHashMap(), limit, null, + projection); + } + public LogicalOlapScanOperator( Table table, Map colRefToColumnMetaMap, diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/logical/LogicalTopNOperator.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/logical/LogicalTopNOperator.java index d628851435568..f16e7216ce9d5 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/logical/LogicalTopNOperator.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/logical/LogicalTopNOperator.java @@ -68,7 +68,7 @@ private LogicalTopNOperator() { super(OperatorType.LOGICAL_TOPN); } - private LogicalTopNOperator(long limit, + public LogicalTopNOperator(long limit, ScalarOperator predicate, Projection projection, List partitionByColumns, long partitionLimit, diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSet.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSet.java index c5e19a3a4cd70..237cffce967c6 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSet.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSet.java @@ -150,6 +150,7 @@ import com.starrocks.sql.optimizer.rule.transformation.RewriteSimpleAggToHDFSScanRule; import com.starrocks.sql.optimizer.rule.transformation.RewriteSimpleAggToMetaScanRule; import com.starrocks.sql.optimizer.rule.transformation.RewriteSumByAssociativeRule; +import com.starrocks.sql.optimizer.rule.transformation.RewriteToVectorPlanRule; import com.starrocks.sql.optimizer.rule.transformation.ScalarApply2AnalyticRule; import com.starrocks.sql.optimizer.rule.transformation.ScalarApply2JoinRule; import com.starrocks.sql.optimizer.rule.transformation.SplitLimitRule; @@ -270,6 +271,10 @@ public class RuleSet { new LimitPruneTabletsRule() )); + REWRITE_RULES.put(RuleSetType.VECTOR_REWRITE, ImmutableList.of( + new RewriteToVectorPlanRule() + )); + REWRITE_RULES.put(RuleSetType.PRUNE_COLUMNS, ImmutableList.of( PruneScanColumnRule.OLAP_SCAN, PruneScanColumnRule.SCHEMA_SCAN, diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSetType.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSetType.java index 0d149a9316389..ae4c78fa89cd2 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSetType.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSetType.java @@ -43,5 +43,6 @@ public enum RuleSetType { FINE_GRAINED_RANGE_PREDICATE, ELIMINATE_GROUP_BY, - NUM_RULE_SET + NUM_RULE_SET, + VECTOR_REWRITE } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleType.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleType.java index c4b0bc956f678..23df6cb3ca1fd 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleType.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleType.java @@ -200,6 +200,8 @@ public enum RuleType { TF_PREDICATE_PROPAGATE, + TF_VECTOR_REWRITE_RULE, + // The following are implementation rules: IMP_OLAP_LSCAN_TO_PSCAN, IMP_HIVE_LSCAN_TO_PSCAN, diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/RewriteToVectorPlanRule.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/RewriteToVectorPlanRule.java new file mode 100644 index 0000000000000..f4202f5748b76 --- /dev/null +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/RewriteToVectorPlanRule.java @@ -0,0 +1,230 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.starrocks.sql.optimizer.rule.transformation; + +import com.google.common.collect.Lists; +import com.starrocks.analysis.BinaryType; +import com.starrocks.catalog.Column; +import com.starrocks.catalog.FunctionSet; +import com.starrocks.catalog.Type; +import com.starrocks.qe.SessionVariable; +import com.starrocks.qe.SessionVariable.ANNAlgorithmType; +import com.starrocks.sql.optimizer.OptExpression; +import com.starrocks.sql.optimizer.OptimizerContext; +import com.starrocks.sql.optimizer.base.ColumnRefFactory; +import com.starrocks.sql.optimizer.base.Ordering; +import com.starrocks.sql.optimizer.operator.Operator; +import com.starrocks.sql.optimizer.operator.OperatorType; +import com.starrocks.sql.optimizer.operator.Projection; +import com.starrocks.sql.optimizer.operator.logical.LogicalOlapScanOperator; +import com.starrocks.sql.optimizer.operator.logical.LogicalScanOperator; +import com.starrocks.sql.optimizer.operator.logical.LogicalTopNOperator; +import com.starrocks.sql.optimizer.operator.pattern.Pattern; +import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator; +import com.starrocks.sql.optimizer.operator.scalar.CallOperator; +import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator; +import com.starrocks.sql.optimizer.operator.scalar.CompoundPredicateOperator; +import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator; +import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator; +import com.starrocks.sql.optimizer.rule.RuleType; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +import static com.starrocks.analysis.BinaryType.GE; +import static com.starrocks.analysis.BinaryType.GT; +import static com.starrocks.analysis.BinaryType.LE; +import static com.starrocks.analysis.BinaryType.LT; +import static com.starrocks.catalog.FunctionSet.APPROX_COSINE_SIMILARITY; +import static com.starrocks.catalog.FunctionSet.APPROX_L2_DISTANCE; + +public class RewriteToVectorPlanRule extends TransformationRule { + + public RewriteToVectorPlanRule() { + super(RuleType.TF_VECTOR_REWRITE_RULE, + Pattern.create(OperatorType.LOGICAL_TOPN) + .addChildren(Pattern.create(OperatorType.LOGICAL_OLAP_SCAN))); + } + + @Override + public boolean check(OptExpression input, OptimizerContext context) { + SessionVariable sessionVariable = context.getSessionVariable(); + if (input.getInputs().size() < 1 || !sessionVariable.isEnableUseANN() || + (sessionVariable.isEnableUseANN() && (sessionVariable.getANNType().equalsIgnoreCase(ANNAlgorithmType.TYPE_BRUTE) + || sessionVariable.getANNType().equalsIgnoreCase(ANNAlgorithmType.TYPE_DEFAULT)))) { + return false; + } + + LogicalTopNOperator topNOperator = (LogicalTopNOperator) input.getOp(); + LogicalScanOperator scanOperator = (LogicalScanOperator) input.getInputs().get(0).getOp(); + if (!topNOperator.getOrderByElements().isEmpty() && + FunctionSet.VECTOR_COMPUTE_FUNCTIONS.contains( + topNOperator.getOrderByElements().get(0).getColumnRef().getName())) { + return topNOperator.getLimit() != Operator.DEFAULT_LIMIT && + scanOperator.getProjection().getColumnRefMap().entrySet().stream() + .filter(entry -> FunctionSet.VECTOR_COMPUTE_FUNCTIONS.contains(entry.getKey().getName())) + .anyMatch(entry -> entry.getValue() instanceof CallOperator); + } + return false; + } + + @Override + public List transform(OptExpression input, OptimizerContext context) { + LogicalTopNOperator topNOperator = (LogicalTopNOperator) input.getOp(); + SessionVariable sessionVariable = context.getSessionVariable(); + // set limit_K for ann searcher + sessionVariable.setVectorLimitK(topNOperator.getLimit()); + LogicalScanOperator scanOperator = (LogicalScanOperator) input.getInputs().get(0).getOp(); + ScalarOperator predicate = scanOperator.getPredicate(); + Optional newPredicate = Optional.empty(); + boolean isAscending = topNOperator.getOrderByElements().get(0).isAscending(); + if (predicate != null) { + newPredicate = findAndSetVectorRange(predicate, sessionVariable, isAscending); + if (!sessionVariable.isEnableUseANN()) { + return Lists.newArrayList(input); + } + } + sessionVariable.setResultOrder(isAscending ? 0 : 1); + String functionName = topNOperator.getOrderByElements().get(0).getColumnRef().getName(); + + if (functionName.equalsIgnoreCase(APPROX_L2_DISTANCE) && !isAscending || + functionName.equalsIgnoreCase(APPROX_COSINE_SIMILARITY) && isAscending || + !sessionVariable.isEnableUseANN()) { + sessionVariable.setEnableUseANN(false); + return Lists.newArrayList(input); + } + sessionVariable.setIsRewrittenVectorPlan(true); + if (sessionVariable.isUseIVFPQ()) { + input.getInputs().get(0).getOp() + .setPredicate(newPredicate.isPresent() ? newPredicate.get() : null); + return Lists.newArrayList(input); + } + + OptExpression result = buildVectorSortScanOperator(topNOperator, + (LogicalOlapScanOperator) scanOperator, context, newPredicate); + return Lists.newArrayList(result); + } + + public OptExpression buildVectorSortScanOperator(LogicalTopNOperator topNOperator, + LogicalOlapScanOperator scanOperator, OptimizerContext context, Optional newPredicate) { + // bottom-up + String distanceColumnName = context.getSessionVariable().getVectorDistanceColumnName(); + Column distanceColumn = new Column(distanceColumnName, Type.FLOAT); + scanOperator.getTable().addColumn(distanceColumn); + + ColumnRefFactory columnRefFactory = context.getColumnRefFactory(); + ColumnRefOperator distanceColumnRefOperator = columnRefFactory.create(distanceColumnName, Type.FLOAT, false); + + Map colRefToColumnMetaMap = new HashMap<>(scanOperator.getColRefToColumnMetaMap()); + colRefToColumnMetaMap.put(distanceColumnRefOperator, distanceColumn); + + Map columnMetaToColRefMap = new HashMap<>(scanOperator.getColumnMetaToColRefMap()); + columnMetaToColRefMap.put(distanceColumn, distanceColumnRefOperator); + + // new Scan operator + LogicalOlapScanOperator newScanOperator = new LogicalOlapScanOperator(scanOperator.getTable(), + colRefToColumnMetaMap, columnMetaToColRefMap, scanOperator.getDistributionSpec(), + scanOperator.getLimit(), newPredicate.isPresent() ? newPredicate.get() : null, + scanOperator.getSelectedIndexId(), scanOperator.getSelectedPartitionId(), + scanOperator.getPartitionNames(), scanOperator.hasTableHints(), + scanOperator.getSelectedTabletId(), scanOperator.getHintsTabletIds(), + scanOperator.getHintsReplicaIds(), scanOperator.isUsePkIndex()); + + Map projectMap = new HashMap<>(); + scanOperator.getProjection().getColumnRefMap().entrySet().stream() + .forEach(entry -> { + if (FunctionSet.VECTOR_COMPUTE_FUNCTIONS.contains(entry.getKey().getName()) + && entry.getValue() instanceof CallOperator) { + projectMap.put(distanceColumnRefOperator, distanceColumnRefOperator); + } else { + projectMap.put(entry.getKey(), entry.getValue()); + } + }); + newScanOperator.setProjection(new Projection(projectMap)); + + List orderByElements = topNOperator.getOrderByElements().stream().map(ordering -> + FunctionSet.VECTOR_COMPUTE_FUNCTIONS.contains(ordering.getColumnRef().getName()) ? + new Ordering(distanceColumnRefOperator, ordering.isAscending(), ordering.isNullsFirst()) : ordering) + .collect(Collectors.toList()); + + Map topNProjectMap = new HashMap<>(); + boolean hasProjection = topNOperator.getProjection() != null; + if (hasProjection) { + topNOperator.getProjection().getColumnRefMap().entrySet().stream() + .forEach(entry -> { + if (FunctionSet.VECTOR_COMPUTE_FUNCTIONS.contains(entry.getKey().getName())) { + topNProjectMap.put(distanceColumnRefOperator, distanceColumnRefOperator); + } else { + topNProjectMap.put(entry.getKey(), entry.getValue()); + } + }); + } + + // new TopN operator + LogicalTopNOperator newTopNOperator = new LogicalTopNOperator(topNOperator.getLimit(), + topNOperator.getPredicate(), hasProjection ? new Projection(topNProjectMap) : null, + topNOperator.getPartitionByColumns(), topNOperator.getPartitionLimit(), orderByElements, + topNOperator.getOffset(), topNOperator.getSortPhase(), topNOperator.getTopNType(), topNOperator.isSplit()); + + OptExpression topNExpression = OptExpression.create(newTopNOperator); + topNExpression.getInputs().clear(); + topNExpression.getInputs().add(OptExpression.create(newScanOperator)); + + return topNExpression; + } + + // todo: prune in dfs + public Optional findAndSetVectorRange(ScalarOperator operator, + SessionVariable sessionVariable, boolean isAscending) { + if (!sessionVariable.isEnableUseANN()) { + return Optional.empty(); + } + + if (operator instanceof BinaryPredicateOperator && operator.getChild(1) instanceof ConstantOperator && + (operator.getChild(0).toString().contains("l2") || operator.getChild(0).toString().contains("cosine"))) { + BinaryType binaryType = ((BinaryPredicateOperator) operator).getBinaryType(); + if (((binaryType.equals(LE) || binaryType.equals(LT)) && !isAscending) || + ((binaryType.equals(GE) || binaryType.equals(GT)) && isAscending)) { + sessionVariable.setEnableUseANN(false); + return Optional.empty(); + } + sessionVariable.setVectorRange( + (double) (((ConstantOperator) operator.getChild(1)).getValue())); + return Optional.empty(); + } else if (operator instanceof CompoundPredicateOperator) { + List newOperators = new ArrayList<>(); + for (ScalarOperator child : operator.getChildren()) { + Optional newChild = findAndSetVectorRange(child, sessionVariable, isAscending); + if (newChild.isPresent()) { + newOperators.add(newChild.get()); + } + } + if (newOperators.size() > 1) { + return Optional.of(new CompoundPredicateOperator(((CompoundPredicateOperator) operator).getCompoundType(), + newOperators)); + } else if (newOperators.size() == 1) { + return Optional.of(newOperators.get(0)); + } else { + return Optional.empty(); + } + } else { + return Optional.of(operator.clone()); + } + } + +} diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/parser/AstBuilder.java b/fe/fe-core/src/main/java/com/starrocks/sql/parser/AstBuilder.java index b5d63fab49053..7d02a6d90e48f 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/parser/AstBuilder.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/parser/AstBuilder.java @@ -6402,6 +6402,10 @@ public ParseNode visitSimpleFunctionCall(StarRocksParser.SimpleFunctionCallConte if (context.over() != null) { return buildOverClause(functionCallExpr, context.over(), pos); } + + if (context.bracketHint() != null) { + functionCallExpr.setFunctionHint(((Identifier) visit(context.bracketHint().identifier().get(0))).getValue()); + } return SyntaxSugars.parse(functionCallExpr); } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/parser/StarRocks.g4 b/fe/fe-core/src/main/java/com/starrocks/sql/parser/StarRocks.g4 index ed032e8e1238f..ef35c29657197 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/parser/StarRocks.g4 +++ b/fe/fe-core/src/main/java/com/starrocks/sql/parser/StarRocks.g4 @@ -2300,7 +2300,7 @@ functionCall | specialFunctionExpression #specialFunction | aggregationFunction over? #aggregationFunctionCall | windowFunction over #windowFunctionCall - | qualifiedName '(' (expression (',' expression)*)? ')' over? #simpleFunctionCall + | qualifiedName '(' (expression (',' expression)*)? ')' over? bracketHint? #simpleFunctionCall ; aggregationFunction diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java b/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java index ff89d13f3ce4f..00cd4a2d3c7c1 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java @@ -235,10 +235,29 @@ public static ExecPlan createPhysicalPlan(OptExpression plan, ConnectContext con List colNames, TResultSinkType resultSinkType, boolean hasOutputFragment) { + // filter vector function name + List newOutputColumns = new ArrayList<>(); + if (connectContext.getSessionVariable().isEnableUseANN()) { + for (ColumnRefOperator outputColumn : outputColumns) { + if (FunctionSet.VECTOR_COMPUTE_FUNCTIONS.contains(outputColumn.getName()) && + !connectContext.getSessionVariable().isUseIVFPQ()) { + for (ColumnRefOperator columnRefOperator : columnRefFactory.getColumnRefs()) { + if (columnRefOperator.getName() + .equalsIgnoreCase(connectContext.getSessionVariable().getVectorDistanceColumnName())) { + newOutputColumns.add(columnRefOperator); + } + } + } else { + newOutputColumns.add(outputColumn); + } + } + } else { + newOutputColumns = outputColumns; + } UKFKConstraintsCollector.collectColumnConstraints(plan); - ExecPlan execPlan = new ExecPlan(connectContext, colNames, plan, outputColumns); + ExecPlan execPlan = new ExecPlan(connectContext, colNames, plan, newOutputColumns); createOutputFragment(new PhysicalPlanTranslator(columnRefFactory).translate(plan, execPlan), execPlan, - outputColumns, hasOutputFragment); + newOutputColumns, hasOutputFragment); execPlan.setPlanCount(plan.getPlanCount()); return finalizeFragments(execPlan, resultSinkType); } diff --git a/fe/fe-core/src/test/java/com/starrocks/planner/VectorIndexTest.java b/fe/fe-core/src/test/java/com/starrocks/planner/VectorIndexTest.java new file mode 100644 index 0000000000000..fcab8ff1ab145 --- /dev/null +++ b/fe/fe-core/src/test/java/com/starrocks/planner/VectorIndexTest.java @@ -0,0 +1,178 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is based on code available under the Apache license here: +// https://github.com/apache/incubator-doris/blob/master/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package com.starrocks.planner; + +import com.google.common.collect.Lists; +import com.starrocks.catalog.OlapTable; +import com.starrocks.catalog.Type; +import com.starrocks.qe.DefaultCoordinator; +import com.starrocks.sql.ast.StatementBase; +import com.starrocks.common.util.UUIDUtil; +import com.starrocks.qe.ConnectContext; +import com.starrocks.qe.StmtExecutor; +import com.starrocks.sql.optimizer.Memo; +import com.starrocks.sql.optimizer.OptExpression; +import com.starrocks.sql.optimizer.OptimizerContext; +import com.starrocks.sql.optimizer.base.ColumnRefFactory; +import com.starrocks.sql.optimizer.base.Ordering; +import com.starrocks.sql.optimizer.operator.Projection; +import com.starrocks.sql.optimizer.operator.logical.LogicalOlapScanOperator; +import com.starrocks.sql.optimizer.operator.logical.LogicalTopNOperator; +import com.starrocks.sql.optimizer.operator.scalar.CallOperator; +import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator; +import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator; +import com.starrocks.sql.optimizer.rule.transformation.RewriteToVectorPlanRule; +import com.starrocks.utframe.StarRocksAssert; +import com.starrocks.utframe.UtFrameUtils; + +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import mockit.Mocked; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class VectorIndexTest { + private static ConnectContext connectContext; + private static StarRocksAssert starRocksAssert; + private static String DB_NAME = "test"; + + @BeforeClass + public static void beforeClass() throws Exception { + UtFrameUtils.createMinStarRocksCluster(); + + // create connect context + connectContext = UtFrameUtils.createDefaultCtx(); + connectContext.setQueryId(UUIDUtil.genUUID()); + starRocksAssert = new StarRocksAssert(connectContext); + starRocksAssert.withDatabase(DB_NAME).useDatabase(DB_NAME); + + starRocksAssert.withTable("create table test.test_vector(c0 INT, c1 array) " + + " duplicate key(c0) distributed by hash(c0) buckets 1 " + + "properties('replication_num'='1');"); + } + + @Test + public void testVectorIndexSyntax() throws Exception { + String sql1 = "select c1 from test.test_vector " + + "order by cosine_similarity(c1, [1.1,2.2,3.3]) desc limit 10"; + List stmts1 = com.starrocks.sql.parser.SqlParser.parse(sql1, + connectContext.getSessionVariable()); + StmtExecutor stmtExecutor1 = new StmtExecutor(connectContext, stmts1.get(0)); + stmtExecutor1.execute(); + Assert.assertTrue(!connectContext.getSessionVariable().isEnableUseANN()); + + String sql2 = "select c1 from test.test_vector " + + "order by approx_cosine_similarity([1.1,2.2,3.3], c1) desc limit 10"; + List stmts2 = com.starrocks.sql.parser.SqlParser.parse(sql2, + connectContext.getSessionVariable()); + StmtExecutor stmtExecutor2 = new StmtExecutor(connectContext, stmts2.get(0)); + stmtExecutor2.execute(); + Assert.assertTrue(connectContext.getSessionVariable().isEnableUseANN()); + + String sql3 = "select c1 from test.test_vector " + + "order by approx_cosine_similarity(\"1.1,2.2,3.3\", c1) desc limit 10"; + List stmts3 = com.starrocks.sql.parser.SqlParser.parse(sql3, + connectContext.getSessionVariable()); + StmtExecutor stmtExecutor3 = new StmtExecutor(connectContext, stmts3.get(0)); + stmtExecutor3.execute(); + Assert.assertTrue(connectContext.getSessionVariable().isEnableUseANN()); + + String sql4 = "select c1 from test.test_vector " + + "order by approx_l2_distance(\"1.1,2.2,3.3\", c1) limit 10"; + List stmts4 = com.starrocks.sql.parser.SqlParser.parse(sql4, + connectContext.getSessionVariable()); + StmtExecutor stmtExecutor4 = new StmtExecutor(connectContext, stmts4.get(0)); + stmtExecutor4.execute(); + Assert.assertTrue(connectContext.getSessionVariable().isEnableUseANN()); + + // Sorting in desc order doesn't make sense in l2_distance + // , which won't trigger the vector retrieval logic. + String sql5 = "select c1 from test.test_vector " + + "order by approx_l2_distance([1.1,2.2,3.3], c1) desc limit 10"; + List stmts5 = com.starrocks.sql.parser.SqlParser.parse(sql5, + connectContext.getSessionVariable()); + StmtExecutor stmtExecutor5 = new StmtExecutor(connectContext, stmts5.get(0)); + stmtExecutor5.execute(); + Assert.assertTrue(!connectContext.getSessionVariable().isEnableUseANN()); + } + + @Test + public void testRewriteToVectorPlanRule(@Mocked OlapTable olapTable1) { + // sql: select c1 from test.test_vector order by cosine_similarity([1.1,2.2,3.3], c1) limit 10; + connectContext.getSessionVariable().setEnableUseANN(true); + ColumnRefOperator c1ColumnRef = new ColumnRefOperator(2, Type.ARRAY_FLOAT, "c1", true); + Map topNProjectMap = new HashMap() { + { + put(c1ColumnRef, c1ColumnRef); + } + }; + ColumnRefOperator cosineColumnRef = new ColumnRefOperator(3, Type.FLOAT, "approx_l2_distance", true); + CallOperator cosineCallRef = new CallOperator("approx_l2_distance", Type.FLOAT, new ArrayList<>()); + OptExpression topN = new OptExpression(new LogicalTopNOperator(10, null, new Projection(topNProjectMap), null, + 0, Lists.newArrayList(new Ordering(cosineColumnRef, true, true)), + 0, null, null, false) + ); + + Map scanProjectMap = new HashMap() { + { + put(c1ColumnRef, c1ColumnRef); + put(cosineColumnRef, cosineCallRef); + } + }; + OptExpression scan = new OptExpression(new LogicalOlapScanOperator(olapTable1, new Projection(scanProjectMap), + 10)); + + topN.getInputs().add(scan); + + RewriteToVectorPlanRule rule = new RewriteToVectorPlanRule(); + List list = rule.transform(topN, new OptimizerContext(new Memo(), new ColumnRefFactory(), connectContext)); + + ScalarOperator scanProjectOperator = + ((ScalarOperator) list.get(0).getInputs().get(0).getOp().getProjection().getColumnRefMap().values().toArray()[0]); + + // Substitution of sort column + assertEquals(connectContext.getSessionVariable().getVectorDistanceColumnName(), + ((LogicalTopNOperator) list.get(0).getOp()).getOrderByElements().get(0).getColumnRef().getName()); + assertTrue(scanProjectOperator instanceof ColumnRefOperator + && ((ColumnRefOperator) scanProjectOperator).getName().equals(connectContext.getSessionVariable().getVectorDistanceColumnName())); + } +} diff --git a/gensrc/script/functions.py b/gensrc/script/functions.py index 0c8f330931cef..82e88ecd11520 100644 --- a/gensrc/script/functions.py +++ b/gensrc/script/functions.py @@ -56,15 +56,25 @@ [10090, "tan", True, False, "DOUBLE", ["DOUBLE"], "MathFunctions::tan"], [10100, "atan", True, False, "DOUBLE", ["DOUBLE"], "MathFunctions::atan"], [10101, "tanh", True, False, "DOUBLE", ["DOUBLE"], "MathFunctions::tanh"], - [10102, "cosine_similarity", True, False, "FLOAT", ["ARRAY_FLOAT", "ARRAY_FLOAT"], - "MathFunctions::cosine_similarity"], - [10103, "cosine_similarity_norm", True, False, "FLOAT", ["ARRAY_FLOAT", "ARRAY_FLOAT"], - "MathFunctions::cosine_similarity"], + + # cosine function + [10102, "cosine_similarity", True, False, "FLOAT", ["ARRAY_FLOAT", "ARRAY_FLOAT"], "MathFunctions::cosine_similarity"], + [10103, "cosine_similarity_norm", True, False, "FLOAT", ["ARRAY_FLOAT", "ARRAY_FLOAT"], "MathFunctions::cosine_similarity"], + [10104, "cosine_similarity", True, False, "FLOAT", ["VARCHAR", "ARRAY_FLOAT"], "MathFunctions::cosine_similarity2"], + [10105, "cosine_similarity_norm", True, False, "FLOAT", ["VARCHAR", "ARRAY_FLOAT"], "MathFunctions::cosine_similarity2"], + [10106, "approx_cosine_similarity", True, False, "FLOAT", ["ARRAY_FLOAT", "ARRAY_FLOAT"], "MathFunctions::cosine_similarity"], + [10107, "approx_cosine_similarity", True, False, "FLOAT", ["VARCHAR", "ARRAY_FLOAT"], "MathFunctions::cosine_similarity2"], [10110, "ceil", True, False, "BIGINT", ["DOUBLE"], "MathFunctions::ceil"], [10111, "ceiling", True, False, "BIGINT", ["DOUBLE"], "MathFunctions::ceil"], [10112, "dceil", True, False, "BIGINT", ["DOUBLE"], "MathFunctions::ceil"], + # l2 function + [10114, "l2_distance", True, False, "FLOAT", ["ARRAY_FLOAT", "ARRAY_FLOAT"], "MathFunctions::l2_distance"], + [10115, "l2_distance", True, False, "FLOAT", ["VARCHAR", "ARRAY_FLOAT"], "MathFunctions::l2_distance2"], + [10116, "approx_l2_distance", True, False, "FLOAT", ["ARRAY_FLOAT", "ARRAY_FLOAT"], "MathFunctions::l2_distance"], + [10117, "approx_l2_distance", True, False, "FLOAT", ["VARCHAR", "ARRAY_FLOAT"], "MathFunctions::l2_distance2"], + [10120, "floor", True, False, "BIGINT", ["DOUBLE"], "MathFunctions::floor"], [10121, "dfloor", True, False, "BIGINT", ["DOUBLE"], "MathFunctions::floor"], diff --git a/gensrc/thrift/InternalService.thrift b/gensrc/thrift/InternalService.thrift index 87848cce039b0..910e373113667 100644 --- a/gensrc/thrift/InternalService.thrift +++ b/gensrc/thrift/InternalService.thrift @@ -134,6 +134,18 @@ enum TTimeUnit { MINUTE = 4; } +struct TVectorSearchOptions { + 1: optional i64 vector_limit_k; + 2: optional string vector_distance_column_name; + 3: optional list query_vector; + 4: optional map query_params; + 5: optional double vector_range; + 6: optional i32 result_order; + 7: optional bool use_ivfpq; + 8: optional double pq_refine_factor; + 9: optional double k_factor; +} + struct TQueryQueueOptions { 1: optional bool enable_global_query_queue; 2: optional bool enable_group_level_query_queue; @@ -253,7 +265,7 @@ struct TQueryOptions { 85: optional TSpillMode spill_mode; 82: optional TSpillOptions spill_options; - + 86: optional i32 io_tasks_per_scan_operator = 4; 87: optional i32 connector_io_tasks_per_scan_operator = 16; 88: optional double runtime_filter_early_return_selectivity = 0.05; @@ -319,6 +331,9 @@ struct TQueryOptions { 140: optional string catalog; 141: optional i32 datacache_evict_probability; + + 150: optional bool enable_use_ann; + 151: optional TVectorSearchOptions vector_search_options; }