From 941bac6e46049c784108faa3f96830592a38657f Mon Sep 17 00:00:00 2001 From: usamoi Date: Fri, 8 Dec 2023 13:55:02 +0800 Subject: [PATCH 01/23] feat: fp16 vector Signed-off-by: usamoi --- Cargo.lock | 475 ++++++++----- Cargo.toml | 82 ++- crates/c/Cargo.toml | 7 + crates/c/build.rs | 6 + crates/c/src/c.c | 3 + crates/c/src/c.h | 0 crates/c/src/c.rs | 3 + crates/c/src/lib.rs | 3 + crates/service/Cargo.toml | 42 ++ .../algorithms/clustering/elkan_k_means.rs | 75 +- .../service/src}/algorithms/clustering/mod.rs | 0 .../service/src}/algorithms/flat.rs | 70 +- .../service/src}/algorithms/hnsw.rs | 145 ++-- .../service/src}/algorithms/ivf/ivf_naive.rs | 88 +-- .../service/src}/algorithms/ivf/ivf_pq.rs | 92 +-- .../service/src}/algorithms/ivf/mod.rs | 16 +- {src => crates/service/src}/algorithms/mod.rs | 1 - .../src}/algorithms/quantization/mod.rs | 40 +- .../src}/algorithms/quantization/product.rs | 52 +- .../src}/algorithms/quantization/scalar.rs | 32 +- .../src}/algorithms/quantization/trivial.rs | 18 +- {src => crates/service/src}/algorithms/raw.rs | 52 +- .../service/src/algorithms/vamana.rs.txt | 0 {src => crates/service/src}/index/delete.rs | 0 .../service/src}/index/indexing/flat.rs | 14 +- .../service/src}/index/indexing/hnsw.rs | 14 +- .../service/src}/index/indexing/ivf.rs | 15 +- .../service/src}/index/indexing/mod.rs | 28 +- crates/service/src/index/mod.rs | 415 +++++++++++ .../service/src}/index/optimizing/indexing.rs | 67 +- .../service/src}/index/optimizing/mod.rs | 18 +- .../service/src/index/optimizing/sealing.rs | 49 ++ .../service/src}/index/optimizing/vacuum.rs | 0 .../service/src}/index/segments/growing.rs | 53 +- .../service/src}/index/segments/mod.rs | 18 +- .../service/src}/index/segments/sealed.rs | 14 +- crates/service/src/lib.rs | 8 + {src => crates/service/src}/prelude/error.rs | 87 +-- {src => crates/service/src}/prelude/filter.rs | 0 crates/service/src/prelude/global/f16_cos.rs | 222 ++++++ crates/service/src/prelude/global/f16_dot.rs | 191 +++++ crates/service/src/prelude/global/f16_l2.rs | 152 ++++ crates/service/src/prelude/global/f32_cos.rs | 222 ++++++ crates/service/src/prelude/global/f32_dot.rs | 191 +++++ crates/service/src/prelude/global/f32_l2.rs | 151 ++++ crates/service/src/prelude/global/mod.rs | 121 ++++ {src => crates/service/src}/prelude/heap.rs | 6 +- crates/service/src/prelude/mod.rs | 16 + crates/service/src/prelude/scalar/f16.rs | 653 ++++++++++++++++++ crates/service/src/prelude/scalar/f32.rs | 632 +++++++++++++++++ crates/service/src/prelude/scalar/mod.rs | 5 + {src => crates/service/src}/prelude/sys.rs | 25 +- crates/service/src/utils/cells.rs | 26 + {src => crates/service/src}/utils/clean.rs | 0 {src => crates/service/src}/utils/dir_ops.rs | 0 .../service/src}/utils/file_atomic.rs | 0 {src => crates/service/src}/utils/file_wal.rs | 0 .../service/src}/utils/mmap_array.rs | 2 +- crates/service/src/utils/mod.rs | 7 + {src => crates/service/src}/utils/vec2.rs | 20 +- crates/service/src/worker/instance.rs | 204 ++++++ .../service/src/worker/mod.rs | 74 +- docs/get-started.md | 8 +- docs/indexing.md | 51 +- src/algorithms/diskann/mod.rs | 1 - src/bgworker/mod.rs | 21 +- src/datatype/casts_f32.rs | 26 + src/datatype/mod.rs | 6 + .../operators_f16.rs} | 80 +-- src/datatype/operators_f32.rs | 184 +++++ src/datatype/typmod.rs | 77 +++ src/datatype/vecf16.rs | 375 ++++++++++ .../datatype.rs => datatype/vecf32.rs} | 255 +++---- src/embedding/udf.rs | 14 +- src/{postgres => }/gucs.rs | 0 src/{postgres/index.rs => index/am.rs} | 49 +- .../index_build.rs => index/am_build.rs} | 61 +- .../index_scan.rs => index/am_scan.rs} | 99 ++- .../index_setup.rs => index/am_setup.rs} | 47 +- src/index/am_update.rs | 31 + src/index/client.rs | 42 ++ src/{postgres => index}/hook_executor.rs | 4 +- src/index/hook_transaction.rs | 26 + src/{postgres => index}/hooks.rs | 14 +- src/index/mod.rs | 465 +------------ src/index/views.rs | 46 ++ src/ipc/client.rs | 197 ++---- src/ipc/mod.rs | 22 +- src/ipc/packet.rs | 15 +- src/ipc/server.rs | 21 +- src/ipc/transport/mmap.rs | 70 +- src/ipc/transport/unix.rs | 12 +- src/lib.rs | 33 +- src/postgres/casts.rs | 21 - src/postgres/hook_transaction.rs | 48 -- src/postgres/index_update.rs | 31 - src/postgres/mod.rs | 19 - src/postgres/stat.rs | 38 - src/prelude.rs | 39 ++ src/prelude/distance.rs | 516 -------------- src/prelude/mod.rs | 15 - src/prelude/scalar.rs | 295 -------- src/prelude/stat.rs | 13 - src/sql/bootstrap.sql | 1 + src/sql/finalize.sql | 21 +- src/utils/cells.rs | 27 +- src/utils/mod.rs | 6 - src/utils/os.rs | 82 ++- 108 files changed, 5611 insertions(+), 2935 deletions(-) create mode 100644 crates/c/Cargo.toml create mode 100644 crates/c/build.rs create mode 100644 crates/c/src/c.c create mode 100644 crates/c/src/c.h create mode 100644 crates/c/src/c.rs create mode 100644 crates/c/src/lib.rs create mode 100644 crates/service/Cargo.toml rename {src => crates/service/src}/algorithms/clustering/elkan_k_means.rs (74%) rename {src => crates/service/src}/algorithms/clustering/mod.rs (100%) rename {src => crates/service/src}/algorithms/flat.rs (63%) rename {src => crates/service/src}/algorithms/hnsw.rs (82%) rename {src => crates/service/src}/algorithms/ivf/ivf_naive.rs (77%) rename {src => crates/service/src}/algorithms/ivf/ivf_pq.rs (77%) rename {src => crates/service/src}/algorithms/ivf/mod.rs (84%) rename {src => crates/service/src}/algorithms/mod.rs (84%) rename {src => crates/service/src}/algorithms/quantization/mod.rs (80%) rename {src => crates/service/src}/algorithms/quantization/product.rs (81%) rename {src => crates/service/src}/algorithms/quantization/scalar.rs (75%) rename {src => crates/service/src}/algorithms/quantization/trivial.rs (64%) rename {src => crates/service/src}/algorithms/raw.rs (74%) rename src/algorithms/diskann/vamana.rs => crates/service/src/algorithms/vamana.rs.txt (100%) rename {src => crates/service/src}/index/delete.rs (100%) rename {src => crates/service/src}/index/indexing/flat.rs (77%) rename {src => crates/service/src}/index/indexing/hnsw.rs (83%) rename {src => crates/service/src}/index/indexing/ivf.rs (88%) rename {src => crates/service/src}/index/indexing/mod.rs (84%) create mode 100644 crates/service/src/index/mod.rs rename {src => crates/service/src}/index/optimizing/indexing.rs (57%) rename {src => crates/service/src}/index/optimizing/mod.rs (67%) create mode 100644 crates/service/src/index/optimizing/sealing.rs rename {src => crates/service/src}/index/optimizing/vacuum.rs (100%) rename {src => crates/service/src}/index/segments/growing.rs (80%) rename {src => crates/service/src}/index/segments/mod.rs (67%) rename {src => crates/service/src}/index/segments/sealed.rs (81%) create mode 100644 crates/service/src/lib.rs rename {src => crates/service/src}/prelude/error.rs (61%) rename {src => crates/service/src}/prelude/filter.rs (100%) create mode 100644 crates/service/src/prelude/global/f16_cos.rs create mode 100644 crates/service/src/prelude/global/f16_dot.rs create mode 100644 crates/service/src/prelude/global/f16_l2.rs create mode 100644 crates/service/src/prelude/global/f32_cos.rs create mode 100644 crates/service/src/prelude/global/f32_dot.rs create mode 100644 crates/service/src/prelude/global/f32_l2.rs create mode 100644 crates/service/src/prelude/global/mod.rs rename {src => crates/service/src}/prelude/heap.rs (89%) create mode 100644 crates/service/src/prelude/mod.rs create mode 100644 crates/service/src/prelude/scalar/f16.rs create mode 100644 crates/service/src/prelude/scalar/f32.rs create mode 100644 crates/service/src/prelude/scalar/mod.rs rename {src => crates/service/src}/prelude/sys.rs (53%) create mode 100644 crates/service/src/utils/cells.rs rename {src => crates/service/src}/utils/clean.rs (100%) rename {src => crates/service/src}/utils/dir_ops.rs (100%) rename {src => crates/service/src}/utils/file_atomic.rs (100%) rename {src => crates/service/src}/utils/file_wal.rs (100%) rename {src => crates/service/src}/utils/mmap_array.rs (97%) create mode 100644 crates/service/src/utils/mod.rs rename {src => crates/service/src}/utils/vec2.rs (78%) create mode 100644 crates/service/src/worker/instance.rs rename src/bgworker/worker.rs => crates/service/src/worker/mod.rs (65%) delete mode 100644 src/algorithms/diskann/mod.rs create mode 100644 src/datatype/casts_f32.rs create mode 100644 src/datatype/mod.rs rename src/{postgres/operators.rs => datatype/operators_f16.rs} (57%) create mode 100644 src/datatype/operators_f32.rs create mode 100644 src/datatype/typmod.rs create mode 100644 src/datatype/vecf16.rs rename src/{postgres/datatype.rs => datatype/vecf32.rs} (55%) rename src/{postgres => }/gucs.rs (100%) rename src/{postgres/index.rs => index/am.rs} (87%) rename src/{postgres/index_build.rs => index/am_build.rs} (65%) rename src/{postgres/index_scan.rs => index/am_scan.rs} (76%) rename src/{postgres/index_setup.rs => index/am_setup.rs} (79%) create mode 100644 src/index/am_update.rs create mode 100644 src/index/client.rs rename src/{postgres => index}/hook_executor.rs (95%) create mode 100644 src/index/hook_transaction.rs rename src/{postgres => index}/hooks.rs (89%) create mode 100644 src/index/views.rs delete mode 100644 src/postgres/casts.rs delete mode 100644 src/postgres/hook_transaction.rs delete mode 100644 src/postgres/index_update.rs delete mode 100644 src/postgres/mod.rs delete mode 100644 src/postgres/stat.rs create mode 100644 src/prelude.rs delete mode 100644 src/prelude/distance.rs delete mode 100644 src/prelude/mod.rs delete mode 100644 src/prelude/scalar.rs delete mode 100644 src/prelude/stat.rs diff --git a/Cargo.lock b/Cargo.lock index 06be97a82..75eb3852d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -100,12 +100,12 @@ dependencies = [ [[package]] name = "async-channel" -version = "2.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d37875bd9915b7d67c2f117ea2c30a0989874d0b2cb694fe25403c85763c0c9e" +checksum = "1ca33f4bc4ed1babef42cad36cc1f51fa88be00420404e5b1e80ab1b18f7678c" dependencies = [ "concurrent-queue", - "event-listener 3.1.0", + "event-listener 4.0.0", "event-listener-strategy", "futures-core", "pin-project-lite", @@ -113,30 +113,30 @@ dependencies = [ [[package]] name = "async-executor" -version = "1.7.2" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc5ea910c42e5ab19012bab31f53cb4d63d54c3a27730f9a833a88efcf4bb52d" +checksum = "17ae5ebefcc48e7452b4987947920dac9450be1110cadf34d1b8c116bdbaf97c" dependencies = [ - "async-lock 3.1.1", + "async-lock 3.2.0", "async-task", "concurrent-queue", "fastrand 2.0.1", - "futures-lite 2.0.1", + "futures-lite 2.1.0", "slab", ] [[package]] name = "async-global-executor" -version = "2.3.1" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1b6f5d7df27bd294849f8eec66ecfc63d11814df7a4f5d74168a2394467b776" +checksum = "9b4353121d5644cdf2beb5726ab752e79a8db1ebb52031770ec47db31d245526" dependencies = [ - "async-channel 1.9.0", + "async-channel 2.1.1", "async-executor", - "async-io 1.13.0", - "async-lock 2.8.0", + "async-io 2.2.1", + "async-lock 3.2.0", "blocking", - "futures-lite 1.13.0", + "futures-lite 2.1.0", "once_cell", ] @@ -162,22 +162,21 @@ dependencies = [ [[package]] name = "async-io" -version = "2.2.0" +version = "2.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41ed9d5715c2d329bf1b4da8d60455b99b187f27ba726df2883799af9af60997" +checksum = "d6d3b15875ba253d1110c740755e246537483f152fa334f91abd7fe84c88b3ff" dependencies = [ - "async-lock 3.1.1", + "async-lock 3.2.0", "cfg-if", "concurrent-queue", "futures-io", - "futures-lite 2.0.1", + "futures-lite 2.1.0", "parking", - "polling 3.3.0", - "rustix 0.38.25", + "polling 3.3.1", + "rustix 0.38.26", "slab", "tracing", - "waker-fn", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -191,11 +190,11 @@ dependencies = [ [[package]] name = "async-lock" -version = "3.1.1" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "655b9c7fe787d3b25cc0f804a1a8401790f0c5bc395beb5a64dc77d8de079105" +checksum = "7125e42787d53db9dd54261812ef17e937c95a51e4d291373b670342fa44310c" dependencies = [ - "event-listener 3.1.0", + "event-listener 4.0.0", "event-listener-strategy", "pin-project-lite", ] @@ -222,8 +221,8 @@ dependencies = [ "cfg-if", "event-listener 3.1.0", "futures-lite 1.13.0", - "rustix 0.38.25", - "windows-sys", + "rustix 0.38.26", + "windows-sys 0.48.0", ] [[package]] @@ -232,16 +231,16 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e47d90f65a225c4527103a8d747001fc56e375203592b25ad103e1ca13124c5" dependencies = [ - "async-io 2.2.0", + "async-io 2.2.1", "async-lock 2.8.0", "atomic-waker", "cfg-if", "futures-core", "futures-io", - "rustix 0.38.25", + "rustix 0.38.26", "signal-hook-registry", "slab", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -434,12 +433,12 @@ version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a37913e8dc4ddcc604f0c6d3bf2887c995153af3611de9e23c352b44c1b9118" dependencies = [ - "async-channel 2.1.0", - "async-lock 3.1.1", + "async-channel 2.1.1", + "async-lock 3.2.0", "async-task", "fastrand 2.0.1", "futures-io", - "futures-lite 2.0.1", + "futures-lite 2.1.0", "piper", "tracing", ] @@ -455,12 +454,26 @@ name = "bytemuck" version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" +dependencies = [ + "bytemuck_derive", +] [[package]] -name = "byteorder" +name = "bytemuck_derive" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.39", +] + +[[package]] +name = "byteorder" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "bytes" @@ -468,6 +481,13 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" +[[package]] +name = "c" +version = "0.0.0" +dependencies = [ + "cc", +] + [[package]] name = "cargo_toml" version = "0.16.3" @@ -518,7 +538,7 @@ dependencies = [ "iana-time-zone", "num-traits", "serde", - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -534,9 +554,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.4.8" +version = "4.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2275f18819641850fa26c89acc84d465c1bf91ce57bc2748b28c420473352f64" +checksum = "41fffed7514f420abec6d183b1d3acfd9099c79c3a10a06ade4f8203f1411272" dependencies = [ "clap_builder", "clap_derive", @@ -554,9 +574,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.4.8" +version = "4.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07cdf1b148b25c1e1f7a42225e30a0d99a615cd4637eae7365548dd4529b95bc" +checksum = "63361bae7eef3771745f02d8d892bec2fee5f6e34af316ba556e7f97a7069ff1" dependencies = [ "anstyle", "clap_lex", @@ -582,9 +602,9 @@ checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" [[package]] name = "concurrent-queue" -version = "2.3.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f057a694a54f12365049b0958a1685bb52d567f5593b355fbf685838e873d400" +checksum = "d16048cd947b08fa32c24458a22f5dc5e835264f689f4f5653210c69fd107363" dependencies = [ "crossbeam-utils", ] @@ -600,9 +620,9 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.4" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" [[package]] name = "cpufeatures" @@ -711,16 +731,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "cstr" -version = "0.2.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8aa998c33a6d3271e3678950a22134cd7dd27cef86dee1b611b5b14207d1d90b" -dependencies = [ - "proc-macro2", - "quote", -] - [[package]] name = "cty" version = "0.2.2" @@ -755,7 +765,7 @@ dependencies = [ "openssl-sys", "pkg-config", "vcpkg", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -800,7 +810,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" dependencies = [ "cfg-if", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "lock_api", "once_cell", "parking_lot_core", @@ -808,9 +818,9 @@ dependencies = [ [[package]] name = "deranged" -version = "0.3.9" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f32d04922c60427da6f9fef14d042d9edddef64cb9d4ce0d64d0685fbeb1fd3" +checksum = "8eb30d70a07a3b04884d2677f06bec33509dc67ca60d92949e5535352d3191dc" dependencies = [ "powerfmt", "serde", @@ -867,7 +877,7 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -919,18 +929,18 @@ dependencies = [ [[package]] name = "enum-map" -version = "2.7.2" +version = "2.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09e6b4f374c071b18172e23134e01026653dc980636ee139e0dfe59c538c61e5" +checksum = "6866f3bfdf8207509a033af1a75a7b08abda06bbaaeae6669323fd5a097df2e9" dependencies = [ "enum-map-derive", ] [[package]] name = "enum-map-derive" -version = "0.16.0" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfdb3d73d1beaf47c8593a1364e577fde072677cbfd103600345c0f547408cc0" +checksum = "f282cfdfe92516eb26c2af8589c274c7c17681f5ecc03c18255fe741c6aa64eb" dependencies = [ "proc-macro2", "quote", @@ -958,12 +968,12 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f258a7194e7f7c2a7837a8913aeab7fd8c383457034fa20ce4dd3dcb813e8eb8" +checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -983,13 +993,24 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "event-listener" +version = "4.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "770d968249b5d99410d61f5bf89057f3199a077a04d087092f58e7d10692baae" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + [[package]] name = "event-listener-strategy" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d96b852f1345da36d551b9473fa1e2b1eb5c5195585c6c018118bc92a8d91160" +checksum = "958e4d70b6d5e81971bebec42271ec641e7ff4e170a6fa605f2b8a8b65cb97d3" dependencies = [ - "event-listener 3.1.0", + "event-listener 4.0.0", "pin-project-lite", ] @@ -1063,9 +1084,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "form_urlencoded" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" dependencies = [ "percent-encoding", ] @@ -1121,14 +1142,13 @@ dependencies = [ [[package]] name = "futures-lite" -version = "2.0.1" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3831c2651acb5177cbd83943f3d9c8912c5ad03c76afcc0e9511ba568ec5ebb" +checksum = "aeee267a1883f7ebef3700f262d2d54de95dfaf38189015a74fdc4e0c7ad8143" dependencies = [ "fastrand 2.0.1", "futures-core", "futures-io", - "memchr", "parking", "pin-project-lite", ] @@ -1194,9 +1214,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.28.0" +version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" [[package]] name = "glob" @@ -1222,6 +1242,19 @@ version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" +[[package]] +name = "half" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872" +dependencies = [ + "bytemuck", + "cfg-if", + "crunchy", + "num-traits", + "serde", +] + [[package]] name = "hash32" version = "0.2.1" @@ -1239,9 +1272,9 @@ checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" [[package]] name = "hashbrown" -version = "0.14.2" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f93e7192158dbcda357bdec5fb5788eebf8bbac027f3f33e719d29135ae84156" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" [[package]] name = "heapless" @@ -1413,6 +1446,16 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "idna" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "if_chain" version = "1.0.2" @@ -1443,7 +1486,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" dependencies = [ "equivalent", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "serde", ] @@ -1464,7 +1507,7 @@ checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" dependencies = [ "hermit-abi", "libc", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -1474,8 +1517,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" dependencies = [ "hermit-abi", - "rustix 0.38.25", - "windows-sys", + "rustix 0.38.26", + "windows-sys 0.48.0", ] [[package]] @@ -1522,9 +1565,9 @@ checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" [[package]] name = "js-sys" -version = "0.3.65" +version = "0.3.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54c0c35952f67de54bb584e9fd912b3023117cbafc0a77d8f3dee1fb5f572fe8" +checksum = "cee9c64da59eae3b50095c18d3e74f8b73c0b86d2792824ff01bbce68ba229ca" dependencies = [ "wasm-bindgen", ] @@ -1650,9 +1693,9 @@ checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" [[package]] name = "linux-raw-sys" -version = "0.4.11" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "969488b55f8ac402214f3f5fd243ebb7206cf82de60d3172994707a4bcc2b829" +checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" [[package]] name = "lock_api" @@ -1736,7 +1779,7 @@ checksum = "3dce281c5e46beae905d4de1870d8b1509a9142b62eedf18b443b011ca8343d0" dependencies = [ "libc", "wasi", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -1875,9 +1918,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.95" +version = "0.9.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40a4130519a360279579c2053038317e40eff64d13fd3f004f9e1b72b8a6aaf9" +checksum = "3812c071ba60da8b5677cc12bcb1d42989a65553772897a7e0355545a819838f" dependencies = [ "cc", "libc", @@ -1923,7 +1966,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -1944,9 +1987,9 @@ checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" [[package]] name = "percent-encoding" -version = "2.3.0" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pest" @@ -2170,21 +2213,21 @@ dependencies = [ "libc", "log", "pin-project-lite", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] name = "polling" -version = "3.3.0" +version = "3.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e53b6af1f60f36f8c2ac2aad5459d75a5a9b4be1e8cdd40264f315d78193e531" +checksum = "cf63fa624ab313c11656b4cda960bfc46c410187ad493c41f6ba2d8c1e991c9e" dependencies = [ "cfg-if", "concurrent-queue", "pin-project-lite", - "rustix 0.38.25", + "rustix 0.38.26", "tracing", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -2304,9 +2347,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.69" +version = "1.0.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" +checksum = "39278fbbf5fb4f646ce651690877f89d1c5811a3d4acb27700c1cb3cdb78fd3b" dependencies = [ "unicode-ident", ] @@ -2468,16 +2511,16 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "ring" -version = "0.17.5" +version = "0.17.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb0205304757e5d899b9c2e448b867ffd03ae7f988002e47cd24954391394d0b" +checksum = "684d5e6e18f669ccebf64a92236bb7db9a34f07be010e3627368182027180866" dependencies = [ "cc", "getrandom", "libc", "spin", "untrusted", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2521,20 +2564,20 @@ dependencies = [ "io-lifetimes", "libc", "linux-raw-sys 0.3.8", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] name = "rustix" -version = "0.38.25" +version = "0.38.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc99bc2d4f1fed22595588a013687477aedf3cdcfb26558c559edb67b4d9b22e" +checksum = "9470c4bf8246c8daf25f9598dca807fb6510347b1e1cfa55749113850c79d88a" dependencies = [ "bitflags 2.4.1", "errno", "libc", - "linux-raw-sys 0.4.11", - "windows-sys", + "linux-raw-sys 0.4.12", + "windows-sys 0.52.0", ] [[package]] @@ -2598,7 +2641,7 @@ version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88" dependencies = [ - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2655,9 +2698,9 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.192" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bca2a08484b285dcb282d0f67b26cadc0df8b19f8c12502c13d966bf9482f001" +checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" dependencies = [ "serde_derive", ] @@ -2668,15 +2711,15 @@ version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5" dependencies = [ - "half", + "half 1.8.2", "serde", ] [[package]] name = "serde_derive" -version = "1.0.192" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6c7207fbec9faa48073f3e3074cbe553af6ea512d7c21ba46e434e70ea9fbc1" +checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", @@ -2742,6 +2785,39 @@ dependencies = [ "syn 2.0.39", ] +[[package]] +name = "service" +version = "0.0.0" +dependencies = [ + "arc-swap", + "arrayvec", + "bincode", + "bytemuck", + "byteorder", + "crc32fast", + "crossbeam", + "dashmap", + "half 2.3.1", + "libc", + "log", + "memmap2", + "memoffset", + "multiversion", + "num-traits", + "parking_lot", + "rand", + "rayon", + "rustix 0.38.26", + "serde", + "serde_json", + "serde_with", + "tempfile", + "thiserror", + "ulock-sys", + "uuid", + "validator", +] + [[package]] name = "sha2" version = "0.10.8" @@ -2823,7 +2899,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2847,12 +2923,6 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" -[[package]] -name = "static_assertions" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" - [[package]] name = "string_cache" version = "0.8.7" @@ -2913,9 +2983,9 @@ dependencies = [ [[package]] name = "sysinfo" -version = "0.29.10" +version = "0.29.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a18d114d420ada3a891e6bc8e96a2023402203296a47cdd65083377dad18ba5" +checksum = "cd727fc423c2060f6c92d9534cef765c65a6ed3f428a03d7def74a8c4348e666" dependencies = [ "cfg-if", "core-foundation-sys", @@ -2947,8 +3017,8 @@ dependencies = [ "cfg-if", "fastrand 2.0.1", "redox_syscall", - "rustix 0.38.25", - "windows-sys", + "rustix 0.38.26", + "windows-sys 0.48.0", ] [[package]] @@ -3065,7 +3135,7 @@ dependencies = [ "signal-hook-registry", "socket2 0.5.5", "tokio-macros", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -3281,9 +3351,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "ureq" -version = "2.8.0" +version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5ccd538d4a604753ebc2f17cd9946e89b77bf87f6a8e2309667c6f2e87855e3" +checksum = "f8cdd25c339e200129fe4de81451814e5228c9b771d57378817d6117cc2b3f97" dependencies = [ "base64", "flate2", @@ -3299,20 +3369,20 @@ dependencies = [ [[package]] name = "url" -version = "2.4.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "143b538f18257fac9cad154828a57c6bf5157e1aa604d4816b5995bf6de87ae5" +checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" dependencies = [ "form_urlencoded", - "idna", + "idna 0.5.0", "percent-encoding", ] [[package]] name = "uuid" -version = "1.6.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c58fe91d841bc04822c9801002db4ea904b9e4b8e6bbad25127b46eff8dc516b" +checksum = "5e395fcf16a7a3d8127ec99782007af141946b4795001f876d54fb0d55978560" dependencies = [ "getrandom", "serde", @@ -3324,7 +3394,7 @@ version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b92f40481c04ff1f4f61f304d61793c7b56ff76ac1469f1beb199b1445b253bd" dependencies = [ - "idna", + "idna 0.4.0", "lazy_static", "regex", "serde", @@ -3374,41 +3444,26 @@ checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" [[package]] name = "vectors" -version = "0.1.1" +version = "0.0.0" dependencies = [ - "arc-swap", - "arrayvec", "bincode", - "bytemuck", "byteorder", - "crc32fast", - "crossbeam", - "cstr", - "dashmap", "env_logger", + "half 2.3.1", "httpmock", "libc", "log", - "memmap2", - "memoffset", "mockall", - "multiversion", + "num-traits", "openai_api_rust", - "parking_lot", "pgrx", "pgrx-tests", - "rand", - "rayon", - "rustix 0.38.25", + "rustix 0.38.26", "serde", "serde_json", - "serde_with", - "static_assertions", - "tempfile", + "service", "thiserror", "toml", - "ulock-sys", - "uuid", "validator", ] @@ -3460,9 +3515,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7daec296f25a1bae309c0cd5c29c4b260e510e6d813c286b19eaadf409d40fce" +checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -3470,9 +3525,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e397f4664c0e4e428e8313a469aaa58310d302159845980fd23b0f22a847f217" +checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826" dependencies = [ "bumpalo", "log", @@ -3485,9 +3540,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.38" +version = "0.4.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9afec9963e3d0994cac82455b2b3502b81a7f40f9a0d32181f7528d9f4b43e02" +checksum = "ac36a15a220124ac510204aec1c3e5db8a22ab06fd6706d881dc6149f8ed9a12" dependencies = [ "cfg-if", "js-sys", @@ -3497,9 +3552,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5961017b3b08ad5f3fe39f1e79877f8ee7c23c5e5fd5eb80de95abc41f1f16b2" +checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -3507,9 +3562,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5353b8dab669f5e10f5bd76df26a9360c748f054f862ff5f3f8aae0c7fb3907" +checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", @@ -3520,15 +3575,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d046c5d029ba91a1ed14da14dca44b68bf2f124cfbaf741c54151fdb3e0750b" +checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" [[package]] name = "web-sys" -version = "0.3.65" +version = "0.3.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5db499c5f66323272151db0e666cd34f78617522fb0c1604d31a27c50c206a85" +checksum = "50c24a44ec86bb68fbecd1b3efed7e85ea5621b39b35ef2766b66cd984f8010f" dependencies = [ "js-sys", "wasm-bindgen", @@ -3536,9 +3591,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.25.2" +version = "0.25.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14247bb57be4f377dfb94c72830b8ce8fc6beac03cf4bf7b9732eadd414123fc" +checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10" [[package]] name = "whoami" @@ -3587,7 +3642,7 @@ version = "0.51.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1f8cf84f35d2db49a46868f947758c7a1138116f7fac3bc844f43ade1292e64" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -3596,7 +3651,16 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.0", ] [[package]] @@ -3605,13 +3669,28 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +dependencies = [ + "windows_aarch64_gnullvm 0.52.0", + "windows_aarch64_msvc 0.52.0", + "windows_i686_gnu 0.52.0", + "windows_i686_msvc 0.52.0", + "windows_x86_64_gnu 0.52.0", + "windows_x86_64_gnullvm 0.52.0", + "windows_x86_64_msvc 0.52.0", ] [[package]] @@ -3620,42 +3699,84 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" + [[package]] name = "windows_i686_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" +[[package]] +name = "windows_i686_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" + [[package]] name = "windows_i686_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" +[[package]] +name = "windows_i686_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" + [[package]] name = "winnow" version = "0.5.19" diff --git a/Cargo.toml b/Cargo.toml index 403fb73fe..054f9ab13 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "vectors" -version = "0.1.1" -edition = "2021" +version.workspace = true +edition.workspace = true [lib] crate-type = ["cdylib"] @@ -16,45 +16,60 @@ pg16 = ["pgrx/pg16", "pgrx-tests/pg16"] pg_test = [] [dependencies] +libc.workspace = true +log.workspace = true +serde.workspace = true +serde_json.workspace = true +validator.workspace = true +rustix.workspace = true +thiserror.workspace = true +byteorder.workspace = true +bincode.workspace = true +half.workspace = true +num-traits.workspace = true +service = { path = "crates/service" } pgrx = { git = "https://github.com/tensorchord/pgrx.git", rev = "7c30e2023876c1efce613756f5ec81f3ab05696b", default-features = false, features = [ ] } openai_api_rust = { git = "https://github.com/tensorchord/openai-api.git", rev = "228d54b6002e98257b3c81501a054942342f585f" } -static_assertions = "1.1.0" -libc = "~0.2" -serde = "1.0.163" -bincode = "1.3.3" -rand = "0.8.5" -byteorder = "1.4.3" -crc32fast = "1.3.2" -log = "0.4.18" env_logger = "0.10.0" -crossbeam = "0.8.2" -dashmap = "5.4.0" -parking_lot = "0.12.1" -memoffset = "0.9.0" -serde_json = "1" -thiserror = "1.0.40" -tempfile = "3.6.0" -cstr = "0.2.11" -arrayvec = { version = "0.7.3", features = ["serde"] } -memmap2 = "0.9.0" -validator = { version = "0.16.1", features = ["derive"] } toml = "0.8.8" -rayon = "1.6.1" -uuid = { version = "1.4.1", features = ["serde"] } -rustix = { version = "0.38.20", features = ["net", "mm"] } -arc-swap = "1.6.0" -bytemuck = { version = "1.14.0", features = ["extern_crate_alloc"] } -serde_with = "3.4.0" -multiversion = "0.7.3" [dev-dependencies] pgrx-tests = { git = "https://github.com/tensorchord/pgrx.git", rev = "7c30e2023876c1efce613756f5ec81f3ab05696b" } httpmock = "0.6" mockall = "0.11.4" -[target.'cfg(target_os = "macos")'.dependencies] -ulock-sys = "0.1.0" +[lints] +clippy.too_many_arguments = "allow" +clippy.unnecessary_literal_unwrap = "allow" +clippy.unnecessary_unwrap = "allow" +rust.unsafe_op_in_unsafe_fn = "warn" + +[workspace] +resolver = "2" +members = ["crates/*"] + +[workspace.package] +version = "0.0.0" +edition = "2021" + +[workspace.dependencies] +libc = "~0.2" +log = "~0.4" +serde = "~1.0" +serde_json = "1" +thiserror = "~1.0" +bincode = "~1.3" +byteorder = "~1.4" +half = { version = "~2.3", features = [ + "bytemuck", + "num-traits", + "serde", + "use-intrinsics", +] } +num-traits = "~0.2" +validator = { version = "~0.16", features = ["derive"] } +rustix = { version = "~0.38", features = ["net", "mm"] } [profile.dev] panic = "unwind" @@ -65,10 +80,3 @@ opt-level = 3 lto = "fat" codegen-units = 1 debug = true - -[lints.clippy] -needless_range_loop = "allow" -derivable_impls = "allow" -unnecessary_literal_unwrap = "allow" -too_many_arguments = "allow" -unnecessary_unwrap = "allow" diff --git a/crates/c/Cargo.toml b/crates/c/Cargo.toml new file mode 100644 index 000000000..8dd8f339d --- /dev/null +++ b/crates/c/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "c" +version.workspace = true +edition.workspace = true + +[build-dependencies] +cc = "1.0" diff --git a/crates/c/build.rs b/crates/c/build.rs new file mode 100644 index 000000000..b86eeefc4 --- /dev/null +++ b/crates/c/build.rs @@ -0,0 +1,6 @@ +fn main() { + cc::Build::new() + .compiler("/usr/bin/clang") + .file("./src/c.c") + .compile("c"); +} diff --git a/crates/c/src/c.c b/crates/c/src/c.c new file mode 100644 index 000000000..6d8e57e97 --- /dev/null +++ b/crates/c/src/c.c @@ -0,0 +1,3 @@ +#include "c.h" + +void c_test() {} diff --git a/crates/c/src/c.h b/crates/c/src/c.h new file mode 100644 index 000000000..e69de29bb diff --git a/crates/c/src/c.rs b/crates/c/src/c.rs new file mode 100644 index 000000000..01c3bb7b5 --- /dev/null +++ b/crates/c/src/c.rs @@ -0,0 +1,3 @@ +extern "C" { + pub fn c_test(); +} diff --git a/crates/c/src/lib.rs b/crates/c/src/lib.rs new file mode 100644 index 000000000..6f1d73975 --- /dev/null +++ b/crates/c/src/lib.rs @@ -0,0 +1,3 @@ +mod c; + +pub use self::c::*; diff --git a/crates/service/Cargo.toml b/crates/service/Cargo.toml new file mode 100644 index 000000000..6b3d8e964 --- /dev/null +++ b/crates/service/Cargo.toml @@ -0,0 +1,42 @@ +[package] +name = "service" +version.workspace = true +edition.workspace = true + +[dependencies] +libc.workspace = true +log.workspace = true +serde.workspace = true +serde_json.workspace = true +validator.workspace = true +rustix.workspace = true +thiserror.workspace = true +byteorder.workspace = true +bincode.workspace = true +half.workspace = true +num-traits.workspace = true +rand = "0.8.5" +crc32fast = "1.3.2" +crossbeam = "0.8.2" +dashmap = "5.4.0" +parking_lot = "0.12.1" +memoffset = "0.9.0" +tempfile = "3.6.0" +arrayvec = { version = "0.7.3", features = ["serde"] } +memmap2 = "0.9.0" +rayon = "1.6.1" +uuid = { version = "1.6.1", features = ["serde"] } +arc-swap = "1.6.0" +bytemuck = { version = "1.14.0", features = ["extern_crate_alloc"] } +serde_with = "3.4.0" +multiversion = "0.7.3" + +[target.'cfg(target_os = "macos")'.dependencies] +ulock-sys = "0.1.0" + +[lints] +clippy.derivable_impls = "allow" +clippy.len_without_is_empty = "allow" +clippy.needless_range_loop = "allow" +clippy.too_many_arguments = "allow" +rust.unsafe_op_in_unsafe_fn = "warn" diff --git a/src/algorithms/clustering/elkan_k_means.rs b/crates/service/src/algorithms/clustering/elkan_k_means.rs similarity index 74% rename from src/algorithms/clustering/elkan_k_means.rs rename to crates/service/src/algorithms/clustering/elkan_k_means.rs index f30d84df3..494a10792 100644 --- a/src/algorithms/clustering/elkan_k_means.rs +++ b/crates/service/src/algorithms/clustering/elkan_k_means.rs @@ -4,38 +4,37 @@ use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use std::ops::{Index, IndexMut}; -pub struct ElkanKMeans { +pub struct ElkanKMeans { dims: u16, c: usize, - pub centroids: Vec2, + pub centroids: Vec2, lowerbound: Square, - upperbound: Vec, + upperbound: Vec, assign: Vec, rand: StdRng, - samples: Vec2, - d: Distance, + samples: Vec2, } const DELTA: f32 = 1.0 / 1024.0; -impl ElkanKMeans { - pub fn new(c: usize, samples: Vec2, d: Distance) -> Self { +impl ElkanKMeans { + pub fn new(c: usize, samples: Vec2) -> Self { let n = samples.len(); let dims = samples.dims(); let mut rand = StdRng::from_entropy(); let mut centroids = Vec2::new(dims, c); let mut lowerbound = Square::new(n, c); - let mut upperbound = vec![Scalar::Z; n]; + let mut upperbound = vec![F32::zero(); n]; let mut assign = vec![0usize; n]; centroids[0].copy_from_slice(&samples[rand.gen_range(0..n)]); - let mut weight = vec![Scalar::INFINITY; n]; + let mut weight = vec![F32::infinity(); n]; for i in 0..c { - let mut sum = Scalar::Z; + let mut sum = F32::zero(); for j in 0..n { - let dis = d.elkan_k_means_distance(&samples[j], ¢roids[i]); + let dis = S::elkan_k_means_distance(&samples[j], ¢roids[i]); lowerbound[(j, i)] = dis; if dis * dis < weight[j] { weight[j] = dis * dis; @@ -49,7 +48,7 @@ impl ElkanKMeans { let mut choice = sum * rand.gen_range(0.0..1.0); for j in 0..(n - 1) { choice -= weight[j]; - if choice <= Scalar::Z { + if choice <= F32::zero() { break 'a j; } } @@ -59,7 +58,7 @@ impl ElkanKMeans { } for i in 0..n { - let mut minimal = Scalar::INFINITY; + let mut minimal = F32::infinity(); let mut target = 0; for j in 0..c { let dis = lowerbound[(i, j)]; @@ -81,13 +80,11 @@ impl ElkanKMeans { assign, rand, samples, - d, } } pub fn iterate(&mut self) -> bool { let c = self.c; - let f = |lhs: &[Scalar], rhs: &[Scalar]| self.d.elkan_k_means_distance(lhs, rhs); let dims = self.dims; let samples = &self.samples; let rand = &mut self.rand; @@ -100,16 +97,16 @@ impl ElkanKMeans { // Step 1 let mut dist0 = Square::new(c, c); - let mut sp = vec![Scalar::Z; c]; + let mut sp = vec![F32::zero(); c]; for i in 0..c { for j in i + 1..c { - let dis = f(¢roids[i], ¢roids[j]) * 0.5; + let dis = S::elkan_k_means_distance(¢roids[i], ¢roids[j]) * 0.5; dist0[(i, j)] = dis; dist0[(j, i)] = dis; } } for i in 0..c { - let mut minimal = Scalar::INFINITY; + let mut minimal = F32::infinity(); for j in 0..c { if i == j { continue; @@ -127,7 +124,7 @@ impl ElkanKMeans { if upperbound[i] <= sp[assign[i]] { continue; } - let mut minimal = f(&samples[i], ¢roids[assign[i]]); + let mut minimal = S::elkan_k_means_distance(&samples[i], ¢roids[assign[i]]); lowerbound[(i, assign[i])] = minimal; upperbound[i] = minimal; // Step 3 @@ -142,7 +139,7 @@ impl ElkanKMeans { continue; } if minimal > lowerbound[(i, j)] || minimal > dist0[(assign[i], j)] { - let dis = f(&samples[i], ¢roids[j]); + let dis = S::elkan_k_means_distance(&samples[i], ¢roids[j]); lowerbound[(i, j)] = dis; if dis < minimal { minimal = dis; @@ -156,8 +153,8 @@ impl ElkanKMeans { // Step 4, 7 let old = std::mem::replace(centroids, Vec2::new(dims, c)); - let mut count = vec![Scalar::Z; c]; - centroids.fill(Scalar::Z); + let mut count = vec![F32::zero(); c]; + centroids.fill(S::Scalar::zero()); for i in 0..n { for j in 0..dims as usize { centroids[assign[i]][j] += samples[i][j]; @@ -165,21 +162,21 @@ impl ElkanKMeans { count[assign[i]] += 1.0; } for i in 0..c { - if count[i] == Scalar::Z { + if count[i] == F32::zero() { continue; } for dim in 0..dims as usize { - centroids[i][dim] /= count[i]; + centroids[i][dim] /= S::Scalar::from_f32(count[i].into()); } } for i in 0..c { - if count[i] != Scalar::Z { + if count[i] != F32::zero() { continue; } let mut o = 0; loop { - let alpha = Scalar(rand.gen_range(0.0..1.0)); - let beta = (count[o] - 1.0) / (n - c) as Float; + let alpha = F32::from_f32(rand.gen_range(0.0..1.0f32)); + let beta = (count[o] - 1.0) / (n - c) as f32; if alpha < beta { break; } @@ -188,28 +185,28 @@ impl ElkanKMeans { centroids.copy_within(o, i); for dim in 0..dims as usize { if dim % 2 == 0 { - centroids[i][dim] *= 1.0 + DELTA; - centroids[o][dim] *= 1.0 - DELTA; + centroids[i][dim] *= S::Scalar::from_f32(1.0 + DELTA); + centroids[o][dim] *= S::Scalar::from_f32(1.0 - DELTA); } else { - centroids[i][dim] *= 1.0 - DELTA; - centroids[o][dim] *= 1.0 + DELTA; + centroids[i][dim] *= S::Scalar::from_f32(1.0 - DELTA); + centroids[o][dim] *= S::Scalar::from_f32(1.0 + DELTA); } } count[i] = count[o] / 2.0; count[o] = count[o] - count[i]; } for i in 0..c { - self.d.elkan_k_means_normalize(&mut centroids[i]); + S::elkan_k_means_normalize(&mut centroids[i]); } // Step 5, 6 - let mut dist1 = vec![Scalar::Z; c]; + let mut dist1 = vec![F32::zero(); c]; for i in 0..c { - dist1[i] = f(&old[i], ¢roids[i]); + dist1[i] = S::elkan_k_means_distance(&old[i], ¢roids[i]); } for i in 0..n { for j in 0..c { - lowerbound[(i, j)] = (lowerbound[(i, j)] - dist1[j]).max(Scalar::Z); + lowerbound[(i, j)] = std::cmp::max(lowerbound[(i, j)] - dist1[j], F32::zero()); } } for i in 0..n { @@ -219,7 +216,7 @@ impl ElkanKMeans { change == 0 } - pub fn finish(self) -> Vec2 { + pub fn finish(self) -> Vec2 { self.centroids } } @@ -227,7 +224,7 @@ impl ElkanKMeans { pub struct Square { x: usize, y: usize, - v: Box<[Scalar]>, + v: Vec, } impl Square { @@ -235,13 +232,13 @@ impl Square { Self { x, y, - v: bytemuck::zeroed_slice_box(x * y), + v: bytemuck::zeroed_vec(x * y), } } } impl Index<(usize, usize)> for Square { - type Output = Scalar; + type Output = F32; fn index(&self, (x, y): (usize, usize)) -> &Self::Output { debug_assert!(x < self.x); diff --git a/src/algorithms/clustering/mod.rs b/crates/service/src/algorithms/clustering/mod.rs similarity index 100% rename from src/algorithms/clustering/mod.rs rename to crates/service/src/algorithms/clustering/mod.rs diff --git a/src/algorithms/flat.rs b/crates/service/src/algorithms/flat.rs similarity index 63% rename from src/algorithms/flat.rs rename to crates/service/src/algorithms/flat.rs index 0317a45f0..22650f4d0 100644 --- a/src/algorithms/flat.rs +++ b/crates/service/src/algorithms/flat.rs @@ -9,16 +9,16 @@ use std::fs::create_dir; use std::path::PathBuf; use std::sync::Arc; -pub struct Flat { - mmap: FlatMmap, +pub struct Flat { + mmap: FlatMmap, } -impl Flat { +impl Flat { pub fn create( path: PathBuf, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Self { create_dir(&path).unwrap(); let ram = make(path.clone(), sealed, growing, options.clone()); @@ -35,7 +35,7 @@ impl Flat { self.mmap.raw.len() } - pub fn vector(&self, i: u32) -> &[Scalar] { + pub fn vector(&self, i: u32) -> &[S::Scalar] { self.mmap.raw.vector(i) } @@ -43,35 +43,33 @@ impl Flat { self.mmap.raw.payload(i) } - pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { + pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap { search(&self.mmap, k, vector, filter) } } -unsafe impl Send for Flat {} -unsafe impl Sync for Flat {} +unsafe impl Send for Flat {} +unsafe impl Sync for Flat {} -pub struct FlatRam { - raw: Arc, - quantization: Quantization, - d: Distance, +pub struct FlatRam { + raw: Arc>, + quantization: Quantization, } -pub struct FlatMmap { - raw: Arc, - quantization: Quantization, - d: Distance, +pub struct FlatMmap { + raw: Arc>, + quantization: Quantization, } -unsafe impl Send for FlatMmap {} -unsafe impl Sync for FlatMmap {} +unsafe impl Send for FlatMmap {} +unsafe impl Sync for FlatMmap {} -pub fn make( +pub fn make( path: PathBuf, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, options: IndexOptions, -) -> FlatRam { +) -> FlatRam { let idx_opts = options.indexing.clone().unwrap_flat(); let raw = Arc::new(Raw::create( path.join("raw"), @@ -85,22 +83,17 @@ pub fn make( idx_opts.quantization, &raw, ); - FlatRam { - raw, - quantization, - d: options.vector.d, - } + FlatRam { raw, quantization } } -pub fn save(ram: FlatRam, _: PathBuf) -> FlatMmap { +pub fn save(ram: FlatRam, _: PathBuf) -> FlatMmap { FlatMmap { raw: ram.raw, quantization: ram.quantization, - d: ram.d, } } -pub fn load(path: PathBuf, options: IndexOptions) -> FlatMmap { +pub fn load(path: PathBuf, options: IndexOptions) -> FlatMmap { let idx_opts = options.indexing.clone().unwrap_flat(); let raw = Arc::new(Raw::open(path.join("raw"), options.clone())); let quantization = Quantization::open( @@ -109,17 +102,18 @@ pub fn load(path: PathBuf, options: IndexOptions) -> FlatMmap { idx_opts.quantization, &raw, ); - FlatMmap { - raw, - quantization, - d: options.vector.d, - } + FlatMmap { raw, quantization } } -pub fn search(mmap: &FlatMmap, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { +pub fn search( + mmap: &FlatMmap, + k: usize, + vector: &[S::Scalar], + filter: &mut impl Filter, +) -> Heap { let mut result = Heap::new(k); for i in 0..mmap.raw.len() { - let distance = mmap.quantization.distance(mmap.d, vector, i); + let distance = mmap.quantization.distance(vector, i); let payload = mmap.raw.payload(i); if filter.check(payload) { result.push(HeapElement { distance, payload }); diff --git a/src/algorithms/hnsw.rs b/crates/service/src/algorithms/hnsw.rs similarity index 82% rename from src/algorithms/hnsw.rs rename to crates/service/src/algorithms/hnsw.rs index 7bc83a82d..7342a87f7 100644 --- a/src/algorithms/hnsw.rs +++ b/crates/service/src/algorithms/hnsw.rs @@ -3,7 +3,7 @@ use super::raw::Raw; use crate::index::indexing::hnsw::HnswIndexingOptions; use crate::index::segments::growing::GrowingSegment; use crate::index::segments::sealed::SealedSegment; -use crate::index::{IndexOptions, VectorOptions}; +use crate::index::IndexOptions; use crate::prelude::*; use crate::utils::dir_ops::sync_dir; use crate::utils::mmap_array::MmapArray; @@ -17,16 +17,16 @@ use std::ops::RangeInclusive; use std::path::PathBuf; use std::sync::Arc; -pub struct Hnsw { - mmap: HnswMmap, +pub struct Hnsw { + mmap: HnswMmap, } -impl Hnsw { +impl Hnsw { pub fn create( path: PathBuf, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Self { create_dir(&path).unwrap(); let ram = make(path.clone(), sealed, growing, options.clone()); @@ -43,7 +43,7 @@ impl Hnsw { self.mmap.raw.len() } - pub fn vector(&self, i: u32) -> &[Scalar] { + pub fn vector(&self, i: u32) -> &[S::Scalar] { self.mmap.raw.vector(i) } @@ -51,19 +51,17 @@ impl Hnsw { self.mmap.raw.payload(i) } - pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { + pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap { search(&self.mmap, k, vector, filter) } } -unsafe impl Send for Hnsw {} -unsafe impl Sync for Hnsw {} +unsafe impl Send for Hnsw {} +unsafe impl Sync for Hnsw {} -pub struct HnswRam { - raw: Arc, - quantization: Quantization, - // ---------------------- - d: Distance, +pub struct HnswRam { + raw: Arc>, + quantization: Quantization, // ---------------------- m: u32, // ---------------------- @@ -87,14 +85,12 @@ impl HnswRamVertex { } struct HnswRamLayer { - edges: Vec<(Scalar, u32)>, + edges: Vec<(F32, u32)>, } -pub struct HnswMmap { - raw: Arc, - quantization: Quantization, - // ---------------------- - d: Distance, +pub struct HnswMmap { + raw: Arc>, + quantization: Quantization, // ---------------------- m: u32, // ---------------------- @@ -106,20 +102,19 @@ pub struct HnswMmap { } #[derive(Debug, Clone, Copy, Default)] -struct HnswMmapEdge(Scalar, u32); +struct HnswMmapEdge(F32, u32); -unsafe impl Send for HnswMmap {} -unsafe impl Sync for HnswMmap {} +unsafe impl Send for HnswMmap {} +unsafe impl Sync for HnswMmap {} unsafe impl Pod for HnswMmapEdge {} unsafe impl Zeroable for HnswMmapEdge {} -pub fn make( +pub fn make( path: PathBuf, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, options: IndexOptions, -) -> HnswRam { - let VectorOptions { d, .. } = options.vector; +) -> HnswRam { let HnswIndexingOptions { m, ef_construction, @@ -151,23 +146,22 @@ pub fn make( let entry = RwLock::>::new(None); let visited = VisitedPool::new(raw.len()); (0..n).into_par_iter().for_each(|i| { - fn fast_search( - quantization: &Quantization, + fn fast_search( + quantization: &Quantization, graph: &HnswRamGraph, - d: Distance, levels: RangeInclusive, u: u32, - target: &[Scalar], + target: &[S::Scalar], ) -> u32 { let mut u = u; - let mut u_dis = quantization.distance(d, target, u); + let mut u_dis = quantization.distance(target, u); for i in levels.rev() { let mut changed = true; while changed { changed = false; let guard = graph.vertexs[u as usize].layers[i as usize].read(); for &(_, v) in guard.edges.iter() { - let v_dis = quantization.distance(d, target, v); + let v_dis = quantization.distance(target, v); if v_dis < u_dis { u = v; u_dis = v_dis; @@ -178,21 +172,20 @@ pub fn make( } u } - fn local_search( - quantization: &Quantization, + fn local_search( + quantization: &Quantization, graph: &HnswRamGraph, - d: Distance, visited: &mut VisitedGuard, - vector: &[Scalar], + vector: &[S::Scalar], s: u32, k: usize, i: u8, - ) -> Vec<(Scalar, u32)> { + ) -> Vec<(F32, u32)> { assert!(k > 0); let mut visited = visited.fetch(); - let mut candidates = BinaryHeap::>::new(); + let mut candidates = BinaryHeap::>::new(); let mut results = BinaryHeap::new(); - let s_dis = quantization.distance(d, vector, s); + let s_dis = quantization.distance(vector, s); visited.mark(s); candidates.push(Reverse((s_dis, s))); results.push((s_dis, s)); @@ -209,7 +202,7 @@ pub fn make( continue; } visited.mark(v); - let v_dis = quantization.distance(d, vector, v); + let v_dis = quantization.distance(vector, v); if results.len() < k || v_dis < results.peek().unwrap().0 { candidates.push(Reverse((v_dis, v))); results.push((v_dis, v)); @@ -221,12 +214,7 @@ pub fn make( } results.into_sorted_vec() } - fn select( - quantization: &Quantization, - d: Distance, - input: &mut Vec<(Scalar, u32)>, - size: u32, - ) { + fn select(quantization: &Quantization, input: &mut Vec<(F32, u32)>, size: u32) { if input.len() <= size as usize { return; } @@ -237,7 +225,7 @@ pub fn make( } let check = res .iter() - .map(|&(_, v)| quantization.distance2(d, u, v)) + .map(|&(_, v)| quantization.distance2(u, v)) .all(|dist| dist > u_dis); if check { res.push((u_dis, u)); @@ -282,14 +270,13 @@ pub fn make( }; let top = graph.vertexs[u as usize].levels(); if top > levels { - u = fast_search(&quantization, &graph, d, levels + 1..=top, u, target); + u = fast_search(&quantization, &graph, levels + 1..=top, u, target); } let mut result = Vec::with_capacity(1 + std::cmp::min(levels, top) as usize); for j in (0..=std::cmp::min(levels, top)).rev() { let mut edges = local_search( &quantization, &graph, - d, &mut visited, target, u, @@ -297,12 +284,7 @@ pub fn make( j, ); edges.sort(); - select( - &quantization, - d, - &mut edges, - count_max_edges_of_a_layer(m, j), - ); + select(&quantization, &mut edges, count_max_edges_of_a_layer(m, j)); u = edges.first().unwrap().1; result.push(edges); } @@ -317,7 +299,6 @@ pub fn make( write.edges.insert(index, element); select( &quantization, - d, &mut write.edges, count_max_edges_of_a_layer(m, j), ); @@ -330,14 +311,13 @@ pub fn make( HnswRam { raw, quantization, - d, m, graph, visited, } } -pub fn save(mut ram: HnswRam, path: PathBuf) -> HnswMmap { +pub fn save(mut ram: HnswRam, path: PathBuf) -> HnswMmap { let edges = MmapArray::create( path.join("edges"), ram.graph @@ -361,7 +341,6 @@ pub fn save(mut ram: HnswRam, path: PathBuf) -> HnswMmap { HnswMmap { raw: ram.raw, quantization: ram.quantization, - d: ram.d, m: ram.m, edges, by_layer_id, @@ -370,7 +349,7 @@ pub fn save(mut ram: HnswRam, path: PathBuf) -> HnswMmap { } } -pub fn load(path: PathBuf, options: IndexOptions) -> HnswMmap { +pub fn load(path: PathBuf, options: IndexOptions) -> HnswMmap { let idx_opts = options.indexing.clone().unwrap_hnsw(); let raw = Arc::new(Raw::open(path.join("raw"), options.clone())); let quantization = Quantization::open( @@ -387,7 +366,6 @@ pub fn load(path: PathBuf, options: IndexOptions) -> HnswMmap { HnswMmap { raw, quantization, - d: options.vector.d, m: idx_opts.m, edges, by_layer_id, @@ -396,7 +374,12 @@ pub fn load(path: PathBuf, options: IndexOptions) -> HnswMmap { } } -pub fn search(mmap: &HnswMmap, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { +pub fn search( + mmap: &HnswMmap, + k: usize, + vector: &[S::Scalar], + filter: &mut impl Filter, +) -> Heap { let Some(s) = entry(mmap, filter) else { return Heap::new(k); }; @@ -405,7 +388,7 @@ pub fn search(mmap: &HnswMmap, k: usize, vector: &[Scalar], filter: &mut impl Fi local_search(mmap, k, u, vector, filter) } -pub fn entry(mmap: &HnswMmap, filter: &mut impl Filter) -> Option { +pub fn entry(mmap: &HnswMmap, filter: &mut impl Filter) -> Option { let m = mmap.m; let n = mmap.raw.len(); let mut shift = 1u64; @@ -433,15 +416,15 @@ pub fn entry(mmap: &HnswMmap, filter: &mut impl Filter) -> Option { None } -pub fn fast_search( - mmap: &HnswMmap, +pub fn fast_search( + mmap: &HnswMmap, levels: RangeInclusive, u: u32, - vector: &[Scalar], + vector: &[S::Scalar], filter: &mut impl Filter, ) -> u32 { let mut u = u; - let mut u_dis = mmap.quantization.distance(mmap.d, vector, u); + let mut u_dis = mmap.quantization.distance(vector, u); for i in levels.rev() { let mut changed = true; while changed { @@ -451,7 +434,7 @@ pub fn fast_search( if !filter.check(mmap.raw.payload(v)) { continue; } - let v_dis = mmap.quantization.distance(mmap.d, vector, v); + let v_dis = mmap.quantization.distance(vector, v); if v_dis < u_dis { u = v; u_dis = v_dis; @@ -463,20 +446,20 @@ pub fn fast_search( u } -pub fn local_search( - mmap: &HnswMmap, +pub fn local_search( + mmap: &HnswMmap, k: usize, s: u32, - vector: &[Scalar], + vector: &[S::Scalar], filter: &mut impl Filter, ) -> Heap { assert!(k > 0); let mut visited = mmap.visited.fetch(); let mut visited = visited.fetch(); - let mut candidates = BinaryHeap::>::new(); + let mut candidates = BinaryHeap::>::new(); let mut results = Heap::new(k); visited.mark(s); - let s_dis = mmap.quantization.distance(mmap.d, vector, s); + let s_dis = mmap.quantization.distance(vector, s); candidates.push(Reverse((s_dis, s))); results.push(HeapElement { distance: s_dis, @@ -495,7 +478,7 @@ pub fn local_search( if !filter.check(mmap.raw.payload(v)) { continue; } - let v_dis = mmap.quantization.distance(mmap.d, vector, v); + let v_dis = mmap.quantization.distance(vector, v); if !results.check(v_dis) { continue; } @@ -537,7 +520,7 @@ fn caluate_offsets(iter: impl Iterator) -> impl Iterator &[HnswMmapEdge] { +fn find_edges(mmap: &HnswMmap, u: u32, level: u8) -> &[HnswMmapEdge] { let offset = u as usize; let index = mmap.by_vertex_id[offset]..mmap.by_vertex_id[offset + 1]; let offset = index.start + level as usize; @@ -588,7 +571,7 @@ impl<'a> Drop for VisitedGuard<'a> { fn drop(&mut self) { let src = VisitedBuffer { version: 0, - data: Box::new([]), + data: Vec::new(), }; let buffer = std::mem::replace(&mut self.buffer, src); self.pool.locked_buffers.lock().push(buffer); @@ -610,14 +593,14 @@ impl<'a> VisitedChecker<'a> { struct VisitedBuffer { version: usize, - data: Box<[usize]>, + data: Vec, } impl VisitedBuffer { fn new(capacity: usize) -> Self { Self { version: 0, - data: bytemuck::zeroed_slice_box(capacity), + data: bytemuck::zeroed_vec(capacity), } } } diff --git a/src/algorithms/ivf/ivf_naive.rs b/crates/service/src/algorithms/ivf/ivf_naive.rs similarity index 77% rename from src/algorithms/ivf/ivf_naive.rs rename to crates/service/src/algorithms/ivf/ivf_naive.rs index 06920264a..2d4e0ea1d 100644 --- a/src/algorithms/ivf/ivf_naive.rs +++ b/crates/service/src/algorithms/ivf/ivf_naive.rs @@ -20,16 +20,16 @@ use std::sync::atomic::AtomicU32; use std::sync::atomic::Ordering::{Acquire, Relaxed, Release}; use std::sync::Arc; -pub struct IvfNaive { - mmap: IvfMmap, +pub struct IvfNaive { + mmap: IvfMmap, } -impl IvfNaive { +impl IvfNaive { pub fn create( path: PathBuf, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Self { create_dir(&path).unwrap(); let ram = make(path.clone(), sealed, growing, options); @@ -47,7 +47,7 @@ impl IvfNaive { self.mmap.raw.len() } - pub fn vector(&self, i: u32) -> &[Scalar] { + pub fn vector(&self, i: u32) -> &[S::Scalar] { self.mmap.raw.vector(i) } @@ -55,65 +55,63 @@ impl IvfNaive { self.mmap.raw.payload(i) } - pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { + pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap { search(&self.mmap, k, vector, filter) } } -unsafe impl Send for IvfNaive {} -unsafe impl Sync for IvfNaive {} +unsafe impl Send for IvfNaive {} +unsafe impl Sync for IvfNaive {} -pub struct IvfRam { - raw: Arc, - quantization: Quantization, +pub struct IvfRam { + raw: Arc>, + quantization: Quantization, // ---------------------- dims: u16, - d: Distance, // ---------------------- nlist: u32, nprobe: u32, // ---------------------- - centroids: Vec2, + centroids: Vec2, heads: Vec, nexts: Vec>, } -unsafe impl Send for IvfRam {} -unsafe impl Sync for IvfRam {} +unsafe impl Send for IvfRam {} +unsafe impl Sync for IvfRam {} -pub struct IvfMmap { - raw: Arc, - quantization: Quantization, +pub struct IvfMmap { + raw: Arc>, + quantization: Quantization, // ---------------------- dims: u16, - d: Distance, // ---------------------- nlist: u32, nprobe: u32, // ---------------------- - centroids: MmapArray, + centroids: MmapArray, heads: MmapArray, nexts: MmapArray, } -unsafe impl Send for IvfMmap {} -unsafe impl Sync for IvfMmap {} +unsafe impl Send for IvfMmap {} +unsafe impl Sync for IvfMmap {} -impl IvfMmap { - fn centroids(&self, i: u32) -> &[Scalar] { +impl IvfMmap { + fn centroids(&self, i: u32) -> &[S::Scalar] { let s = i as usize * self.dims as usize; let e = (i + 1) as usize * self.dims as usize; &self.centroids[s..e] } } -pub fn make( +pub fn make( path: PathBuf, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, options: IndexOptions, -) -> IvfRam { - let VectorOptions { dims, d } = options.vector; +) -> IvfRam { + let VectorOptions { dims, .. } = options.vector; let IvfIndexingOptions { least_iterations, iterations, @@ -140,9 +138,9 @@ pub fn make( let mut samples = Vec2::new(dims, m as usize); for i in 0..m { samples[i as usize].copy_from_slice(raw.vector(f[i as usize] as u32)); - d.elkan_k_means_normalize(&mut samples[i as usize]); + S::elkan_k_means_normalize(&mut samples[i as usize]); } - let mut k_means = ElkanKMeans::new(nlist as usize, samples, d); + let mut k_means = ElkanKMeans::new(nlist as usize, samples); for _ in 0..least_iterations { k_means.iterate(); } @@ -164,10 +162,10 @@ pub fn make( }; (0..n).into_par_iter().for_each(|i| { let mut vector = raw.vector(i).to_vec(); - d.elkan_k_means_normalize(&mut vector); - let mut result = (Scalar::INFINITY, 0); + S::elkan_k_means_normalize(&mut vector); + let mut result = (F32::infinity(), 0); for i in 0..nlist { - let dis = d.elkan_k_means_distance(&vector, ¢roids[i as usize]); + let dis = S::elkan_k_means_distance(&vector, ¢roids[i as usize]); result = std::cmp::min(result, (dis, i)); } let centroid_id = result.1; @@ -191,11 +189,10 @@ pub fn make( nprobe, nlist, dims, - d, } } -pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap { +pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap { let centroids = MmapArray::create( path.join("centroids"), (0..ram.nlist) @@ -214,7 +211,6 @@ pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap { raw: ram.raw, quantization: ram.quantization, dims: ram.dims, - d: ram.d, nlist: ram.nlist, nprobe: ram.nprobe, centroids, @@ -223,7 +219,7 @@ pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap { } } -pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap { +pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap { let raw = Arc::new(Raw::open(path.join("raw"), options.clone())); let quantization = Quantization::open( path.join("quantization"), @@ -239,7 +235,6 @@ pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap { raw, quantization, dims: options.vector.dims, - d: options.vector.d, nlist, nprobe, centroids, @@ -248,13 +243,18 @@ pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap { } } -pub fn search(mmap: &IvfMmap, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { +pub fn search( + mmap: &IvfMmap, + k: usize, + vector: &[S::Scalar], + filter: &mut impl Filter, +) -> Heap { let mut target = vector.to_vec(); - mmap.d.elkan_k_means_normalize(&mut target); + S::elkan_k_means_normalize(&mut target); let mut lists = Heap::new(mmap.nprobe as usize); for i in 0..mmap.nlist { let centroid = mmap.centroids(i); - let distance = mmap.d.elkan_k_means_distance(&target, centroid); + let distance = S::elkan_k_means_distance(&target, centroid); if lists.check(distance) { lists.push(HeapElement { distance, @@ -267,7 +267,7 @@ pub fn search(mmap: &IvfMmap, k: usize, vector: &[Scalar], filter: &mut impl Fil for i in lists.iter().map(|e| e.payload as usize) { let mut j = mmap.heads[i]; while u32::MAX != j { - let distance = mmap.quantization.distance(mmap.d, vector, j); + let distance = mmap.quantization.distance(vector, j); let payload = mmap.raw.payload(j); if result.check(distance) && filter.check(payload) { result.push(HeapElement { distance, payload }); diff --git a/src/algorithms/ivf/ivf_pq.rs b/crates/service/src/algorithms/ivf/ivf_pq.rs similarity index 77% rename from src/algorithms/ivf/ivf_pq.rs rename to crates/service/src/algorithms/ivf/ivf_pq.rs index e1ac8b7de..33e5621cc 100644 --- a/src/algorithms/ivf/ivf_pq.rs +++ b/crates/service/src/algorithms/ivf/ivf_pq.rs @@ -20,16 +20,16 @@ use std::sync::atomic::AtomicU32; use std::sync::atomic::Ordering::{Acquire, Relaxed, Release}; use std::sync::Arc; -pub struct IvfPq { - mmap: IvfMmap, +pub struct IvfPq { + mmap: IvfMmap, } -impl IvfPq { +impl IvfPq { pub fn create( path: PathBuf, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Self { create_dir(&path).unwrap(); let ram = make(path.clone(), sealed, growing, options); @@ -47,7 +47,7 @@ impl IvfPq { self.mmap.raw.len() } - pub fn vector(&self, i: u32) -> &[Scalar] { + pub fn vector(&self, i: u32) -> &[S::Scalar] { self.mmap.raw.vector(i) } @@ -55,65 +55,63 @@ impl IvfPq { self.mmap.raw.payload(i) } - pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { + pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap { search(&self.mmap, k, vector, filter) } } -unsafe impl Send for IvfPq {} -unsafe impl Sync for IvfPq {} +unsafe impl Send for IvfPq {} +unsafe impl Sync for IvfPq {} -pub struct IvfRam { - raw: Arc, - quantization: ProductQuantization, +pub struct IvfRam { + raw: Arc>, + quantization: ProductQuantization, // ---------------------- dims: u16, - d: Distance, // ---------------------- nlist: u32, nprobe: u32, // ---------------------- - centroids: Vec2, + centroids: Vec2, heads: Vec, nexts: Vec>, } -unsafe impl Send for IvfRam {} -unsafe impl Sync for IvfRam {} +unsafe impl Send for IvfRam {} +unsafe impl Sync for IvfRam {} -pub struct IvfMmap { - raw: Arc, - quantization: ProductQuantization, +pub struct IvfMmap { + raw: Arc>, + quantization: ProductQuantization, // ---------------------- dims: u16, - d: Distance, // ---------------------- nlist: u32, nprobe: u32, // ---------------------- - centroids: MmapArray, + centroids: MmapArray, heads: MmapArray, nexts: MmapArray, } -unsafe impl Send for IvfMmap {} -unsafe impl Sync for IvfMmap {} +unsafe impl Send for IvfMmap {} +unsafe impl Sync for IvfMmap {} -impl IvfMmap { - fn centroids(&self, i: u32) -> &[Scalar] { +impl IvfMmap { + fn centroids(&self, i: u32) -> &[S::Scalar] { let s = i as usize * self.dims as usize; let e = (i + 1) as usize * self.dims as usize; &self.centroids[s..e] } } -pub fn make( +pub fn make( path: PathBuf, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, options: IndexOptions, -) -> IvfRam { - let VectorOptions { dims, d } = options.vector; +) -> IvfRam { + let VectorOptions { dims, .. } = options.vector; let IvfIndexingOptions { least_iterations, iterations, @@ -134,9 +132,9 @@ pub fn make( let mut samples = Vec2::new(dims, m as usize); for i in 0..m { samples[i as usize].copy_from_slice(raw.vector(f[i as usize] as u32)); - d.elkan_k_means_normalize(&mut samples[i as usize]); + S::elkan_k_means_normalize(&mut samples[i as usize]); } - let mut k_means = ElkanKMeans::new(nlist as usize, samples, d); + let mut k_means = ElkanKMeans::new(nlist as usize, samples); for _ in 0..least_iterations { k_means.iterate(); } @@ -163,10 +161,10 @@ pub fn make( &raw, |i, target| { let mut vector = target.to_vec(); - d.elkan_k_means_normalize(&mut vector); - let mut result = (Scalar::INFINITY, 0); + S::elkan_k_means_normalize(&mut vector); + let mut result = (F32::infinity(), 0); for i in 0..nlist { - let dis = d.elkan_k_means_distance(&vector, ¢roids[i as usize]); + let dis = S::elkan_k_means_distance(&vector, ¢roids[i as usize]); result = std::cmp::min(result, (dis, i)); } let centroid_id = result.1; @@ -194,11 +192,10 @@ pub fn make( nprobe, nlist, dims, - d, } } -pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap { +pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap { let centroids = MmapArray::create( path.join("centroids"), (0..ram.nlist) @@ -217,7 +214,6 @@ pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap { raw: ram.raw, quantization: ram.quantization, dims: ram.dims, - d: ram.d, nlist: ram.nlist, nprobe: ram.nprobe, centroids, @@ -226,7 +222,7 @@ pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap { } } -pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap { +pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap { let raw = Arc::new(Raw::open(path.join("raw"), options.clone())); let quantization = ProductQuantization::open( path.join("quantization"), @@ -242,7 +238,6 @@ pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap { raw, quantization, dims: options.vector.dims, - d: options.vector.d, nlist, nprobe, centroids, @@ -251,13 +246,18 @@ pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap { } } -pub fn search(mmap: &IvfMmap, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { +pub fn search( + mmap: &IvfMmap, + k: usize, + vector: &[S::Scalar], + filter: &mut impl Filter, +) -> Heap { let mut target = vector.to_vec(); - mmap.d.elkan_k_means_normalize(&mut target); + S::elkan_k_means_normalize(&mut target); let mut lists = Heap::new(mmap.nprobe as usize); for i in 0..mmap.nlist { let centroid = mmap.centroids(i); - let distance = mmap.d.elkan_k_means_distance(&target, centroid); + let distance = S::elkan_k_means_distance(&target, centroid); if lists.check(distance) { lists.push(HeapElement { distance, @@ -270,9 +270,9 @@ pub fn search(mmap: &IvfMmap, k: usize, vector: &[Scalar], filter: &mut impl Fil for i in lists.iter().map(|e| e.payload as u32) { let mut j = mmap.heads[i as usize]; while u32::MAX != j { - let distance = - mmap.quantization - .distance_with_delta(mmap.d, vector, j, mmap.centroids(i)); + let distance = mmap + .quantization + .distance_with_delta(vector, j, mmap.centroids(i)); let payload = mmap.raw.payload(j); if result.check(distance) && filter.check(payload) { result.push(HeapElement { distance, payload }); diff --git a/src/algorithms/ivf/mod.rs b/crates/service/src/algorithms/ivf/mod.rs similarity index 84% rename from src/algorithms/ivf/mod.rs rename to crates/service/src/algorithms/ivf/mod.rs index 9bf645421..1fa73ee66 100644 --- a/src/algorithms/ivf/mod.rs +++ b/crates/service/src/algorithms/ivf/mod.rs @@ -10,17 +10,17 @@ use crate::prelude::*; use std::path::PathBuf; use std::sync::Arc; -pub enum Ivf { - Naive(IvfNaive), - Pq(IvfPq), +pub enum Ivf { + Naive(IvfNaive), + Pq(IvfPq), } -impl Ivf { +impl Ivf { pub fn create( path: PathBuf, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Self { if options .indexing @@ -56,7 +56,7 @@ impl Ivf { } } - pub fn vector(&self, i: u32) -> &[Scalar] { + pub fn vector(&self, i: u32) -> &[S::Scalar] { match self { Ivf::Naive(x) => x.vector(i), Ivf::Pq(x) => x.vector(i), @@ -70,7 +70,7 @@ impl Ivf { } } - pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { + pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap { match self { Ivf::Naive(x) => x.search(k, vector, filter), Ivf::Pq(x) => x.search(k, vector, filter), diff --git a/src/algorithms/mod.rs b/crates/service/src/algorithms/mod.rs similarity index 84% rename from src/algorithms/mod.rs rename to crates/service/src/algorithms/mod.rs index 9c20f5655..a3c5ffd52 100644 --- a/src/algorithms/mod.rs +++ b/crates/service/src/algorithms/mod.rs @@ -1,5 +1,4 @@ pub mod clustering; -pub mod diskann; pub mod flat; pub mod hnsw; pub mod ivf; diff --git a/src/algorithms/quantization/mod.rs b/crates/service/src/algorithms/quantization/mod.rs similarity index 80% rename from src/algorithms/quantization/mod.rs rename to crates/service/src/algorithms/quantization/mod.rs index 5b5e790d6..5e23c284e 100644 --- a/src/algorithms/quantization/mod.rs +++ b/crates/service/src/algorithms/quantization/mod.rs @@ -56,35 +56,35 @@ impl QuantizationOptions { } } -pub trait Quan { +pub trait Quan { fn create( path: PathBuf, options: IndexOptions, quantization_options: QuantizationOptions, - raw: &Arc, + raw: &Arc>, ) -> Self; fn open( path: PathBuf, options: IndexOptions, quantization_options: QuantizationOptions, - raw: &Arc, + raw: &Arc>, ) -> Self; - fn distance(&self, d: Distance, lhs: &[Scalar], rhs: u32) -> Scalar; - fn distance2(&self, d: Distance, lhs: u32, rhs: u32) -> Scalar; + fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32; + fn distance2(&self, lhs: u32, rhs: u32) -> F32; } -pub enum Quantization { - Trivial(TrivialQuantization), - Scalar(ScalarQuantization), - Product(ProductQuantization), +pub enum Quantization { + Trivial(TrivialQuantization), + Scalar(ScalarQuantization), + Product(ProductQuantization), } -impl Quantization { +impl Quantization { pub fn create( path: PathBuf, options: IndexOptions, quantization_options: QuantizationOptions, - raw: &Arc, + raw: &Arc>, ) -> Self { match quantization_options { QuantizationOptions::Trivial(_) => Self::Trivial(TrivialQuantization::create( @@ -112,7 +112,7 @@ impl Quantization { path: PathBuf, options: IndexOptions, quantization_options: QuantizationOptions, - raw: &Arc, + raw: &Arc>, ) -> Self { match quantization_options { QuantizationOptions::Trivial(_) => Self::Trivial(TrivialQuantization::open( @@ -136,21 +136,21 @@ impl Quantization { } } - pub fn distance(&self, d: Distance, lhs: &[Scalar], rhs: u32) -> Scalar { + pub fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32 { use Quantization::*; match self { - Trivial(x) => x.distance(d, lhs, rhs), - Scalar(x) => x.distance(d, lhs, rhs), - Product(x) => x.distance(d, lhs, rhs), + Trivial(x) => x.distance(lhs, rhs), + Scalar(x) => x.distance(lhs, rhs), + Product(x) => x.distance(lhs, rhs), } } - pub fn distance2(&self, d: Distance, lhs: u32, rhs: u32) -> Scalar { + pub fn distance2(&self, lhs: u32, rhs: u32) -> F32 { use Quantization::*; match self { - Trivial(x) => x.distance2(d, lhs, rhs), - Scalar(x) => x.distance2(d, lhs, rhs), - Product(x) => x.distance2(d, lhs, rhs), + Trivial(x) => x.distance2(lhs, rhs), + Scalar(x) => x.distance2(lhs, rhs), + Product(x) => x.distance2(lhs, rhs), } } } diff --git a/src/algorithms/quantization/product.rs b/crates/service/src/algorithms/quantization/product.rs similarity index 81% rename from src/algorithms/quantization/product.rs rename to crates/service/src/algorithms/quantization/product.rs index a87c529f3..50d073476 100644 --- a/src/algorithms/quantization/product.rs +++ b/crates/service/src/algorithms/quantization/product.rs @@ -55,17 +55,17 @@ impl Default for ProductQuantizationOptionsRatio { } } -pub struct ProductQuantization { +pub struct ProductQuantization { dims: u16, ratio: u16, - centroids: Vec, + centroids: Vec, codes: MmapArray, } -unsafe impl Send for ProductQuantization {} -unsafe impl Sync for ProductQuantization {} +unsafe impl Send for ProductQuantization {} +unsafe impl Sync for ProductQuantization {} -impl ProductQuantization { +impl ProductQuantization { fn codes(&self, i: u32) -> &[u8] { let width = self.dims.div_ceil(self.ratio); let s = i as usize * width as usize; @@ -74,12 +74,12 @@ impl ProductQuantization { } } -impl Quan for ProductQuantization { +impl Quan for ProductQuantization { fn create( path: PathBuf, options: IndexOptions, quantization_options: QuantizationOptions, - raw: &Arc, + raw: &Arc>, ) -> Self { Self::with_normalizer(path, options, quantization_options, raw, |_, _| ()) } @@ -88,7 +88,7 @@ impl Quan for ProductQuantization { path: PathBuf, options: IndexOptions, quantization_options: QuantizationOptions, - _: &Arc, + _: &Arc>, ) -> Self { let centroids = serde_json::from_slice(&std::fs::read(path.join("centroids")).unwrap()).unwrap(); @@ -101,32 +101,32 @@ impl Quan for ProductQuantization { } } - fn distance(&self, d: Distance, lhs: &[Scalar], rhs: u32) -> Scalar { + fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32 { let dims = self.dims; let ratio = self.ratio; let rhs = self.codes(rhs); - d.product_quantization_distance(dims, ratio, &self.centroids, lhs, rhs) + S::product_quantization_distance(dims, ratio, &self.centroids, lhs, rhs) } - fn distance2(&self, d: Distance, lhs: u32, rhs: u32) -> Scalar { + fn distance2(&self, lhs: u32, rhs: u32) -> F32 { let dims = self.dims; let ratio = self.ratio; let lhs = self.codes(lhs); let rhs = self.codes(rhs); - d.product_quantization_distance2(dims, ratio, &self.centroids, lhs, rhs) + S::product_quantization_distance2(dims, ratio, &self.centroids, lhs, rhs) } } -impl ProductQuantization { +impl ProductQuantization { pub fn with_normalizer( path: PathBuf, options: IndexOptions, quantization_options: QuantizationOptions, - raw: &Raw, + raw: &Raw, normalizer: F, ) -> Self where - F: Fn(u32, &mut [Scalar]), + F: Fn(u32, &mut [S::Scalar]), { std::fs::create_dir(&path).unwrap(); let quantization_options = quantization_options.unwrap_product_quantization(); @@ -136,22 +136,22 @@ impl ProductQuantization { let m = std::cmp::min(n, quantization_options.sample); let samples = { let f = sample(&mut thread_rng(), n as usize, m as usize).into_vec(); - let mut samples = Vec2::new(options.vector.dims, m as usize); + let mut samples = Vec2::::new(options.vector.dims, m as usize); for i in 0..m { samples[i as usize].copy_from_slice(raw.vector(f[i as usize] as u32)); } samples }; let width = dims.div_ceil(ratio); - let mut centroids = vec![Scalar::Z; 256 * dims as usize]; + let mut centroids = vec![S::Scalar::zero(); 256 * dims as usize]; for i in 0..width { let subdims = std::cmp::min(ratio, dims - ratio * i); - let mut subsamples = Vec2::new(subdims, m as usize); + let mut subsamples = Vec2::::new(subdims, m as usize); for j in 0..m { let src = &samples[j as usize][(i * ratio) as usize..][..subdims as usize]; subsamples[j as usize].copy_from_slice(src); } - let mut k_means = ElkanKMeans::new(256, subsamples, Distance::L2); + let mut k_means = ElkanKMeans::::new(256, subsamples); for _ in 0..25 { if k_means.iterate() { break; @@ -170,13 +170,13 @@ impl ProductQuantization { let mut result = Vec::with_capacity(width as usize); for i in 0..width { let subdims = std::cmp::min(ratio, dims - ratio * i); - let mut minimal = Scalar::INFINITY; + let mut minimal = F32::infinity(); let mut target = 0u8; let left = &vector[(i * ratio) as usize..][..subdims as usize]; for j in 0u8..=255 { let right = ¢roids[j as usize * dims as usize..][(i * ratio) as usize..] [..subdims as usize]; - let dis = Distance::L2.distance(left, right); + let dis = S::l2_distance(left, right); if dis < minimal { minimal = dis; target = j; @@ -201,16 +201,10 @@ impl ProductQuantization { } } - pub fn distance_with_delta( - &self, - d: Distance, - lhs: &[Scalar], - rhs: u32, - delta: &[Scalar], - ) -> Scalar { + pub fn distance_with_delta(&self, lhs: &[S::Scalar], rhs: u32, delta: &[S::Scalar]) -> F32 { let dims = self.dims; let ratio = self.ratio; let rhs = self.codes(rhs); - d.product_quantization_distance_with_delta(dims, ratio, &self.centroids, lhs, rhs, delta) + S::product_quantization_distance_with_delta(dims, ratio, &self.centroids, lhs, rhs, delta) } } diff --git a/src/algorithms/quantization/scalar.rs b/crates/service/src/algorithms/quantization/scalar.rs similarity index 75% rename from src/algorithms/quantization/scalar.rs rename to crates/service/src/algorithms/quantization/scalar.rs index 46d7f6b64..f6effab2b 100644 --- a/src/algorithms/quantization/scalar.rs +++ b/crates/service/src/algorithms/quantization/scalar.rs @@ -19,17 +19,17 @@ impl Default for ScalarQuantizationOptions { } } -pub struct ScalarQuantization { +pub struct ScalarQuantization { dims: u16, - max: Vec, - min: Vec, + max: Vec, + min: Vec, codes: MmapArray, } -unsafe impl Send for ScalarQuantization {} -unsafe impl Sync for ScalarQuantization {} +unsafe impl Send for ScalarQuantization {} +unsafe impl Sync for ScalarQuantization {} -impl ScalarQuantization { +impl ScalarQuantization { fn codes(&self, i: u32) -> &[u8] { let s = i as usize * self.dims as usize; let e = (i + 1) as usize * self.dims as usize; @@ -37,17 +37,17 @@ impl ScalarQuantization { } } -impl Quan for ScalarQuantization { +impl Quan for ScalarQuantization { fn create( path: PathBuf, options: IndexOptions, _: QuantizationOptions, - raw: &Arc, + raw: &Arc>, ) -> Self { std::fs::create_dir(&path).unwrap(); let dims = options.vector.dims; - let mut max = vec![Scalar::NEG_INFINITY; dims as usize]; - let mut min = vec![Scalar::INFINITY; dims as usize]; + let mut max = vec![S::Scalar::neg_infinity(); dims as usize]; + let mut min = vec![S::Scalar::infinity(); dims as usize]; let n = raw.len(); for i in 0..n { let vector = raw.vector(i); @@ -62,7 +62,7 @@ impl Quan for ScalarQuantization { let vector = raw.vector(i); let mut result = vec![0u8; dims as usize]; for i in 0..dims as usize { - let w = ((vector[i] - min[i]) / (max[i] - min[i]) * 256.0).0 as u32; + let w = (((vector[i] - min[i]) / (max[i] - min[i])).to_f32() * 256.0) as u32; result[i] = w.clamp(0, 255) as u8; } result.into_iter() @@ -77,7 +77,7 @@ impl Quan for ScalarQuantization { } } - fn open(path: PathBuf, options: IndexOptions, _: QuantizationOptions, _: &Arc) -> Self { + fn open(path: PathBuf, options: IndexOptions, _: QuantizationOptions, _: &Arc>) -> Self { let dims = options.vector.dims; let max = serde_json::from_slice(&std::fs::read("max").unwrap()).unwrap(); let min = serde_json::from_slice(&std::fs::read("min").unwrap()).unwrap(); @@ -90,16 +90,16 @@ impl Quan for ScalarQuantization { } } - fn distance(&self, d: Distance, lhs: &[Scalar], rhs: u32) -> Scalar { + fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32 { let dims = self.dims; let rhs = self.codes(rhs); - d.scalar_quantization_distance(dims, &self.max, &self.min, lhs, rhs) + S::scalar_quantization_distance(dims, &self.max, &self.min, lhs, rhs) } - fn distance2(&self, d: Distance, lhs: u32, rhs: u32) -> Scalar { + fn distance2(&self, lhs: u32, rhs: u32) -> F32 { let dims = self.dims; let lhs = self.codes(lhs); let rhs = self.codes(rhs); - d.scalar_quantization_distance2(dims, &self.max, &self.min, lhs, rhs) + S::scalar_quantization_distance2(dims, &self.max, &self.min, lhs, rhs) } } diff --git a/src/algorithms/quantization/trivial.rs b/crates/service/src/algorithms/quantization/trivial.rs similarity index 64% rename from src/algorithms/quantization/trivial.rs rename to crates/service/src/algorithms/quantization/trivial.rs index 7bff6510f..416e398a5 100644 --- a/src/algorithms/quantization/trivial.rs +++ b/crates/service/src/algorithms/quantization/trivial.rs @@ -17,24 +17,24 @@ impl Default for TrivialQuantizationOptions { } } -pub struct TrivialQuantization { - raw: Arc, +pub struct TrivialQuantization { + raw: Arc>, } -impl Quan for TrivialQuantization { - fn create(_: PathBuf, _: IndexOptions, _: QuantizationOptions, raw: &Arc) -> Self { +impl Quan for TrivialQuantization { + fn create(_: PathBuf, _: IndexOptions, _: QuantizationOptions, raw: &Arc>) -> Self { Self { raw: raw.clone() } } - fn open(_: PathBuf, _: IndexOptions, _: QuantizationOptions, raw: &Arc) -> Self { + fn open(_: PathBuf, _: IndexOptions, _: QuantizationOptions, raw: &Arc>) -> Self { Self { raw: raw.clone() } } - fn distance(&self, d: Distance, lhs: &[Scalar], rhs: u32) -> Scalar { - d.distance(lhs, self.raw.vector(rhs)) + fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32 { + S::distance(lhs, self.raw.vector(rhs)) } - fn distance2(&self, d: Distance, lhs: u32, rhs: u32) -> Scalar { - d.distance(self.raw.vector(lhs), self.raw.vector(rhs)) + fn distance2(&self, lhs: u32, rhs: u32) -> F32 { + S::distance(self.raw.vector(lhs), self.raw.vector(rhs)) } } diff --git a/src/algorithms/raw.rs b/crates/service/src/algorithms/raw.rs similarity index 74% rename from src/algorithms/raw.rs rename to crates/service/src/algorithms/raw.rs index ec91673aa..f94ba9d59 100644 --- a/src/algorithms/raw.rs +++ b/crates/service/src/algorithms/raw.rs @@ -6,16 +6,16 @@ use crate::utils::mmap_array::MmapArray; use std::path::PathBuf; use std::sync::Arc; -pub struct Raw { - mmap: RawMmap, +pub struct Raw { + mmap: RawMmap, } -impl Raw { +impl Raw { pub fn create( path: PathBuf, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Self { std::fs::create_dir(&path).unwrap(); let ram = make(sealed, growing, options); @@ -33,7 +33,7 @@ impl Raw { self.mmap.len() } - pub fn vector(&self, i: u32) -> &[Scalar] { + pub fn vector(&self, i: u32) -> &[S::Scalar] { self.mmap.vector(i) } @@ -42,21 +42,21 @@ impl Raw { } } -unsafe impl Send for Raw {} -unsafe impl Sync for Raw {} +unsafe impl Send for Raw {} +unsafe impl Sync for Raw {} -struct RawRam { - sealed: Vec>, - growing: Vec>, +struct RawRam { + sealed: Vec>>, + growing: Vec>>, dims: u16, } -impl RawRam { +impl RawRam { fn len(&self) -> u32 { self.sealed.iter().map(|x| x.len()).sum::() + self.growing.iter().map(|x| x.len()).sum::() } - fn vector(&self, mut index: u32) -> &[Scalar] { + fn vector(&self, mut index: u32) -> &[S::Scalar] { for x in self.sealed.iter() { if index < x.len() { return x.vector(index); @@ -88,18 +88,18 @@ impl RawRam { } } -struct RawMmap { - vectors: MmapArray, +struct RawMmap { + vectors: MmapArray, payload: MmapArray, dims: u16, } -impl RawMmap { +impl RawMmap { fn len(&self) -> u32 { self.payload.len() as u32 } - fn vector(&self, i: u32) -> &[Scalar] { + fn vector(&self, i: u32) -> &[S::Scalar] { let s = i as usize * self.dims as usize; let e = (i + 1) as usize * self.dims as usize; &self.vectors[s..e] @@ -110,14 +110,14 @@ impl RawMmap { } } -unsafe impl Send for RawMmap {} -unsafe impl Sync for RawMmap {} +unsafe impl Send for RawMmap {} +unsafe impl Sync for RawMmap {} -fn make( - sealed: Vec>, - growing: Vec>, +fn make( + sealed: Vec>>, + growing: Vec>>, options: IndexOptions, -) -> RawRam { +) -> RawRam { RawRam { sealed, growing, @@ -125,7 +125,7 @@ fn make( } } -fn save(ram: RawRam, path: PathBuf) -> RawMmap { +fn save(ram: RawRam, path: PathBuf) -> RawMmap { let n = ram.len(); let vectors_iter = (0..n).flat_map(|i| ram.vector(i)).copied(); let payload_iter = (0..n).map(|i| ram.payload(i)); @@ -138,8 +138,8 @@ fn save(ram: RawRam, path: PathBuf) -> RawMmap { } } -fn load(path: PathBuf, options: IndexOptions) -> RawMmap { - let vectors: MmapArray = MmapArray::open(path.join("vectors")); +fn load(path: PathBuf, options: IndexOptions) -> RawMmap { + let vectors = MmapArray::open(path.join("vectors")); let payload = MmapArray::open(path.join("payload")); RawMmap { vectors, diff --git a/src/algorithms/diskann/vamana.rs b/crates/service/src/algorithms/vamana.rs.txt similarity index 100% rename from src/algorithms/diskann/vamana.rs rename to crates/service/src/algorithms/vamana.rs.txt diff --git a/src/index/delete.rs b/crates/service/src/index/delete.rs similarity index 100% rename from src/index/delete.rs rename to crates/service/src/index/delete.rs diff --git a/src/index/indexing/flat.rs b/crates/service/src/index/indexing/flat.rs similarity index 77% rename from src/index/indexing/flat.rs rename to crates/service/src/index/indexing/flat.rs index becb27f67..d288216f2 100644 --- a/src/index/indexing/flat.rs +++ b/crates/service/src/index/indexing/flat.rs @@ -24,16 +24,16 @@ impl Default for FlatIndexingOptions { } } -pub struct FlatIndexing { - raw: crate::algorithms::flat::Flat, +pub struct FlatIndexing { + raw: crate::algorithms::flat::Flat, } -impl AbstractIndexing for FlatIndexing { +impl AbstractIndexing for FlatIndexing { fn create( path: PathBuf, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Self { let raw = Flat::create(path, options, sealed, growing); Self { raw } @@ -48,7 +48,7 @@ impl AbstractIndexing for FlatIndexing { self.raw.len() } - fn vector(&self, i: u32) -> &[Scalar] { + fn vector(&self, i: u32) -> &[S::Scalar] { self.raw.vector(i) } @@ -56,7 +56,7 @@ impl AbstractIndexing for FlatIndexing { self.raw.payload(i) } - fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { + fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap { self.raw.search(k, vector, filter) } } diff --git a/src/index/indexing/hnsw.rs b/crates/service/src/index/indexing/hnsw.rs similarity index 83% rename from src/index/indexing/hnsw.rs rename to crates/service/src/index/indexing/hnsw.rs index 38dddcb6c..7934a254c 100644 --- a/src/index/indexing/hnsw.rs +++ b/crates/service/src/index/indexing/hnsw.rs @@ -40,16 +40,16 @@ impl Default for HnswIndexingOptions { } } -pub struct HnswIndexing { - raw: Hnsw, +pub struct HnswIndexing { + raw: Hnsw, } -impl AbstractIndexing for HnswIndexing { +impl AbstractIndexing for HnswIndexing { fn create( path: PathBuf, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Self { let raw = Hnsw::create(path, options, sealed, growing); Self { raw } @@ -64,7 +64,7 @@ impl AbstractIndexing for HnswIndexing { self.raw.len() } - fn vector(&self, i: u32) -> &[Scalar] { + fn vector(&self, i: u32) -> &[S::Scalar] { self.raw.vector(i) } @@ -72,7 +72,7 @@ impl AbstractIndexing for HnswIndexing { self.raw.payload(i) } - fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { + fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap { self.raw.search(k, vector, filter) } } diff --git a/src/index/indexing/ivf.rs b/crates/service/src/index/indexing/ivf.rs similarity index 88% rename from src/index/indexing/ivf.rs rename to crates/service/src/index/indexing/ivf.rs index db83d52c9..9d074864a 100644 --- a/src/index/indexing/ivf.rs +++ b/crates/service/src/index/indexing/ivf.rs @@ -4,7 +4,6 @@ use crate::algorithms::quantization::QuantizationOptions; use crate::index::segments::growing::GrowingSegment; use crate::index::segments::sealed::SealedSegment; use crate::index::IndexOptions; -use crate::prelude::Scalar; use crate::prelude::*; use serde::{Deserialize, Serialize}; use std::path::PathBuf; @@ -64,16 +63,16 @@ impl Default for IvfIndexingOptions { } } -pub struct IvfIndexing { - raw: Ivf, +pub struct IvfIndexing { + raw: Ivf, } -impl AbstractIndexing for IvfIndexing { +impl AbstractIndexing for IvfIndexing { fn create( path: PathBuf, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Self { let raw = Ivf::create(path, options, sealed, growing); Self { raw } @@ -88,7 +87,7 @@ impl AbstractIndexing for IvfIndexing { self.raw.len() } - fn vector(&self, i: u32) -> &[Scalar] { + fn vector(&self, i: u32) -> &[S::Scalar] { self.raw.vector(i) } @@ -96,7 +95,7 @@ impl AbstractIndexing for IvfIndexing { self.raw.payload(i) } - fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { + fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap { self.raw.search(k, vector, filter) } } diff --git a/src/index/indexing/mod.rs b/crates/service/src/index/indexing/mod.rs similarity index 84% rename from src/index/indexing/mod.rs rename to crates/service/src/index/indexing/mod.rs index ccb0c9c5b..b7bc2123b 100644 --- a/src/index/indexing/mod.rs +++ b/crates/service/src/index/indexing/mod.rs @@ -59,32 +59,32 @@ impl Validate for IndexingOptions { } } -pub trait AbstractIndexing: Sized { +pub trait AbstractIndexing: Sized { fn create( path: PathBuf, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Self; fn open(path: PathBuf, options: IndexOptions) -> Self; fn len(&self) -> u32; - fn vector(&self, i: u32) -> &[Scalar]; + fn vector(&self, i: u32) -> &[S::Scalar]; fn payload(&self, i: u32) -> Payload; - fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap; + fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap; } -pub enum DynamicIndexing { - Flat(FlatIndexing), - Ivf(IvfIndexing), - Hnsw(HnswIndexing), +pub enum DynamicIndexing { + Flat(FlatIndexing), + Ivf(IvfIndexing), + Hnsw(HnswIndexing), } -impl DynamicIndexing { +impl DynamicIndexing { pub fn create( path: PathBuf, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Self { match options.indexing { IndexingOptions::Flat(_) => { @@ -115,7 +115,7 @@ impl DynamicIndexing { } } - pub fn vector(&self, i: u32) -> &[Scalar] { + pub fn vector(&self, i: u32) -> &[S::Scalar] { match self { DynamicIndexing::Flat(x) => x.vector(i), DynamicIndexing::Ivf(x) => x.vector(i), @@ -131,7 +131,7 @@ impl DynamicIndexing { } } - pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { + pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap { match self { DynamicIndexing::Flat(x) => x.search(k, vector, filter), DynamicIndexing::Ivf(x) => x.search(k, vector, filter), diff --git a/crates/service/src/index/mod.rs b/crates/service/src/index/mod.rs new file mode 100644 index 000000000..1d881b8cd --- /dev/null +++ b/crates/service/src/index/mod.rs @@ -0,0 +1,415 @@ +pub mod delete; +pub mod indexing; +pub mod optimizing; +pub mod segments; + +use self::delete::Delete; +use self::indexing::IndexingOptions; +use self::optimizing::OptimizingOptions; +use self::segments::growing::GrowingSegment; +use self::segments::growing::GrowingSegmentInsertError; +use self::segments::sealed::SealedSegment; +use self::segments::SegmentsOptions; +use crate::index::optimizing::indexing::OptimizerIndexing; +use crate::index::optimizing::sealing::OptimizerSealing; +use crate::prelude::*; +use crate::utils::clean::clean; +use crate::utils::dir_ops::sync_dir; +use crate::utils::file_atomic::FileAtomic; +use arc_swap::ArcSwap; +use crossbeam::atomic::AtomicCell; +use parking_lot::Mutex; +use serde::{Deserialize, Serialize}; +use std::cmp::Reverse; +use std::collections::BinaryHeap; +use std::collections::HashMap; +use std::collections::HashSet; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Instant; +use thiserror::Error; +use uuid::Uuid; +use validator::Validate; + +#[derive(Debug, Error)] +#[error("The index view is outdated.")] +pub struct OutdatedError(#[from] pub Option); + +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +pub struct VectorOptions { + #[validate(range(min = 1, max = 65535))] + #[serde(rename = "dimensions")] + pub dims: u16, + #[serde(rename = "distance")] + pub d: Distance, + #[serde(rename = "kind")] + pub k: Kind, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +pub struct IndexOptions { + #[validate] + pub vector: VectorOptions, + #[validate] + pub segment: SegmentsOptions, + #[validate] + pub optimizing: OptimizingOptions, + #[validate] + pub indexing: IndexingOptions, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct IndexStat { + pub indexing: bool, + pub sealed: Vec, + pub growing: Vec, + pub write: u32, + pub options: IndexOptions, +} + +pub struct Index { + path: PathBuf, + options: IndexOptions, + delete: Arc, + protect: Mutex>, + view: ArcSwap>, + instant_index: AtomicCell, + instant_write: AtomicCell, + _tracker: Arc, +} + +impl Index { + pub fn create(path: PathBuf, options: IndexOptions) -> Arc { + assert!(options.validate().is_ok()); + std::fs::create_dir(&path).unwrap(); + std::fs::create_dir(path.join("segments")).unwrap(); + let startup = FileAtomic::create( + path.join("startup"), + IndexStartup { + sealeds: HashSet::new(), + growings: HashSet::new(), + }, + ); + let delete = Delete::create(path.join("delete")); + sync_dir(&path); + let index = Arc::new(Index { + path: path.clone(), + options: options.clone(), + delete: delete.clone(), + protect: Mutex::new(IndexProtect { + startup, + sealed: HashMap::new(), + growing: HashMap::new(), + write: None, + }), + view: ArcSwap::new(Arc::new(IndexView { + options: options.clone(), + sealed: HashMap::new(), + growing: HashMap::new(), + delete: delete.clone(), + write: None, + })), + instant_index: AtomicCell::new(Instant::now()), + instant_write: AtomicCell::new(Instant::now()), + _tracker: Arc::new(IndexTracker { path }), + }); + OptimizerIndexing::new(index.clone()).spawn(); + OptimizerSealing::new(index.clone()).spawn(); + index + } + pub fn open(path: PathBuf, options: IndexOptions) -> Arc { + let tracker = Arc::new(IndexTracker { path: path.clone() }); + let startup = FileAtomic::::open(path.join("startup")); + clean( + path.join("segments"), + startup + .get() + .sealeds + .iter() + .map(|s| s.to_string()) + .chain(startup.get().growings.iter().map(|s| s.to_string())), + ); + let sealed = startup + .get() + .sealeds + .iter() + .map(|&uuid| { + ( + uuid, + SealedSegment::open( + tracker.clone(), + path.join("segments").join(uuid.to_string()), + uuid, + options.clone(), + ), + ) + }) + .collect::>(); + let growing = startup + .get() + .growings + .iter() + .map(|&uuid| { + ( + uuid, + GrowingSegment::open( + tracker.clone(), + path.join("segments").join(uuid.to_string()), + uuid, + ), + ) + }) + .collect::>(); + let delete = Delete::open(path.join("delete")); + let index = Arc::new(Index { + path: path.clone(), + options: options.clone(), + delete: delete.clone(), + protect: Mutex::new(IndexProtect { + startup, + sealed: sealed.clone(), + growing: growing.clone(), + write: None, + }), + view: ArcSwap::new(Arc::new(IndexView { + options: options.clone(), + delete: delete.clone(), + sealed, + growing, + write: None, + })), + instant_index: AtomicCell::new(Instant::now()), + instant_write: AtomicCell::new(Instant::now()), + _tracker: tracker, + }); + OptimizerIndexing::new(index.clone()).spawn(); + OptimizerSealing::new(index.clone()).spawn(); + index + } + pub fn options(&self) -> &IndexOptions { + &self.options + } + pub fn view(&self) -> Arc> { + self.view.load_full() + } + pub fn refresh(&self) { + let mut protect = self.protect.lock(); + if let Some((uuid, write)) = protect.write.clone() { + if !write.is_full() { + return; + } + write.seal(); + protect.growing.insert(uuid, write); + } + let write_segment_uuid = Uuid::new_v4(); + let write_segment = GrowingSegment::create( + self._tracker.clone(), + self.path + .join("segments") + .join(write_segment_uuid.to_string()), + write_segment_uuid, + self.options.clone(), + ); + protect.write = Some((write_segment_uuid, write_segment)); + protect.maintain(self.options.clone(), self.delete.clone(), &self.view); + self.instant_write.store(Instant::now()); + } + pub fn seal(&self, check: Uuid) { + let mut protect = self.protect.lock(); + if let Some((uuid, write)) = protect.write.clone() { + if check != uuid { + return; + } + write.seal(); + protect.growing.insert(uuid, write); + } + protect.write = None; + protect.maintain(self.options.clone(), self.delete.clone(), &self.view); + self.instant_write.store(Instant::now()); + } + pub fn stat(&self) -> IndexStat { + let view = self.view(); + IndexStat { + indexing: self.instant_index.load() < self.instant_write.load(), + sealed: view.sealed.values().map(|x| x.len()).collect(), + growing: view.growing.values().map(|x| x.len()).collect(), + write: view.write.as_ref().map(|(_, x)| x.len()).unwrap_or(0), + options: self.options().clone(), + } + } +} + +impl Drop for Index { + fn drop(&mut self) {} +} + +#[derive(Debug, Clone)] +pub struct IndexTracker { + path: PathBuf, +} + +impl Drop for IndexTracker { + fn drop(&mut self) { + std::fs::remove_dir_all(&self.path).unwrap(); + } +} + +pub struct IndexView { + pub options: IndexOptions, + pub delete: Arc, + pub sealed: HashMap>>, + pub growing: HashMap>>, + pub write: Option<(Uuid, Arc>)>, +} + +impl IndexView { + pub fn search bool>( + &self, + k: usize, + vector: &[S::Scalar], + mut filter: F, + ) -> Vec { + assert_eq!(self.options.vector.dims as usize, vector.len()); + + struct Comparer(BinaryHeap>); + + impl PartialEq for Comparer { + fn eq(&self, other: &Self) -> bool { + self.cmp(other).is_eq() + } + } + + impl Eq for Comparer {} + + impl PartialOrd for Comparer { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + + impl Ord for Comparer { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0.peek().cmp(&other.0.peek()).reverse() + } + } + + let mut filter = |payload| { + if let Some(p) = self.delete.check(payload) { + filter(p) + } else { + false + } + }; + let n = self.sealed.len() + self.growing.len() + 1; + let mut result = Heap::new(k); + let mut heaps = BinaryHeap::with_capacity(1 + n); + for (_, sealed) in self.sealed.iter() { + let p = sealed.search(k, vector, &mut filter).into_reversed_heap(); + heaps.push(Comparer(p)); + } + for (_, growing) in self.growing.iter() { + let p = growing.search(k, vector, &mut filter).into_reversed_heap(); + heaps.push(Comparer(p)); + } + if let Some((_, write)) = &self.write { + let p = write.search(k, vector, &mut filter).into_reversed_heap(); + heaps.push(Comparer(p)); + } + while let Some(Comparer(mut heap)) = heaps.pop() { + if let Some(Reverse(x)) = heap.pop() { + result.push(x); + heaps.push(Comparer(heap)); + } + } + result + .into_sorted_vec() + .iter() + .map(|x| Pointer::from_u48(x.payload >> 16)) + .collect() + } + pub fn insert(&self, vector: Vec, pointer: Pointer) -> Result<(), OutdatedError> { + assert_eq!(self.options.vector.dims as usize, vector.len()); + let payload = (pointer.as_u48() << 16) | self.delete.version(pointer) as Payload; + if let Some((_, growing)) = self.write.as_ref() { + Ok(growing.insert(vector, payload)?) + } else { + Err(OutdatedError(None)) + } + } + pub fn delete bool>(&self, mut f: F) { + for (_, sealed) in self.sealed.iter() { + let n = sealed.len(); + for i in 0..n { + if let Some(p) = self.delete.check(sealed.payload(i)) { + if f(p) { + self.delete.delete(p); + } + } + } + } + for (_, growing) in self.growing.iter() { + let n = growing.len(); + for i in 0..n { + if let Some(p) = self.delete.check(growing.payload(i)) { + if f(p) { + self.delete.delete(p); + } + } + } + } + if let Some((_, write)) = &self.write { + let n = write.len(); + for i in 0..n { + if let Some(p) = self.delete.check(write.payload(i)) { + if f(p) { + self.delete.delete(p); + } + } + } + } + } + pub fn flush(&self) { + self.delete.flush(); + if let Some((_, write)) = &self.write { + write.flush(); + } + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct IndexStartup { + sealeds: HashSet, + growings: HashSet, +} + +struct IndexProtect { + startup: FileAtomic, + sealed: HashMap>>, + growing: HashMap>>, + write: Option<(Uuid, Arc>)>, +} + +impl IndexProtect { + fn maintain( + &mut self, + options: IndexOptions, + delete: Arc, + swap: &ArcSwap>, + ) { + let view = Arc::new(IndexView { + options, + delete, + sealed: self.sealed.clone(), + growing: self.growing.clone(), + write: self.write.clone(), + }); + let startup_write = self.write.as_ref().map(|(uuid, _)| *uuid); + let startup_sealeds = self.sealed.keys().copied().collect(); + let startup_growings = self.growing.keys().copied().chain(startup_write).collect(); + self.startup.set(IndexStartup { + sealeds: startup_sealeds, + growings: startup_growings, + }); + swap.swap(view); + } +} diff --git a/src/index/optimizing/indexing.rs b/crates/service/src/index/optimizing/indexing.rs similarity index 57% rename from src/index/optimizing/indexing.rs rename to crates/service/src/index/optimizing/indexing.rs index 445496201..a9e62101d 100644 --- a/src/index/optimizing/indexing.rs +++ b/crates/service/src/index/optimizing/indexing.rs @@ -1,16 +1,54 @@ use crate::index::GrowingSegment; use crate::index::Index; use crate::index::SealedSegment; +use crate::prelude::*; use std::cmp::Reverse; use std::sync::Arc; +use std::time::Instant; use uuid::Uuid; -enum Seg { - Sealed(Arc), - Growing(Arc), +pub struct OptimizerIndexing { + index: Arc>, } -impl Seg { +impl OptimizerIndexing { + pub fn new(index: Arc>) -> Self { + Self { index } + } + pub fn spawn(self) { + std::thread::spawn(move || { + self.main(); + }); + } + pub fn main(self) { + let index = self.index; + let pool = rayon::ThreadPoolBuilder::new() + .num_threads(index.options.optimizing.optimizing_threads) + .build() + .unwrap(); + let weak_index = Arc::downgrade(&index); + std::mem::drop(index); + loop { + { + let Some(index) = weak_index.upgrade() else { + return; + }; + let cont = pool.install(|| optimizing_indexing(index.clone())); + if cont { + continue; + } + } + std::thread::sleep(std::time::Duration::from_secs(60)); + } + } +} + +enum Seg { + Sealed(Arc>), + Growing(Arc>), +} + +impl Seg { fn uuid(&self) -> Uuid { use Seg::*; match self { @@ -25,13 +63,13 @@ impl Seg { Growing(x) => x.len(), } } - fn get_sealed(&self) -> Option> { + fn get_sealed(&self) -> Option>> { match self { Seg::Sealed(x) => Some(x.clone()), _ => None, } } - fn get_growing(&self) -> Option> { + fn get_growing(&self) -> Option>> { match self { Seg::Growing(x) => Some(x.clone()), _ => None, @@ -39,7 +77,7 @@ impl Seg { } } -pub fn optimizing_indexing(index: Arc) -> bool { +pub fn optimizing_indexing(index: Arc>) -> bool { use Seg::*; let segs = { let mut all_segs = { @@ -51,16 +89,21 @@ pub fn optimizing_indexing(index: Arc) -> bool { all_segs }; let mut segs = Vec::new(); - let mut segs_len = 0u64; + let mut total = 0u64; + let mut count = 0; while let Some(seg) = all_segs.pop() { - if segs_len + seg.len() as u64 <= index.options.segment.max_sealed_segment_size as u64 { - segs_len += seg.len() as u64; + if total + seg.len() as u64 <= index.options.segment.max_sealed_segment_size as u64 { + total += seg.len() as u64; + if let Growing(_) = seg { + count += 1; + } segs.push(seg); } else { break; } } - if segs_len < index.options.segment.min_sealed_segment_size as u64 || segs.len() < 3 { + if segs.len() == 0 || (segs.len() == 1 && count == 0) { + index.instant_index.store(Instant::now()); return true; } segs @@ -87,7 +130,7 @@ pub fn optimizing_indexing(index: Arc) -> bool { false } -fn merge(index: &Arc, segs: &[Seg]) -> Arc { +fn merge(index: &Arc>, segs: &[Seg]) -> Arc> { let sealed = segs.iter().filter_map(|x| x.get_sealed()).collect(); let growing = segs.iter().filter_map(|x| x.get_growing()).collect(); let sealed_segment_uuid = Uuid::new_v4(); diff --git a/src/index/optimizing/mod.rs b/crates/service/src/index/optimizing/mod.rs similarity index 67% rename from src/index/optimizing/mod.rs rename to crates/service/src/index/optimizing/mod.rs index a0426720d..b109948cf 100644 --- a/src/index/optimizing/mod.rs +++ b/crates/service/src/index/optimizing/mod.rs @@ -1,3 +1,4 @@ +pub mod sealing; pub mod indexing; pub mod vacuum; @@ -6,9 +7,12 @@ use validator::Validate; #[derive(Debug, Clone, Serialize, Deserialize, Validate)] pub struct OptimizingOptions { - #[serde(default = "OptimizingOptions::default_waiting_secs", skip)] - #[validate(range(min = 0, max = 600))] - pub waiting_secs: u64, + #[serde(default = "OptimizingOptions::default_sealing_secs")] + #[validate(range(min = 0, max = 60))] + pub sealing_secs: u64, + #[serde(default = "OptimizingOptions::default_sealing_size")] + #[validate(range(min = 1, max = 4_000_000_000))] + pub sealing_size: u32, #[serde(default = "OptimizingOptions::default_deleted_threshold", skip)] #[validate(range(min = 0.01, max = 1.00))] pub deleted_threshold: f64, @@ -18,9 +22,12 @@ pub struct OptimizingOptions { } impl OptimizingOptions { - fn default_waiting_secs() -> u64 { + fn default_sealing_secs() -> u64 { 60 } + fn default_sealing_size() -> u32 { + 1 + } fn default_deleted_threshold() -> f64 { 0.2 } @@ -35,7 +42,8 @@ impl OptimizingOptions { impl Default for OptimizingOptions { fn default() -> Self { Self { - waiting_secs: Self::default_waiting_secs(), + sealing_secs: Self::default_sealing_secs(), + sealing_size: Self::default_sealing_size(), deleted_threshold: Self::default_deleted_threshold(), optimizing_threads: Self::default_optimizing_threads(), } diff --git a/crates/service/src/index/optimizing/sealing.rs b/crates/service/src/index/optimizing/sealing.rs new file mode 100644 index 000000000..0e02ead88 --- /dev/null +++ b/crates/service/src/index/optimizing/sealing.rs @@ -0,0 +1,49 @@ +use crate::index::Index; +use crate::prelude::*; +use std::sync::Arc; +use std::time::Duration; + +pub struct OptimizerSealing { + index: Arc>, +} + +impl OptimizerSealing { + pub fn new(index: Arc>) -> Self { + Self { index } + } + pub fn spawn(self) { + std::thread::spawn(move || { + self.main(); + }); + } + pub fn main(self) { + let index = self.index; + let dur = Duration::from_secs(index.options.optimizing.sealing_secs); + let least = index.options.optimizing.sealing_size; + let weak_index = Arc::downgrade(&index); + std::mem::drop(index); + let mut check = None; + loop { + { + let Some(index) = weak_index.upgrade() else { + return; + }; + let view = index.view(); + let stamp = view + .write + .as_ref() + .map(|(uuid, segment)| (*uuid, segment.len())); + if stamp == check { + if let Some((uuid, len)) = stamp { + if len >= least { + index.seal(uuid); + } + } + } else { + check = stamp; + } + } + std::thread::sleep(dur); + } + } +} diff --git a/src/index/optimizing/vacuum.rs b/crates/service/src/index/optimizing/vacuum.rs similarity index 100% rename from src/index/optimizing/vacuum.rs rename to crates/service/src/index/optimizing/vacuum.rs diff --git a/src/index/segments/growing.rs b/crates/service/src/index/segments/growing.rs similarity index 80% rename from src/index/segments/growing.rs rename to crates/service/src/index/segments/growing.rs index e8ac0421f..5c20c2029 100644 --- a/src/index/segments/growing.rs +++ b/crates/service/src/index/segments/growing.rs @@ -1,7 +1,6 @@ use super::SegmentTracker; use crate::index::IndexOptions; use crate::index::IndexTracker; -use crate::index::VectorOptions; use crate::prelude::*; use crate::utils::dir_ops::sync_dir; use crate::utils::file_wal::FileWal; @@ -19,17 +18,16 @@ use uuid::Uuid; #[error("`GrowingSegment` stopped growing.")] pub struct GrowingSegmentInsertError; -pub struct GrowingSegment { +pub struct GrowingSegment { uuid: Uuid, - options: VectorOptions, - vec: Vec>>, + vec: Vec>>>, wal: Mutex, len: AtomicUsize, pro: Mutex, _tracker: Arc, } -impl GrowingSegment { +impl GrowingSegment { pub fn create( _tracker: Arc, path: PathBuf, @@ -42,7 +40,6 @@ impl GrowingSegment { sync_dir(&path); Arc::new(Self { uuid, - options: options.vector, vec: unsafe { let mut vec = Vec::with_capacity(capacity as usize); vec.set_len(capacity as usize); @@ -57,23 +54,17 @@ impl GrowingSegment { _tracker: Arc::new(SegmentTracker { path, _tracker }), }) } - pub fn open( - _tracker: Arc, - path: PathBuf, - uuid: Uuid, - options: IndexOptions, - ) -> Arc { + pub fn open(_tracker: Arc, path: PathBuf, uuid: Uuid) -> Arc { let mut wal = FileWal::open(path.join("wal")); let mut vec = Vec::new(); while let Some(log) = wal.read() { - let log = bincode::deserialize::(&log).unwrap(); + let log = bincode::deserialize::>(&log).unwrap(); vec.push(UnsafeCell::new(MaybeUninit::new(log))); } wal.truncate(); let n = vec.len(); Arc::new(Self { uuid, - options: options.vector, vec, wal: { Mutex::new(wal) }, len: AtomicUsize::new(n), @@ -87,6 +78,20 @@ impl GrowingSegment { pub fn uuid(&self) -> Uuid { self.uuid } + pub fn is_full(&self) -> bool { + let n; + { + let pro = self.pro.lock(); + if pro.inflight < pro.capacity { + return false; + } + n = pro.inflight; + } + while self.len.load(Ordering::Acquire) != n { + std::hint::spin_loop(); + } + true + } pub fn seal(&self) { let n; { @@ -104,7 +109,7 @@ impl GrowingSegment { } pub fn insert( &self, - vector: Vec, + vector: Vec, payload: Payload, ) -> Result<(), GrowingSegmentInsertError> { let log = Log { vector, payload }; @@ -126,13 +131,13 @@ impl GrowingSegment { self.len.store(1 + i, Ordering::Release); self.wal .lock() - .write(&bincode::serialize::(&log).unwrap()); + .write(&bincode::serialize::>(&log).unwrap()); Ok(()) } pub fn len(&self) -> u32 { self.len.load(Ordering::Acquire) as u32 } - pub fn vector(&self, i: u32) -> &[Scalar] { + pub fn vector(&self, i: u32) -> &[S::Scalar] { let i = i as usize; if i >= self.len.load(Ordering::Acquire) { panic!("Out of bound."); @@ -148,12 +153,12 @@ impl GrowingSegment { let log = unsafe { (*self.vec[i].get()).assume_init_ref() }; log.payload } - pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { + pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap { let n = self.len.load(Ordering::Acquire); let mut heap = Heap::new(k); for i in 0..n { let log = unsafe { (*self.vec[i].get()).assume_init_ref() }; - let distance = self.options.d.distance(vector, &log.vector); + let distance = S::distance(vector, &log.vector); if heap.check(distance) && filter.check(log.payload) { heap.push(HeapElement { distance, @@ -165,10 +170,10 @@ impl GrowingSegment { } } -unsafe impl Send for GrowingSegment {} -unsafe impl Sync for GrowingSegment {} +unsafe impl Send for GrowingSegment {} +unsafe impl Sync for GrowingSegment {} -impl Drop for GrowingSegment { +impl Drop for GrowingSegment { fn drop(&mut self) { let n = *self.len.get_mut(); for i in 0..n { @@ -180,8 +185,8 @@ impl Drop for GrowingSegment { } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -struct Log { - vector: Vec, +struct Log { + vector: Vec, payload: Payload, } diff --git a/src/index/segments/mod.rs b/crates/service/src/index/segments/mod.rs similarity index 67% rename from src/index/segments/mod.rs rename to crates/service/src/index/segments/mod.rs index 83b0682af..b09c1f939 100644 --- a/src/index/segments/mod.rs +++ b/crates/service/src/index/segments/mod.rs @@ -10,14 +10,10 @@ use validator::ValidationError; #[derive(Debug, Clone, Serialize, Deserialize, Validate)] #[validate(schema(function = "Self::validate_0"))] -#[validate(schema(function = "Self::validate_1"))] pub struct SegmentsOptions { #[serde(default = "SegmentsOptions::default_max_growing_segment_size")] #[validate(range(min = 1, max = 4_000_000_000))] pub max_growing_segment_size: u32, - #[serde(default = "SegmentsOptions::default_min_sealed_segment_size")] - #[validate(range(min = 1, max = 4_000_000_000))] - pub min_sealed_segment_size: u32, #[serde(default = "SegmentsOptions::default_max_sealed_segment_size")] #[validate(range(min = 1, max = 4_000_000_000))] pub max_sealed_segment_size: u32, @@ -27,22 +23,11 @@ impl SegmentsOptions { fn default_max_growing_segment_size() -> u32 { 20_000 } - fn default_min_sealed_segment_size() -> u32 { - 1_000 - } fn default_max_sealed_segment_size() -> u32 { 1_000_000 } - // min_sealed_segment_size <= max_growing_segment_size <= max_sealed_segment_size + // max_growing_segment_size <= max_sealed_segment_size fn validate_0(&self) -> Result<(), ValidationError> { - if self.min_sealed_segment_size > self.max_growing_segment_size { - return Err(ValidationError::new( - "`min_sealed_segment_size` must be less than or equal to `max_growing_segment_size`", - )); - } - Ok(()) - } - fn validate_1(&self) -> Result<(), ValidationError> { if self.max_growing_segment_size > self.max_sealed_segment_size { return Err(ValidationError::new( "`max_growing_segment_size` must be less than or equal to `max_sealed_segment_size`", @@ -56,7 +41,6 @@ impl Default for SegmentsOptions { fn default() -> Self { Self { max_growing_segment_size: Self::default_max_growing_segment_size(), - min_sealed_segment_size: Self::default_min_sealed_segment_size(), max_sealed_segment_size: Self::default_max_sealed_segment_size(), } } diff --git a/src/index/segments/sealed.rs b/crates/service/src/index/segments/sealed.rs similarity index 81% rename from src/index/segments/sealed.rs rename to crates/service/src/index/segments/sealed.rs index 52a101558..068c31c03 100644 --- a/src/index/segments/sealed.rs +++ b/crates/service/src/index/segments/sealed.rs @@ -8,20 +8,20 @@ use std::path::PathBuf; use std::sync::Arc; use uuid::Uuid; -pub struct SealedSegment { +pub struct SealedSegment { uuid: Uuid, - indexing: DynamicIndexing, + indexing: DynamicIndexing, _tracker: Arc, } -impl SealedSegment { +impl SealedSegment { pub fn create( _tracker: Arc, path: PathBuf, uuid: Uuid, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Arc { std::fs::create_dir(&path).unwrap(); let indexing = DynamicIndexing::create(path.join("indexing"), options, sealed, growing); @@ -51,13 +51,13 @@ impl SealedSegment { pub fn len(&self) -> u32 { self.indexing.len() } - pub fn vector(&self, i: u32) -> &[Scalar] { + pub fn vector(&self, i: u32) -> &[S::Scalar] { self.indexing.vector(i) } pub fn payload(&self, i: u32) -> Payload { self.indexing.payload(i) } - pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { + pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap { self.indexing.search(k, vector, filter) } } diff --git a/crates/service/src/lib.rs b/crates/service/src/lib.rs new file mode 100644 index 000000000..bf1f5018a --- /dev/null +++ b/crates/service/src/lib.rs @@ -0,0 +1,8 @@ +#![feature(core_intrinsics)] + +pub mod algorithms; +pub mod index; +pub mod prelude; +pub mod worker; + +mod utils; diff --git a/src/prelude/error.rs b/crates/service/src/prelude/error.rs similarity index 61% rename from src/prelude/error.rs rename to crates/service/src/prelude/error.rs index e3d8b8b45..fbefffe71 100644 --- a/src/prelude/error.rs +++ b/crates/service/src/prelude/error.rs @@ -1,5 +1,3 @@ -use crate::ipc::IpcError; -use crate::prelude::*; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -15,100 +13,77 @@ or simply run the command `psql -U postgres -c 'ALTER SYSTEM SET shared_preload_ ")] BadInit, #[error("\ -The given index option is invalid. -INFORMATION: reason = {0:?}\ -")] - BadOption(String), - #[error("\ -The given vector is invalid for input. -INFORMATION: vector = {0:?} -ADVICE: Check if dimensions of the vector is matched with the index.\ +Bad literal. +INFORMATION: hint = {hint}\ ")] - BadVector(Vec), + BadLiteral { + hint: String, + }, #[error("\ Modifier of the type is invalid. ADVICE: Check if modifier of the type is an integer among 1 and 65535.\ ")] - BadTypmod, + BadTypeDimensions, #[error("\ Dimensions of the vector is invalid. ADVICE: Check if dimensions of the vector are among 1 and 65535.\ ")] - BadVecForDims, + BadValueDimensions, #[error("\ -Dimensions of the vector is unmatched with the type modifier. -INFORMATION: type_dimensions = {type_dimensions}, value_dimensions = {value_dimensions}\ +The given index option is invalid. +INFORMATION: reason = {validation:?}\ ")] - BadVecForUnmatchedDims { - value_dimensions: u16, - type_dimensions: u16, - }, + BadOption { validation: String }, #[error("\ -Operands of the operator differs in dimensions. -INFORMATION: left_dimensions = {left_dimensions}, right_dimensions = {right_dimensions}\ +Dimensions type modifier of a vector column is needed for building the index.\ ")] - DifferentVectorDims { - left_dimensions: u16, - right_dimensions: u16, - }, + BadOption2, #[error("\ Indexes can only be built on built-in distance functions. ADVICE: If you want pgvecto.rs to support more distance functions, \ visit `https://github.com/tensorchord/pgvecto.rs/issues` and contribute your ideas.\ ")] - UnsupportedOperator, + BadOptions3, #[error("\ The index is not existing in the background worker. ADVICE: Drop or rebuild the index.\ ")] - Index404, + UnknownIndex, #[error("\ -Dimensions type modifier of a vector column is needed for building the index.\ -")] - DimsIsNeeded, - #[error("\ -Bad vector string. -INFORMATION: hint = {hint}\ +Operands of the operator differs in dimensions or scalar type. +INFORMATION: left_dimensions = {left_dimensions}, right_dimensions = {right_dimensions}\ ")] - BadVectorString { - hint: String, + Unmatched { + left_dimensions: u16, + right_dimensions: u16, }, #[error("\ -`mmap` transport is not supported by MacOS.\ +The given vector is invalid for input. +ADVICE: Check if dimensions and scalar type of the vector is matched with the index.\ ")] - MmapTransportNotSupported, + Unmatched2, } -impl FriendlyError { - pub fn friendly(self) -> ! { - panic!("pgvecto.rs: {}", self); - } +pub trait FriendlyErrorLike { + fn friendly(self) -> !; } -impl IpcError { - pub fn friendly(self) -> ! { +impl FriendlyErrorLike for FriendlyError { + fn friendly(self) -> ! { panic!("pgvecto.rs: {}", self); } } -pub trait Friendly { +pub trait FriendlyResult { type Output; fn friendly(self) -> Self::Output; } -impl Friendly for Result { - type Output = T; - - fn friendly(self) -> T { - match self { - Ok(x) => x, - Err(e) => e.friendly(), - } - } -} - -impl Friendly for Result { +impl FriendlyResult for Result +where + E: FriendlyErrorLike, +{ type Output = T; fn friendly(self) -> T { diff --git a/src/prelude/filter.rs b/crates/service/src/prelude/filter.rs similarity index 100% rename from src/prelude/filter.rs rename to crates/service/src/prelude/filter.rs diff --git a/crates/service/src/prelude/global/f16_cos.rs b/crates/service/src/prelude/global/f16_cos.rs new file mode 100644 index 000000000..28bba1c3f --- /dev/null +++ b/crates/service/src/prelude/global/f16_cos.rs @@ -0,0 +1,222 @@ +use super::G; +use crate::prelude::scalar::F32; +use crate::prelude::*; + +#[derive(Debug, Clone, Copy)] +pub enum F16Cos {} + +impl G for F16Cos { + type Scalar = F16; + + const DISTANCE: Distance = Distance::Cos; + + type L2 = F16L2; + + #[multiversion::multiversion(targets = "simd")] + fn distance(lhs: &[F16], rhs: &[F16]) -> F32 { + cosine(lhs, rhs) * (-1.0) + } + + fn l2_distance(lhs: &[F16], rhs: &[F16]) -> F32 { + super::f16_l2::distance_squared_l2(lhs, rhs) + } + + #[multiversion::multiversion(targets = "simd")] + fn elkan_k_means_normalize(vector: &mut [F16]) { + l2_normalize(vector) + } + + #[multiversion::multiversion(targets = "simd")] + fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 { + super::f16_dot::dot(lhs, rhs).acos() + } + + #[multiversion::multiversion(targets = "simd")] + fn scalar_quantization_distance( + dims: u16, + max: &[F16], + min: &[F16], + lhs: &[F16], + rhs: &[u8], + ) -> F32 { + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..dims as usize { + let _x = lhs[i].to_f(); + let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f(); + xy += _x * _y; + x2 += _x * _x; + y2 += _y * _y; + } + xy / (x2 * y2).sqrt() * (-1.0) + } + + #[multiversion::multiversion(targets = "simd")] + fn scalar_quantization_distance2( + dims: u16, + max: &[F16], + min: &[F16], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..dims as usize { + let _x = F32(lhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f(); + let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f(); + xy += _x * _y; + x2 += _x * _x; + y2 += _y * _y; + } + xy / (x2 * y2).sqrt() * (-1.0) + } + + #[multiversion::multiversion(targets = "simd")] + fn product_quantization_distance( + dims: u16, + ratio: u16, + centroids: &[F16], + lhs: &[F16], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs); + xy += _xy; + x2 += _x2; + y2 += _y2; + } + xy / (x2 * y2).sqrt() * (-1.0) + } + + #[multiversion::multiversion(targets = "simd")] + fn product_quantization_distance2( + dims: u16, + ratio: u16, + centroids: &[F16], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhsp = lhs[i as usize] as usize * dims as usize; + let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs); + xy += _xy; + x2 += _x2; + y2 += _y2; + } + xy / (x2 * y2).sqrt() * (-1.0) + } + + #[multiversion::multiversion(targets = "simd")] + fn product_quantization_distance_with_delta( + dims: u16, + ratio: u16, + centroids: &[F16], + lhs: &[F16], + rhs: &[u8], + delta: &[F16], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let del = &delta[(i * ratio) as usize..][..k as usize]; + let (_xy, _x2, _y2) = xy_x2_y2_delta(lhs, rhs, del); + xy += _xy; + x2 += _x2; + y2 += _y2; + } + xy / (x2 * y2).sqrt() * (-1.0) + } +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +fn length(vector: &[F16]) -> F16 { + let n = vector.len(); + let mut dot = F16::zero(); + for i in 0..n { + dot += vector[i] * vector[i]; + } + dot.sqrt() +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +fn l2_normalize(vector: &mut [F16]) { + let n = vector.len(); + let l = length(vector); + for i in 0..n { + vector[i] /= l; + } +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..n { + xy += lhs[i].to_f() * rhs[i].to_f(); + x2 += lhs[i].to_f() * lhs[i].to_f(); + y2 += rhs[i].to_f() * rhs[i].to_f(); + } + xy / (x2 * y2).sqrt() +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +fn xy_x2_y2(lhs: &[F16], rhs: &[F16]) -> (F32, F32, F32) { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..n { + xy += lhs[i].to_f() * rhs[i].to_f(); + x2 += lhs[i].to_f() * lhs[i].to_f(); + y2 += rhs[i].to_f() * rhs[i].to_f(); + } + (xy, x2, y2) +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +fn xy_x2_y2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> (F32, F32, F32) { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..n { + xy += lhs[i].to_f() * (rhs[i].to_f() + del[i].to_f()); + x2 += lhs[i].to_f() * lhs[i].to_f(); + y2 += (rhs[i].to_f() + del[i].to_f()) * (rhs[i].to_f() + del[i].to_f()); + } + (xy, x2, y2) +} diff --git a/crates/service/src/prelude/global/f16_dot.rs b/crates/service/src/prelude/global/f16_dot.rs new file mode 100644 index 000000000..8f6581032 --- /dev/null +++ b/crates/service/src/prelude/global/f16_dot.rs @@ -0,0 +1,191 @@ +use super::G; +use crate::prelude::scalar::F32; +use crate::prelude::*; + +#[derive(Debug, Clone, Copy)] +pub enum F16Dot {} + +impl G for F16Dot { + type Scalar = F16; + + const DISTANCE: Distance = Distance::Dot; + + type L2 = F16L2; + + fn distance(lhs: &[F16], rhs: &[F16]) -> F32 { + dot(lhs, rhs) * (-1.0) + } + + fn l2_distance(lhs: &[F16], rhs: &[F16]) -> F32 { + super::f16_l2::distance_squared_l2(lhs, rhs) + } + + fn elkan_k_means_normalize(vector: &mut [F16]) { + l2_normalize(vector) + } + + fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 { + super::f16_dot::dot(lhs, rhs).acos() + } + + #[multiversion::multiversion(targets = "simd")] + fn scalar_quantization_distance( + dims: u16, + max: &[F16], + min: &[F16], + lhs: &[F16], + rhs: &[u8], + ) -> F32 { + let mut xy = F32::zero(); + for i in 0..dims as usize { + let _x = lhs[i].to_f(); + let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f(); + xy += _x * _y; + } + xy * (-1.0) + } + + #[multiversion::multiversion(targets = "simd")] + fn scalar_quantization_distance2( + dims: u16, + max: &[F16], + min: &[F16], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let mut xy = F32::zero(); + for i in 0..dims as usize { + let _x = F32(lhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f(); + let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f(); + xy += _x * _y; + } + xy * (-1.0) + } + + #[multiversion::multiversion(targets = "simd")] + fn product_quantization_distance( + dims: u16, + ratio: u16, + centroids: &[F16], + lhs: &[F16], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let _xy = dot(lhs, rhs); + xy += _xy; + } + xy * (-1.0) + } + + #[multiversion::multiversion(targets = "simd")] + fn product_quantization_distance2( + dims: u16, + ratio: u16, + centroids: &[F16], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhsp = lhs[i as usize] as usize * dims as usize; + let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let _xy = dot(lhs, rhs); + xy += _xy; + } + xy * (-1.0) + } + + #[multiversion::multiversion(targets = "simd")] + fn product_quantization_distance_with_delta( + dims: u16, + ratio: u16, + centroids: &[F16], + lhs: &[F16], + rhs: &[u8], + delta: &[F16], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let del = &delta[(i * ratio) as usize..][..k as usize]; + let _xy = dot_delta(lhs, rhs, del); + xy += _xy; + } + xy * (-1.0) + } +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +fn length(vector: &[F16]) -> F16 { + let n = vector.len(); + let mut dot = F16::zero(); + for i in 0..n { + dot += vector[i] * vector[i]; + } + dot.sqrt() +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +fn l2_normalize(vector: &mut [F16]) { + let n = vector.len(); + let l = length(vector); + for i in 0..n { + vector[i] /= l; + } +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..n { + xy += lhs[i].to_f() * rhs[i].to_f(); + x2 += lhs[i].to_f() * lhs[i].to_f(); + y2 += rhs[i].to_f() * rhs[i].to_f(); + } + xy / (x2 * y2).sqrt() +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + for i in 0..n { + xy += lhs[i].to_f() * rhs[i].to_f(); + } + xy +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +fn dot_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n: usize = lhs.len(); + let mut xy = F32::zero(); + for i in 0..n { + xy += lhs[i].to_f() * (rhs[i].to_f() + del[i].to_f()); + } + xy +} diff --git a/crates/service/src/prelude/global/f16_l2.rs b/crates/service/src/prelude/global/f16_l2.rs new file mode 100644 index 000000000..fd706bc07 --- /dev/null +++ b/crates/service/src/prelude/global/f16_l2.rs @@ -0,0 +1,152 @@ +use super::G; +use crate::prelude::scalar::F16; +use crate::prelude::scalar::F32; +use crate::prelude::*; + +#[derive(Debug, Clone, Copy)] +pub enum F16L2 {} + +impl G for F16L2 { + type Scalar = F16; + + const DISTANCE: Distance = Distance::L2; + + type L2 = F16L2; + + fn distance(lhs: &[F16], rhs: &[F16]) -> F32 { + distance_squared_l2(lhs, rhs) + } + + fn l2_distance(lhs: &[F16], rhs: &[F16]) -> F32 { + distance_squared_l2(lhs, rhs) + } + + fn elkan_k_means_normalize(_: &mut [F16]) {} + + fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 { + distance_squared_l2(lhs, rhs).sqrt() + } + + #[multiversion::multiversion(targets = "simd")] + fn scalar_quantization_distance( + dims: u16, + max: &[F16], + min: &[F16], + lhs: &[F16], + rhs: &[u8], + ) -> F32 { + let mut result = F32::zero(); + for i in 0..dims as usize { + let _x = lhs[i].to_f(); + let _y = (F32(rhs[i] as f32) / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f(); + result += (_x - _y) * (_x - _y); + } + result + } + + #[multiversion::multiversion(targets = "simd")] + fn scalar_quantization_distance2( + dims: u16, + max: &[F16], + min: &[F16], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let mut result = F32::zero(); + for i in 0..dims as usize { + let _x = F32(lhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f(); + let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f(); + result += (_x - _y) * (_x - _y); + } + result + } + + #[multiversion::multiversion(targets = "simd")] + fn product_quantization_distance( + dims: u16, + ratio: u16, + centroids: &[F16], + lhs: &[F16], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut result = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + result += distance_squared_l2(lhs, rhs); + } + result + } + + #[multiversion::multiversion(targets = "simd")] + fn product_quantization_distance2( + dims: u16, + ratio: u16, + centroids: &[F16], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut result = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhsp = lhs[i as usize] as usize * dims as usize; + let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + result += distance_squared_l2(lhs, rhs); + } + result + } + + #[multiversion::multiversion(targets = "simd")] + fn product_quantization_distance_with_delta( + dims: u16, + ratio: u16, + centroids: &[F16], + lhs: &[F16], + rhs: &[u8], + delta: &[F16], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut result = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let del = &delta[(i * ratio) as usize..][..k as usize]; + result += distance_squared_l2_delta(lhs, rhs, del); + } + result + } +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +pub fn distance_squared_l2(lhs: &[F16], rhs: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut d2 = F32::zero(); + for i in 0..n { + let d = lhs[i].to_f() - rhs[i].to_f(); + d2 += d * d; + } + d2 +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +fn distance_squared_l2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut d2 = F32::zero(); + for i in 0..n { + let d = lhs[i].to_f() - (rhs[i].to_f() + del[i].to_f()); + d2 += d * d; + } + d2 +} diff --git a/crates/service/src/prelude/global/f32_cos.rs b/crates/service/src/prelude/global/f32_cos.rs new file mode 100644 index 000000000..43ae83717 --- /dev/null +++ b/crates/service/src/prelude/global/f32_cos.rs @@ -0,0 +1,222 @@ +use super::G; +use crate::prelude::scalar::F32; +use crate::prelude::*; + +#[derive(Debug, Clone, Copy)] +pub enum F32Cos {} + +impl G for F32Cos { + type Scalar = F32; + + const DISTANCE: Distance = Distance::Cos; + + type L2 = F32L2; + + #[multiversion::multiversion(targets = "simd")] + fn distance(lhs: &[F32], rhs: &[F32]) -> F32 { + cosine(lhs, rhs) * (-1.0) + } + + fn l2_distance(lhs: &[F32], rhs: &[F32]) -> F32 { + super::f32_l2::distance_squared_l2(lhs, rhs) + } + + #[multiversion::multiversion(targets = "simd")] + fn elkan_k_means_normalize(vector: &mut [F32]) { + l2_normalize(vector) + } + + #[multiversion::multiversion(targets = "simd")] + fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 { + super::f32_dot::dot(lhs, rhs).acos() + } + + #[multiversion::multiversion(targets = "simd")] + fn scalar_quantization_distance( + dims: u16, + max: &[F32], + min: &[F32], + lhs: &[F32], + rhs: &[u8], + ) -> F32 { + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..dims as usize { + let _x = lhs[i]; + let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i]; + xy += _x * _y; + x2 += _x * _x; + y2 += _y * _y; + } + xy / (x2 * y2).sqrt() * (-1.0) + } + + #[multiversion::multiversion(targets = "simd")] + fn scalar_quantization_distance2( + dims: u16, + max: &[F32], + min: &[F32], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..dims as usize { + let _x = F32(lhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i]; + let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i]; + xy += _x * _y; + x2 += _x * _x; + y2 += _y * _y; + } + xy / (x2 * y2).sqrt() * (-1.0) + } + + #[multiversion::multiversion(targets = "simd")] + fn product_quantization_distance( + dims: u16, + ratio: u16, + centroids: &[F32], + lhs: &[F32], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs); + xy += _xy; + x2 += _x2; + y2 += _y2; + } + xy / (x2 * y2).sqrt() * (-1.0) + } + + #[multiversion::multiversion(targets = "simd")] + fn product_quantization_distance2( + dims: u16, + ratio: u16, + centroids: &[F32], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhsp = lhs[i as usize] as usize * dims as usize; + let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs); + xy += _xy; + x2 += _x2; + y2 += _y2; + } + xy / (x2 * y2).sqrt() * (-1.0) + } + + #[multiversion::multiversion(targets = "simd")] + fn product_quantization_distance_with_delta( + dims: u16, + ratio: u16, + centroids: &[F32], + lhs: &[F32], + rhs: &[u8], + delta: &[F32], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let del = &delta[(i * ratio) as usize..][..k as usize]; + let (_xy, _x2, _y2) = xy_x2_y2_delta(lhs, rhs, del); + xy += _xy; + x2 += _x2; + y2 += _y2; + } + xy / (x2 * y2).sqrt() * (-1.0) + } +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +fn length(vector: &[F32]) -> F32 { + let n = vector.len(); + let mut dot = F32::zero(); + for i in 0..n { + dot += vector[i] * vector[i]; + } + dot.sqrt() +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +fn l2_normalize(vector: &mut [F32]) { + let n = vector.len(); + let l = length(vector); + for i in 0..n { + vector[i] /= l; + } +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..n { + xy += lhs[i] * rhs[i]; + x2 += lhs[i] * lhs[i]; + y2 += rhs[i] * rhs[i]; + } + xy / (x2 * y2).sqrt() +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +fn xy_x2_y2(lhs: &[F32], rhs: &[F32]) -> (F32, F32, F32) { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..n { + xy += lhs[i] * rhs[i]; + x2 += lhs[i] * lhs[i]; + y2 += rhs[i] * rhs[i]; + } + (xy, x2, y2) +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +fn xy_x2_y2_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> (F32, F32, F32) { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..n { + xy += lhs[i] * (rhs[i] + del[i]); + x2 += lhs[i] * lhs[i]; + y2 += (rhs[i] + del[i]) * (rhs[i] + del[i]); + } + (xy, x2, y2) +} diff --git a/crates/service/src/prelude/global/f32_dot.rs b/crates/service/src/prelude/global/f32_dot.rs new file mode 100644 index 000000000..2ed3be71f --- /dev/null +++ b/crates/service/src/prelude/global/f32_dot.rs @@ -0,0 +1,191 @@ +use super::G; +use crate::prelude::scalar::F32; +use crate::prelude::*; + +#[derive(Debug, Clone, Copy)] +pub enum F32Dot {} + +impl G for F32Dot { + type Scalar = F32; + + const DISTANCE: Distance = Distance::Dot; + + type L2 = F32L2; + + fn distance(lhs: &[F32], rhs: &[F32]) -> F32 { + dot(lhs, rhs) * (-1.0) + } + + fn l2_distance(lhs: &[F32], rhs: &[F32]) -> F32 { + super::f32_l2::distance_squared_l2(lhs, rhs) + } + + fn elkan_k_means_normalize(vector: &mut [F32]) { + l2_normalize(vector) + } + + fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 { + super::f32_dot::dot(lhs, rhs).acos() + } + + #[multiversion::multiversion(targets = "simd")] + fn scalar_quantization_distance( + dims: u16, + max: &[F32], + min: &[F32], + lhs: &[F32], + rhs: &[u8], + ) -> F32 { + let mut xy = F32::zero(); + for i in 0..dims as usize { + let _x = lhs[i]; + let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i]; + xy += _x * _y; + } + xy * (-1.0) + } + + #[multiversion::multiversion(targets = "simd")] + fn scalar_quantization_distance2( + dims: u16, + max: &[F32], + min: &[F32], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let mut xy = F32::zero(); + for i in 0..dims as usize { + let _x = F32(lhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i]; + let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i]; + xy += _x * _y; + } + xy * (-1.0) + } + + #[multiversion::multiversion(targets = "simd")] + fn product_quantization_distance( + dims: u16, + ratio: u16, + centroids: &[F32], + lhs: &[F32], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let _xy = dot(lhs, rhs); + xy += _xy; + } + xy * (-1.0) + } + + #[multiversion::multiversion(targets = "simd")] + fn product_quantization_distance2( + dims: u16, + ratio: u16, + centroids: &[F32], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhsp = lhs[i as usize] as usize * dims as usize; + let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let _xy = dot(lhs, rhs); + xy += _xy; + } + xy * (-1.0) + } + + #[multiversion::multiversion(targets = "simd")] + fn product_quantization_distance_with_delta( + dims: u16, + ratio: u16, + centroids: &[F32], + lhs: &[F32], + rhs: &[u8], + delta: &[F32], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let del = &delta[(i * ratio) as usize..][..k as usize]; + let _xy = dot_delta(lhs, rhs, del); + xy += _xy; + } + xy * (-1.0) + } +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +fn length(vector: &[F32]) -> F32 { + let n = vector.len(); + let mut dot = F32::zero(); + for i in 0..n { + dot += vector[i] * vector[i]; + } + dot.sqrt() +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +fn l2_normalize(vector: &mut [F32]) { + let n = vector.len(); + let l = length(vector); + for i in 0..n { + vector[i] /= l; + } +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..n { + xy += lhs[i] * rhs[i]; + x2 += lhs[i] * lhs[i]; + y2 += rhs[i] * rhs[i]; + } + xy / (x2 * y2).sqrt() +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +pub fn dot(lhs: &[F32], rhs: &[F32]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + for i in 0..n { + xy += lhs[i] * rhs[i]; + } + xy +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +fn dot_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n: usize = lhs.len(); + let mut xy = F32::zero(); + for i in 0..n { + xy += lhs[i] * (rhs[i] + del[i]); + } + xy +} diff --git a/crates/service/src/prelude/global/f32_l2.rs b/crates/service/src/prelude/global/f32_l2.rs new file mode 100644 index 000000000..aca474de8 --- /dev/null +++ b/crates/service/src/prelude/global/f32_l2.rs @@ -0,0 +1,151 @@ +use super::G; +use crate::prelude::scalar::F32; +use crate::prelude::*; + +#[derive(Debug, Clone, Copy)] +pub enum F32L2 {} + +impl G for F32L2 { + type Scalar = F32; + + const DISTANCE: Distance = Distance::L2; + + type L2 = F32L2; + + fn distance(lhs: &[F32], rhs: &[F32]) -> F32 { + distance_squared_l2(lhs, rhs) + } + + fn l2_distance(lhs: &[F32], rhs: &[F32]) -> F32 { + distance_squared_l2(lhs, rhs) + } + + fn elkan_k_means_normalize(_: &mut [F32]) {} + + fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 { + distance_squared_l2(lhs, rhs).sqrt() + } + + #[multiversion::multiversion(targets = "simd")] + fn scalar_quantization_distance( + dims: u16, + max: &[F32], + min: &[F32], + lhs: &[F32], + rhs: &[u8], + ) -> F32 { + let mut result = F32::zero(); + for i in 0..dims as usize { + let _x = lhs[i]; + let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i]; + result += (_x - _y) * (_x - _y); + } + result + } + + #[multiversion::multiversion(targets = "simd")] + fn scalar_quantization_distance2( + dims: u16, + max: &[F32], + min: &[F32], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let mut result = F32::zero(); + for i in 0..dims as usize { + let _x = F32(lhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i]; + let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i]; + result += (_x - _y) * (_x - _y); + } + result + } + + #[multiversion::multiversion(targets = "simd")] + fn product_quantization_distance( + dims: u16, + ratio: u16, + centroids: &[F32], + lhs: &[F32], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut result = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + result += distance_squared_l2(lhs, rhs); + } + result + } + + #[multiversion::multiversion(targets = "simd")] + fn product_quantization_distance2( + dims: u16, + ratio: u16, + centroids: &[F32], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut result = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhsp = lhs[i as usize] as usize * dims as usize; + let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + result += distance_squared_l2(lhs, rhs); + } + result + } + + #[multiversion::multiversion(targets = "simd")] + fn product_quantization_distance_with_delta( + dims: u16, + ratio: u16, + centroids: &[F32], + lhs: &[F32], + rhs: &[u8], + delta: &[F32], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut result = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let del = &delta[(i * ratio) as usize..][..k as usize]; + result += distance_squared_l2_delta(lhs, rhs, del); + } + result + } +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +pub fn distance_squared_l2(lhs: &[F32], rhs: &[F32]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut d2 = F32::zero(); + for i in 0..n { + let d = lhs[i] - rhs[i]; + d2 += d * d; + } + d2 +} + +#[inline(always)] +#[multiversion::multiversion(targets = "simd")] +fn distance_squared_l2_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut d2 = F32::zero(); + for i in 0..n { + let d = lhs[i] - (rhs[i] + del[i]); + d2 += d * d; + } + d2 +} diff --git a/crates/service/src/prelude/global/mod.rs b/crates/service/src/prelude/global/mod.rs new file mode 100644 index 000000000..f0936ddfe --- /dev/null +++ b/crates/service/src/prelude/global/mod.rs @@ -0,0 +1,121 @@ +mod f16_cos; +mod f16_dot; +mod f16_l2; +mod f32_cos; +mod f32_dot; +mod f32_l2; + +pub use f16_cos::F16Cos; +pub use f16_dot::F16Dot; +pub use f16_l2::F16L2; +pub use f32_cos::F32Cos; +pub use f32_dot::F32Dot; +pub use f32_l2::F32L2; + +use crate::prelude::*; +use serde::{Deserialize, Serialize}; +use std::fmt::Debug; + +pub trait G: Copy + std::fmt::Debug + 'static { + type Scalar: Copy + + Send + + Sync + + std::fmt::Debug + + std::fmt::Display + + serde::Serialize + + for<'a> serde::Deserialize<'a> + + Ord + + bytemuck::Zeroable + + bytemuck::Pod + + num_traits::Float + + num_traits::NumOps + + num_traits::NumAssignOps + + FloatCast; + const DISTANCE: Distance; + type L2: G; + + fn distance(lhs: &[Self::Scalar], rhs: &[Self::Scalar]) -> F32; + fn l2_distance(lhs: &[Self::Scalar], rhs: &[Self::Scalar]) -> F32; + fn elkan_k_means_normalize(vector: &mut [Self::Scalar]); + fn elkan_k_means_distance(lhs: &[Self::Scalar], rhs: &[Self::Scalar]) -> F32; + fn scalar_quantization_distance( + dims: u16, + max: &[Self::Scalar], + min: &[Self::Scalar], + lhs: &[Self::Scalar], + rhs: &[u8], + ) -> F32; + fn scalar_quantization_distance2( + dims: u16, + max: &[Self::Scalar], + min: &[Self::Scalar], + lhs: &[u8], + rhs: &[u8], + ) -> F32; + fn product_quantization_distance( + dims: u16, + ratio: u16, + centroids: &[Self::Scalar], + lhs: &[Self::Scalar], + rhs: &[u8], + ) -> F32; + fn product_quantization_distance2( + dims: u16, + ratio: u16, + centroids: &[Self::Scalar], + lhs: &[u8], + rhs: &[u8], + ) -> F32; + fn product_quantization_distance_with_delta( + dims: u16, + ratio: u16, + centroids: &[Self::Scalar], + lhs: &[Self::Scalar], + rhs: &[u8], + delta: &[Self::Scalar], + ) -> F32; +} + +pub trait FloatCast: Sized { + fn from_f32(x: f32) -> Self; + fn to_f32(self) -> f32; + fn from_f(x: F32) -> Self { + Self::from_f32(x.0) + } + fn to_f(self) -> F32 { + F32(Self::to_f32(self)) + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub enum DynamicVector { + F32(Vec), + F16(Vec), +} + +impl From> for DynamicVector { + fn from(value: Vec) -> Self { + Self::F32(value) + } +} + +impl From> for DynamicVector { + fn from(value: Vec) -> Self { + Self::F16(value) + } +} + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum Distance { + L2, + Cos, + Dot, +} + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum Kind { + F32, + F16, +} diff --git a/src/prelude/heap.rs b/crates/service/src/prelude/heap.rs similarity index 89% rename from src/prelude/heap.rs rename to crates/service/src/prelude/heap.rs index f543bdeb5..0431541b6 100644 --- a/src/prelude/heap.rs +++ b/crates/service/src/prelude/heap.rs @@ -1,9 +1,9 @@ -use crate::prelude::{Payload, Scalar}; +use crate::prelude::{Payload, F32}; use std::{cmp::Reverse, collections::BinaryHeap}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub struct HeapElement { - pub distance: Scalar, + pub distance: F32, pub payload: Payload, } @@ -20,7 +20,7 @@ impl Heap { k, } } - pub fn check(&self, distance: Scalar) -> bool { + pub fn check(&self, distance: F32) -> bool { self.binary_heap.len() < self.k || distance < self.binary_heap.peek().unwrap().distance } pub fn push(&mut self, element: HeapElement) { diff --git a/crates/service/src/prelude/mod.rs b/crates/service/src/prelude/mod.rs new file mode 100644 index 000000000..5b66102f2 --- /dev/null +++ b/crates/service/src/prelude/mod.rs @@ -0,0 +1,16 @@ +mod error; +mod filter; +mod global; +mod heap; +mod scalar; +mod sys; + +pub use self::error::{FriendlyError, FriendlyErrorLike, FriendlyResult}; +pub use self::global::*; +pub use self::scalar::{F16, F32}; + +pub use self::filter::{Filter, Payload}; +pub use self::heap::{Heap, HeapElement}; +pub use self::sys::{Id, Pointer}; + +pub use num_traits::{Float, Zero}; diff --git a/crates/service/src/prelude/scalar/f16.rs b/crates/service/src/prelude/scalar/f16.rs new file mode 100644 index 000000000..d8c9ee977 --- /dev/null +++ b/crates/service/src/prelude/scalar/f16.rs @@ -0,0 +1,653 @@ +use crate::prelude::global::FloatCast; +use half::f16; +use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; +use std::fmt::{Debug, Display}; +use std::num::ParseFloatError; +use std::ops::*; +use std::str::FromStr; + +#[derive(Clone, Copy, Default, Serialize, Deserialize)] +#[repr(transparent)] +#[serde(transparent)] +pub struct F16(pub f16); + +impl Debug for F16 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Debug::fmt(&self.0, f) + } +} + +impl Display for F16 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Display::fmt(&self.0, f) + } +} + +impl PartialEq for F16 { + fn eq(&self, other: &Self) -> bool { + self.0.total_cmp(&other.0) == Ordering::Equal + } +} + +impl Eq for F16 {} + +impl PartialOrd for F16 { + #[inline(always)] + fn partial_cmp(&self, other: &Self) -> Option { + Some(Ord::cmp(self, other)) + } +} + +impl Ord for F16 { + #[inline(always)] + fn cmp(&self, other: &Self) -> Ordering { + self.0.total_cmp(&other.0) + } +} + +unsafe impl bytemuck::Zeroable for F16 {} + +unsafe impl bytemuck::Pod for F16 {} + +impl num_traits::Zero for F16 { + fn zero() -> Self { + Self(f16::zero()) + } + + fn is_zero(&self) -> bool { + self.0.is_zero() + } +} + +impl num_traits::One for F16 { + fn one() -> Self { + Self(f16::one()) + } +} + +impl num_traits::FromPrimitive for F16 { + fn from_i64(n: i64) -> Option { + f16::from_i64(n).map(Self) + } + + fn from_u64(n: u64) -> Option { + f16::from_u64(n).map(Self) + } + + fn from_isize(n: isize) -> Option { + f16::from_isize(n).map(Self) + } + + fn from_i8(n: i8) -> Option { + f16::from_i8(n).map(Self) + } + + fn from_i16(n: i16) -> Option { + f16::from_i16(n).map(Self) + } + + fn from_i32(n: i32) -> Option { + f16::from_i32(n).map(Self) + } + + fn from_i128(n: i128) -> Option { + f16::from_i128(n).map(Self) + } + + fn from_usize(n: usize) -> Option { + f16::from_usize(n).map(Self) + } + + fn from_u8(n: u8) -> Option { + f16::from_u8(n).map(Self) + } + + fn from_u16(n: u16) -> Option { + f16::from_u16(n).map(Self) + } + + fn from_u32(n: u32) -> Option { + f16::from_u32(n).map(Self) + } + + fn from_u128(n: u128) -> Option { + f16::from_u128(n).map(Self) + } + + fn from_f32(n: f32) -> Option { + Some(Self(f16::from_f32(n))) + } + + fn from_f64(n: f64) -> Option { + Some(Self(f16::from_f64(n))) + } +} + +impl num_traits::ToPrimitive for F16 { + fn to_i64(&self) -> Option { + self.0.to_i64() + } + + fn to_u64(&self) -> Option { + self.0.to_u64() + } + + fn to_isize(&self) -> Option { + self.0.to_isize() + } + + fn to_i8(&self) -> Option { + self.0.to_i8() + } + + fn to_i16(&self) -> Option { + self.0.to_i16() + } + + fn to_i32(&self) -> Option { + self.0.to_i32() + } + + fn to_i128(&self) -> Option { + self.0.to_i128() + } + + fn to_usize(&self) -> Option { + self.0.to_usize() + } + + fn to_u8(&self) -> Option { + self.0.to_u8() + } + + fn to_u16(&self) -> Option { + self.0.to_u16() + } + + fn to_u32(&self) -> Option { + self.0.to_u32() + } + + fn to_u128(&self) -> Option { + self.0.to_u128() + } + + fn to_f32(&self) -> Option { + Some(self.0.to_f32()) + } + + fn to_f64(&self) -> Option { + Some(self.0.to_f64()) + } +} + +impl num_traits::NumCast for F16 { + fn from(n: T) -> Option { + num_traits::NumCast::from(n).map(Self) + } +} + +impl num_traits::Num for F16 { + type FromStrRadixErr = ::FromStrRadixErr; + + fn from_str_radix(str: &str, radix: u32) -> Result { + f16::from_str_radix(str, radix).map(Self) + } +} + +impl num_traits::Float for F16 { + fn nan() -> Self { + Self(f16::nan()) + } + + fn infinity() -> Self { + Self(f16::infinity()) + } + + fn neg_infinity() -> Self { + Self(f16::neg_infinity()) + } + + fn neg_zero() -> Self { + Self(f16::neg_zero()) + } + + fn min_value() -> Self { + Self(f16::min_value()) + } + + fn min_positive_value() -> Self { + Self(f16::min_positive_value()) + } + + fn max_value() -> Self { + Self(f16::max_value()) + } + + fn is_nan(self) -> bool { + self.0.is_nan() + } + + fn is_infinite(self) -> bool { + self.0.is_infinite() + } + + fn is_finite(self) -> bool { + self.0.is_finite() + } + + fn is_normal(self) -> bool { + self.0.is_normal() + } + + fn classify(self) -> std::num::FpCategory { + self.0.classify() + } + + fn floor(self) -> Self { + Self(self.0.floor()) + } + + fn ceil(self) -> Self { + Self(self.0.ceil()) + } + + fn round(self) -> Self { + Self(self.0.round()) + } + + fn trunc(self) -> Self { + Self(self.0.trunc()) + } + + fn fract(self) -> Self { + Self(self.0.fract()) + } + + fn abs(self) -> Self { + Self(self.0.abs()) + } + + fn signum(self) -> Self { + Self(self.0.signum()) + } + + fn is_sign_positive(self) -> bool { + self.0.is_sign_positive() + } + + fn is_sign_negative(self) -> bool { + self.0.is_sign_negative() + } + + fn mul_add(self, a: Self, b: Self) -> Self { + Self(self.0.mul_add(a.0, b.0)) + } + + fn recip(self) -> Self { + Self(self.0.recip()) + } + + fn powi(self, n: i32) -> Self { + Self(self.0.powi(n)) + } + + fn powf(self, n: Self) -> Self { + Self(self.0.powf(n.0)) + } + + fn sqrt(self) -> Self { + Self(self.0.sqrt()) + } + + fn exp(self) -> Self { + Self(self.0.exp()) + } + + fn exp2(self) -> Self { + Self(self.0.exp2()) + } + + fn ln(self) -> Self { + Self(self.0.ln()) + } + + fn log(self, base: Self) -> Self { + Self(self.0.log(base.0)) + } + + fn log2(self) -> Self { + Self(self.0.log2()) + } + + fn log10(self) -> Self { + Self(self.0.log10()) + } + + fn max(self, other: Self) -> Self { + Self(self.0.max(other.0)) + } + + fn min(self, other: Self) -> Self { + Self(self.0.min(other.0)) + } + + fn abs_sub(self, _: Self) -> Self { + unimplemented!() + } + + fn cbrt(self) -> Self { + Self(self.0.cbrt()) + } + + fn hypot(self, other: Self) -> Self { + Self(self.0.hypot(other.0)) + } + + fn sin(self) -> Self { + Self(self.0.sin()) + } + + fn cos(self) -> Self { + Self(self.0.cos()) + } + + fn tan(self) -> Self { + Self(self.0.tan()) + } + + fn asin(self) -> Self { + Self(self.0.asin()) + } + + fn acos(self) -> Self { + Self(self.0.acos()) + } + + fn atan(self) -> Self { + Self(self.0.atan()) + } + + fn atan2(self, other: Self) -> Self { + Self(self.0.atan2(other.0)) + } + + fn sin_cos(self) -> (Self, Self) { + let (_0, _1) = self.0.sin_cos(); + (Self(_0), Self(_1)) + } + + fn exp_m1(self) -> Self { + Self(self.0.exp_m1()) + } + + fn ln_1p(self) -> Self { + Self(self.0.ln_1p()) + } + + fn sinh(self) -> Self { + Self(self.0.sinh()) + } + + fn cosh(self) -> Self { + Self(self.0.cosh()) + } + + fn tanh(self) -> Self { + Self(self.0.tanh()) + } + + fn asinh(self) -> Self { + Self(self.0.asinh()) + } + + fn acosh(self) -> Self { + Self(self.0.acosh()) + } + + fn atanh(self) -> Self { + Self(self.0.atanh()) + } + + fn integer_decode(self) -> (u64, i16, i8) { + self.0.integer_decode() + } + + fn epsilon() -> Self { + Self(f16::EPSILON) + } + + fn is_subnormal(self) -> bool { + self.0.classify() == std::num::FpCategory::Subnormal + } + + fn to_degrees(self) -> Self { + Self(self.0.to_degrees()) + } + + fn to_radians(self) -> Self { + Self(self.0.to_radians()) + } + + fn copysign(self, sign: Self) -> Self { + Self(self.0.copysign(sign.0)) + } +} + +impl Add for F16 { + type Output = F16; + + #[inline(always)] + fn add(self, rhs: F16) -> F16 { + unsafe { self::intrinsics::fadd_fast(self.0, rhs.0).into() } + } +} + +impl AddAssign for F16 { + #[inline(always)] + fn add_assign(&mut self, rhs: F16) { + unsafe { self.0 = self::intrinsics::fadd_fast(self.0, rhs.0) } + } +} + +impl Sub for F16 { + type Output = F16; + + #[inline(always)] + fn sub(self, rhs: F16) -> F16 { + unsafe { self::intrinsics::fsub_fast(self.0, rhs.0).into() } + } +} + +impl SubAssign for F16 { + #[inline(always)] + fn sub_assign(&mut self, rhs: F16) { + unsafe { self.0 = self::intrinsics::fsub_fast(self.0, rhs.0) } + } +} + +impl Mul for F16 { + type Output = F16; + + #[inline(always)] + fn mul(self, rhs: F16) -> F16 { + unsafe { self::intrinsics::fmul_fast(self.0, rhs.0).into() } + } +} + +impl MulAssign for F16 { + #[inline(always)] + fn mul_assign(&mut self, rhs: F16) { + unsafe { self.0 = self::intrinsics::fmul_fast(self.0, rhs.0) } + } +} + +impl Div for F16 { + type Output = F16; + + #[inline(always)] + fn div(self, rhs: F16) -> F16 { + unsafe { self::intrinsics::fdiv_fast(self.0, rhs.0).into() } + } +} + +impl DivAssign for F16 { + #[inline(always)] + fn div_assign(&mut self, rhs: F16) { + unsafe { self.0 = self::intrinsics::fdiv_fast(self.0, rhs.0) } + } +} + +impl Rem for F16 { + type Output = F16; + + #[inline(always)] + fn rem(self, rhs: F16) -> F16 { + unsafe { self::intrinsics::frem_fast(self.0, rhs.0).into() } + } +} + +impl RemAssign for F16 { + #[inline(always)] + fn rem_assign(&mut self, rhs: F16) { + unsafe { self.0 = self::intrinsics::frem_fast(self.0, rhs.0) } + } +} + +impl Neg for F16 { + type Output = Self; + + fn neg(self) -> Self::Output { + Self(self.0.neg()) + } +} + +impl FromStr for F16 { + type Err = ParseFloatError; + + fn from_str(s: &str) -> Result { + f16::from_str(s).map(|x| x.into()) + } +} + +impl FloatCast for F16 { + fn from_f32(x: f32) -> Self { + Self(f16::from_f32(x)) + } + + fn to_f32(self) -> f32 { + f16::to_f32(self.0) + } +} + +impl From for F16 { + fn from(value: f16) -> Self { + Self(value) + } +} + +impl From for f16 { + fn from(F16(float): F16) -> Self { + float + } +} + +impl Add for F16 { + type Output = F16; + + #[inline(always)] + fn add(self, rhs: f16) -> F16 { + unsafe { self::intrinsics::fadd_fast(self.0, rhs).into() } + } +} + +impl AddAssign for F16 { + fn add_assign(&mut self, rhs: f16) { + unsafe { self.0 = self::intrinsics::fadd_fast(self.0, rhs) } + } +} + +impl Sub for F16 { + type Output = F16; + + #[inline(always)] + fn sub(self, rhs: f16) -> F16 { + unsafe { self::intrinsics::fsub_fast(self.0, rhs).into() } + } +} + +impl SubAssign for F16 { + #[inline(always)] + fn sub_assign(&mut self, rhs: f16) { + unsafe { self.0 = self::intrinsics::fsub_fast(self.0, rhs) } + } +} + +impl Mul for F16 { + type Output = F16; + + #[inline(always)] + fn mul(self, rhs: f16) -> F16 { + unsafe { self::intrinsics::fmul_fast(self.0, rhs).into() } + } +} + +impl MulAssign for F16 { + #[inline(always)] + fn mul_assign(&mut self, rhs: f16) { + unsafe { self.0 = self::intrinsics::fmul_fast(self.0, rhs) } + } +} + +impl Div for F16 { + type Output = F16; + + #[inline(always)] + fn div(self, rhs: f16) -> F16 { + unsafe { self::intrinsics::fdiv_fast(self.0, rhs).into() } + } +} + +impl DivAssign for F16 { + #[inline(always)] + fn div_assign(&mut self, rhs: f16) { + unsafe { self.0 = self::intrinsics::fdiv_fast(self.0, rhs) } + } +} + +impl Rem for F16 { + type Output = F16; + + #[inline(always)] + fn rem(self, rhs: f16) -> F16 { + unsafe { self::intrinsics::frem_fast(self.0, rhs).into() } + } +} + +impl RemAssign for F16 { + #[inline(always)] + fn rem_assign(&mut self, rhs: f16) { + unsafe { self.0 = self::intrinsics::frem_fast(self.0, rhs) } + } +} + +mod intrinsics { + use half::f16; + + pub unsafe fn fadd_fast(lhs: f16, rhs: f16) -> f16 { + lhs + rhs + } + pub unsafe fn fsub_fast(lhs: f16, rhs: f16) -> f16 { + lhs - rhs + } + pub unsafe fn fmul_fast(lhs: f16, rhs: f16) -> f16 { + lhs * rhs + } + pub unsafe fn fdiv_fast(lhs: f16, rhs: f16) -> f16 { + lhs / rhs + } + pub unsafe fn frem_fast(lhs: f16, rhs: f16) -> f16 { + lhs % rhs + } +} diff --git a/crates/service/src/prelude/scalar/f32.rs b/crates/service/src/prelude/scalar/f32.rs new file mode 100644 index 000000000..f80700b12 --- /dev/null +++ b/crates/service/src/prelude/scalar/f32.rs @@ -0,0 +1,632 @@ +use crate::prelude::global::FloatCast; +use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; +use std::fmt::{Debug, Display}; +use std::num::ParseFloatError; +use std::ops::*; +use std::str::FromStr; + +#[derive(Clone, Copy, Default, Serialize, Deserialize)] +#[repr(transparent)] +#[serde(transparent)] +pub struct F32(pub f32); + +impl Debug for F32 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Debug::fmt(&self.0, f) + } +} + +impl Display for F32 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Display::fmt(&self.0, f) + } +} + +impl PartialEq for F32 { + fn eq(&self, other: &Self) -> bool { + self.0.total_cmp(&other.0) == Ordering::Equal + } +} + +impl Eq for F32 {} + +impl PartialOrd for F32 { + #[inline(always)] + fn partial_cmp(&self, other: &Self) -> Option { + Some(Ord::cmp(self, other)) + } +} + +impl Ord for F32 { + #[inline(always)] + fn cmp(&self, other: &Self) -> Ordering { + self.0.total_cmp(&other.0) + } +} + +unsafe impl bytemuck::Zeroable for F32 {} + +unsafe impl bytemuck::Pod for F32 {} + +impl num_traits::Zero for F32 { + fn zero() -> Self { + Self(f32::zero()) + } + + fn is_zero(&self) -> bool { + self.0.is_zero() + } +} + +impl num_traits::One for F32 { + fn one() -> Self { + Self(f32::one()) + } +} + +impl num_traits::FromPrimitive for F32 { + fn from_i64(n: i64) -> Option { + f32::from_i64(n).map(Self) + } + + fn from_u64(n: u64) -> Option { + f32::from_u64(n).map(Self) + } + + fn from_isize(n: isize) -> Option { + f32::from_isize(n).map(Self) + } + + fn from_i8(n: i8) -> Option { + f32::from_i8(n).map(Self) + } + + fn from_i16(n: i16) -> Option { + f32::from_i16(n).map(Self) + } + + fn from_i32(n: i32) -> Option { + f32::from_i32(n).map(Self) + } + + fn from_i128(n: i128) -> Option { + f32::from_i128(n).map(Self) + } + + fn from_usize(n: usize) -> Option { + f32::from_usize(n).map(Self) + } + + fn from_u8(n: u8) -> Option { + f32::from_u8(n).map(Self) + } + + fn from_u16(n: u16) -> Option { + f32::from_u16(n).map(Self) + } + + fn from_u32(n: u32) -> Option { + f32::from_u32(n).map(Self) + } + + fn from_u128(n: u128) -> Option { + f32::from_u128(n).map(Self) + } + + fn from_f32(n: f32) -> Option { + f32::from_f32(n).map(Self) + } + + fn from_f64(n: f64) -> Option { + f32::from_f64(n).map(Self) + } +} + +impl num_traits::ToPrimitive for F32 { + fn to_i64(&self) -> Option { + self.0.to_i64() + } + + fn to_u64(&self) -> Option { + self.0.to_u64() + } + + fn to_isize(&self) -> Option { + self.0.to_isize() + } + + fn to_i8(&self) -> Option { + self.0.to_i8() + } + + fn to_i16(&self) -> Option { + self.0.to_i16() + } + + fn to_i32(&self) -> Option { + self.0.to_i32() + } + + fn to_i128(&self) -> Option { + self.0.to_i128() + } + + fn to_usize(&self) -> Option { + self.0.to_usize() + } + + fn to_u8(&self) -> Option { + self.0.to_u8() + } + + fn to_u16(&self) -> Option { + self.0.to_u16() + } + + fn to_u32(&self) -> Option { + self.0.to_u32() + } + + fn to_u128(&self) -> Option { + self.0.to_u128() + } + + fn to_f32(&self) -> Option { + self.0.to_f32() + } + + fn to_f64(&self) -> Option { + self.0.to_f64() + } +} + +impl num_traits::NumCast for F32 { + fn from(n: T) -> Option { + num_traits::NumCast::from(n).map(Self) + } +} + +impl num_traits::Num for F32 { + type FromStrRadixErr = ::FromStrRadixErr; + + fn from_str_radix(str: &str, radix: u32) -> Result { + f32::from_str_radix(str, radix).map(Self) + } +} + +impl num_traits::Float for F32 { + fn nan() -> Self { + Self(f32::nan()) + } + + fn infinity() -> Self { + Self(f32::infinity()) + } + + fn neg_infinity() -> Self { + Self(f32::neg_infinity()) + } + + fn neg_zero() -> Self { + Self(f32::neg_zero()) + } + + fn min_value() -> Self { + Self(f32::min_value()) + } + + fn min_positive_value() -> Self { + Self(f32::min_positive_value()) + } + + fn max_value() -> Self { + Self(f32::max_value()) + } + + fn is_nan(self) -> bool { + self.0.is_nan() + } + + fn is_infinite(self) -> bool { + self.0.is_infinite() + } + + fn is_finite(self) -> bool { + self.0.is_finite() + } + + fn is_normal(self) -> bool { + self.0.is_normal() + } + + fn classify(self) -> std::num::FpCategory { + self.0.classify() + } + + fn floor(self) -> Self { + Self(self.0.floor()) + } + + fn ceil(self) -> Self { + Self(self.0.ceil()) + } + + fn round(self) -> Self { + Self(self.0.round()) + } + + fn trunc(self) -> Self { + Self(self.0.trunc()) + } + + fn fract(self) -> Self { + Self(self.0.fract()) + } + + fn abs(self) -> Self { + Self(self.0.abs()) + } + + fn signum(self) -> Self { + Self(self.0.signum()) + } + + fn is_sign_positive(self) -> bool { + self.0.is_sign_positive() + } + + fn is_sign_negative(self) -> bool { + self.0.is_sign_negative() + } + + fn mul_add(self, a: Self, b: Self) -> Self { + Self(self.0.mul_add(a.0, b.0)) + } + + fn recip(self) -> Self { + Self(self.0.recip()) + } + + fn powi(self, n: i32) -> Self { + Self(self.0.powi(n)) + } + + fn powf(self, n: Self) -> Self { + Self(self.0.powf(n.0)) + } + + fn sqrt(self) -> Self { + Self(self.0.sqrt()) + } + + fn exp(self) -> Self { + Self(self.0.exp()) + } + + fn exp2(self) -> Self { + Self(self.0.exp2()) + } + + fn ln(self) -> Self { + Self(self.0.ln()) + } + + fn log(self, base: Self) -> Self { + Self(self.0.log(base.0)) + } + + fn log2(self) -> Self { + Self(self.0.log2()) + } + + fn log10(self) -> Self { + Self(self.0.log10()) + } + + fn max(self, other: Self) -> Self { + Self(self.0.max(other.0)) + } + + fn min(self, other: Self) -> Self { + Self(self.0.min(other.0)) + } + + fn abs_sub(self, _: Self) -> Self { + unimplemented!() + } + + fn cbrt(self) -> Self { + Self(self.0.cbrt()) + } + + fn hypot(self, other: Self) -> Self { + Self(self.0.hypot(other.0)) + } + + fn sin(self) -> Self { + Self(self.0.sin()) + } + + fn cos(self) -> Self { + Self(self.0.cos()) + } + + fn tan(self) -> Self { + Self(self.0.tan()) + } + + fn asin(self) -> Self { + Self(self.0.asin()) + } + + fn acos(self) -> Self { + Self(self.0.acos()) + } + + fn atan(self) -> Self { + Self(self.0.atan()) + } + + fn atan2(self, other: Self) -> Self { + Self(self.0.atan2(other.0)) + } + + fn sin_cos(self) -> (Self, Self) { + let (_0, _1) = self.0.sin_cos(); + (Self(_0), Self(_1)) + } + + fn exp_m1(self) -> Self { + Self(self.0.exp_m1()) + } + + fn ln_1p(self) -> Self { + Self(self.0.ln_1p()) + } + + fn sinh(self) -> Self { + Self(self.0.sinh()) + } + + fn cosh(self) -> Self { + Self(self.0.cosh()) + } + + fn tanh(self) -> Self { + Self(self.0.tanh()) + } + + fn asinh(self) -> Self { + Self(self.0.asinh()) + } + + fn acosh(self) -> Self { + Self(self.0.acosh()) + } + + fn atanh(self) -> Self { + Self(self.0.atanh()) + } + + fn integer_decode(self) -> (u64, i16, i8) { + self.0.integer_decode() + } + + fn epsilon() -> Self { + Self(f32::EPSILON) + } + + fn is_subnormal(self) -> bool { + self.0.classify() == std::num::FpCategory::Subnormal + } + + fn to_degrees(self) -> Self { + Self(self.0.to_degrees()) + } + + fn to_radians(self) -> Self { + Self(self.0.to_radians()) + } + + fn copysign(self, sign: Self) -> Self { + Self(self.0.copysign(sign.0)) + } +} + +impl Add for F32 { + type Output = F32; + + #[inline(always)] + fn add(self, rhs: F32) -> F32 { + unsafe { std::intrinsics::fadd_fast(self.0, rhs.0).into() } + } +} + +impl AddAssign for F32 { + #[inline(always)] + fn add_assign(&mut self, rhs: F32) { + unsafe { self.0 = std::intrinsics::fadd_fast(self.0, rhs.0) } + } +} + +impl Sub for F32 { + type Output = F32; + + #[inline(always)] + fn sub(self, rhs: F32) -> F32 { + unsafe { std::intrinsics::fsub_fast(self.0, rhs.0).into() } + } +} + +impl SubAssign for F32 { + #[inline(always)] + fn sub_assign(&mut self, rhs: F32) { + unsafe { self.0 = std::intrinsics::fsub_fast(self.0, rhs.0) } + } +} + +impl Mul for F32 { + type Output = F32; + + #[inline(always)] + fn mul(self, rhs: F32) -> F32 { + unsafe { std::intrinsics::fmul_fast(self.0, rhs.0).into() } + } +} + +impl MulAssign for F32 { + #[inline(always)] + fn mul_assign(&mut self, rhs: F32) { + unsafe { self.0 = std::intrinsics::fmul_fast(self.0, rhs.0) } + } +} + +impl Div for F32 { + type Output = F32; + + #[inline(always)] + fn div(self, rhs: F32) -> F32 { + unsafe { std::intrinsics::fdiv_fast(self.0, rhs.0).into() } + } +} + +impl DivAssign for F32 { + #[inline(always)] + fn div_assign(&mut self, rhs: F32) { + unsafe { self.0 = std::intrinsics::fdiv_fast(self.0, rhs.0) } + } +} + +impl Rem for F32 { + type Output = F32; + + #[inline(always)] + fn rem(self, rhs: F32) -> F32 { + unsafe { std::intrinsics::frem_fast(self.0, rhs.0).into() } + } +} + +impl RemAssign for F32 { + #[inline(always)] + fn rem_assign(&mut self, rhs: F32) { + unsafe { self.0 = std::intrinsics::frem_fast(self.0, rhs.0) } + } +} + +impl Neg for F32 { + type Output = Self; + + fn neg(self) -> Self::Output { + Self(self.0.neg()) + } +} + +impl FromStr for F32 { + type Err = ParseFloatError; + + fn from_str(s: &str) -> Result { + f32::from_str(s).map(|x| x.into()) + } +} + +impl FloatCast for F32 { + fn from_f32(x: f32) -> Self { + Self(x) + } + + fn to_f32(self) -> f32 { + self.0 + } +} + +impl From for F32 { + fn from(value: f32) -> Self { + Self(value) + } +} + +impl From for f32 { + fn from(F32(float): F32) -> Self { + float + } +} + +impl Add for F32 { + type Output = F32; + + #[inline(always)] + fn add(self, rhs: f32) -> F32 { + unsafe { std::intrinsics::fadd_fast(self.0, rhs).into() } + } +} + +impl AddAssign for F32 { + fn add_assign(&mut self, rhs: f32) { + unsafe { self.0 = std::intrinsics::fadd_fast(self.0, rhs) } + } +} + +impl Sub for F32 { + type Output = F32; + + #[inline(always)] + fn sub(self, rhs: f32) -> F32 { + unsafe { std::intrinsics::fsub_fast(self.0, rhs).into() } + } +} + +impl SubAssign for F32 { + #[inline(always)] + fn sub_assign(&mut self, rhs: f32) { + unsafe { self.0 = std::intrinsics::fsub_fast(self.0, rhs) } + } +} + +impl Mul for F32 { + type Output = F32; + + #[inline(always)] + fn mul(self, rhs: f32) -> F32 { + unsafe { std::intrinsics::fmul_fast(self.0, rhs).into() } + } +} + +impl MulAssign for F32 { + #[inline(always)] + fn mul_assign(&mut self, rhs: f32) { + unsafe { self.0 = std::intrinsics::fmul_fast(self.0, rhs) } + } +} + +impl Div for F32 { + type Output = F32; + + #[inline(always)] + fn div(self, rhs: f32) -> F32 { + unsafe { std::intrinsics::fdiv_fast(self.0, rhs).into() } + } +} + +impl DivAssign for F32 { + #[inline(always)] + fn div_assign(&mut self, rhs: f32) { + unsafe { self.0 = std::intrinsics::fdiv_fast(self.0, rhs) } + } +} + +impl Rem for F32 { + type Output = F32; + + #[inline(always)] + fn rem(self, rhs: f32) -> F32 { + unsafe { std::intrinsics::frem_fast(self.0, rhs).into() } + } +} + +impl RemAssign for F32 { + #[inline(always)] + fn rem_assign(&mut self, rhs: f32) { + unsafe { self.0 = std::intrinsics::frem_fast(self.0, rhs) } + } +} diff --git a/crates/service/src/prelude/scalar/mod.rs b/crates/service/src/prelude/scalar/mod.rs new file mode 100644 index 000000000..1894a906f --- /dev/null +++ b/crates/service/src/prelude/scalar/mod.rs @@ -0,0 +1,5 @@ +mod f16; +mod f32; + +pub use f16::F16; +pub use f32::F32; diff --git a/src/prelude/sys.rs b/crates/service/src/prelude/sys.rs similarity index 53% rename from src/prelude/sys.rs rename to crates/service/src/prelude/sys.rs index 0318149ef..640ff7a72 100644 --- a/src/prelude/sys.rs +++ b/crates/service/src/prelude/sys.rs @@ -3,15 +3,10 @@ use std::{fmt::Display, num::ParseIntError, str::FromStr}; #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct Id { - newtype: u32, + pub newtype: u32, } impl Id { - pub fn from_sys(sys: pgrx::pg_sys::Oid) -> Self { - Self { - newtype: sys.as_u32(), - } - } pub fn as_u32(self) -> u32 { self.newtype } @@ -35,26 +30,10 @@ impl FromStr for Id { #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] pub struct Pointer { - newtype: u64, + pub newtype: u64, } impl Pointer { - pub fn from_sys(sys: pgrx::pg_sys::ItemPointerData) -> Self { - let mut newtype = 0; - newtype |= (sys.ip_blkid.bi_hi as u64) << 32; - newtype |= (sys.ip_blkid.bi_lo as u64) << 16; - newtype |= sys.ip_posid as u64; - Self { newtype } - } - pub fn into_sys(self) -> pgrx::pg_sys::ItemPointerData { - pgrx::pg_sys::ItemPointerData { - ip_blkid: pgrx::pg_sys::BlockIdData { - bi_hi: ((self.newtype >> 32) & 0xffff) as u16, - bi_lo: ((self.newtype >> 16) & 0xffff) as u16, - }, - ip_posid: (self.newtype & 0xffff) as u16, - } - } pub fn from_u48(value: u64) -> Self { assert!(value < (1u64 << 48)); Self { newtype: value } diff --git a/crates/service/src/utils/cells.rs b/crates/service/src/utils/cells.rs new file mode 100644 index 000000000..83a0a7a57 --- /dev/null +++ b/crates/service/src/utils/cells.rs @@ -0,0 +1,26 @@ +use std::cell::UnsafeCell; + +#[repr(transparent)] +pub struct SyncUnsafeCell { + value: UnsafeCell, +} + +unsafe impl Sync for SyncUnsafeCell {} + +impl SyncUnsafeCell { + pub const fn new(value: T) -> Self { + Self { + value: UnsafeCell::new(value), + } + } +} + +impl SyncUnsafeCell { + pub fn get(&self) -> *mut T { + self.value.get() + } + + pub fn get_mut(&mut self) -> &mut T { + self.value.get_mut() + } +} diff --git a/src/utils/clean.rs b/crates/service/src/utils/clean.rs similarity index 100% rename from src/utils/clean.rs rename to crates/service/src/utils/clean.rs diff --git a/src/utils/dir_ops.rs b/crates/service/src/utils/dir_ops.rs similarity index 100% rename from src/utils/dir_ops.rs rename to crates/service/src/utils/dir_ops.rs diff --git a/src/utils/file_atomic.rs b/crates/service/src/utils/file_atomic.rs similarity index 100% rename from src/utils/file_atomic.rs rename to crates/service/src/utils/file_atomic.rs diff --git a/src/utils/file_wal.rs b/crates/service/src/utils/file_wal.rs similarity index 100% rename from src/utils/file_wal.rs rename to crates/service/src/utils/file_wal.rs diff --git a/src/utils/mmap_array.rs b/crates/service/src/utils/mmap_array.rs similarity index 97% rename from src/utils/mmap_array.rs rename to crates/service/src/utils/mmap_array.rs index 133f8589a..8648d5641 100644 --- a/src/utils/mmap_array.rs +++ b/crates/service/src/utils/mmap_array.rs @@ -111,5 +111,5 @@ fn read_information(mut file: &File) -> Information { unsafe fn read_mmap(file: &File, len: usize) -> memmap2::Mmap { let len = len.next_multiple_of(4096); - memmap2::MmapOptions::new().len(len).map(file).unwrap() + unsafe { memmap2::MmapOptions::new().len(len).map(file).unwrap() } } diff --git a/crates/service/src/utils/mod.rs b/crates/service/src/utils/mod.rs new file mode 100644 index 000000000..55f717a88 --- /dev/null +++ b/crates/service/src/utils/mod.rs @@ -0,0 +1,7 @@ +pub mod cells; +pub mod clean; +pub mod dir_ops; +pub mod file_atomic; +pub mod file_wal; +pub mod mmap_array; +pub mod vec2; diff --git a/src/utils/vec2.rs b/crates/service/src/utils/vec2.rs similarity index 78% rename from src/utils/vec2.rs rename to crates/service/src/utils/vec2.rs index 1a338ddea..2640f51c8 100644 --- a/src/utils/vec2.rs +++ b/crates/service/src/utils/vec2.rs @@ -2,16 +2,16 @@ use crate::prelude::*; use std::ops::{Deref, DerefMut, Index, IndexMut}; #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -pub struct Vec2 { +pub struct Vec2 { dims: u16, - v: Box<[Scalar]>, + v: Vec, } -impl Vec2 { +impl Vec2 { pub fn new(dims: u16, n: usize) -> Self { Self { dims, - v: bytemuck::zeroed_slice_box(dims as usize * n), + v: bytemuck::zeroed_vec(dims as usize * n), } } pub fn dims(&self) -> u16 { @@ -32,29 +32,29 @@ impl Vec2 { } } -impl Index for Vec2 { - type Output = [Scalar]; +impl Index for Vec2 { + type Output = [S::Scalar]; fn index(&self, index: usize) -> &Self::Output { &self.v[self.dims as usize * index..][..self.dims as usize] } } -impl IndexMut for Vec2 { +impl IndexMut for Vec2 { fn index_mut(&mut self, index: usize) -> &mut Self::Output { &mut self.v[self.dims as usize * index..][..self.dims as usize] } } -impl Deref for Vec2 { - type Target = [Scalar]; +impl Deref for Vec2 { + type Target = [S::Scalar]; fn deref(&self) -> &Self::Target { self.v.deref() } } -impl DerefMut for Vec2 { +impl DerefMut for Vec2 { fn deref_mut(&mut self) -> &mut Self::Target { self.v.deref_mut() } diff --git a/crates/service/src/worker/instance.rs b/crates/service/src/worker/instance.rs new file mode 100644 index 000000000..13dc2587d --- /dev/null +++ b/crates/service/src/worker/instance.rs @@ -0,0 +1,204 @@ +use crate::index::Index; +use crate::index::IndexOptions; +use crate::index::IndexStat; +use crate::index::IndexView; +use crate::index::OutdatedError; +use crate::prelude::*; +use std::path::PathBuf; +use std::sync::Arc; + +#[derive(Clone)] +pub enum Instance { + F32Cos(Arc>), + F32Dot(Arc>), + F32L2(Arc>), + F16Cos(Arc>), + F16Dot(Arc>), + F16L2(Arc>), +} + +impl Instance { + pub fn create(path: PathBuf, options: IndexOptions) -> Self { + match (options.vector.d, options.vector.k) { + (Distance::Cos, Kind::F32) => Self::F32Cos(Index::create(path, options)), + (Distance::Dot, Kind::F32) => Self::F32Dot(Index::create(path, options)), + (Distance::L2, Kind::F32) => Self::F32L2(Index::create(path, options)), + (Distance::Cos, Kind::F16) => Self::F16Cos(Index::create(path, options)), + (Distance::Dot, Kind::F16) => Self::F16Dot(Index::create(path, options)), + (Distance::L2, Kind::F16) => Self::F16L2(Index::create(path, options)), + } + } + pub fn open(path: PathBuf, options: IndexOptions) -> Self { + match (options.vector.d, options.vector.k) { + (Distance::Cos, Kind::F32) => Self::F32Cos(Index::open(path, options)), + (Distance::Dot, Kind::F32) => Self::F32Dot(Index::open(path, options)), + (Distance::L2, Kind::F32) => Self::F32L2(Index::open(path, options)), + (Distance::Cos, Kind::F16) => Self::F16Cos(Index::create(path, options)), + (Distance::Dot, Kind::F16) => Self::F16Dot(Index::create(path, options)), + (Distance::L2, Kind::F16) => Self::F16L2(Index::create(path, options)), + } + } + pub fn options(&self) -> &IndexOptions { + match self { + Instance::F32Cos(x) => x.options(), + Instance::F32Dot(x) => x.options(), + Instance::F32L2(x) => x.options(), + Instance::F16Cos(x) => x.options(), + Instance::F16Dot(x) => x.options(), + Instance::F16L2(x) => x.options(), + } + } + pub fn refresh(&self) { + match self { + Instance::F32Cos(x) => x.refresh(), + Instance::F32Dot(x) => x.refresh(), + Instance::F32L2(x) => x.refresh(), + Instance::F16Cos(x) => x.refresh(), + Instance::F16Dot(x) => x.refresh(), + Instance::F16L2(x) => x.refresh(), + } + } + pub fn view(&self) -> InstanceView { + match self { + Instance::F32Cos(x) => InstanceView::F32Cos(x.view()), + Instance::F32Dot(x) => InstanceView::F32Dot(x.view()), + Instance::F32L2(x) => InstanceView::F32L2(x.view()), + Instance::F16Cos(x) => InstanceView::F16Cos(x.view()), + Instance::F16Dot(x) => InstanceView::F16Dot(x.view()), + Instance::F16L2(x) => InstanceView::F16L2(x.view()), + } + } + pub fn stat(&self) -> IndexStat { + match self { + Instance::F32Cos(x) => x.stat(), + Instance::F32Dot(x) => x.stat(), + Instance::F32L2(x) => x.stat(), + Instance::F16Cos(x) => x.stat(), + Instance::F16Dot(x) => x.stat(), + Instance::F16L2(x) => x.stat(), + } + } +} + +pub enum InstanceView { + F32Cos(Arc>), + F32Dot(Arc>), + F32L2(Arc>), + F16Cos(Arc>), + F16Dot(Arc>), + F16L2(Arc>), +} + +impl InstanceView { + pub fn search bool>( + &self, + k: usize, + vector: DynamicVector, + filter: F, + ) -> Result, FriendlyError> { + match (self, vector) { + (InstanceView::F32Cos(x), DynamicVector::F32(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.search(k, &vector, filter)) + } + (InstanceView::F32Dot(x), DynamicVector::F32(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.search(k, &vector, filter)) + } + (InstanceView::F32L2(x), DynamicVector::F32(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.search(k, &vector, filter)) + } + (InstanceView::F16Cos(x), DynamicVector::F16(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.search(k, &vector, filter)) + } + (InstanceView::F16Dot(x), DynamicVector::F16(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.search(k, &vector, filter)) + } + (InstanceView::F16L2(x), DynamicVector::F16(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.search(k, &vector, filter)) + } + _ => Err(FriendlyError::Unmatched2), + } + } + pub fn insert( + &self, + vector: DynamicVector, + pointer: Pointer, + ) -> Result, FriendlyError> { + match (self, vector) { + (InstanceView::F32Cos(x), DynamicVector::F32(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.insert(vector, pointer)) + } + (InstanceView::F32Dot(x), DynamicVector::F32(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.insert(vector, pointer)) + } + (InstanceView::F32L2(x), DynamicVector::F32(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.insert(vector, pointer)) + } + (InstanceView::F16Cos(x), DynamicVector::F16(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.insert(vector, pointer)) + } + (InstanceView::F16Dot(x), DynamicVector::F16(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.insert(vector, pointer)) + } + (InstanceView::F16L2(x), DynamicVector::F16(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.insert(vector, pointer)) + } + _ => Err(FriendlyError::Unmatched2), + } + } + pub fn delete bool>(&self, f: F) { + match self { + InstanceView::F32Cos(x) => x.delete(f), + InstanceView::F32Dot(x) => x.delete(f), + InstanceView::F32L2(x) => x.delete(f), + InstanceView::F16Cos(x) => x.delete(f), + InstanceView::F16Dot(x) => x.delete(f), + InstanceView::F16L2(x) => x.delete(f), + } + } + pub fn flush(&self) { + match self { + InstanceView::F32Cos(x) => x.flush(), + InstanceView::F32Dot(x) => x.flush(), + InstanceView::F32L2(x) => x.flush(), + InstanceView::F16Cos(x) => x.flush(), + InstanceView::F16Dot(x) => x.flush(), + InstanceView::F16L2(x) => x.flush(), + } + } +} diff --git a/src/bgworker/worker.rs b/crates/service/src/worker/mod.rs similarity index 65% rename from src/bgworker/worker.rs rename to crates/service/src/worker/mod.rs index 424a97d27..2118f7eb7 100644 --- a/src/bgworker/worker.rs +++ b/crates/service/src/worker/mod.rs @@ -1,7 +1,9 @@ -use crate::index::Index; -use crate::index::IndexInsertError; +pub mod instance; + +use self::instance::Instance; use crate::index::IndexOptions; -use crate::index::IndexSearchError; +use crate::index::IndexStat; +use crate::index::OutdatedError; use crate::prelude::*; use crate::utils::clean::clean; use crate::utils::dir_ops::sync_dir; @@ -57,7 +59,7 @@ impl Worker { let mut indexes = HashMap::new(); for (&id, options) in startup.get().indexes.iter() { let path = path.join("indexes").join(id.to_string()); - let index = Index::open(path, options.clone()); + let index = Instance::open(path, options.clone()); indexes.insert(id, index); } let view = Arc::new(WorkerView { @@ -72,7 +74,7 @@ impl Worker { } pub fn call_create(&self, id: Id, options: IndexOptions) { let mut protect = self.protect.lock(); - let index = Index::create(self.path.join("indexes").join(id.to_string()), options); + let index = Instance::create(self.path.join("indexes").join(id.to_string()), options); if protect.indexes.insert(id, index).is_some() { panic!("index {} already exists", id) } @@ -81,29 +83,29 @@ impl Worker { pub fn call_search( &self, id: Id, - search: (Vec, usize), + search: (DynamicVector, usize), filter: F, ) -> Result, FriendlyError> where F: FnMut(Pointer) -> bool, { let view = self.view.load_full(); - let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?; + let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?; let view = index.view(); - match view.search(search.1, &search.0, filter) { - Ok(x) => Ok(x), - Err(IndexSearchError::InvalidVector(x)) => Err(FriendlyError::BadVector(x)), - } + view.search(search.1, search.0, filter) } - pub fn call_insert(&self, id: Id, insert: (Vec, Pointer)) -> Result<(), FriendlyError> { + pub fn call_insert( + &self, + id: Id, + insert: (DynamicVector, Pointer), + ) -> Result<(), FriendlyError> { let view = self.view.load_full(); - let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?; + let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?; loop { let view = index.view(); - match view.insert(insert.0.clone(), insert.1) { + match view.insert(insert.0.clone(), insert.1)? { Ok(()) => break Ok(()), - Err(IndexInsertError::InvalidVector(x)) => break Err(FriendlyError::BadVector(x)), - Err(IndexInsertError::OutdatedView(_)) => index.refresh(), + Err(OutdatedError(_)) => index.refresh(), } } } @@ -112,16 +114,16 @@ impl Worker { F: FnMut(Pointer) -> bool, { let view = self.view.load_full(); - let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?; + let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?; let view = index.view(); view.delete(f); Ok(()) } pub fn call_flush(&self, id: Id) -> Result<(), FriendlyError> { let view = self.view.load_full(); - let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?; + let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?; let view = index.view(); - view.flush().unwrap(); + view.flush(); Ok(()) } pub fn call_destory(&self, ids: Vec) { @@ -134,44 +136,20 @@ impl Worker { protect.maintain(&self.view); } } - pub fn call_stat(&self, id: Id) -> Result { + pub fn call_stat(&self, id: Id) -> Result { let view = self.view.load_full(); - let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?; - let view = index.view(); - let idx_sealed_len = view.sealed_len(); - let idx_growing_len = view.growing_len(); - let idx_write = view.write_len(); - let res = VectorIndexInfo { - indexing: index.indexing(), - idx_tuples: (idx_write + idx_sealed_len + idx_growing_len) - .try_into() - .unwrap(), - idx_sealed_len: idx_sealed_len.try_into().unwrap(), - idx_growing_len: idx_growing_len.try_into().unwrap(), - idx_write: idx_write.try_into().unwrap(), - idx_sealed: view - .sealed_len_vec() - .into_iter() - .map(|x| x.try_into().unwrap()) - .collect(), - idx_growing: view - .growing_len_vec() - .into_iter() - .map(|x| x.try_into().unwrap()) - .collect(), - idx_config: serde_json::to_string(index.options()).unwrap(), - }; - Ok(res) + let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?; + Ok(index.stat()) } } struct WorkerView { - indexes: HashMap>, + indexes: HashMap, } struct WorkerProtect { startup: FileAtomic, - indexes: HashMap>, + indexes: HashMap, } impl WorkerProtect { diff --git a/docs/get-started.md b/docs/get-started.md index 0321795a2..24b5ca3eb 100644 --- a/docs/get-started.md +++ b/docs/get-started.md @@ -46,9 +46,9 @@ We support three operators to calculate the distance between two vectors. -- squared Euclidean distance SELECT '[1, 2, 3]'::vector <-> '[3, 2, 1]'::vector; -- negative dot product -SELECT '[1, 2, 3]' <#> '[3, 2, 1]'; +SELECT '[1, 2, 3]'::vector <#> '[3, 2, 1]'::vector; -- negative cosine similarity -SELECT '[1, 2, 3]' <=> '[3, 2, 1]'; +SELECT '[1, 2, 3]'::vector <=> '[3, 2, 1]'::vector; ``` You can search for a vector simply like this. @@ -58,6 +58,10 @@ You can search for a vector simply like this. SELECT * FROM items ORDER BY embedding <-> '[3,2,1]' LIMIT 5; ``` +## Half-precision floating-point + +`vecf16` type is the same with `vector` in anything but the scalar type. It stores 16-bit floating point numbers. If you want to reduce the memory usage to get better performace, you can try to replace `vector` type with `vecf16` type. + ## Things You Need to Know `vector(n)` is a valid data type only if $1 \leq n \leq 65535$. Due to limits of PostgreSQL, it's possible to create a value of type `vector(3)` of $5$ dimensions and `vector` is also a valid data. However, you cannot still put $0$ scalar or more than $65535$ scalars to a vector. If you use `vector` for a column or there is some values mismatched with dimension denoted by the column, you won't able to create an index on it. diff --git a/docs/indexing.md b/docs/indexing.md index 9fcd8a834..b46ea41be 100644 --- a/docs/indexing.md +++ b/docs/indexing.md @@ -8,8 +8,16 @@ Assuming there is a table `items` and there is a column named `embedding` of typ CREATE INDEX ON items USING vectors (embedding l2_ops); ``` -For negative dot product, replace `l2_ops` with `dot_ops`. -For negative cosine similarity, replace `l2_ops` with `cosine_ops`. +There is a table for you to choose a proper operator class for creating indexes. + +| Vector type | Distance type | Operator class | +| ----------- | -------------------------- | ----------------- | +| vector | squared Euclidean distance | l2_ops | +| vector | negative dot product | dot_ops | +| vector | negative cosine similarity | cosine_ops | +| vecf16 | squared Euclidean distance | vecf16_l2_ops | +| vecf16 | negative dot product | vecf16_dot_ops | +| vecf16 | negative cosine similarity | vecf16_cosine_ops | Now you can perform a KNN search with the following SQL again, but this time the vector index is used for searching. @@ -36,14 +44,15 @@ Options for table `segment`. | Key | Type | Description | | ------------------------ | ------- | ------------------------------------------------------------------- | | max_growing_segment_size | integer | Maximum size of unindexed vectors. Default value is `20_000`. | -| min_sealed_segment_size | integer | Minimum size of vectors for indexing. Default value is `1_000`. | | max_sealed_segment_size | integer | Maximum size of vectors for indexing. Default value is `1_000_000`. | Options for table `optimizing`. -| Key | Type | Description | -| ------------------ | ------- | --------------------------------------------------------------------------- | -| optimizing_threads | integer | Maximum threads for indexing. Default value is the sqrt of number of cores. | +| Key | Type | Description | +| ------------------ | ------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | +| optimizing_threads | integer | Maximum threads for indexing. Default value is the sqrt of number of cores. | +| sealing_secs | integer | If a writing segment larger than `sealing_size` do not accept new data for `sealing_secs` seconds, the writing segment will be turned to a sealed segment. | +| sealing_size | integer | See above. | Options for table `indexing`. @@ -99,23 +108,19 @@ Options for table `product`. ## Progress View We also provide a view `pg_vector_index_info` to monitor the progress of indexing. -Note that whether idx_sealed_len is equal to idx_tuples doesn't relate to the completion of indexing. -It may do further optimization after indexing. It may also stop indexing because there are too few tuples left. - -| Column | Type | Description | -| --------------- | ------ | --------------------------------------------- | -| tablerelid | oid | The oid of the table. | -| indexrelid | oid | The oid of the index. | -| tablename | name | The name of the table. | -| indexname | name | The name of the index. | -| indexing | bool | Whether the background thread is indexing. | -| idx_tuples | int4 | The number of tuples. | -| idx_sealed_len | int4 | The number of tuples in sealed segments. | -| idx_growing_len | int4 | The number of tuples in growing segments. | -| idx_write | int4 | The number of tuples in write buffer. | -| idx_sealed | int4[] | The number of tuples in each sealed segment. | -| idx_growing | int4[] | The number of tuples in each growing segment. | -| idx_config | text | The configuration of the index. | + +| Column | Type | Description | +| ------------ | ------ | --------------------------------------------- | +| tablerelid | oid | The oid of the table. | +| indexrelid | oid | The oid of the index. | +| tablename | name | The name of the table. | +| indexname | name | The name of the index. | +| idx_indexing | bool | Whether the background thread is indexing. | +| idx_tuples | int8 | The number of tuples. | +| idx_sealed | int8[] | The number of tuples in each sealed segment. | +| idx_growing | int8[] | The number of tuples in each growing segment. | +| idx_write | int8 | The number of tuples in write buffer. | +| idx_config | text | The configuration of the index. | ## Examples diff --git a/src/algorithms/diskann/mod.rs b/src/algorithms/diskann/mod.rs deleted file mode 100644 index 87a515a9d..000000000 --- a/src/algorithms/diskann/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod vamana; diff --git a/src/bgworker/mod.rs b/src/bgworker/mod.rs index 4a2153bd4..0ca91e557 100644 --- a/src/bgworker/mod.rs +++ b/src/bgworker/mod.rs @@ -1,11 +1,26 @@ -pub mod worker; - -use self::worker::Worker; use crate::ipc::server::RpcHandler; use crate::ipc::IpcError; +use service::worker::Worker; use std::path::{Path, PathBuf}; use std::sync::Arc; +pub unsafe fn init() { + use pgrx::bgworkers::BackgroundWorkerBuilder; + use pgrx::bgworkers::BgWorkerStartTime; + BackgroundWorkerBuilder::new("vectors") + .set_function("vectors_main") + .set_library("vectors") + .set_argument(None) + .enable_shmem_access(None) + .set_start_time(BgWorkerStartTime::PostmasterStart) + .load(); +} + +#[no_mangle] +extern "C" fn vectors_main(_arg: pgrx::pg_sys::Datum) { + let _ = std::panic::catch_unwind(crate::bgworker::main); +} + pub fn main() { { let mut builder = env_logger::builder(); diff --git a/src/datatype/casts_f32.rs b/src/datatype/casts_f32.rs new file mode 100644 index 000000000..53261a744 --- /dev/null +++ b/src/datatype/casts_f32.rs @@ -0,0 +1,26 @@ +use crate::datatype::typmod::Typmod; +use crate::datatype::vecf32::{Vecf32, Vecf32Input, Vecf32Output}; +use service::prelude::*; + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn vecf32_cast_array_to_vector( + array: pgrx::Array, + typmod: i32, + _explicit: bool, +) -> Vecf32Output { + assert!(!array.is_empty()); + assert!(array.len() <= 65535); + assert!(!array.contains_nulls()); + let typmod = Typmod::parse_from_i32(typmod).unwrap(); + let len = typmod.dims().unwrap_or(array.len() as u16); + let mut data = Vecf32::new_zeroed_in_postgres(len as usize); + for (i, x) in array.iter().enumerate() { + data[i] = F32(x.unwrap_or(f32::NAN)); + } + data +} + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn vecf32_cast_vector_to_array(vector: Vecf32Input<'_>, _typmod: i32, _explicit: bool) -> Vec { + vector.data().iter().map(|x| x.to_f32()).collect() +} diff --git a/src/datatype/mod.rs b/src/datatype/mod.rs new file mode 100644 index 000000000..1b0ef7a78 --- /dev/null +++ b/src/datatype/mod.rs @@ -0,0 +1,6 @@ +pub mod casts_f32; +pub mod operators_f16; +pub mod operators_f32; +pub mod typmod; +pub mod vecf16; +pub mod vecf32; diff --git a/src/postgres/operators.rs b/src/datatype/operators_f16.rs similarity index 57% rename from src/postgres/operators.rs rename to src/datatype/operators_f16.rs index 4f7a63654..72c6367fe 100644 --- a/src/postgres/operators.rs +++ b/src/datatype/operators_f16.rs @@ -1,53 +1,53 @@ -use crate::postgres::datatype::{Vector, VectorInput, VectorOutput}; -use crate::prelude::*; +use crate::datatype::vecf16::{Vecf16, Vecf16Input, Vecf16Output}; +use service::prelude::*; use std::ops::Deref; -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] #[pgrx::opname(+)] #[pgrx::commutator(+)] -fn operator_add(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> VectorOutput { +fn vecf16_operator_add(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> Vecf16Output { if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { + FriendlyError::Unmatched { left_dimensions: lhs.len() as _, right_dimensions: rhs.len() as _, } .friendly(); } let n = lhs.len(); - let mut v = Vector::new_zeroed(n); + let mut v = Vecf16::new_zeroed(n); for i in 0..n { v[i] = lhs[i] + rhs[i]; } v.copy_into_postgres() } -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] #[pgrx::opname(-)] -fn operator_minus(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> VectorOutput { +fn vecf16_operator_minus(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> Vecf16Output { if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { + FriendlyError::Unmatched { left_dimensions: lhs.len() as _, right_dimensions: rhs.len() as _, } .friendly(); } let n = lhs.len(); - let mut v = Vector::new_zeroed(n); + let mut v = Vecf16::new_zeroed(n); for i in 0..n { v[i] = lhs[i] - rhs[i]; } v.copy_into_postgres() } -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] #[pgrx::opname(<)] #[pgrx::negator(>=)] #[pgrx::commutator(>)] #[pgrx::restrict(scalarltsel)] #[pgrx::join(scalarltjoinsel)] -fn operator_lt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { +fn vecf16_operator_lt(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { + FriendlyError::Unmatched { left_dimensions: lhs.len() as _, right_dimensions: rhs.len() as _, } @@ -56,15 +56,15 @@ fn operator_lt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { lhs.deref() < rhs.deref() } -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] #[pgrx::opname(<=)] #[pgrx::negator(>)] #[pgrx::commutator(>=)] #[pgrx::restrict(scalarltsel)] #[pgrx::join(scalarltjoinsel)] -fn operator_lte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { +fn vecf16_operator_lte(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { + FriendlyError::Unmatched { left_dimensions: lhs.len() as _, right_dimensions: rhs.len() as _, } @@ -73,15 +73,15 @@ fn operator_lte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { lhs.deref() <= rhs.deref() } -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] #[pgrx::opname(>)] #[pgrx::negator(<=)] #[pgrx::commutator(<)] #[pgrx::restrict(scalargtsel)] #[pgrx::join(scalargtjoinsel)] -fn operator_gt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { +fn vecf16_operator_gt(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { + FriendlyError::Unmatched { left_dimensions: lhs.len() as _, right_dimensions: rhs.len() as _, } @@ -90,15 +90,15 @@ fn operator_gt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { lhs.deref() > rhs.deref() } -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] #[pgrx::opname(>=)] #[pgrx::negator(<)] #[pgrx::commutator(<=)] #[pgrx::restrict(scalargtsel)] #[pgrx::join(scalargtjoinsel)] -fn operator_gte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { +fn vecf16_operator_gte(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { + FriendlyError::Unmatched { left_dimensions: lhs.len() as _, right_dimensions: rhs.len() as _, } @@ -107,15 +107,15 @@ fn operator_gte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { lhs.deref() >= rhs.deref() } -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] #[pgrx::opname(=)] #[pgrx::negator(<>)] #[pgrx::commutator(=)] #[pgrx::restrict(eqsel)] #[pgrx::join(eqjoinsel)] -fn operator_eq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { +fn vecf16_operator_eq(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { + FriendlyError::Unmatched { left_dimensions: lhs.len() as _, right_dimensions: rhs.len() as _, } @@ -124,15 +124,15 @@ fn operator_eq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { lhs.deref() == rhs.deref() } -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] #[pgrx::opname(<>)] #[pgrx::negator(=)] #[pgrx::commutator(<>)] #[pgrx::restrict(eqsel)] #[pgrx::join(eqjoinsel)] -fn operator_neq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { +fn vecf16_operator_neq(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { + FriendlyError::Unmatched { left_dimensions: lhs.len() as _, right_dimensions: rhs.len() as _, } @@ -141,44 +141,44 @@ fn operator_neq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { lhs.deref() != rhs.deref() } -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] #[pgrx::opname(<=>)] #[pgrx::commutator(<=>)] -fn operator_cosine(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar { +fn vecf16_operator_cosine(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> f32 { if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { + FriendlyError::Unmatched { left_dimensions: lhs.len() as _, right_dimensions: rhs.len() as _, } .friendly(); } - Distance::Cosine.distance(&lhs, &rhs) + F16Cos::distance(&lhs, &rhs).to_f32() } -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] #[pgrx::opname(<#>)] #[pgrx::commutator(<#>)] -fn operator_dot(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar { +fn vecf16_operator_dot(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> f32 { if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { + FriendlyError::Unmatched { left_dimensions: lhs.len() as _, right_dimensions: rhs.len() as _, } .friendly(); } - Distance::Dot.distance(&lhs, &rhs) + F16Dot::distance(&lhs, &rhs).to_f32() } -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] #[pgrx::opname(<->)] #[pgrx::commutator(<->)] -fn operator_l2(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar { +fn vecf16_operator_l2(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> f32 { if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { + FriendlyError::Unmatched { left_dimensions: lhs.len() as _, right_dimensions: rhs.len() as _, } .friendly(); } - Distance::L2.distance(&lhs, &rhs) + F16L2::distance(&lhs, &rhs).to_f32() } diff --git a/src/datatype/operators_f32.rs b/src/datatype/operators_f32.rs new file mode 100644 index 000000000..307055e1a --- /dev/null +++ b/src/datatype/operators_f32.rs @@ -0,0 +1,184 @@ +use crate::datatype::vecf32::{Vecf32, Vecf32Input, Vecf32Output}; +use service::prelude::*; +use std::ops::Deref; + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] +#[pgrx::opname(+)] +#[pgrx::commutator(+)] +fn vecf32_operator_add(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> Vecf32Output { + if lhs.len() != rhs.len() { + FriendlyError::Unmatched { + left_dimensions: lhs.len() as _, + right_dimensions: rhs.len() as _, + } + .friendly(); + } + let n = lhs.len(); + let mut v = Vecf32::new_zeroed(n); + for i in 0..n { + v[i] = lhs[i] + rhs[i]; + } + v.copy_into_postgres() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] +#[pgrx::opname(-)] +fn vecf32_operator_minus(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> Vecf32Output { + if lhs.len() != rhs.len() { + FriendlyError::Unmatched { + left_dimensions: lhs.len() as _, + right_dimensions: rhs.len() as _, + } + .friendly(); + } + let n = lhs.len(); + let mut v = Vecf32::new_zeroed(n); + for i in 0..n { + v[i] = lhs[i] - rhs[i]; + } + v.copy_into_postgres() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] +#[pgrx::opname(<)] +#[pgrx::negator(>=)] +#[pgrx::commutator(>)] +#[pgrx::restrict(scalarltsel)] +#[pgrx::join(scalarltjoinsel)] +fn vecf32_operator_lt(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { + if lhs.len() != rhs.len() { + FriendlyError::Unmatched { + left_dimensions: lhs.len() as _, + right_dimensions: rhs.len() as _, + } + .friendly(); + } + lhs.deref() < rhs.deref() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] +#[pgrx::opname(<=)] +#[pgrx::negator(>)] +#[pgrx::commutator(>=)] +#[pgrx::restrict(scalarltsel)] +#[pgrx::join(scalarltjoinsel)] +fn vecf32_operator_lte(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { + if lhs.len() != rhs.len() { + FriendlyError::Unmatched { + left_dimensions: lhs.len() as _, + right_dimensions: rhs.len() as _, + } + .friendly(); + } + lhs.deref() <= rhs.deref() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] +#[pgrx::opname(>)] +#[pgrx::negator(<=)] +#[pgrx::commutator(<)] +#[pgrx::restrict(scalargtsel)] +#[pgrx::join(scalargtjoinsel)] +fn vecf32_operator_gt(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { + if lhs.len() != rhs.len() { + FriendlyError::Unmatched { + left_dimensions: lhs.len() as _, + right_dimensions: rhs.len() as _, + } + .friendly(); + } + lhs.deref() > rhs.deref() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] +#[pgrx::opname(>=)] +#[pgrx::negator(<)] +#[pgrx::commutator(<=)] +#[pgrx::restrict(scalargtsel)] +#[pgrx::join(scalargtjoinsel)] +fn vecf32_operator_gte(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { + if lhs.len() != rhs.len() { + FriendlyError::Unmatched { + left_dimensions: lhs.len() as _, + right_dimensions: rhs.len() as _, + } + .friendly(); + } + lhs.deref() >= rhs.deref() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] +#[pgrx::opname(=)] +#[pgrx::negator(<>)] +#[pgrx::commutator(=)] +#[pgrx::restrict(eqsel)] +#[pgrx::join(eqjoinsel)] +fn vecf32_operator_eq(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { + if lhs.len() != rhs.len() { + FriendlyError::Unmatched { + left_dimensions: lhs.len() as _, + right_dimensions: rhs.len() as _, + } + .friendly(); + } + lhs.deref() == rhs.deref() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] +#[pgrx::opname(<>)] +#[pgrx::negator(=)] +#[pgrx::commutator(<>)] +#[pgrx::restrict(eqsel)] +#[pgrx::join(eqjoinsel)] +fn vecf32_operator_neq(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { + if lhs.len() != rhs.len() { + FriendlyError::Unmatched { + left_dimensions: lhs.len() as _, + right_dimensions: rhs.len() as _, + } + .friendly(); + } + lhs.deref() != rhs.deref() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] +#[pgrx::opname(<=>)] +#[pgrx::commutator(<=>)] +fn vecf32_operator_cosine(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> f32 { + if lhs.len() != rhs.len() { + FriendlyError::Unmatched { + left_dimensions: lhs.len() as _, + right_dimensions: rhs.len() as _, + } + .friendly(); + } + F32Cos::distance(&lhs, &rhs).to_f32() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] +#[pgrx::opname(<#>)] +#[pgrx::commutator(<#>)] +fn vecf32_operator_dot(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> f32 { + if lhs.len() != rhs.len() { + FriendlyError::Unmatched { + left_dimensions: lhs.len() as _, + right_dimensions: rhs.len() as _, + } + .friendly(); + } + F32Dot::distance(&lhs, &rhs).to_f32() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] +#[pgrx::opname(<->)] +#[pgrx::commutator(<->)] +fn vecf32_operator_l2(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> f32 { + if lhs.len() != rhs.len() { + FriendlyError::Unmatched { + left_dimensions: lhs.len() as _, + right_dimensions: rhs.len() as _, + } + .friendly(); + } + F32L2::distance(&lhs, &rhs).to_f32() +} diff --git a/src/datatype/typmod.rs b/src/datatype/typmod.rs new file mode 100644 index 000000000..4760efa83 --- /dev/null +++ b/src/datatype/typmod.rs @@ -0,0 +1,77 @@ +use pgrx::Array; +use serde::{Deserialize, Serialize}; +use service::prelude::*; +use std::ffi::{CStr, CString}; +use std::num::NonZeroU16; + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub enum Typmod { + Any, + Dims(NonZeroU16), +} + +impl Typmod { + pub fn parse_from_str(s: &str) -> Option { + use Typmod::*; + if let Ok(x) = s.parse::() { + Some(Dims(x)) + } else { + None + } + } + pub fn parse_from_i32(x: i32) -> Option { + use Typmod::*; + if x == -1 { + Some(Any) + } else if 1 <= x && x <= u16::MAX as i32 { + Some(Dims(NonZeroU16::new(x as u16).unwrap())) + } else { + None + } + } + pub fn into_option_string(self) -> Option { + use Typmod::*; + match self { + Any => None, + Dims(x) => Some(i32::from(x.get()).to_string()), + } + } + pub fn into_i32(self) -> i32 { + use Typmod::*; + match self { + Any => -1, + Dims(x) => i32::from(x.get()), + } + } + pub fn dims(self) -> Option { + use Typmod::*; + match self { + Any => None, + Dims(dims) => Some(dims.get()), + } + } +} + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn typmod_in(list: Array<&CStr>) -> i32 { + if list.is_empty() { + -1 + } else if list.len() == 1 { + let s = list.get(0).unwrap().unwrap().to_str().unwrap(); + let typmod = Typmod::parse_from_str(s) + .ok_or(FriendlyError::BadTypeDimensions) + .friendly(); + typmod.into_i32() + } else { + FriendlyError::BadTypeDimensions.friendly(); + } +} + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn typmod_out(typmod: i32) -> CString { + let typmod = Typmod::parse_from_i32(typmod).unwrap(); + match typmod.into_option_string() { + Some(s) => CString::new(format!("({})", s)).unwrap(), + None => CString::new("()").unwrap(), + } +} diff --git a/src/datatype/vecf16.rs b/src/datatype/vecf16.rs new file mode 100644 index 000000000..94e7ccb2f --- /dev/null +++ b/src/datatype/vecf16.rs @@ -0,0 +1,375 @@ +use crate::datatype::typmod::Typmod; +use pgrx::pg_sys::Datum; +use pgrx::pg_sys::Oid; +use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; +use pgrx::pgrx_sql_entity_graph::metadata::Returns; +use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; +use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; +use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; +use pgrx::FromDatum; +use pgrx::IntoDatum; +use service::prelude::*; +use std::alloc::Layout; +use std::cmp::Ordering; +use std::ffi::CStr; +use std::ffi::CString; +use std::ops::Deref; +use std::ops::DerefMut; +use std::ops::Index; +use std::ops::IndexMut; +use std::ptr::NonNull; + +pgrx::extension_sql!( + r#" +CREATE TYPE vecf16 ( + INPUT = vecf16_in, + OUTPUT = vecf16_out, + TYPMOD_IN = typmod_in, + TYPMOD_OUT = typmod_out, + STORAGE = EXTENDED, + INTERNALLENGTH = VARIABLE, + ALIGNMENT = double +); +"#, + name = "vecf16", + creates = [Type(Vecf16)], + requires = [vecf16_in, vecf16_out, typmod_in, typmod_out], +); + +#[repr(C, align(8))] +pub struct Vecf16 { + varlena: u32, + len: u16, + phantom: [F16; 0], +} + +impl Vecf16 { + fn varlena(size: usize) -> u32 { + (size << 2) as u32 + } + fn layout(len: usize) -> Layout { + u16::try_from(len).expect("Vector is too large."); + let layout_alpha = Layout::new::(); + let layout_beta = Layout::array::(len).unwrap(); + let layout = layout_alpha.extend(layout_beta).unwrap().0; + layout.pad_to_align() + } + pub fn new(slice: &[F16]) -> Box { + unsafe { + assert!(u16::try_from(slice.len()).is_ok()); + let layout = Vecf16::layout(slice.len()); + let ptr = std::alloc::alloc(layout) as *mut Vecf16; + std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16::varlena(layout.size())); + std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16); + std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); + Box::from_raw(ptr) + } + } + pub fn new_in_postgres(slice: &[F16]) -> Vecf16Output { + unsafe { + assert!(u16::try_from(slice.len()).is_ok()); + let layout = Vecf16::layout(slice.len()); + let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut Vecf16; + std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16::varlena(layout.size())); + std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16); + std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); + Vecf16Output(NonNull::new(ptr).unwrap()) + } + } + pub fn new_zeroed(len: usize) -> Box { + unsafe { + assert!(u16::try_from(len).is_ok()); + let layout = Vecf16::layout(len); + let ptr = std::alloc::alloc_zeroed(layout) as *mut Vecf16; + std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16::varlena(layout.size())); + std::ptr::addr_of_mut!((*ptr).len).write(len as u16); + Box::from_raw(ptr) + } + } + #[allow(dead_code)] + pub fn new_zeroed_in_postgres(len: usize) -> Vecf16Output { + unsafe { + assert!(u64::try_from(len).is_ok()); + let layout = Vecf16::layout(len); + let ptr = pgrx::pg_sys::palloc0(layout.size()) as *mut Vecf16; + std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16::varlena(layout.size())); + std::ptr::addr_of_mut!((*ptr).len).write(len as u16); + Vecf16Output(NonNull::new(ptr).unwrap()) + } + } + pub fn len(&self) -> usize { + self.len as usize + } + pub fn data(&self) -> &[F16] { + debug_assert_eq!(self.varlena & 3, 0); + unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.len as usize) } + } + pub fn data_mut(&mut self) -> &mut [F16] { + debug_assert_eq!(self.varlena & 3, 0); + unsafe { std::slice::from_raw_parts_mut(self.phantom.as_mut_ptr(), self.len as usize) } + } + #[allow(dead_code)] + pub fn copy(&self) -> Box { + Vecf16::new(self.data()) + } + pub fn copy_into_postgres(&self) -> Vecf16Output { + Vecf16::new_in_postgres(self.data()) + } +} + +impl Deref for Vecf16 { + type Target = [F16]; + + fn deref(&self) -> &Self::Target { + self.data() + } +} + +impl DerefMut for Vecf16 { + fn deref_mut(&mut self) -> &mut Self::Target { + self.data_mut() + } +} + +impl AsRef<[F16]> for Vecf16 { + fn as_ref(&self) -> &[F16] { + self.data() + } +} + +impl AsMut<[F16]> for Vecf16 { + fn as_mut(&mut self) -> &mut [F16] { + self.data_mut() + } +} + +impl Index for Vecf16 { + type Output = F16; + + fn index(&self, index: usize) -> &Self::Output { + self.data().index(index) + } +} + +impl IndexMut for Vecf16 { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + self.data_mut().index_mut(index) + } +} + +impl PartialEq for Vecf16 { + fn eq(&self, other: &Self) -> bool { + if self.len() != other.len() { + return false; + } + let n = self.len(); + for i in 0..n { + if self[i] != other[i] { + return false; + } + } + true + } +} + +impl Eq for Vecf16 {} + +impl PartialOrd for Vecf16 { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Vecf16 { + fn cmp(&self, other: &Self) -> Ordering { + use Ordering::*; + if let x @ Less | x @ Greater = self.len().cmp(&other.len()) { + return x; + } + let n = self.len(); + for i in 0..n { + if let x @ Less | x @ Greater = self[i].cmp(&other[i]) { + return x; + } + } + Equal + } +} + +pub enum Vecf16Input<'a> { + Owned(Vecf16Output), + Borrowed(&'a Vecf16), +} + +impl<'a> Vecf16Input<'a> { + pub unsafe fn new(p: NonNull) -> Self { + let q = unsafe { + NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap() + }; + if p != q { + Vecf16Input::Owned(Vecf16Output(q)) + } else { + unsafe { Vecf16Input::Borrowed(p.as_ref()) } + } + } +} + +impl Deref for Vecf16Input<'_> { + type Target = Vecf16; + + fn deref(&self) -> &Self::Target { + match self { + Vecf16Input::Owned(x) => x, + Vecf16Input::Borrowed(x) => x, + } + } +} + +pub struct Vecf16Output(NonNull); + +impl Vecf16Output { + pub fn into_raw(self) -> *mut Vecf16 { + let result = self.0.as_ptr(); + std::mem::forget(self); + result + } +} + +impl Deref for Vecf16Output { + type Target = Vecf16; + + fn deref(&self) -> &Self::Target { + unsafe { self.0.as_ref() } + } +} + +impl DerefMut for Vecf16Output { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { self.0.as_mut() } + } +} + +impl Drop for Vecf16Output { + fn drop(&mut self) { + unsafe { + pgrx::pg_sys::pfree(self.0.as_ptr() as _); + } + } +} + +impl<'a> FromDatum for Vecf16Input<'a> { + unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { + if is_null { + None + } else { + let ptr = NonNull::new(datum.cast_mut_ptr::()).unwrap(); + unsafe { Some(Vecf16Input::new(ptr)) } + } + } +} + +impl IntoDatum for Vecf16Output { + fn into_datum(self) -> Option { + Some(Datum::from(self.into_raw() as *mut ())) + } + + fn type_oid() -> Oid { + pgrx::wrappers::regtypein("vecf16") + } +} + +unsafe impl SqlTranslatable for Vecf16Input<'_> { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("vecf16"))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("vecf16")))) + } +} + +unsafe impl SqlTranslatable for Vecf16Output { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("vecf16"))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("vecf16")))) + } +} + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn vecf16_in(input: &CStr, _oid: Oid, typmod: i32) -> Vecf16Output { + fn solve(option: Option, hint: &str) -> T { + if let Some(x) = option { + x + } else { + FriendlyError::BadLiteral { + hint: hint.to_string(), + } + .friendly() + } + } + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + enum State { + MatchingLeft, + Reading, + MatchedRight, + } + use State::*; + let input = input.to_bytes(); + let typmod = Typmod::parse_from_i32(typmod).unwrap(); + let mut vector = Vec::::with_capacity(typmod.dims().unwrap_or(0) as usize); + let mut state = MatchingLeft; + let mut token: Option = None; + for &c in input { + match (state, c) { + (MatchingLeft, b'[') => { + state = Reading; + } + (Reading, b'0'..=b'9' | b'.' | b'e' | b'+' | b'-') => { + let token = token.get_or_insert(String::new()); + token.push(char::from_u32(c as u32).unwrap()); + } + (Reading, b',') => { + let token = solve(token.take(), "Expect a number."); + vector.push(solve(token.parse().ok(), "Bad number.")); + } + (Reading, b']') => { + if let Some(token) = token.take() { + vector.push(solve(token.parse().ok(), "Bad number.")); + } + state = MatchedRight; + } + (_, b' ') => {} + _ => { + FriendlyError::BadLiteral { + hint: format!("Bad charactor with ascii {:#x}.", c), + } + .friendly(); + } + } + } + if state != MatchedRight { + FriendlyError::BadLiteral { + hint: "Bad sequence.".to_string(), + } + .friendly(); + } + if vector.is_empty() || vector.len() > 65535 { + FriendlyError::BadValueDimensions.friendly(); + } + Vecf16::new_in_postgres(&vector) +} + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn vecf16_out(vector: Vecf16Input<'_>) -> CString { + let mut buffer = String::new(); + buffer.push('['); + if let Some(&x) = vector.data().first() { + buffer.push_str(format!("{}", x).as_str()); + } + for &x in vector.data().iter().skip(1) { + buffer.push_str(format!(", {}", x).as_str()); + } + buffer.push(']'); + CString::new(buffer).unwrap() +} diff --git a/src/postgres/datatype.rs b/src/datatype/vecf32.rs similarity index 55% rename from src/postgres/datatype.rs rename to src/datatype/vecf32.rs index 46b8056a0..d04f9312d 100644 --- a/src/postgres/datatype.rs +++ b/src/datatype/vecf32.rs @@ -1,4 +1,4 @@ -use crate::prelude::*; +use crate::datatype::typmod::Typmod; use pgrx::pg_sys::Datum; use pgrx::pg_sys::Oid; use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; @@ -6,208 +6,158 @@ use pgrx::pgrx_sql_entity_graph::metadata::Returns; use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; -use pgrx::Array; use pgrx::FromDatum; use pgrx::IntoDatum; -use serde::{Deserialize, Serialize}; +use service::prelude::*; use std::alloc::Layout; use std::cmp::Ordering; use std::ffi::CStr; use std::ffi::CString; -use std::num::NonZeroU16; use std::ops::Deref; use std::ops::DerefMut; use std::ops::Index; use std::ops::IndexMut; use std::ptr::NonNull; -#[derive(Debug, Clone, Copy, Serialize, Deserialize)] -pub enum VectorTypmod { - Any, - Dims(NonZeroU16), -} - -impl VectorTypmod { - pub fn parse_from_str(s: &str) -> Option { - use VectorTypmod::*; - if let Ok(x) = s.parse::() { - Some(Dims(x)) - } else { - None - } - } - pub fn parse_from_i32(x: i32) -> Option { - use VectorTypmod::*; - if x == -1 { - Some(Any) - } else if 1 <= x && x <= u16::MAX as i32 { - Some(Dims(NonZeroU16::new(x as u16).unwrap())) - } else { - None - } - } - pub fn into_option_string(self) -> Option { - use VectorTypmod::*; - match self { - Any => None, - Dims(x) => Some(i32::from(x.get()).to_string()), - } - } - pub fn into_i32(self) -> i32 { - use VectorTypmod::*; - match self { - Any => -1, - Dims(x) => i32::from(x.get()), - } - } - pub fn dims(self) -> Option { - use VectorTypmod::*; - match self { - Any => None, - Dims(dims) => Some(dims.get()), - } - } -} - pgrx::extension_sql!( r#" CREATE TYPE vector ( - INPUT = vector_in, - OUTPUT = vector_out, - TYPMOD_IN = vector_typmod_in, - TYPMOD_OUT = vector_typmod_out, + INPUT = vecf32_in, + OUTPUT = vecf32_out, + TYPMOD_IN = typmod_in, + TYPMOD_OUT = typmod_out, STORAGE = EXTENDED, INTERNALLENGTH = VARIABLE, ALIGNMENT = double ); "#, - name = "vector", - creates = [Type(Vector)], - requires = [vector_in, vector_out, vector_typmod_in, vector_typmod_out], + name = "vecf32", + creates = [Type(Vecf32)], + requires = [vecf32_in, vecf32_out, typmod_in, typmod_out], ); #[repr(C, align(8))] -pub struct Vector { +pub struct Vecf32 { varlena: u32, len: u16, - phantom: [Scalar; 0], + phantom: [F32; 0], } -impl Vector { +impl Vecf32 { fn varlena(size: usize) -> u32 { (size << 2) as u32 } fn layout(len: usize) -> Layout { u16::try_from(len).expect("Vector is too large."); - let layout_alpha = Layout::new::(); - let layout_beta = Layout::array::(len).unwrap(); + let layout_alpha = Layout::new::(); + let layout_beta = Layout::array::(len).unwrap(); let layout = layout_alpha.extend(layout_beta).unwrap().0; layout.pad_to_align() } - pub fn new(slice: &[Scalar]) -> Box { + pub fn new(slice: &[F32]) -> Box { unsafe { assert!(u16::try_from(slice.len()).is_ok()); - let layout = Vector::layout(slice.len()); - let ptr = std::alloc::alloc(layout) as *mut Vector; - std::ptr::addr_of_mut!((*ptr).varlena).write(Vector::varlena(layout.size())); + let layout = Vecf32::layout(slice.len()); + let ptr = std::alloc::alloc(layout) as *mut Vecf32; + std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32::varlena(layout.size())); std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16); std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); Box::from_raw(ptr) } } - pub fn new_in_postgres(slice: &[Scalar]) -> VectorOutput { + pub fn new_in_postgres(slice: &[F32]) -> Vecf32Output { unsafe { assert!(u16::try_from(slice.len()).is_ok()); - let layout = Vector::layout(slice.len()); - let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut Vector; - std::ptr::addr_of_mut!((*ptr).varlena).write(Vector::varlena(layout.size())); + let layout = Vecf32::layout(slice.len()); + let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut Vecf32; + std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32::varlena(layout.size())); std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16); std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); - VectorOutput(NonNull::new(ptr).unwrap()) + Vecf32Output(NonNull::new(ptr).unwrap()) } } pub fn new_zeroed(len: usize) -> Box { unsafe { assert!(u16::try_from(len).is_ok()); - let layout = Vector::layout(len); - let ptr = std::alloc::alloc_zeroed(layout) as *mut Vector; - std::ptr::addr_of_mut!((*ptr).varlena).write(Vector::varlena(layout.size())); + let layout = Vecf32::layout(len); + let ptr = std::alloc::alloc_zeroed(layout) as *mut Vecf32; + std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32::varlena(layout.size())); std::ptr::addr_of_mut!((*ptr).len).write(len as u16); Box::from_raw(ptr) } } #[allow(dead_code)] - pub fn new_zeroed_in_postgres(len: usize) -> VectorOutput { + pub fn new_zeroed_in_postgres(len: usize) -> Vecf32Output { unsafe { assert!(u64::try_from(len).is_ok()); - let layout = Vector::layout(len); - let ptr = pgrx::pg_sys::palloc0(layout.size()) as *mut Vector; - std::ptr::addr_of_mut!((*ptr).varlena).write(Vector::varlena(layout.size())); + let layout = Vecf32::layout(len); + let ptr = pgrx::pg_sys::palloc0(layout.size()) as *mut Vecf32; + std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32::varlena(layout.size())); std::ptr::addr_of_mut!((*ptr).len).write(len as u16); - VectorOutput(NonNull::new(ptr).unwrap()) + Vecf32Output(NonNull::new(ptr).unwrap()) } } pub fn len(&self) -> usize { self.len as usize } - pub fn data(&self) -> &[Scalar] { + pub fn data(&self) -> &[F32] { debug_assert_eq!(self.varlena & 3, 0); unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.len as usize) } } - pub fn data_mut(&mut self) -> &mut [Scalar] { + pub fn data_mut(&mut self) -> &mut [F32] { debug_assert_eq!(self.varlena & 3, 0); unsafe { std::slice::from_raw_parts_mut(self.phantom.as_mut_ptr(), self.len as usize) } } #[allow(dead_code)] - pub fn copy(&self) -> Box { - Vector::new(self.data()) + pub fn copy(&self) -> Box { + Vecf32::new(self.data()) } - pub fn copy_into_postgres(&self) -> VectorOutput { - Vector::new_in_postgres(self.data()) + pub fn copy_into_postgres(&self) -> Vecf32Output { + Vecf32::new_in_postgres(self.data()) } } -impl Deref for Vector { - type Target = [Scalar]; +impl Deref for Vecf32 { + type Target = [F32]; fn deref(&self) -> &Self::Target { self.data() } } -impl DerefMut for Vector { +impl DerefMut for Vecf32 { fn deref_mut(&mut self) -> &mut Self::Target { self.data_mut() } } -impl AsRef<[Scalar]> for Vector { - fn as_ref(&self) -> &[Scalar] { +impl AsRef<[F32]> for Vecf32 { + fn as_ref(&self) -> &[F32] { self.data() } } -impl AsMut<[Scalar]> for Vector { - fn as_mut(&mut self) -> &mut [Scalar] { +impl AsMut<[F32]> for Vecf32 { + fn as_mut(&mut self) -> &mut [F32] { self.data_mut() } } -impl Index for Vector { - type Output = Scalar; +impl Index for Vecf32 { + type Output = F32; fn index(&self, index: usize) -> &Self::Output { self.data().index(index) } } -impl IndexMut for Vector { +impl IndexMut for Vecf32 { fn index_mut(&mut self, index: usize) -> &mut Self::Output { self.data_mut().index_mut(index) } } -impl PartialEq for Vector { +impl PartialEq for Vecf32 { fn eq(&self, other: &Self) -> bool { if self.len() != other.len() { return false; @@ -222,15 +172,15 @@ impl PartialEq for Vector { } } -impl Eq for Vector {} +impl Eq for Vecf32 {} -impl PartialOrd for Vector { +impl PartialOrd for Vecf32 { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl Ord for Vector { +impl Ord for Vecf32 { fn cmp(&self, other: &Self) -> Ordering { use Ordering::*; if let x @ Less | x @ Greater = self.len().cmp(&other.len()) { @@ -246,58 +196,60 @@ impl Ord for Vector { } } -pub enum VectorInput<'a> { - Owned(VectorOutput), - Borrowed(&'a Vector), +pub enum Vecf32Input<'a> { + Owned(Vecf32Output), + Borrowed(&'a Vecf32), } -impl<'a> VectorInput<'a> { - pub unsafe fn new(p: NonNull) -> Self { - let q = NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap(); +impl<'a> Vecf32Input<'a> { + pub unsafe fn new(p: NonNull) -> Self { + let q = unsafe { + NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap() + }; if p != q { - VectorInput::Owned(VectorOutput(q)) + Vecf32Input::Owned(Vecf32Output(q)) } else { - VectorInput::Borrowed(p.as_ref()) + unsafe { Vecf32Input::Borrowed(p.as_ref()) } } } } -impl Deref for VectorInput<'_> { - type Target = Vector; +impl Deref for Vecf32Input<'_> { + type Target = Vecf32; fn deref(&self) -> &Self::Target { match self { - VectorInput::Owned(x) => x, - VectorInput::Borrowed(x) => x, + Vecf32Input::Owned(x) => x, + Vecf32Input::Borrowed(x) => x, } } } -pub struct VectorOutput(NonNull); +pub struct Vecf32Output(NonNull); -impl VectorOutput { - pub fn into_raw(self) -> *mut Vector { +impl Vecf32Output { + pub fn into_raw(self) -> *mut Vecf32 { let result = self.0.as_ptr(); std::mem::forget(self); result } } -impl Deref for VectorOutput { - type Target = Vector; +impl Deref for Vecf32Output { + type Target = Vecf32; fn deref(&self) -> &Self::Target { unsafe { self.0.as_ref() } } } -impl DerefMut for VectorOutput { +impl DerefMut for Vecf32Output { fn deref_mut(&mut self) -> &mut Self::Target { unsafe { self.0.as_mut() } } } -impl Drop for VectorOutput { +impl Drop for Vecf32Output { fn drop(&mut self) { unsafe { pgrx::pg_sys::pfree(self.0.as_ptr() as _); @@ -305,18 +257,18 @@ impl Drop for VectorOutput { } } -impl<'a> FromDatum for VectorInput<'a> { +impl<'a> FromDatum for Vecf32Input<'a> { unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { if is_null { None } else { - let ptr = NonNull::new(datum.cast_mut_ptr::()).unwrap(); - Some(VectorInput::new(ptr)) + let ptr = NonNull::new(datum.cast_mut_ptr::()).unwrap(); + unsafe { Some(Vecf32Input::new(ptr)) } } } } -impl IntoDatum for VectorOutput { +impl IntoDatum for Vecf32Output { fn into_datum(self) -> Option { Some(Datum::from(self.into_raw() as *mut ())) } @@ -326,7 +278,7 @@ impl IntoDatum for VectorOutput { } } -unsafe impl SqlTranslatable for VectorInput<'_> { +unsafe impl SqlTranslatable for Vecf32Input<'_> { fn argument_sql() -> Result { Ok(SqlMapping::As(String::from("vector"))) } @@ -335,7 +287,7 @@ unsafe impl SqlTranslatable for VectorInput<'_> { } } -unsafe impl SqlTranslatable for VectorOutput { +unsafe impl SqlTranslatable for Vecf32Output { fn argument_sql() -> Result { Ok(SqlMapping::As(String::from("vector"))) } @@ -345,12 +297,12 @@ unsafe impl SqlTranslatable for VectorOutput { } #[pgrx::pg_extern(immutable, parallel_safe, strict)] -fn vector_in(input: &CStr, _oid: Oid, typmod: i32) -> VectorOutput { +fn vecf32_in(input: &CStr, _oid: Oid, typmod: i32) -> Vecf32Output { fn solve(option: Option, hint: &str) -> T { if let Some(x) = option { x } else { - FriendlyError::BadVectorString { + FriendlyError::BadLiteral { hint: hint.to_string(), } .friendly() @@ -364,8 +316,8 @@ fn vector_in(input: &CStr, _oid: Oid, typmod: i32) -> VectorOutput { } use State::*; let input = input.to_bytes(); - let typmod = VectorTypmod::parse_from_i32(typmod).unwrap(); - let mut vector = Vec::::with_capacity(typmod.dims().unwrap_or(0) as usize); + let typmod = Typmod::parse_from_i32(typmod).unwrap(); + let mut vector = Vec::::with_capacity(typmod.dims().unwrap_or(0) as usize); let mut state = MatchingLeft; let mut token: Option = None; for &c in input { @@ -389,7 +341,7 @@ fn vector_in(input: &CStr, _oid: Oid, typmod: i32) -> VectorOutput { } (_, b' ') => {} _ => { - FriendlyError::BadVectorString { + FriendlyError::BadLiteral { hint: format!("Bad charactor with ascii {:#x}.", c), } .friendly(); @@ -397,28 +349,19 @@ fn vector_in(input: &CStr, _oid: Oid, typmod: i32) -> VectorOutput { } } if state != MatchedRight { - FriendlyError::BadVectorString { + FriendlyError::BadLiteral { hint: "Bad sequence.".to_string(), } .friendly(); } if vector.is_empty() || vector.len() > 65535 { - FriendlyError::BadVecForDims.friendly(); - } - if let Some(dims) = typmod.dims() { - if dims as usize != vector.len() { - FriendlyError::BadVecForUnmatchedDims { - value_dimensions: dims, - type_dimensions: vector.len() as u16, - } - .friendly(); - } + FriendlyError::BadValueDimensions.friendly(); } - Vector::new_in_postgres(&vector) + Vecf32::new_in_postgres(&vector) } #[pgrx::pg_extern(immutable, parallel_safe, strict)] -fn vector_out(vector: VectorInput<'_>) -> CString { +fn vecf32_out(vector: Vecf32Input<'_>) -> CString { let mut buffer = String::new(); buffer.push('['); if let Some(&x) = vector.data().first() { @@ -430,27 +373,3 @@ fn vector_out(vector: VectorInput<'_>) -> CString { buffer.push(']'); CString::new(buffer).unwrap() } - -#[pgrx::pg_extern(immutable, parallel_safe, strict)] -fn vector_typmod_in(list: Array<&CStr>) -> i32 { - if list.is_empty() { - -1 - } else if list.len() == 1 { - let s = list.get(0).unwrap().unwrap().to_str().unwrap(); - let typmod = VectorTypmod::parse_from_str(s) - .ok_or(FriendlyError::BadTypmod) - .friendly(); - typmod.into_i32() - } else { - FriendlyError::BadTypmod.friendly(); - } -} - -#[pgrx::pg_extern(immutable, parallel_safe, strict)] -fn vector_typmod_out(typmod: i32) -> CString { - let typmod = VectorTypmod::parse_from_i32(typmod).unwrap(); - match typmod.into_option_string() { - Some(s) => CString::new(format!("({})", s)).unwrap(), - None => CString::new("()").unwrap(), - } -} diff --git a/src/embedding/udf.rs b/src/embedding/udf.rs index cc6bfb14c..97131bb82 100644 --- a/src/embedding/udf.rs +++ b/src/embedding/udf.rs @@ -1,14 +1,12 @@ use super::openai::{EmbeddingCreator, OpenAIEmbedding}; use super::Embedding; -use crate::postgres::datatype::Vector; -use crate::postgres::datatype::VectorOutput; -use crate::postgres::gucs::OPENAI_API_KEY_GUC; -use crate::prelude::Float; -use crate::prelude::Scalar; +use crate::datatype::vecf32::{Vecf32, Vecf32Output}; +use crate::gucs::OPENAI_API_KEY_GUC; use pgrx::prelude::*; +use service::prelude::F32; #[pg_extern] -fn ai_embedding_vector(input: String) -> VectorOutput { +fn ai_embedding_vector(input: String) -> Vecf32Output { let api_key = match OPENAI_API_KEY_GUC.get() { Some(key) => key .to_str() @@ -26,9 +24,9 @@ fn ai_embedding_vector(input: String) -> VectorOutput { Ok(embedding) => { let embedding = embedding .into_iter() - .map(|x| Scalar(x as Float)) + .map(|x| F32(x as f32)) .collect::>(); - Vector::new_in_postgres(&embedding) + Vecf32::new_in_postgres(&embedding) } Err(e) => { error!("{}", e) diff --git a/src/postgres/gucs.rs b/src/gucs.rs similarity index 100% rename from src/postgres/gucs.rs rename to src/gucs.rs diff --git a/src/postgres/index.rs b/src/index/am.rs similarity index 87% rename from src/postgres/index.rs rename to src/index/am.rs index e8ec4152a..82925d2ad 100644 --- a/src/postgres/index.rs +++ b/src/index/am.rs @@ -1,11 +1,12 @@ -use super::index_build; -use super::index_scan; -use super::index_setup; -use super::index_update; -use crate::postgres::datatype::VectorInput; -use crate::postgres::gucs::ENABLE_VECTOR_INDEX; +use super::am_build; +use super::am_scan; +use super::am_setup; +use super::am_update; +use crate::datatype::vecf32::Vecf32Input; +use crate::gucs::ENABLE_VECTOR_INDEX; use crate::prelude::*; use crate::utils::cells::PgCell; +use service::prelude::*; static RELOPT_KIND: PgCell = unsafe { PgCell::new(0) }; @@ -28,9 +29,7 @@ pub unsafe fn init() { #[pgrx::pg_extern(sql = " CREATE OR REPLACE FUNCTION vectors_amhandler(internal) RETURNS index_am_handler PARALLEL SAFE IMMUTABLE STRICT LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@'; - CREATE ACCESS METHOD vectors TYPE INDEX HANDLER vectors_amhandler; - COMMENT ON ACCESS METHOD vectors IS 'pgvecto.rs index access method'; -", requires = ["vector"])] +", requires = ["vecf32"])] fn vectors_amhandler( _fcinfo: pgrx::pg_sys::FunctionCallInfo, ) -> pgrx::PgBox { @@ -85,7 +84,7 @@ const AM_HANDLER: pgrx::pg_sys::IndexAmRoutine = { #[pgrx::pg_guard] pub unsafe extern "C" fn amvalidate(opclass_oid: pgrx::pg_sys::Oid) -> bool { - index_setup::convert_opclass_to_distance(opclass_oid); + am_setup::convert_opclass_to_distance(opclass_oid); true } @@ -99,7 +98,7 @@ pub unsafe extern "C" fn amoptions( let tab: &[pgrx::pg_sys::relopt_parse_elt] = &[pgrx::pg_sys::relopt_parse_elt { optname: "options".as_pg_cstr(), opttype: pgrx::pg_sys::relopt_type_RELOPT_TYPE_STRING, - offset: index_setup::helper_offset() as i32, + offset: am_setup::helper_offset() as i32, }]; let mut noptions = 0; let options = @@ -111,10 +110,10 @@ pub unsafe extern "C" fn amoptions( relopt.gen.as_mut().unwrap().lockmode = pgrx::pg_sys::AccessExclusiveLock as pgrx::pg_sys::LOCKMODE; } - let rdopts = pgrx::pg_sys::allocateReloptStruct(index_setup::helper_size(), options, noptions); + let rdopts = pgrx::pg_sys::allocateReloptStruct(am_setup::helper_size(), options, noptions); pgrx::pg_sys::fillRelOptions( rdopts, - index_setup::helper_size(), + am_setup::helper_size(), options, noptions, validate, @@ -136,13 +135,13 @@ pub unsafe extern "C" fn amoptions( let tab: &[pgrx::pg_sys::relopt_parse_elt] = &[pgrx::pg_sys::relopt_parse_elt { optname: "options".as_pg_cstr(), opttype: pgrx::pg_sys::relopt_type_RELOPT_TYPE_STRING, - offset: index_setup::helper_offset() as i32, + offset: am_setup::helper_offset() as i32, }]; let rdopts = pgrx::pg_sys::build_reloptions( reloptions, validate, RELOPT_KIND.get(), - index_setup::helper_size(), + am_setup::helper_size(), tab.as_ptr(), tab.len() as _, ); @@ -182,7 +181,7 @@ pub unsafe extern "C" fn ambuild( index_info: *mut pgrx::pg_sys::IndexInfo, ) -> *mut pgrx::pg_sys::IndexBuildResult { let result = pgrx::PgBox::::alloc0(); - index_build::build( + am_build::build( index_relation, Some((heap_relation, index_info, result.as_ptr())), ); @@ -191,7 +190,7 @@ pub unsafe extern "C" fn ambuild( #[pgrx::pg_guard] pub unsafe extern "C" fn ambuildempty(index_relation: pgrx::pg_sys::Relation) { - index_build::build(index_relation, None); + am_build::build(index_relation, None); } #[cfg(any(feature = "pg12", feature = "pg13"))] @@ -208,9 +207,9 @@ pub unsafe extern "C" fn aminsert( use pgrx::FromDatum; let oid = (*index_relation).rd_node.relNode; let id = Id::from_sys(oid); - let vector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap(); + let vector = Vecf32Input::from_datum(*values.add(0), *is_null.add(0)).unwrap(); let vector = vector.data().to_vec(); - index_update::update_insert(id, vector, *heap_tid); + am_update::update_insert(id, vector.into(), *heap_tid); true } @@ -232,9 +231,9 @@ pub unsafe extern "C" fn aminsert( #[cfg(feature = "pg16")] let oid = (*index_relation).rd_locator.relNumber; let id = Id::from_sys(oid); - let vector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap(); + let vector = Vecf32Input::from_datum(*values.add(0), *is_null.add(0)).unwrap(); let vector = vector.data().to_vec(); - index_update::update_insert(id, vector, *heap_tid); + am_update::update_insert(id, vector.into(), *heap_tid); true } @@ -244,7 +243,7 @@ pub unsafe extern "C" fn ambeginscan( n_keys: std::os::raw::c_int, n_order_bys: std::os::raw::c_int, ) -> pgrx::pg_sys::IndexScanDesc { - index_scan::make_scan(index_relation, n_keys, n_order_bys) + am_scan::make_scan(index_relation, n_keys, n_order_bys) } #[pgrx::pg_guard] @@ -255,7 +254,7 @@ pub unsafe extern "C" fn amrescan( orderbys: pgrx::pg_sys::ScanKey, n_orderbys: std::os::raw::c_int, ) { - index_scan::start_scan(scan, keys, n_keys, orderbys, n_orderbys); + am_scan::start_scan(scan, keys, n_keys, orderbys, n_orderbys); } #[pgrx::pg_guard] @@ -264,7 +263,7 @@ pub unsafe extern "C" fn amgettuple( direction: pgrx::pg_sys::ScanDirection, ) -> bool { assert!(direction == pgrx::pg_sys::ScanDirection_ForwardScanDirection); - index_scan::next_scan(scan) + am_scan::next_scan(scan) } #[pgrx::pg_guard] @@ -283,7 +282,7 @@ pub unsafe extern "C" fn ambulkdelete( let oid = (*(*info).index).rd_locator.relNumber; let id = Id::from_sys(oid); if let Some(callback) = callback { - index_update::update_delete(id, |pointer| { + am_update::update_delete(id, |pointer| { callback( &mut pointer.into_sys() as *mut pgrx::pg_sys::ItemPointerData, callback_state, diff --git a/src/postgres/index_build.rs b/src/index/am_build.rs similarity index 65% rename from src/postgres/index_build.rs rename to src/index/am_build.rs index f6b41b114..104b0dfd4 100644 --- a/src/postgres/index_build.rs +++ b/src/index/am_build.rs @@ -1,11 +1,12 @@ -use super::hook_transaction::{client, flush_if_commit}; -use crate::ipc::client::Rpc; -use crate::postgres::index_setup::options; +use super::{client::ClientGuard, hook_transaction::flush_if_commit}; +use crate::datatype::vecf32::Vecf32Input; +use crate::index::am_setup::options; use crate::prelude::*; use pgrx::pg_sys::{IndexBuildResult, IndexInfo, RelationData}; +use service::prelude::*; pub struct Builder { - pub rpc: Rpc, + pub client: ClientGuard, pub heap_relation: *mut RelationData, pub index_info: *mut IndexInfo, pub result: *mut IndexBuildResult, @@ -22,27 +23,22 @@ pub unsafe fn build( let id = Id::from_sys(oid); flush_if_commit(id); let options = options(index); - client(|mut rpc| { - rpc.create(id, options).friendly(); - rpc - }); + let mut client = super::client::borrow_mut(); + client.create(id, options); if let Some((heap_relation, index_info, result)) = data { - client(|rpc| { - let mut builder = Builder { - rpc, - heap_relation, - index_info, - result, - }; - pgrx::pg_sys::IndexBuildHeapScan( - heap_relation, - index, - index_info, - Some(callback), - &mut builder, - ); - builder.rpc - }); + let mut builder = Builder { + client, + heap_relation, + index_info, + result, + }; + pgrx::pg_sys::IndexBuildHeapScan( + heap_relation, + index, + index_info, + Some(callback), + &mut builder, + ); } } @@ -56,17 +52,14 @@ unsafe extern "C" fn callback( _tuple_is_alive: bool, state: *mut std::os::raw::c_void, ) { - use super::datatype::VectorInput; use pgrx::FromDatum; - let ctid = &(*htup).t_self; - let oid = (*index_relation).rd_node.relNode; let id = Id::from_sys(oid); let state = &mut *(state as *mut Builder); - let pgvector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap(); - let data = (pgvector.to_vec(), Pointer::from_sys(*ctid)); - state.rpc.insert(id, data).friendly().friendly(); + let pgvector = Vecf32Input::from_datum(*values.add(0), *is_null.add(0)).unwrap(); + let data = (pgvector.to_vec().into(), Pointer::from_sys(*ctid)); + state.client.insert(id, data); (*state.result).heap_tuples += 1.0; (*state.result).index_tuples += 1.0; } @@ -81,18 +74,16 @@ unsafe extern "C" fn callback( _tuple_is_alive: bool, state: *mut std::os::raw::c_void, ) { - use super::datatype::VectorInput; use pgrx::FromDatum; - #[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15"))] let oid = (*index_relation).rd_node.relNode; #[cfg(feature = "pg16")] let oid = (*index_relation).rd_locator.relNumber; let id = Id::from_sys(oid); let state = &mut *(state as *mut Builder); - let pgvector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap(); - let data = (pgvector.to_vec(), Pointer::from_sys(*ctid)); - state.rpc.insert(id, data).friendly().friendly(); + let pgvector = Vecf32Input::from_datum(*values.add(0), *is_null.add(0)).unwrap(); + let data = (pgvector.to_vec().into(), Pointer::from_sys(*ctid)); + state.client.insert(id, data); (*state.result).heap_tuples += 1.0; (*state.result).index_tuples += 1.0; } diff --git a/src/postgres/index_scan.rs b/src/index/am_scan.rs similarity index 76% rename from src/postgres/index_scan.rs rename to src/index/am_scan.rs index b5251d900..a887f9c67 100644 --- a/src/postgres/index_scan.rs +++ b/src/index/am_scan.rs @@ -1,15 +1,15 @@ -use super::gucs::ENABLE_PREFILTER; -use super::hook_transaction::client; -use crate::postgres::datatype::VectorInput; -use crate::postgres::gucs::K; +use crate::datatype::vecf32::Vecf32Input; +use crate::gucs::ENABLE_PREFILTER; +use crate::gucs::K; use crate::prelude::*; use pgrx::FromDatum; +use service::prelude::*; #[derive(Debug, Clone)] pub enum Scanner { Initial { // fields to be filled by amhandler and hook - vector: Option>, + vector: Option>, index_scan_state: Option<*mut pgrx::pg_sys::IndexScanState>, }, Type0 { @@ -81,7 +81,7 @@ pub unsafe fn start_scan( } let orderby = orderbys.add(0); let argument = (*orderby).sk_argument; - let vector = VectorInput::from_datum(argument, false).unwrap(); + let vector = Vecf32Input::from_datum(argument, false).unwrap(); let vector = vector.to_vec(); let last = (*((*scan).opaque as *mut Scanner)).clone(); @@ -136,53 +136,48 @@ pub unsafe fn next_scan(scan: pgrx::pg_sys::IndexScanDesc) -> bool { let oid = (*(*scan).indexRelation).rd_locator.relNumber; let id = Id::from_sys(oid); let vector = vector.expect("`rescan` is never called."); - if index_scan_state.is_some() && ENABLE_PREFILTER.get() { - client(|rpc| { - let index_scan_state = index_scan_state.unwrap(); - let k = K.get() as _; - let mut handler = rpc.search(id, (vector, k), true).friendly(); - let mut res; - let rpc = loop { - use crate::ipc::client::SearchHandle::*; - match handler.handle().friendly() { - Check { p, x } => { - let result = check(index_scan_state, p); - handler = x.leave(result).friendly(); - } - Leave { result, x } => { - res = result.friendly(); - break x; - } - } - }; - res.reverse(); - *scanner = Scanner::Type1 { - index_scan_state, - data: res, - }; - rpc - }); + let mut client = super::client::borrow_mut(); + let k = K.get() as _; + if index_scan_state.is_some() { + struct ClientSearch { + node: *mut pgrx::pg_sys::IndexScanState, + } + + impl crate::ipc::client::ClientSearch for ClientSearch { + fn check(&mut self, p: Pointer) -> bool { + unsafe { check(self.node, p) } + } + } + + let client_search = ClientSearch { + node: index_scan_state.unwrap(), + }; + + let mut result = client.search( + id, + (vector.into(), k), + ENABLE_PREFILTER.get(), + client_search, + ); + result.reverse(); + *scanner = Scanner::Type1 { + index_scan_state: index_scan_state.unwrap(), + data: result, + }; } else { - client(|rpc| { - let k = K.get() as _; - let handler = rpc.search(id, (vector, k), false).friendly(); - let mut res; - let rpc = { - use crate::ipc::client::SearchHandle::*; - match handler.handle().friendly() { - Check { .. } => { - unreachable!() - } - Leave { result, x } => { - res = result.friendly(); - x - } - } - }; - res.reverse(); - *scanner = Scanner::Type0 { data: res }; - rpc - }); + struct ClientSearch {} + + impl crate::ipc::client::ClientSearch for ClientSearch { + fn check(&mut self, _: Pointer) -> bool { + unreachable!() + } + } + + let client_search = ClientSearch {}; + + let mut result = client.search(id, (vector.into(), k), false, client_search); + result.reverse(); + *scanner = Scanner::Type0 { data: result }; } } match scanner { diff --git a/src/postgres/index_setup.rs b/src/index/am_setup.rs similarity index 79% rename from src/postgres/index_setup.rs rename to src/index/am_setup.rs index 433fd6819..c4c89d06e 100644 --- a/src/postgres/index_setup.rs +++ b/src/index/am_setup.rs @@ -1,22 +1,22 @@ -use crate::index::indexing::IndexingOptions; -use crate::index::optimizing::OptimizingOptions; -use crate::index::segments::SegmentsOptions; -use crate::index::{IndexOptions, VectorOptions}; -use crate::postgres::datatype::VectorTypmod; -use crate::prelude::*; +use crate::datatype::typmod::Typmod; use serde::{Deserialize, Serialize}; +use service::index::indexing::IndexingOptions; +use service::index::optimizing::OptimizingOptions; +use service::index::segments::SegmentsOptions; +use service::index::{IndexOptions, VectorOptions}; +use service::prelude::*; use std::ffi::CStr; use validator::Validate; pub fn helper_offset() -> usize { - memoffset::offset_of!(Helper, offset) + std::mem::offset_of!(Helper, offset) } pub fn helper_size() -> usize { std::mem::size_of::() } -pub unsafe fn convert_opclass_to_distance(opclass: pgrx::pg_sys::Oid) -> Distance { +pub unsafe fn convert_opclass_to_distance(opclass: pgrx::pg_sys::Oid) -> (Distance, Kind) { let opclass_cache_id = pgrx::pg_sys::SysCacheIdentifier_CLAOID as _; let tuple = pgrx::pg_sys::SearchSysCache1(opclass_cache_id, opclass.into()); assert!( @@ -25,12 +25,12 @@ pub unsafe fn convert_opclass_to_distance(opclass: pgrx::pg_sys::Oid) -> Distanc ); let classform = pgrx::pg_sys::GETSTRUCT(tuple).cast::(); let opfamily = (*classform).opcfamily; - let distance = convert_opfamily_to_distance(opfamily); + let result = convert_opfamily_to_distance(opfamily); pgrx::pg_sys::ReleaseSysCache(tuple); - distance + result } -pub unsafe fn convert_opfamily_to_distance(opfamily: pgrx::pg_sys::Oid) -> Distance { +pub unsafe fn convert_opfamily_to_distance(opfamily: pgrx::pg_sys::Oid) -> (Distance, Kind) { let opfamily_cache_id = pgrx::pg_sys::SysCacheIdentifier_OPFAMILYOID as _; let opstrategy_cache_id = pgrx::pg_sys::SysCacheIdentifier_AMOPSTRATEGY as _; let tuple = pgrx::pg_sys::SearchSysCache1(opfamily_cache_id, opfamily.into()); @@ -52,19 +52,19 @@ pub unsafe fn convert_opfamily_to_distance(opfamily: pgrx::pg_sys::Oid) -> Dista assert!((*amop).amopstrategy == 1); assert!((*amop).amoppurpose == pgrx::pg_sys::AMOP_ORDER as libc::c_char); let operator = (*amop).amopopr; - let distance; + let result; if operator == regoperatorin("<->(vector,vector)") { - distance = Distance::L2; + result = (Distance::L2, Kind::F32); } else if operator == regoperatorin("<#>(vector,vector)") { - distance = Distance::Dot; + result = (Distance::Dot, Kind::F32); } else if operator == regoperatorin("<=>(vector,vector)") { - distance = Distance::Cosine; + result = (Distance::Cos, Kind::F32); } else { - FriendlyError::UnsupportedOperator.friendly(); + FriendlyError::BadOptions3.friendly(); }; pgrx::pg_sys::ReleaseCatCacheList(list); pgrx::pg_sys::ReleaseSysCache(tuple); - distance + result } pub unsafe fn options(index_relation: pgrx::pg_sys::Relation) -> IndexOptions { @@ -72,22 +72,25 @@ pub unsafe fn options(index_relation: pgrx::pg_sys::Relation) -> IndexOptions { assert!(nkeysatts == 1, "Can not be built on multicolumns."); // get distance let opfamily = (*index_relation).rd_opfamily.read(); - let d = convert_opfamily_to_distance(opfamily); + let (d, k) = convert_opfamily_to_distance(opfamily); // get dims let attrs = (*(*index_relation).rd_att).attrs.as_slice(1); let attr = &attrs[0]; - let typmod = VectorTypmod::parse_from_i32(attr.type_mod()).unwrap(); - let dims = typmod.dims().ok_or(FriendlyError::DimsIsNeeded).friendly(); + let typmod = Typmod::parse_from_i32(attr.type_mod()).unwrap(); + let dims = typmod.dims().ok_or(FriendlyError::BadOption2).friendly(); // get other options let parsed = get_parsed_from_varlena((*index_relation).rd_options); let options = IndexOptions { - vector: VectorOptions { dims, d }, + vector: VectorOptions { dims, d, k }, segment: parsed.segment, optimizing: parsed.optimizing, indexing: parsed.indexing, }; if let Err(errors) = options.validate() { - FriendlyError::BadOption(errors.to_string()).friendly(); + FriendlyError::BadOption { + validation: errors.to_string(), + } + .friendly(); } options } diff --git a/src/index/am_update.rs b/src/index/am_update.rs new file mode 100644 index 000000000..78fd302a5 --- /dev/null +++ b/src/index/am_update.rs @@ -0,0 +1,31 @@ +use crate::index::hook_transaction::flush_if_commit; +use crate::prelude::*; +use service::prelude::*; + +pub fn update_insert(id: Id, vector: DynamicVector, tid: pgrx::pg_sys::ItemPointerData) { + flush_if_commit(id); + let p = Pointer::from_sys(tid); + let mut client = super::client::borrow_mut(); + client.insert(id, (vector, p)); +} + +pub fn update_delete(id: Id, hook: impl Fn(Pointer) -> bool) { + struct ClientDelete { + hook: H, + } + + impl crate::ipc::client::ClientDelete for ClientDelete + where + H: Fn(Pointer) -> bool, + { + fn test(&mut self, p: Pointer) -> bool { + (self.hook)(p) + } + } + + let client_delete = ClientDelete { hook }; + + flush_if_commit(id); + let mut client = super::client::borrow_mut(); + client.delete(id, client_delete); +} diff --git a/src/index/client.rs b/src/index/client.rs new file mode 100644 index 000000000..082d3939b --- /dev/null +++ b/src/index/client.rs @@ -0,0 +1,42 @@ +use crate::gucs::{Transport, TRANSPORT}; +use crate::ipc::client::Client; +use crate::utils::cells::PgRefCell; +use std::cell::RefMut; +use std::ops::{Deref, DerefMut}; + +static CLIENT: PgRefCell> = unsafe { PgRefCell::new(None) }; + +pub fn borrow_mut() -> ClientGuard { + let mut x = CLIENT.borrow_mut(); + if x.is_none() { + *x = Some(match TRANSPORT.get() { + Transport::unix => crate::ipc::connect_unix(), + Transport::mmap => crate::ipc::connect_mmap(), + }); + } + ClientGuard(x) +} + +pub struct ClientGuard(RefMut<'static, Option>); + +impl Drop for ClientGuard { + fn drop(&mut self) { + if std::thread::panicking() { + self.0.take(); + } + } +} + +impl Deref for ClientGuard { + type Target = Client; + + fn deref(&self) -> &Self::Target { + self.0.as_ref().unwrap() + } +} + +impl DerefMut for ClientGuard { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.as_mut().unwrap() + } +} diff --git a/src/postgres/hook_executor.rs b/src/index/hook_executor.rs similarity index 95% rename from src/postgres/hook_executor.rs rename to src/index/hook_executor.rs index 5ec684512..f4515f5cc 100644 --- a/src/postgres/hook_executor.rs +++ b/src/index/hook_executor.rs @@ -1,4 +1,4 @@ -use crate::postgres::index_scan::Scanner; +use crate::index::am_scan::Scanner; use std::ptr::null_mut; pub unsafe fn post_executor_start(query_desc: *mut pgrx::pg_sys::QueryDesc) { @@ -20,7 +20,7 @@ unsafe extern "C" fn rewrite_plan_state( if index_relation .as_ref() .and_then(|p| p.rd_indam.as_ref()) - .map(|p| p.amvalidate == Some(super::index::amvalidate)) + .map(|p| p.amvalidate == Some(super::am::amvalidate)) .unwrap_or(false) { // The logic is copied from Postgres source code. diff --git a/src/index/hook_transaction.rs b/src/index/hook_transaction.rs new file mode 100644 index 000000000..b80c0f60c --- /dev/null +++ b/src/index/hook_transaction.rs @@ -0,0 +1,26 @@ +use crate::utils::cells::PgRefCell; +use service::prelude::*; +use std::collections::BTreeSet; + +static FLUSH_IF_COMMIT: PgRefCell> = unsafe { PgRefCell::new(BTreeSet::new()) }; + +pub fn aborting() { + *FLUSH_IF_COMMIT.borrow_mut() = BTreeSet::new(); +} + +pub fn committing() { + { + let flush_if_commit = FLUSH_IF_COMMIT.borrow(); + if flush_if_commit.len() != 0 { + let mut client = super::client::borrow_mut(); + for id in flush_if_commit.iter().copied() { + client.flush(id); + } + } + } + *FLUSH_IF_COMMIT.borrow_mut() = BTreeSet::new(); +} + +pub fn flush_if_commit(id: Id) { + FLUSH_IF_COMMIT.borrow_mut().insert(id); +} diff --git a/src/postgres/hooks.rs b/src/index/hooks.rs similarity index 89% rename from src/postgres/hooks.rs rename to src/index/hooks.rs index f652bf0ea..951707b7d 100644 --- a/src/postgres/hooks.rs +++ b/src/index/hooks.rs @@ -1,5 +1,5 @@ -use crate::postgres::hook_transaction::client; use crate::prelude::*; +use service::prelude::*; static mut PREV_EXECUTOR_START: pgrx::pg_sys::ExecutorStart_hook_type = None; @@ -46,10 +46,8 @@ unsafe fn xact_delete() { .iter() .map(|node| Id::from_sys(node.relNode)) .collect::>(); - client(|mut rpc| { - rpc.destory(ids).friendly(); - rpc - }); + let mut client = super::client::borrow_mut(); + client.destory(ids); } } @@ -63,9 +61,7 @@ unsafe fn xact_delete() { .iter() .map(|node| Id::from_sys(node.relNumber)) .collect::>(); - client(|mut rpc| { - rpc.destory(ids).friendly(); - rpc - }); + let mut client = super::client::borrow_mut(); + client.destory(ids); } } diff --git a/src/index/mod.rs b/src/index/mod.rs index d2d1cc9dc..f04116b73 100644 --- a/src/index/mod.rs +++ b/src/index/mod.rs @@ -1,450 +1,17 @@ -pub mod delete; -pub mod indexing; -pub mod optimizing; -pub mod segments; - -use self::delete::Delete; -use self::indexing::IndexingOptions; -use self::optimizing::OptimizingOptions; -use self::segments::growing::GrowingSegment; -use self::segments::growing::GrowingSegmentInsertError; -use self::segments::sealed::SealedSegment; -use self::segments::SegmentsOptions; -use crate::prelude::*; -use crate::utils::clean::clean; -use crate::utils::dir_ops::sync_dir; -use crate::utils::file_atomic::FileAtomic; -use arc_swap::ArcSwap; -use crossbeam::sync::Parker; -use crossbeam::sync::Unparker; -use parking_lot::Mutex; -use serde::{Deserialize, Serialize}; -use std::cmp::Reverse; -use std::collections::BinaryHeap; -use std::collections::HashMap; -use std::collections::HashSet; -use std::path::PathBuf; -use std::sync::{Arc, Weak}; -use thiserror::Error; -use uuid::Uuid; -use validator::Validate; - -#[derive(Debug, Error)] -pub enum IndexInsertError { - #[error("The vector is invalid.")] - InvalidVector(Vec), - #[error("The index view is outdated.")] - OutdatedView(#[from] Option), -} - -#[derive(Debug, Error)] -pub enum IndexSearchError { - #[error("The vector is invalid.")] - InvalidVector(Vec), -} - -#[derive(Debug, Clone, Serialize, Deserialize, Validate)] -pub struct VectorOptions { - #[validate(range(min = 1, max = 65535))] - #[serde(rename = "dimensions")] - pub dims: u16, - #[serde(rename = "distance")] - pub d: Distance, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Validate)] -pub struct IndexOptions { - #[validate] - pub vector: VectorOptions, - #[validate] - pub segment: SegmentsOptions, - #[validate] - pub optimizing: OptimizingOptions, - #[validate] - pub indexing: IndexingOptions, -} - -pub struct Index { - path: PathBuf, - options: IndexOptions, - delete: Arc, - protect: Mutex, - view: ArcSwap, - optimize_unparker: Unparker, - indexing: Mutex, - _tracker: Arc, -} - -impl Index { - pub fn create(path: PathBuf, options: IndexOptions) -> Arc { - assert!(options.validate().is_ok()); - std::fs::create_dir(&path).unwrap(); - std::fs::create_dir(path.join("segments")).unwrap(); - let startup = FileAtomic::create( - path.join("startup"), - IndexStartup { - sealeds: HashSet::new(), - growings: HashSet::new(), - }, - ); - let delete = Delete::create(path.join("delete")); - sync_dir(&path); - let parker = Parker::new(); - let index = Arc::new(Index { - path: path.clone(), - options: options.clone(), - delete: delete.clone(), - protect: Mutex::new(IndexProtect { - startup, - sealed: HashMap::new(), - growing: HashMap::new(), - write: None, - }), - view: ArcSwap::new(Arc::new(IndexView { - options: options.clone(), - sealed: HashMap::new(), - growing: HashMap::new(), - delete: delete.clone(), - write: None, - })), - optimize_unparker: parker.unparker().clone(), - indexing: Mutex::new(true), - _tracker: Arc::new(IndexTracker { path }), - }); - IndexBackground { - index: Arc::downgrade(&index), - parker, - } - .spawn(); - index - } - pub fn open(path: PathBuf, options: IndexOptions) -> Arc { - let tracker = Arc::new(IndexTracker { path: path.clone() }); - let startup = FileAtomic::::open(path.join("startup")); - clean( - path.join("segments"), - startup - .get() - .sealeds - .iter() - .map(|s| s.to_string()) - .chain(startup.get().growings.iter().map(|s| s.to_string())), - ); - let sealed = startup - .get() - .sealeds - .iter() - .map(|&uuid| { - ( - uuid, - SealedSegment::open( - tracker.clone(), - path.join("segments").join(uuid.to_string()), - uuid, - options.clone(), - ), - ) - }) - .collect::>(); - let growing = startup - .get() - .growings - .iter() - .map(|&uuid| { - ( - uuid, - GrowingSegment::open( - tracker.clone(), - path.join("segments").join(uuid.to_string()), - uuid, - options.clone(), - ), - ) - }) - .collect::>(); - let delete = Delete::open(path.join("delete")); - let parker = Parker::new(); - let index = Arc::new(Index { - path: path.clone(), - options: options.clone(), - delete: delete.clone(), - protect: Mutex::new(IndexProtect { - startup, - sealed: sealed.clone(), - growing: growing.clone(), - write: None, - }), - view: ArcSwap::new(Arc::new(IndexView { - options: options.clone(), - delete: delete.clone(), - sealed, - growing, - write: None, - })), - optimize_unparker: parker.unparker().clone(), - indexing: Mutex::new(true), - _tracker: tracker, - }); - IndexBackground { - index: Arc::downgrade(&index), - parker, - } - .spawn(); - index - } - pub fn options(&self) -> &IndexOptions { - &self.options - } - pub fn view(&self) -> Arc { - self.view.load_full() - } - pub fn refresh(&self) { - let mut protect = self.protect.lock(); - if let Some((uuid, write)) = protect.write.clone() { - write.seal(); - protect.growing.insert(uuid, write); - } - let write_segment_uuid = Uuid::new_v4(); - let write_segment = GrowingSegment::create( - self._tracker.clone(), - self.path - .join("segments") - .join(write_segment_uuid.to_string()), - write_segment_uuid, - self.options.clone(), - ); - protect.write = Some((write_segment_uuid, write_segment)); - protect.maintain(self.options.clone(), self.delete.clone(), &self.view); - self.optimize_unparker.unpark(); - } - pub fn indexing(&self) -> bool { - *self.indexing.lock() - } -} - -impl Drop for Index { - fn drop(&mut self) { - self.optimize_unparker.unpark(); - } -} - -#[derive(Debug, Clone)] -pub struct IndexTracker { - path: PathBuf, -} - -impl Drop for IndexTracker { - fn drop(&mut self) { - std::fs::remove_dir_all(&self.path).unwrap(); - } -} - -pub struct IndexView { - options: IndexOptions, - delete: Arc, - sealed: HashMap>, - growing: HashMap>, - write: Option<(Uuid, Arc)>, -} - -impl IndexView { - pub fn sealed_len(&self) -> u32 { - self.sealed.values().map(|x| x.len()).sum::() - } - pub fn growing_len(&self) -> u32 { - self.growing.values().map(|x| x.len()).sum::() - } - pub fn write_len(&self) -> u32 { - self.write.as_ref().map(|x| x.1.len()).unwrap_or(0) - } - pub fn sealed_len_vec(&self) -> Vec { - self.sealed.values().map(|x| x.len()).collect() - } - pub fn growing_len_vec(&self) -> Vec { - self.growing.values().map(|x| x.len()).collect() - } - pub fn search bool>( - &self, - k: usize, - vector: &[Scalar], - mut filter: F, - ) -> Result, IndexSearchError> { - if self.options.vector.dims as usize != vector.len() { - return Err(IndexSearchError::InvalidVector(vector.to_vec())); - } - - struct Comparer(BinaryHeap>); - - impl PartialEq for Comparer { - fn eq(&self, other: &Self) -> bool { - self.cmp(other).is_eq() - } - } - - impl Eq for Comparer {} - - impl PartialOrd for Comparer { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } - } - - impl Ord for Comparer { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.0.peek().cmp(&other.0.peek()).reverse() - } - } - - let mut filter = |payload| { - if let Some(p) = self.delete.check(payload) { - filter(p) - } else { - false - } - }; - let n = self.sealed.len() + self.growing.len() + 1; - let mut result = Heap::new(k); - let mut heaps = BinaryHeap::with_capacity(1 + n); - for (_, sealed) in self.sealed.iter() { - let p = sealed.search(k, vector, &mut filter).into_reversed_heap(); - heaps.push(Comparer(p)); - } - for (_, growing) in self.growing.iter() { - let p = growing.search(k, vector, &mut filter).into_reversed_heap(); - heaps.push(Comparer(p)); - } - if let Some((_, write)) = &self.write { - let p = write.search(k, vector, &mut filter).into_reversed_heap(); - heaps.push(Comparer(p)); - } - while let Some(Comparer(mut heap)) = heaps.pop() { - if let Some(Reverse(x)) = heap.pop() { - result.push(x); - heaps.push(Comparer(heap)); - } - } - Ok(result - .into_sorted_vec() - .iter() - .map(|x| Pointer::from_u48(x.payload >> 16)) - .collect()) - } - pub fn insert(&self, vector: Vec, pointer: Pointer) -> Result<(), IndexInsertError> { - if self.options.vector.dims as usize != vector.len() { - return Err(IndexInsertError::InvalidVector(vector)); - } - let payload = (pointer.as_u48() << 16) | self.delete.version(pointer) as Payload; - if let Some((_, growing)) = self.write.as_ref() { - Ok(growing.insert(vector, payload)?) - } else { - Err(IndexInsertError::OutdatedView(None)) - } - } - pub fn delete bool>(&self, mut f: F) { - for (_, sealed) in self.sealed.iter() { - let n = sealed.len(); - for i in 0..n { - if let Some(p) = self.delete.check(sealed.payload(i)) { - if f(p) { - self.delete.delete(p); - } - } - } - } - for (_, growing) in self.growing.iter() { - let n = growing.len(); - for i in 0..n { - if let Some(p) = self.delete.check(growing.payload(i)) { - if f(p) { - self.delete.delete(p); - } - } - } - } - if let Some((_, write)) = &self.write { - let n = write.len(); - for i in 0..n { - if let Some(p) = self.delete.check(write.payload(i)) { - if f(p) { - self.delete.delete(p); - } - } - } - } - } - pub fn flush(&self) -> Result<(), IndexInsertError> { - self.delete.flush(); - if let Some((_, write)) = &self.write { - write.flush(); - } - Ok(()) - } -} - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -struct IndexStartup { - sealeds: HashSet, - growings: HashSet, -} - -struct IndexProtect { - startup: FileAtomic, - sealed: HashMap>, - growing: HashMap>, - write: Option<(Uuid, Arc)>, -} - -impl IndexProtect { - fn maintain(&mut self, options: IndexOptions, delete: Arc, swap: &ArcSwap) { - let view: Arc = Arc::new(IndexView { - options, - delete, - sealed: self.sealed.clone(), - growing: self.growing.clone(), - write: self.write.clone(), - }); - let startup_write = self.write.as_ref().map(|(uuid, _)| *uuid); - let startup_sealeds = self.sealed.keys().copied().collect(); - let startup_growings = self.growing.keys().copied().chain(startup_write).collect(); - self.startup.set(IndexStartup { - sealeds: startup_sealeds, - growings: startup_growings, - }); - swap.swap(view); - } -} - -pub struct IndexBackground { - index: Weak, - parker: Parker, -} - -impl IndexBackground { - pub fn main(self) { - let pool; - if let Some(index) = self.index.upgrade() { - pool = rayon::ThreadPoolBuilder::new() - .num_threads(index.options.optimizing.optimizing_threads) - .build() - .unwrap(); - } else { - return; - } - while let Some(index) = self.index.upgrade() { - let done = pool.install(|| optimizing::indexing::optimizing_indexing(index.clone())); - if done { - *index.indexing.lock() = false; - drop(index); - self.parker.park(); - if let Some(index) = self.index.upgrade() { - *index.indexing.lock() = true; - } - } - } - } - pub fn spawn(self) { - std::thread::spawn(move || { - self.main(); - }); - } +#![allow(unsafe_op_in_unsafe_fn)] + +mod am; +mod am_build; +mod am_scan; +mod am_setup; +mod am_update; +mod client; +mod hook_executor; +mod hook_transaction; +mod hooks; +mod views; + +pub unsafe fn init() { + self::hooks::init(); + self::am::init(); } diff --git a/src/index/views.rs b/src/index/views.rs new file mode 100644 index 000000000..53106a286 --- /dev/null +++ b/src/index/views.rs @@ -0,0 +1,46 @@ +use crate::prelude::*; +use service::prelude::*; + +pgrx::extension_sql!( + "\ +CREATE TYPE VectorIndexStat AS ( + idx_indexing BOOL, + idx_tuples BIGINT, + idx_sealed BIGINT[], + idx_growing BIGINT[], + idx_write BIGINT, + idx_options TEXT +);", + name = "create_composites", +); + +#[pgrx::pg_extern(volatile, strict)] +fn vector_stat(oid: pgrx::pg_sys::Oid) -> pgrx::composite_type!("VectorIndexStat") { + let id = Id::from_sys(oid); + let mut res = pgrx::prelude::PgHeapTuple::new_composite_type("VectorIndexStat").unwrap(); + let mut client = super::client::borrow_mut(); + let stat = client.stat(id); + res.set_by_name("idx_indexing", stat.indexing).unwrap(); + res.set_by_name("idx_tuples", { + let mut tuples = 0; + tuples += stat.sealed.iter().map(|x| *x as i64).sum::(); + tuples += stat.growing.iter().map(|x| *x as i64).sum::(); + tuples += stat.write as i64; + tuples + }) + .unwrap(); + res.set_by_name("idx_sealed", { + let sealed = stat.sealed; + sealed.into_iter().map(|x| x as i64).collect::>() + }) + .unwrap(); + res.set_by_name("idx_growing", { + let growing = stat.growing; + growing.into_iter().map(|x| x as i64).collect::>() + }) + .unwrap(); + res.set_by_name("idx_write", stat.write as i64).unwrap(); + res.set_by_name("idx_options", serde_json::to_string(&stat.options)) + .unwrap(); + res +} diff --git a/src/ipc/client.rs b/src/ipc/client.rs index d344aabd4..fc7a17cda 100644 --- a/src/ipc/client.rs +++ b/src/ipc/client.rs @@ -1,168 +1,93 @@ -use crate::index::IndexOptions; -use crate::ipc::packet::*; -use crate::ipc::transport::Socket; -use crate::ipc::IpcError; -use crate::prelude::*; +use super::packet::*; +use super::transport::Socket; +use service::index::IndexOptions; +use service::index::IndexStat; +use service::prelude::*; -pub struct Rpc { +pub struct Client { socket: Socket, } -impl Rpc { - pub(super) fn new(socket: Socket) -> Self { +impl Client { + pub fn new(socket: Socket) -> Self { Self { socket } } - pub fn create(&mut self, id: Id, options: IndexOptions) -> Result<(), IpcError> { + pub fn create(&mut self, id: Id, options: IndexOptions) { let packet = RpcPacket::Create { id, options }; - self.socket.send(packet)?; - let CreatePacket::Leave {} = self.socket.recv::()?; - Ok(()) + self.socket.send(packet).friendly(); + let CreatePacket::Leave {} = self.socket.recv::().friendly(); } pub fn search( - mut self, + &mut self, id: Id, - search: (Vec, usize), + search: (DynamicVector, usize), prefilter: bool, - ) -> Result { + mut t: impl ClientSearch, + ) -> Vec { let packet = RpcPacket::Search { id, search, prefilter, }; - self.socket.send(packet)?; - Ok(SearchHandler { - socket: self.socket, - }) + self.socket.send(packet).friendly(); + loop { + match self.socket.recv::().friendly() { + SearchPacket::Check { p } => { + self.socket + .send(SearchCheckPacket::Leave { result: t.check(p) }) + .friendly(); + } + SearchPacket::Leave { result } => { + return result.friendly(); + } + } + } } - pub fn delete(mut self, id: Id) -> Result { + pub fn delete(&mut self, id: Id, mut t: impl ClientDelete) { let packet = RpcPacket::Delete { id }; - self.socket.send(packet)?; - Ok(DeleteHandler { - socket: self.socket, - }) + self.socket.send(packet).friendly(); + loop { + match self.socket.recv::().friendly() { + DeletePacket::Test { p } => { + self.socket + .send(DeleteTestPacket::Leave { delete: t.test(p) }) + .friendly(); + } + DeletePacket::Leave { result } => { + return result.friendly(); + } + } + } } - pub fn insert( - &mut self, - id: Id, - insert: (Vec, Pointer), - ) -> Result, IpcError> { + pub fn insert(&mut self, id: Id, insert: (DynamicVector, Pointer)) { let packet = RpcPacket::Insert { id, insert }; - self.socket.send(packet)?; - let InsertPacket::Leave { result } = self.socket.recv::()?; - Ok(result) + self.socket.send(packet).friendly(); + let InsertPacket::Leave { result } = self.socket.recv::().friendly(); + result.friendly() } - pub fn flush(&mut self, id: Id) -> Result, IpcError> { + pub fn flush(&mut self, id: Id) { let packet = RpcPacket::Flush { id }; - self.socket.send(packet)?; - let FlushPacket::Leave { result } = self.socket.recv::()?; - Ok(result) + self.socket.send(packet).friendly(); + let FlushPacket::Leave { result } = self.socket.recv::().friendly(); + result.friendly() } - pub fn destory(&mut self, ids: Vec) -> Result<(), IpcError> { + pub fn destory(&mut self, ids: Vec) { let packet = RpcPacket::Destory { ids }; - self.socket.send(packet)?; - let DestoryPacket::Leave {} = self.socket.recv::()?; - Ok(()) + self.socket.send(packet).friendly(); + let DestoryPacket::Leave {} = self.socket.recv::().friendly(); } - pub fn stat(&mut self, id: Id) -> Result, IpcError> { + pub fn stat(&mut self, id: Id) -> IndexStat { let packet = RpcPacket::Stat { id }; - self.socket.send(packet)?; - let StatPacket::Leave { result } = self.socket.recv::()?; - Ok(result) - } -} - -pub enum SearchHandle { - Check { - p: Pointer, - x: SearchCheck, - }, - Leave { - result: Result, FriendlyError>, - x: Rpc, - }, -} - -pub struct SearchHandler { - socket: Socket, -} - -impl SearchHandler { - pub fn handle(mut self) -> Result { - Ok(match self.socket.recv::()? { - SearchPacket::Check { p } => SearchHandle::Check { - p, - x: SearchCheck { - socket: self.socket, - }, - }, - SearchPacket::Leave { result } => SearchHandle::Leave { - result, - x: Rpc { - socket: self.socket, - }, - }, - }) + self.socket.send(packet).friendly(); + let StatPacket::Leave { result } = self.socket.recv::().friendly(); + result.friendly() } } -pub struct SearchCheck { - socket: Socket, +pub trait ClientSearch { + fn check(&mut self, p: Pointer) -> bool; } -impl SearchCheck { - pub fn leave(mut self, result: bool) -> Result { - let packet = SearchCheckPacket::Leave { result }; - self.socket.send(packet)?; - Ok(SearchHandler { - socket: self.socket, - }) - } -} - -pub enum DeleteHandle { - Next { - p: Pointer, - x: DeleteNext, - }, - Leave { - result: Result<(), FriendlyError>, - x: Rpc, - }, -} - -pub struct DeleteHandler { - socket: Socket, -} - -impl DeleteHandler { - pub fn handle(mut self) -> Result { - Ok(match self.socket.recv::()? { - DeletePacket::Next { p } => DeleteHandle::Next { - p, - x: DeleteNext { - socket: self.socket, - }, - }, - DeletePacket::Leave { result } => DeleteHandle::Leave { - result, - x: Rpc { - socket: self.socket, - }, - }, - }) - } -} - -pub struct DeleteNext { - socket: Socket, -} - -impl DeleteNext { - pub fn leave(mut self, delete: bool) -> Result { - let packet = DeleteNextPacket::Leave { delete }; - self.socket.send(packet)?; - Ok(DeleteHandler { - socket: self.socket, - }) - } +pub trait ClientDelete { + fn test(&mut self, p: Pointer) -> bool; } diff --git a/src/ipc/mod.rs b/src/ipc/mod.rs index 63682a3b7..b4db5b13d 100644 --- a/src/ipc/mod.rs +++ b/src/ipc/mod.rs @@ -3,8 +3,9 @@ mod packet; pub mod server; pub mod transport; -use self::client::Rpc; +use self::client::Client; use self::server::RpcHandler; +use service::prelude::*; use thiserror::Error; #[derive(Debug, Clone, Error)] @@ -18,6 +19,12 @@ Please check the full Postgresql log to get more information.\ Closed, } +impl FriendlyErrorLike for IpcError { + fn friendly(self) -> ! { + panic!("pgvecto.rs: {}", self); + } +} + pub fn listen_unix() -> impl Iterator { std::iter::from_fn(move || { let socket = self::transport::Socket::Unix(self::transport::unix::accept()); @@ -32,12 +39,17 @@ pub fn listen_mmap() -> impl Iterator { }) } -pub fn connect_unix() -> Rpc { +pub fn connect_unix() -> Client { let socket = self::transport::Socket::Unix(self::transport::unix::connect()); - self::client::Rpc::new(socket) + Client::new(socket) } -pub fn connect_mmap() -> Rpc { +pub fn connect_mmap() -> Client { let socket = self::transport::Socket::Mmap(self::transport::mmap::connect()); - self::client::Rpc::new(socket) + Client::new(socket) +} + +pub fn init() { + self::transport::mmap::init(); + self::transport::unix::init(); } diff --git a/src/ipc/packet.rs b/src/ipc/packet.rs index 6aa88bff7..2de6ff48a 100644 --- a/src/ipc/packet.rs +++ b/src/ipc/packet.rs @@ -1,6 +1,7 @@ -use crate::index::IndexOptions; -use crate::prelude::*; use serde::{Deserialize, Serialize}; +use service::index::IndexOptions; +use service::index::IndexStat; +use service::prelude::*; #[derive(Debug, Serialize, Deserialize)] pub enum RpcPacket { @@ -16,14 +17,14 @@ pub enum RpcPacket { }, Insert { id: Id, - insert: (Vec, Pointer), + insert: (DynamicVector, Pointer), }, Delete { id: Id, }, Search { id: Id, - search: (Vec, usize), + search: (DynamicVector, usize), prefilter: bool, }, Stat { @@ -54,12 +55,12 @@ pub enum InsertPacket { #[derive(Debug, Serialize, Deserialize)] pub enum DeletePacket { - Next { p: Pointer }, + Test { p: Pointer }, Leave { result: Result<(), FriendlyError> }, } #[derive(Debug, Serialize, Deserialize)] -pub enum DeleteNextPacket { +pub enum DeleteTestPacket { Leave { delete: bool }, } @@ -81,6 +82,6 @@ pub enum SearchCheckPacket { #[derive(Debug, Serialize, Deserialize)] pub enum StatPacket { Leave { - result: Result, + result: Result, }, } diff --git a/src/ipc/server.rs b/src/ipc/server.rs index dab29da28..4c06ccced 100644 --- a/src/ipc/server.rs +++ b/src/ipc/server.rs @@ -1,8 +1,9 @@ -use crate::index::IndexOptions; -use crate::ipc::packet::*; -use crate::ipc::transport::Socket; -use crate::ipc::IpcError; -use crate::prelude::*; +use super::packet::*; +use super::transport::Socket; +use super::IpcError; +use service::index::IndexOptions; +use service::index::IndexStat; +use service::prelude::*; pub struct RpcHandler { socket: Socket, @@ -77,13 +78,13 @@ pub enum RpcHandle { }, Search { id: Id, - search: (Vec, usize), + search: (DynamicVector, usize), prefilter: bool, x: Search, }, Insert { id: Id, - insert: (Vec, Pointer), + insert: (DynamicVector, Pointer), x: Insert, }, Delete { @@ -139,9 +140,9 @@ pub struct Delete { impl Delete { pub fn next(&mut self, p: Pointer) -> Result { - let packet = DeletePacket::Next { p }; + let packet = DeletePacket::Test { p }; self.socket.send(packet)?; - let DeleteNextPacket::Leave { delete } = self.socket.recv::()?; + let DeleteTestPacket::Leave { delete } = self.socket.recv::()?; Ok(delete) } pub fn leave(mut self, result: Result<(), FriendlyError>) -> Result { @@ -211,7 +212,7 @@ pub struct Stat { impl Stat { pub fn leave( mut self, - result: Result, + result: Result, ) -> Result { let packet = StatPacket::Leave { result }; self.socket.send(packet)?; diff --git a/src/ipc/transport/mmap.rs b/src/ipc/transport/mmap.rs index d1bcbc5ca..a240dca11 100644 --- a/src/ipc/transport/mmap.rs +++ b/src/ipc/transport/mmap.rs @@ -1,4 +1,4 @@ -use crate::ipc::IpcError; +use super::IpcError; use crate::utils::file_socket::FileSocket; use crate::utils::os::{futex_wait, futex_wake, memfd_create, mmap_populate}; use rustix::fd::{AsFd, OwnedFd}; @@ -65,10 +65,7 @@ impl Socket { Err(e) => panic!("{:?}", e), } } - pub fn send(&mut self, packet: T) -> Result<(), IpcError> - where - T: Serialize, - { + pub fn send(&mut self, packet: T) -> Result<(), IpcError> { let buffer = bincode::serialize(&packet).expect("Failed to serialize"); unsafe { if self.is_server { @@ -79,10 +76,7 @@ impl Socket { } Ok(()) } - pub fn recv(&mut self) -> Result - where - T: for<'a> Deserialize<'a>, - { + pub fn recv Deserialize<'a>>(&mut self) -> Result { let buffer = unsafe { if self.is_server { (*self.addr).server_recv(|| self.test())? @@ -106,7 +100,7 @@ struct Channel { futex: AtomicU32, } -static_assertions::assert_eq_size!(Channel, [u8; BUFFER_SIZE]); +const _: () = assert!(std::mem::size_of::() == BUFFER_SIZE); impl Channel { unsafe fn client_recv(&self, test: impl Fn() -> bool) -> Result, IpcError> { @@ -132,30 +126,40 @@ impl Channel { { break; } - futex_wait(&self.futex, Y); + unsafe { + futex_wait(&self.futex, Y); + } } Y => { if !test() { return Err(IpcError::Closed); } - futex_wait(&self.futex, Y); + unsafe { + futex_wait(&self.futex, Y); + } } - _ => std::hint::unreachable_unchecked(), + _ => unsafe { std::hint::unreachable_unchecked() }, } } - let len = *self.len.get(); - let res = (*self.bytes.get())[0..len as usize].to_vec(); - Ok(res) + unsafe { + let len = *self.len.get(); + let res = (*self.bytes.get())[0..len as usize].to_vec(); + Ok(res) + } } unsafe fn client_send(&self, data: &[u8]) { const S: u32 = 0; const T: u32 = 1; const X: u32 = 2; debug_assert!(matches!(self.futex.load(Ordering::Relaxed), S | X)); - *self.len.get() = data.len() as u32; - (*self.bytes.get())[0..data.len()].copy_from_slice(data); + unsafe { + *self.len.get() = data.len() as u32; + (*self.bytes.get())[0..data.len()].copy_from_slice(data); + } if X == self.futex.swap(T, Ordering::Release) { - futex_wake(&self.futex); + unsafe { + futex_wake(&self.futex); + } } } unsafe fn server_recv(&self, test: impl Fn() -> bool) -> Result, IpcError> { @@ -181,30 +185,40 @@ impl Channel { { break; } - futex_wait(&self.futex, Y); + unsafe { + futex_wait(&self.futex, Y); + } } Y => { if !test() { return Err(IpcError::Closed); } - futex_wait(&self.futex, Y); + unsafe { + futex_wait(&self.futex, Y); + } } - _ => std::hint::unreachable_unchecked(), + _ => unsafe { std::hint::unreachable_unchecked() }, } } - let len = *self.len.get(); - let res = (*self.bytes.get())[0..len as usize].to_vec(); - Ok(res) + unsafe { + let len = *self.len.get(); + let res = (*self.bytes.get())[0..len as usize].to_vec(); + Ok(res) + } } unsafe fn server_send(&self, data: &[u8]) { const S: u32 = 1; const T: u32 = 0; const X: u32 = 3; debug_assert!(matches!(self.futex.load(Ordering::Relaxed), S | X)); - *self.len.get() = data.len() as u32; - (*self.bytes.get())[0..data.len()].copy_from_slice(data); + unsafe { + *self.len.get() = data.len() as u32; + (*self.bytes.get())[0..data.len()].copy_from_slice(data); + } if X == self.futex.swap(T, Ordering::Release) { - futex_wake(&self.futex); + unsafe { + futex_wake(&self.futex); + } } } } diff --git a/src/ipc/transport/unix.rs b/src/ipc/transport/unix.rs index 708afa3bf..8ba654baf 100644 --- a/src/ipc/transport/unix.rs +++ b/src/ipc/transport/unix.rs @@ -1,4 +1,4 @@ -use crate::ipc::IpcError; +use super::IpcError; use crate::utils::file_socket::FileSocket; use byteorder::{ReadBytesExt, WriteBytesExt}; use rustix::fd::AsFd; @@ -40,10 +40,7 @@ macro_rules! resolve_closed { } impl Socket { - pub fn send(&mut self, packet: T) -> Result<(), IpcError> - where - T: Serialize, - { + pub fn send(&mut self, packet: T) -> Result<(), IpcError> { use byteorder::NativeEndian as N; let buffer = bincode::serialize(&packet).expect("Failed to serialize"); let len = u32::try_from(buffer.len()).expect("Packet is too large."); @@ -51,10 +48,7 @@ impl Socket { resolve_closed!(self.stream.write_all(&buffer)); Ok(()) } - pub fn recv(&mut self) -> Result - where - T: for<'a> Deserialize<'a>, - { + pub fn recv Deserialize<'a>>(&mut self) -> Result { use byteorder::NativeEndian as N; let len = resolve_closed!(self.stream.read_u32::()); let mut buffer = vec![0u8; len as usize]; diff --git a/src/lib.rs b/src/lib.rs index 068d24b98..36593e6b5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,14 @@ //! Postgres vector extension. //! //! Provides an easy-to-use extension for vector similarity search. -#![feature(core_intrinsics)] +#![feature(offset_of)] -mod algorithms; mod bgworker; +mod datatype; mod embedding; +mod gucs; mod index; mod ipc; -mod postgres; mod prelude; mod utils; @@ -19,27 +19,16 @@ pgrx::extension_sql_file!("./sql/finalize.sql", finalize); #[allow(non_snake_case)] #[pgrx::pg_guard] unsafe extern "C" fn _PG_init() { - use crate::prelude::*; - if pgrx::pg_sys::IsUnderPostmaster { + use service::prelude::*; + if unsafe { pgrx::pg_sys::IsUnderPostmaster } { FriendlyError::BadInit.friendly(); } - use pgrx::bgworkers::BackgroundWorkerBuilder; - use pgrx::bgworkers::BgWorkerStartTime; - BackgroundWorkerBuilder::new("vectors") - .set_function("vectors_main") - .set_library("vectors") - .set_argument(None) - .enable_shmem_access(None) - .set_start_time(BgWorkerStartTime::PostmasterStart) - .load(); - self::postgres::init(); - self::ipc::transport::unix::init(); - self::ipc::transport::mmap::init(); -} - -#[no_mangle] -extern "C" fn vectors_main(_arg: pgrx::pg_sys::Datum) { - let _ = std::panic::catch_unwind(crate::bgworker::main); + unsafe { + self::gucs::init(); + self::index::init(); + self::ipc::init(); + self::bgworker::init(); + } } #[cfg(not(any(target_os = "linux", target_os = "macos")))] diff --git a/src/postgres/casts.rs b/src/postgres/casts.rs deleted file mode 100644 index 26d1cf924..000000000 --- a/src/postgres/casts.rs +++ /dev/null @@ -1,21 +0,0 @@ -use super::datatype::{Vector, VectorInput, VectorOutput, VectorTypmod}; -use crate::prelude::Scalar; - -#[pgrx::pg_extern(immutable, parallel_safe, strict)] -fn cast_array_to_vector(array: pgrx::Array, typmod: i32, _explicit: bool) -> VectorOutput { - assert!(!array.is_empty()); - assert!(array.len() <= 65535); - assert!(!array.contains_nulls()); - let typmod = VectorTypmod::parse_from_i32(typmod).unwrap(); - let len = typmod.dims().unwrap_or(array.len() as u16); - let mut data = Vector::new_zeroed_in_postgres(len as usize); - for (i, x) in array.iter().enumerate() { - data[i] = x.unwrap_or(Scalar::NAN); - } - data -} - -#[pgrx::pg_extern(immutable, parallel_safe, strict)] -fn cast_vector_to_array(vector: VectorInput<'_>, _typmod: i32, _explicit: bool) -> Vec { - vector.data().to_vec() -} diff --git a/src/postgres/hook_transaction.rs b/src/postgres/hook_transaction.rs deleted file mode 100644 index daa083f57..000000000 --- a/src/postgres/hook_transaction.rs +++ /dev/null @@ -1,48 +0,0 @@ -use super::gucs::Transport; -use super::gucs::TRANSPORT; -use crate::ipc::client::Rpc; -use crate::ipc::{connect_mmap, connect_unix}; -use crate::prelude::*; -use crate::utils::cells::PgRefCell; -use std::collections::BTreeSet; - -static FLUSH_IF_COMMIT: PgRefCell> = unsafe { PgRefCell::new(BTreeSet::new()) }; - -static CLIENT: PgRefCell> = unsafe { PgRefCell::new(None) }; - -pub fn aborting() { - *FLUSH_IF_COMMIT.borrow_mut() = BTreeSet::new(); -} - -pub fn committing() { - { - let flush_if_commit = FLUSH_IF_COMMIT.borrow(); - if flush_if_commit.len() != 0 { - client(|mut rpc| { - for id in flush_if_commit.iter().copied() { - rpc.flush(id).friendly().friendly(); - } - - rpc - }); - } - } - *FLUSH_IF_COMMIT.borrow_mut() = BTreeSet::new(); -} - -pub fn flush_if_commit(id: Id) { - FLUSH_IF_COMMIT.borrow_mut().insert(id); -} - -pub fn client(f: F) -where - F: FnOnce(Rpc) -> Rpc, -{ - let mut guard = CLIENT.borrow_mut(); - let client = guard.take().unwrap_or_else(|| match TRANSPORT.get() { - Transport::unix => connect_unix(), - Transport::mmap => connect_mmap(), - }); - let client = f(client); - *guard = Some(client); -} diff --git a/src/postgres/index_update.rs b/src/postgres/index_update.rs deleted file mode 100644 index e9d861920..000000000 --- a/src/postgres/index_update.rs +++ /dev/null @@ -1,31 +0,0 @@ -use crate::postgres::hook_transaction::{client, flush_if_commit}; -use crate::prelude::*; - -pub fn update_insert(id: Id, vector: Vec, tid: pgrx::pg_sys::ItemPointerData) { - flush_if_commit(id); - let p = Pointer::from_sys(tid); - client(|mut rpc| { - rpc.insert(id, (vector, p)).friendly().friendly(); - rpc - }) -} - -pub fn update_delete(id: Id, hook: impl Fn(Pointer) -> bool) { - flush_if_commit(id); - client(|rpc| { - use crate::ipc::client::DeleteHandle; - let mut handler = rpc.delete(id).friendly(); - loop { - let handle = handler.handle().friendly(); - match handle { - DeleteHandle::Next { p, x } => { - handler = x.leave(hook(p)).friendly(); - } - DeleteHandle::Leave { result, x } => { - result.friendly(); - break x; - } - } - } - }) -} diff --git a/src/postgres/mod.rs b/src/postgres/mod.rs deleted file mode 100644 index 0d240478d..000000000 --- a/src/postgres/mod.rs +++ /dev/null @@ -1,19 +0,0 @@ -mod casts; -pub mod datatype; -pub mod gucs; -mod hook_executor; -mod hook_transaction; -mod hooks; -mod index; -mod index_build; -mod index_scan; -mod index_setup; -mod index_update; -mod operators; -mod stat; - -pub unsafe fn init() { - self::gucs::init(); - self::hooks::init(); - self::index::init(); -} diff --git a/src/postgres/stat.rs b/src/postgres/stat.rs deleted file mode 100644 index 72f9b57d5..000000000 --- a/src/postgres/stat.rs +++ /dev/null @@ -1,38 +0,0 @@ -use super::hook_transaction::client; -use crate::prelude::*; - -pgrx::extension_sql!( - "\ -CREATE TYPE VectorIndexInfo AS ( - indexing BOOL, - idx_tuples INT, - idx_sealed_len INT, - idx_growing_len INT, - idx_write INT, - idx_sealed INT[], - idx_growing INT[], - idx_config TEXT -);", - name = "create_composites", -); - -#[pgrx::pg_extern(volatile, strict)] -fn vector_stat(oid: pgrx::pg_sys::Oid) -> pgrx::composite_type!("VectorIndexInfo") { - let id = Id::from_sys(oid); - let mut res = pgrx::prelude::PgHeapTuple::new_composite_type("VectorIndexInfo").unwrap(); - client(|mut rpc| { - let rpc_res = rpc.stat(id).unwrap().friendly(); - res.set_by_name("indexing", rpc_res.indexing).unwrap(); - res.set_by_name("idx_tuples", rpc_res.idx_tuples).unwrap(); - res.set_by_name("idx_sealed_len", rpc_res.idx_sealed_len) - .unwrap(); - res.set_by_name("idx_growing_len", rpc_res.idx_growing_len) - .unwrap(); - res.set_by_name("idx_write", rpc_res.idx_write).unwrap(); - res.set_by_name("idx_sealed", rpc_res.idx_sealed).unwrap(); - res.set_by_name("idx_growing", rpc_res.idx_growing).unwrap(); - res.set_by_name("idx_config", rpc_res.idx_config).unwrap(); - rpc - }); - res -} diff --git a/src/prelude.rs b/src/prelude.rs new file mode 100644 index 000000000..7a18d2954 --- /dev/null +++ b/src/prelude.rs @@ -0,0 +1,39 @@ +use service::prelude::*; + +pub trait FromSys { + fn from_sys(sys: T) -> Self; +} + +impl FromSys for Id { + fn from_sys(sys: pgrx::pg_sys::Oid) -> Self { + Self { + newtype: sys.as_u32(), + } + } +} + +impl FromSys for Pointer { + fn from_sys(sys: pgrx::pg_sys::ItemPointerData) -> Self { + let mut newtype = 0; + newtype |= (sys.ip_blkid.bi_hi as u64) << 32; + newtype |= (sys.ip_blkid.bi_lo as u64) << 16; + newtype |= sys.ip_posid as u64; + Self { newtype } + } +} + +pub trait IntoSys { + fn into_sys(self) -> T; +} + +impl IntoSys for Pointer { + fn into_sys(self) -> pgrx::pg_sys::ItemPointerData { + pgrx::pg_sys::ItemPointerData { + ip_blkid: pgrx::pg_sys::BlockIdData { + bi_hi: ((self.newtype >> 32) & 0xffff) as u16, + bi_lo: ((self.newtype >> 16) & 0xffff) as u16, + }, + ip_posid: (self.newtype & 0xffff) as u16, + } + } +} diff --git a/src/prelude/distance.rs b/src/prelude/distance.rs deleted file mode 100644 index 433b386ed..000000000 --- a/src/prelude/distance.rs +++ /dev/null @@ -1,516 +0,0 @@ -use crate::prelude::*; -use serde::{Deserialize, Serialize}; -use std::fmt::Debug; - -mod sealed { - pub trait Sealed {} -} - -#[repr(u8)] -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] -pub enum Distance { - L2, - Cosine, - Dot, -} - -impl Distance { - pub fn distance(self, lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { - match self { - Distance::L2 => distance_squared_l2(lhs, rhs), - Distance::Cosine => distance_cosine(lhs, rhs) * (-1.0), - Distance::Dot => distance_dot(lhs, rhs) * (-1.0), - } - } - pub fn elkan_k_means_normalize(self, vector: &mut [Scalar]) { - match self { - Distance::L2 => (), - Distance::Cosine => l2_normalize(vector), - Distance::Dot => l2_normalize(vector), - } - } - pub fn elkan_k_means_distance(self, lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { - match self { - Distance::L2 => distance_squared_l2(lhs, rhs).sqrt(), - Distance::Cosine => distance_dot(lhs, rhs).acos(), - Distance::Dot => distance_dot(lhs, rhs).acos(), - } - } - pub fn scalar_quantization_distance( - self, - dims: u16, - max: &[Scalar], - min: &[Scalar], - lhs: &[Scalar], - rhs: &[u8], - ) -> Scalar { - scalar_quantization_distance(self, dims, max, min, lhs, rhs) - } - pub fn scalar_quantization_distance2( - self, - dims: u16, - max: &[Scalar], - min: &[Scalar], - lhs: &[u8], - rhs: &[u8], - ) -> Scalar { - scalar_quantization_distance2(self, dims, max, min, lhs, rhs) - } - pub fn product_quantization_distance( - self, - dims: u16, - ratio: u16, - centroids: &[Scalar], - lhs: &[Scalar], - rhs: &[u8], - ) -> Scalar { - product_quantization_distance(self, dims, ratio, centroids, lhs, rhs) - } - pub fn product_quantization_distance2( - self, - dims: u16, - ratio: u16, - centroids: &[Scalar], - lhs: &[u8], - rhs: &[u8], - ) -> Scalar { - product_quantization_distance2(self, dims, ratio, centroids, lhs, rhs) - } - pub fn product_quantization_distance_with_delta( - self, - dims: u16, - ratio: u16, - centroids: &[Scalar], - lhs: &[Scalar], - rhs: &[u8], - delta: &[Scalar], - ) -> Scalar { - product_quantization_distance_with_delta(self, dims, ratio, centroids, lhs, rhs, delta) - } -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn distance_squared_l2(lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { - if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } - let n = lhs.len(); - let mut d2 = Scalar::Z; - for i in 0..n { - let d = lhs[i] - rhs[i]; - d2 += d * d; - } - d2 -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn distance_cosine(lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { - if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } - let n = lhs.len(); - let mut xy = Scalar::Z; - let mut x2 = Scalar::Z; - let mut y2 = Scalar::Z; - for i in 0..n { - xy += lhs[i] * rhs[i]; - x2 += lhs[i] * lhs[i]; - y2 += rhs[i] * rhs[i]; - } - xy / (x2 * y2).sqrt() -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn distance_dot(lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { - if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } - let n = lhs.len(); - let mut xy = Scalar::Z; - for i in 0..n { - xy += lhs[i] * rhs[i]; - } - xy -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn xy_x2_y2(lhs: &[Scalar], rhs: &[Scalar]) -> (Scalar, Scalar, Scalar) { - if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } - let n = lhs.len(); - let mut xy = Scalar::Z; - let mut x2 = Scalar::Z; - let mut y2 = Scalar::Z; - for i in 0..n { - xy += lhs[i] * rhs[i]; - x2 += lhs[i] * lhs[i]; - y2 += rhs[i] * rhs[i]; - } - (xy, x2, y2) -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn length(vector: &[Scalar]) -> Scalar { - let n = vector.len(); - let mut dot = Scalar::Z; - for i in 0..n { - dot += vector[i] * vector[i]; - } - dot.sqrt() -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn l2_normalize(vector: &mut [Scalar]) { - let n = vector.len(); - let l = length(vector); - for i in 0..n { - vector[i] /= l; - } -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn distance_squared_l2_delta(lhs: &[Scalar], rhs: &[Scalar], del: &[Scalar]) -> Scalar { - if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } - let n = lhs.len(); - let mut d2 = Scalar::Z; - for i in 0..n { - let d = lhs[i] - (rhs[i] + del[i]); - d2 += d * d; - } - d2 -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn xy_x2_y2_delta(lhs: &[Scalar], rhs: &[Scalar], del: &[Scalar]) -> (Scalar, Scalar, Scalar) { - if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } - let n = lhs.len(); - let mut xy = Scalar::Z; - let mut x2 = Scalar::Z; - let mut y2 = Scalar::Z; - for i in 0..n { - xy += lhs[i] * (rhs[i] + del[i]); - x2 += lhs[i] * lhs[i]; - y2 += (rhs[i] + del[i]) * (rhs[i] + del[i]); - } - (xy, x2, y2) -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn distance_dot_delta(lhs: &[Scalar], rhs: &[Scalar], del: &[Scalar]) -> Scalar { - if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } - let n = lhs.len(); - let mut xy = Scalar::Z; - for i in 0..n { - xy += lhs[i] * (rhs[i] + del[i]); - } - xy -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn scalar_quantization_distance( - distance: Distance, - dims: u16, - max: &[Scalar], - min: &[Scalar], - lhs: &[Scalar], - rhs: &[u8], -) -> Scalar { - match distance { - Distance::L2 => { - let mut result = Scalar::Z; - for i in 0..dims as usize { - let _x = lhs[i]; - let _y = Scalar(rhs[i] as Float / 256.0) * (max[i] - min[i]) + min[i]; - result += (_x - _y) * (_x - _y); - } - result - } - Distance::Cosine => { - let mut xy = Scalar::Z; - let mut x2 = Scalar::Z; - let mut y2 = Scalar::Z; - for i in 0..dims as usize { - let _x = lhs[i]; - let _y = Scalar(rhs[i] as Float / 256.0) * (max[i] - min[i]) + min[i]; - xy += _x * _y; - x2 += _x * _x; - y2 += _y * _y; - } - xy / (x2 * y2).sqrt() * (-1.0) - } - Distance::Dot => { - let mut xy = Scalar::Z; - for i in 0..dims as usize { - let _x = lhs[i]; - let _y = Scalar(rhs[i] as Float / 256.0) * (max[i] - min[i]) + min[i]; - xy += _x * _y; - } - xy * (-1.0) - } - } -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn scalar_quantization_distance2( - distance: Distance, - dims: u16, - max: &[Scalar], - min: &[Scalar], - lhs: &[u8], - rhs: &[u8], -) -> Scalar { - match distance { - Distance::L2 => { - let mut result = Scalar::Z; - for i in 0..dims as usize { - let _x = Scalar(lhs[i] as Float / 256.0) * (max[i] - min[i]) + min[i]; - let _y = Scalar(rhs[i] as Float / 256.0) * (max[i] - min[i]) + min[i]; - result += (_x - _y) * (_x - _y); - } - result - } - Distance::Cosine => { - let mut xy = Scalar::Z; - let mut x2 = Scalar::Z; - let mut y2 = Scalar::Z; - for i in 0..dims as usize { - let _x = Scalar(lhs[i] as Float / 256.0) * (max[i] - min[i]) + min[i]; - let _y = Scalar(rhs[i] as Float / 256.0) * (max[i] - min[i]) + min[i]; - xy += _x * _y; - x2 += _x * _x; - y2 += _y * _y; - } - xy / (x2 * y2).sqrt() * (-1.0) - } - Distance::Dot => { - let mut xy = Scalar::Z; - for i in 0..dims as usize { - let _x = Scalar(lhs[i] as Float / 256.0) * (max[i] - min[i]) + min[i]; - let _y = Scalar(rhs[i] as Float / 256.0) * (max[i] - min[i]) + min[i]; - xy += _x * _y; - } - xy * (-1.0) - } - } -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn product_quantization_distance( - distance: Distance, - dims: u16, - ratio: u16, - centroids: &[Scalar], - lhs: &[Scalar], - rhs: &[u8], -) -> Scalar { - match distance { - Distance::L2 => { - let width = dims.div_ceil(ratio); - let mut result = Scalar::Z; - for i in 0..width { - let k = std::cmp::min(ratio, dims - ratio * i); - let lhs = &lhs[(i * ratio) as usize..][..k as usize]; - let rhsp = rhs[i as usize] as usize * dims as usize; - let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - result += distance_squared_l2(lhs, rhs); - } - result - } - Distance::Cosine => { - let width = dims.div_ceil(ratio); - let mut xy = Scalar::Z; - let mut x2 = Scalar::Z; - let mut y2 = Scalar::Z; - for i in 0..width { - let k = std::cmp::min(ratio, dims - ratio * i); - let lhs = &lhs[(i * ratio) as usize..][..k as usize]; - let rhsp = rhs[i as usize] as usize * dims as usize; - let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs); - xy += _xy; - x2 += _x2; - y2 += _y2; - } - xy / (x2 * y2).sqrt() * (-1.0) - } - Distance::Dot => { - let width = dims.div_ceil(ratio); - let mut xy = Scalar::Z; - for i in 0..width { - let k = std::cmp::min(ratio, dims - ratio * i); - let lhs = &lhs[(i * ratio) as usize..][..k as usize]; - let rhsp = rhs[i as usize] as usize * dims as usize; - let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let _xy = distance_dot(lhs, rhs); - xy += _xy; - } - xy * (-1.0) - } - } -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn product_quantization_distance2( - distance: Distance, - dims: u16, - ratio: u16, - centroids: &[Scalar], - lhs: &[u8], - rhs: &[u8], -) -> Scalar { - match distance { - Distance::L2 => { - let width = dims.div_ceil(ratio); - let mut result = Scalar::Z; - for i in 0..width { - let k = std::cmp::min(ratio, dims - ratio * i); - let lhsp = lhs[i as usize] as usize * dims as usize; - let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; - let rhsp = rhs[i as usize] as usize * dims as usize; - let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - result += distance_squared_l2(lhs, rhs); - } - result - } - Distance::Cosine => { - let width = dims.div_ceil(ratio); - let mut xy = Scalar::Z; - let mut x2 = Scalar::Z; - let mut y2 = Scalar::Z; - for i in 0..width { - let k = std::cmp::min(ratio, dims - ratio * i); - let lhsp = lhs[i as usize] as usize * dims as usize; - let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; - let rhsp = rhs[i as usize] as usize * dims as usize; - let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs); - xy += _xy; - x2 += _x2; - y2 += _y2; - } - xy / (x2 * y2).sqrt() * (-1.0) - } - Distance::Dot => { - let width = dims.div_ceil(ratio); - let mut xy = Scalar::Z; - for i in 0..width { - let k = std::cmp::min(ratio, dims - ratio * i); - let lhsp = lhs[i as usize] as usize * dims as usize; - let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; - let rhsp = rhs[i as usize] as usize * dims as usize; - let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let _xy = distance_dot(lhs, rhs); - xy += _xy; - } - xy * (-1.0) - } - } -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn product_quantization_distance_with_delta( - distance: Distance, - dims: u16, - ratio: u16, - centroids: &[Scalar], - lhs: &[Scalar], - rhs: &[u8], - delta: &[Scalar], -) -> Scalar { - match distance { - Distance::L2 => { - let width = dims.div_ceil(ratio); - let mut result = Scalar::Z; - for i in 0..width { - let k = std::cmp::min(ratio, dims - ratio * i); - let lhs = &lhs[(i * ratio) as usize..][..k as usize]; - let rhsp = rhs[i as usize] as usize * dims as usize; - let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let del = &delta[(i * ratio) as usize..][..k as usize]; - result += distance_squared_l2_delta(lhs, rhs, del); - } - result - } - Distance::Cosine => { - let width = dims.div_ceil(ratio); - let mut xy = Scalar::Z; - let mut x2 = Scalar::Z; - let mut y2 = Scalar::Z; - for i in 0..width { - let k = std::cmp::min(ratio, dims - ratio * i); - let lhs = &lhs[(i * ratio) as usize..][..k as usize]; - let rhsp = rhs[i as usize] as usize * dims as usize; - let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let del = &delta[(i * ratio) as usize..][..k as usize]; - let (_xy, _x2, _y2) = xy_x2_y2_delta(lhs, rhs, del); - xy += _xy; - x2 += _x2; - y2 += _y2; - } - xy / (x2 * y2).sqrt() * (-1.0) - } - Distance::Dot => { - let width = dims.div_ceil(ratio); - let mut xy = Scalar::Z; - for i in 0..width { - let k = std::cmp::min(ratio, dims - ratio * i); - let lhs = &lhs[(i * ratio) as usize..][..k as usize]; - let rhsp = rhs[i as usize] as usize * dims as usize; - let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let del = &delta[(i * ratio) as usize..][..k as usize]; - let _xy = distance_dot_delta(lhs, rhs, del); - xy += _xy; - } - xy * (-1.0) - } - } -} diff --git a/src/prelude/mod.rs b/src/prelude/mod.rs deleted file mode 100644 index d740d177d..000000000 --- a/src/prelude/mod.rs +++ /dev/null @@ -1,15 +0,0 @@ -mod distance; -mod error; -mod filter; -mod heap; -mod scalar; -mod stat; -mod sys; - -pub use self::distance::Distance; -pub use self::error::{Friendly, FriendlyError}; -pub use self::filter::{Filter, Payload}; -pub use self::heap::{Heap, HeapElement}; -pub use self::scalar::{Float, Scalar}; -pub use self::stat::VectorIndexInfo; -pub use self::sys::{Id, Pointer}; diff --git a/src/prelude/scalar.rs b/src/prelude/scalar.rs deleted file mode 100644 index 99a6338d1..000000000 --- a/src/prelude/scalar.rs +++ /dev/null @@ -1,295 +0,0 @@ -use bytemuck::{Pod, Zeroable}; -use pgrx::pg_sys::{Datum, Oid}; -use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; -use pgrx::pgrx_sql_entity_graph::metadata::FunctionMetadataTypeEntity; -use pgrx::pgrx_sql_entity_graph::metadata::Returns; -use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; -use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; -use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; -use pgrx::{FromDatum, IntoDatum}; -use serde::{Deserialize, Serialize}; -use std::cmp::Ordering; -use std::fmt::{Debug, Display}; -use std::num::ParseFloatError; -use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Rem, RemAssign, Sub, SubAssign}; -use std::str::FromStr; - -pub type Float = f32; - -#[derive(Clone, Copy, Default, Serialize, Deserialize)] -#[repr(transparent)] -#[serde(transparent)] -pub struct Scalar(pub Float); - -impl Scalar { - pub const INFINITY: Self = Self(Float::INFINITY); - pub const NEG_INFINITY: Self = Self(Float::NEG_INFINITY); - pub const NAN: Self = Self(Float::NAN); - pub const Z: Self = Self(0.0); - - #[inline(always)] - pub fn acos(self) -> Self { - Self(self.0.acos()) - } - - #[inline(always)] - pub fn sqrt(self) -> Self { - Self(self.0.sqrt()) - } -} - -unsafe impl Zeroable for Scalar {} - -unsafe impl Pod for Scalar {} - -impl Debug for Scalar { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Debug::fmt(&self.0, f) - } -} - -impl Display for Scalar { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Display::fmt(&self.0, f) - } -} - -impl From for Scalar { - fn from(value: Float) -> Self { - Self(value) - } -} - -impl From for Float { - fn from(Scalar(float): Scalar) -> Self { - float - } -} - -impl PartialEq for Scalar { - fn eq(&self, other: &Self) -> bool { - self.0.total_cmp(&other.0) == Ordering::Equal - } -} - -impl Eq for Scalar {} - -impl PartialOrd for Scalar { - #[inline(always)] - fn partial_cmp(&self, other: &Self) -> Option { - Some(Ord::cmp(self, other)) - } -} - -impl Ord for Scalar { - #[inline(always)] - fn cmp(&self, other: &Self) -> Ordering { - self.0.total_cmp(&other.0) - } -} - -impl Add for Scalar { - type Output = Scalar; - - #[inline(always)] - fn add(self, rhs: Float) -> Scalar { - unsafe { std::intrinsics::fadd_fast(self.0, rhs).into() } - } -} - -impl AddAssign for Scalar { - fn add_assign(&mut self, rhs: Float) { - unsafe { self.0 = std::intrinsics::fadd_fast(self.0, rhs) } - } -} - -impl Add for Scalar { - type Output = Scalar; - - #[inline(always)] - fn add(self, rhs: Scalar) -> Scalar { - unsafe { std::intrinsics::fadd_fast(self.0, rhs.0).into() } - } -} - -impl AddAssign for Scalar { - #[inline(always)] - fn add_assign(&mut self, rhs: Scalar) { - unsafe { self.0 = std::intrinsics::fadd_fast(self.0, rhs.0) } - } -} - -impl Sub for Scalar { - type Output = Scalar; - - #[inline(always)] - fn sub(self, rhs: Float) -> Scalar { - unsafe { std::intrinsics::fsub_fast(self.0, rhs).into() } - } -} - -impl SubAssign for Scalar { - #[inline(always)] - fn sub_assign(&mut self, rhs: Float) { - unsafe { self.0 = std::intrinsics::fsub_fast(self.0, rhs) } - } -} - -impl Sub for Scalar { - type Output = Scalar; - - #[inline(always)] - fn sub(self, rhs: Scalar) -> Scalar { - unsafe { std::intrinsics::fsub_fast(self.0, rhs.0).into() } - } -} - -impl SubAssign for Scalar { - #[inline(always)] - fn sub_assign(&mut self, rhs: Scalar) { - unsafe { self.0 = std::intrinsics::fsub_fast(self.0, rhs.0) } - } -} - -impl Mul for Scalar { - type Output = Scalar; - - #[inline(always)] - fn mul(self, rhs: Float) -> Scalar { - unsafe { std::intrinsics::fmul_fast(self.0, rhs).into() } - } -} - -impl MulAssign for Scalar { - #[inline(always)] - fn mul_assign(&mut self, rhs: Float) { - unsafe { self.0 = std::intrinsics::fmul_fast(self.0, rhs) } - } -} - -impl Mul for Scalar { - type Output = Scalar; - - #[inline(always)] - fn mul(self, rhs: Scalar) -> Scalar { - unsafe { std::intrinsics::fmul_fast(self.0, rhs.0).into() } - } -} - -impl MulAssign for Scalar { - #[inline(always)] - fn mul_assign(&mut self, rhs: Scalar) { - unsafe { self.0 = std::intrinsics::fmul_fast(self.0, rhs.0) } - } -} - -impl Div for Scalar { - type Output = Scalar; - - #[inline(always)] - fn div(self, rhs: Float) -> Scalar { - unsafe { std::intrinsics::fdiv_fast(self.0, rhs).into() } - } -} - -impl DivAssign for Scalar { - #[inline(always)] - fn div_assign(&mut self, rhs: Float) { - unsafe { self.0 = std::intrinsics::fdiv_fast(self.0, rhs) } - } -} - -impl Div for Scalar { - type Output = Scalar; - - #[inline(always)] - fn div(self, rhs: Scalar) -> Scalar { - unsafe { std::intrinsics::fdiv_fast(self.0, rhs.0).into() } - } -} - -impl DivAssign for Scalar { - #[inline(always)] - fn div_assign(&mut self, rhs: Scalar) { - unsafe { self.0 = std::intrinsics::fdiv_fast(self.0, rhs.0) } - } -} - -impl Rem for Scalar { - type Output = Scalar; - - #[inline(always)] - fn rem(self, rhs: Float) -> Scalar { - unsafe { std::intrinsics::frem_fast(self.0, rhs).into() } - } -} - -impl RemAssign for Scalar { - #[inline(always)] - fn rem_assign(&mut self, rhs: Float) { - unsafe { self.0 = std::intrinsics::frem_fast(self.0, rhs) } - } -} - -impl Rem for Scalar { - type Output = Scalar; - - #[inline(always)] - fn rem(self, rhs: Scalar) -> Scalar { - unsafe { std::intrinsics::frem_fast(self.0, rhs.0).into() } - } -} - -impl RemAssign for Scalar { - #[inline(always)] - fn rem_assign(&mut self, rhs: Scalar) { - unsafe { self.0 = std::intrinsics::frem_fast(self.0, rhs.0) } - } -} - -impl FromStr for Scalar { - type Err = ParseFloatError; - - fn from_str(s: &str) -> Result { - Float::from_str(s).map(|x| x.into()) - } -} - -impl FromDatum for Scalar { - const GET_TYPOID: bool = false; - - unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, typoid: Oid) -> Option { - Float::from_polymorphic_datum(datum, is_null, typoid).map(Self) - } -} - -impl IntoDatum for Scalar { - fn into_datum(self) -> Option { - Float::into_datum(self.0) - } - - fn type_oid() -> Oid { - Float::type_oid() - } -} - -unsafe impl SqlTranslatable for Scalar { - fn type_name() -> &'static str { - Float::type_name() - } - fn argument_sql() -> Result { - Float::argument_sql() - } - fn return_sql() -> Result { - Float::return_sql() - } - fn variadic() -> bool { - Float::variadic() - } - fn optional() -> bool { - Float::optional() - } - fn entity() -> FunctionMetadataTypeEntity { - Float::entity() - } -} diff --git a/src/prelude/stat.rs b/src/prelude/stat.rs deleted file mode 100644 index b325b60ed..000000000 --- a/src/prelude/stat.rs +++ /dev/null @@ -1,13 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize)] -pub struct VectorIndexInfo { - pub indexing: bool, - pub idx_tuples: i32, - pub idx_sealed_len: i32, - pub idx_growing_len: i32, - pub idx_write: i32, - pub idx_sealed: Vec, - pub idx_growing: Vec, - pub idx_config: String, -} diff --git a/src/sql/bootstrap.sql b/src/sql/bootstrap.sql index a12734ab2..5310d6794 100644 --- a/src/sql/bootstrap.sql +++ b/src/sql/bootstrap.sql @@ -1 +1,2 @@ CREATE TYPE vector; +CREATE TYPE vecf16; diff --git a/src/sql/finalize.sql b/src/sql/finalize.sql index bdc484b03..130647e4f 100644 --- a/src/sql/finalize.sql +++ b/src/sql/finalize.sql @@ -1,8 +1,11 @@ CREATE CAST (real[] AS vector) - WITH FUNCTION cast_array_to_vector(real[], integer, boolean) AS IMPLICIT; + WITH FUNCTION vector_cast_array_to_vector(real[], integer, boolean) AS IMPLICIT; CREATE CAST (vector AS real[]) - WITH FUNCTION cast_vector_to_array(vector, integer, boolean) AS IMPLICIT; + WITH FUNCTION vector_cast_vector_to_array(vector, integer, boolean) AS IMPLICIT; + +CREATE ACCESS METHOD vectors TYPE INDEX HANDLER vectors_amhandler; +COMMENT ON ACCESS METHOD vectors IS 'pgvecto.rs index access method'; CREATE OPERATOR CLASS l2_ops FOR TYPE vector USING vectors AS @@ -16,6 +19,18 @@ CREATE OPERATOR CLASS cosine_ops FOR TYPE vector USING vectors AS OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops; +CREATE OPERATOR CLASS vecf16_l2_ops + FOR TYPE vecf16 USING vectors AS + OPERATOR 1 <-> (vecf16, vecf16) FOR ORDER BY float_ops; + +CREATE OPERATOR CLASS vecf16_dot_ops + FOR TYPE vecf16 USING vectors AS + OPERATOR 1 <#> (vecf16, vecf16) FOR ORDER BY float_ops; + +CREATE OPERATOR CLASS vecf16_cosine_ops + FOR TYPE vecf16 USING vectors AS + OPERATOR 1 <=> (vecf16, vecf16) FOR ORDER BY float_ops; + CREATE VIEW pg_vector_index_info AS SELECT C.oid AS tablerelid, @@ -27,4 +42,4 @@ CREATE VIEW pg_vector_index_info AS pg_index X ON C.oid = X.indrelid JOIN pg_class I ON I.oid = X.indexrelid JOIN pg_am A ON A.oid = I.relam - WHERE A.amname = 'vectors'; \ No newline at end of file + WHERE A.amname = 'vectors'; diff --git a/src/utils/cells.rs b/src/utils/cells.rs index e9f004570..edaf4b22a 100644 --- a/src/utils/cells.rs +++ b/src/utils/cells.rs @@ -1,4 +1,4 @@ -use std::cell::{Cell, RefCell, UnsafeCell}; +use std::cell::{Cell, RefCell}; pub struct PgCell(Cell); @@ -36,28 +36,3 @@ impl PgRefCell { self.0.borrow() } } - -#[repr(transparent)] -pub struct SyncUnsafeCell { - value: UnsafeCell, -} - -unsafe impl Sync for SyncUnsafeCell {} - -impl SyncUnsafeCell { - pub const fn new(value: T) -> Self { - Self { - value: UnsafeCell::new(value), - } - } -} - -impl SyncUnsafeCell { - pub fn get(&self) -> *mut T { - self.value.get() - } - - pub fn get_mut(&mut self) -> &mut T { - self.value.get_mut() - } -} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 813ce1d83..462c6ba81 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,9 +1,3 @@ pub mod cells; -pub mod clean; -pub mod dir_ops; -pub mod file_atomic; pub mod file_socket; -pub mod file_wal; -pub mod mmap_array; pub mod os; -pub mod vec2; diff --git a/src/utils/os.rs b/src/utils/os.rs index 44316be89..d77f71899 100644 --- a/src/utils/os.rs +++ b/src/utils/os.rs @@ -8,18 +8,22 @@ pub unsafe fn futex_wait(futex: &AtomicU32, value: u32) { tv_sec: 15, tv_nsec: 0, }; - libc::syscall( - libc::SYS_futex, - futex.as_ptr(), - libc::FUTEX_WAIT, - value, - &FUTEX_TIMEOUT, - ); + unsafe { + libc::syscall( + libc::SYS_futex, + futex.as_ptr(), + libc::FUTEX_WAIT, + value, + &FUTEX_TIMEOUT, + ); + } } #[cfg(target_os = "linux")] pub unsafe fn futex_wake(futex: &AtomicU32) { - libc::syscall(libc::SYS_futex, futex.as_ptr(), libc::FUTEX_WAKE, i32::MAX); + unsafe { + libc::syscall(libc::SYS_futex, futex.as_ptr(), libc::FUTEX_WAKE, i32::MAX); + } } #[cfg(target_os = "linux")] @@ -34,34 +38,40 @@ pub fn memfd_create() -> std::io::Result { #[cfg(target_os = "linux")] pub unsafe fn mmap_populate(len: usize, fd: impl AsFd) -> std::io::Result<*mut libc::c_void> { use std::ptr::null_mut; - Ok(rustix::mm::mmap( - null_mut(), - len, - ProtFlags::READ | ProtFlags::WRITE, - MapFlags::SHARED | MapFlags::POPULATE, - fd, - 0, - )?) + unsafe { + Ok(rustix::mm::mmap( + null_mut(), + len, + ProtFlags::READ | ProtFlags::WRITE, + MapFlags::SHARED | MapFlags::POPULATE, + fd, + 0, + )?) + } } #[cfg(target_os = "macos")] pub unsafe fn futex_wait(futex: &AtomicU32, value: u32) { const ULOCK_TIMEOUT: u32 = 15_000_000; - ulock_sys::__ulock_wait( - ulock_sys::darwin19::UL_COMPARE_AND_WAIT_SHARED, - futex.as_ptr().cast(), - value as _, - ULOCK_TIMEOUT, - ); + unsafe { + ulock_sys::__ulock_wait( + ulock_sys::darwin19::UL_COMPARE_AND_WAIT_SHARED, + futex.as_ptr().cast(), + value as _, + ULOCK_TIMEOUT, + ); + } } #[cfg(target_os = "macos")] pub unsafe fn futex_wake(futex: &AtomicU32) { - ulock_sys::__ulock_wake( - ulock_sys::darwin19::UL_COMPARE_AND_WAIT_SHARED, - futex.as_ptr().cast(), - 0, - ); + unsafe { + ulock_sys::__ulock_wake( + ulock_sys::darwin19::UL_COMPARE_AND_WAIT_SHARED, + futex.as_ptr().cast(), + 0, + ); + } } #[cfg(target_os = "macos")] @@ -87,12 +97,14 @@ pub fn memfd_create() -> std::io::Result { #[cfg(target_os = "macos")] pub unsafe fn mmap_populate(len: usize, fd: impl AsFd) -> std::io::Result<*mut libc::c_void> { use std::ptr::null_mut; - Ok(rustix::mm::mmap( - null_mut(), - len, - ProtFlags::READ | ProtFlags::WRITE, - MapFlags::SHARED, - fd, - 0, - )?) + unsafe { + Ok(rustix::mm::mmap( + null_mut(), + len, + ProtFlags::READ | ProtFlags::WRITE, + MapFlags::SHARED, + fd, + 0, + )?) + } } From 2c13e5d44cf93104f5e4ac8bf5d98bf86a7fa0cf Mon Sep 17 00:00:00 2001 From: usamoi Date: Fri, 8 Dec 2023 15:24:09 +0800 Subject: [PATCH 02/23] feat: detect avx512fp16 Signed-off-by: usamoi --- .github/workflows/check.yml | 2 + Cargo.lock | 11 +++ crates/c/src/c.c | 35 ++++++++- crates/c/src/c.h | 10 +++ crates/c/src/c.rs | 4 +- crates/service/Cargo.toml | 2 + .../src/algorithms/quantization/product.rs | 2 +- crates/service/src/index/optimizing/mod.rs | 2 +- .../service/src/prelude/global/avx512fp16.rs | 14 ++++ crates/service/src/prelude/global/f16.rs | 72 +++++++++++++++++++ crates/service/src/prelude/global/f16_cos.rs | 27 +------ crates/service/src/prelude/global/f16_dot.rs | 40 ++--------- crates/service/src/prelude/global/f16_l2.rs | 25 ++----- crates/service/src/prelude/global/f32_cos.rs | 7 -- crates/service/src/prelude/global/f32_dot.rs | 4 -- crates/service/src/prelude/global/f32_l2.rs | 4 -- crates/service/src/prelude/global/mod.rs | 3 +- 17 files changed, 162 insertions(+), 102 deletions(-) create mode 100644 crates/service/src/prelude/global/avx512fp16.rs create mode 100644 crates/service/src/prelude/global/f16.rs diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 1e981361c..d2f66e82e 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -6,6 +6,7 @@ on: paths: - ".cargo/**" - ".github/**" + - "crates/**" - "scripts/**" - "src/**" - "tests/**" @@ -18,6 +19,7 @@ on: paths: - ".cargo/**" - ".github/**" + - "crates/**" - "scripts/**" - "src/**" - "tests/**" diff --git a/Cargo.lock b/Cargo.lock index 75eb3852d..32e8d2eee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2794,6 +2794,7 @@ dependencies = [ "bincode", "bytemuck", "byteorder", + "c", "crc32fast", "crossbeam", "dashmap", @@ -2811,6 +2812,7 @@ dependencies = [ "serde", "serde_json", "serde_with", + "std_detect", "tempfile", "thiserror", "ulock-sys", @@ -2923,6 +2925,15 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "std_detect" +version = "0.1.5" +source = "git+https://github.com/usamoi/stdarch.git?rev=067a6e889f0ca995a9fe4114061ced6f67acfb95#067a6e889f0ca995a9fe4114061ced6f67acfb95" +dependencies = [ + "cfg-if", + "libc", +] + [[package]] name = "string_cache" version = "0.8.7" diff --git a/crates/c/src/c.c b/crates/c/src/c.c index 6d8e57e97..51299511b 100644 --- a/crates/c/src/c.c +++ b/crates/c/src/c.c @@ -1,3 +1,36 @@ #include "c.h" -void c_test() {} +__attribute__((target("avx512fp16,avx512vl,avx512f,bmi2"))) extern float +vectors_f16_cosine_axv512(_Float16 const *restrict a, + _Float16 const *restrict b, size_t n) { + _Float16 xy = 0.0; + _Float16 x2 = 0.0; + _Float16 y2 = 0.0; + for (size_t i = 0; i < n; i++) { + xy += a[i] * b[i]; + x2 += a[i] * a[i]; + y2 += b[i] * b[i]; + } + return xy / sqrt(x2 * y2); +} + +__attribute__((target("avx512fp16,avx512vl,avx512f,bmi2"))) extern float +vectors_f16_dot_axv512(_Float16 const *restrict a, _Float16 const *restrict b, + size_t n) { + _Float16 result = 0.0; + for (size_t i = 0; i < n; i++) { + result += a[i] * b[i]; + } + return result; +} + +__attribute__((target("avx512fp16,avx512vl,avx512f,bmi2"))) extern float +vectors_f16_distance_squared_l2_axv512(_Float16 const *restrict a, + _Float16 const *restrict b, size_t n) { + _Float16 result = 0.0; + for (size_t i = 0; i < n; i++) { + _Float16 d = a[i] - b[i]; + result += d * d; + } + return result; +} diff --git a/crates/c/src/c.h b/crates/c/src/c.h index e69de29bb..79be0fc0f 100644 --- a/crates/c/src/c.h +++ b/crates/c/src/c.h @@ -0,0 +1,10 @@ +#include +#include +#include + +extern float vectors_f16_cosine_axv512(_Float16 const *, _Float16 const *, + size_t n); +extern float vectors_f16_dot_axv512(_Float16 const *, _Float16 const *, + size_t n); +extern float vectors_f16_distance_squared_l2_axv512(_Float16 const *, + _Float16 const *, size_t n); diff --git a/crates/c/src/c.rs b/crates/c/src/c.rs index 01c3bb7b5..c3efc17a1 100644 --- a/crates/c/src/c.rs +++ b/crates/c/src/c.rs @@ -1,3 +1,5 @@ extern "C" { - pub fn c_test(); + pub fn vectors_f16_cosine_axv512(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn vectors_f16_dot_axv512(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn vectors_f16_distance_squared_l2_axv512(a: *const u16, b: *const u16, n: usize) -> f32; } diff --git a/crates/service/Cargo.toml b/crates/service/Cargo.toml index 6b3d8e964..92650c3ec 100644 --- a/crates/service/Cargo.toml +++ b/crates/service/Cargo.toml @@ -15,6 +15,8 @@ byteorder.workspace = true bincode.workspace = true half.workspace = true num-traits.workspace = true +c = { path = "../c" } +std_detect = { git = "https://github.com/usamoi/stdarch.git", rev = "067a6e889f0ca995a9fe4114061ced6f67acfb95" } rand = "0.8.5" crc32fast = "1.3.2" crossbeam = "0.8.2" diff --git a/crates/service/src/algorithms/quantization/product.rs b/crates/service/src/algorithms/quantization/product.rs index 50d073476..96855632a 100644 --- a/crates/service/src/algorithms/quantization/product.rs +++ b/crates/service/src/algorithms/quantization/product.rs @@ -176,7 +176,7 @@ impl ProductQuantization { for j in 0u8..=255 { let right = ¢roids[j as usize * dims as usize..][(i * ratio) as usize..] [..subdims as usize]; - let dis = S::l2_distance(left, right); + let dis = S::L2::distance(left, right); if dis < minimal { minimal = dis; target = j; diff --git a/crates/service/src/index/optimizing/mod.rs b/crates/service/src/index/optimizing/mod.rs index b109948cf..d0c4b2ae0 100644 --- a/crates/service/src/index/optimizing/mod.rs +++ b/crates/service/src/index/optimizing/mod.rs @@ -1,5 +1,5 @@ -pub mod sealing; pub mod indexing; +pub mod sealing; pub mod vacuum; use serde::{Deserialize, Serialize}; diff --git a/crates/service/src/prelude/global/avx512fp16.rs b/crates/service/src/prelude/global/avx512fp16.rs new file mode 100644 index 000000000..1b65e5a27 --- /dev/null +++ b/crates/service/src/prelude/global/avx512fp16.rs @@ -0,0 +1,14 @@ +// avx512fp16,avx512vl,avx512f,bmi2 + +#[cfg(not(target_arch = "x86_64"))] +pub fn detect() -> bool { + false +} + +#[cfg(target_arch = "x86_64")] +pub fn detect() -> bool { + std_detect::is_x86_feature_detected!("avx512fp16") + && std_detect::is_x86_feature_detected!("avx512vl") + && std_detect::is_x86_feature_detected!("avx512f") + && std_detect::is_x86_feature_detected!("bmi2") +} diff --git a/crates/service/src/prelude/global/f16.rs b/crates/service/src/prelude/global/f16.rs new file mode 100644 index 000000000..3ccd36f53 --- /dev/null +++ b/crates/service/src/prelude/global/f16.rs @@ -0,0 +1,72 @@ +use crate::prelude::*; + +pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { + #[inline(always)] + #[multiversion::multiversion(targets = "simd")] + pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..n { + xy += lhs[i].to_f() * rhs[i].to_f(); + x2 += lhs[i].to_f() * lhs[i].to_f(); + y2 += rhs[i].to_f() * rhs[i].to_f(); + } + xy / (x2 * y2).sqrt() + } + if super::avx512fp16::detect() { + unsafe { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + c::vectors_f16_cosine_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n); + } + } + cosine(lhs, rhs) +} + +pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { + #[inline(always)] + #[multiversion::multiversion(targets = "simd")] + pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + for i in 0..n { + xy += lhs[i].to_f() * rhs[i].to_f(); + } + xy + } + if super::avx512fp16::detect() { + unsafe { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + c::vectors_f16_dot_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n); + } + } + cosine(lhs, rhs) +} + +pub fn distance_squared_l2(lhs: &[F16], rhs: &[F16]) -> F32 { + #[inline(always)] + #[multiversion::multiversion(targets = "simd")] + pub fn distance_squared_l2(lhs: &[F16], rhs: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut d2 = F32::zero(); + for i in 0..n { + let d = lhs[i].to_f() - rhs[i].to_f(); + d2 += d * d; + } + d2 + } + if super::avx512fp16::detect() { + unsafe { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + c::vectors_f16_distance_squared_l2_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n); + } + } + distance_squared_l2(lhs, rhs) +} diff --git a/crates/service/src/prelude/global/f16_cos.rs b/crates/service/src/prelude/global/f16_cos.rs index 28bba1c3f..2cfcacc4f 100644 --- a/crates/service/src/prelude/global/f16_cos.rs +++ b/crates/service/src/prelude/global/f16_cos.rs @@ -12,23 +12,16 @@ impl G for F16Cos { type L2 = F16L2; - #[multiversion::multiversion(targets = "simd")] fn distance(lhs: &[F16], rhs: &[F16]) -> F32 { - cosine(lhs, rhs) * (-1.0) - } - - fn l2_distance(lhs: &[F16], rhs: &[F16]) -> F32 { - super::f16_l2::distance_squared_l2(lhs, rhs) + super::f16::cosine(lhs, rhs) * (-1.0) } - #[multiversion::multiversion(targets = "simd")] fn elkan_k_means_normalize(vector: &mut [F16]) { l2_normalize(vector) } - #[multiversion::multiversion(targets = "simd")] fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 { - super::f16_dot::dot(lhs, rhs).acos() + super::f16::dot(lhs, rhs).acos() } #[multiversion::multiversion(targets = "simd")] @@ -173,22 +166,6 @@ fn l2_normalize(vector: &mut [F16]) { } } -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - let mut xy = F32::zero(); - let mut x2 = F32::zero(); - let mut y2 = F32::zero(); - for i in 0..n { - xy += lhs[i].to_f() * rhs[i].to_f(); - x2 += lhs[i].to_f() * lhs[i].to_f(); - y2 += rhs[i].to_f() * rhs[i].to_f(); - } - xy / (x2 * y2).sqrt() -} - #[inline(always)] #[multiversion::multiversion(targets = "simd")] fn xy_x2_y2(lhs: &[F16], rhs: &[F16]) -> (F32, F32, F32) { diff --git a/crates/service/src/prelude/global/f16_dot.rs b/crates/service/src/prelude/global/f16_dot.rs index 8f6581032..bee46c56f 100644 --- a/crates/service/src/prelude/global/f16_dot.rs +++ b/crates/service/src/prelude/global/f16_dot.rs @@ -13,11 +13,7 @@ impl G for F16Dot { type L2 = F16L2; fn distance(lhs: &[F16], rhs: &[F16]) -> F32 { - dot(lhs, rhs) * (-1.0) - } - - fn l2_distance(lhs: &[F16], rhs: &[F16]) -> F32 { - super::f16_l2::distance_squared_l2(lhs, rhs) + super::f16::dot(lhs, rhs) * (-1.0) } fn elkan_k_means_normalize(vector: &mut [F16]) { @@ -25,7 +21,7 @@ impl G for F16Dot { } fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 { - super::f16_dot::dot(lhs, rhs).acos() + super::f16::dot(lhs, rhs).acos() } #[multiversion::multiversion(targets = "simd")] @@ -77,7 +73,7 @@ impl G for F16Dot { let lhs = &lhs[(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let _xy = dot(lhs, rhs); + let _xy = super::f16::dot(lhs, rhs); xy += _xy; } xy * (-1.0) @@ -99,7 +95,7 @@ impl G for F16Dot { let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let _xy = dot(lhs, rhs); + let _xy = super::f16::dot(lhs, rhs); xy += _xy; } xy * (-1.0) @@ -150,34 +146,6 @@ fn l2_normalize(vector: &mut [F16]) { } } -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - let mut xy = F32::zero(); - let mut x2 = F32::zero(); - let mut y2 = F32::zero(); - for i in 0..n { - xy += lhs[i].to_f() * rhs[i].to_f(); - x2 += lhs[i].to_f() * lhs[i].to_f(); - y2 += rhs[i].to_f() * rhs[i].to_f(); - } - xy / (x2 * y2).sqrt() -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - let mut xy = F32::zero(); - for i in 0..n { - xy += lhs[i].to_f() * rhs[i].to_f(); - } - xy -} - #[inline(always)] #[multiversion::multiversion(targets = "simd")] fn dot_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 { diff --git a/crates/service/src/prelude/global/f16_l2.rs b/crates/service/src/prelude/global/f16_l2.rs index fd706bc07..2cb15b6c2 100644 --- a/crates/service/src/prelude/global/f16_l2.rs +++ b/crates/service/src/prelude/global/f16_l2.rs @@ -14,17 +14,13 @@ impl G for F16L2 { type L2 = F16L2; fn distance(lhs: &[F16], rhs: &[F16]) -> F32 { - distance_squared_l2(lhs, rhs) - } - - fn l2_distance(lhs: &[F16], rhs: &[F16]) -> F32 { - distance_squared_l2(lhs, rhs) + super::f16::distance_squared_l2(lhs, rhs) } fn elkan_k_means_normalize(_: &mut [F16]) {} fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 { - distance_squared_l2(lhs, rhs).sqrt() + super::f16::distance_squared_l2(lhs, rhs).sqrt() } #[multiversion::multiversion(targets = "simd")] @@ -76,7 +72,7 @@ impl G for F16L2 { let lhs = &lhs[(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - result += distance_squared_l2(lhs, rhs); + result += super::f16::distance_squared_l2(lhs, rhs); } result } @@ -97,7 +93,7 @@ impl G for F16L2 { let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - result += distance_squared_l2(lhs, rhs); + result += super::f16::distance_squared_l2(lhs, rhs); } result } @@ -125,19 +121,6 @@ impl G for F16L2 { } } -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -pub fn distance_squared_l2(lhs: &[F16], rhs: &[F16]) -> F32 { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - let mut d2 = F32::zero(); - for i in 0..n { - let d = lhs[i].to_f() - rhs[i].to_f(); - d2 += d * d; - } - d2 -} - #[inline(always)] #[multiversion::multiversion(targets = "simd")] fn distance_squared_l2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 { diff --git a/crates/service/src/prelude/global/f32_cos.rs b/crates/service/src/prelude/global/f32_cos.rs index 43ae83717..4c48af279 100644 --- a/crates/service/src/prelude/global/f32_cos.rs +++ b/crates/service/src/prelude/global/f32_cos.rs @@ -12,21 +12,14 @@ impl G for F32Cos { type L2 = F32L2; - #[multiversion::multiversion(targets = "simd")] fn distance(lhs: &[F32], rhs: &[F32]) -> F32 { cosine(lhs, rhs) * (-1.0) } - fn l2_distance(lhs: &[F32], rhs: &[F32]) -> F32 { - super::f32_l2::distance_squared_l2(lhs, rhs) - } - - #[multiversion::multiversion(targets = "simd")] fn elkan_k_means_normalize(vector: &mut [F32]) { l2_normalize(vector) } - #[multiversion::multiversion(targets = "simd")] fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 { super::f32_dot::dot(lhs, rhs).acos() } diff --git a/crates/service/src/prelude/global/f32_dot.rs b/crates/service/src/prelude/global/f32_dot.rs index 2ed3be71f..081f4eb39 100644 --- a/crates/service/src/prelude/global/f32_dot.rs +++ b/crates/service/src/prelude/global/f32_dot.rs @@ -16,10 +16,6 @@ impl G for F32Dot { dot(lhs, rhs) * (-1.0) } - fn l2_distance(lhs: &[F32], rhs: &[F32]) -> F32 { - super::f32_l2::distance_squared_l2(lhs, rhs) - } - fn elkan_k_means_normalize(vector: &mut [F32]) { l2_normalize(vector) } diff --git a/crates/service/src/prelude/global/f32_l2.rs b/crates/service/src/prelude/global/f32_l2.rs index aca474de8..a5d6da5b5 100644 --- a/crates/service/src/prelude/global/f32_l2.rs +++ b/crates/service/src/prelude/global/f32_l2.rs @@ -16,10 +16,6 @@ impl G for F32L2 { distance_squared_l2(lhs, rhs) } - fn l2_distance(lhs: &[F32], rhs: &[F32]) -> F32 { - distance_squared_l2(lhs, rhs) - } - fn elkan_k_means_normalize(_: &mut [F32]) {} fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 { diff --git a/crates/service/src/prelude/global/mod.rs b/crates/service/src/prelude/global/mod.rs index f0936ddfe..3bc40198b 100644 --- a/crates/service/src/prelude/global/mod.rs +++ b/crates/service/src/prelude/global/mod.rs @@ -1,3 +1,5 @@ +mod avx512fp16; +mod f16; mod f16_cos; mod f16_dot; mod f16_l2; @@ -35,7 +37,6 @@ pub trait G: Copy + std::fmt::Debug + 'static { type L2: G; fn distance(lhs: &[Self::Scalar], rhs: &[Self::Scalar]) -> F32; - fn l2_distance(lhs: &[Self::Scalar], rhs: &[Self::Scalar]) -> F32; fn elkan_k_means_normalize(vector: &mut [Self::Scalar]); fn elkan_k_means_distance(lhs: &[Self::Scalar], rhs: &[Self::Scalar]) -> F32; fn scalar_quantization_distance( From 4f7e9b84c56b610e45511eeb8df237e56a3825cd Mon Sep 17 00:00:00 2001 From: usamoi Date: Fri, 8 Dec 2023 16:12:04 +0800 Subject: [PATCH 03/23] fix: install clang-16 for ci Signed-off-by: usamoi --- .github/workflows/release.yml | 5 ++++- crates/c/build.rs | 3 ++- docs/installation.md | 11 ++++++++++- scripts/ci_setup.sh | 5 ++++- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index edf49c804..c694c8c9c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -112,10 +112,13 @@ jobs: - uses: mozilla-actions/sccache-action@v0.0.3 - name: Prepare run: | - sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list' + sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" >> /etc/apt/sources.list.d/pgdg.list' + sudo sh -c 'echo "deb http://apt.llvm.org/$(lsb_release -cs)/ llvm-toolchain-$(lsb_release -cs)-16 main" >> /etc/apt/sources.list' wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - + wget --quiet -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - sudo apt-get update sudo apt-get -y install libpq-dev postgresql-${{ matrix.version }} postgresql-server-dev-${{ matrix.version }} + sudo apt-get -y install clang-16 cargo install cargo-pgrx --git https://github.com/tensorchord/pgrx.git --rev $(cat Cargo.toml | grep "pgrx =" | awk -F'rev = "' '{print $2}' | cut -d'"' -f1) cargo pgrx init --pg${{ matrix.version }}=/usr/lib/postgresql/${{ matrix.version }}/bin/pg_config if [[ "${{ matrix.arch }}" == "arm64" ]]; then diff --git a/crates/c/build.rs b/crates/c/build.rs index b86eeefc4..df22f5ee2 100644 --- a/crates/c/build.rs +++ b/crates/c/build.rs @@ -1,6 +1,7 @@ fn main() { cc::Build::new() - .compiler("/usr/bin/clang") + .compiler("/usr/bin/clang-16") .file("./src/c.c") + .opt_level(3) .compile("c"); } diff --git a/docs/installation.md b/docs/installation.md index 3b8610146..cd8d7a41f 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -31,12 +31,21 @@ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh Install PostgreSQL. ```sh -sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list' +sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" >> /etc/apt/sources.list.d/pgdg.list' wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - sudo apt-get update sudo apt-get -y install libpq-dev postgresql-15 postgresql-server-dev-15 ``` +Install clang-16. + +```sh +sudo sh -c 'echo "deb http://apt.llvm.org/$(lsb_release -cs)/ llvm-toolchain-$(lsb_release -cs)-16 main" >> /etc/apt/sources.list' +wget --quiet -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - +sudo apt-get update +sudo apt-get -y install clang-16 +``` + Clone the Repository. ```sh diff --git a/scripts/ci_setup.sh b/scripts/ci_setup.sh index f6043ed6b..6efaf81b0 100755 --- a/scripts/ci_setup.sh +++ b/scripts/ci_setup.sh @@ -6,10 +6,13 @@ if [ "$OS" == "ubuntu-latest" ]; then sudo pg_dropcluster 14 main fi sudo apt-get remove -y '^postgres.*' '^libpq.*' '^clang.*' '^llvm.*' '^libclang.*' '^libllvm.*' '^mono-llvm.*' - sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list' + sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" >> /etc/apt/sources.list.d/pgdg.list' + sudo sh -c 'echo "deb http://apt.llvm.org/$(lsb_release -cs)/ llvm-toolchain-$(lsb_release -cs)-16 main" >> /etc/apt/sources.list' wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - + wget --quiet -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - sudo apt-get update sudo apt-get -y install build-essential libpq-dev postgresql-$VERSION postgresql-server-dev-$VERSION + sudo apt-get -y install clang-16 echo "local all all trust" | sudo tee /etc/postgresql/$VERSION/main/pg_hba.conf echo "host all all 127.0.0.1/32 trust" | sudo tee -a /etc/postgresql/$VERSION/main/pg_hba.conf echo "host all all ::1/128 trust" | sudo tee -a /etc/postgresql/$VERSION/main/pg_hba.conf From 87b63873a59d450d55e52d86775de6749d9ca127 Mon Sep 17 00:00:00 2001 From: usamoi Date: Fri, 8 Dec 2023 16:37:21 +0800 Subject: [PATCH 04/23] fix: clippy Signed-off-by: usamoi --- .gitignore | 3 ++- crates/service/src/index/optimizing/indexing.rs | 2 +- crates/service/src/index/segments/growing.rs | 2 ++ crates/service/src/prelude/scalar/f16.rs | 4 ++-- crates/service/src/prelude/scalar/f32.rs | 4 ++-- 5 files changed, 9 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index ee9dad69c..b50e6dba4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ .vscode .ignore __pycache__ -.pytest_cache \ No newline at end of file +.pytest_cache +rustc-ice-*.txt diff --git a/crates/service/src/index/optimizing/indexing.rs b/crates/service/src/index/optimizing/indexing.rs index a9e62101d..d77fc1948 100644 --- a/crates/service/src/index/optimizing/indexing.rs +++ b/crates/service/src/index/optimizing/indexing.rs @@ -102,7 +102,7 @@ pub fn optimizing_indexing(index: Arc>) -> bool { break; } } - if segs.len() == 0 || (segs.len() == 1 && count == 0) { + if segs.is_empty() || (segs.len() == 1 && count == 0) { index.instant_index.store(Instant::now()); return true; } diff --git a/crates/service/src/index/segments/growing.rs b/crates/service/src/index/segments/growing.rs index 5c20c2029..3607a0353 100644 --- a/crates/service/src/index/segments/growing.rs +++ b/crates/service/src/index/segments/growing.rs @@ -1,3 +1,5 @@ +#![allow(clippy::all)] // Clippy bug. + use super::SegmentTracker; use crate::index::IndexOptions; use crate::index::IndexTracker; diff --git a/crates/service/src/prelude/scalar/f16.rs b/crates/service/src/prelude/scalar/f16.rs index d8c9ee977..467542f06 100644 --- a/crates/service/src/prelude/scalar/f16.rs +++ b/crates/service/src/prelude/scalar/f16.rs @@ -374,8 +374,8 @@ impl num_traits::Float for F16 { } fn sin_cos(self) -> (Self, Self) { - let (_0, _1) = self.0.sin_cos(); - (Self(_0), Self(_1)) + let (_x, _y) = self.0.sin_cos(); + (Self(_x), Self(_y)) } fn exp_m1(self) -> Self { diff --git a/crates/service/src/prelude/scalar/f32.rs b/crates/service/src/prelude/scalar/f32.rs index f80700b12..a4e70a10a 100644 --- a/crates/service/src/prelude/scalar/f32.rs +++ b/crates/service/src/prelude/scalar/f32.rs @@ -373,8 +373,8 @@ impl num_traits::Float for F32 { } fn sin_cos(self) -> (Self, Self) { - let (_0, _1) = self.0.sin_cos(); - (Self(_0), Self(_1)) + let (_x, _y) = self.0.sin_cos(); + (Self(_x), Self(_y)) } fn exp_m1(self) -> Self { From 0733c44c069d7a956940072701dba3528f296921 Mon Sep 17 00:00:00 2001 From: usamoi Date: Fri, 8 Dec 2023 17:08:49 +0800 Subject: [PATCH 05/23] fix: rename c to pgvectorsc Signed-off-by: usamoi --- crates/c/build.rs | 2 +- crates/c/src/c.rs | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/c/build.rs b/crates/c/build.rs index df22f5ee2..802de32de 100644 --- a/crates/c/build.rs +++ b/crates/c/build.rs @@ -3,5 +3,5 @@ fn main() { .compiler("/usr/bin/clang-16") .file("./src/c.c") .opt_level(3) - .compile("c"); + .compile("pgvectorsc"); } diff --git a/crates/c/src/c.rs b/crates/c/src/c.rs index c3efc17a1..a015bf6a7 100644 --- a/crates/c/src/c.rs +++ b/crates/c/src/c.rs @@ -1,3 +1,4 @@ +#[link(name = "pgvectorsc", kind = "static")] extern "C" { pub fn vectors_f16_cosine_axv512(a: *const u16, b: *const u16, n: usize) -> f32; pub fn vectors_f16_dot_axv512(a: *const u16, b: *const u16, n: usize) -> f32; From e187cbde4e2db0e0edbc49503d77b28252d0f54d Mon Sep 17 00:00:00 2001 From: usamoi Date: Fri, 8 Dec 2023 17:49:55 +0800 Subject: [PATCH 06/23] feat: hand-writing avx512fp16 Signed-off-by: usamoi --- crates/c/src/c.c | 77 ++++++++++++++----- crates/c/src/c.h | 1 - .../service/src/prelude/global/avx512fp16.rs | 4 - src/sql/finalize.sql | 4 +- 4 files changed, 59 insertions(+), 27 deletions(-) diff --git a/crates/c/src/c.c b/crates/c/src/c.c index 51299511b..486fa9fe0 100644 --- a/crates/c/src/c.c +++ b/crates/c/src/c.c @@ -1,36 +1,73 @@ #include "c.h" +#include +#include -__attribute__((target("avx512fp16,avx512vl,avx512f,bmi2"))) extern float +__attribute__((target("avx512fp16,bmi2"))) extern float vectors_f16_cosine_axv512(_Float16 const *restrict a, _Float16 const *restrict b, size_t n) { - _Float16 xy = 0.0; - _Float16 x2 = 0.0; - _Float16 y2 = 0.0; - for (size_t i = 0; i < n; i++) { - xy += a[i] * b[i]; - x2 += a[i] * a[i]; - y2 += b[i] * b[i]; + __m512h xy = _mm512_set1_ph(0); + __m512h xx = _mm512_set1_ph(0); + __m512h yy = _mm512_set1_ph(0); + + while (n >= 32) { + __m512h x = _mm512_loadu_ph(a); + __m512h y = _mm512_loadu_ph(b); + a += 32, b += 32, n -= 32; + xy = _mm512_fmadd_ph(x, y, xy); + xx = _mm512_fmadd_ph(x, x, xx); + yy = _mm512_fmadd_ph(y, y, yy); + } + if (n > 0) { + __mmask32 mask = _bzhi_u32(0xFFFFFFFF, n); + __m512h x = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a)); + __m512h y = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b)); + xy = _mm512_fmadd_ph(x, y, xy); + xx = _mm512_fmadd_ph(x, x, xx); + yy = _mm512_fmadd_ph(y, y, yy); } - return xy / sqrt(x2 * y2); + return (float)(_mm512_reduce_add_ps(xy) / + sqrt(_mm512_reduce_add_ps(xx) * _mm512_reduce_add_ps(yy))); } -__attribute__((target("avx512fp16,avx512vl,avx512f,bmi2"))) extern float +__attribute__((target("avx512fp16,bmi2"))) extern float vectors_f16_dot_axv512(_Float16 const *restrict a, _Float16 const *restrict b, size_t n) { - _Float16 result = 0.0; - for (size_t i = 0; i < n; i++) { - result += a[i] * b[i]; + __m512h xy = _mm512_set1_ph(0); + + while (n >= 32) { + __m512h x = _mm512_loadu_ph(a); + __m512h y = _mm512_loadu_ph(b); + a += 32, b += 32, n -= 32; + xy = _mm512_fmadd_ph(x, y, xy); + } + if (n > 0) { + __mmask32 mask = _bzhi_u32(0xFFFFFFFF, n); + __m512h x = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a)); + __m512h y = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b)); + xy = _mm512_fmadd_ph(x, y, xy); } - return result; + return (float)_mm512_reduce_add_ph(xy); } -__attribute__((target("avx512fp16,avx512vl,avx512f,bmi2"))) extern float +__attribute__((target("avx512fp16,bmi2"))) extern float vectors_f16_distance_squared_l2_axv512(_Float16 const *restrict a, _Float16 const *restrict b, size_t n) { - _Float16 result = 0.0; - for (size_t i = 0; i < n; i++) { - _Float16 d = a[i] - b[i]; - result += d * d; + __m512h dd = _mm512_set1_ph(0); + + while (n >= 32) { + __m512h x = _mm512_loadu_ph(a); + __m512h y = _mm512_loadu_ph(b); + a += 32, b += 32, n -= 32; + __m512h d = _mm512_sub_ph(x, y); + dd = _mm512_fmadd_ph(d, d, dd); + } + if (n > 0) { + __mmask32 mask = _bzhi_u32(0xFFFFFFFF, n); + __m512h x = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a)); + __m512h y = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b)); + __m512h d = _mm512_sub_ph(x, y); + dd = _mm512_fmadd_ph(d, d, dd); } - return result; + + return (float)_mm512_reduce_add_ph(dd); } diff --git a/crates/c/src/c.h b/crates/c/src/c.h index 79be0fc0f..6da5e3a17 100644 --- a/crates/c/src/c.h +++ b/crates/c/src/c.h @@ -1,4 +1,3 @@ -#include #include #include diff --git a/crates/service/src/prelude/global/avx512fp16.rs b/crates/service/src/prelude/global/avx512fp16.rs index 1b65e5a27..28613f0d5 100644 --- a/crates/service/src/prelude/global/avx512fp16.rs +++ b/crates/service/src/prelude/global/avx512fp16.rs @@ -1,5 +1,3 @@ -// avx512fp16,avx512vl,avx512f,bmi2 - #[cfg(not(target_arch = "x86_64"))] pub fn detect() -> bool { false @@ -8,7 +6,5 @@ pub fn detect() -> bool { #[cfg(target_arch = "x86_64")] pub fn detect() -> bool { std_detect::is_x86_feature_detected!("avx512fp16") - && std_detect::is_x86_feature_detected!("avx512vl") - && std_detect::is_x86_feature_detected!("avx512f") && std_detect::is_x86_feature_detected!("bmi2") } diff --git a/src/sql/finalize.sql b/src/sql/finalize.sql index 130647e4f..62a1f8270 100644 --- a/src/sql/finalize.sql +++ b/src/sql/finalize.sql @@ -1,8 +1,8 @@ CREATE CAST (real[] AS vector) - WITH FUNCTION vector_cast_array_to_vector(real[], integer, boolean) AS IMPLICIT; + WITH FUNCTION vecf32_cast_array_to_vector(real[], integer, boolean) AS IMPLICIT; CREATE CAST (vector AS real[]) - WITH FUNCTION vector_cast_vector_to_array(vector, integer, boolean) AS IMPLICIT; + WITH FUNCTION vecf32_cast_vector_to_array(vector, integer, boolean) AS IMPLICIT; CREATE ACCESS METHOD vectors TYPE INDEX HANDLER vectors_amhandler; COMMENT ON ACCESS METHOD vectors IS 'pgvecto.rs index access method'; From c27617b22885913e283c780267b939fda213e0e4 Mon Sep 17 00:00:00 2001 From: usamoi Date: Fri, 8 Dec 2023 20:10:22 +0800 Subject: [PATCH 07/23] fix: index on fp16 Signed-off-by: usamoi --- .../service/src/index/optimizing/indexing.rs | 34 ++++++++----------- crates/service/src/worker/instance.rs | 6 ++-- src/datatype/vecf16.rs | 7 ++++ src/datatype/vecf32.rs | 7 ++++ src/index/am.rs | 18 ++++------ src/index/am_build.rs | 16 ++++----- src/index/am_scan.rs | 16 +++------ src/index/am_setup.rs | 6 ++++ src/index/mod.rs | 1 + src/index/utils.rs | 24 +++++++++++++ 10 files changed, 82 insertions(+), 53 deletions(-) create mode 100644 src/index/utils.rs diff --git a/crates/service/src/index/optimizing/indexing.rs b/crates/service/src/index/optimizing/indexing.rs index d77fc1948..d7d44e2b4 100644 --- a/crates/service/src/index/optimizing/indexing.rs +++ b/crates/service/src/index/optimizing/indexing.rs @@ -33,8 +33,7 @@ impl OptimizerIndexing { let Some(index) = weak_index.upgrade() else { return; }; - let cont = pool.install(|| optimizing_indexing(index.clone())); - if cont { + if let Ok(()) = pool.install(|| optimizing_indexing(index.clone())) { continue; } } @@ -77,36 +76,33 @@ impl Seg { } } -pub fn optimizing_indexing(index: Arc>) -> bool { +pub fn optimizing_indexing(index: Arc>) -> Result<(), ()> { use Seg::*; let segs = { - let mut all_segs = { - let protect = index.protect.lock(); - let mut all_segs = Vec::new(); - all_segs.extend(protect.growing.values().map(|x| Growing(x.clone()))); - all_segs.extend(protect.sealed.values().map(|x| Sealed(x.clone()))); - all_segs.sort_by_key(|case| Reverse(case.len())); - all_segs - }; - let mut segs = Vec::new(); + let protect = index.protect.lock(); + let mut segs_0 = Vec::new(); + segs_0.extend(protect.growing.values().map(|x| Growing(x.clone()))); + segs_0.extend(protect.sealed.values().map(|x| Sealed(x.clone()))); + segs_0.sort_by_key(|case| Reverse(case.len())); + let mut segs_1 = Vec::new(); let mut total = 0u64; let mut count = 0; - while let Some(seg) = all_segs.pop() { + while let Some(seg) = segs_0.pop() { if total + seg.len() as u64 <= index.options.segment.max_sealed_segment_size as u64 { total += seg.len() as u64; if let Growing(_) = seg { count += 1; } - segs.push(seg); + segs_1.push(seg); } else { break; } } - if segs.is_empty() || (segs.len() == 1 && count == 0) { + if segs_1.is_empty() || (segs_1.len() == 1 && count == 0) { index.instant_index.store(Instant::now()); - return true; + return Err(()); } - segs + segs_1 }; let sealed_segment = merge(&index, &segs); { @@ -118,7 +114,7 @@ pub fn optimizing_indexing(index: Arc>) -> bool { if protect.growing.contains_key(&seg.uuid()) { continue; } - return false; + return Ok(()); } for seg in segs.iter() { protect.sealed.remove(&seg.uuid()); @@ -127,7 +123,7 @@ pub fn optimizing_indexing(index: Arc>) -> bool { protect.sealed.insert(sealed_segment.uuid(), sealed_segment); protect.maintain(index.options.clone(), index.delete.clone(), &index.view); } - false + Ok(()) } fn merge(index: &Arc>, segs: &[Seg]) -> Arc> { diff --git a/crates/service/src/worker/instance.rs b/crates/service/src/worker/instance.rs index 13dc2587d..5d4373779 100644 --- a/crates/service/src/worker/instance.rs +++ b/crates/service/src/worker/instance.rs @@ -33,9 +33,9 @@ impl Instance { (Distance::Cos, Kind::F32) => Self::F32Cos(Index::open(path, options)), (Distance::Dot, Kind::F32) => Self::F32Dot(Index::open(path, options)), (Distance::L2, Kind::F32) => Self::F32L2(Index::open(path, options)), - (Distance::Cos, Kind::F16) => Self::F16Cos(Index::create(path, options)), - (Distance::Dot, Kind::F16) => Self::F16Dot(Index::create(path, options)), - (Distance::L2, Kind::F16) => Self::F16L2(Index::create(path, options)), + (Distance::Cos, Kind::F16) => Self::F16Cos(Index::open(path, options)), + (Distance::Dot, Kind::F16) => Self::F16Dot(Index::open(path, options)), + (Distance::L2, Kind::F16) => Self::F16L2(Index::open(path, options)), } } pub fn options(&self) -> &IndexOptions { diff --git a/src/datatype/vecf16.rs b/src/datatype/vecf16.rs index 94e7ccb2f..2d7c48136 100644 --- a/src/datatype/vecf16.rs +++ b/src/datatype/vecf16.rs @@ -39,6 +39,7 @@ CREATE TYPE vecf16 ( #[repr(C, align(8))] pub struct Vecf16 { varlena: u32, + kind: u8, len: u16, phantom: [F16; 0], } @@ -60,6 +61,7 @@ impl Vecf16 { let layout = Vecf16::layout(slice.len()); let ptr = std::alloc::alloc(layout) as *mut Vecf16; std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16::varlena(layout.size())); + std::ptr::addr_of_mut!((*ptr).kind).write(16); std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16); std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); Box::from_raw(ptr) @@ -71,6 +73,7 @@ impl Vecf16 { let layout = Vecf16::layout(slice.len()); let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut Vecf16; std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16::varlena(layout.size())); + std::ptr::addr_of_mut!((*ptr).kind).write(16); std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16); std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); Vecf16Output(NonNull::new(ptr).unwrap()) @@ -82,6 +85,7 @@ impl Vecf16 { let layout = Vecf16::layout(len); let ptr = std::alloc::alloc_zeroed(layout) as *mut Vecf16; std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16::varlena(layout.size())); + std::ptr::addr_of_mut!((*ptr).kind).write(16); std::ptr::addr_of_mut!((*ptr).len).write(len as u16); Box::from_raw(ptr) } @@ -93,6 +97,7 @@ impl Vecf16 { let layout = Vecf16::layout(len); let ptr = pgrx::pg_sys::palloc0(layout.size()) as *mut Vecf16; std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16::varlena(layout.size())); + std::ptr::addr_of_mut!((*ptr).kind).write(16); std::ptr::addr_of_mut!((*ptr).len).write(len as u16); Vecf16Output(NonNull::new(ptr).unwrap()) } @@ -102,10 +107,12 @@ impl Vecf16 { } pub fn data(&self) -> &[F16] { debug_assert_eq!(self.varlena & 3, 0); + debug_assert_eq!(self.kind, 16); unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.len as usize) } } pub fn data_mut(&mut self) -> &mut [F16] { debug_assert_eq!(self.varlena & 3, 0); + debug_assert_eq!(self.kind, 16); unsafe { std::slice::from_raw_parts_mut(self.phantom.as_mut_ptr(), self.len as usize) } } #[allow(dead_code)] diff --git a/src/datatype/vecf32.rs b/src/datatype/vecf32.rs index d04f9312d..6c9e20c4d 100644 --- a/src/datatype/vecf32.rs +++ b/src/datatype/vecf32.rs @@ -39,6 +39,7 @@ CREATE TYPE vector ( #[repr(C, align(8))] pub struct Vecf32 { varlena: u32, + kind: u8, len: u16, phantom: [F32; 0], } @@ -60,6 +61,7 @@ impl Vecf32 { let layout = Vecf32::layout(slice.len()); let ptr = std::alloc::alloc(layout) as *mut Vecf32; std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32::varlena(layout.size())); + std::ptr::addr_of_mut!((*ptr).kind).write(32); std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16); std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); Box::from_raw(ptr) @@ -71,6 +73,7 @@ impl Vecf32 { let layout = Vecf32::layout(slice.len()); let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut Vecf32; std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32::varlena(layout.size())); + std::ptr::addr_of_mut!((*ptr).kind).write(32); std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16); std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); Vecf32Output(NonNull::new(ptr).unwrap()) @@ -82,6 +85,7 @@ impl Vecf32 { let layout = Vecf32::layout(len); let ptr = std::alloc::alloc_zeroed(layout) as *mut Vecf32; std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32::varlena(layout.size())); + std::ptr::addr_of_mut!((*ptr).kind).write(32); std::ptr::addr_of_mut!((*ptr).len).write(len as u16); Box::from_raw(ptr) } @@ -93,6 +97,7 @@ impl Vecf32 { let layout = Vecf32::layout(len); let ptr = pgrx::pg_sys::palloc0(layout.size()) as *mut Vecf32; std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32::varlena(layout.size())); + std::ptr::addr_of_mut!((*ptr).kind).write(32); std::ptr::addr_of_mut!((*ptr).len).write(len as u16); Vecf32Output(NonNull::new(ptr).unwrap()) } @@ -102,10 +107,12 @@ impl Vecf32 { } pub fn data(&self) -> &[F32] { debug_assert_eq!(self.varlena & 3, 0); + debug_assert_eq!(self.kind, 32); unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.len as usize) } } pub fn data_mut(&mut self) -> &mut [F32] { debug_assert_eq!(self.varlena & 3, 0); + debug_assert_eq!(self.kind, 32); unsafe { std::slice::from_raw_parts_mut(self.phantom.as_mut_ptr(), self.len as usize) } } #[allow(dead_code)] diff --git a/src/index/am.rs b/src/index/am.rs index 82925d2ad..2337f4d0f 100644 --- a/src/index/am.rs +++ b/src/index/am.rs @@ -2,8 +2,8 @@ use super::am_build; use super::am_scan; use super::am_setup; use super::am_update; -use crate::datatype::vecf32::Vecf32Input; use crate::gucs::ENABLE_VECTOR_INDEX; +use crate::index::utils::from_datum; use crate::prelude::*; use crate::utils::cells::PgCell; use service::prelude::*; @@ -198,18 +198,16 @@ pub unsafe extern "C" fn ambuildempty(index_relation: pgrx::pg_sys::Relation) { pub unsafe extern "C" fn aminsert( index_relation: pgrx::pg_sys::Relation, values: *mut pgrx::pg_sys::Datum, - is_null: *mut bool, + _is_null: *mut bool, heap_tid: pgrx::pg_sys::ItemPointer, _heap_relation: pgrx::pg_sys::Relation, _check_unique: pgrx::pg_sys::IndexUniqueCheck, _index_info: *mut pgrx::pg_sys::IndexInfo, ) -> bool { - use pgrx::FromDatum; let oid = (*index_relation).rd_node.relNode; let id = Id::from_sys(oid); - let vector = Vecf32Input::from_datum(*values.add(0), *is_null.add(0)).unwrap(); - let vector = vector.data().to_vec(); - am_update::update_insert(id, vector.into(), *heap_tid); + let vector = from_datum(*values.add(0)); + am_update::update_insert(id, vector, *heap_tid); true } @@ -218,22 +216,20 @@ pub unsafe extern "C" fn aminsert( pub unsafe extern "C" fn aminsert( index_relation: pgrx::pg_sys::Relation, values: *mut pgrx::pg_sys::Datum, - is_null: *mut bool, + _is_null: *mut bool, heap_tid: pgrx::pg_sys::ItemPointer, _heap_relation: pgrx::pg_sys::Relation, _check_unique: pgrx::pg_sys::IndexUniqueCheck, _index_unchanged: bool, _index_info: *mut pgrx::pg_sys::IndexInfo, ) -> bool { - use pgrx::FromDatum; #[cfg(any(feature = "pg14", feature = "pg15"))] let oid = (*index_relation).rd_node.relNode; #[cfg(feature = "pg16")] let oid = (*index_relation).rd_locator.relNumber; let id = Id::from_sys(oid); - let vector = Vecf32Input::from_datum(*values.add(0), *is_null.add(0)).unwrap(); - let vector = vector.data().to_vec(); - am_update::update_insert(id, vector.into(), *heap_tid); + let vector = from_datum(*values.add(0)); + am_update::update_insert(id, vector, *heap_tid); true } diff --git a/src/index/am_build.rs b/src/index/am_build.rs index 104b0dfd4..fd7bcb025 100644 --- a/src/index/am_build.rs +++ b/src/index/am_build.rs @@ -1,6 +1,6 @@ use super::{client::ClientGuard, hook_transaction::flush_if_commit}; -use crate::datatype::vecf32::Vecf32Input; use crate::index::am_setup::options; +use crate::index::utils::from_datum; use crate::prelude::*; use pgrx::pg_sys::{IndexBuildResult, IndexInfo, RelationData}; use service::prelude::*; @@ -48,17 +48,16 @@ unsafe extern "C" fn callback( index_relation: pgrx::pg_sys::Relation, htup: pgrx::pg_sys::HeapTuple, values: *mut pgrx::pg_sys::Datum, - is_null: *mut bool, + _is_null: *mut bool, _tuple_is_alive: bool, state: *mut std::os::raw::c_void, ) { - use pgrx::FromDatum; let ctid = &(*htup).t_self; let oid = (*index_relation).rd_node.relNode; let id = Id::from_sys(oid); let state = &mut *(state as *mut Builder); - let pgvector = Vecf32Input::from_datum(*values.add(0), *is_null.add(0)).unwrap(); - let data = (pgvector.to_vec().into(), Pointer::from_sys(*ctid)); + let vector = from_datum(*values.add(0)); + let data = (vector, Pointer::from_sys(*ctid)); state.client.insert(id, data); (*state.result).heap_tuples += 1.0; (*state.result).index_tuples += 1.0; @@ -70,19 +69,18 @@ unsafe extern "C" fn callback( index_relation: pgrx::pg_sys::Relation, ctid: pgrx::pg_sys::ItemPointer, values: *mut pgrx::pg_sys::Datum, - is_null: *mut bool, + _is_null: *mut bool, _tuple_is_alive: bool, state: *mut std::os::raw::c_void, ) { - use pgrx::FromDatum; #[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15"))] let oid = (*index_relation).rd_node.relNode; #[cfg(feature = "pg16")] let oid = (*index_relation).rd_locator.relNumber; let id = Id::from_sys(oid); let state = &mut *(state as *mut Builder); - let pgvector = Vecf32Input::from_datum(*values.add(0), *is_null.add(0)).unwrap(); - let data = (pgvector.to_vec().into(), Pointer::from_sys(*ctid)); + let vector = from_datum(*values.add(0)); + let data = (vector, Pointer::from_sys(*ctid)); state.client.insert(id, data); (*state.result).heap_tuples += 1.0; (*state.result).index_tuples += 1.0; diff --git a/src/index/am_scan.rs b/src/index/am_scan.rs index a887f9c67..433b1ebd5 100644 --- a/src/index/am_scan.rs +++ b/src/index/am_scan.rs @@ -1,6 +1,6 @@ -use crate::datatype::vecf32::Vecf32Input; use crate::gucs::ENABLE_PREFILTER; use crate::gucs::K; +use crate::index::utils::from_datum; use crate::prelude::*; use pgrx::FromDatum; use service::prelude::*; @@ -9,7 +9,7 @@ use service::prelude::*; pub enum Scanner { Initial { // fields to be filled by amhandler and hook - vector: Option>, + vector: Option, index_scan_state: Option<*mut pgrx::pg_sys::IndexScanState>, }, Type0 { @@ -81,8 +81,7 @@ pub unsafe fn start_scan( } let orderby = orderbys.add(0); let argument = (*orderby).sk_argument; - let vector = Vecf32Input::from_datum(argument, false).unwrap(); - let vector = vector.to_vec(); + let vector = from_datum(argument); let last = (*((*scan).opaque as *mut Scanner)).clone(); let scanner = (*scan).opaque as *mut Scanner; @@ -153,12 +152,7 @@ pub unsafe fn next_scan(scan: pgrx::pg_sys::IndexScanDesc) -> bool { node: index_scan_state.unwrap(), }; - let mut result = client.search( - id, - (vector.into(), k), - ENABLE_PREFILTER.get(), - client_search, - ); + let mut result = client.search(id, (vector, k), ENABLE_PREFILTER.get(), client_search); result.reverse(); *scanner = Scanner::Type1 { index_scan_state: index_scan_state.unwrap(), @@ -175,7 +169,7 @@ pub unsafe fn next_scan(scan: pgrx::pg_sys::IndexScanDesc) -> bool { let client_search = ClientSearch {}; - let mut result = client.search(id, (vector.into(), k), false, client_search); + let mut result = client.search(id, (vector, k), false, client_search); result.reverse(); *scanner = Scanner::Type0 { data: result }; } diff --git a/src/index/am_setup.rs b/src/index/am_setup.rs index c4c89d06e..fe452f0e6 100644 --- a/src/index/am_setup.rs +++ b/src/index/am_setup.rs @@ -59,6 +59,12 @@ pub unsafe fn convert_opfamily_to_distance(opfamily: pgrx::pg_sys::Oid) -> (Dist result = (Distance::Dot, Kind::F32); } else if operator == regoperatorin("<=>(vector,vector)") { result = (Distance::Cos, Kind::F32); + } else if operator == regoperatorin("<->(vecf16,vecf16)") { + result = (Distance::L2, Kind::F16); + } else if operator == regoperatorin("<#>(vecf16,vecf16)") { + result = (Distance::Dot, Kind::F16); + } else if operator == regoperatorin("<=>(vecf16,vecf16)") { + result = (Distance::Cos, Kind::F16); } else { FriendlyError::BadOptions3.friendly(); }; diff --git a/src/index/mod.rs b/src/index/mod.rs index f04116b73..ef9feb973 100644 --- a/src/index/mod.rs +++ b/src/index/mod.rs @@ -9,6 +9,7 @@ mod client; mod hook_executor; mod hook_transaction; mod hooks; +mod utils; mod views; pub unsafe fn init() { diff --git a/src/index/utils.rs b/src/index/utils.rs new file mode 100644 index 000000000..3bd7978d7 --- /dev/null +++ b/src/index/utils.rs @@ -0,0 +1,24 @@ +use crate::datatype::vecf16::Vecf16; +use crate::datatype::vecf32::Vecf32; +use service::prelude::DynamicVector; + +#[repr(C, align(8))] +struct Header { + varlena: u32, + kind: u8, + len: u16, +} + +pub unsafe fn from_datum(datum: pgrx::pg_sys::Datum) -> DynamicVector { + let p = datum.cast_mut_ptr::(); + let q = pgrx::pg_sys::pg_detoast_datum(p); + let vector = match (*q.cast::
()).kind { + 32 => DynamicVector::F32((*q.cast::()).data().to_vec()), + 16 => DynamicVector::F16((*q.cast::()).data().to_vec()), + _ => unreachable!(), + }; + if p != q { + pgrx::pg_sys::pfree(q.cast()); + } + vector +} From e05cfb8c6020153a886920e92eb9f2c1537745dd Mon Sep 17 00:00:00 2001 From: usamoi Date: Fri, 8 Dec 2023 20:48:37 +0800 Subject: [PATCH 08/23] feat: hand-writing avx2 Signed-off-by: usamoi --- crates/c/.gitignore | 2 + crates/c/src/c.c | 48 ++++++++++++++++++++++- crates/c/src/c.h | 6 +++ crates/c/src/c.rs | 3 ++ crates/service/src/prelude/global/avx2.rs | 9 +++++ crates/service/src/prelude/global/f16.rs | 38 ++++++++++++++++-- crates/service/src/prelude/global/mod.rs | 1 + 7 files changed, 102 insertions(+), 5 deletions(-) create mode 100644 crates/c/.gitignore create mode 100644 crates/service/src/prelude/global/avx2.rs diff --git a/crates/c/.gitignore b/crates/c/.gitignore new file mode 100644 index 000000000..b72b9e32f --- /dev/null +++ b/crates/c/.gitignore @@ -0,0 +1,2 @@ +*.s +*.o \ No newline at end of file diff --git a/crates/c/src/c.c b/crates/c/src/c.c index 486fa9fe0..8f21f4d54 100644 --- a/crates/c/src/c.c +++ b/crates/c/src/c.c @@ -25,8 +25,8 @@ vectors_f16_cosine_axv512(_Float16 const *restrict a, xx = _mm512_fmadd_ph(x, x, xx); yy = _mm512_fmadd_ph(y, y, yy); } - return (float)(_mm512_reduce_add_ps(xy) / - sqrt(_mm512_reduce_add_ps(xx) * _mm512_reduce_add_ps(yy))); + return (float)(_mm512_reduce_add_ph(xy) / + sqrt(_mm512_reduce_add_ph(xx) * _mm512_reduce_add_ph(yy))); } __attribute__((target("avx512fp16,bmi2"))) extern float @@ -71,3 +71,47 @@ vectors_f16_distance_squared_l2_axv512(_Float16 const *restrict a, return (float)_mm512_reduce_add_ph(dd); } + +__attribute__((target("avx2"))) extern float +vectors_f16_cosine_axv2(_Float16 const *restrict a, _Float16 const *restrict b, + size_t n) { + float xy = 0; + float xx = 0; + float yy = 0; +#pragma clang loop vectorize_width(8) + for (size_t i = 0; i < n; i++) { + float x = a[i]; + float y = b[i]; + xy += x * y; + xx += x * x; + yy += y * y; + } + return xy / sqrt(xx * yy); +} + +__attribute__((target("avx2"))) extern float +vectors_f16_dot_axv2(_Float16 const *restrict a, _Float16 const *restrict b, + size_t n) { + float xy = 0; +#pragma clang loop vectorize_width(8) + for (size_t i = 0; i < n; i++) { + float x = a[i]; + float y = b[i]; + xy += x * y; + } + return xy; +} + +__attribute__((target("avx2"))) extern float +vectors_f16_distance_squared_l2_axv2(_Float16 const *restrict a, + _Float16 const *restrict b, size_t n) { + float dd = 0; +#pragma clang loop vectorize_width(8) + for (size_t i = 0; i < n; i++) { + float x = a[i]; + float y = b[i]; + float d = x - y; + dd += d * d; + } + return dd; +} diff --git a/crates/c/src/c.h b/crates/c/src/c.h index 6da5e3a17..43541eb78 100644 --- a/crates/c/src/c.h +++ b/crates/c/src/c.h @@ -7,3 +7,9 @@ extern float vectors_f16_dot_axv512(_Float16 const *, _Float16 const *, size_t n); extern float vectors_f16_distance_squared_l2_axv512(_Float16 const *, _Float16 const *, size_t n); + +extern float vectors_f16_cosine_axv2(_Float16 const *, _Float16 const *, + size_t n); +extern float vectors_f16_dot_axv2(_Float16 const *, _Float16 const *, size_t n); +extern float vectors_f16_distance_squared_l2_axv2(_Float16 const *, + _Float16 const *, size_t n); diff --git a/crates/c/src/c.rs b/crates/c/src/c.rs index a015bf6a7..5a80f7687 100644 --- a/crates/c/src/c.rs +++ b/crates/c/src/c.rs @@ -3,4 +3,7 @@ extern "C" { pub fn vectors_f16_cosine_axv512(a: *const u16, b: *const u16, n: usize) -> f32; pub fn vectors_f16_dot_axv512(a: *const u16, b: *const u16, n: usize) -> f32; pub fn vectors_f16_distance_squared_l2_axv512(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn vectors_f16_cosine_axv2(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn vectors_f16_dot_axv2(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn vectors_f16_distance_squared_l2_axv2(a: *const u16, b: *const u16, n: usize) -> f32; } diff --git a/crates/service/src/prelude/global/avx2.rs b/crates/service/src/prelude/global/avx2.rs new file mode 100644 index 000000000..404ebbd48 --- /dev/null +++ b/crates/service/src/prelude/global/avx2.rs @@ -0,0 +1,9 @@ +#[cfg(not(target_arch = "x86_64"))] +pub fn detect() -> bool { + false +} + +#[cfg(target_arch = "x86_64")] +pub fn detect() -> bool { + std_detect::is_x86_feature_detected!("avx2") +} diff --git a/crates/service/src/prelude/global/f16.rs b/crates/service/src/prelude/global/f16.rs index 3ccd36f53..84a6930d2 100644 --- a/crates/service/src/prelude/global/f16.rs +++ b/crates/service/src/prelude/global/f16.rs @@ -20,7 +20,15 @@ pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { unsafe { assert!(lhs.len() == rhs.len()); let n = lhs.len(); - c::vectors_f16_cosine_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n); + return c::vectors_f16_cosine_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n) + .into(); + } + } + if super::avx2::detect() { + unsafe { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + return c::vectors_f16_cosine_axv2(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); } } cosine(lhs, rhs) @@ -42,7 +50,14 @@ pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { unsafe { assert!(lhs.len() == rhs.len()); let n = lhs.len(); - c::vectors_f16_dot_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n); + return c::vectors_f16_dot_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); + } + } + if super::avx2::detect() { + unsafe { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + return c::vectors_f16_dot_axv2(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); } } cosine(lhs, rhs) @@ -65,7 +80,24 @@ pub fn distance_squared_l2(lhs: &[F16], rhs: &[F16]) -> F32 { unsafe { assert!(lhs.len() == rhs.len()); let n = lhs.len(); - c::vectors_f16_distance_squared_l2_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n); + return c::vectors_f16_distance_squared_l2_axv512( + lhs.as_ptr().cast(), + rhs.as_ptr().cast(), + n, + ) + .into(); + } + } + if super::avx2::detect() { + unsafe { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + return c::vectors_f16_distance_squared_l2_axv2( + lhs.as_ptr().cast(), + rhs.as_ptr().cast(), + n, + ) + .into(); } } distance_squared_l2(lhs, rhs) diff --git a/crates/service/src/prelude/global/mod.rs b/crates/service/src/prelude/global/mod.rs index 3bc40198b..00b518269 100644 --- a/crates/service/src/prelude/global/mod.rs +++ b/crates/service/src/prelude/global/mod.rs @@ -1,3 +1,4 @@ +mod avx2; mod avx512fp16; mod f16; mod f16_cos; From fe764afc311eac9dfbd196ce82a987183efb5d80 Mon Sep 17 00:00:00 2001 From: usamoi Date: Fri, 8 Dec 2023 20:51:05 +0800 Subject: [PATCH 09/23] fix: clippy Signed-off-by: usamoi --- crates/service/src/index/optimizing/indexing.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/crates/service/src/index/optimizing/indexing.rs b/crates/service/src/index/optimizing/indexing.rs index d7d44e2b4..5a70ffa91 100644 --- a/crates/service/src/index/optimizing/indexing.rs +++ b/crates/service/src/index/optimizing/indexing.rs @@ -76,7 +76,11 @@ impl Seg { } } -pub fn optimizing_indexing(index: Arc>) -> Result<(), ()> { +#[derive(Debug, thiserror::Error)] +#[error("Interrupted, retry again.")] +pub struct RetryError; + +pub fn optimizing_indexing(index: Arc>) -> Result<(), RetryError> { use Seg::*; let segs = { let protect = index.protect.lock(); @@ -100,7 +104,7 @@ pub fn optimizing_indexing(index: Arc>) -> Result<(), ()> { } if segs_1.is_empty() || (segs_1.len() == 1 && count == 0) { index.instant_index.store(Instant::now()); - return Err(()); + return Err(RetryError); } segs_1 }; From 2b9e9a30184236dd2990910039382574d1213ecf Mon Sep 17 00:00:00 2001 From: usamoi Date: Fri, 8 Dec 2023 21:11:38 +0800 Subject: [PATCH 10/23] fix: add rerun in build script Signed-off-by: usamoi --- crates/c/build.rs | 2 + crates/c/src/c.c | 24 +++++----- crates/c/src/c.h | 18 +++---- crates/c/src/c.rs | 12 ++--- crates/service/src/prelude/global/f16.rs | 53 ++++++++------------- crates/service/src/prelude/global/f16_l2.rs | 8 ++-- 6 files changed, 51 insertions(+), 66 deletions(-) diff --git a/crates/c/build.rs b/crates/c/build.rs index 802de32de..b39683c63 100644 --- a/crates/c/build.rs +++ b/crates/c/build.rs @@ -1,4 +1,6 @@ fn main() { + println!("rerun-if-changed:src/c.h"); + println!("rerun-if-changed:src/c.c"); cc::Build::new() .compiler("/usr/bin/clang-16") .file("./src/c.c") diff --git a/crates/c/src/c.c b/crates/c/src/c.c index 8f21f4d54..777d4b480 100644 --- a/crates/c/src/c.c +++ b/crates/c/src/c.c @@ -3,8 +3,8 @@ #include __attribute__((target("avx512fp16,bmi2"))) extern float -vectors_f16_cosine_axv512(_Float16 const *restrict a, - _Float16 const *restrict b, size_t n) { +v_f16_cosine_axv512(_Float16 const *restrict a, _Float16 const *restrict b, + size_t n) { __m512h xy = _mm512_set1_ph(0); __m512h xx = _mm512_set1_ph(0); __m512h yy = _mm512_set1_ph(0); @@ -30,8 +30,8 @@ vectors_f16_cosine_axv512(_Float16 const *restrict a, } __attribute__((target("avx512fp16,bmi2"))) extern float -vectors_f16_dot_axv512(_Float16 const *restrict a, _Float16 const *restrict b, - size_t n) { +v_f16_dot_axv512(_Float16 const *restrict a, _Float16 const *restrict b, + size_t n) { __m512h xy = _mm512_set1_ph(0); while (n >= 32) { @@ -50,8 +50,8 @@ vectors_f16_dot_axv512(_Float16 const *restrict a, _Float16 const *restrict b, } __attribute__((target("avx512fp16,bmi2"))) extern float -vectors_f16_distance_squared_l2_axv512(_Float16 const *restrict a, - _Float16 const *restrict b, size_t n) { +v_f16_sl2_axv512(_Float16 const *restrict a, _Float16 const *restrict b, + size_t n) { __m512h dd = _mm512_set1_ph(0); while (n >= 32) { @@ -73,8 +73,8 @@ vectors_f16_distance_squared_l2_axv512(_Float16 const *restrict a, } __attribute__((target("avx2"))) extern float -vectors_f16_cosine_axv2(_Float16 const *restrict a, _Float16 const *restrict b, - size_t n) { +v_f16_cosine_axv2(_Float16 const *restrict a, _Float16 const *restrict b, + size_t n) { float xy = 0; float xx = 0; float yy = 0; @@ -90,8 +90,8 @@ vectors_f16_cosine_axv2(_Float16 const *restrict a, _Float16 const *restrict b, } __attribute__((target("avx2"))) extern float -vectors_f16_dot_axv2(_Float16 const *restrict a, _Float16 const *restrict b, - size_t n) { +v_f16_dot_axv2(_Float16 const *restrict a, _Float16 const *restrict b, + size_t n) { float xy = 0; #pragma clang loop vectorize_width(8) for (size_t i = 0; i < n; i++) { @@ -103,8 +103,8 @@ vectors_f16_dot_axv2(_Float16 const *restrict a, _Float16 const *restrict b, } __attribute__((target("avx2"))) extern float -vectors_f16_distance_squared_l2_axv2(_Float16 const *restrict a, - _Float16 const *restrict b, size_t n) { +v_f16_sl2_axv2(_Float16 const *restrict a, _Float16 const *restrict b, + size_t n) { float dd = 0; #pragma clang loop vectorize_width(8) for (size_t i = 0; i < n; i++) { diff --git a/crates/c/src/c.h b/crates/c/src/c.h index 43541eb78..b0575faf7 100644 --- a/crates/c/src/c.h +++ b/crates/c/src/c.h @@ -1,15 +1,9 @@ #include #include -extern float vectors_f16_cosine_axv512(_Float16 const *, _Float16 const *, - size_t n); -extern float vectors_f16_dot_axv512(_Float16 const *, _Float16 const *, - size_t n); -extern float vectors_f16_distance_squared_l2_axv512(_Float16 const *, - _Float16 const *, size_t n); - -extern float vectors_f16_cosine_axv2(_Float16 const *, _Float16 const *, - size_t n); -extern float vectors_f16_dot_axv2(_Float16 const *, _Float16 const *, size_t n); -extern float vectors_f16_distance_squared_l2_axv2(_Float16 const *, - _Float16 const *, size_t n); +extern float v_f16_cosine_axv512(_Float16 const *, _Float16 const *, size_t n); +extern float v_f16_dot_axv512(_Float16 const *, _Float16 const *, size_t n); +extern float v_f16_sl2_axv512(_Float16 const *, _Float16 const *, size_t n); +extern float v_f16_cosine_axv2(_Float16 const *, _Float16 const *, size_t n); +extern float v_f16_dot_axv2(_Float16 const *, _Float16 const *, size_t n); +extern float v_f16_sl2_axv2(_Float16 const *, _Float16 const *, size_t n); diff --git a/crates/c/src/c.rs b/crates/c/src/c.rs index 5a80f7687..4fcbd9d1d 100644 --- a/crates/c/src/c.rs +++ b/crates/c/src/c.rs @@ -1,9 +1,9 @@ #[link(name = "pgvectorsc", kind = "static")] extern "C" { - pub fn vectors_f16_cosine_axv512(a: *const u16, b: *const u16, n: usize) -> f32; - pub fn vectors_f16_dot_axv512(a: *const u16, b: *const u16, n: usize) -> f32; - pub fn vectors_f16_distance_squared_l2_axv512(a: *const u16, b: *const u16, n: usize) -> f32; - pub fn vectors_f16_cosine_axv2(a: *const u16, b: *const u16, n: usize) -> f32; - pub fn vectors_f16_dot_axv2(a: *const u16, b: *const u16, n: usize) -> f32; - pub fn vectors_f16_distance_squared_l2_axv2(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn v_f16_cosine_axv512(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn v_f16_dot_axv512(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn v_f16_sl2_axv512(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn v_f16_cosine_axv2(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn v_f16_dot_axv2(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn v_f16_sl2_axv2(a: *const u16, b: *const u16, n: usize) -> f32; } diff --git a/crates/service/src/prelude/global/f16.rs b/crates/service/src/prelude/global/f16.rs index 84a6930d2..5f1afdc55 100644 --- a/crates/service/src/prelude/global/f16.rs +++ b/crates/service/src/prelude/global/f16.rs @@ -17,18 +17,17 @@ pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { xy / (x2 * y2).sqrt() } if super::avx512fp16::detect() { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); unsafe { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - return c::vectors_f16_cosine_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n) - .into(); + return c::v_f16_cosine_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); } } if super::avx2::detect() { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); unsafe { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - return c::vectors_f16_cosine_axv2(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); + return c::v_f16_cosine_axv2(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); } } cosine(lhs, rhs) @@ -47,26 +46,26 @@ pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { xy } if super::avx512fp16::detect() { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); unsafe { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - return c::vectors_f16_dot_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); + return c::v_f16_dot_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); } } if super::avx2::detect() { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); unsafe { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - return c::vectors_f16_dot_axv2(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); + return c::v_f16_dot_axv2(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); } } cosine(lhs, rhs) } -pub fn distance_squared_l2(lhs: &[F16], rhs: &[F16]) -> F32 { +pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { #[inline(always)] #[multiversion::multiversion(targets = "simd")] - pub fn distance_squared_l2(lhs: &[F16], rhs: &[F16]) -> F32 { + pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); let mut d2 = F32::zero(); @@ -77,28 +76,18 @@ pub fn distance_squared_l2(lhs: &[F16], rhs: &[F16]) -> F32 { d2 } if super::avx512fp16::detect() { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); unsafe { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - return c::vectors_f16_distance_squared_l2_axv512( - lhs.as_ptr().cast(), - rhs.as_ptr().cast(), - n, - ) - .into(); + return c::v_f16_sl2_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); } } if super::avx2::detect() { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); unsafe { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - return c::vectors_f16_distance_squared_l2_axv2( - lhs.as_ptr().cast(), - rhs.as_ptr().cast(), - n, - ) - .into(); + return c::v_f16_sl2_axv2(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); } } - distance_squared_l2(lhs, rhs) + sl2(lhs, rhs) } diff --git a/crates/service/src/prelude/global/f16_l2.rs b/crates/service/src/prelude/global/f16_l2.rs index 2cb15b6c2..3b45d0022 100644 --- a/crates/service/src/prelude/global/f16_l2.rs +++ b/crates/service/src/prelude/global/f16_l2.rs @@ -14,13 +14,13 @@ impl G for F16L2 { type L2 = F16L2; fn distance(lhs: &[F16], rhs: &[F16]) -> F32 { - super::f16::distance_squared_l2(lhs, rhs) + super::f16::sl2(lhs, rhs) } fn elkan_k_means_normalize(_: &mut [F16]) {} fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 { - super::f16::distance_squared_l2(lhs, rhs).sqrt() + super::f16::sl2(lhs, rhs).sqrt() } #[multiversion::multiversion(targets = "simd")] @@ -72,7 +72,7 @@ impl G for F16L2 { let lhs = &lhs[(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - result += super::f16::distance_squared_l2(lhs, rhs); + result += super::f16::sl2(lhs, rhs); } result } @@ -93,7 +93,7 @@ impl G for F16L2 { let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - result += super::f16::distance_squared_l2(lhs, rhs); + result += super::f16::sl2(lhs, rhs); } result } From 360ff162c93c3c45066ebf578c5fea0dbbe415d4 Mon Sep 17 00:00:00 2001 From: usamoi Date: Fri, 8 Dec 2023 21:19:48 +0800 Subject: [PATCH 11/23] fix: cross compilation Signed-off-by: usamoi --- crates/c/src/c.c | 9 ++++++++- crates/c/src/c.h | 4 ++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/crates/c/src/c.c b/crates/c/src/c.c index 777d4b480..79bb55c14 100644 --- a/crates/c/src/c.c +++ b/crates/c/src/c.c @@ -1,7 +1,12 @@ #include "c.h" -#include #include +#if defined(__x86_64__) +#include +#endif + +#if defined(__x86_64__) + __attribute__((target("avx512fp16,bmi2"))) extern float v_f16_cosine_axv512(_Float16 const *restrict a, _Float16 const *restrict b, size_t n) { @@ -115,3 +120,5 @@ v_f16_sl2_axv2(_Float16 const *restrict a, _Float16 const *restrict b, } return dd; } + +#endif diff --git a/crates/c/src/c.h b/crates/c/src/c.h index b0575faf7..23bd96c4e 100644 --- a/crates/c/src/c.h +++ b/crates/c/src/c.h @@ -1,9 +1,13 @@ #include #include +#if defined(__x86_64__) + extern float v_f16_cosine_axv512(_Float16 const *, _Float16 const *, size_t n); extern float v_f16_dot_axv512(_Float16 const *, _Float16 const *, size_t n); extern float v_f16_sl2_axv512(_Float16 const *, _Float16 const *, size_t n); extern float v_f16_cosine_axv2(_Float16 const *, _Float16 const *, size_t n); extern float v_f16_dot_axv2(_Float16 const *, _Float16 const *, size_t n); extern float v_f16_sl2_axv2(_Float16 const *, _Float16 const *, size_t n); + +#endif From 0646a929bb4702eb1ceeeec1b6c94cf8702a4c5e Mon Sep 17 00:00:00 2001 From: usamoi Date: Fri, 8 Dec 2023 22:02:08 +0800 Subject: [PATCH 12/23] fix: do not leave uninitialized bytes in datatype input function Signed-off-by: usamoi --- src/datatype/casts_f32.rs | 4 ++-- src/datatype/operators_f16.rs | 8 +++---- src/datatype/operators_f32.rs | 8 +++---- src/datatype/vecf16.rs | 45 +++-------------------------------- src/datatype/vecf32.rs | 45 +++-------------------------------- src/index/utils.rs | 1 + 6 files changed, 17 insertions(+), 94 deletions(-) diff --git a/src/datatype/casts_f32.rs b/src/datatype/casts_f32.rs index 53261a744..8fcbd7498 100644 --- a/src/datatype/casts_f32.rs +++ b/src/datatype/casts_f32.rs @@ -13,11 +13,11 @@ fn vecf32_cast_array_to_vector( assert!(!array.contains_nulls()); let typmod = Typmod::parse_from_i32(typmod).unwrap(); let len = typmod.dims().unwrap_or(array.len() as u16); - let mut data = Vecf32::new_zeroed_in_postgres(len as usize); + let mut data = vec![F32::zero(); len as usize]; for (i, x) in array.iter().enumerate() { data[i] = F32(x.unwrap_or(f32::NAN)); } - data + Vecf32::new_in_postgres(&data) } #[pgrx::pg_extern(immutable, parallel_safe, strict)] diff --git a/src/datatype/operators_f16.rs b/src/datatype/operators_f16.rs index 72c6367fe..9ff044815 100644 --- a/src/datatype/operators_f16.rs +++ b/src/datatype/operators_f16.rs @@ -14,11 +14,11 @@ fn vecf16_operator_add(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> Vecf16Outp .friendly(); } let n = lhs.len(); - let mut v = Vecf16::new_zeroed(n); + let mut v = vec![F16::zero(); n]; for i in 0..n { v[i] = lhs[i] + rhs[i]; } - v.copy_into_postgres() + Vecf16::new_in_postgres(&v) } #[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] @@ -32,11 +32,11 @@ fn vecf16_operator_minus(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> Vecf16Ou .friendly(); } let n = lhs.len(); - let mut v = Vecf16::new_zeroed(n); + let mut v = vec![F16::zero(); n]; for i in 0..n { v[i] = lhs[i] - rhs[i]; } - v.copy_into_postgres() + Vecf16::new_in_postgres(&v) } #[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] diff --git a/src/datatype/operators_f32.rs b/src/datatype/operators_f32.rs index 307055e1a..d4a67c22d 100644 --- a/src/datatype/operators_f32.rs +++ b/src/datatype/operators_f32.rs @@ -14,11 +14,11 @@ fn vecf32_operator_add(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> Vecf32Outp .friendly(); } let n = lhs.len(); - let mut v = Vecf32::new_zeroed(n); + let mut v = vec![F32::zero(); n]; for i in 0..n { v[i] = lhs[i] + rhs[i]; } - v.copy_into_postgres() + Vecf32::new_in_postgres(&v) } #[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] @@ -32,11 +32,11 @@ fn vecf32_operator_minus(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> Vecf32Ou .friendly(); } let n = lhs.len(); - let mut v = Vecf32::new_zeroed(n); + let mut v = vec![F32::zero(); n]; for i in 0..n { v[i] = lhs[i] - rhs[i]; } - v.copy_into_postgres() + Vecf32::new_in_postgres(&v) } #[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] diff --git a/src/datatype/vecf16.rs b/src/datatype/vecf16.rs index 2d7c48136..b4d5a6e6b 100644 --- a/src/datatype/vecf16.rs +++ b/src/datatype/vecf16.rs @@ -40,6 +40,7 @@ CREATE TYPE vecf16 ( pub struct Vecf16 { varlena: u32, kind: u8, + pad: u8, len: u16, phantom: [F16; 0], } @@ -55,53 +56,20 @@ impl Vecf16 { let layout = layout_alpha.extend(layout_beta).unwrap().0; layout.pad_to_align() } - pub fn new(slice: &[F16]) -> Box { - unsafe { - assert!(u16::try_from(slice.len()).is_ok()); - let layout = Vecf16::layout(slice.len()); - let ptr = std::alloc::alloc(layout) as *mut Vecf16; - std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16::varlena(layout.size())); - std::ptr::addr_of_mut!((*ptr).kind).write(16); - std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16); - std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); - Box::from_raw(ptr) - } - } pub fn new_in_postgres(slice: &[F16]) -> Vecf16Output { unsafe { assert!(u16::try_from(slice.len()).is_ok()); let layout = Vecf16::layout(slice.len()); let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut Vecf16; + ptr.cast::().add(layout.size() - 8).write_bytes(0, 8); std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16::varlena(layout.size())); std::ptr::addr_of_mut!((*ptr).kind).write(16); + std::ptr::addr_of_mut!((*ptr).pad).write(0); std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16); std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); Vecf16Output(NonNull::new(ptr).unwrap()) } } - pub fn new_zeroed(len: usize) -> Box { - unsafe { - assert!(u16::try_from(len).is_ok()); - let layout = Vecf16::layout(len); - let ptr = std::alloc::alloc_zeroed(layout) as *mut Vecf16; - std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16::varlena(layout.size())); - std::ptr::addr_of_mut!((*ptr).kind).write(16); - std::ptr::addr_of_mut!((*ptr).len).write(len as u16); - Box::from_raw(ptr) - } - } - #[allow(dead_code)] - pub fn new_zeroed_in_postgres(len: usize) -> Vecf16Output { - unsafe { - assert!(u64::try_from(len).is_ok()); - let layout = Vecf16::layout(len); - let ptr = pgrx::pg_sys::palloc0(layout.size()) as *mut Vecf16; - std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16::varlena(layout.size())); - std::ptr::addr_of_mut!((*ptr).kind).write(16); - std::ptr::addr_of_mut!((*ptr).len).write(len as u16); - Vecf16Output(NonNull::new(ptr).unwrap()) - } - } pub fn len(&self) -> usize { self.len as usize } @@ -115,13 +83,6 @@ impl Vecf16 { debug_assert_eq!(self.kind, 16); unsafe { std::slice::from_raw_parts_mut(self.phantom.as_mut_ptr(), self.len as usize) } } - #[allow(dead_code)] - pub fn copy(&self) -> Box { - Vecf16::new(self.data()) - } - pub fn copy_into_postgres(&self) -> Vecf16Output { - Vecf16::new_in_postgres(self.data()) - } } impl Deref for Vecf16 { diff --git a/src/datatype/vecf32.rs b/src/datatype/vecf32.rs index 6c9e20c4d..c42c4ecd7 100644 --- a/src/datatype/vecf32.rs +++ b/src/datatype/vecf32.rs @@ -40,6 +40,7 @@ CREATE TYPE vector ( pub struct Vecf32 { varlena: u32, kind: u8, + pad: u8, len: u16, phantom: [F32; 0], } @@ -55,53 +56,20 @@ impl Vecf32 { let layout = layout_alpha.extend(layout_beta).unwrap().0; layout.pad_to_align() } - pub fn new(slice: &[F32]) -> Box { - unsafe { - assert!(u16::try_from(slice.len()).is_ok()); - let layout = Vecf32::layout(slice.len()); - let ptr = std::alloc::alloc(layout) as *mut Vecf32; - std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32::varlena(layout.size())); - std::ptr::addr_of_mut!((*ptr).kind).write(32); - std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16); - std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); - Box::from_raw(ptr) - } - } pub fn new_in_postgres(slice: &[F32]) -> Vecf32Output { unsafe { assert!(u16::try_from(slice.len()).is_ok()); let layout = Vecf32::layout(slice.len()); let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut Vecf32; + ptr.cast::().add(layout.size() - 8).write_bytes(0, 8); std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32::varlena(layout.size())); std::ptr::addr_of_mut!((*ptr).kind).write(32); + std::ptr::addr_of_mut!((*ptr).pad).write(0); std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16); std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); Vecf32Output(NonNull::new(ptr).unwrap()) } } - pub fn new_zeroed(len: usize) -> Box { - unsafe { - assert!(u16::try_from(len).is_ok()); - let layout = Vecf32::layout(len); - let ptr = std::alloc::alloc_zeroed(layout) as *mut Vecf32; - std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32::varlena(layout.size())); - std::ptr::addr_of_mut!((*ptr).kind).write(32); - std::ptr::addr_of_mut!((*ptr).len).write(len as u16); - Box::from_raw(ptr) - } - } - #[allow(dead_code)] - pub fn new_zeroed_in_postgres(len: usize) -> Vecf32Output { - unsafe { - assert!(u64::try_from(len).is_ok()); - let layout = Vecf32::layout(len); - let ptr = pgrx::pg_sys::palloc0(layout.size()) as *mut Vecf32; - std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32::varlena(layout.size())); - std::ptr::addr_of_mut!((*ptr).kind).write(32); - std::ptr::addr_of_mut!((*ptr).len).write(len as u16); - Vecf32Output(NonNull::new(ptr).unwrap()) - } - } pub fn len(&self) -> usize { self.len as usize } @@ -115,13 +83,6 @@ impl Vecf32 { debug_assert_eq!(self.kind, 32); unsafe { std::slice::from_raw_parts_mut(self.phantom.as_mut_ptr(), self.len as usize) } } - #[allow(dead_code)] - pub fn copy(&self) -> Box { - Vecf32::new(self.data()) - } - pub fn copy_into_postgres(&self) -> Vecf32Output { - Vecf32::new_in_postgres(self.data()) - } } impl Deref for Vecf32 { diff --git a/src/index/utils.rs b/src/index/utils.rs index 3bd7978d7..338771ab2 100644 --- a/src/index/utils.rs +++ b/src/index/utils.rs @@ -6,6 +6,7 @@ use service::prelude::DynamicVector; struct Header { varlena: u32, kind: u8, + pad: u8, len: u16, } From 41e7476a1d33ed440018553765fa0d1d2dd9d931 Mon Sep 17 00:00:00 2001 From: usamoi Date: Sat, 9 Dec 2023 05:40:19 +0800 Subject: [PATCH 13/23] fix: compiler built-in function calling convention workaround Signed-off-by: usamoi --- Cargo.lock | 1 + crates/c/.gitignore | 3 ++- crates/c/Cargo.toml | 3 +++ crates/c/build.rs | 5 +++-- crates/c/src/c.c | 18 ++++++------------ crates/c/src/c.h | 12 ++++++------ crates/c/src/c.rs | 14 ++++++++++++++ crates/c/src/lib.rs | 2 ++ 8 files changed, 37 insertions(+), 21 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 32e8d2eee..96e5e099e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -486,6 +486,7 @@ name = "c" version = "0.0.0" dependencies = [ "cc", + "half 2.3.1", ] [[package]] diff --git a/crates/c/.gitignore b/crates/c/.gitignore index b72b9e32f..9f70fdf2e 100644 --- a/crates/c/.gitignore +++ b/crates/c/.gitignore @@ -1,2 +1,3 @@ *.s -*.o \ No newline at end of file +*.o +*.out \ No newline at end of file diff --git a/crates/c/Cargo.toml b/crates/c/Cargo.toml index 8dd8f339d..5dc084ed6 100644 --- a/crates/c/Cargo.toml +++ b/crates/c/Cargo.toml @@ -3,5 +3,8 @@ name = "c" version.workspace = true edition.workspace = true +[dependencies] +half = { version = "~2.3", features = ["use-intrinsics"] } + [build-dependencies] cc = "1.0" diff --git a/crates/c/build.rs b/crates/c/build.rs index b39683c63..dad66331b 100644 --- a/crates/c/build.rs +++ b/crates/c/build.rs @@ -1,9 +1,10 @@ fn main() { - println!("rerun-if-changed:src/c.h"); - println!("rerun-if-changed:src/c.c"); + println!("cargo:rerun-if-changed=src/c.h"); + println!("cargo:rerun-if-changed=src/c.c"); cc::Build::new() .compiler("/usr/bin/clang-16") .file("./src/c.c") .opt_level(3) + .debug(true) .compile("pgvectorsc"); } diff --git a/crates/c/src/c.c b/crates/c/src/c.c index 79bb55c14..eeee00933 100644 --- a/crates/c/src/c.c +++ b/crates/c/src/c.c @@ -8,8 +8,7 @@ #if defined(__x86_64__) __attribute__((target("avx512fp16,bmi2"))) extern float -v_f16_cosine_axv512(_Float16 const *restrict a, _Float16 const *restrict b, - size_t n) { +v_f16_cosine_axv512(_Float16 *a, _Float16 *b, size_t n) { __m512h xy = _mm512_set1_ph(0); __m512h xx = _mm512_set1_ph(0); __m512h yy = _mm512_set1_ph(0); @@ -35,8 +34,7 @@ v_f16_cosine_axv512(_Float16 const *restrict a, _Float16 const *restrict b, } __attribute__((target("avx512fp16,bmi2"))) extern float -v_f16_dot_axv512(_Float16 const *restrict a, _Float16 const *restrict b, - size_t n) { +v_f16_dot_axv512(_Float16 *a, _Float16 *b, size_t n) { __m512h xy = _mm512_set1_ph(0); while (n >= 32) { @@ -55,8 +53,7 @@ v_f16_dot_axv512(_Float16 const *restrict a, _Float16 const *restrict b, } __attribute__((target("avx512fp16,bmi2"))) extern float -v_f16_sl2_axv512(_Float16 const *restrict a, _Float16 const *restrict b, - size_t n) { +v_f16_sl2_axv512(_Float16 *a, _Float16 *b, size_t n) { __m512h dd = _mm512_set1_ph(0); while (n >= 32) { @@ -78,8 +75,7 @@ v_f16_sl2_axv512(_Float16 const *restrict a, _Float16 const *restrict b, } __attribute__((target("avx2"))) extern float -v_f16_cosine_axv2(_Float16 const *restrict a, _Float16 const *restrict b, - size_t n) { +v_f16_cosine_axv2(_Float16 *a, _Float16 *b, size_t n) { float xy = 0; float xx = 0; float yy = 0; @@ -95,8 +91,7 @@ v_f16_cosine_axv2(_Float16 const *restrict a, _Float16 const *restrict b, } __attribute__((target("avx2"))) extern float -v_f16_dot_axv2(_Float16 const *restrict a, _Float16 const *restrict b, - size_t n) { +v_f16_dot_axv2(_Float16 *a, _Float16 *b, size_t n) { float xy = 0; #pragma clang loop vectorize_width(8) for (size_t i = 0; i < n; i++) { @@ -108,8 +103,7 @@ v_f16_dot_axv2(_Float16 const *restrict a, _Float16 const *restrict b, } __attribute__((target("avx2"))) extern float -v_f16_sl2_axv2(_Float16 const *restrict a, _Float16 const *restrict b, - size_t n) { +v_f16_sl2_axv2(_Float16 *a, _Float16 *b, size_t n) { float dd = 0; #pragma clang loop vectorize_width(8) for (size_t i = 0; i < n; i++) { diff --git a/crates/c/src/c.h b/crates/c/src/c.h index 23bd96c4e..914546ed6 100644 --- a/crates/c/src/c.h +++ b/crates/c/src/c.h @@ -3,11 +3,11 @@ #if defined(__x86_64__) -extern float v_f16_cosine_axv512(_Float16 const *, _Float16 const *, size_t n); -extern float v_f16_dot_axv512(_Float16 const *, _Float16 const *, size_t n); -extern float v_f16_sl2_axv512(_Float16 const *, _Float16 const *, size_t n); -extern float v_f16_cosine_axv2(_Float16 const *, _Float16 const *, size_t n); -extern float v_f16_dot_axv2(_Float16 const *, _Float16 const *, size_t n); -extern float v_f16_sl2_axv2(_Float16 const *, _Float16 const *, size_t n); +extern float v_f16_cosine_axv512(_Float16 *, _Float16 *, size_t n); +extern float v_f16_dot_axv512(_Float16 *, _Float16 *, size_t n); +extern float v_f16_sl2_axv512(_Float16 *, _Float16 *, size_t n); +extern float v_f16_cosine_axv2(_Float16 *, _Float16 *, size_t n); +extern float v_f16_dot_axv2(_Float16 *, _Float16 *, size_t n); +extern float v_f16_sl2_axv2(_Float16 *, _Float16 *, size_t n); #endif diff --git a/crates/c/src/c.rs b/crates/c/src/c.rs index 4fcbd9d1d..5752b0f54 100644 --- a/crates/c/src/c.rs +++ b/crates/c/src/c.rs @@ -1,3 +1,4 @@ +#[cfg(target_arch = "x86_64")] #[link(name = "pgvectorsc", kind = "static")] extern "C" { pub fn v_f16_cosine_axv512(a: *const u16, b: *const u16, n: usize) -> f32; @@ -7,3 +8,16 @@ extern "C" { pub fn v_f16_dot_axv2(a: *const u16, b: *const u16, n: usize) -> f32; pub fn v_f16_sl2_axv2(a: *const u16, b: *const u16, n: usize) -> f32; } + +// `compiler_builtin` defines `__extendhfsf2` with integer calling convention. +// However C compilers links `__extendhfsf2` with floating calling convention. +// The code should be removed once Rust offically supports `f16`. + +#[no_mangle] +#[linkage = "external"] +extern "C" fn __extendhfsf2(f: f64) -> f32 { + unsafe { + let f: half::f16 = std::mem::transmute_copy(&f); + f.to_f32() + } +} diff --git a/crates/c/src/lib.rs b/crates/c/src/lib.rs index 6f1d73975..1d776513d 100644 --- a/crates/c/src/lib.rs +++ b/crates/c/src/lib.rs @@ -1,3 +1,5 @@ +#![feature(linkage)] + mod c; pub use self::c::*; From a2ada6da18aba30677a130357614bd044f3689ee Mon Sep 17 00:00:00 2001 From: usamoi Date: Tue, 12 Dec 2023 14:24:11 +0800 Subject: [PATCH 14/23] fix: cross compile on aarch64 Signed-off-by: usamoi --- .cargo/config.toml | 10 ++ .github/workflows/release.yml | 3 - Cargo.lock | 146 +++++++++--------- crates/c/src/c.rs | 1 + crates/c/src/lib.rs | 1 + crates/service/src/prelude/global/avx2.rs | 9 -- .../service/src/prelude/global/avx512fp16.rs | 10 -- crates/service/src/prelude/global/detect.rs | 10 ++ crates/service/src/prelude/global/f16.rs | 18 ++- crates/service/src/prelude/global/mod.rs | 3 +- scripts/ci_setup.sh | 1 + 11 files changed, 108 insertions(+), 104 deletions(-) delete mode 100644 crates/service/src/prelude/global/avx2.rs delete mode 100644 crates/service/src/prelude/global/avx512fp16.rs create mode 100644 crates/service/src/prelude/global/detect.rs diff --git a/.cargo/config.toml b/.cargo/config.toml index c4db64902..52579941e 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -4,3 +4,13 @@ rustdocflags = ["--document-private-items"] [target.'cfg(target_os="macos")'] # Postgres symbols won't be available until runtime rustflags = ["-Clink-arg=-Wl,-undefined,dynamic_lookup"] + +[target.x86_64-unknown-linux-gnu] +linker = "x86_64-linux-gnu-gcc" + +[target.aarch64-unknown-linux-gnu] +linker = "aarch64-linux-gnu-gcc" + +[env] +BINDGEN_EXTRA_CLANG_ARGS_x86_64_unknown_linux_gnu = "-isystem /usr/x86_64-linux-gnu/include/ -ccc-gcc-name x86_64-linux-gnu-gcc" +BINDGEN_EXTRA_CLANG_ARGS_aarch64_unknown_linux_gnu = "-isystem /usr/aarch64-linux-gnu/include/ -ccc-gcc-name aarch64-linux-gnu-gcc" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c694c8c9c..c0c73cf8f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -123,7 +123,6 @@ jobs: cargo pgrx init --pg${{ matrix.version }}=/usr/lib/postgresql/${{ matrix.version }}/bin/pg_config if [[ "${{ matrix.arch }}" == "arm64" ]]; then sudo apt-get -y install crossbuild-essential-arm64 - rustup target add aarch64-unknown-linux-gnu fi - name: Build Release id: build_release @@ -133,8 +132,6 @@ jobs: mkdir ./artifacts cargo pgrx package if [[ "${{ matrix.arch }}" == "arm64" ]]; then - export CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_LINKER=aarch64-linux-gnu-gcc - export BINDGEN_EXTRA_CLANG_ARGS_aarch64_unknown_linux_gnu="-target aarch64-unknown-linux-gnu -isystem /usr/aarch64-linux-gnu/include/ -ccc-gcc-name aarch64-linux-gnu-gcc" cargo build --target aarch64-unknown-linux-gnu --release --features "pg${{ matrix.version }}" --no-default-features mv ./target/aarch64-unknown-linux-gnu/release/libvectors.so ./target/release/vectors-pg${{ matrix.version }}/usr/lib/postgresql/${{ matrix.version }}/lib/vectors.so fi diff --git a/Cargo.lock b/Cargo.lock index 96e5e099e..667f7f85b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -127,13 +127,13 @@ dependencies = [ [[package]] name = "async-global-executor" -version = "2.4.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b4353121d5644cdf2beb5726ab752e79a8db1ebb52031770ec47db31d245526" +checksum = "05b1b633a2115cd122d73b955eadd9916c18c8f510ec9cd1686404c60ad1c29c" dependencies = [ "async-channel 2.1.1", "async-executor", - "async-io 2.2.1", + "async-io 2.2.2", "async-lock 3.2.0", "blocking", "futures-lite 2.1.0", @@ -162,9 +162,9 @@ dependencies = [ [[package]] name = "async-io" -version = "2.2.1" +version = "2.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6d3b15875ba253d1110c740755e246537483f152fa334f91abd7fe84c88b3ff" +checksum = "6afaa937395a620e33dc6a742c593c01aced20aa376ffb0f628121198578ccc7" dependencies = [ "async-lock 3.2.0", "cfg-if", @@ -173,7 +173,7 @@ dependencies = [ "futures-lite 2.1.0", "parking", "polling 3.3.1", - "rustix 0.38.26", + "rustix 0.38.28", "slab", "tracing", "windows-sys 0.52.0", @@ -221,7 +221,7 @@ dependencies = [ "cfg-if", "event-listener 3.1.0", "futures-lite 1.13.0", - "rustix 0.38.26", + "rustix 0.38.28", "windows-sys 0.48.0", ] @@ -231,13 +231,13 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e47d90f65a225c4527103a8d747001fc56e375203592b25ad103e1ca13124c5" dependencies = [ - "async-io 2.2.1", + "async-io 2.2.2", "async-lock 2.8.0", "atomic-waker", "cfg-if", "futures-core", "futures-io", - "rustix 0.38.26", + "rustix 0.38.28", "signal-hook-registry", "slab", "windows-sys 0.48.0", @@ -284,14 +284,14 @@ checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] name = "atomic-polyfill" -version = "0.1.11" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3ff7eb3f316534d83a8a2c3d1674ace8a5a71198eba31e2e2b597833f699b28" +checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4" dependencies = [ "critical-section", ] @@ -362,8 +362,6 @@ dependencies = [ [[package]] name = "bindgen" version = "0.68.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "726e4313eb6ec35d2730258ad4e15b547ee75d6afaa1361a922e78e59b7d8078" dependencies = [ "bitflags 2.4.1", "cexpr", @@ -376,7 +374,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -466,7 +464,7 @@ checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -555,9 +553,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.4.10" +version = "4.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41fffed7514f420abec6d183b1d3acfd9099c79c3a10a06ade4f8203f1411272" +checksum = "bfaff671f6b22ca62406885ece523383b9b64022e341e53e009a62ebc47a45f2" dependencies = [ "clap_builder", "clap_derive", @@ -575,9 +573,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.4.9" +version = "4.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63361bae7eef3771745f02d8d892bec2fee5f6e34af316ba556e7f97a7069ff1" +checksum = "a216b506622bb1d316cd51328dce24e07bdff4a6128a47c7e7fad11878d5adbb" dependencies = [ "anstyle", "clap_lex", @@ -592,7 +590,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -755,9 +753,9 @@ dependencies = [ [[package]] name = "curl-sys" -version = "0.4.68+curl-8.4.0" +version = "0.4.70+curl-8.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4a0d18d88360e374b16b2273c832b5e57258ffc1d4aa4f96b108e0738d5752f" +checksum = "3c0333d8849afe78a4c8102a429a446bfdd055832af071945520e835ae2d841e" dependencies = [ "cc", "libc", @@ -790,7 +788,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -801,7 +799,7 @@ checksum = "836a9bbc7ad63342d6d6e7b815ccab164bc77a2d95d84bc3117a8c0d5c98e2d5" dependencies = [ "darling_core", "quote", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -945,7 +943,7 @@ checksum = "f282cfdfe92516eb26c2af8589c274c7c17681f5ecc03c18255fe741c6aa64eb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -1017,9 +1015,9 @@ dependencies = [ [[package]] name = "eyre" -version = "0.6.9" +version = "0.6.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80f656be11ddf91bd709454d15d5bd896fbaf4cc3314e69349e4d1569f5b46cd" +checksum = "8bbb8258be8305fb0237d7b295f47bb24ff1b136a535f473baf40e70468515aa" dependencies = [ "indenter", "once_cell", @@ -1162,7 +1160,7 @@ checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -1279,9 +1277,9 @@ checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" [[package]] name = "heapless" -version = "0.7.16" +version = "0.7.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db04bc24a18b9ea980628ecf00e6c0264f3c1426dac36c00cb49b6fbad8b0743" +checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f" dependencies = [ "atomic-polyfill", "hash32", @@ -1330,9 +1328,9 @@ dependencies = [ [[package]] name = "http-body" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" +checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", "http", @@ -1518,7 +1516,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" dependencies = [ "hermit-abi", - "rustix 0.38.26", + "rustix 0.38.28", "windows-sys 0.48.0", ] @@ -1560,9 +1558,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.9" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" +checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" [[package]] name = "js-sys" @@ -1633,9 +1631,9 @@ checksum = "db13adb97ab515a3691f56e4dbab09283d0b86cb45abd991d8634a9d6f501760" [[package]] name = "libc" -version = "0.2.150" +version = "0.2.151" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" +checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" [[package]] name = "libloading" @@ -1774,9 +1772,9 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.9" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3dce281c5e46beae905d4de1870d8b1509a9142b62eedf18b443b011ca8343d0" +checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" dependencies = [ "libc", "wasi", @@ -1894,9 +1892,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "openai_api_rust" @@ -1919,9 +1917,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.96" +version = "0.9.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3812c071ba60da8b5677cc12bcb1d42989a65553772897a7e0355545a819838f" +checksum = "c3eaad34cdd97d81de97964fc7f29e2d104f483840d906ef56daa1912338460b" dependencies = [ "cc", "libc", @@ -2169,7 +2167,7 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -2226,7 +2224,7 @@ dependencies = [ "cfg-if", "concurrent-queue", "pin-project-lite", - "rustix 0.38.26", + "rustix 0.38.28", "tracing", "windows-sys 0.52.0", ] @@ -2512,9 +2510,9 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "ring" -version = "0.17.6" +version = "0.17.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "684d5e6e18f669ccebf64a92236bb7db9a34f07be010e3627368182027180866" +checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74" dependencies = [ "cc", "getrandom", @@ -2570,9 +2568,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.26" +version = "0.38.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9470c4bf8246c8daf25f9598dca807fb6510347b1e1cfa55749113850c79d88a" +checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316" dependencies = [ "bitflags 2.4.1", "errno", @@ -2583,9 +2581,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.21.9" +version = "0.21.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "629648aced5775d558af50b2b4c7b02983a04b312126d45eeead26e7caa498b9" +checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" dependencies = [ "log", "ring", @@ -2623,9 +2621,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.15" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" +checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" [[package]] name = "same-file" @@ -2724,7 +2722,7 @@ checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -2783,7 +2781,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -2809,7 +2807,7 @@ dependencies = [ "parking_lot", "rand", "rayon", - "rustix 0.38.26", + "rustix 0.38.28", "serde", "serde_json", "serde_with", @@ -2984,9 +2982,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.39" +version = "2.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23e78b90f2fcf45d3e842032ce32e3f2d1545ba6636271dcbf24fa306d87be7a" +checksum = "13fa70a4ee923979ffb522cacce59d34421ebdea5625e1073c4326ef9d2dd42e" dependencies = [ "proc-macro2", "quote", @@ -3029,7 +3027,7 @@ dependencies = [ "cfg-if", "fastrand 2.0.1", "redox_syscall", - "rustix 0.38.26", + "rustix 0.38.28", "windows-sys 0.48.0", ] @@ -3076,7 +3074,7 @@ checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -3134,9 +3132,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.34.0" +version = "1.35.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0c014766411e834f7af5b8f4cf46257aab4036ca95e9d2c144a10f59ad6f5b9" +checksum = "841d45b238a16291a4e1584e61820b8ae57d696cc5015c459c229ccc6990cc1c" dependencies = [ "backtrace", "bytes", @@ -3158,7 +3156,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -3261,7 +3259,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", ] [[package]] @@ -3285,9 +3283,9 @@ dependencies = [ [[package]] name = "try-lock" -version = "0.2.4" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] name = "typenum" @@ -3324,9 +3322,9 @@ checksum = "ccb97dac3243214f8d8507998906ca3e2e0b900bf9bf4870477f125b82e68f6e" [[package]] name = "unicode-bidi" -version = "0.3.13" +version = "0.3.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" +checksum = "6f2528f27a9eb2b21e69c95319b30bd0efd85d09c379741b0f78ea1d86be2416" [[package]] name = "unicode-ident" @@ -3470,7 +3468,7 @@ dependencies = [ "openai_api_rust", "pgrx", "pgrx-tests", - "rustix 0.38.26", + "rustix 0.38.28", "serde", "serde_json", "service", @@ -3546,7 +3544,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", "wasm-bindgen-shared", ] @@ -3580,7 +3578,7 @@ checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.40", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3791,9 +3789,9 @@ checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" [[package]] name = "winnow" -version = "0.5.19" +version = "0.5.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "829846f3e3db426d4cee4510841b71a8e58aa2a76b1132579487ae430ccd9c7b" +checksum = "6c830786f7720c2fd27a1a0e27a709dbd3c4d009b56d098fc742d4f4eab91fe2" dependencies = [ "memchr", ] diff --git a/crates/c/src/c.rs b/crates/c/src/c.rs index 5752b0f54..1d16999b3 100644 --- a/crates/c/src/c.rs +++ b/crates/c/src/c.rs @@ -13,6 +13,7 @@ extern "C" { // However C compilers links `__extendhfsf2` with floating calling convention. // The code should be removed once Rust offically supports `f16`. +#[cfg(target_arch = "x86_64")] #[no_mangle] #[linkage = "external"] extern "C" fn __extendhfsf2(f: f64) -> f32 { diff --git a/crates/c/src/lib.rs b/crates/c/src/lib.rs index 1d776513d..9c3d869be 100644 --- a/crates/c/src/lib.rs +++ b/crates/c/src/lib.rs @@ -2,4 +2,5 @@ mod c; +#[allow(unused_imports)] pub use self::c::*; diff --git a/crates/service/src/prelude/global/avx2.rs b/crates/service/src/prelude/global/avx2.rs deleted file mode 100644 index 404ebbd48..000000000 --- a/crates/service/src/prelude/global/avx2.rs +++ /dev/null @@ -1,9 +0,0 @@ -#[cfg(not(target_arch = "x86_64"))] -pub fn detect() -> bool { - false -} - -#[cfg(target_arch = "x86_64")] -pub fn detect() -> bool { - std_detect::is_x86_feature_detected!("avx2") -} diff --git a/crates/service/src/prelude/global/avx512fp16.rs b/crates/service/src/prelude/global/avx512fp16.rs deleted file mode 100644 index 28613f0d5..000000000 --- a/crates/service/src/prelude/global/avx512fp16.rs +++ /dev/null @@ -1,10 +0,0 @@ -#[cfg(not(target_arch = "x86_64"))] -pub fn detect() -> bool { - false -} - -#[cfg(target_arch = "x86_64")] -pub fn detect() -> bool { - std_detect::is_x86_feature_detected!("avx512fp16") - && std_detect::is_x86_feature_detected!("bmi2") -} diff --git a/crates/service/src/prelude/global/detect.rs b/crates/service/src/prelude/global/detect.rs new file mode 100644 index 000000000..cf2282fc8 --- /dev/null +++ b/crates/service/src/prelude/global/detect.rs @@ -0,0 +1,10 @@ +#[cfg(target_arch = "x86_64")] +pub fn detect_avx512fp16() -> bool { + std_detect::is_x86_feature_detected!("avx512fp16") + && std_detect::is_x86_feature_detected!("bmi2") +} + +#[cfg(target_arch = "x86_64")] +pub fn detect_avx2() -> bool { + std_detect::is_x86_feature_detected!("avx2") +} diff --git a/crates/service/src/prelude/global/f16.rs b/crates/service/src/prelude/global/f16.rs index 5f1afdc55..587a134cb 100644 --- a/crates/service/src/prelude/global/f16.rs +++ b/crates/service/src/prelude/global/f16.rs @@ -16,14 +16,16 @@ pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { } xy / (x2 * y2).sqrt() } - if super::avx512fp16::detect() { + #[cfg(target_arch = "x86_64")] + if super::detect::detect_avx512fp16() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { return c::v_f16_cosine_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); } } - if super::avx2::detect() { + #[cfg(target_arch = "x86_64")] + if super::detect::detect_avx2() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { @@ -45,14 +47,16 @@ pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { } xy } - if super::avx512fp16::detect() { + #[cfg(target_arch = "x86_64")] + if super::detect::detect_avx512fp16() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { return c::v_f16_dot_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); } } - if super::avx2::detect() { + #[cfg(target_arch = "x86_64")] + if super::detect::detect_avx2() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { @@ -75,14 +79,16 @@ pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { } d2 } - if super::avx512fp16::detect() { + #[cfg(target_arch = "x86_64")] + if super::detect::detect_avx512fp16() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { return c::v_f16_sl2_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); } } - if super::avx2::detect() { + #[cfg(target_arch = "x86_64")] + if super::detect::detect_avx2() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { diff --git a/crates/service/src/prelude/global/mod.rs b/crates/service/src/prelude/global/mod.rs index 00b518269..fa78a8fc4 100644 --- a/crates/service/src/prelude/global/mod.rs +++ b/crates/service/src/prelude/global/mod.rs @@ -1,5 +1,4 @@ -mod avx2; -mod avx512fp16; +mod detect; mod f16; mod f16_cos; mod f16_dot; diff --git a/scripts/ci_setup.sh b/scripts/ci_setup.sh index 6efaf81b0..53f9171ec 100755 --- a/scripts/ci_setup.sh +++ b/scripts/ci_setup.sh @@ -5,6 +5,7 @@ if [ "$OS" == "ubuntu-latest" ]; then if [ $VERSION != 14 ]; then sudo pg_dropcluster 14 main fi + sudo apt-get -y install crossbuild-essential-arm64 sudo apt-get remove -y '^postgres.*' '^libpq.*' '^clang.*' '^llvm.*' '^libclang.*' '^libllvm.*' '^mono-llvm.*' sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" >> /etc/apt/sources.list.d/pgdg.list' sudo sh -c 'echo "deb http://apt.llvm.org/$(lsb_release -cs)/ llvm-toolchain-$(lsb_release -cs)-16 main" >> /etc/apt/sources.list' From 8a925e3dafdf773453b2741a634f1747e1881d66 Mon Sep 17 00:00:00 2001 From: usamoi Date: Tue, 12 Dec 2023 16:46:20 +0800 Subject: [PATCH 15/23] fix: fix detect avx512fp16 Signed-off-by: usamoi --- Cargo.lock | 3 ++- README.md | 20 ++++++++++++-------- crates/service/Cargo.toml | 2 +- docs/installation.md | 22 +++++++++++++++++++--- 4 files changed, 34 insertions(+), 13 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 667f7f85b..f23bb8dcb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -362,6 +362,7 @@ dependencies = [ [[package]] name = "bindgen" version = "0.68.1" +source = "git+https://github.com/usamoi/rust-bindgen.git?rev=7e683d3cc6a0667232f821088b8d0ed7f4d4c31e#7e683d3cc6a0667232f821088b8d0ed7f4d4c31e" dependencies = [ "bitflags 2.4.1", "cexpr", @@ -2927,7 +2928,7 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "std_detect" version = "0.1.5" -source = "git+https://github.com/usamoi/stdarch.git?rev=067a6e889f0ca995a9fe4114061ced6f67acfb95#067a6e889f0ca995a9fe4114061ced6f67acfb95" +source = "git+https://github.com/usamoi/stdarch.git?rev=d934b65e47ce82ce4c20f0268dab01f71fb7b9c7#d934b65e47ce82ce4c20f0268dab01f71fb7b9c7" dependencies = [ "cfg-if", "libc", diff --git a/README.md b/README.md index 4fe30f6c0..246f04ffc 100644 --- a/README.md +++ b/README.md @@ -21,13 +21,13 @@ pgvecto.rs is a Postgres extension that provides vector similarity search functi ## Comparison with pgvector -| | pgvecto.rs | pgvector | -| ------------------------------------------- | ------------------------------------------------------ | ------------------------ | -| Transaction support | ✅ | ⚠️ | -| Sufficient Result with Delete/Update/Filter | ✅ | ⚠️ | -| Vector Dimension Limit | 65535 | 2000 | -| Prefilter on HNSW | ✅ | ❌ | -| Parallel HNSW Index build | ⚡️ Linearly faster with more cores | 🐌 Only single core used | +| | pgvecto.rs | pgvector | +| ------------------------------------------- | ------------------------------------------------------ | ----------------------- | +| Transaction support | ✅ | ⚠️ | +| Sufficient Result with Delete/Update/Filter | ✅ | ⚠️ | +| Vector Dimension Limit | 65535 | 2000 | +| Prefilter on HNSW | ✅ | ❌ | +| Parallel HNSW Index build | ⚡️ Linearly faster with more cores | 🐌 Only single core used | | Async Index build | Ready for queries anytime and do not block insertions. | ❌ | | Quantization | Scalar/Product Quantization | ❌ | @@ -45,7 +45,11 @@ More details at [./docs/comparison-pgvector.md](./docs/comparison-pgvector.md) For users, we recommend you to try pgvecto.rs using our pre-built docker image, by running ```sh -docker run --name pgvecto-rs-demo -e POSTGRES_PASSWORD=mysecretpassword -p 5432:5432 -d tensorchord/pgvecto-rs:pg16-latest +docker run \ + --name pgvecto-rs-demo \ + -e POSTGRES_PASSWORD=mysecretpassword \ + -p 5432:5432 \ + -d tensorchord/pgvecto-rs:pg16-latest ``` ## Development with envd diff --git a/crates/service/Cargo.toml b/crates/service/Cargo.toml index 92650c3ec..509e0b43d 100644 --- a/crates/service/Cargo.toml +++ b/crates/service/Cargo.toml @@ -16,7 +16,7 @@ bincode.workspace = true half.workspace = true num-traits.workspace = true c = { path = "../c" } -std_detect = { git = "https://github.com/usamoi/stdarch.git", rev = "067a6e889f0ca995a9fe4114061ced6f67acfb95" } +std_detect = { git = "https://github.com/usamoi/stdarch.git", rev = "d934b65e47ce82ce4c20f0268dab01f71fb7b9c7" } rand = "0.8.5" crc32fast = "1.3.2" crossbeam = "0.8.2" diff --git a/docs/installation.md b/docs/installation.md index cd8d7a41f..c55903a07 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -19,12 +19,28 @@ To acheive full performance, please mount the volume to pg data directory by add You can configure PostgreSQL by the reference of the parent image in https://hub.docker.com/_/postgres/. -## Build from source +## Install from source Install Rust and base dependency. ```sh -sudo apt install -y build-essential libpq-dev libssl-dev pkg-config gcc libreadline-dev flex bison libxml2-dev libxslt-dev libxml2-utils xsltproc zlib1g-dev ccache clang git +sudo apt install -y \ + build-essential \ + libpq-dev \ + libssl-dev \ + pkg-config \ + gcc \ + libreadline-dev \ + flex \ + bison \ + libxml2-dev \ + libxslt-dev \ + libxml2-utils \ + xsltproc \ + zlib1g-dev \ + ccache \ + clang \ + git curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh ``` @@ -63,7 +79,7 @@ cargo pgrx init --pg15=/usr/lib/postgresql/15/bin/pg_config Install pgvecto.rs. ```sh -cargo pgrx install --release +cargo pgrx install --sudo --release ``` Configure your PostgreSQL by modifying the `shared_preload_libraries` to include `vectors.so`. From 308cb643f6ed6e2816d5759ce1fa39bf6ffef05b Mon Sep 17 00:00:00 2001 From: usamoi Date: Tue, 12 Dec 2023 23:35:28 +0800 Subject: [PATCH 16/23] fix: avx512 codegen by multiversion Signed-off-by: usamoi --- crates/service/src/lib.rs | 1 + crates/service/src/prelude/global/f16.rs | 21 +++++- crates/service/src/prelude/global/f16_cos.rs | 63 +++++++++++++++--- crates/service/src/prelude/global/f16_dot.rs | 56 +++++++++++++--- crates/service/src/prelude/global/f16_l2.rs | 42 ++++++++++-- crates/service/src/prelude/global/f32_cos.rs | 70 +++++++++++++++++--- crates/service/src/prelude/global/f32_dot.rs | 70 +++++++++++++++++--- crates/service/src/prelude/global/f32_l2.rs | 49 ++++++++++++-- 8 files changed, 319 insertions(+), 53 deletions(-) diff --git a/crates/service/src/lib.rs b/crates/service/src/lib.rs index bf1f5018a..b534589f1 100644 --- a/crates/service/src/lib.rs +++ b/crates/service/src/lib.rs @@ -1,4 +1,5 @@ #![feature(core_intrinsics)] +#![feature(avx512_target_feature)] pub mod algorithms; pub mod index; diff --git a/crates/service/src/prelude/global/f16.rs b/crates/service/src/prelude/global/f16.rs index 587a134cb..610707ea2 100644 --- a/crates/service/src/prelude/global/f16.rs +++ b/crates/service/src/prelude/global/f16.rs @@ -2,7 +2,12 @@ use crate::prelude::*; pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { #[inline(always)] - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -37,7 +42,12 @@ pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { #[inline(always)] - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -68,7 +78,12 @@ pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { #[inline(always)] - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); diff --git a/crates/service/src/prelude/global/f16_cos.rs b/crates/service/src/prelude/global/f16_cos.rs index 2cfcacc4f..e6c104dc1 100644 --- a/crates/service/src/prelude/global/f16_cos.rs +++ b/crates/service/src/prelude/global/f16_cos.rs @@ -24,7 +24,12 @@ impl G for F16Cos { super::f16::dot(lhs, rhs).acos() } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn scalar_quantization_distance( dims: u16, max: &[F16], @@ -45,7 +50,12 @@ impl G for F16Cos { xy / (x2 * y2).sqrt() * (-1.0) } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn scalar_quantization_distance2( dims: u16, max: &[F16], @@ -66,7 +76,12 @@ impl G for F16Cos { xy / (x2 * y2).sqrt() * (-1.0) } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn product_quantization_distance( dims: u16, ratio: u16, @@ -91,7 +106,12 @@ impl G for F16Cos { xy / (x2 * y2).sqrt() * (-1.0) } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn product_quantization_distance2( dims: u16, ratio: u16, @@ -117,7 +137,12 @@ impl G for F16Cos { xy / (x2 * y2).sqrt() * (-1.0) } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn product_quantization_distance_with_delta( dims: u16, ratio: u16, @@ -146,7 +171,12 @@ impl G for F16Cos { } #[inline(always)] -#[multiversion::multiversion(targets = "simd")] +#[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" +))] fn length(vector: &[F16]) -> F16 { let n = vector.len(); let mut dot = F16::zero(); @@ -157,7 +187,12 @@ fn length(vector: &[F16]) -> F16 { } #[inline(always)] -#[multiversion::multiversion(targets = "simd")] +#[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" +))] fn l2_normalize(vector: &mut [F16]) { let n = vector.len(); let l = length(vector); @@ -167,7 +202,12 @@ fn l2_normalize(vector: &mut [F16]) { } #[inline(always)] -#[multiversion::multiversion(targets = "simd")] +#[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" +))] fn xy_x2_y2(lhs: &[F16], rhs: &[F16]) -> (F32, F32, F32) { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -183,7 +223,12 @@ fn xy_x2_y2(lhs: &[F16], rhs: &[F16]) -> (F32, F32, F32) { } #[inline(always)] -#[multiversion::multiversion(targets = "simd")] +#[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" +))] fn xy_x2_y2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> (F32, F32, F32) { assert!(lhs.len() == rhs.len()); let n = lhs.len(); diff --git a/crates/service/src/prelude/global/f16_dot.rs b/crates/service/src/prelude/global/f16_dot.rs index bee46c56f..cba885fe9 100644 --- a/crates/service/src/prelude/global/f16_dot.rs +++ b/crates/service/src/prelude/global/f16_dot.rs @@ -24,7 +24,12 @@ impl G for F16Dot { super::f16::dot(lhs, rhs).acos() } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn scalar_quantization_distance( dims: u16, max: &[F16], @@ -41,7 +46,12 @@ impl G for F16Dot { xy * (-1.0) } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn scalar_quantization_distance2( dims: u16, max: &[F16], @@ -58,7 +68,12 @@ impl G for F16Dot { xy * (-1.0) } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn product_quantization_distance( dims: u16, ratio: u16, @@ -79,7 +94,12 @@ impl G for F16Dot { xy * (-1.0) } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn product_quantization_distance2( dims: u16, ratio: u16, @@ -101,7 +121,12 @@ impl G for F16Dot { xy * (-1.0) } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn product_quantization_distance_with_delta( dims: u16, ratio: u16, @@ -126,7 +151,12 @@ impl G for F16Dot { } #[inline(always)] -#[multiversion::multiversion(targets = "simd")] +#[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" +))] fn length(vector: &[F16]) -> F16 { let n = vector.len(); let mut dot = F16::zero(); @@ -137,7 +167,12 @@ fn length(vector: &[F16]) -> F16 { } #[inline(always)] -#[multiversion::multiversion(targets = "simd")] +#[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" +))] fn l2_normalize(vector: &mut [F16]) { let n = vector.len(); let l = length(vector); @@ -147,7 +182,12 @@ fn l2_normalize(vector: &mut [F16]) { } #[inline(always)] -#[multiversion::multiversion(targets = "simd")] +#[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" +))] fn dot_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 { assert!(lhs.len() == rhs.len()); let n: usize = lhs.len(); diff --git a/crates/service/src/prelude/global/f16_l2.rs b/crates/service/src/prelude/global/f16_l2.rs index 3b45d0022..bb03889fe 100644 --- a/crates/service/src/prelude/global/f16_l2.rs +++ b/crates/service/src/prelude/global/f16_l2.rs @@ -23,7 +23,12 @@ impl G for F16L2 { super::f16::sl2(lhs, rhs).sqrt() } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn scalar_quantization_distance( dims: u16, max: &[F16], @@ -40,7 +45,12 @@ impl G for F16L2 { result } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn scalar_quantization_distance2( dims: u16, max: &[F16], @@ -57,7 +67,12 @@ impl G for F16L2 { result } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn product_quantization_distance( dims: u16, ratio: u16, @@ -77,7 +92,12 @@ impl G for F16L2 { result } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn product_quantization_distance2( dims: u16, ratio: u16, @@ -98,7 +118,12 @@ impl G for F16L2 { result } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn product_quantization_distance_with_delta( dims: u16, ratio: u16, @@ -122,7 +147,12 @@ impl G for F16L2 { } #[inline(always)] -#[multiversion::multiversion(targets = "simd")] +#[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" +))] fn distance_squared_l2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); diff --git a/crates/service/src/prelude/global/f32_cos.rs b/crates/service/src/prelude/global/f32_cos.rs index 4c48af279..868035143 100644 --- a/crates/service/src/prelude/global/f32_cos.rs +++ b/crates/service/src/prelude/global/f32_cos.rs @@ -24,7 +24,12 @@ impl G for F32Cos { super::f32_dot::dot(lhs, rhs).acos() } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn scalar_quantization_distance( dims: u16, max: &[F32], @@ -45,7 +50,12 @@ impl G for F32Cos { xy / (x2 * y2).sqrt() * (-1.0) } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn scalar_quantization_distance2( dims: u16, max: &[F32], @@ -66,7 +76,12 @@ impl G for F32Cos { xy / (x2 * y2).sqrt() * (-1.0) } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn product_quantization_distance( dims: u16, ratio: u16, @@ -91,7 +106,12 @@ impl G for F32Cos { xy / (x2 * y2).sqrt() * (-1.0) } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn product_quantization_distance2( dims: u16, ratio: u16, @@ -117,7 +137,12 @@ impl G for F32Cos { xy / (x2 * y2).sqrt() * (-1.0) } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn product_quantization_distance_with_delta( dims: u16, ratio: u16, @@ -146,7 +171,12 @@ impl G for F32Cos { } #[inline(always)] -#[multiversion::multiversion(targets = "simd")] +#[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" +))] fn length(vector: &[F32]) -> F32 { let n = vector.len(); let mut dot = F32::zero(); @@ -157,7 +187,12 @@ fn length(vector: &[F32]) -> F32 { } #[inline(always)] -#[multiversion::multiversion(targets = "simd")] +#[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" +))] fn l2_normalize(vector: &mut [F32]) { let n = vector.len(); let l = length(vector); @@ -167,7 +202,12 @@ fn l2_normalize(vector: &mut [F32]) { } #[inline(always)] -#[multiversion::multiversion(targets = "simd")] +#[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" +))] fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -183,7 +223,12 @@ fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 { } #[inline(always)] -#[multiversion::multiversion(targets = "simd")] +#[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" +))] fn xy_x2_y2(lhs: &[F32], rhs: &[F32]) -> (F32, F32, F32) { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -199,7 +244,12 @@ fn xy_x2_y2(lhs: &[F32], rhs: &[F32]) -> (F32, F32, F32) { } #[inline(always)] -#[multiversion::multiversion(targets = "simd")] +#[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" +))] fn xy_x2_y2_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> (F32, F32, F32) { assert!(lhs.len() == rhs.len()); let n = lhs.len(); diff --git a/crates/service/src/prelude/global/f32_dot.rs b/crates/service/src/prelude/global/f32_dot.rs index 081f4eb39..f8d624d51 100644 --- a/crates/service/src/prelude/global/f32_dot.rs +++ b/crates/service/src/prelude/global/f32_dot.rs @@ -24,7 +24,12 @@ impl G for F32Dot { super::f32_dot::dot(lhs, rhs).acos() } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn scalar_quantization_distance( dims: u16, max: &[F32], @@ -41,7 +46,12 @@ impl G for F32Dot { xy * (-1.0) } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn scalar_quantization_distance2( dims: u16, max: &[F32], @@ -58,7 +68,12 @@ impl G for F32Dot { xy * (-1.0) } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn product_quantization_distance( dims: u16, ratio: u16, @@ -79,7 +94,12 @@ impl G for F32Dot { xy * (-1.0) } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn product_quantization_distance2( dims: u16, ratio: u16, @@ -101,7 +121,12 @@ impl G for F32Dot { xy * (-1.0) } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn product_quantization_distance_with_delta( dims: u16, ratio: u16, @@ -126,7 +151,12 @@ impl G for F32Dot { } #[inline(always)] -#[multiversion::multiversion(targets = "simd")] +#[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" +))] fn length(vector: &[F32]) -> F32 { let n = vector.len(); let mut dot = F32::zero(); @@ -137,7 +167,12 @@ fn length(vector: &[F32]) -> F32 { } #[inline(always)] -#[multiversion::multiversion(targets = "simd")] +#[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" +))] fn l2_normalize(vector: &mut [F32]) { let n = vector.len(); let l = length(vector); @@ -147,7 +182,12 @@ fn l2_normalize(vector: &mut [F32]) { } #[inline(always)] -#[multiversion::multiversion(targets = "simd")] +#[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" +))] fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -163,7 +203,12 @@ fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 { } #[inline(always)] -#[multiversion::multiversion(targets = "simd")] +#[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" +))] pub fn dot(lhs: &[F32], rhs: &[F32]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -175,7 +220,12 @@ pub fn dot(lhs: &[F32], rhs: &[F32]) -> F32 { } #[inline(always)] -#[multiversion::multiversion(targets = "simd")] +#[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" +))] fn dot_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> F32 { assert!(lhs.len() == rhs.len()); let n: usize = lhs.len(); diff --git a/crates/service/src/prelude/global/f32_l2.rs b/crates/service/src/prelude/global/f32_l2.rs index a5d6da5b5..52addc207 100644 --- a/crates/service/src/prelude/global/f32_l2.rs +++ b/crates/service/src/prelude/global/f32_l2.rs @@ -22,7 +22,12 @@ impl G for F32L2 { distance_squared_l2(lhs, rhs).sqrt() } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn scalar_quantization_distance( dims: u16, max: &[F32], @@ -39,7 +44,12 @@ impl G for F32L2 { result } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn scalar_quantization_distance2( dims: u16, max: &[F32], @@ -56,7 +66,12 @@ impl G for F32L2 { result } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn product_quantization_distance( dims: u16, ratio: u16, @@ -76,7 +91,12 @@ impl G for F32L2 { result } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn product_quantization_distance2( dims: u16, ratio: u16, @@ -97,7 +117,12 @@ impl G for F32L2 { result } - #[multiversion::multiversion(targets = "simd")] + #[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" + ))] fn product_quantization_distance_with_delta( dims: u16, ratio: u16, @@ -121,7 +146,12 @@ impl G for F32L2 { } #[inline(always)] -#[multiversion::multiversion(targets = "simd")] +#[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" +))] pub fn distance_squared_l2(lhs: &[F32], rhs: &[F32]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -134,7 +164,12 @@ pub fn distance_squared_l2(lhs: &[F32], rhs: &[F32]) -> F32 { } #[inline(always)] -#[multiversion::multiversion(targets = "simd")] +#[multiversion::multiversion(targets( + "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "aarch64+neon" +))] fn distance_squared_l2_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); From d7bfaa7f056030e641188524d2e338a29dc2d080 Mon Sep 17 00:00:00 2001 From: usamoi Date: Wed, 13 Dec 2023 15:56:12 +0800 Subject: [PATCH 17/23] fix: enable more target features for c Signed-off-by: usamoi --- Cargo.lock | 54 ++++++++----- bindings/python/tests/test_psycopg.py | 2 +- bindings/python/tests/test_sqlalchemy.py | 2 +- crates/c/src/c.c | 24 +++--- crates/c/src/c.h | 12 +-- crates/c/src/c.rs | 12 +-- crates/service/Cargo.toml | 1 + crates/service/src/prelude/global/detect.rs | 10 --- crates/service/src/prelude/global/f16.rs | 43 +++++----- crates/service/src/prelude/global/f16_cos.rs | 54 ++++++------- crates/service/src/prelude/global/f16_dot.rs | 48 +++++------ crates/service/src/prelude/global/f16_l2.rs | 36 ++++----- crates/service/src/prelude/global/f32_cos.rs | 60 +++++++------- crates/service/src/prelude/global/f32_dot.rs | 60 +++++++------- crates/service/src/prelude/global/f32_l2.rs | 42 +++++----- crates/service/src/prelude/global/mod.rs | 1 - crates/service/src/utils/detect.rs | 3 + crates/service/src/utils/detect/x86_64.rs | 85 ++++++++++++++++++++ crates/service/src/utils/mod.rs | 1 + docs/comparison-with-specialized-vectordb.md | 2 +- docs/indexing.md | 28 +++---- src/sql/finalize.sql | 8 +- tests/sqllogictest/error.slt | 2 +- tests/sqllogictest/flat.slt | 2 +- tests/sqllogictest/hnsw.slt | 2 +- tests/sqllogictest/ivf.slt | 4 +- tests/sqllogictest/quantization.slt | 4 +- tests/sqllogictest/reindex.slt | 2 +- tests/sqllogictest/update.slt | 2 +- 29 files changed, 349 insertions(+), 257 deletions(-) delete mode 100644 crates/service/src/prelude/global/detect.rs create mode 100644 crates/service/src/utils/detect.rs create mode 100644 crates/service/src/utils/detect/x86_64.rs diff --git a/Cargo.lock b/Cargo.lock index f23bb8dcb..f44aff553 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -284,7 +284,7 @@ checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -362,7 +362,8 @@ dependencies = [ [[package]] name = "bindgen" version = "0.68.1" -source = "git+https://github.com/usamoi/rust-bindgen.git?rev=7e683d3cc6a0667232f821088b8d0ed7f4d4c31e#7e683d3cc6a0667232f821088b8d0ed7f4d4c31e" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "726e4313eb6ec35d2730258ad4e15b547ee75d6afaa1361a922e78e59b7d8078" dependencies = [ "bitflags 2.4.1", "cexpr", @@ -375,7 +376,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -465,7 +466,7 @@ checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -591,7 +592,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -731,6 +732,16 @@ dependencies = [ "typenum", ] +[[package]] +name = "ctor" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30d2b3721e861707777e3195b0158f950ae6dc4a27e4d02ff9f67e3eb3de199e" +dependencies = [ + "quote", + "syn 2.0.41", +] + [[package]] name = "cty" version = "0.2.2" @@ -789,7 +800,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -800,7 +811,7 @@ checksum = "836a9bbc7ad63342d6d6e7b815ccab164bc77a2d95d84bc3117a8c0d5c98e2d5" dependencies = [ "darling_core", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -944,7 +955,7 @@ checksum = "f282cfdfe92516eb26c2af8589c274c7c17681f5ecc03c18255fe741c6aa64eb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -1016,9 +1027,9 @@ dependencies = [ [[package]] name = "eyre" -version = "0.6.10" +version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bbb8258be8305fb0237d7b295f47bb24ff1b136a535f473baf40e70468515aa" +checksum = "80f656be11ddf91bd709454d15d5bd896fbaf4cc3314e69349e4d1569f5b46cd" dependencies = [ "indenter", "once_cell", @@ -1161,7 +1172,7 @@ checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -2168,7 +2179,7 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -2723,7 +2734,7 @@ checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -2782,7 +2793,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -2797,6 +2808,7 @@ dependencies = [ "c", "crc32fast", "crossbeam", + "ctor", "dashmap", "half 2.3.1", "libc", @@ -2983,9 +2995,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.40" +version = "2.0.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13fa70a4ee923979ffb522cacce59d34421ebdea5625e1073c4326ef9d2dd42e" +checksum = "44c8b28c477cc3bf0e7966561e3460130e1255f7a1cf71931075f1c5e7a7e269" dependencies = [ "proc-macro2", "quote", @@ -3075,7 +3087,7 @@ checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -3157,7 +3169,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -3260,7 +3272,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -3545,7 +3557,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", "wasm-bindgen-shared", ] @@ -3579,7 +3591,7 @@ checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/bindings/python/tests/test_psycopg.py b/bindings/python/tests/test_psycopg.py index f79b4e2af..07e3a924f 100644 --- a/bindings/python/tests/test_psycopg.py +++ b/bindings/python/tests/test_psycopg.py @@ -38,7 +38,7 @@ def conn(): @pytest.mark.parametrize(("index_name", "index_setting"), TOML_SETTINGS.items()) def test_create_index(conn: Connection, index_name: str, index_setting: str): stat = sql.SQL( - "CREATE INDEX {} ON tb_test_item USING vectors (embedding l2_ops) WITH (options={});", + "CREATE INDEX {} ON tb_test_item USING vectors (embedding vector_l2_ops) WITH (options={});", ).format(sql.Identifier(index_name), index_setting) conn.execute(stat) diff --git a/bindings/python/tests/test_sqlalchemy.py b/bindings/python/tests/test_sqlalchemy.py index d51e58b2b..9ae03e391 100644 --- a/bindings/python/tests/test_sqlalchemy.py +++ b/bindings/python/tests/test_sqlalchemy.py @@ -68,7 +68,7 @@ def test_create_index(session: Session, index_name: str, index_setting: str): Document.embedding, postgresql_using="vectors", postgresql_with={"options": f"$${index_setting}$$"}, - postgresql_ops={"embedding": "l2_ops"}, + postgresql_ops={"embedding": "vector_l2_ops"}, ) index.create(session.bind) session.commit() diff --git a/crates/c/src/c.c b/crates/c/src/c.c index eeee00933..e41f282d4 100644 --- a/crates/c/src/c.c +++ b/crates/c/src/c.c @@ -7,8 +7,8 @@ #if defined(__x86_64__) -__attribute__((target("avx512fp16,bmi2"))) extern float -v_f16_cosine_axv512(_Float16 *a, _Float16 *b, size_t n) { +__attribute__((target("arch=x86-64-v4,avx512fp16"))) extern float +v_f16_cosine_avx512fp16(_Float16 *a, _Float16 *b, size_t n) { __m512h xy = _mm512_set1_ph(0); __m512h xx = _mm512_set1_ph(0); __m512h yy = _mm512_set1_ph(0); @@ -33,8 +33,8 @@ v_f16_cosine_axv512(_Float16 *a, _Float16 *b, size_t n) { sqrt(_mm512_reduce_add_ph(xx) * _mm512_reduce_add_ph(yy))); } -__attribute__((target("avx512fp16,bmi2"))) extern float -v_f16_dot_axv512(_Float16 *a, _Float16 *b, size_t n) { +__attribute__((target("arch=x86-64-v4,avx512fp16"))) extern float +v_f16_dot_avx512fp16(_Float16 *a, _Float16 *b, size_t n) { __m512h xy = _mm512_set1_ph(0); while (n >= 32) { @@ -52,8 +52,8 @@ v_f16_dot_axv512(_Float16 *a, _Float16 *b, size_t n) { return (float)_mm512_reduce_add_ph(xy); } -__attribute__((target("avx512fp16,bmi2"))) extern float -v_f16_sl2_axv512(_Float16 *a, _Float16 *b, size_t n) { +__attribute__((target("arch=x86-64-v4,avx512fp16"))) extern float +v_f16_sl2_avx512fp16(_Float16 *a, _Float16 *b, size_t n) { __m512h dd = _mm512_set1_ph(0); while (n >= 32) { @@ -74,8 +74,8 @@ v_f16_sl2_axv512(_Float16 *a, _Float16 *b, size_t n) { return (float)_mm512_reduce_add_ph(dd); } -__attribute__((target("avx2"))) extern float -v_f16_cosine_axv2(_Float16 *a, _Float16 *b, size_t n) { +__attribute__((target("arch=x86-64-v3"))) extern float +v_f16_cosine_v3(_Float16 *a, _Float16 *b, size_t n) { float xy = 0; float xx = 0; float yy = 0; @@ -90,8 +90,8 @@ v_f16_cosine_axv2(_Float16 *a, _Float16 *b, size_t n) { return xy / sqrt(xx * yy); } -__attribute__((target("avx2"))) extern float -v_f16_dot_axv2(_Float16 *a, _Float16 *b, size_t n) { +__attribute__((target("arch=x86-64-v3"))) extern float +v_f16_dot_v3(_Float16 *a, _Float16 *b, size_t n) { float xy = 0; #pragma clang loop vectorize_width(8) for (size_t i = 0; i < n; i++) { @@ -102,8 +102,8 @@ v_f16_dot_axv2(_Float16 *a, _Float16 *b, size_t n) { return xy; } -__attribute__((target("avx2"))) extern float -v_f16_sl2_axv2(_Float16 *a, _Float16 *b, size_t n) { +__attribute__((target("arch=x86-64-v3"))) extern float +v_f16_sl2_v3(_Float16 *a, _Float16 *b, size_t n) { float dd = 0; #pragma clang loop vectorize_width(8) for (size_t i = 0; i < n; i++) { diff --git a/crates/c/src/c.h b/crates/c/src/c.h index 914546ed6..d50c3d712 100644 --- a/crates/c/src/c.h +++ b/crates/c/src/c.h @@ -3,11 +3,11 @@ #if defined(__x86_64__) -extern float v_f16_cosine_axv512(_Float16 *, _Float16 *, size_t n); -extern float v_f16_dot_axv512(_Float16 *, _Float16 *, size_t n); -extern float v_f16_sl2_axv512(_Float16 *, _Float16 *, size_t n); -extern float v_f16_cosine_axv2(_Float16 *, _Float16 *, size_t n); -extern float v_f16_dot_axv2(_Float16 *, _Float16 *, size_t n); -extern float v_f16_sl2_axv2(_Float16 *, _Float16 *, size_t n); +extern float v_f16_cosine_avx512fp16(_Float16 *, _Float16 *, size_t n); +extern float v_f16_dot_avx512fp16(_Float16 *, _Float16 *, size_t n); +extern float v_f16_sl2_avx512fp16(_Float16 *, _Float16 *, size_t n); +extern float v_f16_cosine_v3(_Float16 *, _Float16 *, size_t n); +extern float v_f16_dot_v3(_Float16 *, _Float16 *, size_t n); +extern float v_f16_sl2_v3(_Float16 *, _Float16 *, size_t n); #endif diff --git a/crates/c/src/c.rs b/crates/c/src/c.rs index 1d16999b3..a4ac2c255 100644 --- a/crates/c/src/c.rs +++ b/crates/c/src/c.rs @@ -1,12 +1,12 @@ #[cfg(target_arch = "x86_64")] #[link(name = "pgvectorsc", kind = "static")] extern "C" { - pub fn v_f16_cosine_axv512(a: *const u16, b: *const u16, n: usize) -> f32; - pub fn v_f16_dot_axv512(a: *const u16, b: *const u16, n: usize) -> f32; - pub fn v_f16_sl2_axv512(a: *const u16, b: *const u16, n: usize) -> f32; - pub fn v_f16_cosine_axv2(a: *const u16, b: *const u16, n: usize) -> f32; - pub fn v_f16_dot_axv2(a: *const u16, b: *const u16, n: usize) -> f32; - pub fn v_f16_sl2_axv2(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn v_f16_cosine_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn v_f16_dot_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn v_f16_sl2_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn v_f16_cosine_v3(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn v_f16_dot_v3(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn v_f16_sl2_v3(a: *const u16, b: *const u16, n: usize) -> f32; } // `compiler_builtin` defines `__extendhfsf2` with integer calling convention. diff --git a/crates/service/Cargo.toml b/crates/service/Cargo.toml index 509e0b43d..8693b4e72 100644 --- a/crates/service/Cargo.toml +++ b/crates/service/Cargo.toml @@ -32,6 +32,7 @@ arc-swap = "1.6.0" bytemuck = { version = "1.14.0", features = ["extern_crate_alloc"] } serde_with = "3.4.0" multiversion = "0.7.3" +ctor = "0.2.6" [target.'cfg(target_os = "macos")'.dependencies] ulock-sys = "0.1.0" diff --git a/crates/service/src/prelude/global/detect.rs b/crates/service/src/prelude/global/detect.rs deleted file mode 100644 index cf2282fc8..000000000 --- a/crates/service/src/prelude/global/detect.rs +++ /dev/null @@ -1,10 +0,0 @@ -#[cfg(target_arch = "x86_64")] -pub fn detect_avx512fp16() -> bool { - std_detect::is_x86_feature_detected!("avx512fp16") - && std_detect::is_x86_feature_detected!("bmi2") -} - -#[cfg(target_arch = "x86_64")] -pub fn detect_avx2() -> bool { - std_detect::is_x86_feature_detected!("avx2") -} diff --git a/crates/service/src/prelude/global/f16.rs b/crates/service/src/prelude/global/f16.rs index 610707ea2..95d48335f 100644 --- a/crates/service/src/prelude/global/f16.rs +++ b/crates/service/src/prelude/global/f16.rs @@ -1,11 +1,12 @@ use crate::prelude::*; +use crate::utils::detect; pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { @@ -22,19 +23,19 @@ pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { xy / (x2 * y2).sqrt() } #[cfg(target_arch = "x86_64")] - if super::detect::detect_avx512fp16() { + if self::detect::detect_avx512fp16() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { - return c::v_f16_cosine_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); + return c::v_f16_cosine_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); } } #[cfg(target_arch = "x86_64")] - if super::detect::detect_avx2() { + if self::detect::detect_v3() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { - return c::v_f16_cosine_axv2(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); + return c::v_f16_cosine_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); } } cosine(lhs, rhs) @@ -43,9 +44,9 @@ pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { @@ -58,19 +59,19 @@ pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { xy } #[cfg(target_arch = "x86_64")] - if super::detect::detect_avx512fp16() { + if self::detect::detect_avx512fp16() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { - return c::v_f16_dot_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); + return c::v_f16_dot_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); } } #[cfg(target_arch = "x86_64")] - if super::detect::detect_avx2() { + if self::detect::detect_v3() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { - return c::v_f16_dot_axv2(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); + return c::v_f16_dot_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); } } cosine(lhs, rhs) @@ -79,9 +80,9 @@ pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { @@ -95,19 +96,19 @@ pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { d2 } #[cfg(target_arch = "x86_64")] - if super::detect::detect_avx512fp16() { + if self::detect::detect_avx512fp16() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { - return c::v_f16_sl2_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); + return c::v_f16_sl2_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); } } #[cfg(target_arch = "x86_64")] - if super::detect::detect_avx2() { + if self::detect::detect_v3() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { - return c::v_f16_sl2_axv2(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); + return c::v_f16_sl2_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); } } sl2(lhs, rhs) diff --git a/crates/service/src/prelude/global/f16_cos.rs b/crates/service/src/prelude/global/f16_cos.rs index e6c104dc1..df9f60522 100644 --- a/crates/service/src/prelude/global/f16_cos.rs +++ b/crates/service/src/prelude/global/f16_cos.rs @@ -25,9 +25,9 @@ impl G for F16Cos { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn scalar_quantization_distance( @@ -51,9 +51,9 @@ impl G for F16Cos { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn scalar_quantization_distance2( @@ -77,9 +77,9 @@ impl G for F16Cos { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn product_quantization_distance( @@ -107,9 +107,9 @@ impl G for F16Cos { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn product_quantization_distance2( @@ -138,9 +138,9 @@ impl G for F16Cos { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn product_quantization_distance_with_delta( @@ -172,9 +172,9 @@ impl G for F16Cos { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn length(vector: &[F16]) -> F16 { @@ -188,9 +188,9 @@ fn length(vector: &[F16]) -> F16 { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn l2_normalize(vector: &mut [F16]) { @@ -203,9 +203,9 @@ fn l2_normalize(vector: &mut [F16]) { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn xy_x2_y2(lhs: &[F16], rhs: &[F16]) -> (F32, F32, F32) { @@ -224,9 +224,9 @@ fn xy_x2_y2(lhs: &[F16], rhs: &[F16]) -> (F32, F32, F32) { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn xy_x2_y2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> (F32, F32, F32) { diff --git a/crates/service/src/prelude/global/f16_dot.rs b/crates/service/src/prelude/global/f16_dot.rs index cba885fe9..085c2b827 100644 --- a/crates/service/src/prelude/global/f16_dot.rs +++ b/crates/service/src/prelude/global/f16_dot.rs @@ -25,9 +25,9 @@ impl G for F16Dot { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn scalar_quantization_distance( @@ -47,9 +47,9 @@ impl G for F16Dot { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn scalar_quantization_distance2( @@ -69,9 +69,9 @@ impl G for F16Dot { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn product_quantization_distance( @@ -95,9 +95,9 @@ impl G for F16Dot { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn product_quantization_distance2( @@ -122,9 +122,9 @@ impl G for F16Dot { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn product_quantization_distance_with_delta( @@ -152,9 +152,9 @@ impl G for F16Dot { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn length(vector: &[F16]) -> F16 { @@ -168,9 +168,9 @@ fn length(vector: &[F16]) -> F16 { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn l2_normalize(vector: &mut [F16]) { @@ -183,9 +183,9 @@ fn l2_normalize(vector: &mut [F16]) { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn dot_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 { diff --git a/crates/service/src/prelude/global/f16_l2.rs b/crates/service/src/prelude/global/f16_l2.rs index bb03889fe..647c6f900 100644 --- a/crates/service/src/prelude/global/f16_l2.rs +++ b/crates/service/src/prelude/global/f16_l2.rs @@ -24,9 +24,9 @@ impl G for F16L2 { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn scalar_quantization_distance( @@ -46,9 +46,9 @@ impl G for F16L2 { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn scalar_quantization_distance2( @@ -68,9 +68,9 @@ impl G for F16L2 { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn product_quantization_distance( @@ -93,9 +93,9 @@ impl G for F16L2 { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn product_quantization_distance2( @@ -119,9 +119,9 @@ impl G for F16L2 { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn product_quantization_distance_with_delta( @@ -148,9 +148,9 @@ impl G for F16L2 { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn distance_squared_l2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 { diff --git a/crates/service/src/prelude/global/f32_cos.rs b/crates/service/src/prelude/global/f32_cos.rs index 868035143..06cd7001e 100644 --- a/crates/service/src/prelude/global/f32_cos.rs +++ b/crates/service/src/prelude/global/f32_cos.rs @@ -25,9 +25,9 @@ impl G for F32Cos { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn scalar_quantization_distance( @@ -51,9 +51,9 @@ impl G for F32Cos { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn scalar_quantization_distance2( @@ -77,9 +77,9 @@ impl G for F32Cos { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn product_quantization_distance( @@ -107,9 +107,9 @@ impl G for F32Cos { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn product_quantization_distance2( @@ -138,9 +138,9 @@ impl G for F32Cos { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn product_quantization_distance_with_delta( @@ -172,9 +172,9 @@ impl G for F32Cos { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn length(vector: &[F32]) -> F32 { @@ -188,9 +188,9 @@ fn length(vector: &[F32]) -> F32 { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn l2_normalize(vector: &mut [F32]) { @@ -203,9 +203,9 @@ fn l2_normalize(vector: &mut [F32]) { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 { @@ -224,9 +224,9 @@ fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn xy_x2_y2(lhs: &[F32], rhs: &[F32]) -> (F32, F32, F32) { @@ -245,9 +245,9 @@ fn xy_x2_y2(lhs: &[F32], rhs: &[F32]) -> (F32, F32, F32) { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn xy_x2_y2_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> (F32, F32, F32) { diff --git a/crates/service/src/prelude/global/f32_dot.rs b/crates/service/src/prelude/global/f32_dot.rs index f8d624d51..58108d89e 100644 --- a/crates/service/src/prelude/global/f32_dot.rs +++ b/crates/service/src/prelude/global/f32_dot.rs @@ -25,9 +25,9 @@ impl G for F32Dot { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn scalar_quantization_distance( @@ -47,9 +47,9 @@ impl G for F32Dot { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn scalar_quantization_distance2( @@ -69,9 +69,9 @@ impl G for F32Dot { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn product_quantization_distance( @@ -95,9 +95,9 @@ impl G for F32Dot { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn product_quantization_distance2( @@ -122,9 +122,9 @@ impl G for F32Dot { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn product_quantization_distance_with_delta( @@ -152,9 +152,9 @@ impl G for F32Dot { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn length(vector: &[F32]) -> F32 { @@ -168,9 +168,9 @@ fn length(vector: &[F32]) -> F32 { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn l2_normalize(vector: &mut [F32]) { @@ -183,9 +183,9 @@ fn l2_normalize(vector: &mut [F32]) { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 { @@ -204,9 +204,9 @@ fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] pub fn dot(lhs: &[F32], rhs: &[F32]) -> F32 { @@ -221,9 +221,9 @@ pub fn dot(lhs: &[F32], rhs: &[F32]) -> F32 { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn dot_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> F32 { diff --git a/crates/service/src/prelude/global/f32_l2.rs b/crates/service/src/prelude/global/f32_l2.rs index 52addc207..2672b6714 100644 --- a/crates/service/src/prelude/global/f32_l2.rs +++ b/crates/service/src/prelude/global/f32_l2.rs @@ -23,9 +23,9 @@ impl G for F32L2 { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn scalar_quantization_distance( @@ -45,9 +45,9 @@ impl G for F32L2 { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn scalar_quantization_distance2( @@ -67,9 +67,9 @@ impl G for F32L2 { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn product_quantization_distance( @@ -92,9 +92,9 @@ impl G for F32L2 { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn product_quantization_distance2( @@ -118,9 +118,9 @@ impl G for F32L2 { } #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn product_quantization_distance_with_delta( @@ -147,9 +147,9 @@ impl G for F32L2 { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] pub fn distance_squared_l2(lhs: &[F32], rhs: &[F32]) -> F32 { @@ -165,9 +165,9 @@ pub fn distance_squared_l2(lhs: &[F32], rhs: &[F32]) -> F32 { #[inline(always)] #[multiversion::multiversion(targets( - "x86_64+avx512vl+avx512f+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+avx2+avx+ssse3+sse4.1+sse3+sse2+sse+fma", - "x86_64+ssse3+sse4.1+sse3+sse2+sse+fma", + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", "aarch64+neon" ))] fn distance_squared_l2_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> F32 { diff --git a/crates/service/src/prelude/global/mod.rs b/crates/service/src/prelude/global/mod.rs index fa78a8fc4..2eedaf91a 100644 --- a/crates/service/src/prelude/global/mod.rs +++ b/crates/service/src/prelude/global/mod.rs @@ -1,4 +1,3 @@ -mod detect; mod f16; mod f16_cos; mod f16_dot; diff --git a/crates/service/src/utils/detect.rs b/crates/service/src/utils/detect.rs new file mode 100644 index 000000000..ce69de15a --- /dev/null +++ b/crates/service/src/utils/detect.rs @@ -0,0 +1,3 @@ +mod x86_64; + +pub use x86_64::*; diff --git a/crates/service/src/utils/detect/x86_64.rs b/crates/service/src/utils/detect/x86_64.rs new file mode 100644 index 000000000..5bd3c8705 --- /dev/null +++ b/crates/service/src/utils/detect/x86_64.rs @@ -0,0 +1,85 @@ +#![cfg(target_arch = "x86_64")] + +use std::sync::atomic::{AtomicBool, Ordering}; + +static ATOMIC_AVX512FP16: AtomicBool = AtomicBool::new(false); + +pub fn test_avx512fp16() -> bool { + std_detect::is_x86_feature_detected!("avx512fp16") && test_v4() +} + +#[ctor::ctor] +fn ctor_avx512fp16() { + ATOMIC_AVX512FP16.store(test_avx512fp16(), Ordering::Relaxed); +} + +pub fn detect_avx512fp16() -> bool { + ATOMIC_AVX512FP16.load(Ordering::Relaxed) +} + +static ATOMIC_V4: AtomicBool = AtomicBool::new(false); + +pub fn test_v4() -> bool { + std_detect::is_x86_feature_detected!("avx512bw") + && std_detect::is_x86_feature_detected!("avx512cd") + && std_detect::is_x86_feature_detected!("avx512dq") + && std_detect::is_x86_feature_detected!("avx512f") + && std_detect::is_x86_feature_detected!("avx512vl") + && test_v3() +} + +#[ctor::ctor] +fn ctor_v4() { + ATOMIC_V4.store(test_v4(), Ordering::Relaxed); +} + +pub fn _detect_v4() -> bool { + ATOMIC_V4.load(Ordering::Relaxed) +} + +static ATOMIC_V3: AtomicBool = AtomicBool::new(false); + +pub fn test_v3() -> bool { + std_detect::is_x86_feature_detected!("avx") + && std_detect::is_x86_feature_detected!("avx2") + && std_detect::is_x86_feature_detected!("bmi1") + && std_detect::is_x86_feature_detected!("bmi2") + && std_detect::is_x86_feature_detected!("f16c") + && std_detect::is_x86_feature_detected!("fma") + && std_detect::is_x86_feature_detected!("lzcnt") + && std_detect::is_x86_feature_detected!("movbe") + && std_detect::is_x86_feature_detected!("xsave") + && test_v2() +} + +#[ctor::ctor] +fn ctor_v3() { + ATOMIC_V3.store(test_v3(), Ordering::Relaxed); +} + +pub fn detect_v3() -> bool { + ATOMIC_V3.load(Ordering::Relaxed) +} + +static ATOMIC_V2: AtomicBool = AtomicBool::new(false); + +pub fn test_v2() -> bool { + std_detect::is_x86_feature_detected!("cmpxchg16b") + && std_detect::is_x86_feature_detected!("fxsr") + && std_detect::is_x86_feature_detected!("popcnt") + && std_detect::is_x86_feature_detected!("sse") + && std_detect::is_x86_feature_detected!("sse2") + && std_detect::is_x86_feature_detected!("sse3") + && std_detect::is_x86_feature_detected!("sse4.1") + && std_detect::is_x86_feature_detected!("sse4.2") + && std_detect::is_x86_feature_detected!("ssse3") +} + +#[ctor::ctor] +fn ctor_v2() { + ATOMIC_V2.store(test_v2(), Ordering::Relaxed); +} + +pub fn _detect_v2() -> bool { + ATOMIC_V2.load(Ordering::Relaxed) +} diff --git a/crates/service/src/utils/mod.rs b/crates/service/src/utils/mod.rs index 55f717a88..ab8b9619c 100644 --- a/crates/service/src/utils/mod.rs +++ b/crates/service/src/utils/mod.rs @@ -5,3 +5,4 @@ pub mod file_atomic; pub mod file_wal; pub mod mmap_array; pub mod vec2; +pub mod detect; diff --git a/docs/comparison-with-specialized-vectordb.md b/docs/comparison-with-specialized-vectordb.md index 7c6c9817b..f02b90e90 100644 --- a/docs/comparison-with-specialized-vectordb.md +++ b/docs/comparison-with-specialized-vectordb.md @@ -11,7 +11,7 @@ Why not just use Postgres to do the vector similarity search? This is the reason UPDATE documents SET embedding = ai_embedding_vector(content) WHERE length(embedding) = 0; -- Create an index on the embedding column -CREATE INDEX ON documents USING vectors (embedding l2_ops); +CREATE INDEX ON documents USING vectors (embedding vector_l2_ops); -- Query the similar embeddings SELECT * FROM documents ORDER BY embedding <-> ai_embedding_vector('hello world') LIMIT 5; diff --git a/docs/indexing.md b/docs/indexing.md index b46ea41be..2d4968934 100644 --- a/docs/indexing.md +++ b/docs/indexing.md @@ -5,19 +5,19 @@ Indexing is the core ability of pgvecto.rs. Assuming there is a table `items` and there is a column named `embedding` of type `vector(n)`, you can create a vector index for squared Euclidean distance with the following SQL. ```sql -CREATE INDEX ON items USING vectors (embedding l2_ops); +CREATE INDEX ON items USING vectors (embedding vector_l2_ops); ``` There is a table for you to choose a proper operator class for creating indexes. -| Vector type | Distance type | Operator class | -| ----------- | -------------------------- | ----------------- | -| vector | squared Euclidean distance | l2_ops | -| vector | negative dot product | dot_ops | -| vector | negative cosine similarity | cosine_ops | -| vecf16 | squared Euclidean distance | vecf16_l2_ops | -| vecf16 | negative dot product | vecf16_dot_ops | -| vecf16 | negative cosine similarity | vecf16_cosine_ops | +| Vector type | Distance type | Operator class | +| ----------- | -------------------------- | -------------- | +| vector | squared Euclidean distance | vector_l2_ops | +| vector | negative dot product | vector_dot_ops | +| vector | negative cosine similarity | vector_cos_ops | +| vecf16 | squared Euclidean distance | vecf16_l2_ops | +| vecf16 | negative dot product | vecf16_dot_ops | +| vecf16 | negative cosine similarity | vecf16_cos_ops | Now you can perform a KNN search with the following SQL again, but this time the vector index is used for searching. @@ -129,11 +129,11 @@ There are some examples. ```sql -- HNSW algorithm, default settings. -CREATE INDEX ON items USING vectors (embedding l2_ops); +CREATE INDEX ON items USING vectors (embedding vector_l2_ops); --- Or using bruteforce with PQ. -CREATE INDEX ON items USING vectors (embedding l2_ops) +CREATE INDEX ON items USING vectors (embedding vector_l2_ops) WITH (options = $$ [indexing.flat] quantization.product.ratio = "x16" @@ -141,7 +141,7 @@ $$); --- Or using IVFPQ algorithm. -CREATE INDEX ON items USING vectors (embedding l2_ops) +CREATE INDEX ON items USING vectors (embedding vector_l2_ops) WITH (options = $$ [indexing.ivf] quantization.product.ratio = "x16" @@ -149,14 +149,14 @@ $$); -- Use more threads for background building the index. -CREATE INDEX ON items USING vectors (embedding l2_ops) +CREATE INDEX ON items USING vectors (embedding vector_l2_ops) WITH (options = $$ optimizing.optimizing_threads = 16 $$); -- Prefer smaller HNSW graph. -CREATE INDEX ON items USING vectors (embedding l2_ops) +CREATE INDEX ON items USING vectors (embedding vector_l2_ops) WITH (options = $$ segment.max_growing_segment_size = 200000 $$); diff --git a/src/sql/finalize.sql b/src/sql/finalize.sql index 62a1f8270..3e59ad0c1 100644 --- a/src/sql/finalize.sql +++ b/src/sql/finalize.sql @@ -7,15 +7,15 @@ CREATE CAST (vector AS real[]) CREATE ACCESS METHOD vectors TYPE INDEX HANDLER vectors_amhandler; COMMENT ON ACCESS METHOD vectors IS 'pgvecto.rs index access method'; -CREATE OPERATOR CLASS l2_ops +CREATE OPERATOR CLASS vector_l2_ops FOR TYPE vector USING vectors AS OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops; -CREATE OPERATOR CLASS dot_ops +CREATE OPERATOR CLASS vector_dot_ops FOR TYPE vector USING vectors AS OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops; -CREATE OPERATOR CLASS cosine_ops +CREATE OPERATOR CLASS vector_cos_ops FOR TYPE vector USING vectors AS OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops; @@ -27,7 +27,7 @@ CREATE OPERATOR CLASS vecf16_dot_ops FOR TYPE vecf16 USING vectors AS OPERATOR 1 <#> (vecf16, vecf16) FOR ORDER BY float_ops; -CREATE OPERATOR CLASS vecf16_cosine_ops +CREATE OPERATOR CLASS vecf16_cos_ops FOR TYPE vecf16 USING vectors AS OPERATOR 1 <=> (vecf16, vecf16) FOR ORDER BY float_ops; diff --git a/tests/sqllogictest/error.slt b/tests/sqllogictest/error.slt index 466e547f7..7e4428da4 100644 --- a/tests/sqllogictest/error.slt +++ b/tests/sqllogictest/error.slt @@ -5,7 +5,7 @@ statement ok CREATE TABLE t (val vector(3)); statement ok -CREATE INDEX ON t USING vectors (val l2_ops); +CREATE INDEX ON t USING vectors (val vector_l2_ops); statement error The given vector is invalid for input. INSERT INTO t (val) VALUES ('[0, 1, 2, 3]'); diff --git a/tests/sqllogictest/flat.slt b/tests/sqllogictest/flat.slt index d666518ab..601219db1 100644 --- a/tests/sqllogictest/flat.slt +++ b/tests/sqllogictest/flat.slt @@ -8,7 +8,7 @@ statement ok INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[] FROM generate_series(1, 1000); statement ok -CREATE INDEX ON t USING vectors (val l2_ops) +CREATE INDEX ON t USING vectors (val vector_l2_ops) WITH (options = "[indexing.flat]"); statement ok diff --git a/tests/sqllogictest/hnsw.slt b/tests/sqllogictest/hnsw.slt index 07f666de0..51f394828 100644 --- a/tests/sqllogictest/hnsw.slt +++ b/tests/sqllogictest/hnsw.slt @@ -11,7 +11,7 @@ INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[] FROM gene # And because of borrow checker, we can't remove this table before restarting the postgres. # Maybe we need better error handling. statement ok -CREATE INDEX ON t USING vectors (val l2_ops) +CREATE INDEX ON t USING vectors (val vector_l2_ops) WITH (options = "[indexing.hnsw]"); statement ok diff --git a/tests/sqllogictest/ivf.slt b/tests/sqllogictest/ivf.slt index 56b7aef8c..1f938f990 100644 --- a/tests/sqllogictest/ivf.slt +++ b/tests/sqllogictest/ivf.slt @@ -12,7 +12,7 @@ statement ok INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[] FROM generate_series(1, 1000); statement ok -CREATE INDEX ON t USING vectors (val l2_ops) +CREATE INDEX ON t USING vectors (val vector_l2_ops) WITH (options = "[indexing.ivf]"); statement ok @@ -47,7 +47,7 @@ statement ok INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[] FROM generate_series(1, 1000); statement ok -CREATE INDEX ON t USING vectors (val l2_ops) +CREATE INDEX ON t USING vectors (val vector_l2_ops) WITH (options = "[indexing.ivf.quantization.product]"); statement ok diff --git a/tests/sqllogictest/quantization.slt b/tests/sqllogictest/quantization.slt index af96edeb4..7908d7703 100644 --- a/tests/sqllogictest/quantization.slt +++ b/tests/sqllogictest/quantization.slt @@ -9,7 +9,7 @@ statement ok INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[] FROM generate_series(1, 1000); statement ok -CREATE INDEX ON t USING vectors (val l2_ops) +CREATE INDEX ON t USING vectors (val vector_l2_ops) WITH (options = "[indexing.hnsw.quantization.product]"); statement ok @@ -41,7 +41,7 @@ statement ok INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[] FROM generate_series(1, 1000); statement ok -CREATE INDEX ON t USING vectors (val l2_ops) +CREATE INDEX ON t USING vectors (val vector_l2_ops) WITH (options = "[indexing.hnsw.quantization.scalar]"); statement ok diff --git a/tests/sqllogictest/reindex.slt b/tests/sqllogictest/reindex.slt index de3b7e374..b871a5c9f 100644 --- a/tests/sqllogictest/reindex.slt +++ b/tests/sqllogictest/reindex.slt @@ -8,7 +8,7 @@ statement ok INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[] FROM generate_series(1, 1000); statement ok -CREATE INDEX ON t USING vectors (val l2_ops) +CREATE INDEX ON t USING vectors (val vector_l2_ops) WITH (options = "[indexing.hnsw]"); statement ok diff --git a/tests/sqllogictest/update.slt b/tests/sqllogictest/update.slt index 0306f7262..8fd44673b 100644 --- a/tests/sqllogictest/update.slt +++ b/tests/sqllogictest/update.slt @@ -8,7 +8,7 @@ statement ok INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[] FROM generate_series(1, 1000); statement ok -CREATE INDEX CONCURRENTLY ON t USING vectors (val l2_ops); +CREATE INDEX CONCURRENTLY ON t USING vectors (val vector_l2_ops); statement ok UPDATE t SET val = ARRAY[0.2, random(), random()]::real[] WHERE val = (SELECT val FROM t ORDER BY val <-> '[0.1,0.1,0.1]' LIMIT 1); From 1bc6316c373b809d1a3608eda6ecdd2bf666624d Mon Sep 17 00:00:00 2001 From: usamoi Date: Wed, 13 Dec 2023 16:02:58 +0800 Subject: [PATCH 18/23] fix: use tensorchord/stdarch Signed-off-by: usamoi --- Cargo.lock | 2 +- crates/service/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f44aff553..6a57f8991 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2940,7 +2940,7 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "std_detect" version = "0.1.5" -source = "git+https://github.com/usamoi/stdarch.git?rev=d934b65e47ce82ce4c20f0268dab01f71fb7b9c7#d934b65e47ce82ce4c20f0268dab01f71fb7b9c7" +source = "git+https://github.com/tensorchord/stdarch.git?branch=avx512fp16#db0cdbc9b02074bfddabfd23a4a681f21640eada" dependencies = [ "cfg-if", "libc", diff --git a/crates/service/Cargo.toml b/crates/service/Cargo.toml index 8693b4e72..367157072 100644 --- a/crates/service/Cargo.toml +++ b/crates/service/Cargo.toml @@ -16,7 +16,7 @@ bincode.workspace = true half.workspace = true num-traits.workspace = true c = { path = "../c" } -std_detect = { git = "https://github.com/usamoi/stdarch.git", rev = "d934b65e47ce82ce4c20f0268dab01f71fb7b9c7" } +std_detect = { git = "https://github.com/tensorchord/stdarch.git", branch = "avx512fp16" } rand = "0.8.5" crc32fast = "1.3.2" crossbeam = "0.8.2" From 4101b97927f47ce238dd386f807b775c3b8c7266 Mon Sep 17 00:00:00 2001 From: usamoi Date: Wed, 13 Dec 2023 16:21:53 +0800 Subject: [PATCH 19/23] fix: ci Signed-off-by: usamoi --- .github/workflows/check.yml | 12 +++++++++--- bindings/python/tests/__init__.py | 6 ++---- crates/service/src/prelude/global/f16.rs | 13 ++++++------- crates/service/src/utils/detect.rs | 4 +--- crates/service/src/utils/mod.rs | 2 +- scripts/ci_setup.sh | 2 +- 6 files changed, 20 insertions(+), 19 deletions(-) diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index d2f66e82e..85853a6c6 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -92,11 +92,17 @@ jobs: - name: Format check run: cargo fmt --check - name: Semantic check - run: cargo clippy --no-default-features --features "pg${{ matrix.version }} pg_test" + run: | + cargo clippy --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu + cargo clippy --no-default-features --features "pg${{ matrix.version }} pg_test" --target aarch64-unknown-linux-gnu - name: Debug build - run: cargo build --no-default-features --features "pg${{ matrix.version }} pg_test" + run: | + cargo build --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu + cargo build --no-default-features --features "pg${{ matrix.version }} pg_test" --target aarch64-unknown-linux-gnu - name: Test - run: cargo test --all --no-default-features --features "pg${{ matrix.version }} pg_test" -- --nocapture + run: | + cargo test --all --no-fail-fast --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu + cargo test --all --no-fail-fast --no-default-features --features "pg${{ matrix.version }} pg_test" --target aarch64-unknown-linux-gnu --no-run - name: Install release run: ./scripts/ci_install.sh - name: Sqllogictest diff --git a/bindings/python/tests/__init__.py b/bindings/python/tests/__init__.py index 1e28cc768..0a2e9ace2 100644 --- a/bindings/python/tests/__init__.py +++ b/bindings/python/tests/__init__.py @@ -19,14 +19,12 @@ TOML_SETTINGS = { "flat": toml.dumps( { - "capacity": 2097152, - "algorithm": {"flat": {}}, + "indexing": {"flat": {}}, }, ), "hnsw": toml.dumps( { - "capacity": 2097152, - "algorithm": {"hnsw": {}}, + "indexing": {"hnsw": {}}, }, ), } diff --git a/crates/service/src/prelude/global/f16.rs b/crates/service/src/prelude/global/f16.rs index 95d48335f..2f7e13cab 100644 --- a/crates/service/src/prelude/global/f16.rs +++ b/crates/service/src/prelude/global/f16.rs @@ -1,5 +1,4 @@ use crate::prelude::*; -use crate::utils::detect; pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { #[inline(always)] @@ -23,7 +22,7 @@ pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { xy / (x2 * y2).sqrt() } #[cfg(target_arch = "x86_64")] - if self::detect::detect_avx512fp16() { + if crate::utils::detect::x86_64::detect_avx512fp16() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { @@ -31,7 +30,7 @@ pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { } } #[cfg(target_arch = "x86_64")] - if self::detect::detect_v3() { + if crate::utils::detect::x86_64::detect_v3() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { @@ -59,7 +58,7 @@ pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { xy } #[cfg(target_arch = "x86_64")] - if self::detect::detect_avx512fp16() { + if crate::utils::detect::x86_64::detect_avx512fp16() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { @@ -67,7 +66,7 @@ pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { } } #[cfg(target_arch = "x86_64")] - if self::detect::detect_v3() { + if crate::utils::detect::x86_64::detect_v3() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { @@ -96,7 +95,7 @@ pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { d2 } #[cfg(target_arch = "x86_64")] - if self::detect::detect_avx512fp16() { + if crate::utils::detect::x86_64::detect_avx512fp16() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { @@ -104,7 +103,7 @@ pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { } } #[cfg(target_arch = "x86_64")] - if self::detect::detect_v3() { + if crate::utils::detect::x86_64::detect_v3() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { diff --git a/crates/service/src/utils/detect.rs b/crates/service/src/utils/detect.rs index ce69de15a..2a99bf589 100644 --- a/crates/service/src/utils/detect.rs +++ b/crates/service/src/utils/detect.rs @@ -1,3 +1 @@ -mod x86_64; - -pub use x86_64::*; +pub mod x86_64; diff --git a/crates/service/src/utils/mod.rs b/crates/service/src/utils/mod.rs index ab8b9619c..e42242438 100644 --- a/crates/service/src/utils/mod.rs +++ b/crates/service/src/utils/mod.rs @@ -1,8 +1,8 @@ pub mod cells; pub mod clean; +pub mod detect; pub mod dir_ops; pub mod file_atomic; pub mod file_wal; pub mod mmap_array; pub mod vec2; -pub mod detect; diff --git a/scripts/ci_setup.sh b/scripts/ci_setup.sh index 53f9171ec..9e1b8d74f 100755 --- a/scripts/ci_setup.sh +++ b/scripts/ci_setup.sh @@ -5,7 +5,6 @@ if [ "$OS" == "ubuntu-latest" ]; then if [ $VERSION != 14 ]; then sudo pg_dropcluster 14 main fi - sudo apt-get -y install crossbuild-essential-arm64 sudo apt-get remove -y '^postgres.*' '^libpq.*' '^clang.*' '^llvm.*' '^libclang.*' '^libllvm.*' '^mono-llvm.*' sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" >> /etc/apt/sources.list.d/pgdg.list' sudo sh -c 'echo "deb http://apt.llvm.org/$(lsb_release -cs)/ llvm-toolchain-$(lsb_release -cs)-16 main" >> /etc/apt/sources.list' @@ -14,6 +13,7 @@ if [ "$OS" == "ubuntu-latest" ]; then sudo apt-get update sudo apt-get -y install build-essential libpq-dev postgresql-$VERSION postgresql-server-dev-$VERSION sudo apt-get -y install clang-16 + sudo apt-get -y install crossbuild-essential-arm64 echo "local all all trust" | sudo tee /etc/postgresql/$VERSION/main/pg_hba.conf echo "host all all 127.0.0.1/32 trust" | sudo tee -a /etc/postgresql/$VERSION/main/pg_hba.conf echo "host all all ::1/128 trust" | sudo tee -a /etc/postgresql/$VERSION/main/pg_hba.conf From ac87e12f2b1bd817cb4230cc67d942760b725f97 Mon Sep 17 00:00:00 2001 From: usamoi Date: Wed, 13 Dec 2023 16:49:27 +0800 Subject: [PATCH 20/23] fix: remove no-run cross test Signed-off-by: usamoi --- .github/workflows/check.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 85853a6c6..2e7ed884f 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -102,7 +102,6 @@ jobs: - name: Test run: | cargo test --all --no-fail-fast --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu - cargo test --all --no-fail-fast --no-default-features --features "pg${{ matrix.version }} pg_test" --target aarch64-unknown-linux-gnu --no-run - name: Install release run: ./scripts/ci_install.sh - name: Sqllogictest From 93e89f53d0ceee9a18409346890ca0ae7bebcf4f Mon Sep 17 00:00:00 2001 From: usamoi Date: Wed, 13 Dec 2023 20:27:13 +0800 Subject: [PATCH 21/23] fix: vbase Signed-off-by: usamoi --- src/bgworker/mod.rs | 5 -- src/index/am_scan.rs | 73 +++++----------- src/ipc/client.rs | 93 -------------------- src/ipc/client/mod.rs | 121 +++++++++++++++++++++++++++ src/ipc/packet.rs | 87 ------------------- src/ipc/packet/create.rs | 7 ++ src/ipc/packet/delete.rs | 13 +++ src/ipc/packet/destory.rs | 6 ++ src/ipc/packet/flush.rs | 7 ++ src/ipc/packet/insert.rs | 7 ++ src/ipc/packet/mod.rs | 45 ++++++++++ src/ipc/packet/search.rs | 17 ++++ src/ipc/packet/stat.rs | 10 +++ src/ipc/packet/vbase.rs | 21 +++++ src/ipc/{server.rs => server/mod.rs} | 42 +++++----- 15 files changed, 295 insertions(+), 259 deletions(-) delete mode 100644 src/ipc/client.rs create mode 100644 src/ipc/client/mod.rs delete mode 100644 src/ipc/packet.rs create mode 100644 src/ipc/packet/create.rs create mode 100644 src/ipc/packet/delete.rs create mode 100644 src/ipc/packet/destory.rs create mode 100644 src/ipc/packet/flush.rs create mode 100644 src/ipc/packet/insert.rs create mode 100644 src/ipc/packet/mod.rs create mode 100644 src/ipc/packet/search.rs create mode 100644 src/ipc/packet/stat.rs create mode 100644 src/ipc/packet/vbase.rs rename src/ipc/{server.rs => server/mod.rs} (76%) diff --git a/src/bgworker/mod.rs b/src/bgworker/mod.rs index 0ca91e557..7d09099c8 100644 --- a/src/bgworker/mod.rs +++ b/src/bgworker/mod.rs @@ -136,11 +136,6 @@ fn session(worker: Arc, mut handler: RpcHandler) -> Result<(), IpcError> let result = worker.call_stat(id); handler = x.leave(result)?; } - RpcHandle::Leave {} => { - log::debug!("Handle leave rpc."); - break; - } } } - Ok(()) } diff --git a/src/index/am_scan.rs b/src/index/am_scan.rs index 433b1ebd5..35d2fc6be 100644 --- a/src/index/am_scan.rs +++ b/src/index/am_scan.rs @@ -8,14 +8,10 @@ use service::prelude::*; #[derive(Debug, Clone)] pub enum Scanner { Initial { - // fields to be filled by amhandler and hook vector: Option, index_scan_state: Option<*mut pgrx::pg_sys::IndexScanState>, }, - Type0 { - data: Vec, - }, - Type1 { + Search { index_scan_state: *mut pgrx::pg_sys::IndexScanState, data: Vec, }, @@ -95,13 +91,7 @@ pub unsafe fn start_scan( index_scan_state, }; } - Type0 { data: _ } => { - *scanner = Initial { - vector: Some(vector), - index_scan_state: None, - }; - } - Type1 { + Search { index_scan_state, data: _, } => { @@ -137,54 +127,31 @@ pub unsafe fn next_scan(scan: pgrx::pg_sys::IndexScanDesc) -> bool { let vector = vector.expect("`rescan` is never called."); let mut client = super::client::borrow_mut(); let k = K.get() as _; - if index_scan_state.is_some() { - struct ClientSearch { - node: *mut pgrx::pg_sys::IndexScanState, - } - - impl crate::ipc::client::ClientSearch for ClientSearch { - fn check(&mut self, p: Pointer) -> bool { - unsafe { check(self.node, p) } - } - } - - let client_search = ClientSearch { - node: index_scan_state.unwrap(), - }; - - let mut result = client.search(id, (vector, k), ENABLE_PREFILTER.get(), client_search); - result.reverse(); - *scanner = Scanner::Type1 { - index_scan_state: index_scan_state.unwrap(), - data: result, - }; - } else { - struct ClientSearch {} + assert!(index_scan_state.is_some()); + struct ClientSearch { + node: *mut pgrx::pg_sys::IndexScanState, + } - impl crate::ipc::client::ClientSearch for ClientSearch { - fn check(&mut self, _: Pointer) -> bool { - unreachable!() - } + impl crate::ipc::client::ClientSearch for ClientSearch { + fn check(&mut self, p: Pointer) -> bool { + unsafe { check(self.node, p) } } + } - let client_search = ClientSearch {}; + let client_search = ClientSearch { + node: index_scan_state.unwrap(), + }; - let mut result = client.search(id, (vector, k), false, client_search); - result.reverse(); - *scanner = Scanner::Type0 { data: result }; - } + let mut result = client.search(id, (vector, k), ENABLE_PREFILTER.get(), client_search); + result.reverse(); + *scanner = Scanner::Search { + index_scan_state: index_scan_state.unwrap(), + data: result, + }; } match scanner { Scanner::Initial { .. } => unreachable!(), - Scanner::Type0 { data } => { - if let Some(p) = data.pop() { - (*scan).xs_heaptid = p.into_sys(); - true - } else { - false - } - } - Scanner::Type1 { data, .. } => { + Scanner::Search { data, .. } => { if let Some(p) = data.pop() { (*scan).xs_heaptid = p.into_sys(); true diff --git a/src/ipc/client.rs b/src/ipc/client.rs deleted file mode 100644 index fc7a17cda..000000000 --- a/src/ipc/client.rs +++ /dev/null @@ -1,93 +0,0 @@ -use super::packet::*; -use super::transport::Socket; -use service::index::IndexOptions; -use service::index::IndexStat; -use service::prelude::*; - -pub struct Client { - socket: Socket, -} - -impl Client { - pub fn new(socket: Socket) -> Self { - Self { socket } - } - pub fn create(&mut self, id: Id, options: IndexOptions) { - let packet = RpcPacket::Create { id, options }; - self.socket.send(packet).friendly(); - let CreatePacket::Leave {} = self.socket.recv::().friendly(); - } - pub fn search( - &mut self, - id: Id, - search: (DynamicVector, usize), - prefilter: bool, - mut t: impl ClientSearch, - ) -> Vec { - let packet = RpcPacket::Search { - id, - search, - prefilter, - }; - self.socket.send(packet).friendly(); - loop { - match self.socket.recv::().friendly() { - SearchPacket::Check { p } => { - self.socket - .send(SearchCheckPacket::Leave { result: t.check(p) }) - .friendly(); - } - SearchPacket::Leave { result } => { - return result.friendly(); - } - } - } - } - pub fn delete(&mut self, id: Id, mut t: impl ClientDelete) { - let packet = RpcPacket::Delete { id }; - self.socket.send(packet).friendly(); - loop { - match self.socket.recv::().friendly() { - DeletePacket::Test { p } => { - self.socket - .send(DeleteTestPacket::Leave { delete: t.test(p) }) - .friendly(); - } - DeletePacket::Leave { result } => { - return result.friendly(); - } - } - } - } - pub fn insert(&mut self, id: Id, insert: (DynamicVector, Pointer)) { - let packet = RpcPacket::Insert { id, insert }; - self.socket.send(packet).friendly(); - let InsertPacket::Leave { result } = self.socket.recv::().friendly(); - result.friendly() - } - pub fn flush(&mut self, id: Id) { - let packet = RpcPacket::Flush { id }; - self.socket.send(packet).friendly(); - let FlushPacket::Leave { result } = self.socket.recv::().friendly(); - result.friendly() - } - pub fn destory(&mut self, ids: Vec) { - let packet = RpcPacket::Destory { ids }; - self.socket.send(packet).friendly(); - let DestoryPacket::Leave {} = self.socket.recv::().friendly(); - } - pub fn stat(&mut self, id: Id) -> IndexStat { - let packet = RpcPacket::Stat { id }; - self.socket.send(packet).friendly(); - let StatPacket::Leave { result } = self.socket.recv::().friendly(); - result.friendly() - } -} - -pub trait ClientSearch { - fn check(&mut self, p: Pointer) -> bool; -} - -pub trait ClientDelete { - fn test(&mut self, p: Pointer) -> bool; -} diff --git a/src/ipc/client/mod.rs b/src/ipc/client/mod.rs new file mode 100644 index 000000000..a7e1598d4 --- /dev/null +++ b/src/ipc/client/mod.rs @@ -0,0 +1,121 @@ +use super::packet::*; +use super::transport::Socket; +use service::index::IndexOptions; +use service::index::IndexStat; +use service::prelude::*; + +pub struct Client { + socket: Socket, +} + +impl Client { + pub fn new(socket: Socket) -> Self { + Self { socket } + } + pub fn create(&mut self, id: Id, options: IndexOptions) { + let packet = ClientPacket::Create { id, options }; + self.socket.send(packet).friendly(); + let create::ServerPacket::Leave {} = self.socket.recv::().friendly(); + } + pub fn search( + &mut self, + id: Id, + search: (DynamicVector, usize), + prefilter: bool, + mut t: impl ClientSearch, + ) -> Vec { + let packet = ClientPacket::Search { + id, + search, + prefilter, + }; + self.socket.send(packet).friendly(); + loop { + match self.socket.recv::().friendly() { + search::ServerPacket::Check { p } => { + self.socket + .send(search::ClientCheckPacket { result: t.check(p) }) + .friendly(); + } + search::ServerPacket::Leave { result } => { + return result.friendly(); + } + } + } + } + pub fn delete(&mut self, id: Id, mut t: impl ClientDelete) { + let packet = ClientPacket::Delete { id }; + self.socket.send(packet).friendly(); + loop { + match self.socket.recv::().friendly() { + delete::ServerPacket::Test { p } => { + self.socket + .send(delete::ClientTestPacket { delete: t.test(p) }) + .friendly(); + } + delete::ServerPacket::Leave { result } => { + return result.friendly(); + } + } + } + } + pub fn insert(&mut self, id: Id, insert: (DynamicVector, Pointer)) { + let packet = ClientPacket::Insert { id, insert }; + self.socket.send(packet).friendly(); + let insert::ServerPacket::Leave { result } = + self.socket.recv::().friendly(); + result.friendly() + } + pub fn flush(&mut self, id: Id) { + let packet = ClientPacket::Flush { id }; + self.socket.send(packet).friendly(); + let flush::ServerPacket::Leave { result } = + self.socket.recv::().friendly(); + result.friendly() + } + pub fn destory(&mut self, ids: Vec) { + let packet = ClientPacket::Destory { ids }; + self.socket.send(packet).friendly(); + let destory::ServerPacket::Leave {} = + self.socket.recv::().friendly(); + } + pub fn stat(&mut self, id: Id) -> IndexStat { + let packet = ClientPacket::Stat { id }; + self.socket.send(packet).friendly(); + let stat::ServerPacket::Leave { result } = + self.socket.recv::().friendly(); + result.friendly() + } + pub fn vbase(&mut self, id: Id, search: (DynamicVector, usize)) -> ClientVbase<'_> { + let packet = ClientPacket::Vbase { id, search }; + self.socket.send(packet).friendly(); + let vbase::ServerPacket::Leave {} = self.socket.recv::().friendly(); + ClientVbase(self) + } +} + +pub trait ClientSearch { + fn check(&mut self, p: Pointer) -> bool; +} + +pub trait ClientDelete { + fn test(&mut self, p: Pointer) -> bool; +} + +pub struct ClientVbase<'a>(&'a mut Client); + +impl ClientVbase<'_> { + pub fn next(&mut self) -> Pointer { + let packet = vbase::ClientPacket::Next {}; + self.0.socket.send(packet).friendly(); + let vbase::ServerNextPacket { p } = + self.0.socket.recv::().friendly(); + p + } + pub fn leave(self) { + let packet = vbase::ClientPacket::Leave {}; + self.0.socket.send(packet).friendly(); + let vbase::ServerLeavePacket {} = + self.0.socket.recv::().friendly(); + } +} diff --git a/src/ipc/packet.rs b/src/ipc/packet.rs deleted file mode 100644 index 2de6ff48a..000000000 --- a/src/ipc/packet.rs +++ /dev/null @@ -1,87 +0,0 @@ -use serde::{Deserialize, Serialize}; -use service::index::IndexOptions; -use service::index::IndexStat; -use service::prelude::*; - -#[derive(Debug, Serialize, Deserialize)] -pub enum RpcPacket { - Create { - id: Id, - options: IndexOptions, - }, - Flush { - id: Id, - }, - Destory { - ids: Vec, - }, - Insert { - id: Id, - insert: (DynamicVector, Pointer), - }, - Delete { - id: Id, - }, - Search { - id: Id, - search: (DynamicVector, usize), - prefilter: bool, - }, - Stat { - id: Id, - }, - Leave {}, -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum CreatePacket { - Leave {}, -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum FlushPacket { - Leave { result: Result<(), FriendlyError> }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum DestoryPacket { - Leave {}, -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum InsertPacket { - Leave { result: Result<(), FriendlyError> }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum DeletePacket { - Test { p: Pointer }, - Leave { result: Result<(), FriendlyError> }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum DeleteTestPacket { - Leave { delete: bool }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum SearchPacket { - Check { - p: Pointer, - }, - Leave { - result: Result, FriendlyError>, - }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum SearchCheckPacket { - Leave { result: bool }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum StatPacket { - Leave { - result: Result, - }, -} diff --git a/src/ipc/packet/create.rs b/src/ipc/packet/create.rs new file mode 100644 index 000000000..8f41fabc1 --- /dev/null +++ b/src/ipc/packet/create.rs @@ -0,0 +1,7 @@ +use serde::Deserialize; +use serde::Serialize; + +#[derive(Debug, Serialize, Deserialize)] +pub enum ServerPacket { + Leave {}, +} diff --git a/src/ipc/packet/delete.rs b/src/ipc/packet/delete.rs new file mode 100644 index 000000000..bc8e3c3af --- /dev/null +++ b/src/ipc/packet/delete.rs @@ -0,0 +1,13 @@ +use serde::{Deserialize, Serialize}; +use service::prelude::*; + +#[derive(Debug, Serialize, Deserialize)] +pub enum ServerPacket { + Test { p: Pointer }, + Leave { result: Result<(), FriendlyError> }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ClientTestPacket { + pub delete: bool, +} diff --git a/src/ipc/packet/destory.rs b/src/ipc/packet/destory.rs new file mode 100644 index 000000000..b1900275d --- /dev/null +++ b/src/ipc/packet/destory.rs @@ -0,0 +1,6 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub enum ServerPacket { + Leave {}, +} diff --git a/src/ipc/packet/flush.rs b/src/ipc/packet/flush.rs new file mode 100644 index 000000000..fdad641a9 --- /dev/null +++ b/src/ipc/packet/flush.rs @@ -0,0 +1,7 @@ +use serde::{Deserialize, Serialize}; +use service::prelude::*; + +#[derive(Debug, Serialize, Deserialize)] +pub enum ServerPacket { + Leave { result: Result<(), FriendlyError> }, +} diff --git a/src/ipc/packet/insert.rs b/src/ipc/packet/insert.rs new file mode 100644 index 000000000..fdad641a9 --- /dev/null +++ b/src/ipc/packet/insert.rs @@ -0,0 +1,7 @@ +use serde::{Deserialize, Serialize}; +use service::prelude::*; + +#[derive(Debug, Serialize, Deserialize)] +pub enum ServerPacket { + Leave { result: Result<(), FriendlyError> }, +} diff --git a/src/ipc/packet/mod.rs b/src/ipc/packet/mod.rs new file mode 100644 index 000000000..7f5961335 --- /dev/null +++ b/src/ipc/packet/mod.rs @@ -0,0 +1,45 @@ +pub mod create; +pub mod delete; +pub mod destory; +pub mod flush; +pub mod insert; +pub mod search; +pub mod stat; +pub mod vbase; + +use serde::{Deserialize, Serialize}; +use service::index::IndexOptions; +use service::prelude::*; + +#[derive(Debug, Serialize, Deserialize)] +pub enum ClientPacket { + Create { + id: Id, + options: IndexOptions, + }, + Delete { + id: Id, + }, + Destory { + ids: Vec, + }, + Flush { + id: Id, + }, + Insert { + id: Id, + insert: (DynamicVector, Pointer), + }, + Search { + id: Id, + search: (DynamicVector, usize), + prefilter: bool, + }, + Stat { + id: Id, + }, + Vbase { + id: Id, + search: (DynamicVector, usize), + }, +} diff --git a/src/ipc/packet/search.rs b/src/ipc/packet/search.rs new file mode 100644 index 000000000..30e786158 --- /dev/null +++ b/src/ipc/packet/search.rs @@ -0,0 +1,17 @@ +use serde::{Deserialize, Serialize}; +use service::prelude::*; + +#[derive(Debug, Serialize, Deserialize)] +pub enum ServerPacket { + Check { + p: Pointer, + }, + Leave { + result: Result, FriendlyError>, + }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ClientCheckPacket { + pub result: bool, +} diff --git a/src/ipc/packet/stat.rs b/src/ipc/packet/stat.rs new file mode 100644 index 000000000..ff2416516 --- /dev/null +++ b/src/ipc/packet/stat.rs @@ -0,0 +1,10 @@ +use serde::{Deserialize, Serialize}; +use service::index::IndexStat; +use service::prelude::*; + +#[derive(Debug, Serialize, Deserialize)] +pub enum ServerPacket { + Leave { + result: Result, + }, +} diff --git a/src/ipc/packet/vbase.rs b/src/ipc/packet/vbase.rs new file mode 100644 index 000000000..3e1b5cd94 --- /dev/null +++ b/src/ipc/packet/vbase.rs @@ -0,0 +1,21 @@ +use serde::{Deserialize, Serialize}; +use service::prelude::*; + +#[derive(Debug, Serialize, Deserialize)] +pub enum ServerPacket { + Leave {}, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum ClientPacket { + Next {}, + Leave {}, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ServerNextPacket { + pub p: Pointer, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ServerLeavePacket {} diff --git a/src/ipc/server.rs b/src/ipc/server/mod.rs similarity index 76% rename from src/ipc/server.rs rename to src/ipc/server/mod.rs index 4c06ccced..1ece12702 100644 --- a/src/ipc/server.rs +++ b/src/ipc/server/mod.rs @@ -14,28 +14,28 @@ impl RpcHandler { Self { socket } } pub fn handle(mut self) -> Result { - Ok(match self.socket.recv::()? { - RpcPacket::Create { id, options } => RpcHandle::Create { + Ok(match self.socket.recv::()? { + ClientPacket::Create { id, options } => RpcHandle::Create { id, options, x: Create { socket: self.socket, }, }, - RpcPacket::Insert { id, insert } => RpcHandle::Insert { + ClientPacket::Insert { id, insert } => RpcHandle::Insert { id, insert, x: Insert { socket: self.socket, }, }, - RpcPacket::Delete { id } => RpcHandle::Delete { + ClientPacket::Delete { id } => RpcHandle::Delete { id, x: Delete { socket: self.socket, }, }, - RpcPacket::Search { + ClientPacket::Search { id, search, prefilter, @@ -47,25 +47,25 @@ impl RpcHandler { socket: self.socket, }, }, - RpcPacket::Flush { id } => RpcHandle::Flush { + ClientPacket::Flush { id } => RpcHandle::Flush { id, x: Flush { socket: self.socket, }, }, - RpcPacket::Destory { ids } => RpcHandle::Destory { + ClientPacket::Destory { ids } => RpcHandle::Destory { ids, x: Destory { socket: self.socket, }, }, - RpcPacket::Stat { id } => RpcHandle::Stat { + ClientPacket::Stat { id } => RpcHandle::Stat { id, x: Stat { socket: self.socket, }, }, - RpcPacket::Leave {} => RpcHandle::Leave {}, + ClientPacket::Vbase { id, search } => todo!(), }) } } @@ -103,7 +103,6 @@ pub enum RpcHandle { id: Id, x: Stat, }, - Leave {}, } pub struct Create { @@ -112,7 +111,7 @@ pub struct Create { impl Create { pub fn leave(mut self) -> Result { - let packet = CreatePacket::Leave {}; + let packet = create::ServerPacket::Leave {}; self.socket.send(packet)?; Ok(RpcHandler { socket: self.socket, @@ -126,7 +125,7 @@ pub struct Insert { impl Insert { pub fn leave(mut self, result: Result<(), FriendlyError>) -> Result { - let packet = InsertPacket::Leave { result }; + let packet = insert::ServerPacket::Leave { result }; self.socket.send(packet)?; Ok(RpcHandler { socket: self.socket, @@ -140,13 +139,13 @@ pub struct Delete { impl Delete { pub fn next(&mut self, p: Pointer) -> Result { - let packet = DeletePacket::Test { p }; + let packet = delete::ServerPacket::Test { p }; self.socket.send(packet)?; - let DeleteTestPacket::Leave { delete } = self.socket.recv::()?; + let delete::ClientTestPacket { delete } = self.socket.recv::()?; Ok(delete) } pub fn leave(mut self, result: Result<(), FriendlyError>) -> Result { - let packet = DeletePacket::Leave { result }; + let packet = delete::ServerPacket::Leave { result }; self.socket.send(packet)?; Ok(RpcHandler { socket: self.socket, @@ -160,16 +159,17 @@ pub struct Search { impl Search { pub fn check(&mut self, p: Pointer) -> Result { - let packet = SearchPacket::Check { p }; + let packet = search::ServerPacket::Check { p }; self.socket.send(packet)?; - let SearchCheckPacket::Leave { result } = self.socket.recv::()?; + let search::ClientCheckPacket { result } = + self.socket.recv::()?; Ok(result) } pub fn leave( mut self, result: Result, FriendlyError>, ) -> Result { - let packet = SearchPacket::Leave { result }; + let packet = search::ServerPacket::Leave { result }; self.socket.send(packet)?; Ok(RpcHandler { socket: self.socket, @@ -183,7 +183,7 @@ pub struct Flush { impl Flush { pub fn leave(mut self, result: Result<(), FriendlyError>) -> Result { - let packet = FlushPacket::Leave { result }; + let packet = flush::ServerPacket::Leave { result }; self.socket.send(packet)?; Ok(RpcHandler { socket: self.socket, @@ -197,7 +197,7 @@ pub struct Destory { impl Destory { pub fn leave(mut self) -> Result { - let packet = DestoryPacket::Leave {}; + let packet = destory::ServerPacket::Leave {}; self.socket.send(packet)?; Ok(RpcHandler { socket: self.socket, @@ -214,7 +214,7 @@ impl Stat { mut self, result: Result, ) -> Result { - let packet = StatPacket::Leave { result }; + let packet = stat::ServerPacket::Leave { result }; self.socket.send(packet)?; Ok(RpcHandler { socket: self.socket, From 9607a7e70c82e4c407d0000314c197b28ccb0238 Mon Sep 17 00:00:00 2001 From: usamoi Date: Thu, 14 Dec 2023 17:36:32 +0800 Subject: [PATCH 22/23] fix: error and document Signed-off-by: usamoi --- docs/searching.md | 7 ++++--- src/bgworker/mod.rs | 21 +++++++++++++++++---- src/ipc/client/mod.rs | 3 ++- src/ipc/packet/vbase.rs | 4 +++- src/ipc/server/mod.rs | 33 ++++++++++++++++++++++----------- src/lib.rs | 1 + 6 files changed, 49 insertions(+), 20 deletions(-) diff --git a/docs/searching.md b/docs/searching.md index 3cf432abf..2220b6bd7 100644 --- a/docs/searching.md +++ b/docs/searching.md @@ -15,11 +15,12 @@ If `vectors.k` is set to `64`, but your SQL returned less than `64` rows, for ex * The vector index returned `64` rows, but `32` of which are invisble to the transaction so PostgreSQL decided to hide these rows for you. * The vector index returned `64` rows, but `32` of which are satifying the condition `id % 2 = 0` in `WHERE` clause. -There are three ways to solve the problem: +There are four ways to solve the problem: * Set `vectors.k` larger. If you estimate that 20% of rows will satisfy the condition in `WHERE`, just set `vectors.k` to be 5 times than before. * Set `vectors.enable_vector_index` to `off`. If you estimate that 0.0001% of rows will satisfy the condition in `WHERE`, just do not use vector index. No alogrithms will be faster than brute force by PostgreSQL. * Set `vectors.enable_prefilter` to `on`. If you cannot estimate how many rows will satisfy the condition in `WHERE`, leave the job for the index. The index will check if the returned row can be accepted by PostgreSQL. However, it will make queries slower so the default value for this option is `off`. +* Set `vectors.enable_vbase` to `on`. It will use vbase optimization, so that the index will pull rows as many as you need. It only works for HNSW algorithm. ## Options @@ -30,6 +31,6 @@ Search options are specified by PostgreSQL GUC. You can use `SET` command to app | vectors.k | integer | Expected number of candidates returned by index. The parameter will influence the recall if you use HNSW or quantization for indexing. Default value is `64`. | | vectors.enable_prefilter | boolean | Enable prefiltering or not. Default value is `off`. | | vectors.enable_vector_index | boolean | Enable vector indexes or not. This option is for debugging. Default value is `on`. | -| vectors.vbase_range | int4 | The range size when using vbase optimization. When it is set to `0`, vbase optimization will be disabled. A recommended value is `86`. Default value is `0`. | +| vectors.enable_vbase | boolean | Enable vbase optimization. Default value is `off`. | -Note: When `vectors.vbase_range` is enabled, it will ignore `vectors.enable_prefilter`. +Note: When `vectors.enable_vbase` is enabled, prefilter does not work. diff --git a/src/bgworker/mod.rs b/src/bgworker/mod.rs index 1db29539f..529a95e2e 100644 --- a/src/bgworker/mod.rs +++ b/src/bgworker/mod.rs @@ -136,13 +136,26 @@ fn session(worker: Arc, mut handler: RpcHandler) -> Result<(), IpcError> let result = worker.call_stat(id); handler = x.leave(result)?; } - RpcHandle::Vbase { id, vector, mut x } => { + RpcHandle::Vbase { id, vector, x } => { use crate::ipc::server::VbaseHandle::*; - let instance = worker.get_instance(id).expect("todo"); + let instance = match worker.get_instance(id) { + Ok(x) => x, + Err(e) => { + x.error(Err(e))?; + break Ok(()); + } + }; let view = instance.view(); - let mut it = view.vbase(vector).expect("todo"); + let mut it = match view.vbase(vector) { + Ok(x) => x, + Err(e) => { + x.error(Err(e))?; + break Ok(()); + } + }; + let mut x = x.error(Ok(()))?; loop { - match x.handle().expect("todo") { + match x.handle()? { Next { x: y } => { x = y.leave(it.next())?; } diff --git a/src/ipc/client/mod.rs b/src/ipc/client/mod.rs index d05d5a1d1..38aa32b97 100644 --- a/src/ipc/client/mod.rs +++ b/src/ipc/client/mod.rs @@ -123,7 +123,8 @@ impl Rpc { pub fn vbase(mut self: ClientGuard, id: Id, vector: DynamicVector) -> ClientGuard { let packet = RpcPacket::Vbase { id, vector }; self.socket.client_send(packet).friendly(); - let vbase::VbaseNopPacket {} = self.socket.client_recv().friendly(); + let vbase::VbaseErrorPacket { result } = self.socket.client_recv().friendly(); + result.friendly(); ClientGuard::map(self) } } diff --git a/src/ipc/packet/vbase.rs b/src/ipc/packet/vbase.rs index b943318d3..bce914b32 100644 --- a/src/ipc/packet/vbase.rs +++ b/src/ipc/packet/vbase.rs @@ -2,7 +2,9 @@ use serde::{Deserialize, Serialize}; use service::prelude::*; #[derive(Debug, Serialize, Deserialize)] -pub struct VbaseNopPacket {} +pub struct VbaseErrorPacket { + pub result: Result<(), FriendlyError>, +} #[derive(Debug, Serialize, Deserialize)] pub enum VbasePacket { diff --git a/src/ipc/server/mod.rs b/src/ipc/server/mod.rs index ef479aaf5..efafb6efd 100644 --- a/src/ipc/server/mod.rs +++ b/src/ipc/server/mod.rs @@ -65,16 +65,13 @@ impl RpcHandler { socket: self.socket, }, }, - RpcPacket::Vbase { id, vector } => { - self.socket.server_send(vbase::VbaseNopPacket {})?; - RpcHandle::Vbase { - id, - vector, - x: VbaseHandler { - socket: self.socket, - }, - } - } + RpcPacket::Vbase { id, vector } => RpcHandle::Vbase { + id, + vector, + x: Vbase { + socket: self.socket, + }, + }, }) } } @@ -115,7 +112,7 @@ pub enum RpcHandle { Vbase { id: Id, vector: DynamicVector, - x: VbaseHandler, + x: Vbase, }, } @@ -237,6 +234,20 @@ impl Stat { } } +pub struct Vbase { + socket: Socket, +} + +impl Vbase { + pub fn error(mut self, result: Result<(), FriendlyError>) -> Result { + self.socket + .server_send(vbase::VbaseErrorPacket { result })?; + Ok(VbaseHandler { + socket: self.socket, + }) + } +} + pub struct VbaseHandler { socket: Socket, } diff --git a/src/lib.rs b/src/lib.rs index b0e6d437b..cf69461ac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ //! Provides an easy-to-use extension for vector similarity search. #![feature(offset_of)] #![feature(arbitrary_self_types)] +#![feature(try_blocks)] mod bgworker; mod datatype; From 305b73cc06f35212d9934fa9f45dbc5435dd37e4 Mon Sep 17 00:00:00 2001 From: usamoi Date: Thu, 14 Dec 2023 17:50:23 +0800 Subject: [PATCH 23/23] [skip ci] Signed-off-by: usamoi