Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support select vector index #48521

Merged
merged 2 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions be/src/column/chunk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,15 @@ void Chunk::append_column(ColumnPtr column, const FieldPtr& field) {
check_or_die();
}

void Chunk::append_vector_column(ColumnPtr column, const FieldPtr& field, SlotId slot_id) {
DCHECK(!_cid_to_index.contains(field->id()));
_cid_to_index[field->id()] = _columns.size();
_slot_id_to_index[slot_id] = _columns.size();
_columns.emplace_back(std::move(column));
_schema->append(field);
check_or_die();
}

void Chunk::append_column(ColumnPtr column, SlotId slot_id) {
_slot_id_to_index[slot_id] = _columns.size();
_columns.emplace_back(std::move(column));
Expand Down
2 changes: 2 additions & 0 deletions be/src/column/chunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ class Chunk {
// schema must exist and will be updated.
void append_column(ColumnPtr column, const FieldPtr& field);

void append_vector_column(ColumnPtr column, const FieldPtr& field, SlotId slot_id);

void append_column(ColumnPtr column, SlotId slot_id);
void insert_column(size_t idx, ColumnPtr column, const FieldPtr& field);

Expand Down
5 changes: 4 additions & 1 deletion be/src/column/schema.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ Schema::Schema(Schema* schema, const std::vector<ColumnId>& cids)
auto ori_sort_idxes = schema->sort_key_idxes();
std::unordered_set<ColumnId> scids(ori_sort_idxes.begin(), ori_sort_idxes.end());
for (int i = 0; i < cids.size(); i++) {
DCHECK_LT(cids[i], schema->_fields.size());
if (cids[i] >= schema->_fields.size()) {
_fields.resize(_fields.size() - 1);
continue;
}
yulongfufu marked this conversation as resolved.
Show resolved Hide resolved
_fields[i] = schema->_fields[cids[i]];
if (scids.find(cids[i]) != scids.end()) {
_sort_key_idxes.emplace_back(i);
Expand Down
6 changes: 6 additions & 0 deletions be/src/common/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,9 @@ CONF_Int64(query_cache_capacity, "536870912");
// ranges in [1,16], default value is 4.
CONF_mInt32(query_cache_num_lanes_per_driver, "4");

// Used by vector query cache, 500MB in default
CONF_Int64(vector_query_cache_capacity, "536870912");

// Used to limit buffer size of tablet send channel.
CONF_mInt64(send_channel_buffer_limit, "67108864");

Expand Down Expand Up @@ -1369,6 +1372,9 @@ CONF_mInt64(arrow_io_coalesce_read_max_buffer_size, "8388608");
CONF_mInt64(arrow_io_coalesce_read_max_distance_size, "1048576");
CONF_mInt64(arrow_read_batch_size, "4096");

// default not to build the empty index
CONF_mInt32(config_tenann_default_build_threshold, "0");

// Set to true to enable socket_keepalive option in brpc
CONF_mBool(brpc_socket_keepalive, "false");
CONF_mBool(apply_del_vec_after_all_index_filter, "true");
Expand Down
6 changes: 6 additions & 0 deletions be/src/exec/olap_scan_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,12 @@ void OlapScanNode::_init_counter(RuntimeState* state) {
_bf_filtered_counter = ADD_CHILD_COUNTER(_scan_profile, "BloomFilterFilterRows", TUnit::UNIT, "SegmentInit");
_gin_filtered_counter = ADD_CHILD_COUNTER(_runtime_profile, "GinFilterRows", TUnit::UNIT, "SegmentInit");
_gin_filtered_timer = ADD_CHILD_TIMER(_runtime_profile, "GinFilter", "SegmentInit");
_get_row_ranges_by_vector_index_timer = ADD_CHILD_TIMER(_scan_profile, "GetVectorRowRangesTime", "SegmentInit");
_vector_search_timer = ADD_CHILD_TIMER(_scan_profile, "VectorSearchTime", "SegmentInit");
_vector_index_filtered_counter =
ADD_CHILD_COUNTER(_scan_profile, "VectorIndexFilterRows", TUnit::UNIT, "SegmentInit");
_process_vector_distance_and_id_timer =
ADD_CHILD_TIMER(_scan_profile, "ProcessVectorDistanceAndIdTime", "SegmentInit");
_seg_zm_filtered_counter = ADD_CHILD_COUNTER(_scan_profile, "SegmentZoneMapFilterRows", TUnit::UNIT, "SegmentInit");
_seg_rt_filtered_counter =
ADD_CHILD_COUNTER(_scan_profile, "SegmentRuntimeZoneMapFilterRows", TUnit::UNIT, "SegmentInit");
Expand Down
4 changes: 4 additions & 0 deletions be/src/exec/olap_scan_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ class OlapScanNode final : public starrocks::ScanNode {
RuntimeProfile::Counter* _bi_filter_timer = nullptr;
RuntimeProfile::Counter* _gin_filtered_counter = nullptr;
RuntimeProfile::Counter* _gin_filtered_timer = nullptr;
RuntimeProfile::Counter* _get_row_ranges_by_vector_index_timer = nullptr;
RuntimeProfile::Counter* _vector_search_timer = nullptr;
RuntimeProfile::Counter* _vector_index_filtered_counter = nullptr;
RuntimeProfile::Counter* _process_vector_distance_and_id_timer = nullptr;
RuntimeProfile::Counter* _pushdown_predicates_counter = nullptr;
RuntimeProfile::Counter* _rowsets_read_count = nullptr;
RuntimeProfile::Counter* _segments_read_count = nullptr;
Expand Down
64 changes: 59 additions & 5 deletions be/src/exec/pipeline/scan/olap_chunk_source.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "runtime/exec_env.h"
#include "storage/chunk_helper.h"
#include "storage/column_predicate_rewriter.h"
#include "storage/index/vector/vector_search_option.h"
#include "storage/olap_runtime_range_pruner.hpp"
#include "storage/predicate_parser.h"
#include "storage/projection_iterator.h"
Expand Down Expand Up @@ -76,6 +77,13 @@ Status OlapChunkSource::prepare(RuntimeState* state) {
RETURN_IF_ERROR(ChunkSource::prepare(state));
_runtime_state = state;
const TOlapScanNode& thrift_olap_scan_node = _scan_node->thrift_olap_scan_node();
const TVectorSearchOptions& vector_search_options = thrift_olap_scan_node.vector_search_options;
_use_vector_index = thrift_olap_scan_node.__isset.vector_search_options && vector_search_options.enable_use_ann;
if (_use_vector_index) {
_use_ivfpq = vector_search_options.use_ivfpq;
_vector_distance_column_name = vector_search_options.vector_distance_column_name;
_params.vector_search_option = std::make_shared<VectorSearchOption>();
}
const TupleDescriptor* tuple_desc = state->desc_tbl().get_tuple_descriptor(thrift_olap_scan_node.tuple_id);
_slots = &tuple_desc->slots();

Expand Down Expand Up @@ -129,6 +137,11 @@ void OlapChunkSource::_init_counter(RuntimeState* state) {
const std::string segment_init_name = "SegmentInit";
_seg_init_timer = ADD_CHILD_TIMER(_runtime_profile, segment_init_name, IO_TASK_EXEC_TIMER_NAME);
_bi_filter_timer = ADD_CHILD_TIMER(_runtime_profile, "BitmapIndexFilter", segment_init_name);
_get_row_ranges_by_vector_index_timer =
ADD_CHILD_TIMER(_runtime_profile, "GetVectorRowRangesTime", segment_init_name);
_vector_search_timer = ADD_CHILD_TIMER(_runtime_profile, "VectorSearchTime", segment_init_name);
_process_vector_distance_and_id_timer =
ADD_CHILD_TIMER(_runtime_profile, "ProcessVectorDistanceAndIdTime", segment_init_name);
_bi_filtered_counter = ADD_CHILD_COUNTER(_runtime_profile, "BitmapIndexFilterRows", TUnit::UNIT, segment_init_name);
_bf_filtered_counter = ADD_CHILD_COUNTER(_runtime_profile, "BloomFilterFilterRows", TUnit::UNIT, segment_init_name);
_gin_filtered_counter = ADD_CHILD_COUNTER(_runtime_profile, "GinFilterRows", TUnit::UNIT, segment_init_name);
Expand All @@ -140,6 +153,8 @@ void OlapChunkSource::_init_counter(RuntimeState* state) {
ADD_CHILD_COUNTER(_runtime_profile, "SegmentRuntimeZoneMapFilterRows", TUnit::UNIT, segment_init_name);
_zm_filtered_counter =
ADD_CHILD_COUNTER(_runtime_profile, "ZoneMapIndexFilterRows", TUnit::UNIT, segment_init_name);
_vector_index_filtered_counter =
ADD_CHILD_COUNTER(_runtime_profile, "VectorIndexFilterRows", TUnit::UNIT, segment_init_name);
_sk_filtered_counter =
ADD_CHILD_COUNTER_SKIP_MIN_MAX(_runtime_profile, "ShortKeyFilterRows", TUnit::UNIT,
_get_counter_min_max_type("ShortKeyFilterRows"), segment_init_name);
Expand Down Expand Up @@ -213,6 +228,24 @@ Status OlapChunkSource::_init_reader_params(const std::vector<std::unique_ptr<Ol
if (thrift_olap_scan_node.__isset.enable_gin_filter) {
_params.enable_gin_filter = thrift_olap_scan_node.enable_gin_filter;
}
_params.use_vector_index = _use_vector_index;
if (_use_vector_index) {
const TVectorSearchOptions& vector_options = thrift_olap_scan_node.vector_search_options;

_params.vector_search_option->vector_distance_column_name = _vector_distance_column_name;
_params.vector_search_option->k = vector_options.vector_limit_k;
for (const std::string& str : vector_options.query_vector) {
_params.vector_search_option->query_vector.push_back(std::stof(str));
}
if (_runtime_state->query_options().__isset.ann_params) {
_params.vector_search_option->query_params = _runtime_state->query_options().ann_params;
}
_params.vector_search_option->vector_range = vector_options.vector_range;
_params.vector_search_option->result_order = vector_options.result_order;
_params.vector_search_option->use_ivfpq = _use_ivfpq;
_params.vector_search_option->k_factor = _runtime_state->query_options().k_factor;
_params.vector_search_option->pq_refine_factor = _runtime_state->query_options().pq_refine_factor;
}
if (thrift_olap_scan_node.__isset.sorted_by_keys_per_tablet) {
_params.sorted_by_keys_per_tablet = thrift_olap_scan_node.sorted_by_keys_per_tablet;
}
Expand Down Expand Up @@ -271,7 +304,16 @@ Status OlapChunkSource::_init_reader_params(const std::vector<std::unique_ptr<Ol
Status OlapChunkSource::_init_scanner_columns(std::vector<uint32_t>& scanner_columns) {
for (auto slot : *_slots) {
DCHECK(slot->is_materialized());
int32_t index = _tablet_schema->field_index(slot->col_name());
int32_t index;
if (_use_vector_index && !_use_ivfpq) {
index = _tablet_schema->field_index(slot->col_name(), _vector_distance_column_name);
if (slot->col_name() == _vector_distance_column_name) {
_params.vector_search_option->vector_column_id = index;
_params.vector_search_option->vector_slot_id = slot->id();
}
} else {
index = _tablet_schema->field_index(slot->col_name());
}
if (index < 0) {
std::stringstream ss;
ss << "invalid field name: " << slot->col_name();
Expand All @@ -295,7 +337,12 @@ Status OlapChunkSource::_init_scanner_columns(std::vector<uint32_t>& scanner_col

Status OlapChunkSource::_init_unused_output_columns(const std::vector<std::string>& unused_output_columns) {
for (const auto& col_name : unused_output_columns) {
int32_t index = _tablet_schema->field_index(col_name);
int32_t index;
if (_use_vector_index && !_use_ivfpq) {
index = _tablet_schema->field_index(col_name, _vector_distance_column_name);
} else {
index = _tablet_schema->field_index(col_name);
}
if (index < 0) {
std::stringstream ss;
ss << "invalid field name: " << col_name;
Expand Down Expand Up @@ -500,7 +547,12 @@ Status OlapChunkSource::_init_global_dicts(TabletReaderParams* params) {
auto iter = global_dict_map.find(slot->id());
if (iter != global_dict_map.end()) {
auto& dict_map = iter->second.first;
int32_t index = _tablet_schema->field_index(slot->col_name());
int32_t index;
if (_use_vector_index && !_use_ivfpq) {
index = _tablet_schema->field_index(slot->col_name(), _vector_distance_column_name);
} else {
index = _tablet_schema->field_index(slot->col_name());
}
DCHECK(index >= 0);
global_dict->emplace(index, const_cast<GlobalDictMap*>(&dict_map));
}
Expand Down Expand Up @@ -610,6 +662,7 @@ void OlapChunkSource::_update_counter() {
COUNTER_UPDATE(_seg_zm_filtered_counter, _reader->stats().segment_stats_filtered);
COUNTER_UPDATE(_seg_rt_filtered_counter, _reader->stats().runtime_stats_filtered);
COUNTER_UPDATE(_zm_filtered_counter, _reader->stats().rows_stats_filtered);
COUNTER_UPDATE(_vector_index_filtered_counter, _reader->stats().rows_vector_index_filtered);
COUNTER_UPDATE(_bf_filtered_counter, _reader->stats().rows_bf_filtered);
COUNTER_UPDATE(_sk_filtered_counter, _reader->stats().rows_key_range_filtered);
COUNTER_UPDATE(_rows_after_sk_filtered_counter, _reader->stats().rows_after_key_range);
Expand All @@ -620,8 +673,9 @@ void OlapChunkSource::_update_counter() {

COUNTER_UPDATE(_bi_filtered_counter, _reader->stats().rows_bitmap_index_filtered);
COUNTER_UPDATE(_bi_filter_timer, _reader->stats().bitmap_index_filter_timer);
COUNTER_UPDATE(_gin_filtered_counter, _reader->stats().rows_gin_filtered);
COUNTER_UPDATE(_gin_filtered_timer, _reader->stats().gin_index_filter_ns);
COUNTER_UPDATE(_get_row_ranges_by_vector_index_timer, _reader->stats().get_row_ranges_by_vector_index_timer);
COUNTER_UPDATE(_vector_search_timer, _reader->stats().vector_search_timer);
COUNTER_UPDATE(_process_vector_distance_and_id_timer, _reader->stats().process_vector_distance_and_id_timer);
COUNTER_UPDATE(_block_seek_counter, _reader->stats().block_seek_num);

COUNTER_UPDATE(_rowsets_read_count, _reader->stats().rowsets_read_count);
Expand Down
12 changes: 11 additions & 1 deletion be/src/exec/pipeline/scan/olap_chunk_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class OlapChunkSource final : public ChunkSource {
OlapScanNode* _scan_node;
OlapScanContext* _scan_ctx;

const int64_t _limit; // -1: no limit
int64_t _limit; // -1: no limit
TInternalScanRange* _scan_range;

PredicateTree _non_pushdown_pred_tree;
Expand Down Expand Up @@ -104,6 +104,12 @@ class OlapChunkSource final : public ChunkSource {

std::vector<ColumnAccessPathPtr> _column_access_paths;

bool _use_vector_index = false;

bool _use_ivfpq = false;

std::string _vector_distance_column_name;

// The following are profile meatures
int64_t _num_rows_read = 0;

Expand Down Expand Up @@ -132,6 +138,7 @@ class OlapChunkSource final : public ChunkSource {
RuntimeProfile::Counter* _rows_key_range_counter = nullptr;
RuntimeProfile::Counter* _bf_filter_timer = nullptr;
RuntimeProfile::Counter* _zm_filtered_counter = nullptr;
RuntimeProfile::Counter* _vector_index_filtered_counter = nullptr;
RuntimeProfile::Counter* _bf_filtered_counter = nullptr;
RuntimeProfile::Counter* _seg_zm_filtered_counter = nullptr;
RuntimeProfile::Counter* _seg_rt_filtered_counter = nullptr;
Expand All @@ -148,6 +155,9 @@ class OlapChunkSource final : public ChunkSource {
RuntimeProfile::Counter* _bi_filter_timer = nullptr;
RuntimeProfile::Counter* _gin_filtered_counter = nullptr;
RuntimeProfile::Counter* _gin_filtered_timer = nullptr;
RuntimeProfile::Counter* _get_row_ranges_by_vector_index_timer = nullptr;
RuntimeProfile::Counter* _vector_search_timer = nullptr;
RuntimeProfile::Counter* _process_vector_distance_and_id_timer = nullptr;
RuntimeProfile::Counter* _pushdown_predicates_counter = nullptr;
RuntimeProfile::Counter* _non_pushdown_predicates_counter = nullptr;
RuntimeProfile::Counter* _rowsets_read_count = nullptr;
Expand Down
106 changes: 104 additions & 2 deletions be/src/exprs/math_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,6 @@ StatusOr<ColumnPtr> MathFunctions::cosine_similarity(FunctionContext* context, c
target_sum += target_data[j] * target_data[j];
}
}

if constexpr (!isNorm) {
result_value = sum / (std::sqrt(base_sum) * std::sqrt(target_sum));
} else {
Expand All @@ -899,10 +898,113 @@ StatusOr<ColumnPtr> MathFunctions::cosine_similarity(FunctionContext* context, c
return result;
}

// explicitly instaniate template function.
// explicitly instantiate template function.
template StatusOr<ColumnPtr> MathFunctions::cosine_similarity<TYPE_FLOAT, true>(FunctionContext* context,
const Columns& columns);
template StatusOr<ColumnPtr> MathFunctions::cosine_similarity<TYPE_FLOAT, false>(FunctionContext* context,
const Columns& columns);

template <LogicalType TYPE>
StatusOr<ColumnPtr> MathFunctions::l2_distance(FunctionContext* context, const Columns& columns) {
DCHECK_EQ(columns.size(), 2);

const Column* base = columns[0].get();
const Column* target = columns[1].get();
size_t target_size = target->size();
if (base->size() != target_size) {
return Status::InvalidArgument(fmt::format(
"l2_distance requires equal length arrays. base array size is {} and target array size is {}.",
base->size(), target->size()));
}
if (base->has_null() || target->has_null()) {
return Status::InvalidArgument(fmt::format("l2_distance does not support null values. {} array has null value.",
base->has_null() ? "base" : "target"));
}
if (base->is_constant()) {
auto* const_column = down_cast<const ConstColumn*>(base);
const_column->data_column()->assign(base->size(), 0);
base = const_column->data_column().get();
}
if (target->is_constant()) {
auto* const_column = down_cast<const ConstColumn*>(target);
const_column->data_column()->assign(target->size(), 0);
target = const_column->data_column().get();
}
if (base->is_nullable()) {
base = down_cast<const NullableColumn*>(base)->data_column().get();
}
if (target->is_nullable()) {
target = down_cast<const NullableColumn*>(target)->data_column().get();
}

// check dimension equality.
const Column* base_flat = down_cast<const ArrayColumn*>(base)->elements_column().get();
const uint32_t* base_offset = down_cast<const ArrayColumn*>(base)->offsets().get_data().data();
size_t base_flat_size = base_flat->size();

const Column* target_flat = down_cast<const ArrayColumn*>(target)->elements_column().get();
size_t target_flat_size = target_flat->size();
const uint32_t* target_offset = down_cast<const ArrayColumn*>(target)->offsets().get_data().data();

if (base_flat_size != target_flat_size) {
return Status::InvalidArgument("l2_distance requires equal length arrays");
}

if (base_flat->has_null() || target_flat->has_null()) {
return Status::InvalidArgument("l2_distance does not support null values");
}
if (base_flat->is_nullable()) {
base_flat = down_cast<const NullableColumn*>(base_flat)->data_column().get();
}
if (target_flat->is_nullable()) {
target_flat = down_cast<const NullableColumn*>(target_flat)->data_column().get();
}

using CppType = RunTimeCppType<TYPE>;
using ColumnType = RunTimeColumnType<TYPE>;

const CppType* base_data_head = down_cast<const ColumnType*>(base_flat)->get_data().data();
const CppType* target_data_head = down_cast<const ColumnType*>(target_flat)->get_data().data();

// prepare result with nullable value.
ColumnPtr result = ColumnHelper::create_column(TypeDescriptor{TYPE}, false, false, target_size);
ColumnType* data_result = down_cast<ColumnType*>(result.get());
CppType* result_data = data_result->get_data().data();

for (size_t i = 0; i < target_size; i++) {
size_t t_dim_size = target_offset[i + 1] - target_offset[i];
size_t b_dim_size = base_offset[i + 1] - base_offset[i];
if (t_dim_size != b_dim_size) {
return Status::InvalidArgument(
fmt::format("l2_distance requires equal length arrays in each row. base array dimension size "
"is {}, target array dimension size is {}.",
b_dim_size, t_dim_size));
}
if (t_dim_size == 0) {
return Status::InvalidArgument("l2_distance requires non-empty arrays in each row");
}
}

const CppType* target_data = target_data_head;
const CppType* base_data = base_data_head;

for (size_t i = 0; i < target_size; i++) {
CppType sum = 0;
size_t dim_size = target_offset[i + 1] - target_offset[i];
for (size_t j = 0; j < dim_size; j++) {
CppType distance;
distance = (base_data[j] - target_data[j]) * (base_data[j] - target_data[j]);
sum += distance;
}
result_data[i] = sum;
target_data += dim_size;
base_data += dim_size;
}

return result;
}

// explicitly instantiate template function.
template StatusOr<ColumnPtr> MathFunctions::l2_distance<TYPE_FLOAT>(FunctionContext* context, const Columns& columns);

} // namespace starrocks
Loading
Loading