Skip to content

Commit

Permalink
fix: cosine, l2 distance and external build for multilevel kmeans (#44)
Browse files Browse the repository at this point in the history
* fix: cosine distance and external build for multilevel kmeans

Signed-off-by: usamoi <[email protected]>

* chore: use pgvector names

Signed-off-by: usamoi <[email protected]>

---------

Signed-off-by: usamoi <[email protected]>
  • Loading branch information
usamoi authored Nov 7, 2024
1 parent 9161a67 commit 784edb8
Show file tree
Hide file tree
Showing 11 changed files with 264 additions and 219 deletions.
8 changes: 4 additions & 4 deletions bench/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ def get_ivf_ops_config(metric, k, name=None):
residual_quantization = true
spherical_centroids = false
"""
elif metric == "cos":
metric_ops = "vector_cos_ops"
elif metric == "cosine":
metric_ops = "vector_cosine_ops"
ivf_config = f"""
nlist = {k}
residual_quantization = false
spherical_centroids = true
"""
elif metric == "dot":
metric_ops = "vector_dot_ops"
elif metric == "ip":
metric_ops = "vector_ip_ops"
ivf_config = f"""
nlist = {k}
residual_quantization = false
Expand Down
255 changes: 151 additions & 104 deletions src/algorithm/build.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
use crate::algorithm::rabitq;
use crate::algorithm::tuples::*;
use crate::index::am_options::PgDistanceKind;
use crate::index::utils::load_table_vectors;
use crate::index::am_options::Opfamily;
use crate::postgres::BufferWriteGuard;
use crate::postgres::Relation;
use crate::types::ExternalCentroids;
use crate::types::RabbitholeBuildOptions;
use crate::types::RabbitholeExternalBuildOptions;
use crate::types::RabbitholeIndexingOptions;
use crate::types::RabbitholeInternalBuildOptions;
use base::distance::DistanceKind;
use base::index::VectorOptions;
use base::scalar::ScalarLike;
use base::search::Pointer;
use base::vector::VectBorrowed;
use base::vector::VectorBorrowed;
use common::vec2::Vec2;
use rand::Rng;
use rkyv::ser::serializers::AllocSerializer;
Expand All @@ -23,6 +22,7 @@ pub trait HeapRelation {
fn traverse<F>(&self, callback: F)
where
F: FnMut((Pointer, Vec<f32>));
fn opfamily(&self) -> Opfamily;
}

pub trait Reporter {
Expand All @@ -35,28 +35,26 @@ pub fn build<T: HeapRelation, R: Reporter>(
rabbithole_options: RabbitholeIndexingOptions,
heap_relation: T,
relation: Relation,
pg_distance: PgDistanceKind,
mut reporter: R,
) {
let dims = vector_options.dims;
let is_residual =
rabbithole_options.residual_quantization && vector_options.d == DistanceKind::L2;
let structure = match &rabbithole_options.external_centroids {
Some(_) => Structure::load(
let structure = match rabbithole_options.build {
RabbitholeBuildOptions::External(external_build) => Structure::extern_build(
vector_options.clone(),
rabbithole_options.clone(),
pg_distance,
heap_relation.opfamily(),
external_build.clone(),
),
None => {
RabbitholeBuildOptions::Internal(internal_build) => {
let mut tuples_total = 0_usize;
let samples = {
let mut rand = rand::thread_rng();
let max_number_of_samples = rabbithole_options.nlist.saturating_mul(256);
let max_number_of_samples = internal_build.nlist.saturating_mul(256);
let mut samples = Vec::new();
let mut number_of_samples = 0_u32;
heap_relation.traverse(|(_, vector)| {
assert_eq!(dims as usize, vector.len(), "invalid vector dimensions",);
let vector = rabitq::project(&vector);
assert_eq!(dims as usize, vector.len(), "invalid vector dimensions");
if number_of_samples < max_number_of_samples {
samples.extend(vector);
number_of_samples += 1;
Expand All @@ -71,7 +69,7 @@ pub fn build<T: HeapRelation, R: Reporter>(
Vec2::from_vec((number_of_samples as _, dims as _), samples)
};
reporter.tuples_total(tuples_total);
Structure::compute(vector_options.clone(), rabbithole_options.clone(), samples)
Structure::internal_build(vector_options.clone(), internal_build.clone(), samples)
}
};
let h2_len = structure.h2_len();
Expand Down Expand Up @@ -182,33 +180,37 @@ pub fn build<T: HeapRelation, R: Reporter>(
}

struct Structure {
h2_mean: Vec<f32>,
h2_children: Vec<u32>,
h2_means: Vec<Vec<f32>>,
h2_children: Vec<Vec<u32>>,
h1_means: Vec<Vec<f32>>,
h1_children: Vec<Vec<u32>>,
}

impl Structure {
fn compute(
fn internal_build(
vector_options: VectorOptions,
rabbithole_options: RabbitholeIndexingOptions,
samples: Vec2<f32>,
internal_build: RabbitholeInternalBuildOptions,
mut samples: Vec2<f32>,
) -> Self {
let dims = vector_options.dims;
for i in 0..samples.shape_0() {
let vector = &mut samples[(i,)];
vector.copy_from_slice(&rabitq::project(vector));
}
let h1_means = base::parallelism::RayonParallelism::scoped(
rabbithole_options.build_threads as _,
internal_build.build_threads as _,
Arc::new(AtomicBool::new(false)),
|parallelism| {
let raw = k_means::k_means(
parallelism,
rabbithole_options.nlist as usize,
internal_build.nlist as usize,
samples,
rabbithole_options.spherical_centroids,
internal_build.spherical_centroids,
10,
false,
);
let mut centroids = Vec::new();
for i in 0..rabbithole_options.nlist {
for i in 0..internal_build.nlist {
centroids.push(raw[(i as usize,)].to_vec());
}
centroids
Expand All @@ -218,116 +220,161 @@ impl Structure {
.expect("k_means interrupted");
let h2_mean = {
let mut centroid = vec![0.0; dims as _];
for i in 0..rabbithole_options.nlist {
for i in 0..internal_build.nlist {
for j in 0..dims {
centroid[j as usize] += h1_means[i as usize][j as usize];
}
}
for j in 0..dims {
centroid[j as usize] /= rabbithole_options.nlist as f32;
centroid[j as usize] /= internal_build.nlist as f32;
}
centroid
};
Structure {
h2_mean,
h2_children: (0..rabbithole_options.nlist).collect(),
h2_means: vec![h2_mean],
h2_children: vec![(0..internal_build.nlist).collect()],
h1_means,
h1_children: (0..rabbithole_options.nlist).map(|_| Vec::new()).collect(),
h1_children: (0..internal_build.nlist).map(|_| Vec::new()).collect(),
}
}
fn load(
fn extern_build(
vector_options: VectorOptions,
rabbithole_options: RabbitholeIndexingOptions,
pg_distance: PgDistanceKind,
_opfamily: Opfamily,
external_build: RabbitholeExternalBuildOptions,
) -> Self {
let dims = vector_options.dims;
let preprocess_data = match pg_distance {
PgDistanceKind::L2 | PgDistanceKind::Dot => {
|b: VectBorrowed<f32>| rabitq::project(b.slice())
use std::collections::BTreeMap;
let RabbitholeExternalBuildOptions { table } = external_build;
let query = format!("SELECT id, parent, vector FROM {table};");
let mut parents = BTreeMap::new();
let mut vectors = BTreeMap::new();
pgrx::spi::Spi::connect(|client| {
use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput;
use base::vector::VectorBorrowed;
use pgrx::pg_sys::panic::ErrorReportable;
let table = client.select(&query, None, None).unwrap_or_report();
for row in table {
let id: Option<i32> = row.get_by_name("id").unwrap();
let parent: Option<i32> = row.get_by_name("parent").unwrap();
let vector: Option<PgvectorVectorOutput> = row.get_by_name("vector").unwrap();
let id = id.expect("extern build: id could not be NULL");
let vector = vector.expect("extern build: vector could not be NULL");
let pop = parents.insert(id, parent);
if pop.is_some() {
pgrx::error!(
"external build: there are at least two lines have same id, id = {id}"
);
}
if vector_options.dims != vector.as_borrowed().dims() {
pgrx::error!("extern build: incorrect dimension, id = {id}");
}
vectors.insert(id, rabitq::project(vector.slice()));
}
PgDistanceKind::Cos => {
|b: VectBorrowed<f32>| rabitq::project(b.function_normalize().slice())
});
let mut children = parents
.keys()
.map(|x| (*x, Vec::new()))
.collect::<BTreeMap<_, _>>();
let mut root = None;
for (&id, &parent) in parents.iter() {
if let Some(parent) = parent {
if let Some(parent) = children.get_mut(&parent) {
parent.push(id);
} else {
pgrx::error!(
"external build: parent does not exist, id = {id}, parent = {parent}"
);
}
} else {
if let Some(root) = root {
pgrx::error!("external build: two root, id = {root}, id = {id}");
} else {
root = Some(id);
}
}
}
let Some(root) = root else {
pgrx::error!("extern build: there are no root");
};
let preprocess_index = |b: VectBorrowed<f32>| b.slice().to_vec();

let h1_means = match &rabbithole_options.external_centroids {
Some(ExternalCentroids {
table,
h1_means_column: h1,
..
}) => load_table_vectors(
table,
h1,
rabbithole_options.nlist,
vector_options.dims,
preprocess_data,
),

_ => unreachable!(),
};
let h1_children = match &rabbithole_options.external_centroids {
Some(ExternalCentroids {
table,
h1_children_column: Some(h1),
..
}) => load_table_vectors(table, h1, 1, vector_options.dims, preprocess_index)
.into_iter()
.map(|v| v.into_iter().map(|f| f as u32).collect())
.collect(),
_ => (0..rabbithole_options.nlist).map(|_| Vec::new()).collect(),
};
let h2_mean = match &rabbithole_options.external_centroids {
Some(ExternalCentroids {
table,
h2_mean_column: Some(h2),
..
}) => load_table_vectors(table, h2, 1, vector_options.dims, preprocess_data)
.pop()
.expect("load h2_mean panic"),
_ => {
let mut centroid = vec![0.0; dims as _];
for i in 0..rabbithole_options.nlist {
for j in 0..dims {
centroid[j as usize] += h1_means[i as usize][j as usize];
let mut heights = BTreeMap::<_, _>::new();
fn dfs_for_heights(
heights: &mut BTreeMap<i32, Option<u32>>,
children: &BTreeMap<i32, Vec<i32>>,
u: i32,
) {
if heights.contains_key(&u) {
pgrx::error!("extern build: detect a cycle, id = {u}");
}
heights.insert(u, None);
let mut height = None;
for &v in children[&u].iter() {
dfs_for_heights(heights, children, v);
let new = heights[&v].unwrap() + 1;
if let Some(height) = height {
if height != new {
pgrx::error!("extern build: two heights, id = {u}");
}
} else {
height = Some(new);
}
for j in 0..dims {
centroid[j as usize] /= rabbithole_options.nlist as f32;
}
centroid
}
};
let h2_children = match &rabbithole_options.external_centroids {
Some(ExternalCentroids {
table,
h2_children_column: Some(h2),
..
}) => load_table_vectors(table, h2, 1, vector_options.dims, preprocess_index)
.pop()
.expect("load h2_children panic")
.into_iter()
.map(|f| f as u32)
.collect(),
_ => (0..rabbithole_options.nlist).collect(),
};
Structure {
h2_mean,
if height.is_none() {
height = Some(1);
}
heights.insert(u, height);
}
dfs_for_heights(&mut heights, &children, root);
let heights = heights
.into_iter()
.map(|(k, v)| (k, v.expect("not a connected graph")))
.collect::<BTreeMap<_, _>>();
if heights[&root] != 2 {
pgrx::error!(
"extern build: unexpected tree height, height = {}",
heights[&root]
);
}
let mut cursors = vec![0_u32; 1 + heights[&root] as usize];
let mut labels = BTreeMap::new();
for id in parents.keys().copied() {
let height = heights[&id];
let cursor = cursors[height as usize];
labels.insert(id, (height, cursor));
cursors[height as usize] += 1;
}
fn extract(
height: u32,
labels: &BTreeMap<i32, (u32, u32)>,
vectors: &BTreeMap<i32, Vec<f32>>,
children: &BTreeMap<i32, Vec<i32>>,
) -> (Vec<Vec<f32>>, Vec<Vec<u32>>) {
labels
.iter()
.filter(|(_, &(h, _))| h == height)
.map(|(id, _)| {
(
vectors[id].clone(),
children[id].iter().map(|id| labels[id].1).collect(),
)
})
.unzip()
}
let (h2_means, h2_children) = extract(2, &labels, &vectors, &children);
let (h1_means, h1_children) = extract(1, &labels, &vectors, &children);
Self {
h2_means,
h2_children,
h1_means,
h1_children,
}
}
fn h2_len(&self) -> u32 {
1
self.h2_means.len() as _
}
fn h2_means(&self, i: u32) -> &Vec<f32> {
assert!(i == 0);
&self.h2_mean
&self.h2_means[i as usize]
}
fn h2_children(&self, i: u32) -> &Vec<u32> {
assert!(i == 0);
&self.h2_children
&self.h2_children[i as usize]
}
fn h1_len(&self) -> u32 {
self.h1_means.len() as _
Expand Down
Loading

0 comments on commit 784edb8

Please sign in to comment.