Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: implement XTR for retrieving multivector #3437

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
5 changes: 3 additions & 2 deletions rust/lance-linalg/src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ pub fn multivec_distance(
})
.unwrap_or(f32::NAN)
})
.map(|sim| 1.0 - sim)
.collect();
Ok(dists)
}
Expand All @@ -197,8 +198,8 @@ where
.as_primitive::<T>()
.values()
.chunks_exact(dim)
.map(|v| distance_type.func()(q, v))
.min_by(|a, b| a.partial_cmp(b).unwrap())
.map(|v| 1.0 - distance_type.func()(q, v))
.max_by(|a, b| a.total_cmp(b))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changing these so that the flat search results can be the same with IVF_FLAT, so the numbers won't confuse users

.unwrap()
})
.sum()
Expand Down
81 changes: 36 additions & 45 deletions rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ use crate::index::scalar::detect_scalar_index_type;
use crate::index::vector::utils::{get_vector_dim, get_vector_type};
use crate::index::DatasetIndexInternalExt;
use crate::io::exec::fts::{FlatFtsExec, FtsExec};
use crate::io::exec::knn::MultivectorScoringExec;
use crate::io::exec::scalar_index::{MaterializeIndexExec, ScalarIndexExec};
use crate::io::exec::{get_physical_optimizer, LanceScanConfig};
use crate::io::exec::{
Expand All @@ -88,6 +89,9 @@ pub const LEGACY_DEFAULT_FRAGMENT_READAHEAD: usize = 4;
lazy_static::lazy_static! {
pub static ref DEFAULT_FRAGMENT_READAHEAD: Option<usize> = std::env::var("LANCE_DEFAULT_FRAGMENT_READAHEAD")
.map(|val| Some(val.parse().unwrap())).unwrap_or(None);

pub static ref DEFAULT_XTR_OVERFETCH: u32 = std::env::var("LANCE_XTR_OVERFETCH")
.map(|val| val.parse().unwrap()).unwrap_or(10);
}

// We want to support ~256 concurrent reads to maximize throughput on cloud storage systems
Expand Down Expand Up @@ -1697,13 +1701,13 @@ impl Scanner {

// Find all deltas with the same index name.
let deltas = self.dataset.load_indices_by_name(&index.name).await?;
let (ann_node, is_multivec) = match vector_type {
DataType::FixedSizeList(_, _) => (self.ann(q, &deltas, filter_plan).await?, false),
DataType::List(_) => (self.multivec_ann(q, &deltas, filter_plan).await?, true),
let ann_node = match vector_type {
DataType::FixedSizeList(_, _) => self.ann(q, &deltas, filter_plan).await?,
DataType::List(_) => self.multivec_ann(q, &deltas, filter_plan).await?,
_ => unreachable!(),
};

let mut knn_node = if q.refine_factor.is_some() || is_multivec {
let mut knn_node = if q.refine_factor.is_some() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't require refine for multivector search any more

let vector_projection = self
.dataset
.empty_projection()
Expand Down Expand Up @@ -2164,69 +2168,56 @@ impl Scanner {
index: &[Index],
filter_plan: &FilterPlan,
) -> Result<Arc<dyn ExecutionPlan>> {
// we split the query procedure into two steps:
// 1. collect the candidates by vector searching on each query vector
// 2. scoring the candidates

let over_fetch_factor = *DEFAULT_XTR_OVERFETCH;

let prefilter_source = self.prefilter_source(filter_plan).await?;
let dim = get_vector_dim(self.dataset.schema(), &q.column)?;
// split the query multivectors

let num_queries = q.key.len() / dim;
let new_queries = (0..num_queries)
.map(|i| q.key.slice(i * dim, dim))
.map(|query_vec| {
let mut new_query = q.clone();
new_query.key = query_vec;
// with XTR, we don't need to refine the result with original vectors,
// but here we really need to over-fetch the candidates to reach good enough recall.
// TODO: improve the recall with WARP, expose this parameter to the users.
new_query.refine_factor = Some(over_fetch_factor);
new_query
});
let mut ann_nodes = Vec::with_capacity(new_queries.len());
let prefilter_source = self.prefilter_source(filter_plan).await?;
for query in new_queries {
// this produces `nprobes * k * over_fetch_factor * num_indices` candidates
let ann_node = new_knn_exec(
self.dataset.clone(),
index,
&query,
prefilter_source.clone(),
)?;
ann_nodes.push(ann_node);
let sort_expr = PhysicalSortExpr {
expr: expressions::col(DIST_COL, ann_node.schema().as_ref())?,
options: SortOptions {
descending: false,
nulls_first: false,
},
};
let ann_node = Arc::new(
SortExec::new(LexOrdering::new(vec![sort_expr]), ann_node)
.with_fetch(Some(q.k * over_fetch_factor as usize)),
);
ann_nodes.push(ann_node as Arc<dyn ExecutionPlan>);
}
let ann_node = Arc::new(UnionExec::new(ann_nodes));
let ann_node = Arc::new(RepartitionExec::try_new(
ann_node,
datafusion::physical_plan::Partitioning::RoundRobinBatch(1),
)?);
let schema = ann_node.schema();
// unique by row ids, and get the min distance although it is not used.
let group_expr = vec![(
expressions::col(ROW_ID, schema.as_ref())?,
ROW_ID.to_string(),
)];
// for now multivector is always with cosine distance so here convert the distance to `1 - distance`
// and calculate the sum across all rows with the same row id.
let sum_expr = AggregateExprBuilder::new(
functions_aggregate::sum::sum_udaf(),
vec![expressions::binary(
expressions::lit(1.0),
datafusion_expr::Operator::Minus,
expressions::cast(
expressions::col(DIST_COL, &schema)?,
&schema,
DataType::Float64,
)?,
&schema,
)?],
)
.schema(schema.clone())
.alias(DIST_COL)
.build()?;
let ann_node: Arc<dyn ExecutionPlan> = Arc::new(AggregateExec::try_new(
AggregateMode::Single,
PhysicalGroupBy::new_single(group_expr),
vec![Arc::new(sum_expr)],
vec![None],
ann_node,
schema,
)?);

let ann_node = Arc::new(MultivectorScoringExec::try_new(ann_nodes, q.clone())?);

let sort_expr = PhysicalSortExpr {
expr: expressions::col(DIST_COL, ann_node.schema().as_ref())?,
options: SortOptions {
descending: true,
descending: false,
nulls_first: false,
},
};
Expand Down
6 changes: 4 additions & 2 deletions rust/lance/src/index/vector/ivf/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ mod tests {
.into_iter()
.enumerate()
.map(|(i, dist)| (dist, i as u64))
.sorted_by(|a, b| a.0.partial_cmp(&b.0).unwrap())
.sorted_by(|a, b| a.0.total_cmp(&b.0))
.take(k)
.collect()
}
Expand Down Expand Up @@ -991,6 +991,8 @@ mod tests {
}

async fn test_index_multivec(params: VectorIndexParams, nlist: usize, recall_requirement: f32) {
// we introduce XTR for performance, which would reduce the recall a little bit
let recall_requirement = recall_requirement * 0.9;
match params.metric_type {
DistanceType::Hamming => {
test_index_multivec_impl::<UInt8Type>(params, nlist, recall_requirement, 0..2)
Expand Down Expand Up @@ -1061,7 +1063,7 @@ mod tests {
let gt = multivec_ground_truth(&vectors, &query, k, params.metric_type);
let gt_set = gt.iter().map(|r| r.1).collect::<HashSet<_>>();

let recall = row_ids.intersection(&gt_set).count() as f32 / 10.0;
let recall = row_ids.intersection(&gt_set).count() as f32 / 100.0;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the recall was calculated incorrectly, the previous algo requires refine_factor=5 to reach good enough recall

assert!(
recall >= recall_requirement,
"recall: {}\n results: {:?}\n\ngt: {:?}",
Expand Down
Loading
Loading