Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix index scans with filters pushed down, fix lateral join optimization when grouping, update to v1.1.2 #30

Merged
merged 4 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .github/workflows/StableDistributionPipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,37 @@ name: Stable Extension Distribution Pipeline
on:
pull_request:
branches:
- v0.10.2
- v1.1.2
paths-ignore:
- '**/README.md'
- 'doc/**'
push:
branches:
- v0.10.2
- v1.1.2
paths-ignore:
- '**/README.md'
- 'doc/**'
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || '' }}-${{ github.base_ref || '' }}-${{ github.ref != 'refs/heads/v0.10.2' || github.sha }}
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || '' }}-${{ github.base_ref || '' }}-${{ github.ref != 'refs/heads/v1.1.2' || github.sha }}
cancel-in-progress: true

jobs:
duckdb-stable-build:
name: Build extension binaries
uses: duckdb/extension-ci-tools/.github/workflows/_extension_distribution.yml@v0.10.2
uses: duckdb/extension-ci-tools/.github/workflows/_extension_distribution.yml@v1.1.2
with:
vcpkg_commit: a42af01b72c28a8e1d7b48107b33e4f286a55ef6
duckdb_version: v0.10.2
duckdb_version: v1.1.2
extension_name: vss

duckdb-stable-deploy:
name: Deploy extension binaries
needs: duckdb-stable-build
uses: duckdb/extension-ci-tools/.github/workflows/_extension_deploy.yml@v0.10.2
uses: duckdb/extension-ci-tools/.github/workflows/_extension_deploy.yml@v1.1.2
secrets: inherit
with:
duckdb_version: v0.10.2
duckdb_version: v1.1.2
extension_name: vss
deploy_latest: ${{ startsWith(github.ref, 'refs/heads/v') || github.ref == 'refs/heads/main' }}
2 changes: 1 addition & 1 deletion duckdb
Submodule duckdb updated 547 files
2 changes: 1 addition & 1 deletion src/hnsw/hnsw_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ void HNSWIndex::Construct(DataChunk &input, Vector &row_ids, idx_t thread_idx) {
// Now we can be sure that we have enough space in the index
auto lock = rwlock.GetSharedLock();
for (idx_t out_idx = 0; out_idx < count; out_idx++) {
if(FlatVector::IsNull(vec_vec, out_idx)) {
if (FlatVector::IsNull(vec_vec, out_idx)) {
// Dont add nulls
continue;
}
Expand Down
41 changes: 33 additions & 8 deletions src/hnsw/hnsw_optimize_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,22 @@ OperatorResultType PhysicalHNSWIndexJoin::Execute(ExecutionContext &context, Dat
auto &state = ostate.Cast<HNSWIndexJoinState>();
auto &transcation = DuckTransaction::Get(context.client, table.catalog);

// TODO: dont flatten
input.Flatten();

// The first 0..inner_column_ids.size() columns are the inner table columns
const auto MATCH_COLUMN_OFFSET = inner_column_ids.size();
// The next column is the row number
const auto OUTER_COLUMN_OFFSET = MATCH_COLUMN_OFFSET + 1;
// The rest of the columns are the outer table columns

auto &rhs_vector_vector = input.data[outer_vector_column];
auto &rhs_vector_child = ArrayVector::GetEntry(rhs_vector_vector);
const auto rhs_vector_size = ArrayType::GetSize(rhs_vector_vector.GetType());
const auto rhs_vector_ptr = FlatVector::GetData<float>(rhs_vector_child);

// We mimic the window row_number() operator here and output the row number in each batch, basically.
const auto row_number_vector = FlatVector::GetData<idx_t>(chunk.data[MATCH_COLUMN_OFFSET]);

hnsw_index.ResetMultiScan(*state.index_state);

// How many batches are we going to process?
Expand All @@ -134,7 +142,9 @@ OperatorResultType PhysicalHNSWIndexJoin::Execute(ExecutionContext &context, Dat
// Scan the index for row ids
const auto match_count = hnsw_index.ExecuteMultiScan(*state.index_state, rhs_vector_data, limit);
for (idx_t i = 0; i < match_count; i++) {
state.match_sel.set_index(output_idx++, batch_idx);
state.match_sel.set_index(output_idx, batch_idx);
row_number_vector[output_idx] = i + 1; // Note: 1-indexed!
output_idx++;
}
}

Expand All @@ -144,7 +154,7 @@ OperatorResultType PhysicalHNSWIndexJoin::Execute(ExecutionContext &context, Dat
table.GetStorage().Fetch(transcation, chunk, state.phyiscal_column_ids, row_ids, output_idx, state.fetch_state);

// Now slice the chunk so that we include the rhs too
chunk.Slice(input, state.match_sel, output_idx, state.phyiscal_column_ids.size());
chunk.Slice(input, state.match_sel, output_idx, OUTER_COLUMN_OFFSET);

// Set the cardinality
chunk.SetCardinality(output_idx);
Expand Down Expand Up @@ -220,6 +230,9 @@ void LogicalHNSWIndexJoin::ResolveTypes() {
}
}

// Always add the row_number after the inner columns
types.emplace_back(LogicalType::BIGINT);

// Also add the types of the right hand side
auto &right_types = children[0]->types;
types.insert(types.end(), right_types.begin(), right_types.end());
Expand All @@ -236,6 +249,10 @@ vector<ColumnBinding> LogicalHNSWIndexJoin::GetLeftBindings() {
result.emplace_back(table_index, proj_id);
}
}

// Always add the row number last
result.emplace_back(table_index, inner_column_ids.size());

return result;
}

Expand Down Expand Up @@ -550,6 +567,14 @@ bool HNSWIndexJoinOptimizer::TryOptimize(Binder &binder, ClientContext &context,
replacer.replacement_bindings.emplace_back(old_binding,
ColumnBinding(projection_table_index, new_binding_idx++));
}

// Also add the window expression to the projection last. We will replace this with a reference to the index join
// in the next inlining step
ColumnBinding window_binding(window.window_index, 0);
projection_expressions.push_back(make_uniq<BoundColumnRefExpression>(LogicalType::BIGINT, window_binding));
replacer.replacement_bindings.emplace_back(window_binding,
ColumnBinding(projection_table_index, new_binding_idx++));

auto new_projection = make_uniq<LogicalProjection>(projection_table_index, std::move(projection_expressions));

// Replace all previous references with our new projection
Expand Down Expand Up @@ -583,12 +608,12 @@ bool HNSWIndexJoinOptimizer::TryOptimize(Binder &binder, ClientContext &context,
expr = inner_expr->Copy();
// These can still reference the delim_get, but we replace them in the next step.
}
// Special case: the window row number expression. Is not used, but maybe we add it to the index scan as a third
// column export.

// Special case: the window row number expression. Forward this to the index join
else if (ref.binding.table_index == window.window_index) {
// Just return a constant for now...
// TODO: Fix this!
expr = make_uniq<BoundConstantExpression>(Value::BIGINT(1337));
// The special "row_number" expression is always the last column of the index_join itself
ColumnBinding index_row_number_binding(index_join->table_index, index_join->inner_column_ids.size());
expr = make_uniq<BoundColumnRefExpression>(LogicalType::BIGINT, index_row_number_binding);
}
}

Expand Down
59 changes: 48 additions & 11 deletions src/hnsw/hnsw_optimize_scan.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp"
#include "duckdb/optimizer/column_lifetime_analyzer.hpp"
#include "duckdb/optimizer/matcher/expression_matcher.hpp"
#include "duckdb/optimizer/optimizer_extension.hpp"
#include "duckdb/optimizer/remove_unused_columns.hpp"
#include "duckdb/planner/expression/bound_constant_expression.hpp"
#include "duckdb/planner/expression/bound_function_expression.hpp"
#include "duckdb/planner/expression_iterator.hpp"
#include "duckdb/planner/operator/logical_get.hpp"
#include "duckdb/planner/operator/logical_projection.hpp"
#include "duckdb/planner/operator/logical_top_n.hpp"
#include "duckdb/planner/operator/logical_filter.hpp"
#include "duckdb/storage/data_table.hpp"

#include "hnsw/hnsw.hpp"
#include "hnsw/hnsw_index.hpp"
#include "hnsw/hnsw_index_scan.hpp"
#include "duckdb/optimizer/remove_unused_columns.hpp"
#include "duckdb/planner/expression_iterator.hpp"
#include "duckdb/optimizer/matcher/expression_matcher.hpp"

namespace duckdb {

//-----------------------------------------------------------------------------
Expand Down Expand Up @@ -69,12 +71,18 @@ class HNSWIndexScanOptimizer : public OptimizerExtension {
return false;
}

auto &get = projection.children.front()->Cast<LogicalGet>();
auto &get_ptr = projection.children.front();
auto &get = get_ptr->Cast<LogicalGet>();
// Check if the get is a table scan
if (get.function.name != "seq_scan") {
return false;
}

if (get.dynamic_filters && get.dynamic_filters->HasFilters()) {
// Cant push down!
return false;
}

// We have a top-n operator on top of a table scan
// We can replace the function with a custom index scan (if the table has a custom index)

Expand Down Expand Up @@ -137,17 +145,46 @@ class HNSWIndexScanOptimizer : public OptimizerExtension {
return false;
}

// Replace the scan with our custom index scan function

get.function = HNSWIndexScanFunction::GetFunction();
// If there are no table filters pushed down into the get, we can just replace the get with the index scan
const auto cardinality = get.function.cardinality(context, bind_data.get());
get.function = HNSWIndexScanFunction::GetFunction();
get.has_estimated_cardinality = cardinality->has_estimated_cardinality;
get.estimated_cardinality = cardinality->estimated_cardinality;
get.bind_data = std::move(bind_data);
if (get.table_filters.filters.empty()) {

// Remove the TopN operator
plan = std::move(top_n.children[0]);
return true;
}

// Remove the distance function from the projection
// projection.expressions.erase(projection.expressions.begin() + static_cast<ptrdiff_t>(projection_index));
// top_n.expressions
// Otherwise, things get more complicated. We need to pullup the filters from the table scan as our index scan
// does not support regular filter pushdown.
get.projection_ids.clear();
get.types.clear();

auto new_filter = make_uniq<LogicalFilter>();
auto &column_ids = get.GetColumnIds();
for (const auto &entry : get.table_filters.filters) {
idx_t column_id = entry.first;
auto &type = get.returned_types[column_id];
bool found = false;
for (idx_t i = 0; i < column_ids.size(); i++) {
if (column_ids[i] == column_id) {
column_id = i;
found = true;
break;
}
}
if (!found) {
throw InternalException("Could not find column id for filter");
}
auto column = make_uniq<BoundColumnRefExpression>(type, ColumnBinding(get.table_index, column_id));
new_filter->expressions.push_back(entry.second->ToExpression(*column));
}
new_filter->children.push_back(std::move(get_ptr));
new_filter->ResolveOperatorTypes();
get_ptr = std::move(new_filter);

// Remove the TopN operator
plan = std::move(top_n.children[0]);
Expand Down
44 changes: 42 additions & 2 deletions src/hnsw/hnsw_optimize_topk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "duckdb/planner/expression/bound_function_expression.hpp"
#include "duckdb/planner/operator/logical_aggregate.hpp"
#include "duckdb/planner/operator/logical_get.hpp"
#include "duckdb/planner/operator/logical_filter.hpp"
#include "duckdb/optimizer/optimizer.hpp"
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
#include "duckdb/optimizer/matcher/expression_matcher.hpp"
Expand Down Expand Up @@ -41,7 +42,7 @@ static unique_ptr<Expression> CreateListOrderByExpr(ClientContext &context, uniq
new_agg_expr->order_bys = make_uniq<BoundOrderModifier>();
new_agg_expr->order_bys->orders.push_back(std::move(order_by_node));

return new_agg_expr;
return std::move(new_agg_expr);
}

//------------------------------------------------------------------------------
Expand Down Expand Up @@ -99,11 +100,17 @@ class HNSWTopKOptimizer : public OptimizerExtension {
return false;
}

auto &get = agg.children[0]->Cast<LogicalGet>();
auto &get_ptr = agg.children[0];
auto &get = get_ptr->Cast<LogicalGet>();
if (get.function.name != "seq_scan") {
return false;
}

if (get.dynamic_filters && get.dynamic_filters->HasFilters()) {
// Cant push down!
return false;
}

// Get the table
auto &table = *get.GetTable();
if (!table.IsDuckTable()) {
Expand Down Expand Up @@ -175,6 +182,39 @@ class HNSWTopKOptimizer : public OptimizerExtension {
// Replace the aggregate with a list() aggregate function ordered by the distance
agg.expressions[0] = CreateListOrderByExpr(context, col_expr->Copy(), dist_expr->Copy(),
agg_func_expr.filter ? agg_func_expr.filter->Copy() : nullptr);

if (get.table_filters.filters.empty()) {
return true;
}

// We need to pullup the filters from the table scan as our index scan does not support regular filter pushdown.
get.projection_ids.clear();
get.types.clear();

auto new_filter = make_uniq<LogicalFilter>();
auto &column_ids = get.GetColumnIds();
for (const auto &entry : get.table_filters.filters) {
idx_t column_id = entry.first;
auto &type = get.returned_types[column_id];
bool found = false;
for (idx_t i = 0; i < column_ids.size(); i++) {
if (column_ids[i] == column_id) {
column_id = i;
found = true;
break;
}
}
if (!found) {
throw InternalException("Could not find column id for filter");
}
auto column = make_uniq<BoundColumnRefExpression>(type, ColumnBinding(get.table_index, column_id));
new_filter->expressions.push_back(entry.second->ToExpression(*column));
}

new_filter->children.push_back(std::move(get_ptr));
new_filter->ResolveOperatorTypes();
get_ptr = std::move(new_filter);

return true;
}

Expand Down
8 changes: 8 additions & 0 deletions test/sql/hnsw/hnsw_lateral_join.test
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,11 @@ CREATE INDEX my_idx ON b USING HNSW (b_vec);

query IIIIII rowsort a_has_null
select * from a, lateral (select *, a_id as id_dup from b order by array_distance(a.a_vec, b.b_vec) limit 2);

# Test with a grouping function
query II rowsort
select a_id, list(b_str) from a, lateral (select *, a_id as id_dup from b order by array_distance(a.a_vec, b.b_vec) limit 2) GROUP BY a_id;
----
1 [a, b]
2 [b, a]
3 [a, b]
Loading
Loading