From 0360d7c099b969aa7870bcca001c14348782c9cc Mon Sep 17 00:00:00 2001 From: usamoi Date: Wed, 20 Nov 2024 18:56:06 +0800 Subject: [PATCH] feat: vchordrqfscan (#105) Signed-off-by: usamoi --- src/lib.rs | 4 +- src/sql/finalize.sql | 30 + src/{vchordrq/algorithm => utils}/k_means.rs | 2 +- src/utils/mod.rs | 2 + .../algorithm => utils}/parallelism.rs | 0 src/vchordrq/algorithm/build.rs | 16 +- src/vchordrq/algorithm/mod.rs | 2 - src/vchordrq/algorithm/rabitq.rs | 9 + src/vchordrq/algorithm/vectors.rs | 2 +- src/vchordrq/index/am_options.rs | 2 +- src/vchordrq/index/utils.rs | 11 - src/vchordrq/mod.rs | 1 + src/{ => vchordrq}/types.rs | 0 src/vchordrqfscan/algorithm/build.rs | 412 +++++++++ src/vchordrqfscan/algorithm/insert.rs | 258 ++++++ src/vchordrqfscan/algorithm/mod.rs | 7 + src/vchordrqfscan/algorithm/prewarm.rs | 107 +++ src/vchordrqfscan/algorithm/rabitq.rs | 171 ++++ src/vchordrqfscan/algorithm/scan.rs | 197 ++++ src/vchordrqfscan/algorithm/tuples.rs | 139 +++ src/vchordrqfscan/algorithm/vacuum.rs | 141 +++ src/vchordrqfscan/gucs/executing.rs | 76 ++ src/vchordrqfscan/gucs/mod.rs | 14 + src/vchordrqfscan/gucs/prewarm.rs | 32 + src/vchordrqfscan/index/am.rs | 864 ++++++++++++++++++ src/vchordrqfscan/index/am_options.rs | 222 +++++ src/vchordrqfscan/index/am_scan.rs | 132 +++ src/vchordrqfscan/index/functions.rs | 26 + src/vchordrqfscan/index/mod.rs | 12 + src/vchordrqfscan/index/opclass.rs | 14 + src/vchordrqfscan/index/utils.rs | 20 + src/vchordrqfscan/mod.rs | 11 + src/vchordrqfscan/types.rs | 90 ++ 33 files changed, 3001 insertions(+), 25 deletions(-) rename src/{vchordrq/algorithm => utils}/k_means.rs (98%) create mode 100644 src/utils/mod.rs rename src/{vchordrq/algorithm => utils}/parallelism.rs (100%) rename src/{ => vchordrq}/types.rs (100%) create mode 100644 src/vchordrqfscan/algorithm/build.rs create mode 100644 src/vchordrqfscan/algorithm/insert.rs create mode 100644 src/vchordrqfscan/algorithm/mod.rs create mode 100644 src/vchordrqfscan/algorithm/prewarm.rs create mode 100644 src/vchordrqfscan/algorithm/rabitq.rs create mode 100644 src/vchordrqfscan/algorithm/scan.rs create mode 100644 src/vchordrqfscan/algorithm/tuples.rs create mode 100644 src/vchordrqfscan/algorithm/vacuum.rs create mode 100644 src/vchordrqfscan/gucs/executing.rs create mode 100644 src/vchordrqfscan/gucs/mod.rs create mode 100644 src/vchordrqfscan/gucs/prewarm.rs create mode 100644 src/vchordrqfscan/index/am.rs create mode 100644 src/vchordrqfscan/index/am_options.rs create mode 100644 src/vchordrqfscan/index/am_scan.rs create mode 100644 src/vchordrqfscan/index/functions.rs create mode 100644 src/vchordrqfscan/index/mod.rs create mode 100644 src/vchordrqfscan/index/opclass.rs create mode 100644 src/vchordrqfscan/index/utils.rs create mode 100644 src/vchordrqfscan/mod.rs create mode 100644 src/vchordrqfscan/types.rs diff --git a/src/lib.rs b/src/lib.rs index 61062f3..94741f5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,9 +7,10 @@ mod datatype; mod postgres; mod projection; -mod types; mod upgrade; +mod utils; mod vchordrq; +mod vchordrqfscan; pgrx::pg_module_magic!(); pgrx::extension_sql_file!("./sql/bootstrap.sql", bootstrap); @@ -23,6 +24,7 @@ unsafe extern "C" fn _PG_init() { detect::init(); unsafe { vchordrq::init(); + vchordrqfscan::init(); #[cfg(any(feature = "pg13", feature = "pg14"))] pgrx::pg_sys::EmitWarningsOnPlaceholders(c"vchord".as_ptr()); diff --git a/src/sql/finalize.sql b/src/sql/finalize.sql index be75da3..32b356c 100644 --- a/src/sql/finalize.sql +++ b/src/sql/finalize.sql @@ -39,16 +39,28 @@ IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vchordrq_amhan CREATE FUNCTION vchordrq_prewarm(regclass, integer default 0) RETURNS TEXT STRICT LANGUAGE c AS 'MODULE_PATHNAME', '_vchordrq_prewarm_wrapper'; +CREATE FUNCTION vchordrqfscan_amhandler(internal) RETURNS index_am_handler +IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vchordrqfscan_amhandler_wrapper'; + +CREATE FUNCTION vchordrqfscan_prewarm(regclass, integer default 0) RETURNS TEXT +STRICT LANGUAGE c AS 'MODULE_PATHNAME', '_vchordrqfscan_prewarm_wrapper'; + -- List of access methods CREATE ACCESS METHOD vchordrq TYPE INDEX HANDLER vchordrq_amhandler; +CREATE ACCESS METHOD Vchordrqfscan TYPE INDEX HANDLER Vchordrqfscan_amhandler; + -- List of operator families CREATE OPERATOR FAMILY vector_l2_ops USING vchordrq; CREATE OPERATOR FAMILY vector_ip_ops USING vchordrq; CREATE OPERATOR FAMILY vector_cosine_ops USING vchordrq; +CREATE OPERATOR FAMILY vector_l2_ops USING Vchordrqfscan; +CREATE OPERATOR FAMILY vector_ip_ops USING Vchordrqfscan; +CREATE OPERATOR FAMILY vector_cosine_ops USING Vchordrqfscan; + -- List of operator classes CREATE OPERATOR CLASS vector_l2_ops @@ -68,3 +80,21 @@ CREATE OPERATOR CLASS vector_cosine_ops OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, OPERATOR 2 <<=>> (vector, sphere_vector) FOR SEARCH, FUNCTION 1 _vchordrq_support_vector_cosine_ops(); + +CREATE OPERATOR CLASS vector_l2_ops + FOR TYPE vector USING Vchordrqfscan FAMILY vector_l2_ops AS + OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, + OPERATOR 2 <<->> (vector, sphere_vector) FOR SEARCH, + FUNCTION 1 _Vchordrqfscan_support_vector_l2_ops(); + +CREATE OPERATOR CLASS vector_ip_ops + FOR TYPE vector USING Vchordrqfscan FAMILY vector_ip_ops AS + OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops, + OPERATOR 2 <<#>> (vector, sphere_vector) FOR SEARCH, + FUNCTION 1 _Vchordrqfscan_support_vector_ip_ops(); + +CREATE OPERATOR CLASS vector_cosine_ops + FOR TYPE vector USING Vchordrqfscan FAMILY vector_cosine_ops AS + OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, + OPERATOR 2 <<=>> (vector, sphere_vector) FOR SEARCH, + FUNCTION 1 _Vchordrqfscan_support_vector_cosine_ops(); diff --git a/src/vchordrq/algorithm/k_means.rs b/src/utils/k_means.rs similarity index 98% rename from src/vchordrq/algorithm/k_means.rs rename to src/utils/k_means.rs index c9e7229..8b4fd9d 100644 --- a/src/vchordrq/algorithm/k_means.rs +++ b/src/utils/k_means.rs @@ -1,4 +1,4 @@ -use crate::vchordrq::algorithm::parallelism::{ParallelIterator, Parallelism}; +use super::parallelism::{ParallelIterator, Parallelism}; use base::scalar::*; use half::f16; use rand::rngs::StdRng; diff --git a/src/utils/mod.rs b/src/utils/mod.rs new file mode 100644 index 0000000..1b07dc6 --- /dev/null +++ b/src/utils/mod.rs @@ -0,0 +1,2 @@ +pub mod k_means; +pub mod parallelism; diff --git a/src/vchordrq/algorithm/parallelism.rs b/src/utils/parallelism.rs similarity index 100% rename from src/vchordrq/algorithm/parallelism.rs rename to src/utils/parallelism.rs diff --git a/src/vchordrq/algorithm/build.rs b/src/vchordrq/algorithm/build.rs index 5e44674..3b96f0c 100644 --- a/src/vchordrq/algorithm/build.rs +++ b/src/vchordrq/algorithm/build.rs @@ -1,14 +1,13 @@ use crate::postgres::BufferWriteGuard; use crate::postgres::Relation; -use crate::types::VchordrqBuildOptions; -use crate::types::VchordrqExternalBuildOptions; -use crate::types::VchordrqIndexingOptions; -use crate::types::VchordrqInternalBuildOptions; -use crate::vchordrq::algorithm::k_means; use crate::vchordrq::algorithm::rabitq; use crate::vchordrq::algorithm::tuples::*; use crate::vchordrq::algorithm::vectors; use crate::vchordrq::index::am_options::Opfamily; +use crate::vchordrq::types::VchordrqBuildOptions; +use crate::vchordrq::types::VchordrqExternalBuildOptions; +use crate::vchordrq::types::VchordrqIndexingOptions; +use crate::vchordrq::types::VchordrqInternalBuildOptions; use base::distance::DistanceKind; use base::index::VectorOptions; use base::scalar::ScalarLike; @@ -158,13 +157,13 @@ impl Structure { } let mut result = Vec::::new(); for w in internal_build.lists.iter().rev().copied().chain(once(1)) { - let means = crate::vchordrq::algorithm::parallelism::RayonParallelism::scoped( + let means = crate::utils::parallelism::RayonParallelism::scoped( internal_build.build_threads as _, Arc::new(|| { pgrx::check_for_interrupts!(); }), |parallelism| { - k_means::k_means( + crate::utils::k_means::k_means( parallelism, w as usize, vector_options.dims as usize, @@ -182,7 +181,8 @@ impl Structure { if let Some(structure) = result.last() { let mut children = vec![Vec::new(); means.len()]; for i in 0..structure.len() as u32 { - let target = k_means::k_means_lookup(&structure.means[i as usize], &means); + let target = + crate::utils::k_means::k_means_lookup(&structure.means[i as usize], &means); children[target].push(i); } let (means, children) = std::iter::zip(means, children) diff --git a/src/vchordrq/algorithm/mod.rs b/src/vchordrq/algorithm/mod.rs index a160118..88239a8 100644 --- a/src/vchordrq/algorithm/mod.rs +++ b/src/vchordrq/algorithm/mod.rs @@ -1,7 +1,5 @@ pub mod build; pub mod insert; -pub mod k_means; -pub mod parallelism; pub mod prewarm; pub mod rabitq; pub mod scan; diff --git a/src/vchordrq/algorithm/rabitq.rs b/src/vchordrq/algorithm/rabitq.rs index e7bec93..b3746b8 100644 --- a/src/vchordrq/algorithm/rabitq.rs +++ b/src/vchordrq/algorithm/rabitq.rs @@ -141,3 +141,12 @@ fn asymmetric_binary_dot_product(x: &[u64], y: &(Vec, Vec, Vec, V } (t0 << 0) + (t1 << 1) + (t2 << 2) + (t3 << 3) } + +pub fn distance(d: DistanceKind, lhs: &[f32], rhs: &[f32]) -> Distance { + match d { + DistanceKind::L2 => Distance::from_f32(f32::reduce_sum_of_d2(lhs, rhs)), + DistanceKind::Dot => Distance::from_f32(-f32::reduce_sum_of_xy(lhs, rhs)), + DistanceKind::Hamming => unimplemented!(), + DistanceKind::Jaccard => unimplemented!(), + } +} diff --git a/src/vchordrq/algorithm/vectors.rs b/src/vchordrq/algorithm/vectors.rs index bf4d782..c1f8627 100644 --- a/src/vchordrq/algorithm/vectors.rs +++ b/src/vchordrq/algorithm/vectors.rs @@ -1,6 +1,6 @@ use crate::postgres::Relation; +use crate::vchordrq::algorithm::rabitq::distance; use crate::vchordrq::algorithm::tuples::VectorTuple; -use crate::vchordrq::index::utils::distance; use base::distance::Distance; use base::distance::DistanceKind; diff --git a/src/vchordrq/index/am_options.rs b/src/vchordrq/index/am_options.rs index 1f271ba..971273f 100644 --- a/src/vchordrq/index/am_options.rs +++ b/src/vchordrq/index/am_options.rs @@ -1,7 +1,7 @@ use crate::datatype::memory_pgvector_vector::PgvectorVectorInput; use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput; use crate::datatype::typmod::Typmod; -use crate::types::VchordrqIndexingOptions; +use crate::vchordrq::types::VchordrqIndexingOptions; use base::distance::*; use base::index::*; use base::vector::*; diff --git a/src/vchordrq/index/utils.rs b/src/vchordrq/index/utils.rs index 35697ab..a5d85a3 100644 --- a/src/vchordrq/index/utils.rs +++ b/src/vchordrq/index/utils.rs @@ -1,5 +1,3 @@ -use base::distance::{Distance, DistanceKind}; -use base::scalar::ScalarLike; use base::search::*; pub fn pointer_to_ctid(pointer: Pointer) -> pgrx::pg_sys::ItemPointerData { @@ -20,12 +18,3 @@ pub fn ctid_to_pointer(ctid: pgrx::pg_sys::ItemPointerData) -> Pointer { value |= ctid.ip_posid as u64; Pointer::new(value) } - -pub fn distance(d: DistanceKind, lhs: &[f32], rhs: &[f32]) -> Distance { - match d { - DistanceKind::L2 => Distance::from_f32(f32::reduce_sum_of_d2(lhs, rhs)), - DistanceKind::Dot => Distance::from_f32(-f32::reduce_sum_of_xy(lhs, rhs)), - DistanceKind::Hamming => unimplemented!(), - DistanceKind::Jaccard => unimplemented!(), - } -} diff --git a/src/vchordrq/mod.rs b/src/vchordrq/mod.rs index d7ddff4..c2ae945 100644 --- a/src/vchordrq/mod.rs +++ b/src/vchordrq/mod.rs @@ -1,6 +1,7 @@ mod algorithm; mod gucs; mod index; +mod types; pub unsafe fn init() { unsafe { diff --git a/src/types.rs b/src/vchordrq/types.rs similarity index 100% rename from src/types.rs rename to src/vchordrq/types.rs diff --git a/src/vchordrqfscan/algorithm/build.rs b/src/vchordrqfscan/algorithm/build.rs new file mode 100644 index 0000000..3215e01 --- /dev/null +++ b/src/vchordrqfscan/algorithm/build.rs @@ -0,0 +1,412 @@ +use crate::postgres::BufferWriteGuard; +use crate::postgres::Relation; +use crate::vchordrqfscan::algorithm::rabitq; +use crate::vchordrqfscan::algorithm::tuples::*; +use crate::vchordrqfscan::index::am_options::Opfamily; +use crate::vchordrqfscan::types::VchordrqfscanBuildOptions; +use crate::vchordrqfscan::types::VchordrqfscanExternalBuildOptions; +use crate::vchordrqfscan::types::VchordrqfscanIndexingOptions; +use crate::vchordrqfscan::types::VchordrqfscanInternalBuildOptions; +use base::distance::DistanceKind; +use base::index::VectorOptions; +use base::scalar::ScalarLike; +use base::search::Pointer; +use rand::Rng; +use rkyv::ser::serializers::AllocSerializer; +use std::marker::PhantomData; +use std::sync::Arc; + +pub trait HeapRelation { + fn traverse(&self, progress: bool, callback: F) + where + F: FnMut((Pointer, Vec)); + fn opfamily(&self) -> Opfamily; +} + +pub trait Reporter { + fn tuples_total(&mut self, tuples_total: u64); +} + +pub fn build( + vector_options: VectorOptions, + vchordrqfscan_options: VchordrqfscanIndexingOptions, + heap_relation: T, + relation: Relation, + mut reporter: R, +) { + let dims = vector_options.dims; + let is_residual = + vchordrqfscan_options.residual_quantization && vector_options.d == DistanceKind::L2; + let structures = match vchordrqfscan_options.build { + VchordrqfscanBuildOptions::External(external_build) => Structure::extern_build( + vector_options.clone(), + heap_relation.opfamily(), + external_build.clone(), + ), + VchordrqfscanBuildOptions::Internal(internal_build) => { + let mut tuples_total = 0_u64; + let samples = { + let mut rand = rand::thread_rng(); + let max_number_of_samples = + internal_build.lists.last().unwrap().saturating_mul(256); + let mut samples = Vec::new(); + let mut number_of_samples = 0_u32; + heap_relation.traverse(false, |(_, vector)| { + assert_eq!(dims as usize, vector.len(), "invalid vector dimensions"); + if number_of_samples < max_number_of_samples { + samples.push(vector); + number_of_samples += 1; + } else { + let index = rand.gen_range(0..max_number_of_samples) as usize; + samples[index] = vector; + } + tuples_total += 1; + }); + samples + }; + reporter.tuples_total(tuples_total); + Structure::internal_build(vector_options.clone(), internal_build.clone(), samples) + } + }; + let mut meta = Tape::create(&relation, false); + assert_eq!(meta.first(), 0); + let mut forwards = Tape::::create(&relation, false); + assert_eq!(forwards.first(), 1); + let mut vectors = Tape::create(&relation, true); + assert_eq!(vectors.first(), 2); + let mut pointer_of_means = Vec::>::new(); + for i in 0..structures.len() { + let mut level = Vec::new(); + for j in 0..structures[i].len() { + let pointer = vectors.push(&VectorTuple { + payload: None, + vector: structures[i].means[j].clone(), + }); + level.push(pointer); + } + pointer_of_means.push(level); + } + let mut pointer_of_firsts = Vec::>::new(); + for i in 0..structures.len() { + let mut level = Vec::new(); + for j in 0..structures[i].len() { + if i == 0 { + let tape = Tape::::create(&relation, false); + level.push(tape.first()); + } else { + let mut tape = Tape::::create(&relation, false); + let mut cache = Vec::new(); + let h2_mean = &structures[i].means[j]; + let h2_children = &structures[i].children[j]; + for child in h2_children.iter().copied() { + let h1_mean = &structures[i - 1].means[child as usize]; + let code = if is_residual { + rabitq::code(dims, &f32::vector_sub(h1_mean, h2_mean)) + } else { + rabitq::code(dims, h1_mean) + }; + cache.push((child, code)); + if cache.len() == 32 { + let group = std::mem::take(&mut cache); + let codes = std::array::from_fn(|k| group[k].1.clone()); + let packed = rabitq::pack_codes(dims, codes); + tape.push(&Height1Tuple { + mask: [true; 32], + mean: std::array::from_fn(|k| { + pointer_of_means[i - 1][group[k].0 as usize] + }), + first: std::array::from_fn(|k| { + pointer_of_firsts[i - 1][group[k].0 as usize] + }), + dis_u_2: packed.dis_u_2, + factor_ppc: packed.factor_ppc, + factor_ip: packed.factor_ip, + factor_err: packed.factor_err, + t: packed.t, + }); + } + } + if !cache.is_empty() { + let group = std::mem::take(&mut cache); + let codes = std::array::from_fn(|k| { + if k < group.len() { + group[k].1.clone() + } else { + rabitq::dummy_code(dims) + } + }); + let packed = rabitq::pack_codes(dims, codes); + tape.push(&Height1Tuple { + mask: std::array::from_fn(|k| k < group.len()), + mean: std::array::from_fn(|k| { + if k < group.len() { + pointer_of_means[i - 1][group[k].0 as usize] + } else { + Default::default() + } + }), + first: std::array::from_fn(|k| { + if k < group.len() { + pointer_of_firsts[i - 1][group[k].0 as usize] + } else { + Default::default() + } + }), + dis_u_2: packed.dis_u_2, + factor_ppc: packed.factor_ppc, + factor_ip: packed.factor_ip, + factor_err: packed.factor_err, + t: packed.t, + }); + } + level.push(tape.first()); + } + } + pointer_of_firsts.push(level); + } + forwards.head.get_mut().get_opaque_mut().fast_forward = vectors.first(); + meta.push(&MetaTuple { + dims, + height_of_root: structures.len() as u32, + is_residual, + vectors_first: vectors.first(), + forwards_first: forwards.first(), + mean: pointer_of_means.last().unwrap()[0], + first: pointer_of_firsts.last().unwrap()[0], + }); +} + +struct Structure { + means: Vec>, + children: Vec>, +} + +impl Structure { + fn len(&self) -> usize { + self.children.len() + } + fn internal_build( + vector_options: VectorOptions, + internal_build: VchordrqfscanInternalBuildOptions, + mut samples: Vec>, + ) -> Vec { + use std::iter::once; + for sample in samples.iter_mut() { + *sample = crate::projection::project(sample); + } + let mut result = Vec::::new(); + for w in internal_build.lists.iter().rev().copied().chain(once(1)) { + let means = crate::utils::parallelism::RayonParallelism::scoped( + internal_build.build_threads as _, + Arc::new(|| { + pgrx::check_for_interrupts!(); + }), + |parallelism| { + crate::utils::k_means::k_means( + parallelism, + w as usize, + vector_options.dims as usize, + if let Some(structure) = result.last() { + &structure.means + } else { + &samples + }, + internal_build.spherical_centroids, + 10, + ) + }, + ) + .expect("failed to create thread pool"); + if let Some(structure) = result.last() { + let mut children = vec![Vec::new(); means.len()]; + for i in 0..structure.len() as u32 { + let target = + crate::utils::k_means::k_means_lookup(&structure.means[i as usize], &means); + children[target].push(i); + } + let (means, children) = std::iter::zip(means, children) + .filter(|(_, x)| !x.is_empty()) + .unzip::<_, _, Vec<_>, Vec<_>>(); + result.push(Structure { means, children }); + } else { + let children = vec![Vec::new(); means.len()]; + result.push(Structure { means, children }); + } + } + result + } + fn extern_build( + vector_options: VectorOptions, + _opfamily: Opfamily, + external_build: VchordrqfscanExternalBuildOptions, + ) -> Vec { + use std::collections::BTreeMap; + let VchordrqfscanExternalBuildOptions { table } = external_build; + let query = format!("SELECT id, parent, vector FROM {table};"); + let mut parents = BTreeMap::new(); + let mut vectors = BTreeMap::new(); + pgrx::spi::Spi::connect(|client| { + use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput; + use base::vector::VectorBorrowed; + use pgrx::pg_sys::panic::ErrorReportable; + let table = client.select(&query, None, None).unwrap_or_report(); + for row in table { + let id: Option = row.get_by_name("id").unwrap(); + let parent: Option = row.get_by_name("parent").unwrap(); + let vector: Option = row.get_by_name("vector").unwrap(); + let id = id.expect("extern build: id could not be NULL"); + let vector = vector.expect("extern build: vector could not be NULL"); + let pop = parents.insert(id, parent); + if pop.is_some() { + pgrx::error!( + "external build: there are at least two lines have same id, id = {id}" + ); + } + if vector_options.dims != vector.as_borrowed().dims() { + pgrx::error!("extern build: incorrect dimension, id = {id}"); + } + vectors.insert(id, crate::projection::project(vector.slice())); + } + }); + let mut children = parents + .keys() + .map(|x| (*x, Vec::new())) + .collect::>(); + let mut root = None; + for (&id, &parent) in parents.iter() { + if let Some(parent) = parent { + if let Some(parent) = children.get_mut(&parent) { + parent.push(id); + } else { + pgrx::error!( + "external build: parent does not exist, id = {id}, parent = {parent}" + ); + } + } else { + if let Some(root) = root { + pgrx::error!("external build: two root, id = {root}, id = {id}"); + } else { + root = Some(id); + } + } + } + let Some(root) = root else { + pgrx::error!("extern build: there are no root"); + }; + let mut heights = BTreeMap::<_, _>::new(); + fn dfs_for_heights( + heights: &mut BTreeMap>, + children: &BTreeMap>, + u: i32, + ) { + if heights.contains_key(&u) { + pgrx::error!("extern build: detect a cycle, id = {u}"); + } + heights.insert(u, None); + let mut height = None; + for &v in children[&u].iter() { + dfs_for_heights(heights, children, v); + let new = heights[&v].unwrap() + 1; + if let Some(height) = height { + if height != new { + pgrx::error!("extern build: two heights, id = {u}"); + } + } else { + height = Some(new); + } + } + if height.is_none() { + height = Some(1); + } + heights.insert(u, height); + } + dfs_for_heights(&mut heights, &children, root); + let heights = heights + .into_iter() + .map(|(k, v)| (k, v.expect("not a connected graph"))) + .collect::>(); + if !(1..=8).contains(&(heights[&root] - 1)) { + pgrx::error!( + "extern build: unexpected tree height, height = {}", + heights[&root] + ); + } + let mut cursors = vec![0_u32; 1 + heights[&root] as usize]; + let mut labels = BTreeMap::new(); + for id in parents.keys().copied() { + let height = heights[&id]; + let cursor = cursors[height as usize]; + labels.insert(id, (height, cursor)); + cursors[height as usize] += 1; + } + fn extract( + height: u32, + labels: &BTreeMap, + vectors: &BTreeMap>, + children: &BTreeMap>, + ) -> (Vec>, Vec>) { + labels + .iter() + .filter(|(_, &(h, _))| h == height) + .map(|(id, _)| { + ( + vectors[id].clone(), + children[id].iter().map(|id| labels[id].1).collect(), + ) + }) + .unzip() + } + let mut result = Vec::new(); + for height in 1..=heights[&root] { + let (means, children) = extract(height, &labels, &vectors, &children); + result.push(Structure { means, children }); + } + result + } +} + +struct Tape<'a, T> { + relation: &'a Relation, + head: BufferWriteGuard, + first: u32, + tracking_freespace: bool, + _phantom: PhantomData T>, +} + +impl<'a, T> Tape<'a, T> { + fn create(relation: &'a Relation, tracking_freespace: bool) -> Self { + let head = relation.extend(tracking_freespace); + let first = head.id(); + Self { + relation, + head, + first, + tracking_freespace, + _phantom: PhantomData, + } + } + fn first(&self) -> u32 { + self.first + } +} + +impl<'a, T> Tape<'a, T> +where + T: rkyv::Serialize>, +{ + fn push(&mut self, x: &T) -> (u32, u16) { + let bytes = rkyv::to_bytes(x).expect("failed to serialize"); + if let Some(i) = self.head.get_mut().alloc(&bytes) { + (self.head.id(), i) + } else { + let next = self.relation.extend(self.tracking_freespace); + self.head.get_mut().get_opaque_mut().next = next.id(); + self.head = next; + if let Some(i) = self.head.get_mut().alloc(&bytes) { + (self.head.id(), i) + } else { + panic!("tuple is too large to fit in a fresh page") + } + } + } +} diff --git a/src/vchordrqfscan/algorithm/insert.rs b/src/vchordrqfscan/algorithm/insert.rs new file mode 100644 index 0000000..d17c22e --- /dev/null +++ b/src/vchordrqfscan/algorithm/insert.rs @@ -0,0 +1,258 @@ +use crate::postgres::Relation; +use crate::vchordrqfscan::algorithm::rabitq; +use crate::vchordrqfscan::algorithm::rabitq::distance; +use crate::vchordrqfscan::algorithm::rabitq::fscan_process_lowerbound; +use crate::vchordrqfscan::algorithm::tuples::*; +use base::always_equal::AlwaysEqual; +use base::distance::Distance; +use base::distance::DistanceKind; +use base::scalar::ScalarLike; +use base::search::Pointer; +use std::cmp::Reverse; +use std::collections::BinaryHeap; + +pub fn insert(relation: Relation, payload: Pointer, vector: Vec, distance_kind: DistanceKind) { + let meta_guard = relation.read(0); + let meta_tuple = meta_guard + .get() + .get(1) + .map(rkyv::check_archived_root::) + .expect("data corruption") + .expect("data corruption"); + let dims = meta_tuple.dims; + assert_eq!(dims as usize, vector.len(), "invalid vector dimensions"); + let vector = crate::projection::project(&vector); + let is_residual = meta_tuple.is_residual; + let default_lut = if !is_residual { + Some(rabitq::fscan_preprocess(&vector)) + } else { + None + }; + let h0_vector = 'h0_vector: { + let tuple = rkyv::to_bytes::<_, 8192>(&VectorTuple { + vector: vector.clone(), + payload: Some(payload.as_u64()), + }) + .unwrap(); + if let Some(mut write) = relation.search(tuple.len()) { + let i = write.get_mut().alloc(&tuple).unwrap(); + break 'h0_vector (write.id(), i); + } + let mut current = relation.read(1).get().get_opaque().fast_forward; + let mut changed = false; + loop { + let read = relation.read(current); + let flag = 'flag: { + if read.get().freespace() as usize >= tuple.len() { + break 'flag true; + } + if read.get().get_opaque().next == u32::MAX { + break 'flag true; + } + false + }; + if flag { + drop(read); + let mut write = relation.write(current); + if let Some(i) = write.get_mut().alloc(&tuple) { + break (current, i); + } + if write.get().get_opaque().next == u32::MAX { + if changed { + relation.write(1).get_mut().get_opaque_mut().fast_forward = write.id(); + } + let mut extend = relation.extend(true); + write.get_mut().get_opaque_mut().next = extend.id(); + if let Some(i) = extend.get_mut().alloc(&tuple) { + break (extend.id(), i); + } else { + panic!("a tuple cannot even be fit in a fresh page"); + } + } + current = write.get().get_opaque().next; + } else { + current = read.get().get_opaque().next; + } + changed = true; + } + }; + let h0_payload = payload.as_u64(); + let mut list = ( + meta_tuple.first, + if is_residual { + let vector_guard = relation.read(meta_tuple.mean.0); + let vector_tuple = vector_guard + .get() + .get(meta_tuple.mean.1) + .map(rkyv::check_archived_root::) + .expect("data corruption") + .expect("data corruption"); + Some(vector_tuple.vector.to_vec()) + } else { + None + }, + ); + let make_list = |list: (u32, Option>)| { + let mut results = Vec::new(); + { + let lut = if is_residual { + &rabitq::fscan_preprocess(&f32::vector_sub(&vector, list.1.as_ref().unwrap())) + } else { + default_lut.as_ref().unwrap() + }; + let mut current = list.0; + while current != u32::MAX { + let h1_guard = relation.read(current); + for i in 1..=h1_guard.get().len() { + let h1_tuple = h1_guard + .get() + .get(i) + .map(rkyv::check_archived_root::) + .expect("data corruption") + .expect("data corruption"); + let lowerbounds = fscan_process_lowerbound( + distance_kind, + dims, + lut, + ( + &h1_tuple.dis_u_2, + &h1_tuple.factor_ppc, + &h1_tuple.factor_ip, + &h1_tuple.factor_err, + &h1_tuple.t, + ), + 1.9, + ); + for j in 0..32 { + if h1_tuple.mask[j] { + results.push(( + Reverse(lowerbounds[j]), + AlwaysEqual(h1_tuple.mean[j]), + AlwaysEqual(h1_tuple.first[j]), + )); + } + } + } + current = h1_guard.get().get_opaque().next; + } + } + let mut heap = BinaryHeap::from(results); + let mut cache = BinaryHeap::<(Reverse, _, _)>::new(); + { + while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) { + let (_, AlwaysEqual(mean), AlwaysEqual(first)) = heap.pop().unwrap(); + let vector_guard = relation.read(mean.0); + let vector_tuple = vector_guard + .get() + .get(mean.1) + .map(rkyv::check_archived_root::) + .expect("data corruption") + .expect("data corruption"); + let dis_u = distance(distance_kind, &vector, &vector_tuple.vector); + cache.push(( + Reverse(dis_u), + AlwaysEqual(first), + AlwaysEqual(if is_residual { + Some(vector_tuple.vector.to_vec()) + } else { + None + }), + )); + } + let (_, AlwaysEqual(first), AlwaysEqual(mean)) = cache.pop().unwrap(); + (first, mean) + } + }; + for _ in (1..meta_tuple.height_of_root).rev() { + list = make_list(list); + } + let code = if is_residual { + rabitq::code(dims, &f32::vector_sub(&vector, list.1.as_ref().unwrap())) + } else { + rabitq::code(dims, &vector) + }; + let dummy = rkyv::to_bytes::<_, 8192>(&Height0Tuple { + mask: [false; 32], + mean: [(0, 0); 32], + payload: [0; 32], + dis_u_2: [0.0f32; 32], + factor_ppc: [0.0f32; 32], + factor_ip: [0.0f32; 32], + factor_err: [0.0f32; 32], + t: vec![0; (dims.div_ceil(4) * 16) as usize], + }) + .unwrap(); + let first = list.0; + assert!(first != u32::MAX); + let mut current = first; + loop { + let read = relation.read(current); + let flag = 'flag: { + for i in 1..=read.get().len() { + let h0_tuple = read + .get() + .get(i) + .map(rkyv::check_archived_root::) + .expect("data corruption") + .expect("data corruption"); + if h0_tuple.mask.iter().any(|x| *x) { + break 'flag true; + } + } + if read.get().freespace() as usize >= dummy.len() { + break 'flag true; + } + if read.get().get_opaque().next == u32::MAX { + break 'flag true; + } + false + }; + if flag { + drop(read); + let mut write = relation.write(current); + for i in 1..=write.get().len() { + let flag = put( + write.get_mut().get_mut(i).expect("data corruption"), + dims, + &code, + h0_vector, + h0_payload, + ); + if flag { + return; + } + } + if let Some(i) = write.get_mut().alloc(&dummy) { + let flag = put( + write.get_mut().get_mut(i).expect("data corruption"), + dims, + &code, + h0_vector, + h0_payload, + ); + assert!(flag, "a put fails even on a fresh tuple"); + return; + } + if write.get().get_opaque().next == u32::MAX { + let mut extend = relation.extend(false); + write.get_mut().get_opaque_mut().next = extend.id(); + if let Some(i) = extend.get_mut().alloc(&dummy) { + let flag = put( + extend.get_mut().get_mut(i).expect("data corruption"), + dims, + &code, + h0_vector, + h0_payload, + ); + assert!(flag, "a put fails even on a fresh tuple"); + return; + } else { + panic!("a tuple cannot even be fit in a fresh page"); + } + } + current = write.get().get_opaque().next; + } else { + current = read.get().get_opaque().next; + } + } +} diff --git a/src/vchordrqfscan/algorithm/mod.rs b/src/vchordrqfscan/algorithm/mod.rs new file mode 100644 index 0000000..448d919 --- /dev/null +++ b/src/vchordrqfscan/algorithm/mod.rs @@ -0,0 +1,7 @@ +pub mod build; +pub mod insert; +pub mod prewarm; +pub mod rabitq; +pub mod scan; +pub mod tuples; +pub mod vacuum; diff --git a/src/vchordrqfscan/algorithm/prewarm.rs b/src/vchordrqfscan/algorithm/prewarm.rs new file mode 100644 index 0000000..ec7642a --- /dev/null +++ b/src/vchordrqfscan/algorithm/prewarm.rs @@ -0,0 +1,107 @@ +use crate::postgres::Relation; +use crate::vchordrqfscan::algorithm::tuples::*; +use std::fmt::Write; + +pub fn prewarm(relation: Relation, height: i32) -> String { + let mut message = String::new(); + let meta_guard = relation.read(0); + let meta_tuple = meta_guard + .get() + .get(1) + .map(rkyv::check_archived_root::) + .expect("data corruption") + .expect("data corruption"); + writeln!(message, "height of root: {}", meta_tuple.height_of_root).unwrap(); + let prewarm_max_height = if height < 0 { 0 } else { height as u32 }; + if prewarm_max_height > meta_tuple.height_of_root { + return message; + } + let mut lists = { + let mut results = Vec::new(); + let counter = 1_usize; + { + let vector_guard = relation.read(meta_tuple.mean.0); + let vector_tuple = vector_guard + .get() + .get(meta_tuple.mean.1) + .map(rkyv::check_archived_root::) + .expect("data corruption") + .expect("data corruption"); + let _ = vector_tuple; + results.push(meta_tuple.first); + } + writeln!(message, "number of tuples: {}", results.len()).unwrap(); + writeln!(message, "number of pages: {}", counter).unwrap(); + results + }; + let mut make_lists = |lists| { + let mut counter = 0_usize; + let mut results = Vec::new(); + for list in lists { + let mut current = list; + while current != u32::MAX { + counter += 1; + pgrx::check_for_interrupts!(); + let h1_guard = relation.read(current); + for i in 1..=h1_guard.get().len() { + let h1_tuple = h1_guard + .get() + .get(i) + .map(rkyv::check_archived_root::) + .expect("data corruption") + .expect("data corruption"); + for j in 0..32 { + if h1_tuple.mask[j] { + results.push(h1_tuple.first[j]); + let mean = h1_tuple.mean[j]; + let vector_guard = relation.read(mean.0); + let vector_tuple = vector_guard + .get() + .get(mean.1) + .map(rkyv::check_archived_root::) + .expect("data corruption") + .expect("data corruption"); + let _ = vector_tuple; + } + } + } + current = h1_guard.get().get_opaque().next; + } + } + writeln!(message, "number of tuples: {}", results.len()).unwrap(); + writeln!(message, "number of pages: {}", counter).unwrap(); + results + }; + for _ in (std::cmp::max(1, prewarm_max_height)..meta_tuple.height_of_root).rev() { + lists = make_lists(lists); + } + if prewarm_max_height == 0 { + let mut counter = 0_usize; + let mut results = Vec::new(); + for list in lists { + let mut current = list; + while current != u32::MAX { + counter += 1; + pgrx::check_for_interrupts!(); + let h0_guard = relation.read(current); + for i in 1..=h0_guard.get().len() { + let h0_tuple = h0_guard + .get() + .get(i) + .map(rkyv::check_archived_root::) + .expect("data corruption") + .expect("data corruption"); + for j in 0..32 { + if h0_tuple.mask[j] { + results.push(()); + } + } + } + current = h0_guard.get().get_opaque().next; + } + } + writeln!(message, "number of tuples: {}", results.len()).unwrap(); + writeln!(message, "number of pages: {}", counter).unwrap(); + } + message +} diff --git a/src/vchordrqfscan/algorithm/rabitq.rs b/src/vchordrqfscan/algorithm/rabitq.rs new file mode 100644 index 0000000..cf72ca5 --- /dev/null +++ b/src/vchordrqfscan/algorithm/rabitq.rs @@ -0,0 +1,171 @@ +use base::distance::{Distance, DistanceKind}; +use base::scalar::ScalarLike; +use quantization::utils::InfiniteByteChunks; + +#[derive(Debug, Clone)] +pub struct Code { + pub dis_u_2: f32, + pub factor_ppc: f32, + pub factor_ip: f32, + pub factor_err: f32, + pub signs: Vec, +} + +pub fn code(dims: u32, vector: &[f32]) -> Code { + let sum_of_abs_x = f32::reduce_sum_of_abs_x(vector); + let sum_of_x_2 = f32::reduce_sum_of_x2(vector); + let dis_u = sum_of_x_2.sqrt(); + let x0 = sum_of_abs_x / (sum_of_x_2 * (dims as f32)).sqrt(); + let x_x0 = dis_u / x0; + let fac_norm = (dims as f32).sqrt(); + let max_x1 = 1.0f32 / (dims as f32 - 1.0).sqrt(); + let factor_err = 2.0f32 * max_x1 * (x_x0 * x_x0 - dis_u * dis_u).sqrt(); + let factor_ip = -2.0f32 / fac_norm * x_x0; + let cnt_pos = vector + .iter() + .map(|x| x.is_sign_positive() as i32) + .sum::(); + let cnt_neg = vector + .iter() + .map(|x| x.is_sign_negative() as i32) + .sum::(); + let factor_ppc = factor_ip * (cnt_pos - cnt_neg) as f32; + let mut signs = Vec::new(); + for i in 0..dims { + signs.push(vector[i as usize].is_sign_positive() as u8); + } + Code { + dis_u_2: sum_of_x_2, + factor_ppc, + factor_ip, + factor_err, + signs, + } +} + +pub fn dummy_code(dims: u32) -> Code { + Code { + dis_u_2: 0.0, + factor_ppc: 0.0, + factor_ip: 0.0, + factor_err: 0.0, + signs: vec![0; dims as _], + } +} + +pub struct PackedCodes { + pub dis_u_2: [f32; 32], + pub factor_ppc: [f32; 32], + pub factor_ip: [f32; 32], + pub factor_err: [f32; 32], + pub t: Vec, +} + +pub fn pack_codes(dims: u32, codes: [Code; 32]) -> PackedCodes { + PackedCodes { + dis_u_2: std::array::from_fn(|i| codes[i].dis_u_2), + factor_ppc: std::array::from_fn(|i| codes[i].factor_ppc), + factor_ip: std::array::from_fn(|i| codes[i].factor_ip), + factor_err: std::array::from_fn(|i| codes[i].factor_err), + t: { + let signs = codes.map(|code| { + InfiniteByteChunks::new(code.signs.into_iter()) + .map(|[b0, b1, b2, b3]| b0 | b1 << 1 | b2 << 2 | b3 << 3) + .take(dims.div_ceil(4) as usize) + .collect::>() + }); + quantization::fast_scan::b4::pack(dims.div_ceil(4), signs).collect() + }, + } +} + +pub fn fscan_preprocess(vector: &[f32]) -> (f32, f32, f32, f32, Vec) { + use quantization::quantize; + let dis_v_2 = f32::reduce_sum_of_x2(vector); + let (k, b, qvector) = quantize::quantize::<15>(vector); + let qvector_sum = if vector.len() <= 4369 { + quantize::reduce_sum_of_x_as_u16(&qvector) as f32 + } else { + quantize::reduce_sum_of_x_as_u32(&qvector) as f32 + }; + (dis_v_2, b, k, qvector_sum, compress(qvector)) +} + +pub fn fscan_process_lowerbound( + distance_kind: DistanceKind, + dims: u32, + lut: &(f32, f32, f32, f32, Vec), + (dis_u_2, factor_ppc, factor_ip, factor_err, t): ( + &[f32; 32], + &[f32; 32], + &[f32; 32], + &[f32; 32], + &[u8], + ), + epsilon: f32, +) -> [Distance; 32] { + let &(dis_v_2, b, k, qvector_sum, ref s) = lut; + let r = quantization::fast_scan::b4::fast_scan_b4(dims.div_ceil(4), t, s); + match distance_kind { + DistanceKind::L2 => std::array::from_fn(|i| { + let rough = dis_u_2[i] + + dis_v_2 + + b * factor_ppc[i] + + ((2.0 * r[i] as f32) - qvector_sum) * factor_ip[i] * k; + let err = factor_err[i] * dis_v_2.sqrt(); + Distance::from_f32(rough - epsilon * err) + }), + DistanceKind::Dot => std::array::from_fn(|i| { + let rough = 0.5 * b * factor_ppc[i] + + 0.5 * ((2.0 * r[i] as f32) - qvector_sum) * factor_ip[i] * k; + let err = 0.5 * factor_err[i] * dis_v_2.sqrt(); + Distance::from_f32(rough - epsilon * err) + }), + DistanceKind::Hamming => unreachable!(), + DistanceKind::Jaccard => unreachable!(), + } +} + +fn compress(mut qvector: Vec) -> Vec { + let dims = qvector.len() as u32; + let width = dims.div_ceil(4); + qvector.resize(qvector.len().next_multiple_of(4), 0); + let mut t = vec![0u8; width as usize * 16]; + for i in 0..width as usize { + unsafe { + // this hint is used to skip bound checks + std::hint::assert_unchecked(4 * i + 3 < qvector.len()); + std::hint::assert_unchecked(16 * i + 15 < t.len()); + } + let t0 = qvector[4 * i + 0]; + let t1 = qvector[4 * i + 1]; + let t2 = qvector[4 * i + 2]; + let t3 = qvector[4 * i + 3]; + t[16 * i + 0b0000] = 0; + t[16 * i + 0b0001] = t0; + t[16 * i + 0b0010] = t1; + t[16 * i + 0b0011] = t1 + t0; + t[16 * i + 0b0100] = t2; + t[16 * i + 0b0101] = t2 + t0; + t[16 * i + 0b0110] = t2 + t1; + t[16 * i + 0b0111] = t2 + t1 + t0; + t[16 * i + 0b1000] = t3; + t[16 * i + 0b1001] = t3 + t0; + t[16 * i + 0b1010] = t3 + t1; + t[16 * i + 0b1011] = t3 + t1 + t0; + t[16 * i + 0b1100] = t3 + t2; + t[16 * i + 0b1101] = t3 + t2 + t0; + t[16 * i + 0b1110] = t3 + t2 + t1; + t[16 * i + 0b1111] = t3 + t2 + t1 + t0; + } + t +} + +pub fn distance(d: DistanceKind, lhs: &[f32], rhs: &[f32]) -> Distance { + match d { + DistanceKind::L2 => Distance::from_f32(f32::reduce_sum_of_d2(lhs, rhs)), + DistanceKind::Dot => Distance::from_f32(-f32::reduce_sum_of_xy(lhs, rhs)), + DistanceKind::Hamming => unimplemented!(), + DistanceKind::Jaccard => unimplemented!(), + } +} diff --git a/src/vchordrqfscan/algorithm/scan.rs b/src/vchordrqfscan/algorithm/scan.rs new file mode 100644 index 0000000..63264f1 --- /dev/null +++ b/src/vchordrqfscan/algorithm/scan.rs @@ -0,0 +1,197 @@ +use crate::postgres::Relation; +use crate::vchordrqfscan::algorithm::rabitq; +use crate::vchordrqfscan::algorithm::rabitq::distance; +use crate::vchordrqfscan::algorithm::rabitq::fscan_process_lowerbound; +use crate::vchordrqfscan::algorithm::tuples::*; +use base::always_equal::AlwaysEqual; +use base::distance::Distance; +use base::distance::DistanceKind; +use base::scalar::ScalarLike; +use base::search::Pointer; +use std::cmp::Reverse; +use std::collections::BinaryHeap; + +pub fn scan( + relation: Relation, + vector: Vec, + distance_kind: DistanceKind, + probes: Vec, + epsilon: f32, +) -> impl Iterator { + let meta_guard = relation.read(0); + let meta_tuple = meta_guard + .get() + .get(1) + .map(rkyv::check_archived_root::) + .expect("data corruption") + .expect("data corruption"); + let dims = meta_tuple.dims; + let height_of_root = meta_tuple.height_of_root; + assert_eq!(dims as usize, vector.len(), "invalid vector dimensions"); + assert_eq!(height_of_root as usize, 1 + probes.len(), "invalid probes"); + let vector = crate::projection::project(&vector); + let is_residual = meta_tuple.is_residual; + let default_lut = if !is_residual { + Some(rabitq::fscan_preprocess(&vector)) + } else { + None + }; + let mut lists: Vec<_> = vec![( + meta_tuple.first, + if is_residual { + let vector_guard = relation.read(meta_tuple.mean.0); + let vector_tuple = vector_guard + .get() + .get(meta_tuple.mean.1) + .map(rkyv::check_archived_root::) + .expect("data corruption") + .expect("data corruption"); + Some(vector_tuple.vector.to_vec()) + } else { + None + }, + )]; + let make_lists = |lists: Vec<(u32, Option>)>, probes| { + let mut results = Vec::new(); + for list in lists { + let lut = if is_residual { + &rabitq::fscan_preprocess(&f32::vector_sub(&vector, list.1.as_ref().unwrap())) + } else { + default_lut.as_ref().unwrap() + }; + let mut current = list.0; + while current != u32::MAX { + let h1_guard = relation.read(current); + for i in 1..=h1_guard.get().len() { + let h1_tuple = h1_guard + .get() + .get(i) + .map(rkyv::check_archived_root::) + .expect("data corruption") + .expect("data corruption"); + let lowerbounds = fscan_process_lowerbound( + distance_kind, + dims, + lut, + ( + &h1_tuple.dis_u_2, + &h1_tuple.factor_ppc, + &h1_tuple.factor_ip, + &h1_tuple.factor_err, + &h1_tuple.t, + ), + epsilon, + ); + for j in 0..32 { + if h1_tuple.mask[j] { + results.push(( + Reverse(lowerbounds[j]), + AlwaysEqual(h1_tuple.mean[j]), + AlwaysEqual(h1_tuple.first[j]), + )); + } + } + } + current = h1_guard.get().get_opaque().next; + } + } + let mut heap = BinaryHeap::from(results); + let mut cache = BinaryHeap::<(Reverse, _, _)>::new(); + std::iter::from_fn(|| { + while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) { + let (_, AlwaysEqual(mean), AlwaysEqual(first)) = heap.pop().unwrap(); + let vector_guard = relation.read(mean.0); + let vector_tuple = vector_guard + .get() + .get(mean.1) + .map(rkyv::check_archived_root::) + .expect("data corruption") + .expect("data corruption"); + let dis_u = distance(distance_kind, &vector, &vector_tuple.vector); + cache.push(( + Reverse(dis_u), + AlwaysEqual(first), + AlwaysEqual(if is_residual { + Some(vector_tuple.vector.to_vec()) + } else { + None + }), + )); + } + let (_, AlwaysEqual(first), AlwaysEqual(mean)) = cache.pop()?; + Some((first, mean)) + }) + .take(probes as usize) + .collect() + }; + for i in (1..meta_tuple.height_of_root).rev() { + lists = make_lists(lists, probes[i as usize - 1]); + } + { + let mut results = Vec::new(); + for list in lists { + let lut = if is_residual { + &rabitq::fscan_preprocess(&f32::vector_sub(&vector, list.1.as_ref().unwrap())) + } else { + default_lut.as_ref().unwrap() + }; + let mut current = list.0; + while current != u32::MAX { + let h0_guard = relation.read(current); + for i in 1..=h0_guard.get().len() { + let h0_tuple = h0_guard + .get() + .get(i) + .map(rkyv::check_archived_root::) + .expect("data corruption") + .expect("data corruption"); + let lowerbounds = fscan_process_lowerbound( + distance_kind, + dims, + lut, + ( + &h0_tuple.dis_u_2, + &h0_tuple.factor_ppc, + &h0_tuple.factor_ip, + &h0_tuple.factor_err, + &h0_tuple.t, + ), + epsilon, + ); + for j in 0..32 { + if h0_tuple.mask[j] { + results.push(( + Reverse(lowerbounds[j]), + AlwaysEqual(h0_tuple.mean[j]), + AlwaysEqual(h0_tuple.payload[j]), + )); + } + } + } + current = h0_guard.get().get_opaque().next; + } + } + let mut heap = BinaryHeap::from(results); + let mut cache = BinaryHeap::<(Reverse, _)>::new(); + std::iter::from_fn(move || { + while !heap.is_empty() && heap.peek().map(|x| x.0) > cache.peek().map(|x| x.0) { + let (_, AlwaysEqual(mean), AlwaysEqual(pay_u)) = heap.pop().unwrap(); + let vector_guard = relation.read(mean.0); + let Some(vector_tuple) = vector_guard.get().get(mean.1) else { + // fails consistency check + continue; + }; + let vector_tuple = rkyv::check_archived_root::(vector_tuple) + .expect("data corruption"); + if vector_tuple.payload != Some(pay_u) { + // fails consistency check + continue; + } + let dis_u = distance(distance_kind, &vector, &vector_tuple.vector); + cache.push((Reverse(dis_u), AlwaysEqual(pay_u))); + } + let (Reverse(dis_u), AlwaysEqual(pay_u)) = cache.pop()?; + Some((dis_u, Pointer::new(pay_u))) + }) + } +} diff --git a/src/vchordrqfscan/algorithm/tuples.rs b/src/vchordrqfscan/algorithm/tuples.rs new file mode 100644 index 0000000..3b43dac --- /dev/null +++ b/src/vchordrqfscan/algorithm/tuples.rs @@ -0,0 +1,139 @@ +use crate::vchordrqfscan::algorithm::rabitq; +use rkyv::{Archive, Deserialize, Serialize}; + +#[derive(Clone, PartialEq, Archive, Serialize, Deserialize)] +#[archive(check_bytes)] +pub struct MetaTuple { + pub dims: u32, + pub height_of_root: u32, + pub is_residual: bool, + pub vectors_first: u32, + pub forwards_first: u32, + // raw vector + pub mean: (u32, u16), + // for meta tuple, it's pointers to next level + pub first: u32, +} + +#[derive(Clone, PartialEq, Archive, Serialize, Deserialize)] +#[archive(check_bytes)] +pub struct VectorTuple { + pub vector: Vec, + // this field is saved only for vacuum + pub payload: Option, +} + +#[derive(Clone, PartialEq, Archive, Serialize, Deserialize)] +#[archive(check_bytes)] +pub struct Height1Tuple { + pub mask: [bool; 32], + // raw vector + pub mean: [(u32, u16); 32], + // for height 1 tuple, it's pointers to next level + pub first: [u32; 32], + // RaBitQ algorithm + pub dis_u_2: [f32; 32], + pub factor_ppc: [f32; 32], + pub factor_ip: [f32; 32], + pub factor_err: [f32; 32], + pub t: Vec, +} + +#[derive(Clone, PartialEq, Archive, Serialize, Deserialize)] +#[archive(check_bytes)] +pub struct Height0Tuple { + pub mask: [bool; 32], + // raw vector + pub mean: [(u32, u16); 32], + // for height 0 tuple, it's pointers to heap relation + pub payload: [u64; 32], + // RaBitQ algorithm + pub dis_u_2: [f32; 32], + pub factor_ppc: [f32; 32], + pub factor_ip: [f32; 32], + pub factor_err: [f32; 32], + pub t: Vec, +} + +pub fn put( + bytes: &mut [u8], + dims: u32, + code: &rabitq::Code, + vector: (u32, u16), + payload: u64, +) -> bool { + // todo: use mutable api + let mut x = rkyv::from_bytes::(bytes).expect("data corruption"); + for j in 0..32 { + if !x.mask[j] { + x.mean[j] = vector; + x.payload[j] = payload; + x.mask[j] = true; + x.dis_u_2[j] = code.dis_u_2; + x.factor_ppc[j] = code.factor_ppc; + x.factor_ip[j] = code.factor_ip; + x.factor_err[j] = code.factor_err; + let width = dims.div_ceil(4) as usize; + let table = [ + (0, 0), + (2, 0), + (4, 0), + (6, 0), + (8, 0), + (10, 0), + (12, 0), + (14, 0), + (1, 0), + (3, 0), + (5, 0), + (7, 0), + (9, 0), + (11, 0), + (13, 0), + (15, 0), + (0, 1), + (2, 1), + (4, 1), + (6, 1), + (8, 1), + (10, 1), + (12, 1), + (14, 1), + (1, 1), + (3, 1), + (5, 1), + (7, 1), + (9, 1), + (11, 1), + (13, 1), + (15, 1), + ]; + let pos = table[j].0; + let mask = match table[j].1 { + 0 => 0xf0, + 1 => 0x0f, + _ => unreachable!(), + }; + let shift = match table[j].1 { + 0 => 0, + 1 => 4, + _ => unreachable!(), + }; + let mut buffer = vec![0u8; width]; + for j in 0..width { + let b0 = code.signs.get(4 * j + 0).copied().unwrap_or_default(); + let b1 = code.signs.get(4 * j + 1).copied().unwrap_or_default(); + let b2 = code.signs.get(4 * j + 2).copied().unwrap_or_default(); + let b3 = code.signs.get(4 * j + 3).copied().unwrap_or_default(); + buffer[j] = b0 | b1 << 1 | b2 << 2 | b3 << 3; + } + for j in 0..width { + x.t[16 * j + pos] &= mask; + x.t[16 * j + pos] |= buffer[j] << shift; + } + bytes.copy_from_slice(&rkyv::to_bytes::<_, 8192>(&x).unwrap()); + return true; + } + } + false +} diff --git a/src/vchordrqfscan/algorithm/vacuum.rs b/src/vchordrqfscan/algorithm/vacuum.rs new file mode 100644 index 0000000..1a77b98 --- /dev/null +++ b/src/vchordrqfscan/algorithm/vacuum.rs @@ -0,0 +1,141 @@ +use crate::postgres::Relation; +use crate::vchordrqfscan::algorithm::tuples::VectorTuple; +use crate::vchordrqfscan::algorithm::tuples::*; +use base::search::Pointer; + +pub fn vacuum(relation: Relation, delay: impl Fn(), callback: impl Fn(Pointer) -> bool) { + // step 1: vacuum height_0_tuple + { + let meta_guard = relation.read(0); + let meta_tuple = meta_guard + .get() + .get(1) + .map(rkyv::check_archived_root::) + .expect("data corruption") + .expect("data corruption"); + let mut firsts = vec![meta_tuple.first]; + let make_firsts = |firsts| { + let mut results = Vec::new(); + for first in firsts { + let mut current = first; + while current != u32::MAX { + let h1_guard = relation.read(current); + for i in 1..=h1_guard.get().len() { + let h1_tuple = h1_guard + .get() + .get(i) + .map(rkyv::check_archived_root::) + .expect("data corruption") + .expect("data corruption"); + for j in 0..32 { + if h1_tuple.mask[j] { + results.push(h1_tuple.first[j]); + } + } + } + current = h1_guard.get().get_opaque().next; + } + } + results + }; + for _ in (1..meta_tuple.height_of_root).rev() { + firsts = make_firsts(firsts); + } + for first in firsts { + let mut current = first; + while current != u32::MAX { + delay(); + let mut h0_guard = relation.write(current); + for i in 1..=h0_guard.get().len() { + let h0_tuple = h0_guard + .get() + .get(i) + .map(rkyv::check_archived_root::) + .expect("data corruption") + .expect("data corruption"); + let flag = 'flag: { + for j in 0..32 { + if h0_tuple.mask[j] && callback(Pointer::new(h0_tuple.payload[j])) { + break 'flag true; + } + } + false + }; + if flag { + // todo: use mutable API + let mut temp = h0_guard + .get() + .get(i) + .map(rkyv::from_bytes::) + .expect("data corruption") + .expect("data corruption"); + for j in 0..32 { + if temp.mask[j] && callback(Pointer::new(temp.payload[j])) { + temp.mask[j] = false; + } + } + let temp = rkyv::to_bytes::<_, 8192>(&temp).expect("failed to serialize"); + h0_guard + .get_mut() + .get_mut(i) + .expect("data corruption") + .copy_from_slice(&temp); + } + } + // todo: cross-tuple vacuum so that we can skip a tuple + current = h0_guard.get().get_opaque().next; + } + } + } + // step 2: vacuum vector_tuple + { + let mut current = { + let meta_guard = relation.read(0); + let meta_tuple = meta_guard + .get() + .get(1) + .map(rkyv::check_archived_root::) + .expect("data corruption") + .expect("data corruption"); + meta_tuple.vectors_first + }; + while current != u32::MAX { + delay(); + let read = relation.read(current); + let flag = 'flag: { + for i in 1..=read.get().len() { + let Some(vector_tuple) = read.get().get(i) else { + continue; + }; + let vector_tuple = rkyv::check_archived_root::(vector_tuple) + .expect("data corruption"); + if let Some(payload) = vector_tuple.payload.as_ref().copied() { + if callback(Pointer::new(payload)) { + break 'flag true; + } + } + } + false + }; + if flag { + drop(read); + let mut write = relation.write(current); + for i in 1..=write.get().len() { + let Some(vector_tuple) = write.get().get(i) else { + continue; + }; + let vector_tuple = rkyv::check_archived_root::(vector_tuple) + .expect("data corruption"); + if let Some(payload) = vector_tuple.payload.as_ref().copied() { + if callback(Pointer::new(payload)) { + write.get_mut().free(i); + } + } + } + current = write.get().get_opaque().next; + } else { + current = read.get().get_opaque().next; + } + } + } +} diff --git a/src/vchordrqfscan/gucs/executing.rs b/src/vchordrqfscan/gucs/executing.rs new file mode 100644 index 0000000..6ec186e --- /dev/null +++ b/src/vchordrqfscan/gucs/executing.rs @@ -0,0 +1,76 @@ +use pgrx::guc::{GucContext, GucFlags, GucRegistry, GucSetting}; +use std::ffi::CStr; + +static PROBES: GucSetting> = GucSetting::>::new(Some(c"10")); +static EPSILON: GucSetting = GucSetting::::new(1.9); +static MAX_SCAN_TUPLES: GucSetting = GucSetting::::new(-1); + +pub unsafe fn init() { + GucRegistry::define_string_guc( + "vchordrqfscan.probes", + "`probes` argument of vchordrqfscan.", + "`probes` argument of vchordrqfscan.", + &PROBES, + GucContext::Userset, + GucFlags::default(), + ); + GucRegistry::define_float_guc( + "vchordrqfscan.epsilon", + "`epsilon` argument of vchordrqfscan.", + "`epsilon` argument of vchordrqfscan.", + &EPSILON, + 0.0, + 4.0, + GucContext::Userset, + GucFlags::default(), + ); + GucRegistry::define_int_guc( + "vchordrqfscan.max_scan_tuples", + "`max_scan_tuples` argument of vchordrqfscan.", + "`max_scan_tuples` argument of vchordrqfscan.", + &MAX_SCAN_TUPLES, + -1, + u16::MAX as _, + GucContext::Userset, + GucFlags::default(), + ); +} + +pub fn probes() -> Vec { + match PROBES.get() { + None => Vec::new(), + Some(probes) => { + let mut result = Vec::new(); + let mut current = None; + for &c in probes.to_bytes() { + match c { + b' ' => continue, + b',' => result.push(current.take().expect("empty probes")), + b'0'..=b'9' => { + if let Some(x) = current.as_mut() { + *x = *x * 10 + (c - b'0') as u32; + } else { + current = Some((c - b'0') as u32); + } + } + c => pgrx::error!("unknown character in probes: ASCII = {c}"), + } + } + result.push(current.take().expect("empty probes")); + result + } + } +} + +pub fn epsilon() -> f32 { + EPSILON.get() as f32 +} + +pub fn max_scan_tuples() -> Option { + let x = MAX_SCAN_TUPLES.get(); + if x < 0 { + None + } else { + Some(x as u32) + } +} diff --git a/src/vchordrqfscan/gucs/mod.rs b/src/vchordrqfscan/gucs/mod.rs new file mode 100644 index 0000000..48cc060 --- /dev/null +++ b/src/vchordrqfscan/gucs/mod.rs @@ -0,0 +1,14 @@ +pub mod executing; +pub mod prewarm; + +pub unsafe fn init() { + unsafe { + executing::init(); + prewarm::init(); + prewarm::prewarm(); + #[cfg(any(feature = "pg13", feature = "pg14"))] + pgrx::pg_sys::EmitWarningsOnPlaceholders(c"vchordrqfscan".as_ptr()); + #[cfg(any(feature = "pg15", feature = "pg16", feature = "pg17"))] + pgrx::pg_sys::MarkGUCPrefixReserved(c"vchordrqfscan".as_ptr()); + } +} diff --git a/src/vchordrqfscan/gucs/prewarm.rs b/src/vchordrqfscan/gucs/prewarm.rs new file mode 100644 index 0000000..ae9180a --- /dev/null +++ b/src/vchordrqfscan/gucs/prewarm.rs @@ -0,0 +1,32 @@ +use pgrx::guc::{GucContext, GucFlags, GucRegistry, GucSetting}; +use std::ffi::CStr; + +static PREWARM_DIM: GucSetting> = + GucSetting::>::new(Some(c"64,128,256,384,512,768,1024,1536")); + +pub unsafe fn init() { + GucRegistry::define_string_guc( + "vchordrqfscan.prewarm_dim", + "prewarm_dim when the extension is loading.", + "prewarm_dim when the extension is loading.", + &PREWARM_DIM, + GucContext::Userset, + GucFlags::default(), + ); +} + +pub fn prewarm() { + if let Some(prewarm_dim) = PREWARM_DIM.get() { + if let Ok(prewarm_dim) = prewarm_dim.to_str() { + for dim in prewarm_dim.split(',') { + if let Ok(dim) = dim.trim().parse::() { + crate::projection::prewarm(dim as _); + } else { + pgrx::warning!("{dim:?} is not a valid integer"); + } + } + } else { + pgrx::warning!("vchordrqfscan.prewarm_dim is not a valid UTF-8 string"); + } + } +} diff --git a/src/vchordrqfscan/index/am.rs b/src/vchordrqfscan/index/am.rs new file mode 100644 index 0000000..e234414 --- /dev/null +++ b/src/vchordrqfscan/index/am.rs @@ -0,0 +1,864 @@ +use crate::postgres::Relation; +use crate::vchordrqfscan::algorithm; +use crate::vchordrqfscan::algorithm::build::{HeapRelation, Reporter}; +use crate::vchordrqfscan::index::am_options::{Opfamily, Reloption}; +use crate::vchordrqfscan::index::am_scan::Scanner; +use crate::vchordrqfscan::index::utils::{ctid_to_pointer, pointer_to_ctid}; +use crate::vchordrqfscan::index::{am_options, am_scan}; +use base::search::Pointer; +use pgrx::datum::Internal; +use pgrx::pg_sys::Datum; + +static mut RELOPT_KIND_VCHORDRQFSCAN: pgrx::pg_sys::relopt_kind::Type = 0; + +pub unsafe fn init() { + unsafe { + (&raw mut RELOPT_KIND_VCHORDRQFSCAN).write(pgrx::pg_sys::add_reloption_kind()); + pgrx::pg_sys::add_string_reloption( + (&raw const RELOPT_KIND_VCHORDRQFSCAN).read(), + c"options".as_ptr(), + c"Vector index options, represented as a TOML string.".as_ptr(), + c"".as_ptr(), + None, + pgrx::pg_sys::AccessExclusiveLock as pgrx::pg_sys::LOCKMODE, + ); + } +} + +#[pgrx::pg_extern(sql = "")] +fn _vchordrqfscan_amhandler(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Internal { + type T = pgrx::pg_sys::IndexAmRoutine; + unsafe { + let index_am_routine = pgrx::pg_sys::palloc0(size_of::()) as *mut T; + index_am_routine.write(AM_HANDLER); + Internal::from(Some(Datum::from(index_am_routine))) + } +} + +const AM_HANDLER: pgrx::pg_sys::IndexAmRoutine = { + let mut am_routine = + unsafe { std::mem::MaybeUninit::::zeroed().assume_init() }; + + am_routine.type_ = pgrx::pg_sys::NodeTag::T_IndexAmRoutine; + + am_routine.amsupport = 1; + am_routine.amcanorderbyop = true; + + #[cfg(feature = "pg17")] + { + am_routine.amcanbuildparallel = true; + } + + // Index access methods that set `amoptionalkey` to `false` + // must index all tuples, even if the first column is `NULL`. + // However, PostgreSQL does not generate a path if there is no + // index clauses, even if there is a `ORDER BY` clause. + // So we have to set it to `true` and set costs of every path + // for vector index scans without `ORDER BY` clauses a large number + // and throw errors if someone really wants such a path. + am_routine.amoptionalkey = true; + + am_routine.amvalidate = Some(amvalidate); + am_routine.amoptions = Some(amoptions); + am_routine.amcostestimate = Some(amcostestimate); + + am_routine.ambuild = Some(ambuild); + am_routine.ambuildempty = Some(ambuildempty); + am_routine.aminsert = Some(aminsert); + am_routine.ambulkdelete = Some(ambulkdelete); + am_routine.amvacuumcleanup = Some(amvacuumcleanup); + + am_routine.ambeginscan = Some(ambeginscan); + am_routine.amrescan = Some(amrescan); + am_routine.amgettuple = Some(amgettuple); + am_routine.amendscan = Some(amendscan); + + am_routine +}; + +#[pgrx::pg_guard] +pub unsafe extern "C" fn amvalidate(_opclass_oid: pgrx::pg_sys::Oid) -> bool { + true +} + +#[pgrx::pg_guard] +pub unsafe extern "C" fn amoptions(reloptions: Datum, validate: bool) -> *mut pgrx::pg_sys::bytea { + let rdopts = unsafe { + pgrx::pg_sys::build_reloptions( + reloptions, + validate, + (&raw const RELOPT_KIND_VCHORDRQFSCAN).read(), + size_of::(), + Reloption::TAB.as_ptr(), + Reloption::TAB.len() as _, + ) + }; + rdopts as *mut pgrx::pg_sys::bytea +} + +#[pgrx::pg_guard] +pub unsafe extern "C" fn amcostestimate( + _root: *mut pgrx::pg_sys::PlannerInfo, + path: *mut pgrx::pg_sys::IndexPath, + _loop_count: f64, + index_startup_cost: *mut pgrx::pg_sys::Cost, + index_total_cost: *mut pgrx::pg_sys::Cost, + index_selectivity: *mut pgrx::pg_sys::Selectivity, + index_correlation: *mut f64, + index_pages: *mut f64, +) { + unsafe { + if (*path).indexorderbys.is_null() && (*path).indexclauses.is_null() { + *index_startup_cost = f64::MAX; + *index_total_cost = f64::MAX; + *index_selectivity = 0.0; + *index_correlation = 0.0; + *index_pages = 0.0; + return; + } + *index_startup_cost = 0.0; + *index_total_cost = 0.0; + *index_selectivity = 1.0; + *index_correlation = 1.0; + *index_pages = 0.0; + } +} + +#[derive(Debug, Clone)] +struct PgReporter {} + +impl Reporter for PgReporter { + fn tuples_total(&mut self, tuples_total: u64) { + unsafe { + pgrx::pg_sys::pgstat_progress_update_param( + pgrx::pg_sys::PROGRESS_CREATEIDX_TUPLES_TOTAL as _, + tuples_total as _, + ); + } + } +} + +impl PgReporter { + fn tuples_done(&mut self, tuples_done: u64) { + unsafe { + pgrx::pg_sys::pgstat_progress_update_param( + pgrx::pg_sys::PROGRESS_CREATEIDX_TUPLES_DONE as _, + tuples_done as _, + ); + } + } +} + +#[pgrx::pg_guard] +pub unsafe extern "C" fn ambuild( + heap: pgrx::pg_sys::Relation, + index: pgrx::pg_sys::Relation, + index_info: *mut pgrx::pg_sys::IndexInfo, +) -> *mut pgrx::pg_sys::IndexBuildResult { + use validator::Validate; + #[derive(Debug, Clone)] + pub struct Heap { + heap: pgrx::pg_sys::Relation, + index: pgrx::pg_sys::Relation, + index_info: *mut pgrx::pg_sys::IndexInfo, + opfamily: Opfamily, + } + impl HeapRelation for Heap { + fn traverse(&self, progress: bool, callback: F) + where + F: FnMut((Pointer, Vec)), + { + pub struct State<'a, F> { + pub this: &'a Heap, + pub callback: F, + } + #[pgrx::pg_guard] + unsafe extern "C" fn call( + _index: pgrx::pg_sys::Relation, + ctid: pgrx::pg_sys::ItemPointer, + values: *mut Datum, + is_null: *mut bool, + _tuple_is_alive: bool, + state: *mut core::ffi::c_void, + ) where + F: FnMut((Pointer, Vec)), + { + use base::vector::OwnedVector; + let state = unsafe { &mut *state.cast::>() }; + let opfamily = state.this.opfamily; + let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; + let pointer = unsafe { ctid_to_pointer(ctid.read()) }; + if let Some(vector) = vector { + let vector = match vector { + OwnedVector::Vecf32(x) => x, + OwnedVector::Vecf16(_) => unreachable!(), + OwnedVector::SVecf32(_) => unreachable!(), + OwnedVector::BVector(_) => unreachable!(), + }; + (state.callback)((pointer, vector.into_vec())); + } + } + let table_am = unsafe { &*(*self.heap).rd_tableam }; + let mut state = State { + this: self, + callback, + }; + unsafe { + table_am.index_build_range_scan.unwrap()( + self.heap, + self.index, + self.index_info, + true, + false, + progress, + 0, + pgrx::pg_sys::InvalidBlockNumber, + Some(call::), + (&mut state) as *mut State as *mut _, + std::ptr::null_mut(), + ); + } + } + + fn opfamily(&self) -> Opfamily { + self.opfamily + } + } + let (vector_options, vchordrqfscan_options) = unsafe { am_options::options(index) }; + if let Err(errors) = Validate::validate(&vector_options) { + pgrx::error!("error while validating options: {}", errors); + } + if vector_options.dims == 0 { + pgrx::error!("error while validating options: dimension cannot be 0"); + } + if vector_options.dims > 1600 { + pgrx::error!("error while validating options: dimension is too large"); + } + if let Err(errors) = Validate::validate(&vchordrqfscan_options) { + pgrx::error!("error while validating options: {}", errors); + } + let opfamily = unsafe { am_options::opfamily(index) }; + let heap_relation = Heap { + heap, + index, + index_info, + opfamily, + }; + let mut reporter = PgReporter {}; + let index_relation = unsafe { Relation::new(index) }; + algorithm::build::build( + vector_options, + vchordrqfscan_options, + heap_relation.clone(), + index_relation.clone(), + reporter.clone(), + ); + if let Some(leader) = + unsafe { VchordrqfscanLeader::enter(heap, index, (*index_info).ii_Concurrent) } + { + unsafe { + parallel_build( + index, + heap, + index_info, + leader.tablescandesc, + leader.vchordrqfscanshared, + Some(reporter), + ); + leader.wait(); + let nparticipants = leader.nparticipants; + loop { + pgrx::pg_sys::SpinLockAcquire(&raw mut (*leader.vchordrqfscanshared).mutex); + if (*leader.vchordrqfscanshared).nparticipantsdone == nparticipants { + pgrx::pg_sys::SpinLockRelease(&raw mut (*leader.vchordrqfscanshared).mutex); + break; + } + pgrx::pg_sys::SpinLockRelease(&raw mut (*leader.vchordrqfscanshared).mutex); + pgrx::pg_sys::ConditionVariableSleep( + &raw mut (*leader.vchordrqfscanshared).workersdonecv, + pgrx::pg_sys::WaitEventIPC::WAIT_EVENT_PARALLEL_CREATE_INDEX_SCAN, + ); + } + pgrx::pg_sys::ConditionVariableCancelSleep(); + } + } else { + let mut indtuples = 0; + reporter.tuples_done(indtuples); + heap_relation.traverse(true, |(payload, vector)| { + algorithm::insert::insert( + index_relation.clone(), + payload, + vector, + opfamily.distance_kind(), + ); + indtuples += 1; + reporter.tuples_done(indtuples); + }); + } + unsafe { pgrx::pgbox::PgBox::::alloc0().into_pg() } +} + +struct VchordrqfscanShared { + /* Immutable state */ + heaprelid: pgrx::pg_sys::Oid, + indexrelid: pgrx::pg_sys::Oid, + isconcurrent: bool, + + /* Worker progress */ + workersdonecv: pgrx::pg_sys::ConditionVariable, + + /* Mutex for mutable state */ + mutex: pgrx::pg_sys::slock_t, + + /* Mutable state */ + nparticipantsdone: i32, + indtuples: u64, +} + +fn is_mvcc_snapshot(snapshot: *mut pgrx::pg_sys::SnapshotData) -> bool { + matches!( + unsafe { (*snapshot).snapshot_type }, + pgrx::pg_sys::SnapshotType::SNAPSHOT_MVCC + | pgrx::pg_sys::SnapshotType::SNAPSHOT_HISTORIC_MVCC + ) +} + +struct VchordrqfscanLeader { + pcxt: *mut pgrx::pg_sys::ParallelContext, + nparticipants: i32, + vchordrqfscanshared: *mut VchordrqfscanShared, + tablescandesc: *mut pgrx::pg_sys::ParallelTableScanDescData, + snapshot: pgrx::pg_sys::Snapshot, +} + +impl VchordrqfscanLeader { + pub unsafe fn enter( + heap: pgrx::pg_sys::Relation, + index: pgrx::pg_sys::Relation, + isconcurrent: bool, + ) -> Option { + unsafe fn compute_parallel_workers( + heap: pgrx::pg_sys::Relation, + index: pgrx::pg_sys::Relation, + ) -> i32 { + unsafe { + if pgrx::pg_sys::plan_create_index_workers((*heap).rd_id, (*index).rd_id) == 0 { + return 0; + } + if !(*heap).rd_options.is_null() { + let std_options = (*heap).rd_options.cast::(); + std::cmp::min( + (*std_options).parallel_workers, + pgrx::pg_sys::max_parallel_maintenance_workers, + ) + } else { + pgrx::pg_sys::max_parallel_maintenance_workers + } + } + } + + let request = unsafe { compute_parallel_workers(heap, index) }; + if request <= 0 { + return None; + } + + unsafe { + pgrx::pg_sys::EnterParallelMode(); + } + let pcxt = unsafe { + pgrx::pg_sys::CreateParallelContext( + c"vchord".as_ptr(), + c"vchordrqfscan_parallel_build_main".as_ptr(), + request, + ) + }; + + let snapshot = if isconcurrent { + unsafe { pgrx::pg_sys::RegisterSnapshot(pgrx::pg_sys::GetTransactionSnapshot()) } + } else { + &raw mut pgrx::pg_sys::SnapshotAnyData + }; + + fn estimate_chunk(e: &mut pgrx::pg_sys::shm_toc_estimator, x: usize) { + e.space_for_chunks += x.next_multiple_of(pgrx::pg_sys::ALIGNOF_BUFFER as _); + } + fn estimate_keys(e: &mut pgrx::pg_sys::shm_toc_estimator, x: usize) { + e.number_of_keys += x; + } + let est_tablescandesc = + unsafe { pgrx::pg_sys::table_parallelscan_estimate(heap, snapshot) }; + unsafe { + estimate_chunk(&mut (*pcxt).estimator, size_of::()); + estimate_keys(&mut (*pcxt).estimator, 1); + estimate_chunk(&mut (*pcxt).estimator, est_tablescandesc); + estimate_keys(&mut (*pcxt).estimator, 1); + } + + unsafe { + pgrx::pg_sys::InitializeParallelDSM(pcxt); + if (*pcxt).seg.is_null() { + if is_mvcc_snapshot(snapshot) { + pgrx::pg_sys::UnregisterSnapshot(snapshot); + } + pgrx::pg_sys::DestroyParallelContext(pcxt); + pgrx::pg_sys::ExitParallelMode(); + return None; + } + } + + let vchordrqfscanshared = unsafe { + let vchordrqfscanshared = + pgrx::pg_sys::shm_toc_allocate((*pcxt).toc, size_of::()) + .cast::(); + vchordrqfscanshared.write(VchordrqfscanShared { + heaprelid: (*heap).rd_id, + indexrelid: (*index).rd_id, + isconcurrent, + workersdonecv: std::mem::zeroed(), + mutex: std::mem::zeroed(), + nparticipantsdone: 0, + indtuples: 0, + }); + pgrx::pg_sys::ConditionVariableInit(&raw mut (*vchordrqfscanshared).workersdonecv); + pgrx::pg_sys::SpinLockInit(&raw mut (*vchordrqfscanshared).mutex); + vchordrqfscanshared + }; + + let tablescandesc = unsafe { + let tablescandesc = pgrx::pg_sys::shm_toc_allocate((*pcxt).toc, est_tablescandesc) + .cast::(); + pgrx::pg_sys::table_parallelscan_initialize(heap, tablescandesc, snapshot); + tablescandesc + }; + + unsafe { + pgrx::pg_sys::shm_toc_insert( + (*pcxt).toc, + 0xA000000000000001, + vchordrqfscanshared.cast(), + ); + pgrx::pg_sys::shm_toc_insert((*pcxt).toc, 0xA000000000000002, tablescandesc.cast()); + } + + unsafe { + pgrx::pg_sys::LaunchParallelWorkers(pcxt); + } + + let nworkers_launched = unsafe { (*pcxt).nworkers_launched }; + + unsafe { + if nworkers_launched == 0 { + pgrx::pg_sys::WaitForParallelWorkersToFinish(pcxt); + if is_mvcc_snapshot(snapshot) { + pgrx::pg_sys::UnregisterSnapshot(snapshot); + } + pgrx::pg_sys::DestroyParallelContext(pcxt); + pgrx::pg_sys::ExitParallelMode(); + return None; + } + } + + Some(Self { + pcxt, + nparticipants: nworkers_launched + 1, + vchordrqfscanshared, + tablescandesc, + snapshot, + }) + } + + pub fn wait(&self) { + unsafe { + pgrx::pg_sys::WaitForParallelWorkersToAttach(self.pcxt); + } + } +} + +impl Drop for VchordrqfscanLeader { + fn drop(&mut self) { + if !std::thread::panicking() { + unsafe { + pgrx::pg_sys::WaitForParallelWorkersToFinish(self.pcxt); + if is_mvcc_snapshot(self.snapshot) { + pgrx::pg_sys::UnregisterSnapshot(self.snapshot); + } + pgrx::pg_sys::DestroyParallelContext(self.pcxt); + pgrx::pg_sys::ExitParallelMode(); + } + } + } +} + +#[pgrx::pg_guard] +#[no_mangle] +pub unsafe extern "C" fn vchordrqfscan_parallel_build_main( + _seg: *mut pgrx::pg_sys::dsm_segment, + toc: *mut pgrx::pg_sys::shm_toc, +) { + let vchordrqfscanshared = unsafe { + pgrx::pg_sys::shm_toc_lookup(toc, 0xA000000000000001, false).cast::() + }; + let tablescandesc = unsafe { + pgrx::pg_sys::shm_toc_lookup(toc, 0xA000000000000002, false) + .cast::() + }; + let heap_lockmode; + let index_lockmode; + if unsafe { !(*vchordrqfscanshared).isconcurrent } { + heap_lockmode = pgrx::pg_sys::ShareLock as pgrx::pg_sys::LOCKMODE; + index_lockmode = pgrx::pg_sys::AccessExclusiveLock as pgrx::pg_sys::LOCKMODE; + } else { + heap_lockmode = pgrx::pg_sys::ShareUpdateExclusiveLock as pgrx::pg_sys::LOCKMODE; + index_lockmode = pgrx::pg_sys::RowExclusiveLock as pgrx::pg_sys::LOCKMODE; + } + let heap = unsafe { pgrx::pg_sys::table_open((*vchordrqfscanshared).heaprelid, heap_lockmode) }; + let index = + unsafe { pgrx::pg_sys::index_open((*vchordrqfscanshared).indexrelid, index_lockmode) }; + let index_info = unsafe { pgrx::pg_sys::BuildIndexInfo(index) }; + unsafe { + (*index_info).ii_Concurrent = (*vchordrqfscanshared).isconcurrent; + } + + unsafe { + parallel_build( + index, + heap, + index_info, + tablescandesc, + vchordrqfscanshared, + None, + ); + } + + unsafe { + pgrx::pg_sys::index_close(index, index_lockmode); + pgrx::pg_sys::table_close(heap, heap_lockmode); + } +} + +unsafe fn parallel_build( + index: *mut pgrx::pg_sys::RelationData, + heap: pgrx::pg_sys::Relation, + index_info: *mut pgrx::pg_sys::IndexInfo, + tablescandesc: *mut pgrx::pg_sys::ParallelTableScanDescData, + vchordrqfscanshared: *mut VchordrqfscanShared, + mut reporter: Option, +) { + #[derive(Debug, Clone)] + pub struct Heap { + heap: pgrx::pg_sys::Relation, + index: pgrx::pg_sys::Relation, + index_info: *mut pgrx::pg_sys::IndexInfo, + opfamily: Opfamily, + scan: *mut pgrx::pg_sys::TableScanDescData, + } + impl HeapRelation for Heap { + fn traverse(&self, progress: bool, callback: F) + where + F: FnMut((Pointer, Vec)), + { + pub struct State<'a, F> { + pub this: &'a Heap, + pub callback: F, + } + #[pgrx::pg_guard] + unsafe extern "C" fn call( + _index: pgrx::pg_sys::Relation, + ctid: pgrx::pg_sys::ItemPointer, + values: *mut Datum, + is_null: *mut bool, + _tuple_is_alive: bool, + state: *mut core::ffi::c_void, + ) where + F: FnMut((Pointer, Vec)), + { + use base::vector::OwnedVector; + let state = unsafe { &mut *state.cast::>() }; + let opfamily = state.this.opfamily; + let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; + let pointer = unsafe { ctid_to_pointer(ctid.read()) }; + if let Some(vector) = vector { + let vector = match vector { + OwnedVector::Vecf32(x) => x, + OwnedVector::Vecf16(_) => unreachable!(), + OwnedVector::SVecf32(_) => unreachable!(), + OwnedVector::BVector(_) => unreachable!(), + }; + (state.callback)((pointer, vector.into_vec())); + } + } + let table_am = unsafe { &*(*self.heap).rd_tableam }; + let mut state = State { + this: self, + callback, + }; + unsafe { + table_am.index_build_range_scan.unwrap()( + self.heap, + self.index, + self.index_info, + true, + false, + progress, + 0, + pgrx::pg_sys::InvalidBlockNumber, + Some(call::), + (&mut state) as *mut State as *mut _, + self.scan, + ); + } + } + + fn opfamily(&self) -> Opfamily { + self.opfamily + } + } + + let index_relation = unsafe { Relation::new(index) }; + let scan = unsafe { pgrx::pg_sys::table_beginscan_parallel(heap, tablescandesc) }; + let opfamily = unsafe { am_options::opfamily(index) }; + let heap_relation = Heap { + heap, + index, + index_info, + opfamily, + scan, + }; + heap_relation.traverse(reporter.is_some(), |(payload, vector)| { + algorithm::insert::insert( + index_relation.clone(), + payload, + vector, + opfamily.distance_kind(), + ); + unsafe { + let indtuples; + { + pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqfscanshared).mutex); + (*vchordrqfscanshared).indtuples += 1; + indtuples = (*vchordrqfscanshared).indtuples; + pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqfscanshared).mutex); + } + if let Some(reporter) = reporter.as_mut() { + reporter.tuples_done(indtuples); + } + } + }); + + unsafe { + pgrx::pg_sys::SpinLockAcquire(&raw mut (*vchordrqfscanshared).mutex); + (*vchordrqfscanshared).nparticipantsdone += 1; + pgrx::pg_sys::SpinLockRelease(&raw mut (*vchordrqfscanshared).mutex); + pgrx::pg_sys::ConditionVariableSignal(&raw mut (*vchordrqfscanshared).workersdonecv); + } +} + +#[pgrx::pg_guard] +pub unsafe extern "C" fn ambuildempty(_index: pgrx::pg_sys::Relation) { + pgrx::error!("Unlogged indexes are not supported."); +} + +#[cfg(feature = "pg13")] +#[pgrx::pg_guard] +pub unsafe extern "C" fn aminsert( + index: pgrx::pg_sys::Relation, + values: *mut Datum, + is_null: *mut bool, + heap_tid: pgrx::pg_sys::ItemPointer, + _heap: pgrx::pg_sys::Relation, + _check_unique: pgrx::pg_sys::IndexUniqueCheck::Type, + _index_info: *mut pgrx::pg_sys::IndexInfo, +) -> bool { + use base::vector::OwnedVector; + let opfamily = unsafe { am_options::opfamily(index) }; + let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; + if let Some(vector) = vector { + let vector = match vector { + OwnedVector::Vecf32(x) => x, + OwnedVector::Vecf16(_) => unreachable!(), + OwnedVector::SVecf32(_) => unreachable!(), + OwnedVector::BVector(_) => unreachable!(), + }; + let pointer = ctid_to_pointer(unsafe { heap_tid.read() }); + algorithm::insert::insert( + unsafe { Relation::new(index) }, + pointer, + vector.into_vec(), + opfamily.distance_kind(), + ); + } + false +} + +#[cfg(any(feature = "pg14", feature = "pg15", feature = "pg16", feature = "pg17"))] +#[pgrx::pg_guard] +pub unsafe extern "C" fn aminsert( + index: pgrx::pg_sys::Relation, + values: *mut Datum, + is_null: *mut bool, + heap_tid: pgrx::pg_sys::ItemPointer, + _heap: pgrx::pg_sys::Relation, + _check_unique: pgrx::pg_sys::IndexUniqueCheck::Type, + _index_unchanged: bool, + _index_info: *mut pgrx::pg_sys::IndexInfo, +) -> bool { + use base::vector::OwnedVector; + let opfamily = unsafe { am_options::opfamily(index) }; + let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; + if let Some(vector) = vector { + let vector = match vector { + OwnedVector::Vecf32(x) => x, + OwnedVector::Vecf16(_) => unreachable!(), + OwnedVector::SVecf32(_) => unreachable!(), + OwnedVector::BVector(_) => unreachable!(), + }; + let pointer = ctid_to_pointer(unsafe { heap_tid.read() }); + algorithm::insert::insert( + unsafe { Relation::new(index) }, + pointer, + vector.into_vec(), + opfamily.distance_kind(), + ); + } + false +} + +#[pgrx::pg_guard] +pub unsafe extern "C" fn ambeginscan( + index: pgrx::pg_sys::Relation, + n_keys: std::os::raw::c_int, + n_orderbys: std::os::raw::c_int, +) -> pgrx::pg_sys::IndexScanDesc { + use pgrx::memcxt::PgMemoryContexts::CurrentMemoryContext; + + let scan = unsafe { pgrx::pg_sys::RelationGetIndexScan(index, n_keys, n_orderbys) }; + unsafe { + let scanner = am_scan::scan_make(None, None, false); + (*scan).opaque = CurrentMemoryContext.leak_and_drop_on_delete(scanner).cast(); + } + scan +} + +#[pgrx::pg_guard] +pub unsafe extern "C" fn amrescan( + scan: pgrx::pg_sys::IndexScanDesc, + keys: pgrx::pg_sys::ScanKey, + _n_keys: std::os::raw::c_int, + orderbys: pgrx::pg_sys::ScanKey, + _n_orderbys: std::os::raw::c_int, +) { + unsafe { + if !keys.is_null() && (*scan).numberOfKeys > 0 { + std::ptr::copy(keys, (*scan).keyData, (*scan).numberOfKeys as _); + } + if !orderbys.is_null() && (*scan).numberOfOrderBys > 0 { + std::ptr::copy(orderbys, (*scan).orderByData, (*scan).numberOfOrderBys as _); + } + let opfamily = am_options::opfamily((*scan).indexRelation); + let (orderbys, spheres) = { + let mut orderbys = Vec::new(); + let mut spheres = Vec::new(); + if (*scan).numberOfOrderBys == 0 && (*scan).numberOfKeys == 0 { + pgrx::error!( + "vector search with no WHERE clause and no ORDER BY clause is not supported" + ); + } + for i in 0..(*scan).numberOfOrderBys { + let data = (*scan).orderByData.add(i as usize); + let value = (*data).sk_argument; + let is_null = ((*data).sk_flags & pgrx::pg_sys::SK_ISNULL as i32) != 0; + match (*data).sk_strategy { + 1 => orderbys.push(opfamily.datum_to_vector(value, is_null)), + _ => unreachable!(), + } + } + for i in 0..(*scan).numberOfKeys { + let data = (*scan).keyData.add(i as usize); + let value = (*data).sk_argument; + let is_null = ((*data).sk_flags & pgrx::pg_sys::SK_ISNULL as i32) != 0; + match (*data).sk_strategy { + 2 => spheres.push(opfamily.datum_to_sphere(value, is_null)), + _ => unreachable!(), + } + } + (orderbys, spheres) + }; + let (vector, threshold, recheck) = am_scan::scan_build(orderbys, spheres, opfamily); + let scanner = (*scan).opaque.cast::().as_mut().unwrap_unchecked(); + let scanner = std::mem::replace(scanner, am_scan::scan_make(vector, threshold, recheck)); + am_scan::scan_release(scanner); + } +} + +#[pgrx::pg_guard] +pub unsafe extern "C" fn amgettuple( + scan: pgrx::pg_sys::IndexScanDesc, + direction: pgrx::pg_sys::ScanDirection::Type, +) -> bool { + if direction != pgrx::pg_sys::ScanDirection::ForwardScanDirection { + pgrx::error!("vector search without a forward scan direction is not supported"); + } + // https://www.postgresql.org/docs/current/index-locking.html + // If heap entries referenced physical pointers are deleted before + // they are consumed by PostgreSQL, PostgreSQL will received wrong + // physical pointers: no rows or irreverent rows are referenced. + if unsafe { (*(*scan).xs_snapshot).snapshot_type } != pgrx::pg_sys::SnapshotType::SNAPSHOT_MVCC + { + pgrx::error!("scanning with a non-MVCC-compliant snapshot is not supported"); + } + let scanner = unsafe { (*scan).opaque.cast::().as_mut().unwrap_unchecked() }; + let relation = unsafe { Relation::new((*scan).indexRelation) }; + if let Some((pointer, recheck)) = am_scan::scan_next(scanner, relation) { + let ctid = pointer_to_ctid(pointer); + unsafe { + (*scan).xs_heaptid = ctid; + (*scan).xs_recheckorderby = false; + (*scan).xs_recheck = recheck; + } + true + } else { + false + } +} + +#[pgrx::pg_guard] +pub unsafe extern "C" fn amendscan(scan: pgrx::pg_sys::IndexScanDesc) { + unsafe { + let scanner = (*scan).opaque.cast::().as_mut().unwrap_unchecked(); + let scanner = std::mem::replace(scanner, am_scan::scan_make(None, None, false)); + am_scan::scan_release(scanner); + } +} + +#[pgrx::pg_guard] +pub unsafe extern "C" fn ambulkdelete( + info: *mut pgrx::pg_sys::IndexVacuumInfo, + stats: *mut pgrx::pg_sys::IndexBulkDeleteResult, + callback: pgrx::pg_sys::IndexBulkDeleteCallback, + callback_state: *mut std::os::raw::c_void, +) -> *mut pgrx::pg_sys::IndexBulkDeleteResult { + let mut stats = stats; + if stats.is_null() { + stats = unsafe { + pgrx::pg_sys::palloc0(size_of::()).cast() + }; + } + let callback = callback.unwrap(); + let callback = |p: Pointer| unsafe { callback(&mut pointer_to_ctid(p), callback_state) }; + algorithm::vacuum::vacuum( + unsafe { Relation::new((*info).index) }, + || unsafe { + pgrx::pg_sys::vacuum_delay_point(); + }, + callback, + ); + stats +} + +#[pgrx::pg_guard] +pub unsafe extern "C" fn amvacuumcleanup( + _info: *mut pgrx::pg_sys::IndexVacuumInfo, + _stats: *mut pgrx::pg_sys::IndexBulkDeleteResult, +) -> *mut pgrx::pg_sys::IndexBulkDeleteResult { + std::ptr::null_mut() +} diff --git a/src/vchordrqfscan/index/am_options.rs b/src/vchordrqfscan/index/am_options.rs new file mode 100644 index 0000000..51a1009 --- /dev/null +++ b/src/vchordrqfscan/index/am_options.rs @@ -0,0 +1,222 @@ +use crate::datatype::memory_pgvector_vector::PgvectorVectorInput; +use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput; +use crate::datatype::typmod::Typmod; +use crate::vchordrqfscan::types::VchordrqfscanIndexingOptions; +use base::distance::*; +use base::index::*; +use base::vector::*; +use pgrx::datum::FromDatum; +use pgrx::heap_tuple::PgHeapTuple; +use serde::Deserialize; +use std::ffi::CStr; +use std::num::NonZero; + +#[derive(Copy, Clone, Debug, Default)] +#[repr(C)] +pub struct Reloption { + vl_len_: i32, + pub options: i32, +} + +impl Reloption { + pub const TAB: &'static [pgrx::pg_sys::relopt_parse_elt] = &[pgrx::pg_sys::relopt_parse_elt { + optname: c"options".as_ptr(), + opttype: pgrx::pg_sys::relopt_type::RELOPT_TYPE_STRING, + offset: std::mem::offset_of!(Reloption, options) as i32, + }]; + unsafe fn options(&self) -> &CStr { + unsafe { + let ptr = std::ptr::addr_of!(*self) + .cast::() + .offset(self.options as _); + CStr::from_ptr(ptr) + } + } +} + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum PgDistanceKind { + L2, + Dot, + Cos, +} + +impl PgDistanceKind { + pub fn to_distance(self) -> DistanceKind { + match self { + PgDistanceKind::L2 => DistanceKind::L2, + PgDistanceKind::Dot | PgDistanceKind::Cos => DistanceKind::Dot, + } + } +} + +fn convert_name_to_vd(name: &str) -> Option<(VectorKind, PgDistanceKind)> { + match name.strip_suffix("_ops") { + Some("vector_l2") => Some((VectorKind::Vecf32, PgDistanceKind::L2)), + Some("vector_ip") => Some((VectorKind::Vecf32, PgDistanceKind::Dot)), + Some("vector_cosine") => Some((VectorKind::Vecf32, PgDistanceKind::Cos)), + _ => None, + } +} + +unsafe fn convert_reloptions_to_options( + reloptions: *const pgrx::pg_sys::varlena, +) -> VchordrqfscanIndexingOptions { + #[derive(Debug, Clone, Deserialize, Default)] + #[serde(deny_unknown_fields)] + struct Parsed { + #[serde(flatten)] + rabitq: VchordrqfscanIndexingOptions, + } + let reloption = reloptions as *const Reloption; + if reloption.is_null() || unsafe { (*reloption).options == 0 } { + return Default::default(); + } + let s = unsafe { (*reloption).options() }.to_string_lossy(); + match toml::from_str::(&s) { + Ok(p) => p.rabitq, + Err(e) => pgrx::error!("failed to parse options: {}", e), + } +} + +pub unsafe fn options( + index: pgrx::pg_sys::Relation, +) -> (VectorOptions, VchordrqfscanIndexingOptions) { + let att = unsafe { &mut *(*index).rd_att }; + let atts = unsafe { att.attrs.as_slice(att.natts as _) }; + if atts.is_empty() { + pgrx::error!("indexing on no columns is not supported"); + } + if atts.len() != 1 { + pgrx::error!("multicolumn index is not supported"); + } + // get dims + let typmod = Typmod::parse_from_i32(atts[0].type_mod()).unwrap(); + let dims = if let Some(dims) = typmod.dims() { + dims.get() + } else { + pgrx::error!( + "Dimensions type modifier of a vector column is needed for building the index." + ); + }; + // get v, d + let opfamily = unsafe { opfamily(index) }; + let vector = VectorOptions { + dims, + v: opfamily.vector, + d: opfamily.distance_kind(), + }; + // get indexing, segment, optimizing + let rabitq = unsafe { convert_reloptions_to_options((*index).rd_options) }; + (vector, rabitq) +} + +#[derive(Debug, Clone, Copy)] +pub struct Opfamily { + vector: VectorKind, + pg_distance: PgDistanceKind, +} + +impl Opfamily { + pub unsafe fn datum_to_vector( + self, + datum: pgrx::pg_sys::Datum, + is_null: bool, + ) -> Option { + if is_null || datum.is_null() { + return None; + } + let vector = match self.vector { + VectorKind::Vecf32 => { + let vector = unsafe { PgvectorVectorInput::from_datum(datum, false).unwrap() }; + self.preprocess(BorrowedVector::Vecf32(vector.as_borrowed())) + } + _ => unreachable!(), + }; + Some(vector) + } + pub unsafe fn datum_to_sphere( + self, + datum: pgrx::pg_sys::Datum, + is_null: bool, + ) -> (Option, Option) { + if is_null || datum.is_null() { + return (None, None); + } + let tuple = unsafe { PgHeapTuple::from_composite_datum(datum) }; + let center = match self.vector { + VectorKind::Vecf32 => tuple + .get_by_index::(NonZero::new(1).unwrap()) + .unwrap() + .map(|vector| self.preprocess(BorrowedVector::Vecf32(vector.as_borrowed()))), + _ => unreachable!(), + }; + let radius = tuple.get_by_index::(NonZero::new(2).unwrap()).unwrap(); + (center, radius) + } + pub fn preprocess(self, vector: BorrowedVector<'_>) -> OwnedVector { + use BorrowedVector as B; + use OwnedVector as O; + match (vector, self.pg_distance) { + (B::Vecf32(x), PgDistanceKind::L2) => O::Vecf32(x.own()), + (B::Vecf32(x), PgDistanceKind::Dot) => O::Vecf32(x.own()), + (B::Vecf32(x), PgDistanceKind::Cos) => O::Vecf32(x.function_normalize()), + (B::Vecf16(x), _) => O::Vecf16(x.own()), + (B::SVecf32(x), _) => O::SVecf32(x.own()), + (B::BVector(x), _) => O::BVector(x.own()), + } + } + pub fn process(self, x: Distance) -> f32 { + match self.pg_distance { + PgDistanceKind::Cos => f32::from(x) + 1.0f32, + PgDistanceKind::L2 => f32::from(x).sqrt(), + _ => f32::from(x), + } + } + pub fn distance_kind(self) -> DistanceKind { + self.pg_distance.to_distance() + } +} + +pub unsafe fn opfamily(index: pgrx::pg_sys::Relation) -> Opfamily { + use pgrx::pg_sys::Oid; + + let proc = unsafe { pgrx::pg_sys::index_getprocid(index, 1, 1) }; + + if proc == Oid::INVALID { + pgrx::error!("support function 1 is not found"); + } + + let mut flinfo = pgrx::pg_sys::FmgrInfo::default(); + unsafe { + pgrx::pg_sys::fmgr_info(proc, &mut flinfo); + } + + let fn_addr = flinfo.fn_addr.expect("null function pointer"); + + let mut fcinfo = unsafe { std::mem::zeroed::() }; + fcinfo.flinfo = &mut flinfo; + fcinfo.fncollation = pgrx::pg_sys::DEFAULT_COLLATION_OID; + fcinfo.context = std::ptr::null_mut(); + fcinfo.resultinfo = std::ptr::null_mut(); + fcinfo.isnull = true; + fcinfo.nargs = 0; + + let result_datum = unsafe { pgrx::pg_sys::ffi::pg_guard_ffi_boundary(|| fn_addr(&mut fcinfo)) }; + + let result_option = unsafe { String::from_datum(result_datum, fcinfo.isnull) }; + + let result_string = result_option.expect("null string"); + + let (vector, pg_distance) = convert_name_to_vd(&result_string).unwrap(); + + unsafe { + pgrx::pg_sys::pfree(result_datum.cast_mut_ptr()); + } + + Opfamily { + vector, + pg_distance, + } +} diff --git a/src/vchordrqfscan/index/am_scan.rs b/src/vchordrqfscan/index/am_scan.rs new file mode 100644 index 0000000..7396bd1 --- /dev/null +++ b/src/vchordrqfscan/index/am_scan.rs @@ -0,0 +1,132 @@ +use super::am_options::Opfamily; +use crate::postgres::Relation; +use crate::vchordrqfscan::algorithm::scan::scan; +use crate::vchordrqfscan::gucs::executing::epsilon; +use crate::vchordrqfscan::gucs::executing::max_scan_tuples; +use crate::vchordrqfscan::gucs::executing::probes; +use base::distance::Distance; +use base::search::*; +use base::vector::*; + +pub enum Scanner { + Initial { + vector: Option<(OwnedVector, Opfamily)>, + threshold: Option, + recheck: bool, + }, + Vbase { + vbase: Box>, + threshold: Option, + recheck: bool, + opfamily: Opfamily, + }, + Empty {}, +} + +pub fn scan_build( + orderbys: Vec>, + spheres: Vec<(Option, Option)>, + opfamily: Opfamily, +) -> (Option<(OwnedVector, Opfamily)>, Option, bool) { + let mut pair = None; + let mut threshold = None; + let mut recheck = false; + for orderby_vector in orderbys { + if pair.is_none() { + pair = orderby_vector; + } else if orderby_vector.is_some() && pair != orderby_vector { + pgrx::error!("vector search with multiple vectors is not supported"); + } + } + for (sphere_vector, sphere_threshold) in spheres { + if pair.is_none() { + pair = sphere_vector; + threshold = sphere_threshold; + } else if pair == sphere_vector { + if threshold.is_none() || sphere_threshold < threshold { + threshold = sphere_threshold; + } + } else { + recheck = true; + break; + } + } + (pair.map(|x| (x, opfamily)), threshold, recheck) +} + +pub fn scan_make( + vector: Option<(OwnedVector, Opfamily)>, + threshold: Option, + recheck: bool, +) -> Scanner { + Scanner::Initial { + vector, + threshold, + recheck, + } +} + +pub fn scan_next(scanner: &mut Scanner, relation: Relation) -> Option<(Pointer, bool)> { + if let Scanner::Initial { + vector, + threshold, + recheck, + } = scanner + { + if let Some((vector, opfamily)) = vector.as_ref() { + let vbase = scan( + relation, + match vector { + OwnedVector::Vecf32(x) => x.slice().to_vec(), + OwnedVector::Vecf16(_) => unreachable!(), + OwnedVector::SVecf32(_) => unreachable!(), + OwnedVector::BVector(_) => unreachable!(), + }, + opfamily.distance_kind(), + probes(), + epsilon(), + ); + *scanner = Scanner::Vbase { + vbase: if let Some(max_scan_tuples) = max_scan_tuples() { + Box::new(vbase.take(max_scan_tuples as usize)) + } else { + Box::new(vbase) + }, + threshold: *threshold, + recheck: *recheck, + opfamily: *opfamily, + }; + } else { + *scanner = Scanner::Empty {}; + } + } + match scanner { + Scanner::Initial { .. } => unreachable!(), + Scanner::Vbase { + vbase, + threshold, + recheck, + opfamily, + } => match ( + vbase.next().map(|(d, p)| (opfamily.process(d), p)), + threshold, + ) { + (Some((_, ptr)), None) => Some((ptr, *recheck)), + (Some((distance, ptr)), Some(t)) if distance < *t => Some((ptr, *recheck)), + _ => { + let scanner = std::mem::replace(scanner, Scanner::Empty {}); + scan_release(scanner); + None + } + }, + Scanner::Empty {} => None, + } +} + +pub fn scan_release(scanner: Scanner) { + match scanner { + Scanner::Initial { .. } => {} + Scanner::Vbase { .. } => {} + Scanner::Empty {} => {} + } +} diff --git a/src/vchordrqfscan/index/functions.rs b/src/vchordrqfscan/index/functions.rs new file mode 100644 index 0000000..98f0f25 --- /dev/null +++ b/src/vchordrqfscan/index/functions.rs @@ -0,0 +1,26 @@ +use crate::postgres::Relation; +use crate::vchordrqfscan::algorithm::prewarm::prewarm; +use pgrx::pg_sys::Oid; +use pgrx_catalog::{PgAm, PgClass}; + +#[pgrx::pg_extern(sql = "")] +fn _vchordrqfscan_prewarm(indexrelid: Oid, height: i32) -> String { + let pg_am = PgAm::search_amname(c"vchordrqfscan").unwrap(); + let Some(pg_am) = pg_am.get() else { + pgrx::error!("vchord is not installed"); + }; + let pg_class = PgClass::search_reloid(indexrelid).unwrap(); + let Some(pg_class) = pg_class.get() else { + pgrx::error!("there is no such index"); + }; + if pg_class.relam() != pg_am.oid() { + pgrx::error!("{:?} is not a vchordrqfscan index", pg_class.relname()); + } + let index = unsafe { pgrx::pg_sys::index_open(indexrelid, pgrx::pg_sys::ShareLock as _) }; + let relation = unsafe { Relation::new(index) }; + let message = prewarm(relation, height); + unsafe { + pgrx::pg_sys::index_close(index, pgrx::pg_sys::ShareLock as _); + } + message +} diff --git a/src/vchordrqfscan/index/mod.rs b/src/vchordrqfscan/index/mod.rs new file mode 100644 index 0000000..5203e4f --- /dev/null +++ b/src/vchordrqfscan/index/mod.rs @@ -0,0 +1,12 @@ +pub mod am; +pub mod am_options; +pub mod am_scan; +pub mod functions; +pub mod opclass; +pub mod utils; + +pub unsafe fn init() { + unsafe { + am::init(); + } +} diff --git a/src/vchordrqfscan/index/opclass.rs b/src/vchordrqfscan/index/opclass.rs new file mode 100644 index 0000000..d095b9a --- /dev/null +++ b/src/vchordrqfscan/index/opclass.rs @@ -0,0 +1,14 @@ +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchordrqfscan_support_vector_l2_ops() -> String { + "vector_l2_ops".to_string() +} + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchordrqfscan_support_vector_ip_ops() -> String { + "vector_ip_ops".to_string() +} + +#[pgrx::pg_extern(immutable, strict, parallel_safe)] +fn _vchordrqfscan_support_vector_cosine_ops() -> String { + "vector_cosine_ops".to_string() +} diff --git a/src/vchordrqfscan/index/utils.rs b/src/vchordrqfscan/index/utils.rs new file mode 100644 index 0000000..a5d85a3 --- /dev/null +++ b/src/vchordrqfscan/index/utils.rs @@ -0,0 +1,20 @@ +use base::search::*; + +pub fn pointer_to_ctid(pointer: Pointer) -> pgrx::pg_sys::ItemPointerData { + let value = pointer.as_u64(); + pgrx::pg_sys::ItemPointerData { + ip_blkid: pgrx::pg_sys::BlockIdData { + bi_hi: ((value >> 32) & 0xffff) as u16, + bi_lo: ((value >> 16) & 0xffff) as u16, + }, + ip_posid: (value & 0xffff) as u16, + } +} + +pub fn ctid_to_pointer(ctid: pgrx::pg_sys::ItemPointerData) -> Pointer { + let mut value = 0; + value |= (ctid.ip_blkid.bi_hi as u64) << 32; + value |= (ctid.ip_blkid.bi_lo as u64) << 16; + value |= ctid.ip_posid as u64; + Pointer::new(value) +} diff --git a/src/vchordrqfscan/mod.rs b/src/vchordrqfscan/mod.rs new file mode 100644 index 0000000..c2ae945 --- /dev/null +++ b/src/vchordrqfscan/mod.rs @@ -0,0 +1,11 @@ +mod algorithm; +mod gucs; +mod index; +mod types; + +pub unsafe fn init() { + unsafe { + index::init(); + gucs::init(); + } +} diff --git a/src/vchordrqfscan/types.rs b/src/vchordrqfscan/types.rs new file mode 100644 index 0000000..b3a1067 --- /dev/null +++ b/src/vchordrqfscan/types.rs @@ -0,0 +1,90 @@ +use serde::{Deserialize, Serialize}; +use validator::{Validate, ValidationError, ValidationErrors}; + +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +pub struct VchordrqfscanInternalBuildOptions { + #[serde(default = "VchordrqfscanInternalBuildOptions::default_lists")] + #[validate(length(min = 1, max = 8), custom(function = VchordrqfscanInternalBuildOptions::validate_lists))] + pub lists: Vec, + #[serde(default = "VchordrqfscanInternalBuildOptions::default_spherical_centroids")] + pub spherical_centroids: bool, + #[serde(default = "VchordrqfscanInternalBuildOptions::default_build_threads")] + #[validate(range(min = 1, max = 255))] + pub build_threads: u16, +} + +impl VchordrqfscanInternalBuildOptions { + fn default_lists() -> Vec { + vec![1000] + } + fn validate_lists(lists: &[u32]) -> Result<(), ValidationError> { + if !lists.is_sorted() { + return Err(ValidationError::new("`lists` should be in ascending order")); + } + if !lists.iter().all(|x| (1..=1 << 24).contains(x)) { + return Err(ValidationError::new("list is too long or too short")); + } + Ok(()) + } + fn default_spherical_centroids() -> bool { + false + } + fn default_build_threads() -> u16 { + 1 + } +} + +impl Default for VchordrqfscanInternalBuildOptions { + fn default() -> Self { + Self { + lists: Self::default_lists(), + spherical_centroids: Self::default_spherical_centroids(), + build_threads: Self::default_build_threads(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +pub struct VchordrqfscanExternalBuildOptions { + pub table: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +#[serde(rename_all = "snake_case")] +pub enum VchordrqfscanBuildOptions { + Internal(VchordrqfscanInternalBuildOptions), + External(VchordrqfscanExternalBuildOptions), +} + +impl Default for VchordrqfscanBuildOptions { + fn default() -> Self { + Self::Internal(Default::default()) + } +} + +impl Validate for VchordrqfscanBuildOptions { + fn validate(&self) -> Result<(), ValidationErrors> { + use VchordrqfscanBuildOptions::*; + match self { + Internal(internal_build) => internal_build.validate(), + External(external_build) => external_build.validate(), + } + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +pub struct VchordrqfscanIndexingOptions { + #[serde(default = "VchordrqfscanIndexingOptions::default_residual_quantization")] + pub residual_quantization: bool, + pub build: VchordrqfscanBuildOptions, +} + +impl VchordrqfscanIndexingOptions { + fn default_residual_quantization() -> bool { + false + } +}