Skip to content

Commit

Permalink
[Feature] Support select vector index (StarRocks#48521)
Browse files Browse the repository at this point in the history
Signed-off-by: yulongzhao <[email protected]>
  • Loading branch information
yulongfufu authored and HangyuanLiu committed Sep 11, 2024
1 parent 5fde62b commit c1c3b4f
Show file tree
Hide file tree
Showing 64 changed files with 2,010 additions and 60 deletions.
9 changes: 9 additions & 0 deletions be/src/column/chunk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,15 @@ void Chunk::append_column(ColumnPtr column, const FieldPtr& field) {
check_or_die();
}

void Chunk::append_vector_column(ColumnPtr column, const FieldPtr& field, SlotId slot_id) {
DCHECK(!_cid_to_index.contains(field->id()));
_cid_to_index[field->id()] = _columns.size();
_slot_id_to_index[slot_id] = _columns.size();
_columns.emplace_back(std::move(column));
_schema->append(field);
check_or_die();
}

void Chunk::append_column(ColumnPtr column, SlotId slot_id) {
_slot_id_to_index[slot_id] = _columns.size();
_columns.emplace_back(std::move(column));
Expand Down
2 changes: 2 additions & 0 deletions be/src/column/chunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ class Chunk {
// schema must exist and will be updated.
void append_column(ColumnPtr column, const FieldPtr& field);

void append_vector_column(ColumnPtr column, const FieldPtr& field, SlotId slot_id);

void append_column(ColumnPtr column, SlotId slot_id);
void insert_column(size_t idx, ColumnPtr column, const FieldPtr& field);

Expand Down
5 changes: 4 additions & 1 deletion be/src/column/schema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ Schema::Schema(Schema* schema, const std::vector<ColumnId>& cids)
auto ori_sort_idxes = schema->sort_key_idxes();
std::unordered_set<ColumnId> scids(ori_sort_idxes.begin(), ori_sort_idxes.end());
for (int i = 0; i < cids.size(); i++) {
DCHECK_LT(cids[i], schema->_fields.size());
if (cids[i] >= schema->_fields.size()) {
_fields.resize(_fields.size() - 1);
continue;
}
_fields[i] = schema->_fields[cids[i]];
if (scids.find(cids[i]) != scids.end()) {
_sort_key_idxes.emplace_back(i);
Expand Down
6 changes: 6 additions & 0 deletions be/src/common/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,9 @@ CONF_Int64(query_cache_capacity, "536870912");
// ranges in [1,16], default value is 4.
CONF_mInt32(query_cache_num_lanes_per_driver, "4");

// Used by vector query cache, 500MB in default
CONF_Int64(vector_query_cache_capacity, "536870912");

// Used to limit buffer size of tablet send channel.
CONF_mInt64(send_channel_buffer_limit, "67108864");

Expand Down Expand Up @@ -1369,6 +1372,9 @@ CONF_mInt64(arrow_io_coalesce_read_max_buffer_size, "8388608");
CONF_mInt64(arrow_io_coalesce_read_max_distance_size, "1048576");
CONF_mInt64(arrow_read_batch_size, "4096");

// default not to build the empty index
CONF_mInt32(config_tenann_default_build_threshold, "0");

// Set to true to enable socket_keepalive option in brpc
CONF_mBool(brpc_socket_keepalive, "false");
CONF_mBool(apply_del_vec_after_all_index_filter, "true");
Expand Down
6 changes: 6 additions & 0 deletions be/src/exec/olap_scan_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,12 @@ void OlapScanNode::_init_counter(RuntimeState* state) {
_bf_filtered_counter = ADD_CHILD_COUNTER(_scan_profile, "BloomFilterFilterRows", TUnit::UNIT, "SegmentInit");
_gin_filtered_counter = ADD_CHILD_COUNTER(_runtime_profile, "GinFilterRows", TUnit::UNIT, "SegmentInit");
_gin_filtered_timer = ADD_CHILD_TIMER(_runtime_profile, "GinFilter", "SegmentInit");
_get_row_ranges_by_vector_index_timer = ADD_CHILD_TIMER(_scan_profile, "GetVectorRowRangesTime", "SegmentInit");
_vector_search_timer = ADD_CHILD_TIMER(_scan_profile, "VectorSearchTime", "SegmentInit");
_vector_index_filtered_counter =
ADD_CHILD_COUNTER(_scan_profile, "VectorIndexFilterRows", TUnit::UNIT, "SegmentInit");
_process_vector_distance_and_id_timer =
ADD_CHILD_TIMER(_scan_profile, "ProcessVectorDistanceAndIdTime", "SegmentInit");
_seg_zm_filtered_counter = ADD_CHILD_COUNTER(_scan_profile, "SegmentZoneMapFilterRows", TUnit::UNIT, "SegmentInit");
_seg_rt_filtered_counter =
ADD_CHILD_COUNTER(_scan_profile, "SegmentRuntimeZoneMapFilterRows", TUnit::UNIT, "SegmentInit");
Expand Down
4 changes: 4 additions & 0 deletions be/src/exec/olap_scan_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ class OlapScanNode final : public starrocks::ScanNode {
RuntimeProfile::Counter* _bi_filter_timer = nullptr;
RuntimeProfile::Counter* _gin_filtered_counter = nullptr;
RuntimeProfile::Counter* _gin_filtered_timer = nullptr;
RuntimeProfile::Counter* _get_row_ranges_by_vector_index_timer = nullptr;
RuntimeProfile::Counter* _vector_search_timer = nullptr;
RuntimeProfile::Counter* _vector_index_filtered_counter = nullptr;
RuntimeProfile::Counter* _process_vector_distance_and_id_timer = nullptr;
RuntimeProfile::Counter* _pushdown_predicates_counter = nullptr;
RuntimeProfile::Counter* _rowsets_read_count = nullptr;
RuntimeProfile::Counter* _segments_read_count = nullptr;
Expand Down
64 changes: 59 additions & 5 deletions be/src/exec/pipeline/scan/olap_chunk_source.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "runtime/exec_env.h"
#include "storage/chunk_helper.h"
#include "storage/column_predicate_rewriter.h"
#include "storage/index/vector/vector_search_option.h"
#include "storage/olap_runtime_range_pruner.hpp"
#include "storage/predicate_parser.h"
#include "storage/projection_iterator.h"
Expand Down Expand Up @@ -76,6 +77,13 @@ Status OlapChunkSource::prepare(RuntimeState* state) {
RETURN_IF_ERROR(ChunkSource::prepare(state));
_runtime_state = state;
const TOlapScanNode& thrift_olap_scan_node = _scan_node->thrift_olap_scan_node();
const TVectorSearchOptions& vector_search_options = thrift_olap_scan_node.vector_search_options;
_use_vector_index = thrift_olap_scan_node.__isset.vector_search_options && vector_search_options.enable_use_ann;
if (_use_vector_index) {
_use_ivfpq = vector_search_options.use_ivfpq;
_vector_distance_column_name = vector_search_options.vector_distance_column_name;
_params.vector_search_option = std::make_shared<VectorSearchOption>();
}
const TupleDescriptor* tuple_desc = state->desc_tbl().get_tuple_descriptor(thrift_olap_scan_node.tuple_id);
_slots = &tuple_desc->slots();

Expand Down Expand Up @@ -129,6 +137,11 @@ void OlapChunkSource::_init_counter(RuntimeState* state) {
const std::string segment_init_name = "SegmentInit";
_seg_init_timer = ADD_CHILD_TIMER(_runtime_profile, segment_init_name, IO_TASK_EXEC_TIMER_NAME);
_bi_filter_timer = ADD_CHILD_TIMER(_runtime_profile, "BitmapIndexFilter", segment_init_name);
_get_row_ranges_by_vector_index_timer =
ADD_CHILD_TIMER(_runtime_profile, "GetVectorRowRangesTime", segment_init_name);
_vector_search_timer = ADD_CHILD_TIMER(_runtime_profile, "VectorSearchTime", segment_init_name);
_process_vector_distance_and_id_timer =
ADD_CHILD_TIMER(_runtime_profile, "ProcessVectorDistanceAndIdTime", segment_init_name);
_bi_filtered_counter = ADD_CHILD_COUNTER(_runtime_profile, "BitmapIndexFilterRows", TUnit::UNIT, segment_init_name);
_bf_filtered_counter = ADD_CHILD_COUNTER(_runtime_profile, "BloomFilterFilterRows", TUnit::UNIT, segment_init_name);
_gin_filtered_counter = ADD_CHILD_COUNTER(_runtime_profile, "GinFilterRows", TUnit::UNIT, segment_init_name);
Expand All @@ -140,6 +153,8 @@ void OlapChunkSource::_init_counter(RuntimeState* state) {
ADD_CHILD_COUNTER(_runtime_profile, "SegmentRuntimeZoneMapFilterRows", TUnit::UNIT, segment_init_name);
_zm_filtered_counter =
ADD_CHILD_COUNTER(_runtime_profile, "ZoneMapIndexFilterRows", TUnit::UNIT, segment_init_name);
_vector_index_filtered_counter =
ADD_CHILD_COUNTER(_runtime_profile, "VectorIndexFilterRows", TUnit::UNIT, segment_init_name);
_sk_filtered_counter =
ADD_CHILD_COUNTER_SKIP_MIN_MAX(_runtime_profile, "ShortKeyFilterRows", TUnit::UNIT,
_get_counter_min_max_type("ShortKeyFilterRows"), segment_init_name);
Expand Down Expand Up @@ -213,6 +228,24 @@ Status OlapChunkSource::_init_reader_params(const std::vector<std::unique_ptr<Ol
if (thrift_olap_scan_node.__isset.enable_gin_filter) {
_params.enable_gin_filter = thrift_olap_scan_node.enable_gin_filter;
}
_params.use_vector_index = _use_vector_index;
if (_use_vector_index) {
const TVectorSearchOptions& vector_options = thrift_olap_scan_node.vector_search_options;

_params.vector_search_option->vector_distance_column_name = _vector_distance_column_name;
_params.vector_search_option->k = vector_options.vector_limit_k;
for (const std::string& str : vector_options.query_vector) {
_params.vector_search_option->query_vector.push_back(std::stof(str));
}
if (_runtime_state->query_options().__isset.ann_params) {
_params.vector_search_option->query_params = _runtime_state->query_options().ann_params;
}
_params.vector_search_option->vector_range = vector_options.vector_range;
_params.vector_search_option->result_order = vector_options.result_order;
_params.vector_search_option->use_ivfpq = _use_ivfpq;
_params.vector_search_option->k_factor = _runtime_state->query_options().k_factor;
_params.vector_search_option->pq_refine_factor = _runtime_state->query_options().pq_refine_factor;
}
if (thrift_olap_scan_node.__isset.sorted_by_keys_per_tablet) {
_params.sorted_by_keys_per_tablet = thrift_olap_scan_node.sorted_by_keys_per_tablet;
}
Expand Down Expand Up @@ -271,7 +304,16 @@ Status OlapChunkSource::_init_reader_params(const std::vector<std::unique_ptr<Ol
Status OlapChunkSource::_init_scanner_columns(std::vector<uint32_t>& scanner_columns) {
for (auto slot : *_slots) {
DCHECK(slot->is_materialized());
int32_t index = _tablet_schema->field_index(slot->col_name());
int32_t index;
if (_use_vector_index && !_use_ivfpq) {
index = _tablet_schema->field_index(slot->col_name(), _vector_distance_column_name);
if (slot->col_name() == _vector_distance_column_name) {
_params.vector_search_option->vector_column_id = index;
_params.vector_search_option->vector_slot_id = slot->id();
}
} else {
index = _tablet_schema->field_index(slot->col_name());
}
if (index < 0) {
std::stringstream ss;
ss << "invalid field name: " << slot->col_name();
Expand All @@ -295,7 +337,12 @@ Status OlapChunkSource::_init_scanner_columns(std::vector<uint32_t>& scanner_col

Status OlapChunkSource::_init_unused_output_columns(const std::vector<std::string>& unused_output_columns) {
for (const auto& col_name : unused_output_columns) {
int32_t index = _tablet_schema->field_index(col_name);
int32_t index;
if (_use_vector_index && !_use_ivfpq) {
index = _tablet_schema->field_index(col_name, _vector_distance_column_name);
} else {
index = _tablet_schema->field_index(col_name);
}
if (index < 0) {
std::stringstream ss;
ss << "invalid field name: " << col_name;
Expand Down Expand Up @@ -500,7 +547,12 @@ Status OlapChunkSource::_init_global_dicts(TabletReaderParams* params) {
auto iter = global_dict_map.find(slot->id());
if (iter != global_dict_map.end()) {
auto& dict_map = iter->second.first;
int32_t index = _tablet_schema->field_index(slot->col_name());
int32_t index;
if (_use_vector_index && !_use_ivfpq) {
index = _tablet_schema->field_index(slot->col_name(), _vector_distance_column_name);
} else {
index = _tablet_schema->field_index(slot->col_name());
}
DCHECK(index >= 0);
global_dict->emplace(index, const_cast<GlobalDictMap*>(&dict_map));
}
Expand Down Expand Up @@ -610,6 +662,7 @@ void OlapChunkSource::_update_counter() {
COUNTER_UPDATE(_seg_zm_filtered_counter, _reader->stats().segment_stats_filtered);
COUNTER_UPDATE(_seg_rt_filtered_counter, _reader->stats().runtime_stats_filtered);
COUNTER_UPDATE(_zm_filtered_counter, _reader->stats().rows_stats_filtered);
COUNTER_UPDATE(_vector_index_filtered_counter, _reader->stats().rows_vector_index_filtered);
COUNTER_UPDATE(_bf_filtered_counter, _reader->stats().rows_bf_filtered);
COUNTER_UPDATE(_sk_filtered_counter, _reader->stats().rows_key_range_filtered);
COUNTER_UPDATE(_rows_after_sk_filtered_counter, _reader->stats().rows_after_key_range);
Expand All @@ -620,8 +673,9 @@ void OlapChunkSource::_update_counter() {

COUNTER_UPDATE(_bi_filtered_counter, _reader->stats().rows_bitmap_index_filtered);
COUNTER_UPDATE(_bi_filter_timer, _reader->stats().bitmap_index_filter_timer);
COUNTER_UPDATE(_gin_filtered_counter, _reader->stats().rows_gin_filtered);
COUNTER_UPDATE(_gin_filtered_timer, _reader->stats().gin_index_filter_ns);
COUNTER_UPDATE(_get_row_ranges_by_vector_index_timer, _reader->stats().get_row_ranges_by_vector_index_timer);
COUNTER_UPDATE(_vector_search_timer, _reader->stats().vector_search_timer);
COUNTER_UPDATE(_process_vector_distance_and_id_timer, _reader->stats().process_vector_distance_and_id_timer);
COUNTER_UPDATE(_block_seek_counter, _reader->stats().block_seek_num);

COUNTER_UPDATE(_rowsets_read_count, _reader->stats().rowsets_read_count);
Expand Down
12 changes: 11 additions & 1 deletion be/src/exec/pipeline/scan/olap_chunk_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class OlapChunkSource final : public ChunkSource {
OlapScanNode* _scan_node;
OlapScanContext* _scan_ctx;

const int64_t _limit; // -1: no limit
int64_t _limit; // -1: no limit
TInternalScanRange* _scan_range;

PredicateTree _non_pushdown_pred_tree;
Expand Down Expand Up @@ -104,6 +104,12 @@ class OlapChunkSource final : public ChunkSource {

std::vector<ColumnAccessPathPtr> _column_access_paths;

bool _use_vector_index = false;

bool _use_ivfpq = false;

std::string _vector_distance_column_name;

// The following are profile meatures
int64_t _num_rows_read = 0;

Expand Down Expand Up @@ -132,6 +138,7 @@ class OlapChunkSource final : public ChunkSource {
RuntimeProfile::Counter* _rows_key_range_counter = nullptr;
RuntimeProfile::Counter* _bf_filter_timer = nullptr;
RuntimeProfile::Counter* _zm_filtered_counter = nullptr;
RuntimeProfile::Counter* _vector_index_filtered_counter = nullptr;
RuntimeProfile::Counter* _bf_filtered_counter = nullptr;
RuntimeProfile::Counter* _seg_zm_filtered_counter = nullptr;
RuntimeProfile::Counter* _seg_rt_filtered_counter = nullptr;
Expand All @@ -148,6 +155,9 @@ class OlapChunkSource final : public ChunkSource {
RuntimeProfile::Counter* _bi_filter_timer = nullptr;
RuntimeProfile::Counter* _gin_filtered_counter = nullptr;
RuntimeProfile::Counter* _gin_filtered_timer = nullptr;
RuntimeProfile::Counter* _get_row_ranges_by_vector_index_timer = nullptr;
RuntimeProfile::Counter* _vector_search_timer = nullptr;
RuntimeProfile::Counter* _process_vector_distance_and_id_timer = nullptr;
RuntimeProfile::Counter* _pushdown_predicates_counter = nullptr;
RuntimeProfile::Counter* _non_pushdown_predicates_counter = nullptr;
RuntimeProfile::Counter* _rowsets_read_count = nullptr;
Expand Down
106 changes: 104 additions & 2 deletions be/src/exprs/math_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,6 @@ StatusOr<ColumnPtr> MathFunctions::cosine_similarity(FunctionContext* context, c
target_sum += target_data[j] * target_data[j];
}
}

if constexpr (!isNorm) {
result_value = sum / (std::sqrt(base_sum) * std::sqrt(target_sum));
} else {
Expand All @@ -899,10 +898,113 @@ StatusOr<ColumnPtr> MathFunctions::cosine_similarity(FunctionContext* context, c
return result;
}

// explicitly instaniate template function.
// explicitly instantiate template function.
template StatusOr<ColumnPtr> MathFunctions::cosine_similarity<TYPE_FLOAT, true>(FunctionContext* context,
const Columns& columns);
template StatusOr<ColumnPtr> MathFunctions::cosine_similarity<TYPE_FLOAT, false>(FunctionContext* context,
const Columns& columns);

template <LogicalType TYPE>
StatusOr<ColumnPtr> MathFunctions::l2_distance(FunctionContext* context, const Columns& columns) {
DCHECK_EQ(columns.size(), 2);

const Column* base = columns[0].get();
const Column* target = columns[1].get();
size_t target_size = target->size();
if (base->size() != target_size) {
return Status::InvalidArgument(fmt::format(
"l2_distance requires equal length arrays. base array size is {} and target array size is {}.",
base->size(), target->size()));
}
if (base->has_null() || target->has_null()) {
return Status::InvalidArgument(fmt::format("l2_distance does not support null values. {} array has null value.",
base->has_null() ? "base" : "target"));
}
if (base->is_constant()) {
auto* const_column = down_cast<const ConstColumn*>(base);
const_column->data_column()->assign(base->size(), 0);
base = const_column->data_column().get();
}
if (target->is_constant()) {
auto* const_column = down_cast<const ConstColumn*>(target);
const_column->data_column()->assign(target->size(), 0);
target = const_column->data_column().get();
}
if (base->is_nullable()) {
base = down_cast<const NullableColumn*>(base)->data_column().get();
}
if (target->is_nullable()) {
target = down_cast<const NullableColumn*>(target)->data_column().get();
}

// check dimension equality.
const Column* base_flat = down_cast<const ArrayColumn*>(base)->elements_column().get();
const uint32_t* base_offset = down_cast<const ArrayColumn*>(base)->offsets().get_data().data();
size_t base_flat_size = base_flat->size();

const Column* target_flat = down_cast<const ArrayColumn*>(target)->elements_column().get();
size_t target_flat_size = target_flat->size();
const uint32_t* target_offset = down_cast<const ArrayColumn*>(target)->offsets().get_data().data();

if (base_flat_size != target_flat_size) {
return Status::InvalidArgument("l2_distance requires equal length arrays");
}

if (base_flat->has_null() || target_flat->has_null()) {
return Status::InvalidArgument("l2_distance does not support null values");
}
if (base_flat->is_nullable()) {
base_flat = down_cast<const NullableColumn*>(base_flat)->data_column().get();
}
if (target_flat->is_nullable()) {
target_flat = down_cast<const NullableColumn*>(target_flat)->data_column().get();
}

using CppType = RunTimeCppType<TYPE>;
using ColumnType = RunTimeColumnType<TYPE>;

const CppType* base_data_head = down_cast<const ColumnType*>(base_flat)->get_data().data();
const CppType* target_data_head = down_cast<const ColumnType*>(target_flat)->get_data().data();

// prepare result with nullable value.
ColumnPtr result = ColumnHelper::create_column(TypeDescriptor{TYPE}, false, false, target_size);
ColumnType* data_result = down_cast<ColumnType*>(result.get());
CppType* result_data = data_result->get_data().data();

for (size_t i = 0; i < target_size; i++) {
size_t t_dim_size = target_offset[i + 1] - target_offset[i];
size_t b_dim_size = base_offset[i + 1] - base_offset[i];
if (t_dim_size != b_dim_size) {
return Status::InvalidArgument(
fmt::format("l2_distance requires equal length arrays in each row. base array dimension size "
"is {}, target array dimension size is {}.",
b_dim_size, t_dim_size));
}
if (t_dim_size == 0) {
return Status::InvalidArgument("l2_distance requires non-empty arrays in each row");
}
}

const CppType* target_data = target_data_head;
const CppType* base_data = base_data_head;

for (size_t i = 0; i < target_size; i++) {
CppType sum = 0;
size_t dim_size = target_offset[i + 1] - target_offset[i];
for (size_t j = 0; j < dim_size; j++) {
CppType distance;
distance = (base_data[j] - target_data[j]) * (base_data[j] - target_data[j]);
sum += distance;
}
result_data[i] = sum;
target_data += dim_size;
base_data += dim_size;
}

return result;
}

// explicitly instantiate template function.
template StatusOr<ColumnPtr> MathFunctions::l2_distance<TYPE_FLOAT>(FunctionContext* context, const Columns& columns);

} // namespace starrocks
Loading

0 comments on commit c1c3b4f

Please sign in to comment.