diff --git a/be/src/column/chunk.cpp b/be/src/column/chunk.cpp index 9f455d42fc321..887ed8f480658 100644 --- a/be/src/column/chunk.cpp +++ b/be/src/column/chunk.cpp @@ -132,6 +132,15 @@ void Chunk::append_column(ColumnPtr column, const FieldPtr& field) { check_or_die(); } +void Chunk::append_vector_column(ColumnPtr column, const FieldPtr& field, SlotId slot_id) { + DCHECK(!_cid_to_index.contains(field->id())); + _cid_to_index[field->id()] = _columns.size(); + _slot_id_to_index[slot_id] = _columns.size(); + _columns.emplace_back(std::move(column)); + _schema->append(field); + check_or_die(); +} + void Chunk::append_column(ColumnPtr column, SlotId slot_id) { _slot_id_to_index[slot_id] = _columns.size(); _columns.emplace_back(std::move(column)); diff --git a/be/src/column/chunk.h b/be/src/column/chunk.h index 67de2c8768540..a97f5c8c25106 100644 --- a/be/src/column/chunk.h +++ b/be/src/column/chunk.h @@ -112,6 +112,8 @@ class Chunk { // schema must exist and will be updated. void append_column(ColumnPtr column, const FieldPtr& field); + void append_vector_column(ColumnPtr column, const FieldPtr& field, SlotId slot_id); + void append_column(ColumnPtr column, SlotId slot_id); void insert_column(size_t idx, ColumnPtr column, const FieldPtr& field); diff --git a/be/src/column/schema.cpp b/be/src/column/schema.cpp index 92e554c86cc54..d0e8625fda3ad 100644 --- a/be/src/column/schema.cpp +++ b/be/src/column/schema.cpp @@ -45,7 +45,10 @@ Schema::Schema(Schema* schema, const std::vector& cids) auto ori_sort_idxes = schema->sort_key_idxes(); std::unordered_set scids(ori_sort_idxes.begin(), ori_sort_idxes.end()); for (int i = 0; i < cids.size(); i++) { - DCHECK_LT(cids[i], schema->_fields.size()); + if (cids[i] >= schema->_fields.size()) { + _fields.resize(_fields.size() - 1); + continue; + } _fields[i] = schema->_fields[cids[i]]; if (scids.find(cids[i]) != scids.end()) { _sort_key_idxes.emplace_back(i); diff --git a/be/src/common/config.h b/be/src/common/config.h index 97e7fe81f92b5..3d7f23a0f981d 100644 --- a/be/src/common/config.h +++ b/be/src/common/config.h @@ -1226,6 +1226,9 @@ CONF_Int64(query_cache_capacity, "536870912"); // ranges in [1,16], default value is 4. CONF_mInt32(query_cache_num_lanes_per_driver, "4"); +// Used by vector query cache, 500MB in default +CONF_Int64(vector_query_cache_capacity, "536870912"); + // Used to limit buffer size of tablet send channel. CONF_mInt64(send_channel_buffer_limit, "67108864"); @@ -1369,6 +1372,9 @@ CONF_mInt64(arrow_io_coalesce_read_max_buffer_size, "8388608"); CONF_mInt64(arrow_io_coalesce_read_max_distance_size, "1048576"); CONF_mInt64(arrow_read_batch_size, "4096"); +// default not to build the empty index +CONF_mInt32(config_tenann_default_build_threshold, "0"); + // Set to true to enable socket_keepalive option in brpc CONF_mBool(brpc_socket_keepalive, "false"); CONF_mBool(apply_del_vec_after_all_index_filter, "true"); diff --git a/be/src/exec/olap_scan_node.cpp b/be/src/exec/olap_scan_node.cpp index 3f46b349cb720..65a5400806868 100644 --- a/be/src/exec/olap_scan_node.cpp +++ b/be/src/exec/olap_scan_node.cpp @@ -583,6 +583,12 @@ void OlapScanNode::_init_counter(RuntimeState* state) { _bf_filtered_counter = ADD_CHILD_COUNTER(_scan_profile, "BloomFilterFilterRows", TUnit::UNIT, "SegmentInit"); _gin_filtered_counter = ADD_CHILD_COUNTER(_runtime_profile, "GinFilterRows", TUnit::UNIT, "SegmentInit"); _gin_filtered_timer = ADD_CHILD_TIMER(_runtime_profile, "GinFilter", "SegmentInit"); + _get_row_ranges_by_vector_index_timer = ADD_CHILD_TIMER(_scan_profile, "GetVectorRowRangesTime", "SegmentInit"); + _vector_search_timer = ADD_CHILD_TIMER(_scan_profile, "VectorSearchTime", "SegmentInit"); + _vector_index_filtered_counter = + ADD_CHILD_COUNTER(_scan_profile, "VectorIndexFilterRows", TUnit::UNIT, "SegmentInit"); + _process_vector_distance_and_id_timer = + ADD_CHILD_TIMER(_scan_profile, "ProcessVectorDistanceAndIdTime", "SegmentInit"); _seg_zm_filtered_counter = ADD_CHILD_COUNTER(_scan_profile, "SegmentZoneMapFilterRows", TUnit::UNIT, "SegmentInit"); _seg_rt_filtered_counter = ADD_CHILD_COUNTER(_scan_profile, "SegmentRuntimeZoneMapFilterRows", TUnit::UNIT, "SegmentInit"); diff --git a/be/src/exec/olap_scan_node.h b/be/src/exec/olap_scan_node.h index 1ad14f10edee9..f2a9c1174df27 100644 --- a/be/src/exec/olap_scan_node.h +++ b/be/src/exec/olap_scan_node.h @@ -248,6 +248,10 @@ class OlapScanNode final : public starrocks::ScanNode { RuntimeProfile::Counter* _bi_filter_timer = nullptr; RuntimeProfile::Counter* _gin_filtered_counter = nullptr; RuntimeProfile::Counter* _gin_filtered_timer = nullptr; + RuntimeProfile::Counter* _get_row_ranges_by_vector_index_timer = nullptr; + RuntimeProfile::Counter* _vector_search_timer = nullptr; + RuntimeProfile::Counter* _vector_index_filtered_counter = nullptr; + RuntimeProfile::Counter* _process_vector_distance_and_id_timer = nullptr; RuntimeProfile::Counter* _pushdown_predicates_counter = nullptr; RuntimeProfile::Counter* _rowsets_read_count = nullptr; RuntimeProfile::Counter* _segments_read_count = nullptr; diff --git a/be/src/exec/pipeline/scan/olap_chunk_source.cpp b/be/src/exec/pipeline/scan/olap_chunk_source.cpp index d0c469985099a..6425eb48414e5 100644 --- a/be/src/exec/pipeline/scan/olap_chunk_source.cpp +++ b/be/src/exec/pipeline/scan/olap_chunk_source.cpp @@ -36,6 +36,7 @@ #include "runtime/exec_env.h" #include "storage/chunk_helper.h" #include "storage/column_predicate_rewriter.h" +#include "storage/index/vector/vector_search_option.h" #include "storage/olap_runtime_range_pruner.hpp" #include "storage/predicate_parser.h" #include "storage/projection_iterator.h" @@ -76,6 +77,13 @@ Status OlapChunkSource::prepare(RuntimeState* state) { RETURN_IF_ERROR(ChunkSource::prepare(state)); _runtime_state = state; const TOlapScanNode& thrift_olap_scan_node = _scan_node->thrift_olap_scan_node(); + const TVectorSearchOptions& vector_search_options = thrift_olap_scan_node.vector_search_options; + _use_vector_index = thrift_olap_scan_node.__isset.vector_search_options && vector_search_options.enable_use_ann; + if (_use_vector_index) { + _use_ivfpq = vector_search_options.use_ivfpq; + _vector_distance_column_name = vector_search_options.vector_distance_column_name; + _params.vector_search_option = std::make_shared(); + } const TupleDescriptor* tuple_desc = state->desc_tbl().get_tuple_descriptor(thrift_olap_scan_node.tuple_id); _slots = &tuple_desc->slots(); @@ -129,6 +137,11 @@ void OlapChunkSource::_init_counter(RuntimeState* state) { const std::string segment_init_name = "SegmentInit"; _seg_init_timer = ADD_CHILD_TIMER(_runtime_profile, segment_init_name, IO_TASK_EXEC_TIMER_NAME); _bi_filter_timer = ADD_CHILD_TIMER(_runtime_profile, "BitmapIndexFilter", segment_init_name); + _get_row_ranges_by_vector_index_timer = + ADD_CHILD_TIMER(_runtime_profile, "GetVectorRowRangesTime", segment_init_name); + _vector_search_timer = ADD_CHILD_TIMER(_runtime_profile, "VectorSearchTime", segment_init_name); + _process_vector_distance_and_id_timer = + ADD_CHILD_TIMER(_runtime_profile, "ProcessVectorDistanceAndIdTime", segment_init_name); _bi_filtered_counter = ADD_CHILD_COUNTER(_runtime_profile, "BitmapIndexFilterRows", TUnit::UNIT, segment_init_name); _bf_filtered_counter = ADD_CHILD_COUNTER(_runtime_profile, "BloomFilterFilterRows", TUnit::UNIT, segment_init_name); _gin_filtered_counter = ADD_CHILD_COUNTER(_runtime_profile, "GinFilterRows", TUnit::UNIT, segment_init_name); @@ -140,6 +153,8 @@ void OlapChunkSource::_init_counter(RuntimeState* state) { ADD_CHILD_COUNTER(_runtime_profile, "SegmentRuntimeZoneMapFilterRows", TUnit::UNIT, segment_init_name); _zm_filtered_counter = ADD_CHILD_COUNTER(_runtime_profile, "ZoneMapIndexFilterRows", TUnit::UNIT, segment_init_name); + _vector_index_filtered_counter = + ADD_CHILD_COUNTER(_runtime_profile, "VectorIndexFilterRows", TUnit::UNIT, segment_init_name); _sk_filtered_counter = ADD_CHILD_COUNTER_SKIP_MIN_MAX(_runtime_profile, "ShortKeyFilterRows", TUnit::UNIT, _get_counter_min_max_type("ShortKeyFilterRows"), segment_init_name); @@ -213,6 +228,24 @@ Status OlapChunkSource::_init_reader_params(const std::vectorvector_distance_column_name = _vector_distance_column_name; + _params.vector_search_option->k = vector_options.vector_limit_k; + for (const std::string& str : vector_options.query_vector) { + _params.vector_search_option->query_vector.push_back(std::stof(str)); + } + if (_runtime_state->query_options().__isset.ann_params) { + _params.vector_search_option->query_params = _runtime_state->query_options().ann_params; + } + _params.vector_search_option->vector_range = vector_options.vector_range; + _params.vector_search_option->result_order = vector_options.result_order; + _params.vector_search_option->use_ivfpq = _use_ivfpq; + _params.vector_search_option->k_factor = _runtime_state->query_options().k_factor; + _params.vector_search_option->pq_refine_factor = _runtime_state->query_options().pq_refine_factor; + } if (thrift_olap_scan_node.__isset.sorted_by_keys_per_tablet) { _params.sorted_by_keys_per_tablet = thrift_olap_scan_node.sorted_by_keys_per_tablet; } @@ -271,7 +304,16 @@ Status OlapChunkSource::_init_reader_params(const std::vector& scanner_columns) { for (auto slot : *_slots) { DCHECK(slot->is_materialized()); - int32_t index = _tablet_schema->field_index(slot->col_name()); + int32_t index; + if (_use_vector_index && !_use_ivfpq) { + index = _tablet_schema->field_index(slot->col_name(), _vector_distance_column_name); + if (slot->col_name() == _vector_distance_column_name) { + _params.vector_search_option->vector_column_id = index; + _params.vector_search_option->vector_slot_id = slot->id(); + } + } else { + index = _tablet_schema->field_index(slot->col_name()); + } if (index < 0) { std::stringstream ss; ss << "invalid field name: " << slot->col_name(); @@ -295,7 +337,12 @@ Status OlapChunkSource::_init_scanner_columns(std::vector& scanner_col Status OlapChunkSource::_init_unused_output_columns(const std::vector& unused_output_columns) { for (const auto& col_name : unused_output_columns) { - int32_t index = _tablet_schema->field_index(col_name); + int32_t index; + if (_use_vector_index && !_use_ivfpq) { + index = _tablet_schema->field_index(col_name, _vector_distance_column_name); + } else { + index = _tablet_schema->field_index(col_name); + } if (index < 0) { std::stringstream ss; ss << "invalid field name: " << col_name; @@ -500,7 +547,12 @@ Status OlapChunkSource::_init_global_dicts(TabletReaderParams* params) { auto iter = global_dict_map.find(slot->id()); if (iter != global_dict_map.end()) { auto& dict_map = iter->second.first; - int32_t index = _tablet_schema->field_index(slot->col_name()); + int32_t index; + if (_use_vector_index && !_use_ivfpq) { + index = _tablet_schema->field_index(slot->col_name(), _vector_distance_column_name); + } else { + index = _tablet_schema->field_index(slot->col_name()); + } DCHECK(index >= 0); global_dict->emplace(index, const_cast(&dict_map)); } @@ -610,6 +662,7 @@ void OlapChunkSource::_update_counter() { COUNTER_UPDATE(_seg_zm_filtered_counter, _reader->stats().segment_stats_filtered); COUNTER_UPDATE(_seg_rt_filtered_counter, _reader->stats().runtime_stats_filtered); COUNTER_UPDATE(_zm_filtered_counter, _reader->stats().rows_stats_filtered); + COUNTER_UPDATE(_vector_index_filtered_counter, _reader->stats().rows_vector_index_filtered); COUNTER_UPDATE(_bf_filtered_counter, _reader->stats().rows_bf_filtered); COUNTER_UPDATE(_sk_filtered_counter, _reader->stats().rows_key_range_filtered); COUNTER_UPDATE(_rows_after_sk_filtered_counter, _reader->stats().rows_after_key_range); @@ -620,8 +673,9 @@ void OlapChunkSource::_update_counter() { COUNTER_UPDATE(_bi_filtered_counter, _reader->stats().rows_bitmap_index_filtered); COUNTER_UPDATE(_bi_filter_timer, _reader->stats().bitmap_index_filter_timer); - COUNTER_UPDATE(_gin_filtered_counter, _reader->stats().rows_gin_filtered); - COUNTER_UPDATE(_gin_filtered_timer, _reader->stats().gin_index_filter_ns); + COUNTER_UPDATE(_get_row_ranges_by_vector_index_timer, _reader->stats().get_row_ranges_by_vector_index_timer); + COUNTER_UPDATE(_vector_search_timer, _reader->stats().vector_search_timer); + COUNTER_UPDATE(_process_vector_distance_and_id_timer, _reader->stats().process_vector_distance_and_id_timer); COUNTER_UPDATE(_block_seek_counter, _reader->stats().block_seek_num); COUNTER_UPDATE(_rowsets_read_count, _reader->stats().rowsets_read_count); diff --git a/be/src/exec/pipeline/scan/olap_chunk_source.h b/be/src/exec/pipeline/scan/olap_chunk_source.h index 50e73383e30b8..eafc28043002b 100644 --- a/be/src/exec/pipeline/scan/olap_chunk_source.h +++ b/be/src/exec/pipeline/scan/olap_chunk_source.h @@ -74,7 +74,7 @@ class OlapChunkSource final : public ChunkSource { OlapScanNode* _scan_node; OlapScanContext* _scan_ctx; - const int64_t _limit; // -1: no limit + int64_t _limit; // -1: no limit TInternalScanRange* _scan_range; PredicateTree _non_pushdown_pred_tree; @@ -104,6 +104,12 @@ class OlapChunkSource final : public ChunkSource { std::vector _column_access_paths; + bool _use_vector_index = false; + + bool _use_ivfpq = false; + + std::string _vector_distance_column_name; + // The following are profile meatures int64_t _num_rows_read = 0; @@ -132,6 +138,7 @@ class OlapChunkSource final : public ChunkSource { RuntimeProfile::Counter* _rows_key_range_counter = nullptr; RuntimeProfile::Counter* _bf_filter_timer = nullptr; RuntimeProfile::Counter* _zm_filtered_counter = nullptr; + RuntimeProfile::Counter* _vector_index_filtered_counter = nullptr; RuntimeProfile::Counter* _bf_filtered_counter = nullptr; RuntimeProfile::Counter* _seg_zm_filtered_counter = nullptr; RuntimeProfile::Counter* _seg_rt_filtered_counter = nullptr; @@ -148,6 +155,9 @@ class OlapChunkSource final : public ChunkSource { RuntimeProfile::Counter* _bi_filter_timer = nullptr; RuntimeProfile::Counter* _gin_filtered_counter = nullptr; RuntimeProfile::Counter* _gin_filtered_timer = nullptr; + RuntimeProfile::Counter* _get_row_ranges_by_vector_index_timer = nullptr; + RuntimeProfile::Counter* _vector_search_timer = nullptr; + RuntimeProfile::Counter* _process_vector_distance_and_id_timer = nullptr; RuntimeProfile::Counter* _pushdown_predicates_counter = nullptr; RuntimeProfile::Counter* _non_pushdown_predicates_counter = nullptr; RuntimeProfile::Counter* _rowsets_read_count = nullptr; diff --git a/be/src/exprs/math_functions.cpp b/be/src/exprs/math_functions.cpp index 361036d8f6243..71e7fece64918 100644 --- a/be/src/exprs/math_functions.cpp +++ b/be/src/exprs/math_functions.cpp @@ -886,7 +886,6 @@ StatusOr MathFunctions::cosine_similarity(FunctionContext* context, c target_sum += target_data[j] * target_data[j]; } } - if constexpr (!isNorm) { result_value = sum / (std::sqrt(base_sum) * std::sqrt(target_sum)); } else { @@ -899,10 +898,113 @@ StatusOr MathFunctions::cosine_similarity(FunctionContext* context, c return result; } -// explicitly instaniate template function. +// explicitly instantiate template function. template StatusOr MathFunctions::cosine_similarity(FunctionContext* context, const Columns& columns); template StatusOr MathFunctions::cosine_similarity(FunctionContext* context, const Columns& columns); +template +StatusOr MathFunctions::l2_distance(FunctionContext* context, const Columns& columns) { + DCHECK_EQ(columns.size(), 2); + + const Column* base = columns[0].get(); + const Column* target = columns[1].get(); + size_t target_size = target->size(); + if (base->size() != target_size) { + return Status::InvalidArgument(fmt::format( + "l2_distance requires equal length arrays. base array size is {} and target array size is {}.", + base->size(), target->size())); + } + if (base->has_null() || target->has_null()) { + return Status::InvalidArgument(fmt::format("l2_distance does not support null values. {} array has null value.", + base->has_null() ? "base" : "target")); + } + if (base->is_constant()) { + auto* const_column = down_cast(base); + const_column->data_column()->assign(base->size(), 0); + base = const_column->data_column().get(); + } + if (target->is_constant()) { + auto* const_column = down_cast(target); + const_column->data_column()->assign(target->size(), 0); + target = const_column->data_column().get(); + } + if (base->is_nullable()) { + base = down_cast(base)->data_column().get(); + } + if (target->is_nullable()) { + target = down_cast(target)->data_column().get(); + } + + // check dimension equality. + const Column* base_flat = down_cast(base)->elements_column().get(); + const uint32_t* base_offset = down_cast(base)->offsets().get_data().data(); + size_t base_flat_size = base_flat->size(); + + const Column* target_flat = down_cast(target)->elements_column().get(); + size_t target_flat_size = target_flat->size(); + const uint32_t* target_offset = down_cast(target)->offsets().get_data().data(); + + if (base_flat_size != target_flat_size) { + return Status::InvalidArgument("l2_distance requires equal length arrays"); + } + + if (base_flat->has_null() || target_flat->has_null()) { + return Status::InvalidArgument("l2_distance does not support null values"); + } + if (base_flat->is_nullable()) { + base_flat = down_cast(base_flat)->data_column().get(); + } + if (target_flat->is_nullable()) { + target_flat = down_cast(target_flat)->data_column().get(); + } + + using CppType = RunTimeCppType; + using ColumnType = RunTimeColumnType; + + const CppType* base_data_head = down_cast(base_flat)->get_data().data(); + const CppType* target_data_head = down_cast(target_flat)->get_data().data(); + + // prepare result with nullable value. + ColumnPtr result = ColumnHelper::create_column(TypeDescriptor{TYPE}, false, false, target_size); + ColumnType* data_result = down_cast(result.get()); + CppType* result_data = data_result->get_data().data(); + + for (size_t i = 0; i < target_size; i++) { + size_t t_dim_size = target_offset[i + 1] - target_offset[i]; + size_t b_dim_size = base_offset[i + 1] - base_offset[i]; + if (t_dim_size != b_dim_size) { + return Status::InvalidArgument( + fmt::format("l2_distance requires equal length arrays in each row. base array dimension size " + "is {}, target array dimension size is {}.", + b_dim_size, t_dim_size)); + } + if (t_dim_size == 0) { + return Status::InvalidArgument("l2_distance requires non-empty arrays in each row"); + } + } + + const CppType* target_data = target_data_head; + const CppType* base_data = base_data_head; + + for (size_t i = 0; i < target_size; i++) { + CppType sum = 0; + size_t dim_size = target_offset[i + 1] - target_offset[i]; + for (size_t j = 0; j < dim_size; j++) { + CppType distance; + distance = (base_data[j] - target_data[j]) * (base_data[j] - target_data[j]); + sum += distance; + } + result_data[i] = sum; + target_data += dim_size; + base_data += dim_size; + } + + return result; +} + +// explicitly instantiate template function. +template StatusOr MathFunctions::l2_distance(FunctionContext* context, const Columns& columns); + } // namespace starrocks diff --git a/be/src/exprs/math_functions.h b/be/src/exprs/math_functions.h index 6a5b3abc850aa..87ab849df393e 100644 --- a/be/src/exprs/math_functions.h +++ b/be/src/exprs/math_functions.h @@ -167,6 +167,15 @@ class MathFunctions { template DEFINE_VECTORIZED_FN(cosine_similarity); + template + DEFINE_VECTORIZED_FN(cosine_similarity2); + + template + DEFINE_VECTORIZED_FN(l2_distance); + + template + DEFINE_VECTORIZED_FN(l2_distance2); + /** * @param columns: [DoubleColumn] * @return BigIntColumn diff --git a/be/src/storage/CMakeLists.txt b/be/src/storage/CMakeLists.txt index bb75e11bddb04..ee184bbdb0601 100644 --- a/be/src/storage/CMakeLists.txt +++ b/be/src/storage/CMakeLists.txt @@ -254,8 +254,12 @@ add_library(Storage STATIC index/inverted/clucene/clucene_inverted_writer.cpp index/inverted/clucene/clucene_inverted_reader.cpp index/inverted/clucene/match_operator.cpp + index/vector/empty_index_reader.cpp index/vector/vector_index_builder_factory.cpp index/vector/vector_index_writer.cpp index/vector/vector_index_builder.cpp + index/vector/vector_index_reader_factory.cpp + index/vector/tenann_index_reader.cpp + index/vector/tenann/del_id_filter.cpp index/vector/tenann/tenann_index_builder.cpp index/vector/tenann/tenann_index_utils.cpp) diff --git a/be/src/storage/index/index_descriptor.h b/be/src/storage/index/index_descriptor.h index 4891d18d23627..e330046f64b3d 100644 --- a/be/src/storage/index/index_descriptor.h +++ b/be/src/storage/index/index_descriptor.h @@ -25,6 +25,7 @@ namespace starrocks { class IndexDescriptor { public: inline static const std::string mark_word = "TENANNEMPTYMARK"; + inline static const int64_t mark_word_len = 15; static StatusOr get_index_file_path(const IndexType index_type, const std::string& rowset_dir, const std::string& rowset_id, int segment_id, int64_t index_id) { diff --git a/be/src/storage/index/vector/empty_index_reader.cpp b/be/src/storage/index/vector/empty_index_reader.cpp new file mode 100644 index 0000000000000..ca625892e1cd5 --- /dev/null +++ b/be/src/storage/index/vector/empty_index_reader.cpp @@ -0,0 +1,17 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "storage/index/vector/empty_index_reader.h" + +namespace starrocks {} // namespace starrocks \ No newline at end of file diff --git a/be/src/storage/index/vector/empty_index_reader.h b/be/src/storage/index/vector/empty_index_reader.h new file mode 100644 index 0000000000000..f3d541bcffbcc --- /dev/null +++ b/be/src/storage/index/vector/empty_index_reader.h @@ -0,0 +1,41 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "storage/index/vector/vector_index_reader.h" + +namespace starrocks { + +class EmptyIndexReader final : public VectorIndexReader { +public: + ~EmptyIndexReader() override = default; + + Status init_searcher(const tenann::IndexMeta& meta, const std::string& index_path) override { + return Status::NotSupported("Not implement"); + } + + Status search(tenann::PrimitiveSeqView query_vector, int k, int64_t* result_ids, uint8_t* result_distances, + tenann::IdFilter* id_filter = nullptr) override { + return Status::NotSupported("Not implement"); + } + + Status range_search(tenann::PrimitiveSeqView query_vector, int k, std::vector* result_ids, + std::vector* result_distances, tenann::IdFilter* id_filter, float range, + int order) override { + return Status::NotSupported("Not implement"); + } +}; + +} // namespace starrocks \ No newline at end of file diff --git a/be/src/storage/index/vector/tenann/del_id_filter.cpp b/be/src/storage/index/vector/tenann/del_id_filter.cpp new file mode 100644 index 0000000000000..deb4ec8d75886 --- /dev/null +++ b/be/src/storage/index/vector/tenann/del_id_filter.cpp @@ -0,0 +1,52 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is based on code available under the Apache license here: +// https://github.com/apache/incubator-doris/blob/master/be/src/olap/rowset/segment_v2/column_reader.h + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifdef WITH_TENANN +#include "storage/index/vector/tenann/del_id_filter.h" + +#include "storage/range.h" +#include "storage/roaring2range.h" +#include "tenann/common/type_traits.h" +#include "tenann/searcher/id_filter.h" + +namespace starrocks { + +DelIdFilter::DelIdFilter(const SparseRange<>& scan_range) : _row_bitmap(range2roaring(scan_range)) {} + +bool DelIdFilter::IsMember(tenann::idx_t id) const { + return _row_bitmap.contains(id); +} + +} // namespace starrocks +#endif \ No newline at end of file diff --git a/be/src/storage/index/vector/tenann/del_id_filter.h b/be/src/storage/index/vector/tenann/del_id_filter.h new file mode 100644 index 0000000000000..b23b3e245fdd5 --- /dev/null +++ b/be/src/storage/index/vector/tenann/del_id_filter.h @@ -0,0 +1,54 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is based on code available under the Apache license here: +// https://github.com/apache/incubator-doris/blob/master/be/src/olap/rowset/segment_v2/column_reader.h + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#ifdef WITH_TENANN +#include "storage/del_vector.h" +#include "storage/range.h" +#include "tenann/common/type_traits.h" +#include "tenann/searcher/id_filter.h" + +namespace starrocks { + +class DelIdFilter : public tenann::IdFilter { +public: + DelIdFilter(const SparseRange<>& scan_range); + ~DelIdFilter() = default; + + bool IsMember(tenann::idx_t id) const override; + +private: + Roaring _row_bitmap; +}; + +} // namespace starrocks +#endif \ No newline at end of file diff --git a/be/src/storage/index/vector/tenann/tenann_index_utils.h b/be/src/storage/index/vector/tenann/tenann_index_utils.h index c4761ba50b448..44be262da46d7 100644 --- a/be/src/storage/index/vector/tenann/tenann_index_utils.h +++ b/be/src/storage/index/vector/tenann/tenann_index_utils.h @@ -41,10 +41,8 @@ static const std::string RANGE_SEARCH_CONFIDENCE = "range_search_confidence"; }; // namespace starrocks::index::vector namespace starrocks { - StatusOr get_vector_meta(const std::shared_ptr& tablet_index, const std::map& query_params); - } // namespace starrocks #endif \ No newline at end of file diff --git a/be/src/storage/index/vector/tenann_index_reader.cpp b/be/src/storage/index/vector/tenann_index_reader.cpp new file mode 100644 index 0000000000000..d40630a3f7d5b --- /dev/null +++ b/be/src/storage/index/vector/tenann_index_reader.cpp @@ -0,0 +1,98 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is based on code available under the Apache license here: +// https://github.com/apache/incubator-doris/blob/master/be/src/olap/tablet.h + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "tenann_index_reader.h" + +#include "common/config.h" +#include "common/status.h" +#include "tenann/common/error.h" +#include "tenann/common/seq_view.h" +#include "tenann/searcher/id_filter.h" + +namespace starrocks { + +Status TenANNReader::init_searcher(const tenann::IndexMeta& meta, const std::string& index_path) { + try { + auto meta_copy = meta; + if (meta.index_type() == tenann::IndexType::kFaissIvfPq) { + if (config::enable_vector_index_block_cache) { + // cache index blocks + meta_copy.index_reader_options()[tenann::IndexReaderOptions::cache_index_file_key] = false; + meta_copy.index_reader_options()[tenann::IndexReaderOptions::cache_index_block_key] = true; + } else { + // cache index file + meta_copy.index_reader_options()[tenann::IndexReaderOptions::cache_index_file_key] = true; + meta_copy.index_reader_options()[tenann::IndexReaderOptions::cache_index_block_key] = false; + } + } else { + // cache index file + meta_copy.index_reader_options()[tenann::IndexReaderOptions::cache_index_file_key] = true; + } + + tenann::IndexCache::GetGlobalInstance()->SetCapacity(config::vector_query_cache_capacity); + + _searcher = tenann::AnnSearcherFactory::CreateSearcherFromMeta(meta_copy); + _searcher->index_reader()->SetIndexCache(tenann::IndexCache::GetGlobalInstance()); + _searcher->ReadIndex(index_path); + + DCHECK(_searcher->is_index_loaded()); + } catch (tenann::Error& e) { + return Status::InternalError(e.what()); + } + return Status::OK(); +} + +Status TenANNReader::search(tenann::PrimitiveSeqView query_vector, int k, int64_t* result_ids, + uint8_t* result_distances, tenann::IdFilter* id_filter) { + try { + _searcher->AnnSearch(query_vector, k, result_ids, result_distances, id_filter); + } catch (tenann::Error& e) { + return Status::InternalError(e.what()); + } + return Status::OK(); +}; + +Status TenANNReader::range_search(tenann::PrimitiveSeqView query_vector, int k, std::vector* result_ids, + std::vector* result_distances, tenann::IdFilter* id_filter, float range, + int order) { + try { + _searcher->RangeSearch(query_vector, range, k, tenann::AnnSearcher::ResultOrder(order), result_ids, + result_distances, id_filter); + } catch (tenann::Error& e) { + return Status::InternalError(e.what()); + } + return Status::OK(); +}; + +} // namespace starrocks \ No newline at end of file diff --git a/be/src/storage/index/vector/tenann_index_reader.h b/be/src/storage/index/vector/tenann_index_reader.h new file mode 100644 index 0000000000000..0751d86e262d6 --- /dev/null +++ b/be/src/storage/index/vector/tenann_index_reader.h @@ -0,0 +1,47 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "common/status.h" +#include "storage/index/vector/vector_index_reader.h" +#include "tenann/common/seq_view.h" +#include "tenann/common/type_traits.h" +#include "tenann/factory/ann_searcher_factory.h" +#include "tenann/factory/index_factory.h" +#include "tenann/searcher/ann_searcher.h" +#include "tenann/searcher/faiss_hnsw_ann_searcher.h" +#include "tenann/searcher/id_filter.h" +#include "tenann/store/index_meta.h" + +namespace starrocks { + +class TenANNReader final : public VectorIndexReader { +public: + TenANNReader() = default; + ~TenANNReader() override{}; + + Status init_searcher(const tenann::IndexMeta& meta, const std::string& index_path) override; + + Status search(tenann::PrimitiveSeqView query_vector, int k, int64_t* result_ids, uint8_t* result_distances, + tenann::IdFilter* id_filter = nullptr) override; + Status range_search(tenann::PrimitiveSeqView query_vector, int k, std::vector* result_ids, + std::vector* result_distances, tenann::IdFilter* id_filter, float range, + int order) override; + +private: + std::shared_ptr _searcher; +}; + +} // namespace starrocks diff --git a/be/src/storage/index/vector/vector_index_reader.h b/be/src/storage/index/vector/vector_index_reader.h new file mode 100644 index 0000000000000..d43d8d0c58483 --- /dev/null +++ b/be/src/storage/index/vector/vector_index_reader.h @@ -0,0 +1,58 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is based on code available under the Apache license here: +// https://github.com/apache/incubator-doris/blob/master/be/src/olap/rowset/segment_v2/column_reader.h + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "common/status.h" +#include "tenann/common/seq_view.h" +#include "tenann/searcher/id_filter.h" +#include "tenann/store/index_meta.h" + +namespace starrocks { + +class VectorIndexReader { +public: + VectorIndexReader() = default; + virtual ~VectorIndexReader() = default; + + virtual Status init_searcher(const tenann::IndexMeta& meta, const std::string& index_path) = 0; + + virtual Status search(tenann::PrimitiveSeqView query_vector, int k, int64_t* result_ids, uint8_t* result_distances, + tenann::IdFilter* id_filter = nullptr) = 0; + virtual Status range_search(tenann::PrimitiveSeqView query_vector, int k, std::vector* result_ids, + std::vector* result_distances, tenann::IdFilter* id_filter, float range, + int order) = 0; +}; + +} // namespace starrocks diff --git a/be/src/storage/index/vector/vector_index_reader_factory.cpp b/be/src/storage/index/vector/vector_index_reader_factory.cpp new file mode 100644 index 0000000000000..84e21dc25c03a --- /dev/null +++ b/be/src/storage/index/vector/vector_index_reader_factory.cpp @@ -0,0 +1,44 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "storage/index/vector/vector_index_reader_factory.h" + +#include "storage/index/vector/empty_index_reader.h" +#include "storage/index/vector/tenann_index_reader.h" +#include "storage/index/vector/vector_index_reader.h" + +namespace starrocks { +Status VectorIndexReaderFactory::create_from_file(const std::string& index_path, + const std::shared_ptr& index_meta, + std::shared_ptr* vector_index_reader) { + if (!fs::path_exist(index_path)) { + return Status::NotFound(fmt::format("index path {} not found", index_path)); + } + ASSIGN_OR_RETURN(auto index_file, fs::new_random_access_file(index_path)) + ASSIGN_OR_RETURN(auto file_size, index_file->get_size()) + + if (file_size == IndexDescriptor::mark_word_len) { + auto buf = std::make_unique(file_size); + RETURN_IF_ERROR(index_file->read_fully(buf.get(), file_size)); + std::string_view buf_str = std::string_view(reinterpret_cast(buf.get()), file_size); + if (buf_str == IndexDescriptor::mark_word) { + (*vector_index_reader) = std::make_shared(); + return Status::OK(); + } + } + (*vector_index_reader) = std::make_shared(); + return Status::OK(); +} + +} // namespace starrocks diff --git a/be/src/storage/index/vector/vector_index_reader_factory.h b/be/src/storage/index/vector/vector_index_reader_factory.h new file mode 100644 index 0000000000000..c0714e2d094b7 --- /dev/null +++ b/be/src/storage/index/vector/vector_index_reader_factory.h @@ -0,0 +1,30 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fmt/format.h" +#include "fs/fs_util.h" +#include "storage/index/index_descriptor.h" +#include "vector_index_reader.h" + +namespace starrocks { + +class VectorIndexReaderFactory { +public: + static Status create_from_file(const std::string& index_path, const std::shared_ptr& index_meta, + std::shared_ptr* vector_index_reader); +}; + +} // namespace starrocks \ No newline at end of file diff --git a/be/src/storage/index/vector/vector_search_option.h b/be/src/storage/index/vector/vector_search_option.h new file mode 100644 index 0000000000000..3395a1487275a --- /dev/null +++ b/be/src/storage/index/vector/vector_search_option.h @@ -0,0 +1,57 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "common/global_types.h" + +namespace starrocks { + +class BaseRowset; +class Segment; + +struct VectorSearchOption { +public: + int64_t k; + + std::vector query_vector; + + std::string vector_distance_column_name; + + bool use_vector_index = false; + + int vector_column_id; + + SlotId vector_slot_id; + + std::map query_params; + + double vector_range; + + int result_order; + + bool use_ivfpq = false; + + double pq_refine_factor; + + double k_factor; + + VectorSearchOption() = default; +}; + +} // namespace starrocks diff --git a/be/src/storage/olap_common.h b/be/src/storage/olap_common.h index ebf1c111dd47a..877e1b979f49c 100644 --- a/be/src/storage/olap_common.h +++ b/be/src/storage/olap_common.h @@ -253,6 +253,7 @@ struct OlapReaderStatistics { int64_t rows_after_key_range = 0; int64_t rows_key_range_num = 0; int64_t rows_stats_filtered = 0; + int64_t rows_vector_index_filtered = 0; int64_t rows_bf_filtered = 0; int64_t rows_del_filtered = 0; int64_t del_filter_ns = 0; @@ -262,6 +263,9 @@ struct OlapReaderStatistics { int64_t rows_bitmap_index_filtered = 0; int64_t bitmap_index_filter_timer = 0; + int64_t get_row_ranges_by_vector_index_timer = 0; + int64_t vector_search_timer = 0; + int64_t process_vector_distance_and_id_timer = 0; int64_t rows_del_vec_filtered = 0; diff --git a/be/src/storage/rowset/rowset.cpp b/be/src/storage/rowset/rowset.cpp index 8d82f0d64500b..0ab75e924e8b5 100644 --- a/be/src/storage/rowset/rowset.cpp +++ b/be/src/storage/rowset/rowset.cpp @@ -708,6 +708,9 @@ Status Rowset::get_segment_iterators(const Schema& schema, const RowsetReadOptio seg_options.runtime_range_pruner = options.runtime_range_pruner; seg_options.column_access_paths = options.column_access_paths; seg_options.tablet_schema = options.tablet_schema; + seg_options.use_vector_index = options.use_vector_index; + seg_options.vector_search_option = options.vector_search_option; + if (options.delete_predicates != nullptr) { seg_options.delete_predicates = options.delete_predicates->get_predicates(end_version()); } diff --git a/be/src/storage/rowset/rowset_options.h b/be/src/storage/rowset/rowset_options.h index 79a8319d502e9..141d69d1c2c5f 100644 --- a/be/src/storage/rowset/rowset_options.h +++ b/be/src/storage/rowset/rowset_options.h @@ -41,6 +41,8 @@ class DeletePredicates; class ChunkPredicate; struct RowidRangeOption; struct ShortKeyRangesOption; +struct VectorSearchOption; +using VectorSearchOptionPtr = std::shared_ptr; class RowsetReadOptions { using RowidRangeOptionPtr = std::shared_ptr; @@ -88,6 +90,10 @@ class RowsetReadOptions { bool prune_column_after_index_filter = false; bool enable_gin_filter = false; bool has_preaggregation = true; + + bool use_vector_index = false; + + VectorSearchOptionPtr vector_search_option = nullptr; }; } // namespace starrocks diff --git a/be/src/storage/rowset/segment_iterator.cpp b/be/src/storage/rowset/segment_iterator.cpp index 960ce8a3fce0e..8e9a4e2cda00f 100644 --- a/be/src/storage/rowset/segment_iterator.cpp +++ b/be/src/storage/rowset/segment_iterator.cpp @@ -40,6 +40,11 @@ #include "storage/column_predicate_rewriter.h" #include "storage/del_vector.h" #include "storage/index/index_descriptor.h" +#include "storage/index/vector/tenann/del_id_filter.h" +#include "storage/index/vector/tenann/tenann_index_utils.h" +#include "storage/index/vector/vector_index_reader.h" +#include "storage/index/vector/vector_index_reader_factory.h" +#include "storage/index/vector/vector_search_option.h" #include "storage/lake/update_manager.h" #include "storage/olap_runtime_range_pruner.hpp" #include "storage/projection_iterator.h" @@ -212,8 +217,10 @@ class SegmentIterator final : public ChunkIterator { StatusOr> _get_row_ranges_by_key_ranges(); StatusOr> _get_row_ranges_by_short_key_ranges(); Status _get_row_ranges_by_zone_map(); + Status _get_row_ranges_by_vector_index(); Status _get_row_ranges_by_bloom_filter(); Status _get_row_ranges_by_rowid_range(); + Status _get_row_ranges_by_row_ids(std::vector* result_ids, SparseRange<>* r); uint32_t segment_id() const { return _segment->id(); } uint32_t num_rows() const { return _segment->num_rows(); } @@ -248,6 +255,8 @@ class SegmentIterator final : public ChunkIterator { Status _encode_to_global_id(ScanContext* ctx); + FieldPtr _make_field(size_t i); + Status _switch_context(ScanContext* to); // `_check_low_cardinality_optimization` and `_init_column_iterators` must have been called @@ -286,6 +295,8 @@ class SegmentIterator final : public ChunkIterator { bool need_early_materialize_subfield(const FieldPtr& field); + Status _init_ann_reader(); + private: using RawColumnIterators = std::vector>; using ColumnDecoders = std::vector; @@ -295,6 +306,7 @@ class SegmentIterator final : public ChunkIterator { RawColumnIterators _column_iterators; std::vector _io_coalesce_column_index; ColumnDecoders _column_decoders; + std::shared_ptr _ann_reader; BitmapIndexEvaluator _bitmap_index_evaluator; // delete predicates std::map _del_predicates; @@ -349,6 +361,28 @@ class SegmentIterator final : public ChunkIterator { std::unordered_map _predicate_column_access_paths; std::unordered_set _prune_cols_candidate_by_inverted_index; + + // vector index params + int64_t _k; +#ifdef WITH_TENANN + tenann::PrimitiveSeqView _query_view; + std::shared_ptr _index_meta; +#endif + bool _use_vector_index; + std::string _vector_distance_column_name; + int _vector_column_id; + SlotId _vector_slot_id; + std::unordered_map _id2distance_map; + std::vector _first_rowids; + std::map _query_params; + double _vector_range; + int _result_order; + bool _use_ivfpq; + Buffer _filter_selection; + Buffer _filter_by_expr_selection; + + Status _init_reader_from_file(const std::string& index_path, const std::shared_ptr& tablet_index_meta, + const std::map& query_params); }; SegmentIterator::SegmentIterator(std::shared_ptr segment, Schema schema, SegmentReadOptions options) @@ -356,7 +390,33 @@ SegmentIterator::SegmentIterator(std::shared_ptr segment, Schema schema _segment(std::move(segment)), _opts(std::move(options)), _bitmap_index_evaluator(_schema, _opts.pred_tree), - _predicate_columns(_opts.pred_tree.num_columns()) { + _predicate_columns(_opts.pred_tree.num_columns()), + _use_vector_index(_opts.use_vector_index) { + if (_use_vector_index) { + // The K in front of Fe is long, which can be changed to uint32. This can be a problem, + // but this k is wasted memory allocation, so it should not exceed the accuracy of uint32 + // options.query_vector is a string, passed to tenann as a float string, see if you need to use the stof function to convert + // and consider precision loss + _vector_distance_column_name = _opts.vector_search_option->vector_distance_column_name; + _vector_column_id = _opts.vector_search_option->vector_column_id; + _vector_slot_id = _opts.vector_search_option->vector_slot_id; + _vector_range = _opts.vector_search_option->vector_range; + _result_order = _opts.vector_search_option->result_order; + _use_ivfpq = _opts.vector_search_option->use_ivfpq; + _query_params = _opts.vector_search_option->query_params; + if (_vector_range > 0 && _use_ivfpq) { + _k = _opts.vector_search_option->k * _opts.vector_search_option->pq_refine_factor * + _opts.vector_search_option->k_factor; + } else { + _k = _opts.vector_search_option->k * _opts.vector_search_option->k_factor; + } +#ifdef WITH_TENANN + _query_view = tenann::PrimitiveSeqView{ + .data = reinterpret_cast(_opts.vector_search_option->query_vector.data()), + .size = static_cast(_opts.vector_search_option->query_vector.size()), + .elem_type = tenann::PrimitiveType::kFloatType}; +#endif + } // For small segment file (the number of rows is less than chunk_size), // the segment iterator will reserve a large amount of memory, // especially when there are many columns, many small files, many versions, @@ -432,7 +492,7 @@ Status SegmentIterator::_init() { _init_column_access_paths(); RETURN_IF_ERROR(_check_low_cardinality_optimization()); RETURN_IF_ERROR(_init_column_iterators(_schema)); - + RETURN_IF_ERROR(_init_ann_reader()); // filter by index stage // Use indexes and predicates to filter some data page RETURN_IF_ERROR(_get_row_ranges_by_rowid_range()); @@ -449,6 +509,7 @@ Status SegmentIterator::_init() { if (apply_del_vec_after_all_index_filter) { RETURN_IF_ERROR(_apply_del_vector()); } + RETURN_IF_ERROR(_get_row_ranges_by_vector_index()); // rewrite stage // Rewriting predicates using segment dictionary codes RETURN_IF_ERROR(_rewrite_predicates()); @@ -469,6 +530,155 @@ Status SegmentIterator::_init() { return Status::OK(); } +inline Status SegmentIterator::_init_reader_from_file(const std::string& index_path, + const std::shared_ptr& tablet_index_meta, + const std::map& query_params) { +#ifdef WITH_TENANN + ASSIGN_OR_RETURN(auto meta, get_vector_meta(tablet_index_meta, query_params)) + _index_meta = std::make_shared(std::move(meta)); + RETURN_IF_ERROR(VectorIndexReaderFactory::create_from_file(index_path, _index_meta, &_ann_reader)); + auto status = _ann_reader->init_searcher(*_index_meta.get(), index_path); + // means empty ann reader + if (status.is_not_supported()) { + _use_vector_index = false; + return Status::OK(); + } + return status; +#else + return Status::OK(); +#endif +} + +Status SegmentIterator::_init_ann_reader() { +#ifdef WITH_TENANN + RETURN_IF(!_use_vector_index, Status::OK()); + std::unordered_map col_map_index; + for (const auto& index : *_segment->tablet_schema().indexes()) { + if (index.index_type() == VECTOR) { + col_map_index.emplace(index.col_unique_ids()[0], index); + } + } + + std::vector hit_indexes; + for (auto& field : _schema.fields()) { + if (col_map_index.count(field->uid()) > 0) { + hit_indexes.emplace_back(col_map_index.at(field->uid())); + } + } + + // TODO: Support more index in one segment iterator, only support one index for now + DCHECK(hit_indexes.size() <= 1) << "Only support query no more than one index now"; + + if (hit_indexes.empty()) { + return Status::OK(); + } + + auto tablet_index_meta = std::make_shared(hit_indexes[0]); + + std::string index_path = IndexDescriptor::vector_index_file_path(_opts.rowset_path, _opts.rowsetid.to_string(), + segment_id(), tablet_index_meta->index_id()); + + return _init_reader_from_file(index_path, tablet_index_meta, _query_params); +#else + return Status::OK(); +#endif +} + +Status SegmentIterator::_get_row_ranges_by_vector_index() { +#ifdef WITH_TENANN + RETURN_IF(!_use_vector_index, Status::OK()); + RETURN_IF(_scan_range.empty(), Status::OK()); + + SCOPED_RAW_TIMER(&_opts.stats->get_row_ranges_by_vector_index_timer); + + Status st; + std::map id2distance_map; + std::vector result_ids; + std::vector result_distances; + std::vector filtered_result_ids; + DelIdFilter del_id_filter(_scan_range); + + { + SCOPED_RAW_TIMER(&_opts.stats->vector_search_timer); + if (_vector_range > 0) { + st = _ann_reader->range_search(_query_view, _k, &result_ids, &result_distances, &del_id_filter, + static_cast(_vector_range), _result_order); + } else { + result_ids.resize(_k); + result_distances.resize(_k); + st = _ann_reader->search(_query_view, _k, (result_ids.data()), + reinterpret_cast(result_distances.data()), &del_id_filter); + } + } + + if (!st.ok()) { + LOG(WARNING) << "Vector index search failed: " << st.to_string(); + return Status::InternalError(st.to_string()); + } + SCOPED_RAW_TIMER(&_opts.stats->process_vector_distance_and_id_timer); + + for (size_t i = 0; i < result_ids.size() && result_ids[i] != -1; i++) { + id2distance_map[result_ids[i]] = result_distances[i]; + } + + SparseRange r; + RETURN_IF_ERROR(_get_row_ranges_by_row_ids(&result_ids, &r)); + + size_t prev_size = _scan_range.span_size(); + _scan_range = _scan_range.intersection(r); + _opts.stats->rows_vector_index_filtered += (prev_size - _scan_range.span_size()); + + SparseRangeIterator range_iter = _scan_range.new_iterator(); + size_t to_read = _scan_range.span_size(); + while (range_iter.has_more()) { + Range r = range_iter.next(to_read); + for (uint32_t i = r.begin(); i < r.end(); i++) { + filtered_result_ids.emplace_back(i); + } + } + + _id2distance_map.reserve(filtered_result_ids.size()); + for (size_t i = 0; i < filtered_result_ids.size(); i++) { + _id2distance_map[static_cast(filtered_result_ids[i])] = id2distance_map[filtered_result_ids[i]]; + } + return Status::OK(); +#else + return Status::OK(); +#endif +} + +Status SegmentIterator::_get_row_ranges_by_row_ids(std::vector* result_ids, SparseRange<>* r) { + if (result_ids->empty()) { + return Status::OK(); + } + + std::sort(result_ids->begin(), result_ids->end()); + + // filter -1 above + auto first_valid_id_iter = std::upper_bound(result_ids->begin(), result_ids->end(), -1); + if (first_valid_id_iter == result_ids->end()) { + // All elements are less than 0 + return Status::OK(); + } + + int64_t range_start = *first_valid_id_iter; + int64_t range_end = range_start + 1; + + for (auto it = first_valid_id_iter + 1; it != result_ids->end(); ++it) { + if (*it == range_end) { + ++range_end; + } else { + r->add(Range<>(range_start, range_end)); + range_start = *it; + range_end = range_start + 1; + } + } + + r->add(Range<>(range_start, range_end)); + + return Status::OK(); +} + Status SegmentIterator::_try_to_update_ranges_by_runtime_filter() { return _opts.runtime_range_pruner.update_range_if_arrived( _opts.global_dictmaps, @@ -1098,6 +1308,12 @@ inline Status SegmentIterator::_read(Chunk* chunk, vector* rowids, size chunk->check_or_die(); } + if (_use_vector_index) { + for (uint32_t i = range.begin(); i < range.end(); i++) { + _first_rowids.push_back(i); + } + } + if (rowids != nullptr) { rowids->reserve(rowids->size() + n); SparseRangeIterator<> iter = range.new_iterator(); @@ -1279,6 +1495,28 @@ Status SegmentIterator::_do_get_next(Chunk* result, vector* rowid) { chunk = _context->_adapt_global_dict_chunk.get(); } + if (_use_vector_index && !_use_ivfpq) { + std::shared_ptr distance_column = FloatColumn::create(); + vector rowids; + for (const auto& id : _first_rowids) { + auto it = _id2distance_map.find(id); + if (it != _id2distance_map.end()) { + rowids.emplace_back(it->first); + } + } + if (!rowids.empty()) { + std::sort(rowids.begin(), rowids.end()); + for (const auto& vrid : rowids) { + distance_column->append(_id2distance_map[vrid]); + } + } + if (has_non_expr_predicate && _filter_selection.size() == distance_column->size()) { + distance_column->filter_range(_filter_selection, 0, distance_column->size()); + } + chunk->append_vector_column(distance_column, _make_field(_vector_column_id), _vector_slot_id); + } + _first_rowids.clear(); + result->swap_chunk(*chunk); if (need_switch_context) { @@ -1288,6 +1526,10 @@ Status SegmentIterator::_do_get_next(Chunk* result, vector* rowid) { return Status::OK(); } +FieldPtr SegmentIterator::_make_field(size_t i) { + return std::make_shared(i, _vector_distance_column_name, get_type_info(TYPE_FLOAT), false); +} + Status SegmentIterator::_switch_context(ScanContext* to) { if (_context != nullptr) { const ordinal_t ordinal = _context->_column_iterators[0]->get_current_ordinal(); @@ -1381,6 +1623,9 @@ StatusOr SegmentIterator::_filter_by_non_expr_predicates(Chunk* chunk, } } _opts.stats->rows_vec_cond_filtered += (to - chunk_size); + for (int i = from; i < to; i++) { + _filter_selection.push_back(_selection[i]); + } return chunk_size; } @@ -1406,6 +1651,9 @@ StatusOr SegmentIterator::_filter_by_expr_predicates(Chunk* chunk, vec } } _opts.stats->rows_vec_cond_filtered += (chunk_size - new_size); + for (int i = 0; i < chunk_size; i++) { + _filter_by_expr_selection.push_back(_selection[i]); + } chunk_size = new_size; } return chunk_size; diff --git a/be/src/storage/rowset/segment_options.h b/be/src/storage/rowset/segment_options.h index 2d8de90c80fab..a7c66cec7f6e0 100644 --- a/be/src/storage/rowset/segment_options.h +++ b/be/src/storage/rowset/segment_options.h @@ -45,6 +45,8 @@ struct RowidRangeOption; using RowidRangeOptionPtr = std::shared_ptr; struct ShortKeyRangeOption; using ShortKeyRangeOptionPtr = std::shared_ptr; +struct VectorSearchOption; +using VectorSearchOptionPtr = std::shared_ptr; class SegmentReadOptions { public: @@ -105,6 +107,10 @@ class SegmentReadOptions { bool enable_gin_filter = false; bool has_preaggregation = true; + bool use_vector_index = false; + + VectorSearchOptionPtr vector_search_option = nullptr; + public: Status convert_to(SegmentReadOptions* dst, const std::vector& new_types, ObjectPool* obj_pool) const; diff --git a/be/src/storage/tablet.h b/be/src/storage/tablet.h index 9a8489afc76dc..9f75b19d20898 100644 --- a/be/src/storage/tablet.h +++ b/be/src/storage/tablet.h @@ -119,6 +119,7 @@ class Tablet : public BaseTablet { size_t num_rows_per_row_block_with_max_version() const; size_t next_unique_id() const; size_t field_index_with_max_version(const string& field_name) const; + size_t field_index(const string& field_name, const string& extra_column_name) const; std::string schema_debug_string() const; std::string debug_string() const; bool enable_shortcut_compaction() const; @@ -506,6 +507,10 @@ inline size_t Tablet::field_index_with_max_version(const string& field_name) con return tablet_schema()->field_index(field_name); } +inline size_t Tablet::field_index(const string& field_name, const string& extra_column_name) const { + return tablet_schema()->field_index(field_name, extra_column_name); +} + inline bool Tablet::enable_shortcut_compaction() const { std::shared_lock rdlock(_meta_lock); return _tablet_meta->enable_shortcut_compaction(); diff --git a/be/src/storage/tablet_reader.cpp b/be/src/storage/tablet_reader.cpp index 4fd0267356cc6..29d9c6054bd36 100644 --- a/be/src/storage/tablet_reader.cpp +++ b/be/src/storage/tablet_reader.cpp @@ -286,6 +286,8 @@ Status TabletReader::_init_collector_for_pk_index_read() { rs_opts.runtime_range_pruner = _reader_params->runtime_range_pruner; // single row fetch, no need to use delvec rs_opts.is_primary_keys = false; + rs_opts.use_vector_index = _reader_params->use_vector_index; + rs_opts.vector_search_option = _reader_params->vector_search_option; rs_opts.rowid_range_option = std::make_shared(); auto rowid_range = std::make_shared>(); @@ -358,6 +360,8 @@ Status TabletReader::get_segment_iterators(const TabletReaderParams& params, std rs_opts.unused_output_column_ids = params.unused_output_column_ids; rs_opts.runtime_range_pruner = params.runtime_range_pruner; rs_opts.column_access_paths = params.column_access_paths; + rs_opts.use_vector_index = params.use_vector_index; + rs_opts.vector_search_option = params.vector_search_option; if (keys_type == KeysType::PRIMARY_KEYS) { rs_opts.is_primary_keys = true; rs_opts.version = _version.second; diff --git a/be/src/storage/tablet_reader_params.h b/be/src/storage/tablet_reader_params.h index eeba0678c128f..ff4b263c524d4 100644 --- a/be/src/storage/tablet_reader_params.h +++ b/be/src/storage/tablet_reader_params.h @@ -38,6 +38,8 @@ using RowidRangeOptionPtr = std::shared_ptr; struct ShortKeyRangesOption; using ShortKeyRangesOptionPtr = std::shared_ptr; struct OlapScanRange; +struct VectorSearchOption; +using VectorSearchOptionPtr = std::shared_ptr; static inline std::unordered_set EMPTY_FILTERED_COLUMN_IDS; // Params for TabletReader @@ -94,6 +96,10 @@ struct TabletReaderParams { bool prune_column_after_index_filter = false; bool enable_gin_filter = false; + bool use_vector_index = false; + + VectorSearchOptionPtr vector_search_option = nullptr; + public: std::string to_string() const; }; diff --git a/be/src/storage/tablet_schema.cpp b/be/src/storage/tablet_schema.cpp index f6fb092476606..410f22e069fae 100644 --- a/be/src/storage/tablet_schema.cpp +++ b/be/src/storage/tablet_schema.cpp @@ -678,6 +678,20 @@ size_t TabletSchema::field_index(std::string_view field_name) const { return -1; } +size_t TabletSchema::field_index(std::string_view field_name, std::string_view extra_column_name) const { + int ordinal = -1; + for (auto& column : _cols) { + ordinal++; + if (column.name() == field_name) { + return ordinal; + } + } + if (field_name == extra_column_name) { + return ordinal + 1; + } + return -1; +} + int32_t TabletSchema::field_index(int32_t col_unique_id) const { const auto& found = _unique_id_to_index.find(col_unique_id); return (found == _unique_id_to_index.end()) ? -1 : found->second; diff --git a/be/src/storage/tablet_schema.h b/be/src/storage/tablet_schema.h index 5e7dec9ec3596..26f4a5174341d 100644 --- a/be/src/storage/tablet_schema.h +++ b/be/src/storage/tablet_schema.h @@ -288,6 +288,7 @@ class TabletSchema { size_t estimate_row_size(size_t variable_len) const; int32_t field_index(int32_t col_unique_id) const; size_t field_index(std::string_view field_name) const; + size_t field_index(std::string_view field_name, std::string_view extra_column_name) const; const TabletColumn& column(size_t ordinal) const; const std::vector& columns() const; const std::vector sort_key_idxes() const { return _sort_key_idxes; } diff --git a/be/src/util/system_metrics.h b/be/src/util/system_metrics.h index 04bcb34826a68..38652de6f8182 100644 --- a/be/src/util/system_metrics.h +++ b/be/src/util/system_metrics.h @@ -32,6 +32,7 @@ class SnmpMetrics; class QueryCacheMetrics; class VectorIndexCacheMetrics; class RuntimeFilterMetrics; +class VectorIndexCacheMetrics; class MemoryMetrics { public: diff --git a/be/test/CMakeLists.txt b/be/test/CMakeLists.txt index e1dbfa5acac80..5a9d72b911090 100644 --- a/be/test/CMakeLists.txt +++ b/be/test/CMakeLists.txt @@ -297,6 +297,7 @@ set(EXEC_FILES ./storage/rowset/index_page_test.cpp ./storage/rowset/metadata_cache_test.cpp ./storage/index/vector_index_test.cpp + ./storage/index/vector_search_test.cpp ./storage/snapshot_meta_test.cpp ./storage/short_key_index_test.cpp ./storage/storage_types_test.cpp diff --git a/be/test/storage/index/vector_search_test.cpp b/be/test/storage/index/vector_search_test.cpp new file mode 100644 index 0000000000000..5627ee5c9fa21 --- /dev/null +++ b/be/test/storage/index/vector_search_test.cpp @@ -0,0 +1,184 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#ifdef WITH_TENANN +#include +#include +#endif + +#include "column/column_helper.h" +#include "runtime/mem_pool.h" +#include "storage/index/index_descriptor.h" +#include "storage/index/vector/tenann/del_id_filter.h" +#include "storage/index/vector/tenann/tenann_index_utils.h" +#include "storage/index/vector/vector_index_writer.h" +#include "storage/rowset/bitmap_index_reader.h" +#include "storage/rowset/bitmap_index_writer.h" +#include "testutil/assert.h" + +namespace starrocks { + +class VectorIndexSearchTest : public testing::Test { +public: + VectorIndexSearchTest() = default; + +protected: + void SetUp() override { + srand(GetCurrentTimeMicros()); + CHECK_OK(fs::remove_all(test_vector_index_dir)); + CHECK_OK(fs::create_directories(test_vector_index_dir)); + ASSIGN_OR_ABORT(_fs, FileSystem::CreateSharedFromString(test_vector_index_dir)); + } + + void TearDown() override { fs::remove_all(test_vector_index_dir); } + + std::shared_ptr _fs; + const std::string test_vector_index_dir = "vector_search_test"; + const std::string vector_index_name = "vector_index.vi"; + const std::string empty_index_name = "empty_index.vi"; + + std::shared_ptr prepare_tablet_index() { + std::shared_ptr tablet_index = std::make_shared(); + TabletIndexPB index_pb; + index_pb.set_index_id(0); + index_pb.set_index_name("test_index"); + index_pb.set_index_type(IndexType::VECTOR); + index_pb.add_col_unique_id(1); + tablet_index->init_from_pb(index_pb); + return tablet_index; + } + + void write_vector_index(const std::string& path, const std::shared_ptr& tablet_index) { + DeferOp op([&] { ASSERT_TRUE(fs::path_exist(path)); }); + + std::unique_ptr vector_index_writer; + VectorIndexWriter::create(tablet_index, path, false, &vector_index_writer); + CHECK_OK(vector_index_writer->init()); + + // construct columns + std::shared_ptr> element = std::make_shared>(); + element->append(1); + element->append(2); + element->append(3); + NullColumnPtr null_column = std::make_shared(element->size(), 0); + std::shared_ptr nullable_column = std::make_shared(element, null_column); + std::shared_ptr offsets = std::make_shared(); + offsets->append(0); + offsets->append(3); + for (int i = 0; i < 10; i++) { + std::shared_ptr> e = std::make_shared>(); + e->append(i + 1.1); + e->append(i + 2.2); + e->append(i + 3.3); + nullable_column->append(*e, 0, e->size()); + offsets->append((i + 2) * 3); + } + + ArrayColumn array_column(nullable_column, offsets); + + CHECK_OK(vector_index_writer->append(array_column)); + + ASSERT_EQ(vector_index_writer->size(), 11); + + uint64_t size = 0; + CHECK_OK(vector_index_writer->finish(&size)); + + ASSERT_GT(size, 0); + } +}; + +TEST_F(VectorIndexSearchTest, test_search_vector_index) { + auto tablet_index = prepare_tablet_index(); + tablet_index->add_common_properties("index_type", "hnsw"); + tablet_index->add_common_properties("dim", "3"); + tablet_index->add_common_properties("is_vector_normed", "false"); + tablet_index->add_common_properties("metric_type", "l2_distance"); + tablet_index->add_index_properties("efconstruction", "40"); + tablet_index->add_index_properties("M", "16"); + tablet_index->add_search_properties("efsearch", "40"); + + auto index_path = test_vector_index_dir + "/" + vector_index_name; + write_vector_index(index_path, tablet_index); + +#ifdef WITH_TENANN + try { + const auto& empty_meta = std::map{}; + auto status = get_vector_meta(tablet_index, empty_meta); + + CHECK_OK(status); + auto meta = status.value(); + + std::shared_ptr ann_reader; + VectorIndexReaderFactory::create_from_file(index_path, meta, &ann_reader); + + auto status = ann_reader->init_searcher(meta, index_path); + + ASSERT_TRUE(!status.is_not_supported()); + + Status st; + std::vector result_ids; + std::vector result_distances; + SparseRange<> scan_range; + DelIdFilter del_id_filter(scan_range); + std::vector query_vector = {1.0f, 2.0f, 3.0f}; + tenann::PrimitiveSeqView query_view = + tenann::PrimitiveSeqView{.data = reinterpret_cast(query_vector.data()), + .size = static_cast(3), + .elem_type = tenann::PrimitiveType::kFloatType}; + + st = ann_reader->search(query_view, 1, (result_ids.data()), reinterpret_cast(result_distances.data()), + &del_id_filter); + CHECK_OK(st); + ASSERT_EQ(result_ids.size(), 0); + } catch (tenann::Error& e) { + LOG(WARNING) << e.what(); + } +#endif +} + +TEST_F(VectorIndexSearchTest, test_select_empty_mark) { + config::config_vector_index_default_build_threshold = 100; + auto tablet_index = prepare_tablet_index(); + + tablet_index->add_common_properties("index_type", "ivfpq"); + tablet_index->add_common_properties("dim", "3"); + tablet_index->add_common_properties("is_vector_normed", "false"); + tablet_index->add_common_properties("metric_type", "l2_distance"); + + auto index_path = test_vector_index_dir + "/" + empty_index_name; + write_vector_index(index_path, tablet_index); + +#ifdef WITH_TENANN + try { + const auto& empty_meta = std::map{}; + auto status = get_vector_meta(tablet_index, empty_meta); + + CHECK_OK(status); + auto meta = status.value(); + + std::shared_ptr ann_reader; + VectorIndexReaderFactory::create_from_file(index_path, meta, &ann_reader); + + auto status = ann_reader->init_searcher(meta, index_path); + + ASSERT_TRUE(status.is_not_supported()); + } catch (tenann::Error& e) { + LOG(WARNING) << e.what(); + } +#endif +} + +} // namespace starrocks \ No newline at end of file diff --git a/build.sh b/build.sh index f8d651d37cc38..49d25ea0850ba 100755 --- a/build.sh +++ b/build.sh @@ -98,8 +98,8 @@ Usage: $0 --output-compile-time save a list of the compile time for every C++ file in ${ROOT}/compile_times.txt. Turning this option on automatically disables ccache. - --with-tenann - build with vector index tenann library + --without-tenann + build without vector index tenann library --with-compress-debug-symbol {ON|OFF} build with compressing debug symbol. (default: $WITH_COMPRESS) -h,--help Show this help message @@ -133,7 +133,7 @@ OPTS=$(getopt \ -l 'use-staros' \ -l 'enable-shared-data' \ -l 'output-compile-time' \ - -l 'with-tenann' \ + -l 'without-tenann' \ -l 'with-compress-debug-symbol:' \ -l 'help' \ -- "$@") @@ -158,7 +158,7 @@ WITH_STARCACHE=ON USE_STAROS=OFF BUILD_JAVA_EXT=ON OUTPUT_COMPILE_TIME=OFF -WITH_TENANN=OFF +WITH_TENANN=ON MSG="" MSG_FE="Frontend" MSG_DPP="Spark Dpp application" @@ -180,10 +180,6 @@ if [[ -z ${CCACHE} ]] && [[ -x "$(command -v ccache)" ]]; then CCACHE=ccache fi -if [[ -z ${WITH_TENANN} ]]; then - WITH_TENANN=ON -fi - if [ -e /proc/cpuinfo ] ; then # detect cpuinfo if [[ -z $(grep -o 'avx[^ ]\+' /proc/cpuinfo) ]]; then @@ -246,7 +242,7 @@ else --without-java-ext) BUILD_JAVA_EXT=OFF; shift ;; --without-starcache) WITH_STARCACHE=OFF; shift ;; --output-compile-time) OUTPUT_COMPILE_TIME=ON; shift ;; - --with-tenann) WITH_TENANN=ON; shift ;; + --without-tenann) WITH_TENANN=OFF; shift ;; --with-compress-debug-symbol) WITH_COMPRESS=$2 ; shift 2 ;; -h) HELP=1; shift ;; --help) HELP=1; shift ;; diff --git a/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java b/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java index eb46484b28d22..dea4bec0357c9 100644 --- a/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/com/starrocks/catalog/FunctionSet.java @@ -150,6 +150,11 @@ public class FunctionSet { public static final String SHA2 = "sha2"; public static final String SM3 = "sm3"; + // Vector Index functions: + public static final String APPROX_COSINE_SIMILARITY = "approx_cosine_similarity"; + public static final String APPROX_COSINE_SIMILARITY_NORM = "approx_cosine_similarity_norm"; + public static final String APPROX_L2_DISTANCE = "approx_l2_distance"; + // Geo functions: public static final String ST_ASTEXT = "st_astext"; public static final String ST_ASWKT = "st_aswkt"; @@ -628,6 +633,13 @@ public class FunctionSet { .add(SLEEP) .build(); + public static final Set VECTOR_COMPUTE_FUNCTIONS = + ImmutableSet.builder() + .add(APPROX_COSINE_SIMILARITY) + .add(APPROX_COSINE_SIMILARITY_NORM) + .add(APPROX_L2_DISTANCE) + .build(); + // Only use query cache if these time function can be reduced into a constant // date/datetime value after applying FoldConstantRule, otherwise BE would yield // non-deterministic result when these function is delivered to BE. diff --git a/fe/fe-core/src/main/java/com/starrocks/catalog/OlapTable.java b/fe/fe-core/src/main/java/com/starrocks/catalog/OlapTable.java index 1239d96191beb..a813fb3befff9 100644 --- a/fe/fe-core/src/main/java/com/starrocks/catalog/OlapTable.java +++ b/fe/fe-core/src/main/java/com/starrocks/catalog/OlapTable.java @@ -195,7 +195,7 @@ public enum OlapTableState { /* This state means table is updating table meta during alter operation(SCHEMA_CHANGE * or ROLLUP). * The query plan which is generate during this state is invalid because the meta - * during the creation of the logical plan and the physical plan might be inconsistent. + * during the creation of the logical plan and the physical plan might be inconsistent. */ UPDATING_META } diff --git a/fe/fe-core/src/main/java/com/starrocks/catalog/Table.java b/fe/fe-core/src/main/java/com/starrocks/catalog/Table.java index 931d3ccc12e08..e3aabf497bc92 100644 --- a/fe/fe-core/src/main/java/com/starrocks/catalog/Table.java +++ b/fe/fe-core/src/main/java/com/starrocks/catalog/Table.java @@ -457,6 +457,11 @@ public List getColumns() { return new ArrayList<>(nameToColumn.values()); } + public void addColumn(Column column) { + fullSchema.add(column); + nameToColumn.put(column.getName(), column); + } + public long getCreateTime() { return createTime; } diff --git a/fe/fe-core/src/main/java/com/starrocks/common/VectorSearchOptions.java b/fe/fe-core/src/main/java/com/starrocks/common/VectorSearchOptions.java new file mode 100644 index 0000000000000..7f141a1ed3b42 --- /dev/null +++ b/fe/fe-core/src/main/java/com/starrocks/common/VectorSearchOptions.java @@ -0,0 +1,120 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.starrocks.common; + +import com.google.gson.annotations.SerializedName; +import com.google.gson.reflect.TypeToken; +import com.starrocks.persist.gson.GsonUtils; + +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public class VectorSearchOptions { + + public VectorSearchOptions() {} + + @SerializedName(value = "enableUseANN") + private boolean enableUseANN = false; + + @SerializedName(value = "useIVFPQ") + private boolean useIVFPQ = false; + + @SerializedName(value = "vectorDistanceColumnName") + private String vectorDistanceColumnName = "vector_distance"; + + @SerializedName(value = "vectorLimitK") + private long vectorLimitK; + + @SerializedName(value = "queryVector") + private List queryVector = new ArrayList<>(); + + @SerializedName(value = "vectorRange") + private double vectorRange = -1; + + @SerializedName(value = "resultOrder") + private int resultOrder = 0; + + public boolean isEnableUseANN() { + return enableUseANN; + } + + public void setEnableUseANN(boolean enableUseANN) { + this.enableUseANN = enableUseANN; + } + + public boolean isUseIVFPQ() { + return useIVFPQ; + } + + public void setUseIVFPQ(boolean useIVFPQ) { + this.useIVFPQ = useIVFPQ; + } + + public String getVectorDistanceColumnName() { + return vectorDistanceColumnName; + } + + public void setVectorDistanceColumnName(String vectorDistanceColumnName) { + this.vectorDistanceColumnName = vectorDistanceColumnName; + } + + public long getVectorLimitK() { + return vectorLimitK; + } + + public void setVectorLimitK(long vectorLimitK) { + this.vectorLimitK = vectorLimitK; + } + + public List getQueryVector() { + return queryVector; + } + + public void setQueryVector(List queryVector) { + this.queryVector = queryVector; + } + + public double getVectorRange() { + return vectorRange; + } + + public void setVectorRange(double vectorRange) { + this.vectorRange = vectorRange; + } + + public int getResultOrder() { + return resultOrder; + } + + public void setResultOrder(int resultOrder) { + this.resultOrder = resultOrder; + } + + public static VectorSearchOptions read(String json) { + return GsonUtils.GSON.fromJson(json, VectorSearchOptions.class); + } + + public static Map readAnnParams(String json) { + Type type = new TypeToken>() {}.getType(); + return GsonUtils.GSON.fromJson(json, type); + } + + @Override + public String toString() { + return GsonUtils.GSON.toJson(this); + } +} \ No newline at end of file diff --git a/fe/fe-core/src/main/java/com/starrocks/planner/OlapScanNode.java b/fe/fe-core/src/main/java/com/starrocks/planner/OlapScanNode.java index 1fc6ad8354921..317747f456380 100644 --- a/fe/fe-core/src/main/java/com/starrocks/planner/OlapScanNode.java +++ b/fe/fe-core/src/main/java/com/starrocks/planner/OlapScanNode.java @@ -75,6 +75,7 @@ import com.starrocks.common.FeConstants; import com.starrocks.common.Pair; import com.starrocks.common.UserException; +import com.starrocks.common.VectorSearchOptions; import com.starrocks.lake.LakeTablet; import com.starrocks.qe.ConnectContext; import com.starrocks.rowstore.RowStoreUtils; @@ -100,6 +101,7 @@ import com.starrocks.thrift.TScanRange; import com.starrocks.thrift.TScanRangeLocation; import com.starrocks.thrift.TScanRangeLocations; +import com.starrocks.thrift.TVectorSearchOptions; import com.starrocks.warehouse.Warehouse; import org.apache.commons.collections4.CollectionUtils; import org.apache.logging.log4j.LogManager; @@ -184,6 +186,8 @@ public class OlapScanNode extends ScanNode { private Map scanPartitionVersions = Maps.newHashMap(); + private VectorSearchOptions vectorSearchOptions = new VectorSearchOptions(); + // Constructs node to scan given data files of table 'tbl'. public OlapScanNode(PlanNodeId id, TupleDescriptor desc, String planNodeName) { super(id, desc, planNodeName); @@ -199,6 +203,10 @@ public Map getScanPartitionVersions() { return scanPartitionVersions; } + public void setVectorSearchOptions(VectorSearchOptions vectorSearchOptions) { + this.vectorSearchOptions = vectorSearchOptions; + } + public void setIsPreAggregation(boolean isPreAggregation, String reason) { this.isPreAggregation = isPreAggregation; this.reasonOfPreAggregation = reason; @@ -756,6 +764,13 @@ protected String getNodeExplainString(String prefix, TExplainLevel detailLevel) output.append(prefix).append("PREAGGREGATION: OFF. Reason: ").append(reasonOfPreAggregation) .append("\n"); } + if (ConnectContext.get() != null && Config.enable_experimental_vector == true) { + if (vectorSearchOptions != null && vectorSearchOptions.isEnableUseANN()) { + output.append(prefix).append("VECTORINDEX: ON").append("\n"); + } else { + output.append(prefix).append("VECTORINDEX: OFF").append("\n"); + } + } if (!conjuncts.isEmpty()) { output.append(prefix).append("PREDICATES: ").append( getExplainString(conjuncts)).append("\n"); @@ -1023,6 +1038,18 @@ protected void toThrift(TPlanNode msg) { msg.olap_scan_node.setColumn_access_paths(columnAccessPathToThrift()); } + if (vectorSearchOptions != null && vectorSearchOptions.isEnableUseANN()) { + TVectorSearchOptions tVectorSearchOptions = new TVectorSearchOptions(); + tVectorSearchOptions.setEnable_use_ann(true); + tVectorSearchOptions.setVector_limit_k(vectorSearchOptions.getVectorLimitK()); + tVectorSearchOptions.setVector_distance_column_name(vectorSearchOptions.getVectorDistanceColumnName()); + tVectorSearchOptions.setQuery_vector(vectorSearchOptions.getQueryVector()); + tVectorSearchOptions.setVector_range(vectorSearchOptions.getVectorRange()); + tVectorSearchOptions.setResult_order(vectorSearchOptions.getResultOrder()); + tVectorSearchOptions.setUse_ivfpq(vectorSearchOptions.isUseIVFPQ()); + msg.olap_scan_node.setVector_search_options(tVectorSearchOptions); + } + msg.olap_scan_node.setUse_pk_index(usePkIndex); } } diff --git a/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java b/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java index 71fdd09d73d82..b62c2729b9dcc 100644 --- a/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java @@ -44,6 +44,7 @@ import com.google.gson.GsonBuilder; import com.google.gson.ToNumberPolicy; import com.starrocks.catalog.InternalCatalog; +import com.starrocks.common.VectorSearchOptions; import com.starrocks.common.io.Text; import com.starrocks.common.io.Writable; import com.starrocks.common.util.CompressionUtils; @@ -106,6 +107,7 @@ public class SessionVariable implements Serializable, Writable, Cloneable { // The optional values are "compute_nodes_only" and "all_nodes". public static final String COMPUTATION_FRAGMENT_SCHEDULING_POLICY = "computation_fragment_scheduling_policy"; public static final String EXEC_MEM_LIMIT = "exec_mem_limit"; + public static final String VECTOR_DISTANCE_COLUMN_NAME = "vector_distance_column_name"; /** * configure the mem limit of load process on BE. @@ -323,6 +325,7 @@ public class SessionVariable implements Serializable, Writable, Cloneable { public static final String ENABLE_GROUPBY_USE_OUTPUT_ALIAS = "enable_groupby_use_output_alias"; public static final String ENABLE_QUERY_DUMP = "enable_query_dump"; public static final String QUERY_DEBUG_OPTIONS = "query_debug_options"; + public static final String VECTOR_SEARCH_OPTIONS = "vector_search_options"; // --------------------------- Limitations for Materialized View ------------------------------------ // public static final String OPTIMIZER_MATERIALIZED_VIEW_TIMELIMIT = "optimizer_materialized_view_timelimit"; @@ -760,6 +763,13 @@ public static MaterializedViewRewriteMode parse(String str) { public static final String JIT_LEVEL = "jit_level"; + // ann params like: nprobe + public static final String ANN_PARAMS = "ann_params"; + + public static final String PQ_REFINE_FACTOR = "pq_refine_factor"; + + public static final String K_FACTOR = "k_factor"; + /** * Used to split files stored in dfs such as object storage or hdfs into smaller files. */ @@ -2048,6 +2058,18 @@ public long getConnectorSinkTargetMaxFileSize() { @VarAttr(name = CBO_ENABLE_PARALLEL_PREPARE_METADATA) private boolean enableParallelPrepareMetadata = false; + // To set ANN tuning parameters for user. + // Since the session variables does not support map variables, + // it needs to be passed in the form of a JSON string. + @VarAttr(name = ANN_PARAMS) + private String annParams = ""; + + @VarAttr(name = PQ_REFINE_FACTOR) + private double pqRefineFactor = 1; + + @VarAttr(name = K_FACTOR) + private double kFactor = 1; + public int getPrepareMetadataPoolSize() { return prepareMetadataPoolSize; } @@ -4293,6 +4315,9 @@ public TQueryOptions toThrift() { tResult.setConnector_io_tasks_slow_io_latency_ms(connectorIoTasksSlowIoLatency); tResult.setConnector_scan_use_query_mem_ratio(connectorScanUseQueryMemRatio); tResult.setScan_use_query_mem_ratio(scanUseQueryMemRatio); + tResult.setAnn_params(VectorSearchOptions.readAnnParams(annParams)); + tResult.setPq_refine_factor(pqRefineFactor); + tResult.setK_factor(kFactor); tResult.setEnable_collect_table_level_scan_stats(enableCollectTableLevelScanStats); tResult.setEnable_pipeline_level_shuffle(enablePipelineLevelShuffle); tResult.setEnable_hyperscan_vec(enableHyperscanVec); diff --git a/fe/fe-core/src/main/java/com/starrocks/qe/StmtExecutor.java b/fe/fe-core/src/main/java/com/starrocks/qe/StmtExecutor.java index 90a4a7e8ed9c3..dcd03783c0cbb 100644 --- a/fe/fe-core/src/main/java/com/starrocks/qe/StmtExecutor.java +++ b/fe/fe-core/src/main/java/com/starrocks/qe/StmtExecutor.java @@ -472,7 +472,6 @@ public void execute() throws Exception { context.setExecutionId(UUIDUtil.toTUniqueId(uuid)); SessionVariable sessionVariableBackup = context.getSessionVariable(); - // if use http protocal, use httpResultSender to send result to netty channel if (context instanceof HttpConnectContext) { httpResultSender = new HttpResultSender((HttpConnectContext) context); diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/StatementPlanner.java b/fe/fe-core/src/main/java/com/starrocks/sql/StatementPlanner.java index 2f35607191b03..d85bfcbad10cc 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/StatementPlanner.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/StatementPlanner.java @@ -20,6 +20,7 @@ import com.google.common.collect.Sets; import com.starrocks.catalog.Database; import com.starrocks.catalog.ExternalOlapTable; +import com.starrocks.catalog.Index; import com.starrocks.catalog.KeysType; import com.starrocks.catalog.OlapTable; import com.starrocks.catalog.Table; @@ -30,6 +31,9 @@ import com.starrocks.common.ErrorCode; import com.starrocks.common.ErrorReport; import com.starrocks.common.LabelAlreadyUsedException; +import com.starrocks.common.VectorIndexParams.CommonIndexParamKey; +import com.starrocks.common.VectorIndexParams.VectorIndexType; +import com.starrocks.common.VectorSearchOptions; import com.starrocks.common.profile.Timer; import com.starrocks.common.profile.Tracers; import com.starrocks.http.HttpConnectContext; @@ -47,6 +51,7 @@ import com.starrocks.sql.analyzer.SemanticException; import com.starrocks.sql.ast.DeleteStmt; import com.starrocks.sql.ast.DmlStmt; +import com.starrocks.sql.ast.IndexDef; import com.starrocks.sql.ast.InsertStmt; import com.starrocks.sql.ast.QueryRelation; import com.starrocks.sql.ast.QueryStatement; @@ -81,6 +86,7 @@ import java.util.Collections; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Set; @@ -127,13 +133,14 @@ public static ExecPlan plan(StatementBase stmt, ConnectContext session, boolean areTablesCopySafe = AnalyzerUtils.areTablesCopySafe(queryStmt); needWholePhaseLock = isLockFree(areTablesCopySafe, session) ? false : true; ExecPlan plan; + VectorSearchOptions vectorSearchOptions = new VectorSearchOptions(); if (needWholePhaseLock) { - plan = createQueryPlan(queryStmt, session, resultSinkType); + plan = createQueryPlan(queryStmt, session, resultSinkType, vectorSearchOptions); } else { long planStartTime = OptimisticVersion.generate(); unLock(plannerMetaLocker); plan = createQueryPlanWithReTry(queryStmt, session, resultSinkType, plannerMetaLocker, - planStartTime); + planStartTime, vectorSearchOptions); } setOutfileSink(queryStmt, plan); return plan; @@ -226,8 +233,10 @@ public static MVTransformerContext makeMVTransformerContext(SessionVariable sess private static ExecPlan createQueryPlan(StatementBase stmt, ConnectContext session, - TResultSinkType resultSinkType) { + TResultSinkType resultSinkType, + VectorSearchOptions vectorSearchOptions) { QueryStatement queryStmt = (QueryStatement) stmt; + checkVectorIndex(queryStmt, vectorSearchOptions); QueryRelation query = (QueryRelation) queryStmt.getQueryRelation(); List colNames = query.getColumnOutputNames(); // 1. Build Logical plan @@ -254,7 +263,8 @@ private static ExecPlan createQueryPlan(StatementBase stmt, stmt, new PhysicalPropertySet(), new ColumnRefSet(logicalPlan.getOutputColumn()), - columnRefFactory); + columnRefFactory, + vectorSearchOptions); } try (Timer ignored = Tracers.watchScope("ExecPlanBuild")) { @@ -278,10 +288,13 @@ public static ExecPlan createQueryPlanWithReTry(QueryStatement queryStmt, ConnectContext session, TResultSinkType resultSinkType, PlannerMetaLocker plannerMetaLocker, - long planStartTime) { + long planStartTime, + VectorSearchOptions vectorSearchOptions) { QueryRelation query = queryStmt.getQueryRelation(); List colNames = query.getColumnOutputNames(); + checkVectorIndex(queryStmt, vectorSearchOptions); + // 1. Build Logical plan ColumnRefFactory columnRefFactory = new ColumnRefFactory(); boolean isSchemaValid = true; @@ -324,7 +337,8 @@ public static ExecPlan createQueryPlanWithReTry(QueryStatement queryStmt, queryStmt, new PhysicalPropertySet(), new ColumnRefSet(logicalPlan.getOutputColumn()), - columnRefFactory); + columnRefFactory, + vectorSearchOptions); } try (Timer ignored = Tracers.watchScope("ExecPlanBuild")) { @@ -357,6 +371,36 @@ public static ExecPlan createQueryPlanWithReTry(QueryStatement queryStmt, "schema of %s had been updated frequently during the plan generation", updatedTables); } + private static boolean checkAndSetVectorIndex(OlapTable olapTable, VectorSearchOptions vectorSearchOptions) { + for (Index index : olapTable.getIndexes()) { + if (index.getIndexType() == IndexDef.IndexType.VECTOR) { + Map indexProperties = index.getProperties(); + String indexType = indexProperties.get(CommonIndexParamKey.INDEX_TYPE.name().toLowerCase(Locale.ROOT)); + + if (VectorIndexType.IVFPQ.name().equalsIgnoreCase(indexType)) { + vectorSearchOptions.setUseIVFPQ(true); + } + + vectorSearchOptions.setEnableUseANN(true); + return true; + } + } + return false; + } + + private static void checkVectorIndex(QueryStatement queryStmt, VectorSearchOptions vectorSearchOptions) { + Set olapTables = Sets.newHashSet(); + AnalyzerUtils.copyOlapTable(queryStmt, olapTables); + boolean hasVectorIndex = false; + for (OlapTable olapTable : olapTables) { + if (checkAndSetVectorIndex(olapTable, vectorSearchOptions)) { + hasVectorIndex = true; + break; + } + } + vectorSearchOptions.setEnableUseANN(hasVectorIndex); + } + public static Set collectOriginalOlapTables(ConnectContext session, StatementBase queryStmt) { Set olapTables = Sets.newHashSet(); PlannerMetaLocker locker = new PlannerMetaLocker(session, queryStmt); diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java index d4e72041b2657..f445748e93304 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/Optimizer.java @@ -20,6 +20,7 @@ import com.starrocks.analysis.JoinOperator; import com.starrocks.catalog.MaterializedView; import com.starrocks.catalog.OlapTable; +import com.starrocks.common.VectorSearchOptions; import com.starrocks.common.profile.Timer; import com.starrocks.common.profile.Tracers; import com.starrocks.qe.ConnectContext; @@ -168,7 +169,7 @@ public OptExpression optimize(ConnectContext connectContext, ColumnRefSet requiredColumns, ColumnRefFactory columnRefFactory) { return optimize(connectContext, logicOperatorTree, null, null, requiredProperty, - requiredColumns, columnRefFactory); + requiredColumns, columnRefFactory, new VectorSearchOptions()); } public OptExpression optimize(ConnectContext connectContext, @@ -177,11 +178,14 @@ public OptExpression optimize(ConnectContext connectContext, StatementBase stmt, PhysicalPropertySet requiredProperty, ColumnRefSet requiredColumns, - ColumnRefFactory columnRefFactory) { + ColumnRefFactory columnRefFactory, + VectorSearchOptions vectorSearchOptions) { try { // prepare for optimizer prepare(connectContext, columnRefFactory, logicOperatorTree); + context.setVectorSearchOptions(vectorSearchOptions); + // prepare for mv rewrite prepareMvRewrite(connectContext, logicOperatorTree, columnRefFactory, requiredColumns); try (Timer ignored = Tracers.watchScope("MVTextRewrite")) { @@ -658,6 +662,8 @@ private OptExpression logicalRuleRewrite( ruleRewriteOnlyOnce(tree, rootTaskContext, UnionToValuesRule.getInstance()); + ruleRewriteOnlyOnce(tree, rootTaskContext, RuleSetType.VECTOR_REWRITE); + tree = SimplifyCaseWhenPredicateRule.INSTANCE.rewrite(tree, rootTaskContext); deriveLogicalProperty(tree); return tree.getInputs().get(0); diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/OptimizerContext.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/OptimizerContext.java index ca323d6c7bf48..fefaf64514d29 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/OptimizerContext.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/OptimizerContext.java @@ -19,6 +19,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.starrocks.catalog.OlapTable; +import com.starrocks.common.VectorSearchOptions; import com.starrocks.qe.ConnectContext; import com.starrocks.qe.SessionVariable; import com.starrocks.qe.VariableMgr; @@ -85,6 +86,8 @@ public class OptimizerContext { // collect all LogicalOlapScanOperators in the query before any optimization private List allLogicalOlapScanOperators; + private VectorSearchOptions vectorSearchOptions = new VectorSearchOptions(); + @VisibleForTesting public OptimizerContext(Memo memo, ColumnRefFactory columnRefFactory) { this.memo = memo; @@ -310,4 +313,12 @@ public void setAllLogicalOlapScanOperators(List allScan public List getAllLogicalOlapScanOperators() { return allLogicalOlapScanOperators; } + + public void setVectorSearchOptions(VectorSearchOptions vectorSearchOptions) { + this.vectorSearchOptions = vectorSearchOptions; + } + + public VectorSearchOptions getVectorSearchOptions() { + return vectorSearchOptions; + } } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/logical/LogicalOlapScanOperator.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/logical/LogicalOlapScanOperator.java index 2389b8e59c24c..ca5ca0b5fc5e5 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/logical/LogicalOlapScanOperator.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/logical/LogicalOlapScanOperator.java @@ -21,6 +21,7 @@ import com.starrocks.catalog.Column; import com.starrocks.catalog.OlapTable; import com.starrocks.catalog.Table; +import com.starrocks.common.VectorSearchOptions; import com.starrocks.sql.ast.PartitionNames; import com.starrocks.sql.optimizer.base.DistributionSpec; import com.starrocks.sql.optimizer.operator.Operator; @@ -51,6 +52,8 @@ public final class LogicalOlapScanOperator extends LogicalScanOperator { private long gtid = 0; + private VectorSearchOptions vectorSearchOptions = new VectorSearchOptions(); + // Only for UT public LogicalOlapScanOperator(Table table) { this(table, Maps.newHashMap(), Maps.newHashMap(), null, Operator.DEFAULT_LIMIT, null); @@ -164,6 +167,14 @@ public boolean isFromSplitOR() { return fromSplitOR; } + public VectorSearchOptions getVectorSearchOptions() { + return vectorSearchOptions; + } + + public void setVectorSearchOptions(VectorSearchOptions vectorSearchOptions) { + this.vectorSearchOptions = vectorSearchOptions; + } + @Override public R accept(OperatorVisitor visitor, C context) { return visitor.visitLogicalOlapScan(this, context); @@ -222,6 +233,7 @@ public Builder withOperator(LogicalOlapScanOperator scanOperator) { builder.hintsReplicaIds = scanOperator.hintsReplicaIds; builder.prunedPartitionPredicates = scanOperator.prunedPartitionPredicates; builder.usePkIndex = scanOperator.usePkIndex; + builder.vectorSearchOptions = scanOperator.vectorSearchOptions; return this; } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/logical/LogicalTopNOperator.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/logical/LogicalTopNOperator.java index d628851435568..f16e7216ce9d5 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/logical/LogicalTopNOperator.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/logical/LogicalTopNOperator.java @@ -68,7 +68,7 @@ private LogicalTopNOperator() { super(OperatorType.LOGICAL_TOPN); } - private LogicalTopNOperator(long limit, + public LogicalTopNOperator(long limit, ScalarOperator predicate, Projection projection, List partitionByColumns, long partitionLimit, diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/physical/PhysicalOlapScanOperator.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/physical/PhysicalOlapScanOperator.java index e908d47aab19a..44846945a2cad 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/physical/PhysicalOlapScanOperator.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/operator/physical/PhysicalOlapScanOperator.java @@ -19,6 +19,7 @@ import com.starrocks.catalog.Column; import com.starrocks.catalog.Table; import com.starrocks.common.Pair; +import com.starrocks.common.VectorSearchOptions; import com.starrocks.sql.optimizer.OptExpression; import com.starrocks.sql.optimizer.OptExpressionVisitor; import com.starrocks.sql.optimizer.base.DistributionSpec; @@ -58,6 +59,8 @@ public class PhysicalOlapScanOperator extends PhysicalScanOperator { // Rewriting the scan column ref also needs to rewrite the pruned predicate at the same time. private List prunedPartitionPredicates = Lists.newArrayList(); + private VectorSearchOptions vectorSearchOptions = new VectorSearchOptions(); + private long gtid = 0; private PhysicalOlapScanOperator() { @@ -75,7 +78,8 @@ public PhysicalOlapScanOperator(Table table, List hintsReplicaId, List prunedPartitionPredicates, Projection projection, - boolean usePkIndex) { + boolean usePkIndex, + VectorSearchOptions vectorSearchOptions) { super(OperatorType.PHYSICAL_OLAP_SCAN, table, colRefToColumnMetaMap, limit, predicate, projection); this.distributionSpec = distributionDesc; this.selectedIndexId = selectedIndexId; @@ -84,6 +88,7 @@ public PhysicalOlapScanOperator(Table table, this.hintsReplicaId = hintsReplicaId; this.prunedPartitionPredicates = prunedPartitionPredicates; this.usePkIndex = usePkIndex; + this.vectorSearchOptions = vectorSearchOptions; } public PhysicalOlapScanOperator(LogicalOlapScanOperator scanOperator) { @@ -96,6 +101,11 @@ public PhysicalOlapScanOperator(LogicalOlapScanOperator scanOperator) { this.hintsReplicaId = scanOperator.getHintsReplicaIds(); this.prunedPartitionPredicates = scanOperator.getPrunedPartitionPredicates(); this.usePkIndex = scanOperator.isUsePkIndex(); + this.vectorSearchOptions = scanOperator.getVectorSearchOptions(); + } + + public VectorSearchOptions getVectorSearchOptions() { + return vectorSearchOptions; } public long getSelectedIndexId() { @@ -278,6 +288,7 @@ public Builder withOperator(PhysicalOlapScanOperator operator) { builder.usePkIndex = operator.usePkIndex; builder.globalDicts = operator.globalDicts; builder.prunedPartitionPredicates = operator.prunedPartitionPredicates; + builder.vectorSearchOptions = operator.vectorSearchOptions; return this; } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSet.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSet.java index b25d8f54af365..800bbc35bfce7 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSet.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSet.java @@ -151,6 +151,7 @@ import com.starrocks.sql.optimizer.rule.transformation.RewriteSimpleAggToHDFSScanRule; import com.starrocks.sql.optimizer.rule.transformation.RewriteSimpleAggToMetaScanRule; import com.starrocks.sql.optimizer.rule.transformation.RewriteSumByAssociativeRule; +import com.starrocks.sql.optimizer.rule.transformation.RewriteToVectorPlanRule; import com.starrocks.sql.optimizer.rule.transformation.ScalarApply2AnalyticRule; import com.starrocks.sql.optimizer.rule.transformation.ScalarApply2JoinRule; import com.starrocks.sql.optimizer.rule.transformation.SplitLimitRule; @@ -271,6 +272,10 @@ public class RuleSet { new LimitPruneTabletsRule() )); + REWRITE_RULES.put(RuleSetType.VECTOR_REWRITE, ImmutableList.of( + new RewriteToVectorPlanRule() + )); + REWRITE_RULES.put(RuleSetType.PRUNE_COLUMNS, ImmutableList.of( PruneScanColumnRule.OLAP_SCAN, PruneScanColumnRule.SCHEMA_SCAN, diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSetType.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSetType.java index cb96d67305dc5..627887c505f8f 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSetType.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleSetType.java @@ -43,5 +43,6 @@ public enum RuleSetType { FINE_GRAINED_RANGE_PREDICATE, ELIMINATE_OP_WITH_CONSTANT, - NUM_RULE_SET + NUM_RULE_SET, + VECTOR_REWRITE } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleType.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleType.java index ba63906b82917..ac3331ff5e612 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleType.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/RuleType.java @@ -202,6 +202,8 @@ public enum RuleType { TF_PREDICATE_PROPAGATE, + TF_VECTOR_REWRITE_RULE, + // The following are implementation rules: IMP_OLAP_LSCAN_TO_PSCAN, IMP_HIVE_LSCAN_TO_PSCAN, diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/RewriteToVectorPlanRule.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/RewriteToVectorPlanRule.java new file mode 100644 index 0000000000000..1e4a60c409d45 --- /dev/null +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/RewriteToVectorPlanRule.java @@ -0,0 +1,296 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package com.starrocks.sql.optimizer.rule.transformation; + +import com.google.common.collect.Lists; +import com.starrocks.analysis.BinaryType; +import com.starrocks.catalog.Column; +import com.starrocks.catalog.FunctionSet; +import com.starrocks.catalog.Type; +import com.starrocks.common.Config; +import com.starrocks.common.VectorSearchOptions; +import com.starrocks.sql.optimizer.OptExpression; +import com.starrocks.sql.optimizer.OptimizerContext; +import com.starrocks.sql.optimizer.base.ColumnRefFactory; +import com.starrocks.sql.optimizer.base.Ordering; +import com.starrocks.sql.optimizer.operator.Operator; +import com.starrocks.sql.optimizer.operator.OperatorType; +import com.starrocks.sql.optimizer.operator.Projection; +import com.starrocks.sql.optimizer.operator.logical.LogicalOlapScanOperator; +import com.starrocks.sql.optimizer.operator.logical.LogicalTopNOperator; +import com.starrocks.sql.optimizer.operator.pattern.Pattern; +import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator; +import com.starrocks.sql.optimizer.operator.scalar.CallOperator; +import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator; +import com.starrocks.sql.optimizer.operator.scalar.CompoundPredicateOperator; +import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator; +import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator; +import com.starrocks.sql.optimizer.rule.RuleType; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +import static com.starrocks.analysis.BinaryType.GE; +import static com.starrocks.analysis.BinaryType.GT; +import static com.starrocks.analysis.BinaryType.LE; +import static com.starrocks.analysis.BinaryType.LT; +import static com.starrocks.catalog.FunctionSet.APPROX_COSINE_SIMILARITY; +import static com.starrocks.catalog.FunctionSet.APPROX_L2_DISTANCE; + +public class RewriteToVectorPlanRule extends TransformationRule { + + public RewriteToVectorPlanRule() { + super(RuleType.TF_VECTOR_REWRITE_RULE, + Pattern.create(OperatorType.LOGICAL_TOPN) + .addChildren(Pattern.create(OperatorType.LOGICAL_OLAP_SCAN))); + } + + @Override + public boolean check(OptExpression input, OptimizerContext context) { + LogicalTopNOperator topNOperator = (LogicalTopNOperator) input.getOp(); + LogicalOlapScanOperator scanOperator = (LogicalOlapScanOperator) input.getInputs().get(0).getOp(); + + scanOperator.setVectorSearchOptions(context.getVectorSearchOptions()); + VectorSearchOptions vectorSearchOptions = scanOperator.getVectorSearchOptions(); + + if (!vectorSearchOptions.isEnableUseANN() || Config.enable_experimental_vector != true) { + return false; + } + + Map columnRefMap = scanOperator.getProjection().getColumnRefMap(); + + boolean isEnableUseANN = false; + for (Map.Entry entry : columnRefMap.entrySet()) { + if (FunctionSet.VECTOR_COMPUTE_FUNCTIONS.contains(entry.getKey().getName()) + && entry.getValue() instanceof CallOperator) { + CallOperator callOperator = (CallOperator) entry.getValue(); + vectorSearchOptions.setQueryVector(collectVectorQuery(callOperator)); + isEnableUseANN = true; + break; + } + } + + if (!isEnableUseANN) { + vectorSearchOptions.setEnableUseANN(false); + return false; + } + + if (!topNOperator.getOrderByElements().isEmpty() && + FunctionSet.VECTOR_COMPUTE_FUNCTIONS.contains( + topNOperator.getOrderByElements().get(0).getColumnRef().getName())) { + return topNOperator.getLimit() != Operator.DEFAULT_LIMIT && + columnRefMap.entrySet().stream() + .filter(entry -> FunctionSet.VECTOR_COMPUTE_FUNCTIONS.contains(entry.getKey().getName())) + .anyMatch(entry -> entry.getValue() instanceof CallOperator); + } + return false; + } + + @Override + public List transform(OptExpression input, OptimizerContext context) { + LogicalTopNOperator topNOperator = (LogicalTopNOperator) input.getOp(); + LogicalOlapScanOperator scanOperator = (LogicalOlapScanOperator) input.getInputs().get(0).getOp(); + + VectorSearchOptions options = scanOperator.getVectorSearchOptions(); + // set limit_K for ann searcher + options.setVectorLimitK(topNOperator.getLimit()); + ScalarOperator predicate = scanOperator.getPredicate(); + Optional newPredicate = Optional.empty(); + boolean isAscending = topNOperator.getOrderByElements().get(0).isAscending(); + if (predicate != null) { + newPredicate = findAndSetVectorRange(predicate, isAscending, options); + if (!options.isEnableUseANN()) { + return Lists.newArrayList(input); + } + } + options.setResultOrder(isAscending ? 0 : 1); + String functionName = topNOperator.getOrderByElements().get(0).getColumnRef().getName(); + + if (functionName.equalsIgnoreCase(APPROX_L2_DISTANCE) && !isAscending || + functionName.equalsIgnoreCase(APPROX_COSINE_SIMILARITY) && isAscending || + !options.isEnableUseANN()) { + options.setEnableUseANN(false); + return Lists.newArrayList(input); + } + if (options.isUseIVFPQ()) { + // Skip rewrite because IVFPQ is inaccurate and requires a brute force search after the ANN index search + input.getInputs().get(0).getOp() + .setPredicate(newPredicate.isPresent() ? newPredicate.get() : null); + return Lists.newArrayList(input); + } + + Optional result = buildVectorSortScanOperator(topNOperator, + scanOperator, context, newPredicate, options); + return result.isPresent() ? Lists.newArrayList(result.get()) : Lists.newArrayList(input); + } + + public Optional buildVectorSortScanOperator(LogicalTopNOperator topNOperator, + LogicalOlapScanOperator scanOperator, OptimizerContext context, + Optional newPredicate, VectorSearchOptions vectorSearchOptions) { + // bottom-up + String distanceColumnName = scanOperator.getVectorSearchOptions().getVectorDistanceColumnName(); + Column distanceColumn = new Column(distanceColumnName, Type.FLOAT); + scanOperator.getTable().addColumn(distanceColumn); + + ColumnRefFactory columnRefFactory = context.getColumnRefFactory(); + ColumnRefOperator distanceColumnRefOperator = columnRefFactory.create(distanceColumnName, Type.FLOAT, false); + + Map colRefToColumnMetaMap = new HashMap<>(scanOperator.getColRefToColumnMetaMap()); + colRefToColumnMetaMap.put(distanceColumnRefOperator, distanceColumn); + + Map columnMetaToColRefMap = new HashMap<>(scanOperator.getColumnMetaToColRefMap()); + columnMetaToColRefMap.put(distanceColumn, distanceColumnRefOperator); + + // new Scan operator + LogicalOlapScanOperator newScanOperator = new LogicalOlapScanOperator(scanOperator.getTable(), + colRefToColumnMetaMap, columnMetaToColRefMap, scanOperator.getDistributionSpec(), + scanOperator.getLimit(), newPredicate.isPresent() ? newPredicate.get() : null, + scanOperator.getSelectedIndexId(), scanOperator.getSelectedPartitionId(), + scanOperator.getPartitionNames(), scanOperator.hasTableHints(), + scanOperator.getSelectedTabletId(), scanOperator.getHintsTabletIds(), + scanOperator.getHintsReplicaIds(), scanOperator.isUsePkIndex()); + + newScanOperator.setVectorSearchOptions(vectorSearchOptions); + Map scanProjectMap = new HashMap<>(); + Map topNProjectMap = new HashMap<>(); + // find original column and project it onto the topN + Optional originalColRef = scanOperator.getProjection().getColumnRefMap() + .entrySet().stream().filter(entry -> FunctionSet.VECTOR_COMPUTE_FUNCTIONS + .contains(entry.getKey().getName())).map(entry -> entry.getKey()) + .findFirst(); + if (originalColRef.isEmpty()) { + return Optional.empty(); + } + + scanOperator.getProjection().getColumnRefMap().entrySet().stream() + .forEach(entry -> { + if (FunctionSet.VECTOR_COMPUTE_FUNCTIONS.contains(entry.getKey().getName()) + && entry.getValue() instanceof CallOperator) { + scanProjectMap.put(distanceColumnRefOperator, distanceColumnRefOperator); + } else { + scanProjectMap.put(entry.getKey(), entry.getValue()); + topNProjectMap.put(entry.getKey(), entry.getValue()); + } + }); + newScanOperator.setProjection(new Projection(scanProjectMap)); + + List orderByElements = topNOperator.getOrderByElements().stream().map(ordering -> + FunctionSet.VECTOR_COMPUTE_FUNCTIONS.contains(ordering.getColumnRef().getName()) ? + new Ordering(distanceColumnRefOperator, ordering.isAscending(), ordering.isNullsFirst()) : ordering) + .collect(Collectors.toList()); + + boolean hasProjection = topNOperator.getProjection() != null; + Map newTopNProjectMap = new HashMap<>(); + if (hasProjection) { + topNOperator.getProjection().getColumnRefMap().entrySet().stream() + .forEach(entry -> { + if (FunctionSet.VECTOR_COMPUTE_FUNCTIONS.contains(entry.getKey().getName())) { + newTopNProjectMap.put(originalColRef.get(), distanceColumnRefOperator); + } else { + newTopNProjectMap.put(entry.getKey(), entry.getValue()); + } + }); + } else { + topNProjectMap.put(originalColRef.get(), distanceColumnRefOperator); + } + + // new TopN operator + LogicalTopNOperator newTopNOperator = new LogicalTopNOperator(topNOperator.getLimit(), + topNOperator.getPredicate(), + hasProjection ? new Projection(newTopNProjectMap) : new Projection(topNProjectMap), + topNOperator.getPartitionByColumns(), topNOperator.getPartitionLimit(), orderByElements, + topNOperator.getOffset(), topNOperator.getSortPhase(), topNOperator.getTopNType(), topNOperator.isSplit()); + + OptExpression topNExpression = OptExpression.create(newTopNOperator); + topNExpression.getInputs().clear(); + topNExpression.getInputs().add(OptExpression.create(newScanOperator)); + + return Optional.of(topNExpression); + } + + public Optional findAndSetVectorRange(ScalarOperator operator, + boolean isAscending, VectorSearchOptions options) { + if (!options.isEnableUseANN()) { + return Optional.empty(); + } + + if (operator instanceof BinaryPredicateOperator && operator.getChild(1) instanceof ConstantOperator && + (isVectorCallOperator(operator.getChild(0)))) { + BinaryType binaryType = ((BinaryPredicateOperator) operator).getBinaryType(); + if (((binaryType.equals(LE) || binaryType.equals(LT)) && !isAscending) || + ((binaryType.equals(GE) || binaryType.equals(GT)) && isAscending)) { + options.setEnableUseANN(false); + return Optional.empty(); + } + options.setVectorRange((double) (((ConstantOperator) operator.getChild(1)).getValue())); + return Optional.empty(); + } else if (operator instanceof CompoundPredicateOperator) { + List newOperators = new ArrayList<>(); + for (ScalarOperator child : operator.getChildren()) { + Optional newChild = findAndSetVectorRange(child, isAscending, options); + if (newChild.isPresent()) { + newOperators.add(newChild.get()); + } + } + if (newOperators.size() > 1) { + return Optional.of(new CompoundPredicateOperator(((CompoundPredicateOperator) operator).getCompoundType(), + newOperators)); + } else if (newOperators.size() == 1) { + return Optional.of(newOperators.get(0)); + } else { + return Optional.empty(); + } + } else { + options.setEnableUseANN(false); + return Optional.of(operator.clone()); + } + } + + public boolean isVectorCallOperator(ScalarOperator scalarOperator) { + if (scalarOperator instanceof CallOperator && + FunctionSet.VECTOR_COMPUTE_FUNCTIONS.contains(((CallOperator) scalarOperator).getFnName())) { + return true; + } + if (scalarOperator.getChildren().size() == 0) { + return false; + } + return isVectorCallOperator(scalarOperator.getChild(0)); + } + + public List collectVectorQuery(CallOperator callOperator) { + // suppose it's a standard vector query + List vectorQuery = new ArrayList<>(); + collectVector(callOperator, vectorQuery); + return vectorQuery; + } + + public void collectVector(ScalarOperator scalarOperator, List vectorQuery) { + if (scalarOperator instanceof ColumnRefOperator) { + return; + } + + if (scalarOperator instanceof ConstantOperator) { + vectorQuery.add(String.valueOf(((ConstantOperator) scalarOperator).getValue())); + return; + } + + for (ScalarOperator child : scalarOperator.getChildren()) { + collectVector(child, vectorQuery); + } + } +} diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/AddDecodeNodeForDictStringRule.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/AddDecodeNodeForDictStringRule.java index cd7bbdac403b4..519f7cc995ffa 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/AddDecodeNodeForDictStringRule.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/AddDecodeNodeForDictStringRule.java @@ -484,7 +484,8 @@ public OptExpression visitPhysicalOlapScan(OptExpression optExpression, DecodeCo scanOperator.getSelectedIndexId(), scanOperator.getSelectedPartitionId(), scanOperator.getSelectedTabletId(), scanOperator.getHintsReplicaId(), newPrunedPredicates, - scanOperator.getProjection(), scanOperator.isUsePkIndex()); + scanOperator.getProjection(), scanOperator.isUsePkIndex(), + scanOperator.getVectorSearchOptions()); newOlapScan.setScanOptimzeOption(scanOperator.getScanOptimzeOption()); newOlapScan.setPreAggregation(scanOperator.isPreAggregation()); newOlapScan.setGlobalDicts(globalDicts); diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java b/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java index 6adf1decb19bd..20a6fff10d70e 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java @@ -822,6 +822,7 @@ public PlanFragment visitPhysicalOlapScan(OptExpression optExpr, ExecPlan contex scanNode.setIsOutputChunkByBucket(node.needOutputChunkByBucket()); scanNode.setWithoutColocateRequirement(node.isWithoutColocateRequirement()); scanNode.setGtid(node.getGtid()); + scanNode.setVectorSearchOptions(node.getVectorSearchOptions()); currentExecGroup.add(scanNode); // set tablet try { diff --git a/fe/fe-core/src/test/java/com/starrocks/planner/VectorIndexTest.java b/fe/fe-core/src/test/java/com/starrocks/planner/VectorIndexTest.java new file mode 100644 index 0000000000000..e9e0691c3f2dd --- /dev/null +++ b/fe/fe-core/src/test/java/com/starrocks/planner/VectorIndexTest.java @@ -0,0 +1,122 @@ +// Copyright 2021-present StarRocks, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file is based on code available under the Apache license here: +// https://github.com/apache/incubator-doris/blob/master/fe/fe-core/src/test/java/org/apache/doris/planner/QueryPlanTest.java + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package com.starrocks.planner; + +import com.starrocks.common.Config; +import com.starrocks.common.FeConstants; +import com.starrocks.sql.plan.PlanTestBase; + +import org.junit.BeforeClass; +import org.junit.Test; + +public class VectorIndexTest extends PlanTestBase { + + @BeforeClass + public static void beforeClass() throws Exception { + PlanTestBase.beforeClass(); + Config.enable_experimental_vector = true; + FeConstants.enablePruneEmptyOutputScan = false; + starRocksAssert.withTable("CREATE TABLE test.test_cosine (" + + " c0 INT," + + " c1 array," + + " INDEX index_vector1 (c1) USING VECTOR ('metric_type' = 'cosine_similarity', " + + "'is_vector_normed' = 'false', 'M' = '512', 'index_type' = 'hnsw', 'dim'='5') " + + ") " + + "DUPLICATE KEY(c0) " + + "DISTRIBUTED BY HASH(c0) BUCKETS 1 " + + "PROPERTIES ('replication_num'='1');"); + + starRocksAssert.withTable("CREATE TABLE test.test_l2 (" + + " c0 INT," + + " c1 array," + + " INDEX index_vector1 (c1) USING VECTOR ('metric_type' = 'l2_distance', " + + "'is_vector_normed' = 'false', 'M' = '512', 'index_type' = 'hnsw', 'dim'='5') " + + ") " + + "DUPLICATE KEY(c0) " + + "DISTRIBUTED BY HASH(c0) BUCKETS 1 " + + "PROPERTIES ('replication_num'='1');"); + + starRocksAssert.withTable("CREATE TABLE test.test_ivfpq (" + + " c0 INT," + + " c1 array," + + " INDEX index_vector1 (c1) USING VECTOR ('metric_type' = 'l2_distance', " + + "'is_vector_normed' = 'false', 'nbits' = '1', 'index_type' = 'ivfpq', 'dim'='5') " + + ") " + + "DUPLICATE KEY(c0) " + + "DISTRIBUTED BY HASH(c0) BUCKETS 1 " + + "PROPERTIES ('replication_num'='1');"); + } + + @Test + public void testVectorIndexSyntax() throws Exception { + String sql1 = "select c1 from test.test_cosine " + + "order by approx_cosine_similarity([1.1,2.2,3.3], c1) desc limit 10"; + assertPlanContains(sql1, "VECTORINDEX: ON"); + + String sql2 = "select c1 from test.test_l2 " + + "order by approx_l2_distance([1.1,2.2,3.3], c1) limit 10"; + assertPlanContains(sql2, "VECTORINDEX: ON"); + + // Sorting in desc order doesn't make sense in l2_distance, + // which won't trigger the vector retrieval logic. + String sql3 = "select c1 from test.test_l2 " + + "order by approx_l2_distance([1.1,2.2,3.3], c1) desc limit 10"; + assertPlanContains(sql3, "VECTORINDEX: OFF"); + + String sql4 = "select c1 from test.test_cosine " + + "order by cosine_similarity([1.1,2.2,3.3], c1) desc limit 10"; + assertPlanContains(sql4, "VECTORINDEX: OFF"); + + String sql5 = "select c1, approx_l2_distance([1.1,2.2,3.3], c1) as score" + + " from test.test_ivfpq order by score limit 10"; + assertPlanContains(sql5, "VECTORINDEX: ON"); + + String sql6 = "select c1, approx_cosine_similarity([1.1,2.2,3.3], c1) as score" + + " from test.test_cosine order by score desc limit 10"; + assertPlanContains(sql6, "VECTORINDEX: ON"); + + String sql7 = "select c1, approx_cosine_similarity([1.1,2.2,3.3], c1) as score" + + " from test.test_cosine where c0 = 1 order by score desc limit 10"; + assertPlanContains(sql7, "VECTORINDEX: OFF"); + + String sql8 = "select c1, approx_cosine_similarity([1.1,2.2,3.3], c1) as score" + + " from test.test_cosine having score > 0.8 order by score desc limit 10"; + assertPlanContains(sql8, "VECTORINDEX: ON"); + + String sql9 = "select c1, approx_cosine_similarity([1.1,2.2,3.3], c1) as score" + + " from test.test_cosine having score < 0.8 order by score desc limit 10"; + assertPlanContains(sql9, "VECTORINDEX: OFF"); + } +} diff --git a/gensrc/script/functions.py b/gensrc/script/functions.py index 0c8f330931cef..ed08eb1958185 100644 --- a/gensrc/script/functions.py +++ b/gensrc/script/functions.py @@ -56,15 +56,20 @@ [10090, "tan", True, False, "DOUBLE", ["DOUBLE"], "MathFunctions::tan"], [10100, "atan", True, False, "DOUBLE", ["DOUBLE"], "MathFunctions::atan"], [10101, "tanh", True, False, "DOUBLE", ["DOUBLE"], "MathFunctions::tanh"], - [10102, "cosine_similarity", True, False, "FLOAT", ["ARRAY_FLOAT", "ARRAY_FLOAT"], - "MathFunctions::cosine_similarity"], - [10103, "cosine_similarity_norm", True, False, "FLOAT", ["ARRAY_FLOAT", "ARRAY_FLOAT"], - "MathFunctions::cosine_similarity"], + + # cosine function + [10102, "cosine_similarity", True, False, "FLOAT", ["ARRAY_FLOAT", "ARRAY_FLOAT"], "MathFunctions::cosine_similarity"], + [10103, "cosine_similarity_norm", True, False, "FLOAT", ["ARRAY_FLOAT", "ARRAY_FLOAT"], "MathFunctions::cosine_similarity"], + [10106, "approx_cosine_similarity", True, False, "FLOAT", ["ARRAY_FLOAT", "ARRAY_FLOAT"], "MathFunctions::cosine_similarity"], [10110, "ceil", True, False, "BIGINT", ["DOUBLE"], "MathFunctions::ceil"], [10111, "ceiling", True, False, "BIGINT", ["DOUBLE"], "MathFunctions::ceil"], [10112, "dceil", True, False, "BIGINT", ["DOUBLE"], "MathFunctions::ceil"], + # l2 function + [10114, "l2_distance", True, False, "FLOAT", ["ARRAY_FLOAT", "ARRAY_FLOAT"], "MathFunctions::l2_distance"], + [10116, "approx_l2_distance", True, False, "FLOAT", ["ARRAY_FLOAT", "ARRAY_FLOAT"], "MathFunctions::l2_distance"], + [10120, "floor", True, False, "BIGINT", ["DOUBLE"], "MathFunctions::floor"], [10121, "dfloor", True, False, "BIGINT", ["DOUBLE"], "MathFunctions::floor"], diff --git a/gensrc/thrift/InternalService.thrift b/gensrc/thrift/InternalService.thrift index 2cafa6e773502..db3de6f353876 100644 --- a/gensrc/thrift/InternalService.thrift +++ b/gensrc/thrift/InternalService.thrift @@ -320,6 +320,10 @@ struct TQueryOptions { 140: optional string catalog; 141: optional i32 datacache_evict_probability; + + 150: optional map ann_params; + 151: optional double pq_refine_factor; + 152: optional double k_factor; } diff --git a/gensrc/thrift/PlanNodes.thrift b/gensrc/thrift/PlanNodes.thrift index cbf9149e631a0..3650fc983fcd8 100644 --- a/gensrc/thrift/PlanNodes.thrift +++ b/gensrc/thrift/PlanNodes.thrift @@ -525,6 +525,19 @@ struct TColumnAccessPath { 5: optional Types.TTypeDesc type_desc } +struct TVectorSearchOptions { + 1: optional bool enable_use_ann; + 2: optional i64 vector_limit_k; + 3: optional string vector_distance_column_name; + 4: optional list query_vector; + 5: optional map query_params; + 6: optional double vector_range; + 7: optional i32 result_order; + 8: optional bool use_ivfpq; + 9: optional double pq_refine_factor; + 10: optional double k_factor; +} + // If you find yourself changing this struct, see also TLakeScanNode struct TOlapScanNode { 1: required Types.TTupleId tuple_id @@ -555,6 +568,8 @@ struct TOlapScanNode { 35: optional bool enable_prune_column_after_index_filter 36: optional bool enable_gin_filter 37: optional i64 schema_id + + 40: optional TVectorSearchOptions vector_search_options } struct TJDBCScanNode { diff --git a/test/sql/test_function/R/test_math b/test/sql/test_function/R/test_math index ddba7b6af76dc..71ab0a1678888 100644 --- a/test/sql/test_function/R/test_math +++ b/test/sql/test_function/R/test_math @@ -1,4 +1,4 @@ --- name: test_math_cosine_similarity +-- name: test_vector_math create table t1 (id int, data array) engine = olap distributed by hash(id) properties ("replication_num" = "1"); -- result: -- !result @@ -19,19 +19,29 @@ select cosine_similarity_norm(array[0.1, 0.2, 0.3], array[0.1, 0.2 -- result: 0.14000002 -- !result -create table test_cosine (id int, data array) ENGINE=olap DUPLICATE KEY (id) DISTRIBUTED BY HASH(id) properties ("replication_num" = "1"); +select l2_distance(array[0.1, 0.2, 0.3], data) as dist, id from t1 order by dist desc; -- result: +0.08000001 3 +0.020000001 2 +0.0 1 -- !result -insert into test_cosine values (1, array[0.1, 0.2, 0.3]), (2, array[0.2, 0.1, 0.3]); +select l2_distance(array[0.1, 0.2, 0.3], array[0.1, 0.2, 0.3]) as dist; -- result: +0.0 -- !result -insert into test_cosine values (3, array[0.15, 0.25, 0.32]), (4, array[0.12, 0.11, 0.32]); +create table test_vector (id int, data array) ENGINE=olap DUPLICATE KEY (id) DISTRIBUTED BY HASH(id) properties ("replication_num" = "1"); -- result: -- !result -insert into test_cosine values (5, array[0.25, 0.12, 0.13]), (6, array[0.22, 0.01, 0.39]); +insert into test_vector values (1, array[0.1, 0.2, 0.3]), (2, array[0.2, 0.1, 0.3]); -- result: -- !result -select id, data, cosine_similarity(array[0.1, 0.2, 0.3], data) as sim from test_cosine order by sim desc; +insert into test_vector values (3, array[0.15, 0.25, 0.32]), (4, array[0.12, 0.11, 0.32]); +-- result: +-- !result +insert into test_vector values (5, array[0.25, 0.12, 0.13]), (6, array[0.22, 0.01, 0.39]); +-- result: +-- !result +select id, data, cosine_similarity(array[0.1, 0.2, 0.3], data) as sim from test_vector order by sim desc; -- result: 1 [0.1,0.2,0.3] 0.9999999 3 [0.15,0.25,0.32] 0.99397856 @@ -40,7 +50,7 @@ select id, data, cosine_similarity(array[0.1, 0.2, 0.3], data) as sim fro 6 [0.22,0.01,0.39] 0.841375 5 [0.25,0.12,0.13] 0.76792216 -- !result -select a.id, b.id, a.data, b.data, cosine_similarity(a.data, b.data) as sim from test_cosine as a cross join test_cosine as b; +select a.id, b.id, a.data, b.data, cosine_similarity(a.data, b.data) as sim from test_vector as a cross join test_vector as b; -- result: 6 4 [0.22,0.01,0.39] [0.12,0.11,0.32] 0.94712645 6 1 [0.22,0.01,0.39] [0.1,0.2,0.3] 0.841375 @@ -79,7 +89,7 @@ select a.id, b.id, a.data, b.data, cosine_similarity(a.data, b.data) as sim from 4 5 [0.12,0.11,0.32] [0.25,0.12,0.13] 0.77120167 4 6 [0.12,0.11,0.32] [0.22,0.01,0.39] 0.94712645 -- !result -select a.id, b.id, a.data, b.data, cosine_similarity(a.data, b.data) as sim from test_cosine as a cross join test_cosine as b order by sim desc; +select a.id, b.id, a.data, b.data, cosine_similarity(a.data, b.data) as sim from test_vector as a cross join test_vector as b order by sim desc; -- result: 3 3 [0.15,0.25,0.32] [0.15,0.25,0.32] 1.0 5 5 [0.25,0.12,0.13] [0.25,0.12,0.13] 0.99999994 @@ -117,4 +127,13 @@ select a.id, b.id, a.data, b.data, cosine_similarity(a.data, b.data) as sim from 4 5 [0.12,0.11,0.32] [0.25,0.12,0.13] 0.77120167 1 5 [0.1,0.2,0.3] [0.25,0.12,0.13] 0.76792216 5 1 [0.25,0.12,0.13] [0.1,0.2,0.3] 0.76792216 +-- !result +select id, data, l2_distance(array[0.1, 0.2, 0.3], data) as sim from test_vector order by sim desc; +-- result: +6 [0.22,0.01,0.39] 0.058599994 +5 [0.25,0.12,0.13] 0.057800006 +2 [0.2,0.1,0.3] 0.020000001 +4 [0.12,0.11,0.32] 0.0089 +3 [0.15,0.25,0.32] 0.005399999 +1 [0.1,0.2,0.3] 0.0 -- !result \ No newline at end of file diff --git a/test/sql/test_function/T/test_math b/test/sql/test_function/T/test_math index 91747b0c2b96c..145dc83f5a4f0 100644 --- a/test/sql/test_function/T/test_math +++ b/test/sql/test_function/T/test_math @@ -1,4 +1,4 @@ --- name: test_math_cosine_similarity +-- name: test_vector_math create table t1 (id int, data array) engine = olap distributed by hash(id) properties ("replication_num" = "1"); @@ -10,14 +10,20 @@ select cosine_similarity(array[0.1, 0.2, 0.3], array[0.1, 0.2, 0.3 select cosine_similarity_norm(array[0.1, 0.2, 0.3], array[0.1, 0.2, 0.3]) as dist; +select l2_distance(array[0.1, 0.2, 0.3], data) as dist, id from t1 order by dist desc; + +select l2_distance(array[0.1, 0.2, 0.3], array[0.1, 0.2, 0.3]) as dist; + --------------- cross join ----------------- -create table test_cosine (id int, data array) ENGINE=olap DUPLICATE KEY (id) DISTRIBUTED BY HASH(id) properties ("replication_num" = "1"); +create table test_vector (id int, data array) ENGINE=olap DUPLICATE KEY (id) DISTRIBUTED BY HASH(id) properties ("replication_num" = "1"); + +insert into test_vector values (1, array[0.1, 0.2, 0.3]), (2, array[0.2, 0.1, 0.3]); +insert into test_vector values (3, array[0.15, 0.25, 0.32]), (4, array[0.12, 0.11, 0.32]); +insert into test_vector values (5, array[0.25, 0.12, 0.13]), (6, array[0.22, 0.01, 0.39]); -insert into test_cosine values (1, array[0.1, 0.2, 0.3]), (2, array[0.2, 0.1, 0.3]); -insert into test_cosine values (3, array[0.15, 0.25, 0.32]), (4, array[0.12, 0.11, 0.32]); -insert into test_cosine values (5, array[0.25, 0.12, 0.13]), (6, array[0.22, 0.01, 0.39]); +select id, data, cosine_similarity(array[0.1, 0.2, 0.3], data) as sim from test_vector order by sim desc; +select a.id, b.id, a.data, b.data, cosine_similarity(a.data, b.data) as sim from test_vector as a cross join test_vector as b; +select a.id, b.id, a.data, b.data, cosine_similarity(a.data, b.data) as sim from test_vector as a cross join test_vector as b order by sim desc; -select id, data, cosine_similarity(array[0.1, 0.2, 0.3], data) as sim from test_cosine order by sim desc; -select a.id, b.id, a.data, b.data, cosine_similarity(a.data, b.data) as sim from test_cosine as a cross join test_cosine as b; -select a.id, b.id, a.data, b.data, cosine_similarity(a.data, b.data) as sim from test_cosine as a cross join test_cosine as b order by sim desc; \ No newline at end of file +select id, data, l2_distance(array[0.1, 0.2, 0.3], data) as sim from test_vector order by sim desc; \ No newline at end of file diff --git a/test/sql/test_vector_index/R/test_vector_index b/test/sql/test_vector_index/R/test_vector_index index 0dc9e1a712cbc..09c176a587166 100644 --- a/test/sql/test_vector_index/R/test_vector_index +++ b/test/sql/test_vector_index/R/test_vector_index @@ -51,7 +51,7 @@ DROP TABLE t_test_vector_table; -- result: -- !result --- name: test_write_vector_index +-- name: test_vector_index CREATE TABLE `t_test_vector_table` ( `id` bigint(20) NOT NULL COMMENT "", `vector1` ARRAY NOT NULL COMMENT "", @@ -72,6 +72,20 @@ insert into t_test_vector_table values(1, [1,2,3,4,5]); -- result: -- !result +insert into t_test_vector_table values(2, [4,5,6,7,8]); +-- result: +-- !result + +select id, approx_l2_distance([1,1,1,1,1], vector1) from t_test_vector_table order by approx_l2_distance([1,1,1,1,1], vector1) limit 1; +-- result: +1 30.0 +-- !result + +select * from (select id, approx_l2_distance([1,1,1,1,1], vector1) score from t_test_vector_table) a where score < 40 order by score limit 1; +-- result: +1 30.0 +-- !result + DROP TABLE t_test_vector_table; -- result: -- !result \ No newline at end of file diff --git a/test/sql/test_vector_index/T/test_vector_index b/test/sql/test_vector_index/T/test_vector_index index e743fa3ee40ad..f50e60de399ad 100644 --- a/test/sql/test_vector_index/T/test_vector_index +++ b/test/sql/test_vector_index/T/test_vector_index @@ -22,7 +22,7 @@ ALTER TABLE t_test_vector_table drop index index_vector2; DROP TABLE t_test_vector_table; --- name: test_write_vector_index +-- name: test_vector_index CREATE TABLE `t_test_vector_table` ( `id` bigint(20) NOT NULL COMMENT "", `vector1` ARRAY NOT NULL COMMENT "", @@ -38,6 +38,8 @@ PROPERTIES ( ); insert into t_test_vector_table values(1, [1,2,3,4,5]); -insert into t_test_vector_table values(2, [0,0]); +insert into t_test_vector_table values(2, [4,5,6,7,8]); +select id, approx_l2_distance([1,1,1,1,1], vector1) from t_test_vector_table order by approx_l2_distance([1,1,1,1,1], vector1) limit 1; +select * from (select id, approx_l2_distance([1,1,1,1,1], vector1) score from t_test_vector_table) a where score < 40 order by score limit 1; -DROP TABLE t_test_vector_table; +DROP TABLE t_test_vector_table; \ No newline at end of file