From 1520bb714955d7d5aed5e0160cd66cb61c7a5f05 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Sat, 8 Feb 2025 11:41:43 +0800 Subject: [PATCH 01/10] xtr Signed-off-by: BubbleCal --- rust/lance/src/dataset/scanner.rs | 57 ++++++++--- rust/lance/src/io/exec/knn.rs | 162 +++++++++++++++++++++++++++++- 2 files changed, 202 insertions(+), 17 deletions(-) diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 93cdf3ae34..1ced2a0e6f 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -2123,8 +2123,14 @@ impl Scanner { index: &[Index], filter_plan: &FilterPlan, ) -> Result> { + // we split the query procedure into two steps: + // 1. collect the candidates by vector searching on each query vector + // 2. scoring the candidates + + let prefilter_source = self.prefilter_source(filter_plan).await?; let dim = get_vector_dim(self.dataset.schema(), &q.column)?; - // split the query multivectors + + // collect the candidates let num_queries = q.key.len() / dim; let new_queries = (0..num_queries) .map(|i| q.key.slice(i * dim, dim)) @@ -2134,7 +2140,6 @@ impl Scanner { new_query }); let mut ann_nodes = Vec::with_capacity(new_queries.len()); - let prefilter_source = self.prefilter_source(filter_plan).await?; for query in new_queries { let ann_node = new_knn_exec( self.dataset.clone(), @@ -2142,7 +2147,39 @@ impl Scanner { &query, prefilter_source.clone(), )?; - ann_nodes.push(ann_node); + // for single query vector, it's possible to retrieve multiple rows with the same row id + // need to max-reduce the results + let schema = ann_node.schema(); + let group_expr = vec![( + expressions::col(ROW_ID, ann_node.schema().as_ref())?, + ROW_ID.to_string(), + )]; + // for now multivector is always with cosine distance so here convert the distance to `1 - distance` + let max_expr = AggregateExprBuilder::new( + functions_aggregate::min_max::max_udaf(), + vec![expressions::binary( + expressions::lit(1.0), + datafusion_expr::Operator::Minus, + expressions::cast( + expressions::col(DIST_COL, &schema)?, + &schema, + DataType::Float64, + )?, + &schema, + )?], + ) + .schema(schema.clone()) + .alias(DIST_COL) + .build()?; + let ann_node = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new_single(group_expr), + vec![Arc::new(max_expr)], + vec![None], + ann_node, + schema, + )?); + ann_nodes.push(ann_node as Arc); } let ann_node = Arc::new(UnionExec::new(ann_nodes)); let ann_node = Arc::new(RepartitionExec::try_new( @@ -2155,20 +2192,10 @@ impl Scanner { expressions::col(ROW_ID, schema.as_ref())?, ROW_ID.to_string(), )]; - // for now multivector is always with cosine distance so here convert the distance to `1 - distance` - // and calculate the sum across all rows with the same row id. + // calculate the sum across all rows with the same row id. let sum_expr = AggregateExprBuilder::new( functions_aggregate::sum::sum_udaf(), - vec![expressions::binary( - expressions::lit(1.0), - datafusion_expr::Operator::Minus, - expressions::cast( - expressions::col(DIST_COL, &schema)?, - &schema, - DataType::Float64, - )?, - &schema, - )?], + vec![expressions::col(DIST_COL, &schema)?], ) .schema(schema.clone()) .alias(DIST_COL) diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index e9ff1f823d..6050aba5b6 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -2,14 +2,16 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors use std::any::Any; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; -use arrow::datatypes::UInt32Type; +use arrow::datatypes::{Float32Type, UInt32Type, UInt64Type}; use arrow_array::{ builder::{ListBuilder, UInt32Builder}, cast::AsArray, ArrayRef, RecordBatch, StringArray, }; +use arrow_array::{Float32Array, UInt64Array}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion::error::{DataFusionError, Result as DataFusionResult}; use datafusion::physical_plan::PlanProperties; @@ -25,6 +27,7 @@ use datafusion_physical_expr::EquivalenceProperties; use futures::stream::repeat_with; use futures::{future, stream, StreamExt, TryFutureExt, TryStreamExt}; use itertools::Itertools; +use lance_core::ROW_ID; use lance_core::{utils::tokio::get_num_compute_intensive_cpus, ROW_ID_FIELD}; use lance_index::vector::{ flat::compute_distance, Query, DIST_COL, INDEX_UUID_COLUMN, PART_ID_COLUMN, @@ -36,7 +39,7 @@ use snafu::{location, Location}; use crate::dataset::Dataset; use crate::index::prefilter::{DatasetPreFilter, FilterLoader}; -use crate::index::vector::utils::get_vector_type; +use crate::index::vector::utils::{get_vector_dim, get_vector_type}; use crate::index::DatasetIndexInternalExt; use crate::{Error, Result}; use lance_arrow::*; @@ -650,6 +653,161 @@ impl ExecutionPlan for ANNIvfSubIndexExec { } } +#[derive(Debug)] +pub struct MultivectorScoringExec { + dataset: Arc, + inputs: Vec>, + query: Query, + dim: usize, + properties: PlanProperties, +} + +impl MultivectorScoringExec { + pub fn try_new( + dataset: Arc, + inputs: Vec>, + query: Query, + ) -> Result { + let dim = get_vector_dim(dataset.schema(), &query.column)?; + let properties = PlanProperties::new( + EquivalenceProperties::new(KNN_INDEX_SCHEMA.clone()), + Partitioning::RoundRobinBatch(1), + EmissionType::Incremental, + Boundedness::Bounded, + ); + + Ok(Self { + dataset, + inputs, + query, + dim, + properties, + }) + } +} + +impl DisplayAs for MultivectorScoringExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "MultivectorScoring: k={}", self.query.k) + } + } + } +} + +impl ExecutionPlan for MultivectorScoringExec { + fn name(&self) -> &str { + "MultivectorScoringExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> arrow_schema::SchemaRef { + self.inputs[0].schema() + } + + fn children(&self) -> Vec<&Arc> { + self.inputs.iter().collect() + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> DataFusionResult> { + let plan = + MultivectorScoringExec::try_new(self.dataset.clone(), children, self.query.clone())?; + Ok(Arc::new(plan)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> DataFusionResult { + let inputs = self + .inputs + .iter() + .map(|input| input.execute(partition, context.clone())) + .collect::>>()?; + + let mut results = HashMap::with_capacity(self.query.k); + + let num_queries = self.query.key.len() / self.dim; + let mut min_dists = vec![0; num_queries]; + + // collect the top k results from each stream, + // and max-reduce for each query, + // records the minimum distance for each query as estimation. + let reduced_inputs = stream::select_all(inputs.into_iter().map(|stream| { + stream.map(|batch| { + let batch = batch?; + let row_ids = batch[ROW_ID].as_primitive::(); + let dists = batch[DIST_COL].as_primitive::(); + + // max-reduce for the same row id + let min_dist = dists.values().last().cloned().unwrap_or(0.0); + let mut new_row_ids = Vec::with_capacity(row_ids.len()); + let mut new_dists = Vec::with_capacity(row_ids.len()); + let mut visited_row_ids = HashSet::with_capacity(row_ids.len()); + + for (row_id, dist) in row_ids.values().iter().zip(dists.values().iter()) { + if visited_row_ids.contains(row_id) { + continue; + } + visited_row_ids.insert(row_id); + new_row_ids.push(*row_id); + new_dists.push(*dist); + } + let new_row_ids = UInt64Array::from(new_row_ids); + let new_dists = Float32Array::from(new_dists); + + let batch = RecordBatch::try_new( + KNN_INDEX_SCHEMA.clone(), + vec![Arc::new(new_dists), Arc::new(new_row_ids)], + )?; + + Ok((min_dist, batch)) + }) + })); + + let stream = stream::once(async move { + let mut results = HashMap::with_capacity(self.query.k); + while let Some(res) = reduced_inputs.next().await { + let (min_dist, batch) = res?; + let row_ids = batch[ROW_ID].as_primitive::(); + let dists = batch[DIST_COL].as_primitive::(); + } + }); + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + stream.boxed(), + ))) + } + + fn statistics(&self) -> DataFusionResult { + let inner_stats = self.input.statistics()?; + let dist_col_stats = inner_stats.column_statistics[0].clone(); + let column_statistics = inner_stats + .column_statistics + .into_iter() + .chain([dist_col_stats]) + .collect::>(); + Ok(Statistics { + num_rows: inner_stats.num_rows, + column_statistics, + ..Statistics::new_unknown(self.schema().as_ref()) + }) + // self.input.statistics() + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } +} + #[cfg(test)] mod tests { use super::*; From 0525f33c48c8af952a1aaa83e5d62e3b5232aa00 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Sat, 8 Feb 2025 14:55:47 +0800 Subject: [PATCH 02/10] perf: implement XTR for retrieving multivector Signed-off-by: BubbleCal --- rust/lance/src/dataset/scanner.rs | 75 ++++---------------------- rust/lance/src/io/exec/knn.rs | 90 ++++++++++++++++++++----------- 2 files changed, 71 insertions(+), 94 deletions(-) diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 1ced2a0e6f..5a65681688 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -63,6 +63,7 @@ use crate::index::scalar::detect_scalar_index_type; use crate::index::vector::utils::{get_vector_dim, get_vector_type}; use crate::index::DatasetIndexInternalExt; use crate::io::exec::fts::{FlatFtsExec, FtsExec}; +use crate::io::exec::knn::MultivectorScoringExec; use crate::io::exec::scalar_index::{MaterializeIndexExec, ScalarIndexExec}; use crate::io::exec::{get_physical_optimizer, LanceScanConfig}; use crate::io::exec::{ @@ -1656,13 +1657,13 @@ impl Scanner { // Find all deltas with the same index name. let deltas = self.dataset.load_indices_by_name(&index.name).await?; - let (ann_node, is_multivec) = match vector_type { - DataType::FixedSizeList(_, _) => (self.ann(q, &deltas, filter_plan).await?, false), - DataType::List(_) => (self.multivec_ann(q, &deltas, filter_plan).await?, true), + let ann_node = match vector_type { + DataType::FixedSizeList(_, _) => self.ann(q, &deltas, filter_plan).await?, + DataType::List(_) => self.multivec_ann(q, &deltas, filter_plan).await?, _ => unreachable!(), }; - let mut knn_node = if q.refine_factor.is_some() || is_multivec { + let mut knn_node = if q.refine_factor.is_some() { let vector_projection = self .dataset .empty_projection() @@ -2130,7 +2131,6 @@ impl Scanner { let prefilter_source = self.prefilter_source(filter_plan).await?; let dim = get_vector_dim(self.dataset.schema(), &q.column)?; - // collect the candidates let num_queries = q.key.len() / dim; let new_queries = (0..num_queries) .map(|i| q.key.slice(i * dim, dim)) @@ -2147,72 +2147,19 @@ impl Scanner { &query, prefilter_source.clone(), )?; - // for single query vector, it's possible to retrieve multiple rows with the same row id - // need to max-reduce the results - let schema = ann_node.schema(); - let group_expr = vec![( - expressions::col(ROW_ID, ann_node.schema().as_ref())?, - ROW_ID.to_string(), - )]; - // for now multivector is always with cosine distance so here convert the distance to `1 - distance` - let max_expr = AggregateExprBuilder::new( - functions_aggregate::min_max::max_udaf(), - vec![expressions::binary( - expressions::lit(1.0), - datafusion_expr::Operator::Minus, - expressions::cast( - expressions::col(DIST_COL, &schema)?, - &schema, - DataType::Float64, - )?, - &schema, - )?], - ) - .schema(schema.clone()) - .alias(DIST_COL) - .build()?; - let ann_node = Arc::new(AggregateExec::try_new( - AggregateMode::Single, - PhysicalGroupBy::new_single(group_expr), - vec![Arc::new(max_expr)], - vec![None], - ann_node, - schema, - )?); ann_nodes.push(ann_node as Arc); } - let ann_node = Arc::new(UnionExec::new(ann_nodes)); - let ann_node = Arc::new(RepartitionExec::try_new( - ann_node, - datafusion::physical_plan::Partitioning::RoundRobinBatch(1), - )?); - let schema = ann_node.schema(); - // unique by row ids, and get the min distance although it is not used. - let group_expr = vec![( - expressions::col(ROW_ID, schema.as_ref())?, - ROW_ID.to_string(), - )]; - // calculate the sum across all rows with the same row id. - let sum_expr = AggregateExprBuilder::new( - functions_aggregate::sum::sum_udaf(), - vec![expressions::col(DIST_COL, &schema)?], - ) - .schema(schema.clone()) - .alias(DIST_COL) - .build()?; - let ann_node: Arc = Arc::new(AggregateExec::try_new( - AggregateMode::Single, - PhysicalGroupBy::new_single(group_expr), - vec![Arc::new(sum_expr)], - vec![None], - ann_node, - schema, + + let ann_node = Arc::new(MultivectorScoringExec::try_new( + self.dataset.clone(), + ann_nodes, + q.clone(), )?); let sort_expr = PhysicalSortExpr { expr: expressions::col(DIST_COL, ann_node.schema().as_ref())?, options: SortOptions { - descending: true, + descending: false, nulls_first: false, }, }; diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index 6050aba5b6..d26ce6f14c 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -39,7 +39,7 @@ use snafu::{location, Location}; use crate::dataset::Dataset; use crate::index::prefilter::{DatasetPreFilter, FilterLoader}; -use crate::index::vector::utils::{get_vector_dim, get_vector_type}; +use crate::index::vector::utils::get_vector_type; use crate::index::DatasetIndexInternalExt; use crate::{Error, Result}; use lance_arrow::*; @@ -656,9 +656,9 @@ impl ExecutionPlan for ANNIvfSubIndexExec { #[derive(Debug)] pub struct MultivectorScoringExec { dataset: Arc, + // the inputs are sorted ANN search results inputs: Vec>, query: Query, - dim: usize, properties: PlanProperties, } @@ -668,7 +668,6 @@ impl MultivectorScoringExec { inputs: Vec>, query: Query, ) -> Result { - let dim = get_vector_dim(dataset.schema(), &query.column)?; let properties = PlanProperties::new( EquivalenceProperties::new(KNN_INDEX_SCHEMA.clone()), Partitioning::RoundRobinBatch(1), @@ -680,7 +679,6 @@ impl MultivectorScoringExec { dataset, inputs, query, - dim, properties, }) } @@ -706,7 +704,7 @@ impl ExecutionPlan for MultivectorScoringExec { } fn schema(&self) -> arrow_schema::SchemaRef { - self.inputs[0].schema() + KNN_INDEX_SCHEMA.clone() } fn children(&self) -> Vec<&Arc> { @@ -715,7 +713,7 @@ impl ExecutionPlan for MultivectorScoringExec { fn with_new_children( self: Arc, - mut children: Vec>, + children: Vec>, ) -> DataFusionResult> { let plan = MultivectorScoringExec::try_new(self.dataset.clone(), children, self.query.clone())?; @@ -733,53 +731,91 @@ impl ExecutionPlan for MultivectorScoringExec { .map(|input| input.execute(partition, context.clone())) .collect::>>()?; - let mut results = HashMap::with_capacity(self.query.k); - - let num_queries = self.query.key.len() / self.dim; - let mut min_dists = vec![0; num_queries]; - // collect the top k results from each stream, // and max-reduce for each query, // records the minimum distance for each query as estimation. - let reduced_inputs = stream::select_all(inputs.into_iter().map(|stream| { + let mut reduced_inputs = stream::select_all(inputs.into_iter().map(|stream| { stream.map(|batch| { let batch = batch?; let row_ids = batch[ROW_ID].as_primitive::(); let dists = batch[DIST_COL].as_primitive::(); // max-reduce for the same row id - let min_dist = dists.values().last().cloned().unwrap_or(0.0); + let min_sim = dists.values().last().map(|dist| 1.0 - *dist).unwrap_or(0.0); let mut new_row_ids = Vec::with_capacity(row_ids.len()); - let mut new_dists = Vec::with_capacity(row_ids.len()); + let mut new_sims = Vec::with_capacity(row_ids.len()); let mut visited_row_ids = HashSet::with_capacity(row_ids.len()); for (row_id, dist) in row_ids.values().iter().zip(dists.values().iter()) { + // the results are sorted by distance, so we can skip if we have seen this row id before if visited_row_ids.contains(row_id) { continue; } visited_row_ids.insert(row_id); new_row_ids.push(*row_id); - new_dists.push(*dist); + // it's cosine distance, so we need to convert it to similarity + new_sims.push(1.0 - *dist); } let new_row_ids = UInt64Array::from(new_row_ids); - let new_dists = Float32Array::from(new_dists); + let new_dists = Float32Array::from(new_sims); let batch = RecordBatch::try_new( KNN_INDEX_SCHEMA.clone(), vec![Arc::new(new_dists), Arc::new(new_row_ids)], )?; - Ok((min_dist, batch)) + Ok::<_, DataFusionError>((min_sim, batch)) }) })); + let k = self.query.k; + let refactor = self.query.refine_factor.unwrap_or(1) as usize; let stream = stream::once(async move { - let mut results = HashMap::with_capacity(self.query.k); - while let Some(res) = reduced_inputs.next().await { - let (min_dist, batch) = res?; + // at most, we will have k * refine_factor results for each query + let mut results = HashMap::with_capacity(k * refactor); + let mut missed_similarities = 0.0; + while let Some((min_sim, batch)) = reduced_inputs.try_next().await? { let row_ids = batch[ROW_ID].as_primitive::(); - let dists = batch[DIST_COL].as_primitive::(); + let sims = batch[DIST_COL].as_primitive::(); + + let query_results = row_ids + .values() + .iter() + .copied() + .zip(sims.values().iter().copied()) + .collect::>(); + results.iter_mut().for_each(|(row_id, sim)| { + if let Some(new_dist) = query_results.get(row_id) { + *sim += new_dist; + } else { + *sim += min_sim; + } + }); + query_results.into_iter().for_each(|(row_id, sim)| { + if !results.contains_key(&row_id) { + results.insert(row_id, sim + missed_similarities); + } + }); + missed_similarities += min_sim; } + + let results = results.into_iter().collect::>(); + let row_ids = results + .iter() + .map(|(row_id, _)| *row_id) + .collect::>(); + let dists = results + .iter() + // it's similarity, so we need to convert it back to distance + .map(|(_, sim)| 1.0 - *sim) + .collect::>(); + let row_ids = UInt64Array::from(row_ids); + let dists = Float32Array::from(dists); + let batch = RecordBatch::try_new( + KNN_INDEX_SCHEMA.clone(), + vec![Arc::new(dists), Arc::new(row_ids)], + )?; + Ok::<_, DataFusionError>(batch) }); Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), @@ -788,16 +824,10 @@ impl ExecutionPlan for MultivectorScoringExec { } fn statistics(&self) -> DataFusionResult { - let inner_stats = self.input.statistics()?; - let dist_col_stats = inner_stats.column_statistics[0].clone(); - let column_statistics = inner_stats - .column_statistics - .into_iter() - .chain([dist_col_stats]) - .collect::>(); Ok(Statistics { - num_rows: inner_stats.num_rows, - column_statistics, + num_rows: Precision::Inexact( + self.query.k * self.query.refine_factor.unwrap_or(1) as usize, + ), ..Statistics::new_unknown(self.schema().as_ref()) }) // self.input.statistics() From dbfa98b08436ffdca71ce992024fec7178fc2d81 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Sat, 8 Feb 2025 15:31:36 +0800 Subject: [PATCH 03/10] fix Signed-off-by: BubbleCal --- rust/lance/src/io/exec/knn.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index d26ce6f14c..017df7bbc1 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -715,8 +715,7 @@ impl ExecutionPlan for MultivectorScoringExec { self: Arc, children: Vec>, ) -> DataFusionResult> { - let plan = - MultivectorScoringExec::try_new(self.dataset.clone(), children, self.query.clone())?; + let plan = Self::try_new(self.dataset.clone(), children, self.query.clone())?; Ok(Arc::new(plan)) } @@ -792,9 +791,7 @@ impl ExecutionPlan for MultivectorScoringExec { } }); query_results.into_iter().for_each(|(row_id, sim)| { - if !results.contains_key(&row_id) { - results.insert(row_id, sim + missed_similarities); - } + results.entry(row_id).or_insert(sim + missed_similarities); }); missed_similarities += min_sim; } From dc7cc2489843cb89b534693a63afff57d4a02854 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Sat, 8 Feb 2025 16:27:16 +0800 Subject: [PATCH 04/10] remove unused lines Signed-off-by: BubbleCal --- rust/lance/src/io/exec/knn.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index 017df7bbc1..93ab066b1c 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -827,7 +827,6 @@ impl ExecutionPlan for MultivectorScoringExec { ), ..Statistics::new_unknown(self.schema().as_ref()) }) - // self.input.statistics() } fn properties(&self) -> &PlanProperties { From 417d7764e075dd18d094f5357eb7592101572a75 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Mon, 10 Feb 2025 17:41:37 +0800 Subject: [PATCH 05/10] fix Signed-off-by: BubbleCal --- rust/lance-linalg/src/distance.rs | 5 +++-- rust/lance/src/dataset/scanner.rs | 17 +++++++++++++++++ rust/lance/src/index/vector/ivf/v2.rs | 6 ++++-- rust/lance/src/io/exec/knn.rs | 14 +++++--------- 4 files changed, 29 insertions(+), 13 deletions(-) diff --git a/rust/lance-linalg/src/distance.rs b/rust/lance-linalg/src/distance.rs index a5575ec061..6e79c7d8b0 100644 --- a/rust/lance-linalg/src/distance.rs +++ b/rust/lance-linalg/src/distance.rs @@ -175,6 +175,7 @@ pub fn multivec_distance( }) .unwrap_or(f32::NAN) }) + .map(|sim| 1.0 - sim) .collect(); Ok(dists) } @@ -197,8 +198,8 @@ where .as_primitive::() .values() .chunks_exact(dim) - .map(|v| distance_type.func()(q, v)) - .min_by(|a, b| a.partial_cmp(b).unwrap()) + .map(|v| 1.0 - distance_type.func()(q, v)) + .max_by(|a, b| a.total_cmp(b)) .unwrap() }) .sum() diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 5a65681688..8ed24b5121 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -2128,6 +2128,8 @@ impl Scanner { // 1. collect the candidates by vector searching on each query vector // 2. scoring the candidates + let over_fetch_factor = q.refine_factor.unwrap_or(10); + let prefilter_source = self.prefilter_source(filter_plan).await?; let dim = get_vector_dim(self.dataset.schema(), &q.column)?; @@ -2137,6 +2139,10 @@ impl Scanner { .map(|query_vec| { let mut new_query = q.clone(); new_query.key = query_vec; + // with XTR, we don't need to refine the result with original vectors, + // but here we really need to over-fetch the candidates to reach good enough recall. + // TODO: improve the recall with WARP, expose this parameter to the users. + new_query.refine_factor = Some(over_fetch_factor); new_query }); let mut ann_nodes = Vec::with_capacity(new_queries.len()); @@ -2147,6 +2153,17 @@ impl Scanner { &query, prefilter_source.clone(), )?; + let sort_expr = PhysicalSortExpr { + expr: expressions::col(DIST_COL, ann_node.schema().as_ref())?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }; + let ann_node = Arc::new( + SortExec::new(LexOrdering::new(vec![sort_expr]), ann_node) + .with_fetch(Some(q.k * over_fetch_factor as usize)), + ); ann_nodes.push(ann_node as Arc); } diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index ed313ea30e..8aa9ba6616 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -683,7 +683,7 @@ mod tests { .into_iter() .enumerate() .map(|(i, dist)| (dist, i as u64)) - .sorted_by(|a, b| a.0.partial_cmp(&b.0).unwrap()) + .sorted_by(|a, b| a.0.total_cmp(&b.0)) .take(k) .collect() } @@ -991,6 +991,8 @@ mod tests { } async fn test_index_multivec(params: VectorIndexParams, nlist: usize, recall_requirement: f32) { + // we introduce XTR for performance, which would reduce the recall a little bit + let recall_requirement = recall_requirement * 0.9; match params.metric_type { DistanceType::Hamming => { test_index_multivec_impl::(params, nlist, recall_requirement, 0..2) @@ -1061,7 +1063,7 @@ mod tests { let gt = multivec_ground_truth(&vectors, &query, k, params.metric_type); let gt_set = gt.iter().map(|r| r.1).collect::>(); - let recall = row_ids.intersection(>_set).count() as f32 / 10.0; + let recall = row_ids.intersection(>_set).count() as f32 / 100.0; assert!( recall >= recall_requirement, "recall: {}\n results: {:?}\n\ngt: {:?}", diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index 93ab066b1c..f2616c56d5 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -740,7 +740,7 @@ impl ExecutionPlan for MultivectorScoringExec { let dists = batch[DIST_COL].as_primitive::(); // max-reduce for the same row id - let min_sim = dists.values().last().map(|dist| 1.0 - *dist).unwrap_or(0.0); + let min_sim = 1.0 - dists.values().last().copied().unwrap_or(2.0); let mut new_row_ids = Vec::with_capacity(row_ids.len()); let mut new_sims = Vec::with_capacity(row_ids.len()); let mut visited_row_ids = HashSet::with_capacity(row_ids.len()); @@ -796,15 +796,11 @@ impl ExecutionPlan for MultivectorScoringExec { missed_similarities += min_sim; } - let results = results.into_iter().collect::>(); - let row_ids = results - .iter() - .map(|(row_id, _)| *row_id) - .collect::>(); - let dists = results - .iter() + let (row_ids, sims): (Vec<_>, Vec<_>) = results.into_iter().unzip(); + let dists = sims + .into_iter() // it's similarity, so we need to convert it back to distance - .map(|(_, sim)| 1.0 - *sim) + .map(|sim| 1.0 - sim) .collect::>(); let row_ids = UInt64Array::from(row_ids); let dists = Float32Array::from(dists); From c6aa61568d84100df4a8a38d6b6c159d34469ae0 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Tue, 11 Feb 2025 12:08:35 +0800 Subject: [PATCH 06/10] fix Signed-off-by: BubbleCal --- rust/lance/src/io/exec/knn.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index f2616c56d5..3f3e6e4fa9 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -11,7 +11,7 @@ use arrow_array::{ cast::AsArray, ArrayRef, RecordBatch, StringArray, }; -use arrow_array::{Float32Array, UInt64Array}; +use arrow_array::{Array, Float32Array, UInt64Array}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion::error::{DataFusionError, Result as DataFusionResult}; use datafusion::physical_plan::PlanProperties; @@ -738,6 +738,7 @@ impl ExecutionPlan for MultivectorScoringExec { let batch = batch?; let row_ids = batch[ROW_ID].as_primitive::(); let dists = batch[DIST_COL].as_primitive::(); + debug_assert_eq!(dists.null_count(), 0); // max-reduce for the same row id let min_sim = 1.0 - dists.values().last().copied().unwrap_or(2.0); @@ -783,6 +784,12 @@ impl ExecutionPlan for MultivectorScoringExec { .copied() .zip(sims.values().iter().copied()) .collect::>(); + + // for a row `r`: + // if `r` is in only `results``, then `results[r] += min_sim` + // if `r` is in only `query_results`, then `results[r] = query_results[r] + missed_similarities`, + // here `missed_similarities` is the sum of `min_sim` from previous iterations + // if `r` is in both, then `results[r] += query_results[r]` results.iter_mut().for_each(|(row_id, sim)| { if let Some(new_dist) = query_results.get(row_id) { *sim += new_dist; From f573c260541dc7830f260a66f920302ca806401d Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 12 Feb 2025 00:06:05 +0800 Subject: [PATCH 07/10] more comments Signed-off-by: BubbleCal --- rust/lance/src/dataset/scanner.rs | 6 +++++- rust/lance/src/io/exec/knn.rs | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 8ed24b5121..2e0ab7ca35 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -89,6 +89,9 @@ pub const LEGACY_DEFAULT_FRAGMENT_READAHEAD: usize = 4; lazy_static::lazy_static! { pub static ref DEFAULT_FRAGMENT_READAHEAD: Option = std::env::var("LANCE_DEFAULT_FRAGMENT_READAHEAD") .map(|val| Some(val.parse().unwrap())).unwrap_or(None); + + pub static ref DEFAULT_XTR_OVERFETCH: u32 = std::env::var("LANCE_XTR_OVERFETCH") + .map(|val| val.parse().unwrap()).unwrap_or(10); } // We want to support ~256 concurrent reads to maximize throughput on cloud storage systems @@ -2128,7 +2131,7 @@ impl Scanner { // 1. collect the candidates by vector searching on each query vector // 2. scoring the candidates - let over_fetch_factor = q.refine_factor.unwrap_or(10); + let over_fetch_factor = *DEFAULT_XTR_OVERFETCH; let prefilter_source = self.prefilter_source(filter_plan).await?; let dim = get_vector_dim(self.dataset.schema(), &q.column)?; @@ -2147,6 +2150,7 @@ impl Scanner { }); let mut ann_nodes = Vec::with_capacity(new_queries.len()); for query in new_queries { + // this produces `nprobes * k * over_fetch_factor * num_indices` candidates let ann_node = new_knn_exec( self.dataset.clone(), index, diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index 3f3e6e4fa9..1dd4965462 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -741,7 +741,11 @@ impl ExecutionPlan for MultivectorScoringExec { debug_assert_eq!(dists.null_count(), 0); // max-reduce for the same row id - let min_sim = 1.0 - dists.values().last().copied().unwrap_or(2.0); + let min_sim = dists + .values() + .last() + .map(|dist| 1.0 - *dist) + .unwrap_or_default(); let mut new_row_ids = Vec::with_capacity(row_ids.len()); let mut new_sims = Vec::with_capacity(row_ids.len()); let mut visited_row_ids = HashSet::with_capacity(row_ids.len()); From a598113eeb7b8570cdd44dc8e505a04781bdfb59 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 19 Feb 2025 14:33:04 +0800 Subject: [PATCH 08/10] test scoring Signed-off-by: BubbleCal --- rust/lance/src/dataset/scanner.rs | 6 +- rust/lance/src/io/exec/knn.rs | 93 +++++++++++++++++++++++++++---- rust/lance/src/io/exec/testing.rs | 6 +- 3 files changed, 88 insertions(+), 17 deletions(-) diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 6d9370a8c1..266341ada2 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -2212,11 +2212,7 @@ impl Scanner { ann_nodes.push(ann_node as Arc); } - let ann_node = Arc::new(MultivectorScoringExec::try_new( - self.dataset.clone(), - ann_nodes, - q.clone(), - )?); + let ann_node = Arc::new(MultivectorScoringExec::try_new(ann_nodes, q.clone())?); let sort_expr = PhysicalSortExpr { expr: expressions::col(DIST_COL, ann_node.schema().as_ref())?, diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index 65ef47b3c2..961e15d64c 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -668,7 +668,6 @@ impl ExecutionPlan for ANNIvfSubIndexExec { #[derive(Debug)] pub struct MultivectorScoringExec { - dataset: Arc, // the inputs are sorted ANN search results inputs: Vec>, query: Query, @@ -676,11 +675,7 @@ pub struct MultivectorScoringExec { } impl MultivectorScoringExec { - pub fn try_new( - dataset: Arc, - inputs: Vec>, - query: Query, - ) -> Result { + pub fn try_new(inputs: Vec>, query: Query) -> Result { let properties = PlanProperties::new( EquivalenceProperties::new(KNN_INDEX_SCHEMA.clone()), Partitioning::RoundRobinBatch(1), @@ -689,7 +684,6 @@ impl MultivectorScoringExec { ); Ok(Self { - dataset, inputs, query, properties, @@ -728,7 +722,7 @@ impl ExecutionPlan for MultivectorScoringExec { self: Arc, children: Vec>, ) -> DataFusionResult> { - let plan = Self::try_new(self.dataset.clone(), children, self.query.clone())?; + let plan = Self::try_new(children, self.query.clone())?; Ok(Arc::new(plan)) } @@ -790,7 +784,7 @@ impl ExecutionPlan for MultivectorScoringExec { let stream = stream::once(async move { // at most, we will have k * refine_factor results for each query let mut results = HashMap::with_capacity(k * refactor); - let mut missed_similarities = 0.0; + let mut missed_sim_sum = 0.0; while let Some((min_sim, batch)) = reduced_inputs.try_next().await? { let row_ids = batch[ROW_ID].as_primitive::(); let sims = batch[DIST_COL].as_primitive::(); @@ -815,9 +809,9 @@ impl ExecutionPlan for MultivectorScoringExec { } }); query_results.into_iter().for_each(|(row_id, sim)| { - results.entry(row_id).or_insert(sim + missed_similarities); + results.entry(row_id).or_insert(sim + missed_sim_sum); }); - missed_similarities += min_sim; + missed_sim_sum += min_sim; } let (row_ids, sims): (Vec<_>, Vec<_>) = results.into_iter().unzip(); @@ -993,4 +987,81 @@ mod tests { ]) ); } + + #[tokio::test] + async fn test_multivector_score() { + let query = Query { + column: "vector".to_string(), + key: Arc::new(generate_random_array(1)), + k: 10, + lower_bound: None, + upper_bound: None, + nprobes: 1, + ef: None, + refine_factor: None, + metric_type: DistanceType::Cosine, + use_index: true, + }; + + async fn multivector_scoring( + inputs: Vec>, + query: Query, + ) -> Result> { + let ctx = Arc::new(datafusion::execution::context::TaskContext::default()); + let plan = MultivectorScoringExec::try_new(inputs, query.clone())?; + let batches = plan + .execute(0, ctx.clone()) + .unwrap() + .try_collect::>() + .await?; + let mut results = HashMap::new(); + for batch in batches { + let row_ids = batch + .column_by_name(ROW_ID) + .unwrap() + .as_primitive::(); + let dists = batch + .column_by_name(DIST_COL) + .unwrap() + .as_primitive::(); + for (row_id, dist) in row_ids.values().iter().zip(dists.values().iter()) { + results.insert(*row_id, *dist); + } + } + Ok(results) + } + + let batches = (0..3) + .map(|i| { + RecordBatch::try_new( + KNN_INDEX_SCHEMA.clone(), + vec![ + Arc::new(Float32Array::from(vec![i as f32 + 1.0, i as f32 + 2.0])), + Arc::new(UInt64Array::from(vec![i + 1, i + 2])), + ], + ) + .unwrap() + }) + .collect::>(); + + let mut res: Option> = None; + for perm in batches.into_iter().permutations(3) { + let inputs = perm + .into_iter() + .map(|batch| { + let input: Arc = + Arc::new(TestingExec::new(vec![batch.clone()])); + input + }) + .collect::>(); + let new_res = multivector_scoring(inputs, query.clone()).await.unwrap(); + if let Some(res) = &res { + for (row_id, dist) in new_res.iter() { + assert_eq!(res.get(row_id).unwrap(), dist) + } + } else { + res = Some(new_res); + } + } + } } diff --git a/rust/lance/src/io/exec/testing.rs b/rust/lance/src/io/exec/testing.rs index 611cf5480b..a9985d6a74 100644 --- a/rust/lance/src/io/exec/testing.rs +++ b/rust/lance/src/io/exec/testing.rs @@ -8,6 +8,7 @@ use std::any::Any; use std::sync::Arc; use arrow_array::RecordBatch; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::{ common::Statistics, execution::context::TaskContext, @@ -17,6 +18,7 @@ use datafusion::{ }, }; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use futures::StreamExt; #[derive(Debug)] pub struct TestingExec { @@ -76,7 +78,9 @@ impl ExecutionPlan for TestingExec { _partition: usize, _context: Arc, ) -> datafusion::error::Result { - todo!() + let stream = futures::stream::iter(self.batches.clone().into_iter().map(|batch| Ok(batch))); + let stream = RecordBatchStreamAdapter::new(self.schema(), stream.boxed()); + Ok(Box::pin(stream)) } fn statistics(&self) -> datafusion::error::Result { From 955471ba794801ecc948b0ffbbdba2d2b7f292b5 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 19 Feb 2025 14:36:27 +0800 Subject: [PATCH 09/10] more assertion Signed-off-by: BubbleCal --- rust/lance/src/io/exec/knn.rs | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index 961e15d64c..9a552d5443 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -1016,14 +1016,8 @@ mod tests { .await?; let mut results = HashMap::new(); for batch in batches { - let row_ids = batch - .column_by_name(ROW_ID) - .unwrap() - .as_primitive::(); - let dists = batch - .column_by_name(DIST_COL) - .unwrap() - .as_primitive::(); + let row_ids = batch[ROW_ID].as_primitive::(); + let dists = batch[DIST_COL].as_primitive::(); for (row_id, dist) in row_ids.values().iter().zip(dists.values().iter()) { results.insert(*row_id, *dist); } @@ -1055,6 +1049,7 @@ mod tests { }) .collect::>(); let new_res = multivector_scoring(inputs, query.clone()).await.unwrap(); + assert_eq!(new_res.len(), 4); if let Some(res) = &res { for (row_id, dist) in new_res.iter() { assert_eq!(res.get(row_id).unwrap(), dist) From 940fb85e5cd695f8e1f41eb540625c7331e7b129 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 19 Feb 2025 14:41:20 +0800 Subject: [PATCH 10/10] fmt Signed-off-by: BubbleCal --- rust/lance/src/io/exec/knn.rs | 3 +-- rust/lance/src/io/exec/testing.rs | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index 9a552d5443..779189bbc8 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -1043,8 +1043,7 @@ mod tests { let inputs = perm .into_iter() .map(|batch| { - let input: Arc = - Arc::new(TestingExec::new(vec![batch.clone()])); + let input: Arc = Arc::new(TestingExec::new(vec![batch])); input }) .collect::>(); diff --git a/rust/lance/src/io/exec/testing.rs b/rust/lance/src/io/exec/testing.rs index a9985d6a74..23a69ed7d0 100644 --- a/rust/lance/src/io/exec/testing.rs +++ b/rust/lance/src/io/exec/testing.rs @@ -78,7 +78,7 @@ impl ExecutionPlan for TestingExec { _partition: usize, _context: Arc, ) -> datafusion::error::Result { - let stream = futures::stream::iter(self.batches.clone().into_iter().map(|batch| Ok(batch))); + let stream = futures::stream::iter(self.batches.clone().into_iter().map(Ok)); let stream = RecordBatchStreamAdapter::new(self.schema(), stream.boxed()); Ok(Box::pin(stream)) }