-
Notifications
You must be signed in to change notification settings - Fork 256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
perf: implement XTR for retrieving multivector #3437
base: main
Are you sure you want to change the base?
Changes from all commits
1520bb7
0525f33
dbfa98b
dc7cc24
686cfae
417d776
c6aa615
f573c26
ec0310d
a598113
955471b
940fb85
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,6 +63,7 @@ use crate::index::scalar::detect_scalar_index_type; | |
use crate::index::vector::utils::{get_vector_dim, get_vector_type}; | ||
use crate::index::DatasetIndexInternalExt; | ||
use crate::io::exec::fts::{FlatFtsExec, FtsExec}; | ||
use crate::io::exec::knn::MultivectorScoringExec; | ||
use crate::io::exec::scalar_index::{MaterializeIndexExec, ScalarIndexExec}; | ||
use crate::io::exec::{get_physical_optimizer, LanceScanConfig}; | ||
use crate::io::exec::{ | ||
|
@@ -88,6 +89,9 @@ pub const LEGACY_DEFAULT_FRAGMENT_READAHEAD: usize = 4; | |
lazy_static::lazy_static! { | ||
pub static ref DEFAULT_FRAGMENT_READAHEAD: Option<usize> = std::env::var("LANCE_DEFAULT_FRAGMENT_READAHEAD") | ||
.map(|val| Some(val.parse().unwrap())).unwrap_or(None); | ||
|
||
pub static ref DEFAULT_XTR_OVERFETCH: u32 = std::env::var("LANCE_XTR_OVERFETCH") | ||
.map(|val| val.parse().unwrap()).unwrap_or(10); | ||
} | ||
|
||
// We want to support ~256 concurrent reads to maximize throughput on cloud storage systems | ||
|
@@ -1697,13 +1701,13 @@ impl Scanner { | |
|
||
// Find all deltas with the same index name. | ||
let deltas = self.dataset.load_indices_by_name(&index.name).await?; | ||
let (ann_node, is_multivec) = match vector_type { | ||
DataType::FixedSizeList(_, _) => (self.ann(q, &deltas, filter_plan).await?, false), | ||
DataType::List(_) => (self.multivec_ann(q, &deltas, filter_plan).await?, true), | ||
let ann_node = match vector_type { | ||
DataType::FixedSizeList(_, _) => self.ann(q, &deltas, filter_plan).await?, | ||
DataType::List(_) => self.multivec_ann(q, &deltas, filter_plan).await?, | ||
_ => unreachable!(), | ||
}; | ||
|
||
let mut knn_node = if q.refine_factor.is_some() || is_multivec { | ||
let mut knn_node = if q.refine_factor.is_some() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't require refine for multivector search any more |
||
let vector_projection = self | ||
.dataset | ||
.empty_projection() | ||
|
@@ -2164,69 +2168,56 @@ impl Scanner { | |
index: &[Index], | ||
filter_plan: &FilterPlan, | ||
) -> Result<Arc<dyn ExecutionPlan>> { | ||
// we split the query procedure into two steps: | ||
// 1. collect the candidates by vector searching on each query vector | ||
// 2. scoring the candidates | ||
|
||
let over_fetch_factor = *DEFAULT_XTR_OVERFETCH; | ||
|
||
let prefilter_source = self.prefilter_source(filter_plan).await?; | ||
let dim = get_vector_dim(self.dataset.schema(), &q.column)?; | ||
// split the query multivectors | ||
|
||
let num_queries = q.key.len() / dim; | ||
let new_queries = (0..num_queries) | ||
.map(|i| q.key.slice(i * dim, dim)) | ||
.map(|query_vec| { | ||
let mut new_query = q.clone(); | ||
new_query.key = query_vec; | ||
// with XTR, we don't need to refine the result with original vectors, | ||
// but here we really need to over-fetch the candidates to reach good enough recall. | ||
// TODO: improve the recall with WARP, expose this parameter to the users. | ||
new_query.refine_factor = Some(over_fetch_factor); | ||
new_query | ||
}); | ||
let mut ann_nodes = Vec::with_capacity(new_queries.len()); | ||
let prefilter_source = self.prefilter_source(filter_plan).await?; | ||
for query in new_queries { | ||
// this produces `nprobes * k * over_fetch_factor * num_indices` candidates | ||
let ann_node = new_knn_exec( | ||
self.dataset.clone(), | ||
index, | ||
&query, | ||
prefilter_source.clone(), | ||
)?; | ||
ann_nodes.push(ann_node); | ||
let sort_expr = PhysicalSortExpr { | ||
expr: expressions::col(DIST_COL, ann_node.schema().as_ref())?, | ||
options: SortOptions { | ||
descending: false, | ||
nulls_first: false, | ||
}, | ||
}; | ||
let ann_node = Arc::new( | ||
SortExec::new(LexOrdering::new(vec![sort_expr]), ann_node) | ||
.with_fetch(Some(q.k * over_fetch_factor as usize)), | ||
); | ||
ann_nodes.push(ann_node as Arc<dyn ExecutionPlan>); | ||
} | ||
let ann_node = Arc::new(UnionExec::new(ann_nodes)); | ||
let ann_node = Arc::new(RepartitionExec::try_new( | ||
ann_node, | ||
datafusion::physical_plan::Partitioning::RoundRobinBatch(1), | ||
)?); | ||
let schema = ann_node.schema(); | ||
// unique by row ids, and get the min distance although it is not used. | ||
let group_expr = vec![( | ||
expressions::col(ROW_ID, schema.as_ref())?, | ||
ROW_ID.to_string(), | ||
)]; | ||
// for now multivector is always with cosine distance so here convert the distance to `1 - distance` | ||
// and calculate the sum across all rows with the same row id. | ||
let sum_expr = AggregateExprBuilder::new( | ||
functions_aggregate::sum::sum_udaf(), | ||
vec![expressions::binary( | ||
expressions::lit(1.0), | ||
datafusion_expr::Operator::Minus, | ||
expressions::cast( | ||
expressions::col(DIST_COL, &schema)?, | ||
&schema, | ||
DataType::Float64, | ||
)?, | ||
&schema, | ||
)?], | ||
) | ||
.schema(schema.clone()) | ||
.alias(DIST_COL) | ||
.build()?; | ||
let ann_node: Arc<dyn ExecutionPlan> = Arc::new(AggregateExec::try_new( | ||
AggregateMode::Single, | ||
PhysicalGroupBy::new_single(group_expr), | ||
vec![Arc::new(sum_expr)], | ||
vec![None], | ||
ann_node, | ||
schema, | ||
)?); | ||
|
||
let ann_node = Arc::new(MultivectorScoringExec::try_new(ann_nodes, q.clone())?); | ||
|
||
let sort_expr = PhysicalSortExpr { | ||
expr: expressions::col(DIST_COL, ann_node.schema().as_ref())?, | ||
options: SortOptions { | ||
descending: true, | ||
descending: false, | ||
nulls_first: false, | ||
}, | ||
}; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -683,7 +683,7 @@ mod tests { | |
.into_iter() | ||
.enumerate() | ||
.map(|(i, dist)| (dist, i as u64)) | ||
.sorted_by(|a, b| a.0.partial_cmp(&b.0).unwrap()) | ||
.sorted_by(|a, b| a.0.total_cmp(&b.0)) | ||
.take(k) | ||
.collect() | ||
} | ||
|
@@ -991,6 +991,8 @@ mod tests { | |
} | ||
|
||
async fn test_index_multivec(params: VectorIndexParams, nlist: usize, recall_requirement: f32) { | ||
// we introduce XTR for performance, which would reduce the recall a little bit | ||
let recall_requirement = recall_requirement * 0.9; | ||
match params.metric_type { | ||
DistanceType::Hamming => { | ||
test_index_multivec_impl::<UInt8Type>(params, nlist, recall_requirement, 0..2) | ||
|
@@ -1061,7 +1063,7 @@ mod tests { | |
let gt = multivec_ground_truth(&vectors, &query, k, params.metric_type); | ||
let gt_set = gt.iter().map(|r| r.1).collect::<HashSet<_>>(); | ||
|
||
let recall = row_ids.intersection(>_set).count() as f32 / 10.0; | ||
let recall = row_ids.intersection(>_set).count() as f32 / 100.0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the recall was calculated incorrectly, the previous algo requires refine_factor=5 to reach good enough recall |
||
assert!( | ||
recall >= recall_requirement, | ||
"recall: {}\n results: {:?}\n\ngt: {:?}", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changing these so that the flat search results can be the same with IVF_FLAT, so the numbers won't confuse users