Skip to content

Commit c27617b

Browse files
committed
fix: index on fp16
Signed-off-by: usamoi <[email protected]>
1 parent e187cbd commit c27617b

File tree

10 files changed

+82
-53
lines changed

10 files changed

+82
-53
lines changed

crates/service/src/index/optimizing/indexing.rs

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ impl<S: G> OptimizerIndexing<S> {
3333
let Some(index) = weak_index.upgrade() else {
3434
return;
3535
};
36-
let cont = pool.install(|| optimizing_indexing(index.clone()));
37-
if cont {
36+
if let Ok(()) = pool.install(|| optimizing_indexing(index.clone())) {
3837
continue;
3938
}
4039
}
@@ -77,36 +76,33 @@ impl<S: G> Seg<S> {
7776
}
7877
}
7978

80-
pub fn optimizing_indexing<S: G>(index: Arc<Index<S>>) -> bool {
79+
pub fn optimizing_indexing<S: G>(index: Arc<Index<S>>) -> Result<(), ()> {
8180
use Seg::*;
8281
let segs = {
83-
let mut all_segs = {
84-
let protect = index.protect.lock();
85-
let mut all_segs = Vec::new();
86-
all_segs.extend(protect.growing.values().map(|x| Growing(x.clone())));
87-
all_segs.extend(protect.sealed.values().map(|x| Sealed(x.clone())));
88-
all_segs.sort_by_key(|case| Reverse(case.len()));
89-
all_segs
90-
};
91-
let mut segs = Vec::new();
82+
let protect = index.protect.lock();
83+
let mut segs_0 = Vec::new();
84+
segs_0.extend(protect.growing.values().map(|x| Growing(x.clone())));
85+
segs_0.extend(protect.sealed.values().map(|x| Sealed(x.clone())));
86+
segs_0.sort_by_key(|case| Reverse(case.len()));
87+
let mut segs_1 = Vec::new();
9288
let mut total = 0u64;
9389
let mut count = 0;
94-
while let Some(seg) = all_segs.pop() {
90+
while let Some(seg) = segs_0.pop() {
9591
if total + seg.len() as u64 <= index.options.segment.max_sealed_segment_size as u64 {
9692
total += seg.len() as u64;
9793
if let Growing(_) = seg {
9894
count += 1;
9995
}
100-
segs.push(seg);
96+
segs_1.push(seg);
10197
} else {
10298
break;
10399
}
104100
}
105-
if segs.is_empty() || (segs.len() == 1 && count == 0) {
101+
if segs_1.is_empty() || (segs_1.len() == 1 && count == 0) {
106102
index.instant_index.store(Instant::now());
107-
return true;
103+
return Err(());
108104
}
109-
segs
105+
segs_1
110106
};
111107
let sealed_segment = merge(&index, &segs);
112108
{
@@ -118,7 +114,7 @@ pub fn optimizing_indexing<S: G>(index: Arc<Index<S>>) -> bool {
118114
if protect.growing.contains_key(&seg.uuid()) {
119115
continue;
120116
}
121-
return false;
117+
return Ok(());
122118
}
123119
for seg in segs.iter() {
124120
protect.sealed.remove(&seg.uuid());
@@ -127,7 +123,7 @@ pub fn optimizing_indexing<S: G>(index: Arc<Index<S>>) -> bool {
127123
protect.sealed.insert(sealed_segment.uuid(), sealed_segment);
128124
protect.maintain(index.options.clone(), index.delete.clone(), &index.view);
129125
}
130-
false
126+
Ok(())
131127
}
132128

133129
fn merge<S: G>(index: &Arc<Index<S>>, segs: &[Seg<S>]) -> Arc<SealedSegment<S>> {

crates/service/src/worker/instance.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ impl Instance {
3333
(Distance::Cos, Kind::F32) => Self::F32Cos(Index::open(path, options)),
3434
(Distance::Dot, Kind::F32) => Self::F32Dot(Index::open(path, options)),
3535
(Distance::L2, Kind::F32) => Self::F32L2(Index::open(path, options)),
36-
(Distance::Cos, Kind::F16) => Self::F16Cos(Index::create(path, options)),
37-
(Distance::Dot, Kind::F16) => Self::F16Dot(Index::create(path, options)),
38-
(Distance::L2, Kind::F16) => Self::F16L2(Index::create(path, options)),
36+
(Distance::Cos, Kind::F16) => Self::F16Cos(Index::open(path, options)),
37+
(Distance::Dot, Kind::F16) => Self::F16Dot(Index::open(path, options)),
38+
(Distance::L2, Kind::F16) => Self::F16L2(Index::open(path, options)),
3939
}
4040
}
4141
pub fn options(&self) -> &IndexOptions {

src/datatype/vecf16.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ CREATE TYPE vecf16 (
3939
#[repr(C, align(8))]
4040
pub struct Vecf16 {
4141
varlena: u32,
42+
kind: u8,
4243
len: u16,
4344
phantom: [F16; 0],
4445
}
@@ -60,6 +61,7 @@ impl Vecf16 {
6061
let layout = Vecf16::layout(slice.len());
6162
let ptr = std::alloc::alloc(layout) as *mut Vecf16;
6263
std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16::varlena(layout.size()));
64+
std::ptr::addr_of_mut!((*ptr).kind).write(16);
6365
std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16);
6466
std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len());
6567
Box::from_raw(ptr)
@@ -71,6 +73,7 @@ impl Vecf16 {
7173
let layout = Vecf16::layout(slice.len());
7274
let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut Vecf16;
7375
std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16::varlena(layout.size()));
76+
std::ptr::addr_of_mut!((*ptr).kind).write(16);
7477
std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16);
7578
std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len());
7679
Vecf16Output(NonNull::new(ptr).unwrap())
@@ -82,6 +85,7 @@ impl Vecf16 {
8285
let layout = Vecf16::layout(len);
8386
let ptr = std::alloc::alloc_zeroed(layout) as *mut Vecf16;
8487
std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16::varlena(layout.size()));
88+
std::ptr::addr_of_mut!((*ptr).kind).write(16);
8589
std::ptr::addr_of_mut!((*ptr).len).write(len as u16);
8690
Box::from_raw(ptr)
8791
}
@@ -93,6 +97,7 @@ impl Vecf16 {
9397
let layout = Vecf16::layout(len);
9498
let ptr = pgrx::pg_sys::palloc0(layout.size()) as *mut Vecf16;
9599
std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16::varlena(layout.size()));
100+
std::ptr::addr_of_mut!((*ptr).kind).write(16);
96101
std::ptr::addr_of_mut!((*ptr).len).write(len as u16);
97102
Vecf16Output(NonNull::new(ptr).unwrap())
98103
}
@@ -102,10 +107,12 @@ impl Vecf16 {
102107
}
103108
pub fn data(&self) -> &[F16] {
104109
debug_assert_eq!(self.varlena & 3, 0);
110+
debug_assert_eq!(self.kind, 16);
105111
unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.len as usize) }
106112
}
107113
pub fn data_mut(&mut self) -> &mut [F16] {
108114
debug_assert_eq!(self.varlena & 3, 0);
115+
debug_assert_eq!(self.kind, 16);
109116
unsafe { std::slice::from_raw_parts_mut(self.phantom.as_mut_ptr(), self.len as usize) }
110117
}
111118
#[allow(dead_code)]

src/datatype/vecf32.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ CREATE TYPE vector (
3939
#[repr(C, align(8))]
4040
pub struct Vecf32 {
4141
varlena: u32,
42+
kind: u8,
4243
len: u16,
4344
phantom: [F32; 0],
4445
}
@@ -60,6 +61,7 @@ impl Vecf32 {
6061
let layout = Vecf32::layout(slice.len());
6162
let ptr = std::alloc::alloc(layout) as *mut Vecf32;
6263
std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32::varlena(layout.size()));
64+
std::ptr::addr_of_mut!((*ptr).kind).write(32);
6365
std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16);
6466
std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len());
6567
Box::from_raw(ptr)
@@ -71,6 +73,7 @@ impl Vecf32 {
7173
let layout = Vecf32::layout(slice.len());
7274
let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut Vecf32;
7375
std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32::varlena(layout.size()));
76+
std::ptr::addr_of_mut!((*ptr).kind).write(32);
7477
std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16);
7578
std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len());
7679
Vecf32Output(NonNull::new(ptr).unwrap())
@@ -82,6 +85,7 @@ impl Vecf32 {
8285
let layout = Vecf32::layout(len);
8386
let ptr = std::alloc::alloc_zeroed(layout) as *mut Vecf32;
8487
std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32::varlena(layout.size()));
88+
std::ptr::addr_of_mut!((*ptr).kind).write(32);
8589
std::ptr::addr_of_mut!((*ptr).len).write(len as u16);
8690
Box::from_raw(ptr)
8791
}
@@ -93,6 +97,7 @@ impl Vecf32 {
9397
let layout = Vecf32::layout(len);
9498
let ptr = pgrx::pg_sys::palloc0(layout.size()) as *mut Vecf32;
9599
std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32::varlena(layout.size()));
100+
std::ptr::addr_of_mut!((*ptr).kind).write(32);
96101
std::ptr::addr_of_mut!((*ptr).len).write(len as u16);
97102
Vecf32Output(NonNull::new(ptr).unwrap())
98103
}
@@ -102,10 +107,12 @@ impl Vecf32 {
102107
}
103108
pub fn data(&self) -> &[F32] {
104109
debug_assert_eq!(self.varlena & 3, 0);
110+
debug_assert_eq!(self.kind, 32);
105111
unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.len as usize) }
106112
}
107113
pub fn data_mut(&mut self) -> &mut [F32] {
108114
debug_assert_eq!(self.varlena & 3, 0);
115+
debug_assert_eq!(self.kind, 32);
109116
unsafe { std::slice::from_raw_parts_mut(self.phantom.as_mut_ptr(), self.len as usize) }
110117
}
111118
#[allow(dead_code)]

src/index/am.rs

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ use super::am_build;
22
use super::am_scan;
33
use super::am_setup;
44
use super::am_update;
5-
use crate::datatype::vecf32::Vecf32Input;
65
use crate::gucs::ENABLE_VECTOR_INDEX;
6+
use crate::index::utils::from_datum;
77
use crate::prelude::*;
88
use crate::utils::cells::PgCell;
99
use service::prelude::*;
@@ -198,18 +198,16 @@ pub unsafe extern "C" fn ambuildempty(index_relation: pgrx::pg_sys::Relation) {
198198
pub unsafe extern "C" fn aminsert(
199199
index_relation: pgrx::pg_sys::Relation,
200200
values: *mut pgrx::pg_sys::Datum,
201-
is_null: *mut bool,
201+
_is_null: *mut bool,
202202
heap_tid: pgrx::pg_sys::ItemPointer,
203203
_heap_relation: pgrx::pg_sys::Relation,
204204
_check_unique: pgrx::pg_sys::IndexUniqueCheck,
205205
_index_info: *mut pgrx::pg_sys::IndexInfo,
206206
) -> bool {
207-
use pgrx::FromDatum;
208207
let oid = (*index_relation).rd_node.relNode;
209208
let id = Id::from_sys(oid);
210-
let vector = Vecf32Input::from_datum(*values.add(0), *is_null.add(0)).unwrap();
211-
let vector = vector.data().to_vec();
212-
am_update::update_insert(id, vector.into(), *heap_tid);
209+
let vector = from_datum(*values.add(0));
210+
am_update::update_insert(id, vector, *heap_tid);
213211
true
214212
}
215213

@@ -218,22 +216,20 @@ pub unsafe extern "C" fn aminsert(
218216
pub unsafe extern "C" fn aminsert(
219217
index_relation: pgrx::pg_sys::Relation,
220218
values: *mut pgrx::pg_sys::Datum,
221-
is_null: *mut bool,
219+
_is_null: *mut bool,
222220
heap_tid: pgrx::pg_sys::ItemPointer,
223221
_heap_relation: pgrx::pg_sys::Relation,
224222
_check_unique: pgrx::pg_sys::IndexUniqueCheck,
225223
_index_unchanged: bool,
226224
_index_info: *mut pgrx::pg_sys::IndexInfo,
227225
) -> bool {
228-
use pgrx::FromDatum;
229226
#[cfg(any(feature = "pg14", feature = "pg15"))]
230227
let oid = (*index_relation).rd_node.relNode;
231228
#[cfg(feature = "pg16")]
232229
let oid = (*index_relation).rd_locator.relNumber;
233230
let id = Id::from_sys(oid);
234-
let vector = Vecf32Input::from_datum(*values.add(0), *is_null.add(0)).unwrap();
235-
let vector = vector.data().to_vec();
236-
am_update::update_insert(id, vector.into(), *heap_tid);
231+
let vector = from_datum(*values.add(0));
232+
am_update::update_insert(id, vector, *heap_tid);
237233
true
238234
}
239235

src/index/am_build.rs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use super::{client::ClientGuard, hook_transaction::flush_if_commit};
2-
use crate::datatype::vecf32::Vecf32Input;
32
use crate::index::am_setup::options;
3+
use crate::index::utils::from_datum;
44
use crate::prelude::*;
55
use pgrx::pg_sys::{IndexBuildResult, IndexInfo, RelationData};
66
use service::prelude::*;
@@ -48,17 +48,16 @@ unsafe extern "C" fn callback(
4848
index_relation: pgrx::pg_sys::Relation,
4949
htup: pgrx::pg_sys::HeapTuple,
5050
values: *mut pgrx::pg_sys::Datum,
51-
is_null: *mut bool,
51+
_is_null: *mut bool,
5252
_tuple_is_alive: bool,
5353
state: *mut std::os::raw::c_void,
5454
) {
55-
use pgrx::FromDatum;
5655
let ctid = &(*htup).t_self;
5756
let oid = (*index_relation).rd_node.relNode;
5857
let id = Id::from_sys(oid);
5958
let state = &mut *(state as *mut Builder);
60-
let pgvector = Vecf32Input::from_datum(*values.add(0), *is_null.add(0)).unwrap();
61-
let data = (pgvector.to_vec().into(), Pointer::from_sys(*ctid));
59+
let vector = from_datum(*values.add(0));
60+
let data = (vector, Pointer::from_sys(*ctid));
6261
state.client.insert(id, data);
6362
(*state.result).heap_tuples += 1.0;
6463
(*state.result).index_tuples += 1.0;
@@ -70,19 +69,18 @@ unsafe extern "C" fn callback(
7069
index_relation: pgrx::pg_sys::Relation,
7170
ctid: pgrx::pg_sys::ItemPointer,
7271
values: *mut pgrx::pg_sys::Datum,
73-
is_null: *mut bool,
72+
_is_null: *mut bool,
7473
_tuple_is_alive: bool,
7574
state: *mut std::os::raw::c_void,
7675
) {
77-
use pgrx::FromDatum;
7876
#[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15"))]
7977
let oid = (*index_relation).rd_node.relNode;
8078
#[cfg(feature = "pg16")]
8179
let oid = (*index_relation).rd_locator.relNumber;
8280
let id = Id::from_sys(oid);
8381
let state = &mut *(state as *mut Builder);
84-
let pgvector = Vecf32Input::from_datum(*values.add(0), *is_null.add(0)).unwrap();
85-
let data = (pgvector.to_vec().into(), Pointer::from_sys(*ctid));
82+
let vector = from_datum(*values.add(0));
83+
let data = (vector, Pointer::from_sys(*ctid));
8684
state.client.insert(id, data);
8785
(*state.result).heap_tuples += 1.0;
8886
(*state.result).index_tuples += 1.0;

src/index/am_scan.rs

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
use crate::datatype::vecf32::Vecf32Input;
21
use crate::gucs::ENABLE_PREFILTER;
32
use crate::gucs::K;
3+
use crate::index::utils::from_datum;
44
use crate::prelude::*;
55
use pgrx::FromDatum;
66
use service::prelude::*;
@@ -9,7 +9,7 @@ use service::prelude::*;
99
pub enum Scanner {
1010
Initial {
1111
// fields to be filled by amhandler and hook
12-
vector: Option<Vec<F32>>,
12+
vector: Option<DynamicVector>,
1313
index_scan_state: Option<*mut pgrx::pg_sys::IndexScanState>,
1414
},
1515
Type0 {
@@ -81,8 +81,7 @@ pub unsafe fn start_scan(
8181
}
8282
let orderby = orderbys.add(0);
8383
let argument = (*orderby).sk_argument;
84-
let vector = Vecf32Input::from_datum(argument, false).unwrap();
85-
let vector = vector.to_vec();
84+
let vector = from_datum(argument);
8685

8786
let last = (*((*scan).opaque as *mut Scanner)).clone();
8887
let scanner = (*scan).opaque as *mut Scanner;
@@ -153,12 +152,7 @@ pub unsafe fn next_scan(scan: pgrx::pg_sys::IndexScanDesc) -> bool {
153152
node: index_scan_state.unwrap(),
154153
};
155154

156-
let mut result = client.search(
157-
id,
158-
(vector.into(), k),
159-
ENABLE_PREFILTER.get(),
160-
client_search,
161-
);
155+
let mut result = client.search(id, (vector, k), ENABLE_PREFILTER.get(), client_search);
162156
result.reverse();
163157
*scanner = Scanner::Type1 {
164158
index_scan_state: index_scan_state.unwrap(),
@@ -175,7 +169,7 @@ pub unsafe fn next_scan(scan: pgrx::pg_sys::IndexScanDesc) -> bool {
175169

176170
let client_search = ClientSearch {};
177171

178-
let mut result = client.search(id, (vector.into(), k), false, client_search);
172+
let mut result = client.search(id, (vector, k), false, client_search);
179173
result.reverse();
180174
*scanner = Scanner::Type0 { data: result };
181175
}

src/index/am_setup.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ pub unsafe fn convert_opfamily_to_distance(opfamily: pgrx::pg_sys::Oid) -> (Dist
5959
result = (Distance::Dot, Kind::F32);
6060
} else if operator == regoperatorin("<=>(vector,vector)") {
6161
result = (Distance::Cos, Kind::F32);
62+
} else if operator == regoperatorin("<->(vecf16,vecf16)") {
63+
result = (Distance::L2, Kind::F16);
64+
} else if operator == regoperatorin("<#>(vecf16,vecf16)") {
65+
result = (Distance::Dot, Kind::F16);
66+
} else if operator == regoperatorin("<=>(vecf16,vecf16)") {
67+
result = (Distance::Cos, Kind::F16);
6268
} else {
6369
FriendlyError::BadOptions3.friendly();
6470
};

src/index/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ mod client;
99
mod hook_executor;
1010
mod hook_transaction;
1111
mod hooks;
12+
mod utils;
1213
mod views;
1314

1415
pub unsafe fn init() {

0 commit comments

Comments
 (0)