diff --git a/rust/lance-linalg/src/distance.rs b/rust/lance-linalg/src/distance.rs index a5575ec061..6e79c7d8b0 100644 --- a/rust/lance-linalg/src/distance.rs +++ b/rust/lance-linalg/src/distance.rs @@ -175,6 +175,7 @@ pub fn multivec_distance( }) .unwrap_or(f32::NAN) }) + .map(|sim| 1.0 - sim) .collect(); Ok(dists) } @@ -197,8 +198,8 @@ where .as_primitive::() .values() .chunks_exact(dim) - .map(|v| distance_type.func()(q, v)) - .min_by(|a, b| a.partial_cmp(b).unwrap()) + .map(|v| 1.0 - distance_type.func()(q, v)) + .max_by(|a, b| a.total_cmp(b)) .unwrap() }) .sum() diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 9c36f17999..266341ada2 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -63,6 +63,7 @@ use crate::index::scalar::detect_scalar_index_type; use crate::index::vector::utils::{get_vector_dim, get_vector_type}; use crate::index::DatasetIndexInternalExt; use crate::io::exec::fts::{FlatFtsExec, FtsExec}; +use crate::io::exec::knn::MultivectorScoringExec; use crate::io::exec::scalar_index::{MaterializeIndexExec, ScalarIndexExec}; use crate::io::exec::{get_physical_optimizer, LanceScanConfig}; use crate::io::exec::{ @@ -88,6 +89,9 @@ pub const LEGACY_DEFAULT_FRAGMENT_READAHEAD: usize = 4; lazy_static::lazy_static! { pub static ref DEFAULT_FRAGMENT_READAHEAD: Option = std::env::var("LANCE_DEFAULT_FRAGMENT_READAHEAD") .map(|val| Some(val.parse().unwrap())).unwrap_or(None); + + pub static ref DEFAULT_XTR_OVERFETCH: u32 = std::env::var("LANCE_XTR_OVERFETCH") + .map(|val| val.parse().unwrap()).unwrap_or(10); } // We want to support ~256 concurrent reads to maximize throughput on cloud storage systems @@ -1697,13 +1701,13 @@ impl Scanner { // Find all deltas with the same index name. let deltas = self.dataset.load_indices_by_name(&index.name).await?; - let (ann_node, is_multivec) = match vector_type { - DataType::FixedSizeList(_, _) => (self.ann(q, &deltas, filter_plan).await?, false), - DataType::List(_) => (self.multivec_ann(q, &deltas, filter_plan).await?, true), + let ann_node = match vector_type { + DataType::FixedSizeList(_, _) => self.ann(q, &deltas, filter_plan).await?, + DataType::List(_) => self.multivec_ann(q, &deltas, filter_plan).await?, _ => unreachable!(), }; - let mut knn_node = if q.refine_factor.is_some() || is_multivec { + let mut knn_node = if q.refine_factor.is_some() { let vector_projection = self .dataset .empty_projection() @@ -2164,69 +2168,56 @@ impl Scanner { index: &[Index], filter_plan: &FilterPlan, ) -> Result> { + // we split the query procedure into two steps: + // 1. collect the candidates by vector searching on each query vector + // 2. scoring the candidates + + let over_fetch_factor = *DEFAULT_XTR_OVERFETCH; + + let prefilter_source = self.prefilter_source(filter_plan).await?; let dim = get_vector_dim(self.dataset.schema(), &q.column)?; - // split the query multivectors + let num_queries = q.key.len() / dim; let new_queries = (0..num_queries) .map(|i| q.key.slice(i * dim, dim)) .map(|query_vec| { let mut new_query = q.clone(); new_query.key = query_vec; + // with XTR, we don't need to refine the result with original vectors, + // but here we really need to over-fetch the candidates to reach good enough recall. + // TODO: improve the recall with WARP, expose this parameter to the users. + new_query.refine_factor = Some(over_fetch_factor); new_query }); let mut ann_nodes = Vec::with_capacity(new_queries.len()); - let prefilter_source = self.prefilter_source(filter_plan).await?; for query in new_queries { + // this produces `nprobes * k * over_fetch_factor * num_indices` candidates let ann_node = new_knn_exec( self.dataset.clone(), index, &query, prefilter_source.clone(), )?; - ann_nodes.push(ann_node); + let sort_expr = PhysicalSortExpr { + expr: expressions::col(DIST_COL, ann_node.schema().as_ref())?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }; + let ann_node = Arc::new( + SortExec::new(LexOrdering::new(vec![sort_expr]), ann_node) + .with_fetch(Some(q.k * over_fetch_factor as usize)), + ); + ann_nodes.push(ann_node as Arc); } - let ann_node = Arc::new(UnionExec::new(ann_nodes)); - let ann_node = Arc::new(RepartitionExec::try_new( - ann_node, - datafusion::physical_plan::Partitioning::RoundRobinBatch(1), - )?); - let schema = ann_node.schema(); - // unique by row ids, and get the min distance although it is not used. - let group_expr = vec![( - expressions::col(ROW_ID, schema.as_ref())?, - ROW_ID.to_string(), - )]; - // for now multivector is always with cosine distance so here convert the distance to `1 - distance` - // and calculate the sum across all rows with the same row id. - let sum_expr = AggregateExprBuilder::new( - functions_aggregate::sum::sum_udaf(), - vec![expressions::binary( - expressions::lit(1.0), - datafusion_expr::Operator::Minus, - expressions::cast( - expressions::col(DIST_COL, &schema)?, - &schema, - DataType::Float64, - )?, - &schema, - )?], - ) - .schema(schema.clone()) - .alias(DIST_COL) - .build()?; - let ann_node: Arc = Arc::new(AggregateExec::try_new( - AggregateMode::Single, - PhysicalGroupBy::new_single(group_expr), - vec![Arc::new(sum_expr)], - vec![None], - ann_node, - schema, - )?); + + let ann_node = Arc::new(MultivectorScoringExec::try_new(ann_nodes, q.clone())?); let sort_expr = PhysicalSortExpr { expr: expressions::col(DIST_COL, ann_node.schema().as_ref())?, options: SortOptions { - descending: true, + descending: false, nulls_first: false, }, }; diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index 95e8379251..8f68529c91 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -683,7 +683,7 @@ mod tests { .into_iter() .enumerate() .map(|(i, dist)| (dist, i as u64)) - .sorted_by(|a, b| a.0.partial_cmp(&b.0).unwrap()) + .sorted_by(|a, b| a.0.total_cmp(&b.0)) .take(k) .collect() } @@ -991,6 +991,8 @@ mod tests { } async fn test_index_multivec(params: VectorIndexParams, nlist: usize, recall_requirement: f32) { + // we introduce XTR for performance, which would reduce the recall a little bit + let recall_requirement = recall_requirement * 0.9; match params.metric_type { DistanceType::Hamming => { test_index_multivec_impl::(params, nlist, recall_requirement, 0..2) @@ -1061,7 +1063,7 @@ mod tests { let gt = multivec_ground_truth(&vectors, &query, k, params.metric_type); let gt_set = gt.iter().map(|r| r.1).collect::>(); - let recall = row_ids.intersection(>_set).count() as f32 / 10.0; + let recall = row_ids.intersection(>_set).count() as f32 / 100.0; assert!( recall >= recall_requirement, "recall: {}\n results: {:?}\n\ngt: {:?}", diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index 6bde182545..779189bbc8 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -2,14 +2,16 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors use std::any::Any; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; -use arrow::datatypes::UInt32Type; +use arrow::datatypes::{Float32Type, UInt32Type, UInt64Type}; use arrow_array::{ builder::{ListBuilder, UInt32Builder}, cast::AsArray, ArrayRef, RecordBatch, StringArray, }; +use arrow_array::{Array, Float32Array, UInt64Array}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion::common::ColumnStatistics; use datafusion::error::{DataFusionError, Result as DataFusionResult}; @@ -26,6 +28,7 @@ use datafusion_physical_expr::EquivalenceProperties; use futures::stream::repeat_with; use futures::{future, stream, StreamExt, TryFutureExt, TryStreamExt}; use itertools::Itertools; +use lance_core::ROW_ID; use lance_core::{utils::tokio::get_num_compute_intensive_cpus, ROW_ID_FIELD}; use lance_index::vector::{ flat::compute_distance, Query, DIST_COL, INDEX_UUID_COLUMN, PART_ID_COLUMN, @@ -663,6 +666,188 @@ impl ExecutionPlan for ANNIvfSubIndexExec { } } +#[derive(Debug)] +pub struct MultivectorScoringExec { + // the inputs are sorted ANN search results + inputs: Vec>, + query: Query, + properties: PlanProperties, +} + +impl MultivectorScoringExec { + pub fn try_new(inputs: Vec>, query: Query) -> Result { + let properties = PlanProperties::new( + EquivalenceProperties::new(KNN_INDEX_SCHEMA.clone()), + Partitioning::RoundRobinBatch(1), + EmissionType::Incremental, + Boundedness::Bounded, + ); + + Ok(Self { + inputs, + query, + properties, + }) + } +} + +impl DisplayAs for MultivectorScoringExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "MultivectorScoring: k={}", self.query.k) + } + } + } +} + +impl ExecutionPlan for MultivectorScoringExec { + fn name(&self) -> &str { + "MultivectorScoringExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> arrow_schema::SchemaRef { + KNN_INDEX_SCHEMA.clone() + } + + fn children(&self) -> Vec<&Arc> { + self.inputs.iter().collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DataFusionResult> { + let plan = Self::try_new(children, self.query.clone())?; + Ok(Arc::new(plan)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> DataFusionResult { + let inputs = self + .inputs + .iter() + .map(|input| input.execute(partition, context.clone())) + .collect::>>()?; + + // collect the top k results from each stream, + // and max-reduce for each query, + // records the minimum distance for each query as estimation. + let mut reduced_inputs = stream::select_all(inputs.into_iter().map(|stream| { + stream.map(|batch| { + let batch = batch?; + let row_ids = batch[ROW_ID].as_primitive::(); + let dists = batch[DIST_COL].as_primitive::(); + debug_assert_eq!(dists.null_count(), 0); + + // max-reduce for the same row id + let min_sim = dists + .values() + .last() + .map(|dist| 1.0 - *dist) + .unwrap_or_default(); + let mut new_row_ids = Vec::with_capacity(row_ids.len()); + let mut new_sims = Vec::with_capacity(row_ids.len()); + let mut visited_row_ids = HashSet::with_capacity(row_ids.len()); + + for (row_id, dist) in row_ids.values().iter().zip(dists.values().iter()) { + // the results are sorted by distance, so we can skip if we have seen this row id before + if visited_row_ids.contains(row_id) { + continue; + } + visited_row_ids.insert(row_id); + new_row_ids.push(*row_id); + // it's cosine distance, so we need to convert it to similarity + new_sims.push(1.0 - *dist); + } + let new_row_ids = UInt64Array::from(new_row_ids); + let new_dists = Float32Array::from(new_sims); + + let batch = RecordBatch::try_new( + KNN_INDEX_SCHEMA.clone(), + vec![Arc::new(new_dists), Arc::new(new_row_ids)], + )?; + + Ok::<_, DataFusionError>((min_sim, batch)) + }) + })); + + let k = self.query.k; + let refactor = self.query.refine_factor.unwrap_or(1) as usize; + let stream = stream::once(async move { + // at most, we will have k * refine_factor results for each query + let mut results = HashMap::with_capacity(k * refactor); + let mut missed_sim_sum = 0.0; + while let Some((min_sim, batch)) = reduced_inputs.try_next().await? { + let row_ids = batch[ROW_ID].as_primitive::(); + let sims = batch[DIST_COL].as_primitive::(); + + let query_results = row_ids + .values() + .iter() + .copied() + .zip(sims.values().iter().copied()) + .collect::>(); + + // for a row `r`: + // if `r` is in only `results``, then `results[r] += min_sim` + // if `r` is in only `query_results`, then `results[r] = query_results[r] + missed_similarities`, + // here `missed_similarities` is the sum of `min_sim` from previous iterations + // if `r` is in both, then `results[r] += query_results[r]` + results.iter_mut().for_each(|(row_id, sim)| { + if let Some(new_dist) = query_results.get(row_id) { + *sim += new_dist; + } else { + *sim += min_sim; + } + }); + query_results.into_iter().for_each(|(row_id, sim)| { + results.entry(row_id).or_insert(sim + missed_sim_sum); + }); + missed_sim_sum += min_sim; + } + + let (row_ids, sims): (Vec<_>, Vec<_>) = results.into_iter().unzip(); + let dists = sims + .into_iter() + // it's similarity, so we need to convert it back to distance + .map(|sim| 1.0 - sim) + .collect::>(); + let row_ids = UInt64Array::from(row_ids); + let dists = Float32Array::from(dists); + let batch = RecordBatch::try_new( + KNN_INDEX_SCHEMA.clone(), + vec![Arc::new(dists), Arc::new(row_ids)], + )?; + Ok::<_, DataFusionError>(batch) + }); + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + stream.boxed(), + ))) + } + + fn statistics(&self) -> DataFusionResult { + Ok(Statistics { + num_rows: Precision::Inexact( + self.query.k * self.query.refine_factor.unwrap_or(1) as usize, + ), + ..Statistics::new_unknown(self.schema().as_ref()) + }) + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } +} + #[cfg(test)] mod tests { use super::*; @@ -802,4 +987,75 @@ mod tests { ]) ); } + + #[tokio::test] + async fn test_multivector_score() { + let query = Query { + column: "vector".to_string(), + key: Arc::new(generate_random_array(1)), + k: 10, + lower_bound: None, + upper_bound: None, + nprobes: 1, + ef: None, + refine_factor: None, + metric_type: DistanceType::Cosine, + use_index: true, + }; + + async fn multivector_scoring( + inputs: Vec>, + query: Query, + ) -> Result> { + let ctx = Arc::new(datafusion::execution::context::TaskContext::default()); + let plan = MultivectorScoringExec::try_new(inputs, query.clone())?; + let batches = plan + .execute(0, ctx.clone()) + .unwrap() + .try_collect::>() + .await?; + let mut results = HashMap::new(); + for batch in batches { + let row_ids = batch[ROW_ID].as_primitive::(); + let dists = batch[DIST_COL].as_primitive::(); + for (row_id, dist) in row_ids.values().iter().zip(dists.values().iter()) { + results.insert(*row_id, *dist); + } + } + Ok(results) + } + + let batches = (0..3) + .map(|i| { + RecordBatch::try_new( + KNN_INDEX_SCHEMA.clone(), + vec![ + Arc::new(Float32Array::from(vec![i as f32 + 1.0, i as f32 + 2.0])), + Arc::new(UInt64Array::from(vec![i + 1, i + 2])), + ], + ) + .unwrap() + }) + .collect::>(); + + let mut res: Option> = None; + for perm in batches.into_iter().permutations(3) { + let inputs = perm + .into_iter() + .map(|batch| { + let input: Arc = Arc::new(TestingExec::new(vec![batch])); + input + }) + .collect::>(); + let new_res = multivector_scoring(inputs, query.clone()).await.unwrap(); + assert_eq!(new_res.len(), 4); + if let Some(res) = &res { + for (row_id, dist) in new_res.iter() { + assert_eq!(res.get(row_id).unwrap(), dist) + } + } else { + res = Some(new_res); + } + } + } } diff --git a/rust/lance/src/io/exec/testing.rs b/rust/lance/src/io/exec/testing.rs index 611cf5480b..23a69ed7d0 100644 --- a/rust/lance/src/io/exec/testing.rs +++ b/rust/lance/src/io/exec/testing.rs @@ -8,6 +8,7 @@ use std::any::Any; use std::sync::Arc; use arrow_array::RecordBatch; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::{ common::Statistics, execution::context::TaskContext, @@ -17,6 +18,7 @@ use datafusion::{ }, }; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use futures::StreamExt; #[derive(Debug)] pub struct TestingExec { @@ -76,7 +78,9 @@ impl ExecutionPlan for TestingExec { _partition: usize, _context: Arc, ) -> datafusion::error::Result { - todo!() + let stream = futures::stream::iter(self.batches.clone().into_iter().map(Ok)); + let stream = RecordBatchStreamAdapter::new(self.schema(), stream.boxed()); + Ok(Box::pin(stream)) } fn statistics(&self) -> datafusion::error::Result {