diff --git a/bench/index.py b/bench/index.py index 97b67a0..bd0322d 100644 --- a/bench/index.py +++ b/bench/index.py @@ -61,15 +61,15 @@ def get_ivf_ops_config(metric, k, name=None): residual_quantization = true spherical_centroids = false """ - elif metric == "cos": - metric_ops = "vector_cos_ops" + elif metric == "cosine": + metric_ops = "vector_cosine_ops" ivf_config = f""" nlist = {k} residual_quantization = false spherical_centroids = true """ - elif metric == "dot": - metric_ops = "vector_dot_ops" + elif metric == "ip": + metric_ops = "vector_ip_ops" ivf_config = f""" nlist = {k} residual_quantization = false diff --git a/src/algorithm/build.rs b/src/algorithm/build.rs index 9f3eef8..6c111d0 100644 --- a/src/algorithm/build.rs +++ b/src/algorithm/build.rs @@ -1,17 +1,16 @@ use crate::algorithm::rabitq; use crate::algorithm::tuples::*; -use crate::index::am_options::PgDistanceKind; -use crate::index::utils::load_table_vectors; +use crate::index::am_options::Opfamily; use crate::postgres::BufferWriteGuard; use crate::postgres::Relation; -use crate::types::ExternalCentroids; +use crate::types::RabbitholeBuildOptions; +use crate::types::RabbitholeExternalBuildOptions; use crate::types::RabbitholeIndexingOptions; +use crate::types::RabbitholeInternalBuildOptions; use base::distance::DistanceKind; use base::index::VectorOptions; use base::scalar::ScalarLike; use base::search::Pointer; -use base::vector::VectBorrowed; -use base::vector::VectorBorrowed; use common::vec2::Vec2; use rand::Rng; use rkyv::ser::serializers::AllocSerializer; @@ -23,6 +22,7 @@ pub trait HeapRelation { fn traverse(&self, callback: F) where F: FnMut((Pointer, Vec)); + fn opfamily(&self) -> Opfamily; } pub trait Reporter { @@ -35,28 +35,26 @@ pub fn build( rabbithole_options: RabbitholeIndexingOptions, heap_relation: T, relation: Relation, - pg_distance: PgDistanceKind, mut reporter: R, ) { let dims = vector_options.dims; let is_residual = rabbithole_options.residual_quantization && vector_options.d == DistanceKind::L2; - let structure = match &rabbithole_options.external_centroids { - Some(_) => Structure::load( + let structure = match rabbithole_options.build { + RabbitholeBuildOptions::External(external_build) => Structure::extern_build( vector_options.clone(), - rabbithole_options.clone(), - pg_distance, + heap_relation.opfamily(), + external_build.clone(), ), - None => { + RabbitholeBuildOptions::Internal(internal_build) => { let mut tuples_total = 0_usize; let samples = { let mut rand = rand::thread_rng(); - let max_number_of_samples = rabbithole_options.nlist.saturating_mul(256); + let max_number_of_samples = internal_build.nlist.saturating_mul(256); let mut samples = Vec::new(); let mut number_of_samples = 0_u32; heap_relation.traverse(|(_, vector)| { - assert_eq!(dims as usize, vector.len(), "invalid vector dimensions",); - let vector = rabitq::project(&vector); + assert_eq!(dims as usize, vector.len(), "invalid vector dimensions"); if number_of_samples < max_number_of_samples { samples.extend(vector); number_of_samples += 1; @@ -71,7 +69,7 @@ pub fn build( Vec2::from_vec((number_of_samples as _, dims as _), samples) }; reporter.tuples_total(tuples_total); - Structure::compute(vector_options.clone(), rabbithole_options.clone(), samples) + Structure::internal_build(vector_options.clone(), internal_build.clone(), samples) } }; let h2_len = structure.h2_len(); @@ -182,33 +180,37 @@ pub fn build( } struct Structure { - h2_mean: Vec, - h2_children: Vec, + h2_means: Vec>, + h2_children: Vec>, h1_means: Vec>, h1_children: Vec>, } impl Structure { - fn compute( + fn internal_build( vector_options: VectorOptions, - rabbithole_options: RabbitholeIndexingOptions, - samples: Vec2, + internal_build: RabbitholeInternalBuildOptions, + mut samples: Vec2, ) -> Self { let dims = vector_options.dims; + for i in 0..samples.shape_0() { + let vector = &mut samples[(i,)]; + vector.copy_from_slice(&rabitq::project(vector)); + } let h1_means = base::parallelism::RayonParallelism::scoped( - rabbithole_options.build_threads as _, + internal_build.build_threads as _, Arc::new(AtomicBool::new(false)), |parallelism| { let raw = k_means::k_means( parallelism, - rabbithole_options.nlist as usize, + internal_build.nlist as usize, samples, - rabbithole_options.spherical_centroids, + internal_build.spherical_centroids, 10, false, ); let mut centroids = Vec::new(); - for i in 0..rabbithole_options.nlist { + for i in 0..internal_build.nlist { centroids.push(raw[(i as usize,)].to_vec()); } centroids @@ -218,116 +220,161 @@ impl Structure { .expect("k_means interrupted"); let h2_mean = { let mut centroid = vec![0.0; dims as _]; - for i in 0..rabbithole_options.nlist { + for i in 0..internal_build.nlist { for j in 0..dims { centroid[j as usize] += h1_means[i as usize][j as usize]; } } for j in 0..dims { - centroid[j as usize] /= rabbithole_options.nlist as f32; + centroid[j as usize] /= internal_build.nlist as f32; } centroid }; Structure { - h2_mean, - h2_children: (0..rabbithole_options.nlist).collect(), + h2_means: vec![h2_mean], + h2_children: vec![(0..internal_build.nlist).collect()], h1_means, - h1_children: (0..rabbithole_options.nlist).map(|_| Vec::new()).collect(), + h1_children: (0..internal_build.nlist).map(|_| Vec::new()).collect(), } } - fn load( + fn extern_build( vector_options: VectorOptions, - rabbithole_options: RabbitholeIndexingOptions, - pg_distance: PgDistanceKind, + _opfamily: Opfamily, + external_build: RabbitholeExternalBuildOptions, ) -> Self { - let dims = vector_options.dims; - let preprocess_data = match pg_distance { - PgDistanceKind::L2 | PgDistanceKind::Dot => { - |b: VectBorrowed| rabitq::project(b.slice()) + use std::collections::BTreeMap; + let RabbitholeExternalBuildOptions { table } = external_build; + let query = format!("SELECT id, parent, vector FROM {table};"); + let mut parents = BTreeMap::new(); + let mut vectors = BTreeMap::new(); + pgrx::spi::Spi::connect(|client| { + use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput; + use base::vector::VectorBorrowed; + use pgrx::pg_sys::panic::ErrorReportable; + let table = client.select(&query, None, None).unwrap_or_report(); + for row in table { + let id: Option = row.get_by_name("id").unwrap(); + let parent: Option = row.get_by_name("parent").unwrap(); + let vector: Option = row.get_by_name("vector").unwrap(); + let id = id.expect("extern build: id could not be NULL"); + let vector = vector.expect("extern build: vector could not be NULL"); + let pop = parents.insert(id, parent); + if pop.is_some() { + pgrx::error!( + "external build: there are at least two lines have same id, id = {id}" + ); + } + if vector_options.dims != vector.as_borrowed().dims() { + pgrx::error!("extern build: incorrect dimension, id = {id}"); + } + vectors.insert(id, rabitq::project(vector.slice())); } - PgDistanceKind::Cos => { - |b: VectBorrowed| rabitq::project(b.function_normalize().slice()) + }); + let mut children = parents + .keys() + .map(|x| (*x, Vec::new())) + .collect::>(); + let mut root = None; + for (&id, &parent) in parents.iter() { + if let Some(parent) = parent { + if let Some(parent) = children.get_mut(&parent) { + parent.push(id); + } else { + pgrx::error!( + "external build: parent does not exist, id = {id}, parent = {parent}" + ); + } + } else { + if let Some(root) = root { + pgrx::error!("external build: two root, id = {root}, id = {id}"); + } else { + root = Some(id); + } } + } + let Some(root) = root else { + pgrx::error!("extern build: there are no root"); }; - let preprocess_index = |b: VectBorrowed| b.slice().to_vec(); - - let h1_means = match &rabbithole_options.external_centroids { - Some(ExternalCentroids { - table, - h1_means_column: h1, - .. - }) => load_table_vectors( - table, - h1, - rabbithole_options.nlist, - vector_options.dims, - preprocess_data, - ), - - _ => unreachable!(), - }; - let h1_children = match &rabbithole_options.external_centroids { - Some(ExternalCentroids { - table, - h1_children_column: Some(h1), - .. - }) => load_table_vectors(table, h1, 1, vector_options.dims, preprocess_index) - .into_iter() - .map(|v| v.into_iter().map(|f| f as u32).collect()) - .collect(), - _ => (0..rabbithole_options.nlist).map(|_| Vec::new()).collect(), - }; - let h2_mean = match &rabbithole_options.external_centroids { - Some(ExternalCentroids { - table, - h2_mean_column: Some(h2), - .. - }) => load_table_vectors(table, h2, 1, vector_options.dims, preprocess_data) - .pop() - .expect("load h2_mean panic"), - _ => { - let mut centroid = vec![0.0; dims as _]; - for i in 0..rabbithole_options.nlist { - for j in 0..dims { - centroid[j as usize] += h1_means[i as usize][j as usize]; + let mut heights = BTreeMap::<_, _>::new(); + fn dfs_for_heights( + heights: &mut BTreeMap>, + children: &BTreeMap>, + u: i32, + ) { + if heights.contains_key(&u) { + pgrx::error!("extern build: detect a cycle, id = {u}"); + } + heights.insert(u, None); + let mut height = None; + for &v in children[&u].iter() { + dfs_for_heights(heights, children, v); + let new = heights[&v].unwrap() + 1; + if let Some(height) = height { + if height != new { + pgrx::error!("extern build: two heights, id = {u}"); } + } else { + height = Some(new); } - for j in 0..dims { - centroid[j as usize] /= rabbithole_options.nlist as f32; - } - centroid } - }; - let h2_children = match &rabbithole_options.external_centroids { - Some(ExternalCentroids { - table, - h2_children_column: Some(h2), - .. - }) => load_table_vectors(table, h2, 1, vector_options.dims, preprocess_index) - .pop() - .expect("load h2_children panic") - .into_iter() - .map(|f| f as u32) - .collect(), - _ => (0..rabbithole_options.nlist).collect(), - }; - Structure { - h2_mean, + if height.is_none() { + height = Some(1); + } + heights.insert(u, height); + } + dfs_for_heights(&mut heights, &children, root); + let heights = heights + .into_iter() + .map(|(k, v)| (k, v.expect("not a connected graph"))) + .collect::>(); + if heights[&root] != 2 { + pgrx::error!( + "extern build: unexpected tree height, height = {}", + heights[&root] + ); + } + let mut cursors = vec![0_u32; 1 + heights[&root] as usize]; + let mut labels = BTreeMap::new(); + for id in parents.keys().copied() { + let height = heights[&id]; + let cursor = cursors[height as usize]; + labels.insert(id, (height, cursor)); + cursors[height as usize] += 1; + } + fn extract( + height: u32, + labels: &BTreeMap, + vectors: &BTreeMap>, + children: &BTreeMap>, + ) -> (Vec>, Vec>) { + labels + .iter() + .filter(|(_, &(h, _))| h == height) + .map(|(id, _)| { + ( + vectors[id].clone(), + children[id].iter().map(|id| labels[id].1).collect(), + ) + }) + .unzip() + } + let (h2_means, h2_children) = extract(2, &labels, &vectors, &children); + let (h1_means, h1_children) = extract(1, &labels, &vectors, &children); + Self { + h2_means, h2_children, h1_means, h1_children, } } fn h2_len(&self) -> u32 { - 1 + self.h2_means.len() as _ } fn h2_means(&self, i: u32) -> &Vec { - assert!(i == 0); - &self.h2_mean + &self.h2_means[i as usize] } fn h2_children(&self, i: u32) -> &Vec { - assert!(i == 0); - &self.h2_children + &self.h2_children[i as usize] } fn h1_len(&self) -> u32 { self.h1_means.len() as _ diff --git a/src/algorithm/scan.rs b/src/algorithm/scan.rs index dec0cbe..56950f1 100644 --- a/src/algorithm/scan.rs +++ b/src/algorithm/scan.rs @@ -15,9 +15,9 @@ pub fn scan( relation: Relation, vector: Vec, distance_kind: DistanceKind, - h1_nprobe: u32, + nprobe_1: u32, ) -> impl Iterator { - assert!(h1_nprobe >= 1); + assert!(nprobe_1 >= 1); let meta_guard = relation.read(0); let meta_tuple = meta_guard .get() @@ -28,7 +28,7 @@ pub fn scan( let dims = meta_tuple.dims; assert_eq!(dims as usize, vector.len(), "invalid vector dimensions"); let vector = rabitq::project(&vector); - let is_residual = meta_tuple.is_residual && distance_kind == DistanceKind::L2; + let is_residual = meta_tuple.is_residual; let default_lut = if !is_residual { Some(rabitq::fscan_preprocess(&vector)) } else { @@ -118,7 +118,7 @@ pub fn scan( let (_, AlwaysEqual(first), AlwaysEqual(mean)) = cache.pop()?; Some((first, mean)) }) - .take(h1_nprobe as usize) + .take(nprobe_1 as usize) .collect() }; { diff --git a/src/datatype/operators_pgvector_vector.rs b/src/datatype/operators_pgvector_vector.rs index 44e17b8..1d8ac1a 100644 --- a/src/datatype/operators_pgvector_vector.rs +++ b/src/datatype/operators_pgvector_vector.rs @@ -1,5 +1,4 @@ use crate::datatype::memory_pgvector_vector::*; -use base::scalar::ScalarLike; use base::vector::{VectBorrowed, VectorBorrowed}; use std::num::NonZero; @@ -21,11 +20,14 @@ fn _rabbithole_pgvector_vector_sphere_l2_in( Ok(None) => pgrx::error!("Bad input: empty radius at sphere"), Err(_) => unreachable!(), }; - f32::reduce_sum_of_d2(lhs.slice(), center.slice()).to_f32() < radius + let lhs = lhs.as_borrowed(); + let center = center.as_borrowed(); + let d = VectBorrowed::operator_l2(lhs, center).to_f32().sqrt(); + d < radius } #[pgrx::pg_extern(immutable, strict, parallel_safe)] -fn _rabbithole_pgvector_vector_sphere_dot_in( +fn _rabbithole_pgvector_vector_sphere_ip_in( lhs: PgvectorVectorInput<'_>, rhs: pgrx::composite_type!("sphere_vector"), ) -> bool { @@ -42,11 +44,14 @@ fn _rabbithole_pgvector_vector_sphere_dot_in( Ok(None) => pgrx::error!("Bad input: empty radius at sphere"), Err(_) => unreachable!(), }; - -f32::reduce_sum_of_xy(lhs.slice(), center.slice()) < radius + let lhs = lhs.as_borrowed(); + let center = center.as_borrowed(); + let d = VectBorrowed::operator_dot(lhs, center).to_f32(); + d < radius } #[pgrx::pg_extern(immutable, strict, parallel_safe)] -fn _rabbithole_pgvector_vector_sphere_cos_in( +fn _rabbithole_pgvector_vector_sphere_cosine_in( lhs: PgvectorVectorInput<'_>, rhs: pgrx::composite_type!("sphere_vector"), ) -> bool { @@ -63,5 +68,8 @@ fn _rabbithole_pgvector_vector_sphere_cos_in( Ok(None) => pgrx::error!("Bad input: empty radius at sphere"), Err(_) => unreachable!(), }; - VectBorrowed::operator_cos(lhs.as_borrowed(), center.as_borrowed()).to_f32() < radius + let lhs = lhs.as_borrowed(); + let center = center.as_borrowed(); + let d = VectBorrowed::operator_cos(lhs, center).to_f32(); + d < radius } diff --git a/src/index/am.rs b/src/index/am.rs index 04f78e8..005a9ff 100644 --- a/src/index/am.rs +++ b/src/index/am.rs @@ -166,15 +166,11 @@ pub unsafe extern "C" fn ambuild( { use base::vector::OwnedVector; let state = unsafe { &mut *state.cast::>() }; - let vector = unsafe { - state - .this - .opfamily - .datum_to_vector(*values.add(0), *is_null.add(0)) - }; + let opfamily = state.this.opfamily; + let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; let pointer = unsafe { ctid_to_pointer(ctid.read()) }; if let Some(vector) = vector { - let vector = match vector { + let vector = match opfamily.preprocess(vector.as_borrowed()) { OwnedVector::Vecf32(x) => x, OwnedVector::Vecf16(_) => unreachable!(), OwnedVector::SVecf32(_) => unreachable!(), @@ -204,6 +200,10 @@ pub unsafe extern "C" fn ambuild( ); } } + + fn opfamily(&self) -> Opfamily { + self.opfamily + } } #[derive(Debug, Clone)] pub struct PgReporter {} @@ -225,12 +225,13 @@ pub unsafe extern "C" fn ambuild( } } } - let (vector_options, rabbithole_options, pg_distance) = unsafe { am_options::options(index) }; + let (vector_options, rabbithole_options) = unsafe { am_options::options(index) }; + let opfamily = unsafe { am_options::opfamily(index) }; let heap_relation = Heap { heap, index, index_info, - opfamily: unsafe { am_options::opfamily(index) }, + opfamily, }; let mut reporter = PgReporter {}; let index_relation = unsafe { Relation::new(index) }; @@ -239,7 +240,6 @@ pub unsafe extern "C" fn ambuild( rabbithole_options, heap_relation.clone(), index_relation.clone(), - pg_distance, reporter.clone(), ); if let Some(leader) = @@ -269,7 +269,7 @@ pub unsafe extern "C" fn ambuild( index_relation.clone(), payload, vector, - pg_distance.to_distance(), + opfamily.distance_kind(), ); tuples_done += 1; reporter.tuples_done(tuples_done); @@ -516,15 +516,11 @@ pub unsafe extern "C" fn rabbithole_parallel_build_main( { use base::vector::OwnedVector; let state = unsafe { &mut *state.cast::>() }; - let vector = unsafe { - state - .this - .opfamily - .datum_to_vector(*values.add(0), *is_null.add(0)) - }; + let opfamily = state.this.opfamily; + let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; let pointer = unsafe { ctid_to_pointer(ctid.read()) }; if let Some(vector) = vector { - let vector = match vector { + let vector = match opfamily.preprocess(vector.as_borrowed()) { OwnedVector::Vecf32(x) => x, OwnedVector::Vecf16(_) => unreachable!(), OwnedVector::SVecf32(_) => unreachable!(), @@ -554,6 +550,10 @@ pub unsafe extern "C" fn rabbithole_parallel_build_main( ); } } + + fn opfamily(&self) -> Opfamily { + self.opfamily + } } let index_relation = unsafe { Relation::new(index) }; @@ -608,7 +608,7 @@ pub unsafe extern "C" fn aminsert( let opfamily = unsafe { am_options::opfamily(index) }; let vector = unsafe { opfamily.datum_to_vector(*values.add(0), *is_null.add(0)) }; if let Some(vector) = vector { - let vector = match vector { + let vector = match opfamily.preprocess(vector.as_borrowed()) { OwnedVector::Vecf32(x) => x, OwnedVector::Vecf16(_) => unreachable!(), OwnedVector::SVecf32(_) => unreachable!(), diff --git a/src/index/am_options.rs b/src/index/am_options.rs index f9803e6..1939653 100644 --- a/src/index/am_options.rs +++ b/src/index/am_options.rs @@ -88,8 +88,8 @@ pub fn convert_opfamily_to_vd( fn convert_name_to_vd(name: &str) -> Option<(VectorKind, PgDistanceKind)> { match name.strip_suffix("_ops") { Some("vector_l2") => Some((VectorKind::Vecf32, PgDistanceKind::L2)), - Some("vector_dot") => Some((VectorKind::Vecf32, PgDistanceKind::Dot)), - Some("vector_cos") => Some((VectorKind::Vecf32, PgDistanceKind::Cos)), + Some("vector_ip") => Some((VectorKind::Vecf32, PgDistanceKind::Dot)), + Some("vector_cosine") => Some((VectorKind::Vecf32, PgDistanceKind::Cos)), _ => None, } } @@ -114,9 +114,7 @@ unsafe fn convert_reloptions_to_options( } } -pub unsafe fn options( - index: pgrx::pg_sys::Relation, -) -> (VectorOptions, RabbitholeIndexingOptions, PgDistanceKind) { +pub unsafe fn options(index: pgrx::pg_sys::Relation) -> (VectorOptions, RabbitholeIndexingOptions) { let opfamily = unsafe { (*index).rd_opfamily.read() }; let att = unsafe { &mut *(*index).rd_att }; let atts = unsafe { att.attrs.as_slice(att.natts as _) }; @@ -144,7 +142,7 @@ pub unsafe fn options( }; // get indexing, segment, optimizing let rabitq = unsafe { convert_reloptions_to_options((*index).rd_options) }; - (vector, rabitq, pg_d) + (vector, rabitq) } #[derive(Debug, Clone, Copy)] @@ -205,6 +203,7 @@ impl Opfamily { pub fn process(self, x: Distance) -> f32 { match self.pg_distance { PgDistanceKind::Cos => f32::from(x) + 1.0f32, + PgDistanceKind::L2 => f32::from(x).sqrt(), _ => f32::from(x), } } diff --git a/src/index/am_scan.rs b/src/index/am_scan.rs index f248c00..0c6406d 100644 --- a/src/index/am_scan.rs +++ b/src/index/am_scan.rs @@ -74,7 +74,7 @@ pub fn scan_next(scanner: &mut Scanner, relation: Relation) -> Option<(Pointer, if let Some((vector, opfamily)) = vector.as_ref() { let vbase = scan( relation, - match vector { + match opfamily.preprocess(vector.as_borrowed()) { OwnedVector::Vecf32(x) => x.slice().to_vec(), OwnedVector::Vecf16(_) => unreachable!(), OwnedVector::SVecf32(_) => unreachable!(), diff --git a/src/index/utils.rs b/src/index/utils.rs index 68df340..35697ab 100644 --- a/src/index/utils.rs +++ b/src/index/utils.rs @@ -1,11 +1,6 @@ use base::distance::{Distance, DistanceKind}; use base::scalar::ScalarLike; use base::search::*; -use base::vector::{VectBorrowed, VectorBorrowed}; -use pgrx::pg_sys::panic::ErrorReportable; -use pgrx::{error, Spi}; - -use crate::datatype::memory_pgvector_vector::PgvectorVectorOutput; pub fn pointer_to_ctid(pointer: Pointer) -> pgrx::pg_sys::ItemPointerData { let value = pointer.as_u64(); @@ -26,37 +21,6 @@ pub fn ctid_to_pointer(ctid: pgrx::pg_sys::ItemPointerData) -> Pointer { Pointer::new(value) } -pub fn load_table_vectors( - table_name: &str, - column_name: &str, - rows: u32, - dims: u32, - preprocess: F, -) -> Vec> -where - F: Fn(VectBorrowed) -> Vec, -{ - let query = format!("SELECT {column_name} FROM {table_name};"); - let mut centroids = Vec::new(); - - Spi::connect(|client| { - let tup_table = client.select(&query, None, None).unwrap_or_report(); - assert_eq!(tup_table.len(), rows as usize); - - for row in tup_table { - let vector = row[column_name].value::(); - if let Ok(Some(v)) = vector { - let borrowed = v.as_borrowed(); - assert_eq!(borrowed.dims(), dims); - centroids.push(preprocess(borrowed)); - } else { - error!("load vectors from column is not valid") - } - } - centroids - }) -} - pub fn distance(d: DistanceKind, lhs: &[f32], rhs: &[f32]) -> Distance { match d { DistanceKind::L2 => Distance::from_f32(f32::reduce_sum_of_d2(lhs, rhs)), diff --git a/src/lib.rs b/src/lib.rs index 347aee7..db5d53c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,8 @@ -#![allow(clippy::too_many_arguments)] +#![allow(clippy::collapsible_else_if)] +#![allow(clippy::identity_op)] #![allow(clippy::needless_range_loop)] +#![allow(clippy::too_many_arguments)] #![allow(clippy::type_complexity)] -#![allow(clippy::identity_op)] mod algorithm; mod datatype; diff --git a/src/sql/finalize.sql b/src/sql/finalize.sql index efcff72..259b779 100644 --- a/src/sql/finalize.sql +++ b/src/sql/finalize.sql @@ -14,18 +14,18 @@ CREATE OPERATOR <<->> ( COMMUTATOR = <<->> ); -CREATE OPERATOR <<=>> ( - PROCEDURE = _rabbithole_pgvector_vector_sphere_cos_in, +CREATE OPERATOR <<#>> ( + PROCEDURE = _rabbithole_pgvector_vector_sphere_ip_in, LEFTARG = vector, RIGHTARG = sphere_vector, - COMMUTATOR = <<=>> + COMMUTATOR = <<#>> ); -CREATE OPERATOR <<#>> ( - PROCEDURE = _rabbithole_pgvector_vector_sphere_dot_in, +CREATE OPERATOR <<=>> ( + PROCEDURE = _rabbithole_pgvector_vector_sphere_cosine_in, LEFTARG = vector, RIGHTARG = sphere_vector, - COMMUTATOR = <<#>> + COMMUTATOR = <<=>> ); -- List of functions @@ -41,8 +41,8 @@ COMMENT ON ACCESS METHOD rabbithole IS 'rabbithole index access method'; -- List of operator families CREATE OPERATOR FAMILY vector_l2_ops USING rabbithole; -CREATE OPERATOR FAMILY vector_dot_ops USING rabbithole; -CREATE OPERATOR FAMILY vector_cos_ops USING rabbithole; +CREATE OPERATOR FAMILY vector_ip_ops USING rabbithole; +CREATE OPERATOR FAMILY vector_cosine_ops USING rabbithole; -- List of operator classes @@ -51,12 +51,12 @@ CREATE OPERATOR CLASS vector_l2_ops OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, OPERATOR 2 <<->> (vector, sphere_vector) FOR SEARCH; -CREATE OPERATOR CLASS vector_dot_ops - FOR TYPE vector USING rabbithole FAMILY vector_dot_ops AS +CREATE OPERATOR CLASS vector_ip_ops + FOR TYPE vector USING rabbithole FAMILY vector_ip_ops AS OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops, OPERATOR 2 <<#>> (vector, sphere_vector) FOR SEARCH; -CREATE OPERATOR CLASS vector_cos_ops - FOR TYPE vector USING rabbithole FAMILY vector_cos_ops AS +CREATE OPERATOR CLASS vector_cosine_ops + FOR TYPE vector USING rabbithole FAMILY vector_cosine_ops AS OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, - OPERATOR 2 <<=>> (vector, sphere_vector) FOR SEARCH; \ No newline at end of file + OPERATOR 2 <<=>> (vector, sphere_vector) FOR SEARCH; diff --git a/src/types.rs b/src/types.rs index b60fa14..eb977b6 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,55 +1,81 @@ use serde::{Deserialize, Serialize}; use validator::Validate; -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct ExternalCentroids { - pub table: String, - pub h1_means_column: String, - pub h1_children_column: Option, - pub h2_mean_column: Option, - pub h2_children_column: Option, -} - #[derive(Debug, Clone, Serialize, Deserialize, Validate)] #[serde(deny_unknown_fields)] -pub struct RabbitholeIndexingOptions { - #[serde(default = "RabbitholeIndexingOptions::default_nlist")] +pub struct RabbitholeInternalBuildOptions { + #[serde(default = "RabbitholeInternalBuildOptions::default_nlist")] #[validate(range(min = 1, max = 1_000_000))] pub nlist: u32, - #[serde(default = "RabbitholeIndexingOptions::default_spherical_centroids")] + #[serde(default = "RabbitholeInternalBuildOptions::default_spherical_centroids")] pub spherical_centroids: bool, - #[serde(default = "RabbitholeIndexingOptions::default_residual_quantization")] - pub residual_quantization: bool, - #[serde(default = "RabbitholeIndexingOptions::default_build_threads")] + #[serde(default = "RabbitholeInternalBuildOptions::default_build_threads")] #[validate(range(min = 1, max = 255))] pub build_threads: u16, - pub external_centroids: Option, } -impl RabbitholeIndexingOptions { +impl RabbitholeInternalBuildOptions { fn default_nlist() -> u32 { 1000 } fn default_spherical_centroids() -> bool { false } - fn default_residual_quantization() -> bool { - false - } fn default_build_threads() -> u16 { 1 } } -impl Default for RabbitholeIndexingOptions { +impl Default for RabbitholeInternalBuildOptions { fn default() -> Self { Self { nlist: Self::default_nlist(), spherical_centroids: Self::default_spherical_centroids(), - residual_quantization: Self::default_residual_quantization(), build_threads: Self::default_build_threads(), - external_centroids: None, } } } + +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +pub struct RabbitholeExternalBuildOptions { + pub table: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +#[serde(rename_all = "snake_case")] +pub enum RabbitholeBuildOptions { + Internal(RabbitholeInternalBuildOptions), + External(RabbitholeExternalBuildOptions), +} + +impl Default for RabbitholeBuildOptions { + fn default() -> Self { + Self::Internal(Default::default()) + } +} + +impl Validate for RabbitholeBuildOptions { + fn validate(&self) -> Result<(), validator::ValidationErrors> { + use RabbitholeBuildOptions::*; + match self { + Internal(internal_build) => internal_build.validate(), + External(external_build) => external_build.validate(), + } + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +pub struct RabbitholeIndexingOptions { + #[serde(default = "RabbitholeIndexingOptions::default_residual_quantization")] + pub residual_quantization: bool, + pub build: RabbitholeBuildOptions, +} + +impl RabbitholeIndexingOptions { + fn default_residual_quantization() -> bool { + false + } +}