diff --git a/.cargo/config.toml b/.cargo/config.toml index c4db64902..52579941e 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -4,3 +4,13 @@ rustdocflags = ["--document-private-items"] [target.'cfg(target_os="macos")'] # Postgres symbols won't be available until runtime rustflags = ["-Clink-arg=-Wl,-undefined,dynamic_lookup"] + +[target.x86_64-unknown-linux-gnu] +linker = "x86_64-linux-gnu-gcc" + +[target.aarch64-unknown-linux-gnu] +linker = "aarch64-linux-gnu-gcc" + +[env] +BINDGEN_EXTRA_CLANG_ARGS_x86_64_unknown_linux_gnu = "-isystem /usr/x86_64-linux-gnu/include/ -ccc-gcc-name x86_64-linux-gnu-gcc" +BINDGEN_EXTRA_CLANG_ARGS_aarch64_unknown_linux_gnu = "-isystem /usr/aarch64-linux-gnu/include/ -ccc-gcc-name aarch64-linux-gnu-gcc" diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 1e981361c..2e7ed884f 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -6,6 +6,7 @@ on: paths: - ".cargo/**" - ".github/**" + - "crates/**" - "scripts/**" - "src/**" - "tests/**" @@ -18,6 +19,7 @@ on: paths: - ".cargo/**" - ".github/**" + - "crates/**" - "scripts/**" - "src/**" - "tests/**" @@ -90,11 +92,16 @@ jobs: - name: Format check run: cargo fmt --check - name: Semantic check - run: cargo clippy --no-default-features --features "pg${{ matrix.version }} pg_test" + run: | + cargo clippy --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu + cargo clippy --no-default-features --features "pg${{ matrix.version }} pg_test" --target aarch64-unknown-linux-gnu - name: Debug build - run: cargo build --no-default-features --features "pg${{ matrix.version }} pg_test" + run: | + cargo build --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu + cargo build --no-default-features --features "pg${{ matrix.version }} pg_test" --target aarch64-unknown-linux-gnu - name: Test - run: cargo test --all --no-default-features --features "pg${{ matrix.version }} pg_test" -- --nocapture + run: | + cargo test --all --no-fail-fast --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu - name: Install release run: ./scripts/ci_install.sh - name: Sqllogictest diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index edf49c804..c0c73cf8f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -112,15 +112,17 @@ jobs: - uses: mozilla-actions/sccache-action@v0.0.3 - name: Prepare run: | - sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list' + sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" >> /etc/apt/sources.list.d/pgdg.list' + sudo sh -c 'echo "deb http://apt.llvm.org/$(lsb_release -cs)/ llvm-toolchain-$(lsb_release -cs)-16 main" >> /etc/apt/sources.list' wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - + wget --quiet -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - sudo apt-get update sudo apt-get -y install libpq-dev postgresql-${{ matrix.version }} postgresql-server-dev-${{ matrix.version }} + sudo apt-get -y install clang-16 cargo install cargo-pgrx --git https://github.com/tensorchord/pgrx.git --rev $(cat Cargo.toml | grep "pgrx =" | awk -F'rev = "' '{print $2}' | cut -d'"' -f1) cargo pgrx init --pg${{ matrix.version }}=/usr/lib/postgresql/${{ matrix.version }}/bin/pg_config if [[ "${{ matrix.arch }}" == "arm64" ]]; then sudo apt-get -y install crossbuild-essential-arm64 - rustup target add aarch64-unknown-linux-gnu fi - name: Build Release id: build_release @@ -130,8 +132,6 @@ jobs: mkdir ./artifacts cargo pgrx package if [[ "${{ matrix.arch }}" == "arm64" ]]; then - export CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_LINKER=aarch64-linux-gnu-gcc - export BINDGEN_EXTRA_CLANG_ARGS_aarch64_unknown_linux_gnu="-target aarch64-unknown-linux-gnu -isystem /usr/aarch64-linux-gnu/include/ -ccc-gcc-name aarch64-linux-gnu-gcc" cargo build --target aarch64-unknown-linux-gnu --release --features "pg${{ matrix.version }}" --no-default-features mv ./target/aarch64-unknown-linux-gnu/release/libvectors.so ./target/release/vectors-pg${{ matrix.version }}/usr/lib/postgresql/${{ matrix.version }}/lib/vectors.so fi diff --git a/.gitignore b/.gitignore index ee9dad69c..b50e6dba4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ .vscode .ignore __pycache__ -.pytest_cache \ No newline at end of file +.pytest_cache +rustc-ice-*.txt diff --git a/Cargo.lock b/Cargo.lock index 06be97a82..6a57f8991 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -100,12 +100,12 @@ dependencies = [ [[package]] name = "async-channel" -version = "2.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d37875bd9915b7d67c2f117ea2c30a0989874d0b2cb694fe25403c85763c0c9e" +checksum = "1ca33f4bc4ed1babef42cad36cc1f51fa88be00420404e5b1e80ab1b18f7678c" dependencies = [ "concurrent-queue", - "event-listener 3.1.0", + "event-listener 4.0.0", "event-listener-strategy", "futures-core", "pin-project-lite", @@ -113,30 +113,30 @@ dependencies = [ [[package]] name = "async-executor" -version = "1.7.2" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc5ea910c42e5ab19012bab31f53cb4d63d54c3a27730f9a833a88efcf4bb52d" +checksum = "17ae5ebefcc48e7452b4987947920dac9450be1110cadf34d1b8c116bdbaf97c" dependencies = [ - "async-lock 3.1.1", + "async-lock 3.2.0", "async-task", "concurrent-queue", "fastrand 2.0.1", - "futures-lite 2.0.1", + "futures-lite 2.1.0", "slab", ] [[package]] name = "async-global-executor" -version = "2.3.1" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1b6f5d7df27bd294849f8eec66ecfc63d11814df7a4f5d74168a2394467b776" +checksum = "05b1b633a2115cd122d73b955eadd9916c18c8f510ec9cd1686404c60ad1c29c" dependencies = [ - "async-channel 1.9.0", + "async-channel 2.1.1", "async-executor", - "async-io 1.13.0", - "async-lock 2.8.0", + "async-io 2.2.2", + "async-lock 3.2.0", "blocking", - "futures-lite 1.13.0", + "futures-lite 2.1.0", "once_cell", ] @@ -162,22 +162,21 @@ dependencies = [ [[package]] name = "async-io" -version = "2.2.0" +version = "2.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41ed9d5715c2d329bf1b4da8d60455b99b187f27ba726df2883799af9af60997" +checksum = "6afaa937395a620e33dc6a742c593c01aced20aa376ffb0f628121198578ccc7" dependencies = [ - "async-lock 3.1.1", + "async-lock 3.2.0", "cfg-if", "concurrent-queue", "futures-io", - "futures-lite 2.0.1", + "futures-lite 2.1.0", "parking", - "polling 3.3.0", - "rustix 0.38.25", + "polling 3.3.1", + "rustix 0.38.28", "slab", "tracing", - "waker-fn", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -191,11 +190,11 @@ dependencies = [ [[package]] name = "async-lock" -version = "3.1.1" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "655b9c7fe787d3b25cc0f804a1a8401790f0c5bc395beb5a64dc77d8de079105" +checksum = "7125e42787d53db9dd54261812ef17e937c95a51e4d291373b670342fa44310c" dependencies = [ - "event-listener 3.1.0", + "event-listener 4.0.0", "event-listener-strategy", "pin-project-lite", ] @@ -222,8 +221,8 @@ dependencies = [ "cfg-if", "event-listener 3.1.0", "futures-lite 1.13.0", - "rustix 0.38.25", - "windows-sys", + "rustix 0.38.28", + "windows-sys 0.48.0", ] [[package]] @@ -232,16 +231,16 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e47d90f65a225c4527103a8d747001fc56e375203592b25ad103e1ca13124c5" dependencies = [ - "async-io 2.2.0", + "async-io 2.2.2", "async-lock 2.8.0", "atomic-waker", "cfg-if", "futures-core", "futures-io", - "rustix 0.38.25", + "rustix 0.38.28", "signal-hook-registry", "slab", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -285,14 +284,14 @@ checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.41", ] [[package]] name = "atomic-polyfill" -version = "0.1.11" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3ff7eb3f316534d83a8a2c3d1674ace8a5a71198eba31e2e2b597833f699b28" +checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4" dependencies = [ "critical-section", ] @@ -377,7 +376,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.39", + "syn 2.0.41", ] [[package]] @@ -434,12 +433,12 @@ version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a37913e8dc4ddcc604f0c6d3bf2887c995153af3611de9e23c352b44c1b9118" dependencies = [ - "async-channel 2.1.0", - "async-lock 3.1.1", + "async-channel 2.1.1", + "async-lock 3.2.0", "async-task", "fastrand 2.0.1", "futures-io", - "futures-lite 2.0.1", + "futures-lite 2.1.0", "piper", "tracing", ] @@ -455,12 +454,26 @@ name = "bytemuck" version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" +dependencies = [ + "bytemuck_derive", +] [[package]] -name = "byteorder" +name = "bytemuck_derive" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.41", +] + +[[package]] +name = "byteorder" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "bytes" @@ -468,6 +481,14 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" +[[package]] +name = "c" +version = "0.0.0" +dependencies = [ + "cc", + "half 2.3.1", +] + [[package]] name = "cargo_toml" version = "0.16.3" @@ -518,7 +539,7 @@ dependencies = [ "iana-time-zone", "num-traits", "serde", - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -534,9 +555,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.4.8" +version = "4.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2275f18819641850fa26c89acc84d465c1bf91ce57bc2748b28c420473352f64" +checksum = "bfaff671f6b22ca62406885ece523383b9b64022e341e53e009a62ebc47a45f2" dependencies = [ "clap_builder", "clap_derive", @@ -554,9 +575,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.4.8" +version = "4.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07cdf1b148b25c1e1f7a42225e30a0d99a615cd4637eae7365548dd4529b95bc" +checksum = "a216b506622bb1d316cd51328dce24e07bdff4a6128a47c7e7fad11878d5adbb" dependencies = [ "anstyle", "clap_lex", @@ -571,7 +592,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.41", ] [[package]] @@ -582,9 +603,9 @@ checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" [[package]] name = "concurrent-queue" -version = "2.3.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f057a694a54f12365049b0958a1685bb52d567f5593b355fbf685838e873d400" +checksum = "d16048cd947b08fa32c24458a22f5dc5e835264f689f4f5653210c69fd107363" dependencies = [ "crossbeam-utils", ] @@ -600,9 +621,9 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.4" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" [[package]] name = "cpufeatures" @@ -712,13 +733,13 @@ dependencies = [ ] [[package]] -name = "cstr" -version = "0.2.11" +name = "ctor" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8aa998c33a6d3271e3678950a22134cd7dd27cef86dee1b611b5b14207d1d90b" +checksum = "30d2b3721e861707777e3195b0158f950ae6dc4a27e4d02ff9f67e3eb3de199e" dependencies = [ - "proc-macro2", "quote", + "syn 2.0.41", ] [[package]] @@ -744,9 +765,9 @@ dependencies = [ [[package]] name = "curl-sys" -version = "0.4.68+curl-8.4.0" +version = "0.4.70+curl-8.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4a0d18d88360e374b16b2273c832b5e57258ffc1d4aa4f96b108e0738d5752f" +checksum = "3c0333d8849afe78a4c8102a429a446bfdd055832af071945520e835ae2d841e" dependencies = [ "cc", "libc", @@ -755,7 +776,7 @@ dependencies = [ "openssl-sys", "pkg-config", "vcpkg", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -779,7 +800,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.39", + "syn 2.0.41", ] [[package]] @@ -790,7 +811,7 @@ checksum = "836a9bbc7ad63342d6d6e7b815ccab164bc77a2d95d84bc3117a8c0d5c98e2d5" dependencies = [ "darling_core", "quote", - "syn 2.0.39", + "syn 2.0.41", ] [[package]] @@ -800,7 +821,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" dependencies = [ "cfg-if", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "lock_api", "once_cell", "parking_lot_core", @@ -808,9 +829,9 @@ dependencies = [ [[package]] name = "deranged" -version = "0.3.9" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f32d04922c60427da6f9fef14d042d9edddef64cb9d4ce0d64d0685fbeb1fd3" +checksum = "8eb30d70a07a3b04884d2677f06bec33509dc67ca60d92949e5535352d3191dc" dependencies = [ "powerfmt", "serde", @@ -867,7 +888,7 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -919,22 +940,22 @@ dependencies = [ [[package]] name = "enum-map" -version = "2.7.2" +version = "2.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09e6b4f374c071b18172e23134e01026653dc980636ee139e0dfe59c538c61e5" +checksum = "6866f3bfdf8207509a033af1a75a7b08abda06bbaaeae6669323fd5a097df2e9" dependencies = [ "enum-map-derive", ] [[package]] name = "enum-map-derive" -version = "0.16.0" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfdb3d73d1beaf47c8593a1364e577fde072677cbfd103600345c0f547408cc0" +checksum = "f282cfdfe92516eb26c2af8589c274c7c17681f5ecc03c18255fe741c6aa64eb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.41", ] [[package]] @@ -958,12 +979,12 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f258a7194e7f7c2a7837a8913aeab7fd8c383457034fa20ce4dd3dcb813e8eb8" +checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -983,13 +1004,24 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "event-listener" +version = "4.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "770d968249b5d99410d61f5bf89057f3199a077a04d087092f58e7d10692baae" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + [[package]] name = "event-listener-strategy" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d96b852f1345da36d551b9473fa1e2b1eb5c5195585c6c018118bc92a8d91160" +checksum = "958e4d70b6d5e81971bebec42271ec641e7ff4e170a6fa605f2b8a8b65cb97d3" dependencies = [ - "event-listener 3.1.0", + "event-listener 4.0.0", "pin-project-lite", ] @@ -1063,9 +1095,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "form_urlencoded" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" dependencies = [ "percent-encoding", ] @@ -1121,14 +1153,13 @@ dependencies = [ [[package]] name = "futures-lite" -version = "2.0.1" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3831c2651acb5177cbd83943f3d9c8912c5ad03c76afcc0e9511ba568ec5ebb" +checksum = "aeee267a1883f7ebef3700f262d2d54de95dfaf38189015a74fdc4e0c7ad8143" dependencies = [ "fastrand 2.0.1", "futures-core", "futures-io", - "memchr", "parking", "pin-project-lite", ] @@ -1141,7 +1172,7 @@ checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.41", ] [[package]] @@ -1194,9 +1225,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.28.0" +version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" [[package]] name = "glob" @@ -1222,6 +1253,19 @@ version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" +[[package]] +name = "half" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872" +dependencies = [ + "bytemuck", + "cfg-if", + "crunchy", + "num-traits", + "serde", +] + [[package]] name = "hash32" version = "0.2.1" @@ -1239,15 +1283,15 @@ checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" [[package]] name = "hashbrown" -version = "0.14.2" +version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f93e7192158dbcda357bdec5fb5788eebf8bbac027f3f33e719d29135ae84156" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" [[package]] name = "heapless" -version = "0.7.16" +version = "0.7.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db04bc24a18b9ea980628ecf00e6c0264f3c1426dac36c00cb49b6fbad8b0743" +checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f" dependencies = [ "atomic-polyfill", "hash32", @@ -1296,9 +1340,9 @@ dependencies = [ [[package]] name = "http-body" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" +checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", "http", @@ -1413,6 +1457,16 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "idna" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "if_chain" version = "1.0.2" @@ -1443,7 +1497,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" dependencies = [ "equivalent", - "hashbrown 0.14.2", + "hashbrown 0.14.3", "serde", ] @@ -1464,7 +1518,7 @@ checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" dependencies = [ "hermit-abi", "libc", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -1474,8 +1528,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" dependencies = [ "hermit-abi", - "rustix 0.38.25", - "windows-sys", + "rustix 0.38.28", + "windows-sys 0.48.0", ] [[package]] @@ -1516,15 +1570,15 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.9" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" +checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" [[package]] name = "js-sys" -version = "0.3.65" +version = "0.3.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54c0c35952f67de54bb584e9fd912b3023117cbafc0a77d8f3dee1fb5f572fe8" +checksum = "cee9c64da59eae3b50095c18d3e74f8b73c0b86d2792824ff01bbce68ba229ca" dependencies = [ "wasm-bindgen", ] @@ -1589,9 +1643,9 @@ checksum = "db13adb97ab515a3691f56e4dbab09283d0b86cb45abd991d8634a9d6f501760" [[package]] name = "libc" -version = "0.2.150" +version = "0.2.151" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" +checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" [[package]] name = "libloading" @@ -1650,9 +1704,9 @@ checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" [[package]] name = "linux-raw-sys" -version = "0.4.11" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "969488b55f8ac402214f3f5fd243ebb7206cf82de60d3172994707a4bcc2b829" +checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" [[package]] name = "lock_api" @@ -1730,13 +1784,13 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.9" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3dce281c5e46beae905d4de1870d8b1509a9142b62eedf18b443b011ca8343d0" +checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" dependencies = [ "libc", "wasi", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -1850,9 +1904,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "openai_api_rust" @@ -1875,9 +1929,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.95" +version = "0.9.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40a4130519a360279579c2053038317e40eff64d13fd3f004f9e1b72b8a6aaf9" +checksum = "c3eaad34cdd97d81de97964fc7f29e2d104f483840d906ef56daa1912338460b" dependencies = [ "cc", "libc", @@ -1923,7 +1977,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -1944,9 +1998,9 @@ checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" [[package]] name = "percent-encoding" -version = "2.3.0" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pest" @@ -2125,7 +2179,7 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.41", ] [[package]] @@ -2170,21 +2224,21 @@ dependencies = [ "libc", "log", "pin-project-lite", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] name = "polling" -version = "3.3.0" +version = "3.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e53b6af1f60f36f8c2ac2aad5459d75a5a9b4be1e8cdd40264f315d78193e531" +checksum = "cf63fa624ab313c11656b4cda960bfc46c410187ad493c41f6ba2d8c1e991c9e" dependencies = [ "cfg-if", "concurrent-queue", "pin-project-lite", - "rustix 0.38.25", + "rustix 0.38.28", "tracing", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -2304,9 +2358,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.69" +version = "1.0.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" +checksum = "39278fbbf5fb4f646ce651690877f89d1c5811a3d4acb27700c1cb3cdb78fd3b" dependencies = [ "unicode-ident", ] @@ -2468,16 +2522,16 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "ring" -version = "0.17.5" +version = "0.17.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb0205304757e5d899b9c2e448b867ffd03ae7f988002e47cd24954391394d0b" +checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74" dependencies = [ "cc", "getrandom", "libc", "spin", "untrusted", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2521,27 +2575,27 @@ dependencies = [ "io-lifetimes", "libc", "linux-raw-sys 0.3.8", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] name = "rustix" -version = "0.38.25" +version = "0.38.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc99bc2d4f1fed22595588a013687477aedf3cdcfb26558c559edb67b4d9b22e" +checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316" dependencies = [ "bitflags 2.4.1", "errno", "libc", - "linux-raw-sys 0.4.11", - "windows-sys", + "linux-raw-sys 0.4.12", + "windows-sys 0.52.0", ] [[package]] name = "rustls" -version = "0.21.9" +version = "0.21.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "629648aced5775d558af50b2b4c7b02983a04b312126d45eeead26e7caa498b9" +checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" dependencies = [ "log", "ring", @@ -2579,9 +2633,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.15" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" +checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" [[package]] name = "same-file" @@ -2598,7 +2652,7 @@ version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88" dependencies = [ - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2655,9 +2709,9 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.192" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bca2a08484b285dcb282d0f67b26cadc0df8b19f8c12502c13d966bf9482f001" +checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" dependencies = [ "serde_derive", ] @@ -2668,19 +2722,19 @@ version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2bef2ebfde456fb76bbcf9f59315333decc4fda0b2b44b420243c11e0f5ec1f5" dependencies = [ - "half", + "half 1.8.2", "serde", ] [[package]] name = "serde_derive" -version = "1.0.192" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6c7207fbec9faa48073f3e3074cbe553af6ea512d7c21ba46e434e70ea9fbc1" +checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.41", ] [[package]] @@ -2739,7 +2793,43 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.41", +] + +[[package]] +name = "service" +version = "0.0.0" +dependencies = [ + "arc-swap", + "arrayvec", + "bincode", + "bytemuck", + "byteorder", + "c", + "crc32fast", + "crossbeam", + "ctor", + "dashmap", + "half 2.3.1", + "libc", + "log", + "memmap2", + "memoffset", + "multiversion", + "num-traits", + "parking_lot", + "rand", + "rayon", + "rustix 0.38.28", + "serde", + "serde_json", + "serde_with", + "std_detect", + "tempfile", + "thiserror", + "ulock-sys", + "uuid", + "validator", ] [[package]] @@ -2823,7 +2913,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -2848,10 +2938,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] -name = "static_assertions" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +name = "std_detect" +version = "0.1.5" +source = "git+https://github.com/tensorchord/stdarch.git?branch=avx512fp16#db0cdbc9b02074bfddabfd23a4a681f21640eada" +dependencies = [ + "cfg-if", + "libc", +] [[package]] name = "string_cache" @@ -2902,9 +2995,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.39" +version = "2.0.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23e78b90f2fcf45d3e842032ce32e3f2d1545ba6636271dcbf24fa306d87be7a" +checksum = "44c8b28c477cc3bf0e7966561e3460130e1255f7a1cf71931075f1c5e7a7e269" dependencies = [ "proc-macro2", "quote", @@ -2913,9 +3006,9 @@ dependencies = [ [[package]] name = "sysinfo" -version = "0.29.10" +version = "0.29.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a18d114d420ada3a891e6bc8e96a2023402203296a47cdd65083377dad18ba5" +checksum = "cd727fc423c2060f6c92d9534cef765c65a6ed3f428a03d7def74a8c4348e666" dependencies = [ "cfg-if", "core-foundation-sys", @@ -2947,8 +3040,8 @@ dependencies = [ "cfg-if", "fastrand 2.0.1", "redox_syscall", - "rustix 0.38.25", - "windows-sys", + "rustix 0.38.28", + "windows-sys 0.48.0", ] [[package]] @@ -2994,7 +3087,7 @@ checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.41", ] [[package]] @@ -3052,9 +3145,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.34.0" +version = "1.35.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0c014766411e834f7af5b8f4cf46257aab4036ca95e9d2c144a10f59ad6f5b9" +checksum = "841d45b238a16291a4e1584e61820b8ae57d696cc5015c459c229ccc6990cc1c" dependencies = [ "backtrace", "bytes", @@ -3065,7 +3158,7 @@ dependencies = [ "signal-hook-registry", "socket2 0.5.5", "tokio-macros", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -3076,7 +3169,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.41", ] [[package]] @@ -3179,7 +3272,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.41", ] [[package]] @@ -3203,9 +3296,9 @@ dependencies = [ [[package]] name = "try-lock" -version = "0.2.4" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] name = "typenum" @@ -3242,9 +3335,9 @@ checksum = "ccb97dac3243214f8d8507998906ca3e2e0b900bf9bf4870477f125b82e68f6e" [[package]] name = "unicode-bidi" -version = "0.3.13" +version = "0.3.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" +checksum = "6f2528f27a9eb2b21e69c95319b30bd0efd85d09c379741b0f78ea1d86be2416" [[package]] name = "unicode-ident" @@ -3281,9 +3374,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "ureq" -version = "2.8.0" +version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5ccd538d4a604753ebc2f17cd9946e89b77bf87f6a8e2309667c6f2e87855e3" +checksum = "f8cdd25c339e200129fe4de81451814e5228c9b771d57378817d6117cc2b3f97" dependencies = [ "base64", "flate2", @@ -3299,20 +3392,20 @@ dependencies = [ [[package]] name = "url" -version = "2.4.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "143b538f18257fac9cad154828a57c6bf5157e1aa604d4816b5995bf6de87ae5" +checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" dependencies = [ "form_urlencoded", - "idna", + "idna 0.5.0", "percent-encoding", ] [[package]] name = "uuid" -version = "1.6.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c58fe91d841bc04822c9801002db4ea904b9e4b8e6bbad25127b46eff8dc516b" +checksum = "5e395fcf16a7a3d8127ec99782007af141946b4795001f876d54fb0d55978560" dependencies = [ "getrandom", "serde", @@ -3324,7 +3417,7 @@ version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b92f40481c04ff1f4f61f304d61793c7b56ff76ac1469f1beb199b1445b253bd" dependencies = [ - "idna", + "idna 0.4.0", "lazy_static", "regex", "serde", @@ -3374,41 +3467,26 @@ checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" [[package]] name = "vectors" -version = "0.1.1" +version = "0.0.0" dependencies = [ - "arc-swap", - "arrayvec", "bincode", - "bytemuck", "byteorder", - "crc32fast", - "crossbeam", - "cstr", - "dashmap", "env_logger", + "half 2.3.1", "httpmock", "libc", "log", - "memmap2", - "memoffset", "mockall", - "multiversion", + "num-traits", "openai_api_rust", - "parking_lot", "pgrx", "pgrx-tests", - "rand", - "rayon", - "rustix 0.38.25", + "rustix 0.38.28", "serde", "serde_json", - "serde_with", - "static_assertions", - "tempfile", + "service", "thiserror", "toml", - "ulock-sys", - "uuid", "validator", ] @@ -3460,9 +3538,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7daec296f25a1bae309c0cd5c29c4b260e510e6d813c286b19eaadf409d40fce" +checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -3470,24 +3548,24 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e397f4664c0e4e428e8313a469aaa58310d302159845980fd23b0f22a847f217" +checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.41", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.38" +version = "0.4.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9afec9963e3d0994cac82455b2b3502b81a7f40f9a0d32181f7528d9f4b43e02" +checksum = "ac36a15a220124ac510204aec1c3e5db8a22ab06fd6706d881dc6149f8ed9a12" dependencies = [ "cfg-if", "js-sys", @@ -3497,9 +3575,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5961017b3b08ad5f3fe39f1e79877f8ee7c23c5e5fd5eb80de95abc41f1f16b2" +checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -3507,28 +3585,28 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5353b8dab669f5e10f5bd76df26a9360c748f054f862ff5f3f8aae0c7fb3907" +checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn 2.0.41", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.88" +version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d046c5d029ba91a1ed14da14dca44b68bf2f124cfbaf741c54151fdb3e0750b" +checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" [[package]] name = "web-sys" -version = "0.3.65" +version = "0.3.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5db499c5f66323272151db0e666cd34f78617522fb0c1604d31a27c50c206a85" +checksum = "50c24a44ec86bb68fbecd1b3efed7e85ea5621b39b35ef2766b66cd984f8010f" dependencies = [ "js-sys", "wasm-bindgen", @@ -3536,9 +3614,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.25.2" +version = "0.25.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14247bb57be4f377dfb94c72830b8ce8fc6beac03cf4bf7b9732eadd414123fc" +checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10" [[package]] name = "whoami" @@ -3587,7 +3665,7 @@ version = "0.51.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1f8cf84f35d2db49a46868f947758c7a1138116f7fac3bc844f43ade1292e64" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", ] [[package]] @@ -3596,7 +3674,16 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets", + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.0", ] [[package]] @@ -3605,13 +3692,28 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +dependencies = [ + "windows_aarch64_gnullvm 0.52.0", + "windows_aarch64_msvc 0.52.0", + "windows_i686_gnu 0.52.0", + "windows_i686_msvc 0.52.0", + "windows_x86_64_gnu 0.52.0", + "windows_x86_64_gnullvm 0.52.0", + "windows_x86_64_msvc 0.52.0", ] [[package]] @@ -3620,47 +3722,89 @@ version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" + [[package]] name = "windows_i686_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" +[[package]] +name = "windows_i686_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" + [[package]] name = "windows_i686_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" +[[package]] +name = "windows_i686_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" + [[package]] name = "winnow" -version = "0.5.19" +version = "0.5.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "829846f3e3db426d4cee4510841b71a8e58aa2a76b1132579487ae430ccd9c7b" +checksum = "6c830786f7720c2fd27a1a0e27a709dbd3c4d009b56d098fc742d4f4eab91fe2" dependencies = [ "memchr", ] diff --git a/Cargo.toml b/Cargo.toml index 403fb73fe..054f9ab13 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "vectors" -version = "0.1.1" -edition = "2021" +version.workspace = true +edition.workspace = true [lib] crate-type = ["cdylib"] @@ -16,45 +16,60 @@ pg16 = ["pgrx/pg16", "pgrx-tests/pg16"] pg_test = [] [dependencies] +libc.workspace = true +log.workspace = true +serde.workspace = true +serde_json.workspace = true +validator.workspace = true +rustix.workspace = true +thiserror.workspace = true +byteorder.workspace = true +bincode.workspace = true +half.workspace = true +num-traits.workspace = true +service = { path = "crates/service" } pgrx = { git = "https://github.com/tensorchord/pgrx.git", rev = "7c30e2023876c1efce613756f5ec81f3ab05696b", default-features = false, features = [ ] } openai_api_rust = { git = "https://github.com/tensorchord/openai-api.git", rev = "228d54b6002e98257b3c81501a054942342f585f" } -static_assertions = "1.1.0" -libc = "~0.2" -serde = "1.0.163" -bincode = "1.3.3" -rand = "0.8.5" -byteorder = "1.4.3" -crc32fast = "1.3.2" -log = "0.4.18" env_logger = "0.10.0" -crossbeam = "0.8.2" -dashmap = "5.4.0" -parking_lot = "0.12.1" -memoffset = "0.9.0" -serde_json = "1" -thiserror = "1.0.40" -tempfile = "3.6.0" -cstr = "0.2.11" -arrayvec = { version = "0.7.3", features = ["serde"] } -memmap2 = "0.9.0" -validator = { version = "0.16.1", features = ["derive"] } toml = "0.8.8" -rayon = "1.6.1" -uuid = { version = "1.4.1", features = ["serde"] } -rustix = { version = "0.38.20", features = ["net", "mm"] } -arc-swap = "1.6.0" -bytemuck = { version = "1.14.0", features = ["extern_crate_alloc"] } -serde_with = "3.4.0" -multiversion = "0.7.3" [dev-dependencies] pgrx-tests = { git = "https://github.com/tensorchord/pgrx.git", rev = "7c30e2023876c1efce613756f5ec81f3ab05696b" } httpmock = "0.6" mockall = "0.11.4" -[target.'cfg(target_os = "macos")'.dependencies] -ulock-sys = "0.1.0" +[lints] +clippy.too_many_arguments = "allow" +clippy.unnecessary_literal_unwrap = "allow" +clippy.unnecessary_unwrap = "allow" +rust.unsafe_op_in_unsafe_fn = "warn" + +[workspace] +resolver = "2" +members = ["crates/*"] + +[workspace.package] +version = "0.0.0" +edition = "2021" + +[workspace.dependencies] +libc = "~0.2" +log = "~0.4" +serde = "~1.0" +serde_json = "1" +thiserror = "~1.0" +bincode = "~1.3" +byteorder = "~1.4" +half = { version = "~2.3", features = [ + "bytemuck", + "num-traits", + "serde", + "use-intrinsics", +] } +num-traits = "~0.2" +validator = { version = "~0.16", features = ["derive"] } +rustix = { version = "~0.38", features = ["net", "mm"] } [profile.dev] panic = "unwind" @@ -65,10 +80,3 @@ opt-level = 3 lto = "fat" codegen-units = 1 debug = true - -[lints.clippy] -needless_range_loop = "allow" -derivable_impls = "allow" -unnecessary_literal_unwrap = "allow" -too_many_arguments = "allow" -unnecessary_unwrap = "allow" diff --git a/README.md b/README.md index 4fe30f6c0..246f04ffc 100644 --- a/README.md +++ b/README.md @@ -21,13 +21,13 @@ pgvecto.rs is a Postgres extension that provides vector similarity search functi ## Comparison with pgvector -| | pgvecto.rs | pgvector | -| ------------------------------------------- | ------------------------------------------------------ | ------------------------ | -| Transaction support | ✅ | ⚠️ | -| Sufficient Result with Delete/Update/Filter | ✅ | ⚠️ | -| Vector Dimension Limit | 65535 | 2000 | -| Prefilter on HNSW | ✅ | ❌ | -| Parallel HNSW Index build | ⚡️ Linearly faster with more cores | 🐌 Only single core used | +| | pgvecto.rs | pgvector | +| ------------------------------------------- | ------------------------------------------------------ | ----------------------- | +| Transaction support | ✅ | ⚠️ | +| Sufficient Result with Delete/Update/Filter | ✅ | ⚠️ | +| Vector Dimension Limit | 65535 | 2000 | +| Prefilter on HNSW | ✅ | ❌ | +| Parallel HNSW Index build | ⚡️ Linearly faster with more cores | 🐌 Only single core used | | Async Index build | Ready for queries anytime and do not block insertions. | ❌ | | Quantization | Scalar/Product Quantization | ❌ | @@ -45,7 +45,11 @@ More details at [./docs/comparison-pgvector.md](./docs/comparison-pgvector.md) For users, we recommend you to try pgvecto.rs using our pre-built docker image, by running ```sh -docker run --name pgvecto-rs-demo -e POSTGRES_PASSWORD=mysecretpassword -p 5432:5432 -d tensorchord/pgvecto-rs:pg16-latest +docker run \ + --name pgvecto-rs-demo \ + -e POSTGRES_PASSWORD=mysecretpassword \ + -p 5432:5432 \ + -d tensorchord/pgvecto-rs:pg16-latest ``` ## Development with envd diff --git a/bindings/python/tests/__init__.py b/bindings/python/tests/__init__.py index 1e28cc768..0a2e9ace2 100644 --- a/bindings/python/tests/__init__.py +++ b/bindings/python/tests/__init__.py @@ -19,14 +19,12 @@ TOML_SETTINGS = { "flat": toml.dumps( { - "capacity": 2097152, - "algorithm": {"flat": {}}, + "indexing": {"flat": {}}, }, ), "hnsw": toml.dumps( { - "capacity": 2097152, - "algorithm": {"hnsw": {}}, + "indexing": {"hnsw": {}}, }, ), } diff --git a/bindings/python/tests/test_psycopg.py b/bindings/python/tests/test_psycopg.py index f79b4e2af..07e3a924f 100644 --- a/bindings/python/tests/test_psycopg.py +++ b/bindings/python/tests/test_psycopg.py @@ -38,7 +38,7 @@ def conn(): @pytest.mark.parametrize(("index_name", "index_setting"), TOML_SETTINGS.items()) def test_create_index(conn: Connection, index_name: str, index_setting: str): stat = sql.SQL( - "CREATE INDEX {} ON tb_test_item USING vectors (embedding l2_ops) WITH (options={});", + "CREATE INDEX {} ON tb_test_item USING vectors (embedding vector_l2_ops) WITH (options={});", ).format(sql.Identifier(index_name), index_setting) conn.execute(stat) diff --git a/bindings/python/tests/test_sqlalchemy.py b/bindings/python/tests/test_sqlalchemy.py index d51e58b2b..9ae03e391 100644 --- a/bindings/python/tests/test_sqlalchemy.py +++ b/bindings/python/tests/test_sqlalchemy.py @@ -68,7 +68,7 @@ def test_create_index(session: Session, index_name: str, index_setting: str): Document.embedding, postgresql_using="vectors", postgresql_with={"options": f"$${index_setting}$$"}, - postgresql_ops={"embedding": "l2_ops"}, + postgresql_ops={"embedding": "vector_l2_ops"}, ) index.create(session.bind) session.commit() diff --git a/crates/c/.gitignore b/crates/c/.gitignore new file mode 100644 index 000000000..9f70fdf2e --- /dev/null +++ b/crates/c/.gitignore @@ -0,0 +1,3 @@ +*.s +*.o +*.out \ No newline at end of file diff --git a/crates/c/Cargo.toml b/crates/c/Cargo.toml new file mode 100644 index 000000000..5dc084ed6 --- /dev/null +++ b/crates/c/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "c" +version.workspace = true +edition.workspace = true + +[dependencies] +half = { version = "~2.3", features = ["use-intrinsics"] } + +[build-dependencies] +cc = "1.0" diff --git a/crates/c/build.rs b/crates/c/build.rs new file mode 100644 index 000000000..dad66331b --- /dev/null +++ b/crates/c/build.rs @@ -0,0 +1,10 @@ +fn main() { + println!("cargo:rerun-if-changed=src/c.h"); + println!("cargo:rerun-if-changed=src/c.c"); + cc::Build::new() + .compiler("/usr/bin/clang-16") + .file("./src/c.c") + .opt_level(3) + .debug(true) + .compile("pgvectorsc"); +} diff --git a/crates/c/src/c.c b/crates/c/src/c.c new file mode 100644 index 000000000..e41f282d4 --- /dev/null +++ b/crates/c/src/c.c @@ -0,0 +1,118 @@ +#include "c.h" +#include + +#if defined(__x86_64__) +#include +#endif + +#if defined(__x86_64__) + +__attribute__((target("arch=x86-64-v4,avx512fp16"))) extern float +v_f16_cosine_avx512fp16(_Float16 *a, _Float16 *b, size_t n) { + __m512h xy = _mm512_set1_ph(0); + __m512h xx = _mm512_set1_ph(0); + __m512h yy = _mm512_set1_ph(0); + + while (n >= 32) { + __m512h x = _mm512_loadu_ph(a); + __m512h y = _mm512_loadu_ph(b); + a += 32, b += 32, n -= 32; + xy = _mm512_fmadd_ph(x, y, xy); + xx = _mm512_fmadd_ph(x, x, xx); + yy = _mm512_fmadd_ph(y, y, yy); + } + if (n > 0) { + __mmask32 mask = _bzhi_u32(0xFFFFFFFF, n); + __m512h x = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a)); + __m512h y = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b)); + xy = _mm512_fmadd_ph(x, y, xy); + xx = _mm512_fmadd_ph(x, x, xx); + yy = _mm512_fmadd_ph(y, y, yy); + } + return (float)(_mm512_reduce_add_ph(xy) / + sqrt(_mm512_reduce_add_ph(xx) * _mm512_reduce_add_ph(yy))); +} + +__attribute__((target("arch=x86-64-v4,avx512fp16"))) extern float +v_f16_dot_avx512fp16(_Float16 *a, _Float16 *b, size_t n) { + __m512h xy = _mm512_set1_ph(0); + + while (n >= 32) { + __m512h x = _mm512_loadu_ph(a); + __m512h y = _mm512_loadu_ph(b); + a += 32, b += 32, n -= 32; + xy = _mm512_fmadd_ph(x, y, xy); + } + if (n > 0) { + __mmask32 mask = _bzhi_u32(0xFFFFFFFF, n); + __m512h x = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a)); + __m512h y = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b)); + xy = _mm512_fmadd_ph(x, y, xy); + } + return (float)_mm512_reduce_add_ph(xy); +} + +__attribute__((target("arch=x86-64-v4,avx512fp16"))) extern float +v_f16_sl2_avx512fp16(_Float16 *a, _Float16 *b, size_t n) { + __m512h dd = _mm512_set1_ph(0); + + while (n >= 32) { + __m512h x = _mm512_loadu_ph(a); + __m512h y = _mm512_loadu_ph(b); + a += 32, b += 32, n -= 32; + __m512h d = _mm512_sub_ph(x, y); + dd = _mm512_fmadd_ph(d, d, dd); + } + if (n > 0) { + __mmask32 mask = _bzhi_u32(0xFFFFFFFF, n); + __m512h x = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a)); + __m512h y = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b)); + __m512h d = _mm512_sub_ph(x, y); + dd = _mm512_fmadd_ph(d, d, dd); + } + + return (float)_mm512_reduce_add_ph(dd); +} + +__attribute__((target("arch=x86-64-v3"))) extern float +v_f16_cosine_v3(_Float16 *a, _Float16 *b, size_t n) { + float xy = 0; + float xx = 0; + float yy = 0; +#pragma clang loop vectorize_width(8) + for (size_t i = 0; i < n; i++) { + float x = a[i]; + float y = b[i]; + xy += x * y; + xx += x * x; + yy += y * y; + } + return xy / sqrt(xx * yy); +} + +__attribute__((target("arch=x86-64-v3"))) extern float +v_f16_dot_v3(_Float16 *a, _Float16 *b, size_t n) { + float xy = 0; +#pragma clang loop vectorize_width(8) + for (size_t i = 0; i < n; i++) { + float x = a[i]; + float y = b[i]; + xy += x * y; + } + return xy; +} + +__attribute__((target("arch=x86-64-v3"))) extern float +v_f16_sl2_v3(_Float16 *a, _Float16 *b, size_t n) { + float dd = 0; +#pragma clang loop vectorize_width(8) + for (size_t i = 0; i < n; i++) { + float x = a[i]; + float y = b[i]; + float d = x - y; + dd += d * d; + } + return dd; +} + +#endif diff --git a/crates/c/src/c.h b/crates/c/src/c.h new file mode 100644 index 000000000..d50c3d712 --- /dev/null +++ b/crates/c/src/c.h @@ -0,0 +1,13 @@ +#include +#include + +#if defined(__x86_64__) + +extern float v_f16_cosine_avx512fp16(_Float16 *, _Float16 *, size_t n); +extern float v_f16_dot_avx512fp16(_Float16 *, _Float16 *, size_t n); +extern float v_f16_sl2_avx512fp16(_Float16 *, _Float16 *, size_t n); +extern float v_f16_cosine_v3(_Float16 *, _Float16 *, size_t n); +extern float v_f16_dot_v3(_Float16 *, _Float16 *, size_t n); +extern float v_f16_sl2_v3(_Float16 *, _Float16 *, size_t n); + +#endif diff --git a/crates/c/src/c.rs b/crates/c/src/c.rs new file mode 100644 index 000000000..a4ac2c255 --- /dev/null +++ b/crates/c/src/c.rs @@ -0,0 +1,24 @@ +#[cfg(target_arch = "x86_64")] +#[link(name = "pgvectorsc", kind = "static")] +extern "C" { + pub fn v_f16_cosine_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn v_f16_dot_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn v_f16_sl2_avx512fp16(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn v_f16_cosine_v3(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn v_f16_dot_v3(a: *const u16, b: *const u16, n: usize) -> f32; + pub fn v_f16_sl2_v3(a: *const u16, b: *const u16, n: usize) -> f32; +} + +// `compiler_builtin` defines `__extendhfsf2` with integer calling convention. +// However C compilers links `__extendhfsf2` with floating calling convention. +// The code should be removed once Rust offically supports `f16`. + +#[cfg(target_arch = "x86_64")] +#[no_mangle] +#[linkage = "external"] +extern "C" fn __extendhfsf2(f: f64) -> f32 { + unsafe { + let f: half::f16 = std::mem::transmute_copy(&f); + f.to_f32() + } +} diff --git a/crates/c/src/lib.rs b/crates/c/src/lib.rs new file mode 100644 index 000000000..9c3d869be --- /dev/null +++ b/crates/c/src/lib.rs @@ -0,0 +1,6 @@ +#![feature(linkage)] + +mod c; + +#[allow(unused_imports)] +pub use self::c::*; diff --git a/crates/service/Cargo.toml b/crates/service/Cargo.toml new file mode 100644 index 000000000..367157072 --- /dev/null +++ b/crates/service/Cargo.toml @@ -0,0 +1,45 @@ +[package] +name = "service" +version.workspace = true +edition.workspace = true + +[dependencies] +libc.workspace = true +log.workspace = true +serde.workspace = true +serde_json.workspace = true +validator.workspace = true +rustix.workspace = true +thiserror.workspace = true +byteorder.workspace = true +bincode.workspace = true +half.workspace = true +num-traits.workspace = true +c = { path = "../c" } +std_detect = { git = "https://github.com/tensorchord/stdarch.git", branch = "avx512fp16" } +rand = "0.8.5" +crc32fast = "1.3.2" +crossbeam = "0.8.2" +dashmap = "5.4.0" +parking_lot = "0.12.1" +memoffset = "0.9.0" +tempfile = "3.6.0" +arrayvec = { version = "0.7.3", features = ["serde"] } +memmap2 = "0.9.0" +rayon = "1.6.1" +uuid = { version = "1.6.1", features = ["serde"] } +arc-swap = "1.6.0" +bytemuck = { version = "1.14.0", features = ["extern_crate_alloc"] } +serde_with = "3.4.0" +multiversion = "0.7.3" +ctor = "0.2.6" + +[target.'cfg(target_os = "macos")'.dependencies] +ulock-sys = "0.1.0" + +[lints] +clippy.derivable_impls = "allow" +clippy.len_without_is_empty = "allow" +clippy.needless_range_loop = "allow" +clippy.too_many_arguments = "allow" +rust.unsafe_op_in_unsafe_fn = "warn" diff --git a/src/algorithms/clustering/elkan_k_means.rs b/crates/service/src/algorithms/clustering/elkan_k_means.rs similarity index 74% rename from src/algorithms/clustering/elkan_k_means.rs rename to crates/service/src/algorithms/clustering/elkan_k_means.rs index f30d84df3..494a10792 100644 --- a/src/algorithms/clustering/elkan_k_means.rs +++ b/crates/service/src/algorithms/clustering/elkan_k_means.rs @@ -4,38 +4,37 @@ use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use std::ops::{Index, IndexMut}; -pub struct ElkanKMeans { +pub struct ElkanKMeans { dims: u16, c: usize, - pub centroids: Vec2, + pub centroids: Vec2, lowerbound: Square, - upperbound: Vec, + upperbound: Vec, assign: Vec, rand: StdRng, - samples: Vec2, - d: Distance, + samples: Vec2, } const DELTA: f32 = 1.0 / 1024.0; -impl ElkanKMeans { - pub fn new(c: usize, samples: Vec2, d: Distance) -> Self { +impl ElkanKMeans { + pub fn new(c: usize, samples: Vec2) -> Self { let n = samples.len(); let dims = samples.dims(); let mut rand = StdRng::from_entropy(); let mut centroids = Vec2::new(dims, c); let mut lowerbound = Square::new(n, c); - let mut upperbound = vec![Scalar::Z; n]; + let mut upperbound = vec![F32::zero(); n]; let mut assign = vec![0usize; n]; centroids[0].copy_from_slice(&samples[rand.gen_range(0..n)]); - let mut weight = vec![Scalar::INFINITY; n]; + let mut weight = vec![F32::infinity(); n]; for i in 0..c { - let mut sum = Scalar::Z; + let mut sum = F32::zero(); for j in 0..n { - let dis = d.elkan_k_means_distance(&samples[j], ¢roids[i]); + let dis = S::elkan_k_means_distance(&samples[j], ¢roids[i]); lowerbound[(j, i)] = dis; if dis * dis < weight[j] { weight[j] = dis * dis; @@ -49,7 +48,7 @@ impl ElkanKMeans { let mut choice = sum * rand.gen_range(0.0..1.0); for j in 0..(n - 1) { choice -= weight[j]; - if choice <= Scalar::Z { + if choice <= F32::zero() { break 'a j; } } @@ -59,7 +58,7 @@ impl ElkanKMeans { } for i in 0..n { - let mut minimal = Scalar::INFINITY; + let mut minimal = F32::infinity(); let mut target = 0; for j in 0..c { let dis = lowerbound[(i, j)]; @@ -81,13 +80,11 @@ impl ElkanKMeans { assign, rand, samples, - d, } } pub fn iterate(&mut self) -> bool { let c = self.c; - let f = |lhs: &[Scalar], rhs: &[Scalar]| self.d.elkan_k_means_distance(lhs, rhs); let dims = self.dims; let samples = &self.samples; let rand = &mut self.rand; @@ -100,16 +97,16 @@ impl ElkanKMeans { // Step 1 let mut dist0 = Square::new(c, c); - let mut sp = vec![Scalar::Z; c]; + let mut sp = vec![F32::zero(); c]; for i in 0..c { for j in i + 1..c { - let dis = f(¢roids[i], ¢roids[j]) * 0.5; + let dis = S::elkan_k_means_distance(¢roids[i], ¢roids[j]) * 0.5; dist0[(i, j)] = dis; dist0[(j, i)] = dis; } } for i in 0..c { - let mut minimal = Scalar::INFINITY; + let mut minimal = F32::infinity(); for j in 0..c { if i == j { continue; @@ -127,7 +124,7 @@ impl ElkanKMeans { if upperbound[i] <= sp[assign[i]] { continue; } - let mut minimal = f(&samples[i], ¢roids[assign[i]]); + let mut minimal = S::elkan_k_means_distance(&samples[i], ¢roids[assign[i]]); lowerbound[(i, assign[i])] = minimal; upperbound[i] = minimal; // Step 3 @@ -142,7 +139,7 @@ impl ElkanKMeans { continue; } if minimal > lowerbound[(i, j)] || minimal > dist0[(assign[i], j)] { - let dis = f(&samples[i], ¢roids[j]); + let dis = S::elkan_k_means_distance(&samples[i], ¢roids[j]); lowerbound[(i, j)] = dis; if dis < minimal { minimal = dis; @@ -156,8 +153,8 @@ impl ElkanKMeans { // Step 4, 7 let old = std::mem::replace(centroids, Vec2::new(dims, c)); - let mut count = vec![Scalar::Z; c]; - centroids.fill(Scalar::Z); + let mut count = vec![F32::zero(); c]; + centroids.fill(S::Scalar::zero()); for i in 0..n { for j in 0..dims as usize { centroids[assign[i]][j] += samples[i][j]; @@ -165,21 +162,21 @@ impl ElkanKMeans { count[assign[i]] += 1.0; } for i in 0..c { - if count[i] == Scalar::Z { + if count[i] == F32::zero() { continue; } for dim in 0..dims as usize { - centroids[i][dim] /= count[i]; + centroids[i][dim] /= S::Scalar::from_f32(count[i].into()); } } for i in 0..c { - if count[i] != Scalar::Z { + if count[i] != F32::zero() { continue; } let mut o = 0; loop { - let alpha = Scalar(rand.gen_range(0.0..1.0)); - let beta = (count[o] - 1.0) / (n - c) as Float; + let alpha = F32::from_f32(rand.gen_range(0.0..1.0f32)); + let beta = (count[o] - 1.0) / (n - c) as f32; if alpha < beta { break; } @@ -188,28 +185,28 @@ impl ElkanKMeans { centroids.copy_within(o, i); for dim in 0..dims as usize { if dim % 2 == 0 { - centroids[i][dim] *= 1.0 + DELTA; - centroids[o][dim] *= 1.0 - DELTA; + centroids[i][dim] *= S::Scalar::from_f32(1.0 + DELTA); + centroids[o][dim] *= S::Scalar::from_f32(1.0 - DELTA); } else { - centroids[i][dim] *= 1.0 - DELTA; - centroids[o][dim] *= 1.0 + DELTA; + centroids[i][dim] *= S::Scalar::from_f32(1.0 - DELTA); + centroids[o][dim] *= S::Scalar::from_f32(1.0 + DELTA); } } count[i] = count[o] / 2.0; count[o] = count[o] - count[i]; } for i in 0..c { - self.d.elkan_k_means_normalize(&mut centroids[i]); + S::elkan_k_means_normalize(&mut centroids[i]); } // Step 5, 6 - let mut dist1 = vec![Scalar::Z; c]; + let mut dist1 = vec![F32::zero(); c]; for i in 0..c { - dist1[i] = f(&old[i], ¢roids[i]); + dist1[i] = S::elkan_k_means_distance(&old[i], ¢roids[i]); } for i in 0..n { for j in 0..c { - lowerbound[(i, j)] = (lowerbound[(i, j)] - dist1[j]).max(Scalar::Z); + lowerbound[(i, j)] = std::cmp::max(lowerbound[(i, j)] - dist1[j], F32::zero()); } } for i in 0..n { @@ -219,7 +216,7 @@ impl ElkanKMeans { change == 0 } - pub fn finish(self) -> Vec2 { + pub fn finish(self) -> Vec2 { self.centroids } } @@ -227,7 +224,7 @@ impl ElkanKMeans { pub struct Square { x: usize, y: usize, - v: Box<[Scalar]>, + v: Vec, } impl Square { @@ -235,13 +232,13 @@ impl Square { Self { x, y, - v: bytemuck::zeroed_slice_box(x * y), + v: bytemuck::zeroed_vec(x * y), } } } impl Index<(usize, usize)> for Square { - type Output = Scalar; + type Output = F32; fn index(&self, (x, y): (usize, usize)) -> &Self::Output { debug_assert!(x < self.x); diff --git a/src/algorithms/clustering/mod.rs b/crates/service/src/algorithms/clustering/mod.rs similarity index 100% rename from src/algorithms/clustering/mod.rs rename to crates/service/src/algorithms/clustering/mod.rs diff --git a/src/algorithms/flat.rs b/crates/service/src/algorithms/flat.rs similarity index 63% rename from src/algorithms/flat.rs rename to crates/service/src/algorithms/flat.rs index 0317a45f0..22650f4d0 100644 --- a/src/algorithms/flat.rs +++ b/crates/service/src/algorithms/flat.rs @@ -9,16 +9,16 @@ use std::fs::create_dir; use std::path::PathBuf; use std::sync::Arc; -pub struct Flat { - mmap: FlatMmap, +pub struct Flat { + mmap: FlatMmap, } -impl Flat { +impl Flat { pub fn create( path: PathBuf, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Self { create_dir(&path).unwrap(); let ram = make(path.clone(), sealed, growing, options.clone()); @@ -35,7 +35,7 @@ impl Flat { self.mmap.raw.len() } - pub fn vector(&self, i: u32) -> &[Scalar] { + pub fn vector(&self, i: u32) -> &[S::Scalar] { self.mmap.raw.vector(i) } @@ -43,35 +43,33 @@ impl Flat { self.mmap.raw.payload(i) } - pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { + pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap { search(&self.mmap, k, vector, filter) } } -unsafe impl Send for Flat {} -unsafe impl Sync for Flat {} +unsafe impl Send for Flat {} +unsafe impl Sync for Flat {} -pub struct FlatRam { - raw: Arc, - quantization: Quantization, - d: Distance, +pub struct FlatRam { + raw: Arc>, + quantization: Quantization, } -pub struct FlatMmap { - raw: Arc, - quantization: Quantization, - d: Distance, +pub struct FlatMmap { + raw: Arc>, + quantization: Quantization, } -unsafe impl Send for FlatMmap {} -unsafe impl Sync for FlatMmap {} +unsafe impl Send for FlatMmap {} +unsafe impl Sync for FlatMmap {} -pub fn make( +pub fn make( path: PathBuf, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, options: IndexOptions, -) -> FlatRam { +) -> FlatRam { let idx_opts = options.indexing.clone().unwrap_flat(); let raw = Arc::new(Raw::create( path.join("raw"), @@ -85,22 +83,17 @@ pub fn make( idx_opts.quantization, &raw, ); - FlatRam { - raw, - quantization, - d: options.vector.d, - } + FlatRam { raw, quantization } } -pub fn save(ram: FlatRam, _: PathBuf) -> FlatMmap { +pub fn save(ram: FlatRam, _: PathBuf) -> FlatMmap { FlatMmap { raw: ram.raw, quantization: ram.quantization, - d: ram.d, } } -pub fn load(path: PathBuf, options: IndexOptions) -> FlatMmap { +pub fn load(path: PathBuf, options: IndexOptions) -> FlatMmap { let idx_opts = options.indexing.clone().unwrap_flat(); let raw = Arc::new(Raw::open(path.join("raw"), options.clone())); let quantization = Quantization::open( @@ -109,17 +102,18 @@ pub fn load(path: PathBuf, options: IndexOptions) -> FlatMmap { idx_opts.quantization, &raw, ); - FlatMmap { - raw, - quantization, - d: options.vector.d, - } + FlatMmap { raw, quantization } } -pub fn search(mmap: &FlatMmap, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { +pub fn search( + mmap: &FlatMmap, + k: usize, + vector: &[S::Scalar], + filter: &mut impl Filter, +) -> Heap { let mut result = Heap::new(k); for i in 0..mmap.raw.len() { - let distance = mmap.quantization.distance(mmap.d, vector, i); + let distance = mmap.quantization.distance(vector, i); let payload = mmap.raw.payload(i); if filter.check(payload) { result.push(HeapElement { distance, payload }); diff --git a/src/algorithms/hnsw.rs b/crates/service/src/algorithms/hnsw.rs similarity index 81% rename from src/algorithms/hnsw.rs rename to crates/service/src/algorithms/hnsw.rs index e0c63744a..9ee6c7258 100644 --- a/src/algorithms/hnsw.rs +++ b/crates/service/src/algorithms/hnsw.rs @@ -3,7 +3,7 @@ use super::raw::Raw; use crate::index::indexing::hnsw::HnswIndexingOptions; use crate::index::segments::growing::GrowingSegment; use crate::index::segments::sealed::SealedSegment; -use crate::index::{IndexOptions, VectorOptions}; +use crate::index::IndexOptions; use crate::prelude::*; use crate::utils::dir_ops::sync_dir; use crate::utils::mmap_array::MmapArray; @@ -17,16 +17,16 @@ use std::ops::RangeInclusive; use std::path::PathBuf; use std::sync::Arc; -pub struct Hnsw { - mmap: HnswMmap, +pub struct Hnsw { + mmap: HnswMmap, } -impl Hnsw { +impl Hnsw { pub fn create( path: PathBuf, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Self { create_dir(&path).unwrap(); let ram = make(path.clone(), sealed, growing, options.clone()); @@ -43,7 +43,7 @@ impl Hnsw { self.mmap.raw.len() } - pub fn vector(&self, i: u32) -> &[Scalar] { + pub fn vector(&self, i: u32) -> &[S::Scalar] { self.mmap.raw.vector(i) } @@ -51,27 +51,21 @@ impl Hnsw { self.mmap.raw.payload(i) } - pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { + pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap { search(&self.mmap, k, vector, filter) } - pub fn search_vbase<'index, 'vector>( - &'index self, - range: usize, - vector: &'vector [Scalar], - ) -> HnswIndexIter<'index, 'vector> { + pub fn search_vbase(&self, range: usize, vector: &[S::Scalar]) -> HnswIndexIter<'_, S> { search_vbase(&self.mmap, range, vector) } } -unsafe impl Send for Hnsw {} -unsafe impl Sync for Hnsw {} +unsafe impl Send for Hnsw {} +unsafe impl Sync for Hnsw {} -pub struct HnswRam { - raw: Arc, - quantization: Quantization, - // ---------------------- - d: Distance, +pub struct HnswRam { + raw: Arc>, + quantization: Quantization, // ---------------------- m: u32, // ---------------------- @@ -95,14 +89,12 @@ impl HnswRamVertex { } struct HnswRamLayer { - edges: Vec<(Scalar, u32)>, + edges: Vec<(F32, u32)>, } -pub struct HnswMmap { - raw: Arc, - quantization: Quantization, - // ---------------------- - d: Distance, +pub struct HnswMmap { + raw: Arc>, + quantization: Quantization, // ---------------------- m: u32, // ---------------------- @@ -114,20 +106,19 @@ pub struct HnswMmap { } #[derive(Debug, Clone, Copy, Default)] -struct HnswMmapEdge(Scalar, u32); +struct HnswMmapEdge(F32, u32); -unsafe impl Send for HnswMmap {} -unsafe impl Sync for HnswMmap {} +unsafe impl Send for HnswMmap {} +unsafe impl Sync for HnswMmap {} unsafe impl Pod for HnswMmapEdge {} unsafe impl Zeroable for HnswMmapEdge {} -pub fn make( +pub fn make( path: PathBuf, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, options: IndexOptions, -) -> HnswRam { - let VectorOptions { d, .. } = options.vector; +) -> HnswRam { let HnswIndexingOptions { m, ef_construction, @@ -159,23 +150,22 @@ pub fn make( let entry = RwLock::>::new(None); let visited = VisitedPool::new(raw.len()); (0..n).into_par_iter().for_each(|i| { - fn fast_search( - quantization: &Quantization, + fn fast_search( + quantization: &Quantization, graph: &HnswRamGraph, - d: Distance, levels: RangeInclusive, u: u32, - target: &[Scalar], + target: &[S::Scalar], ) -> u32 { let mut u = u; - let mut u_dis = quantization.distance(d, target, u); + let mut u_dis = quantization.distance(target, u); for i in levels.rev() { let mut changed = true; while changed { changed = false; let guard = graph.vertexs[u as usize].layers[i as usize].read(); for &(_, v) in guard.edges.iter() { - let v_dis = quantization.distance(d, target, v); + let v_dis = quantization.distance(target, v); if v_dis < u_dis { u = v; u_dis = v_dis; @@ -186,21 +176,20 @@ pub fn make( } u } - fn local_search( - quantization: &Quantization, + fn local_search( + quantization: &Quantization, graph: &HnswRamGraph, - d: Distance, visited: &mut VisitedGuard, - vector: &[Scalar], + vector: &[S::Scalar], s: u32, k: usize, i: u8, - ) -> Vec<(Scalar, u32)> { + ) -> Vec<(F32, u32)> { assert!(k > 0); let mut visited = visited.fetch(); - let mut candidates = BinaryHeap::>::new(); + let mut candidates = BinaryHeap::>::new(); let mut results = BinaryHeap::new(); - let s_dis = quantization.distance(d, vector, s); + let s_dis = quantization.distance(vector, s); visited.mark(s); candidates.push(Reverse((s_dis, s))); results.push((s_dis, s)); @@ -217,7 +206,7 @@ pub fn make( continue; } visited.mark(v); - let v_dis = quantization.distance(d, vector, v); + let v_dis = quantization.distance(vector, v); if results.len() < k || v_dis < results.peek().unwrap().0 { candidates.push(Reverse((v_dis, v))); results.push((v_dis, v)); @@ -229,12 +218,7 @@ pub fn make( } results.into_sorted_vec() } - fn select( - quantization: &Quantization, - d: Distance, - input: &mut Vec<(Scalar, u32)>, - size: u32, - ) { + fn select(quantization: &Quantization, input: &mut Vec<(F32, u32)>, size: u32) { if input.len() <= size as usize { return; } @@ -245,7 +229,7 @@ pub fn make( } let check = res .iter() - .map(|&(_, v)| quantization.distance2(d, u, v)) + .map(|&(_, v)| quantization.distance2(u, v)) .all(|dist| dist > u_dis); if check { res.push((u_dis, u)); @@ -290,14 +274,13 @@ pub fn make( }; let top = graph.vertexs[u as usize].levels(); if top > levels { - u = fast_search(&quantization, &graph, d, levels + 1..=top, u, target); + u = fast_search(&quantization, &graph, levels + 1..=top, u, target); } let mut result = Vec::with_capacity(1 + std::cmp::min(levels, top) as usize); for j in (0..=std::cmp::min(levels, top)).rev() { let mut edges = local_search( &quantization, &graph, - d, &mut visited, target, u, @@ -305,12 +288,7 @@ pub fn make( j, ); edges.sort(); - select( - &quantization, - d, - &mut edges, - count_max_edges_of_a_layer(m, j), - ); + select(&quantization, &mut edges, count_max_edges_of_a_layer(m, j)); u = edges.first().unwrap().1; result.push(edges); } @@ -325,7 +303,6 @@ pub fn make( write.edges.insert(index, element); select( &quantization, - d, &mut write.edges, count_max_edges_of_a_layer(m, j), ); @@ -338,14 +315,13 @@ pub fn make( HnswRam { raw, quantization, - d, m, graph, visited, } } -pub fn save(mut ram: HnswRam, path: PathBuf) -> HnswMmap { +pub fn save(mut ram: HnswRam, path: PathBuf) -> HnswMmap { let edges = MmapArray::create( path.join("edges"), ram.graph @@ -369,7 +345,6 @@ pub fn save(mut ram: HnswRam, path: PathBuf) -> HnswMmap { HnswMmap { raw: ram.raw, quantization: ram.quantization, - d: ram.d, m: ram.m, edges, by_layer_id, @@ -378,7 +353,7 @@ pub fn save(mut ram: HnswRam, path: PathBuf) -> HnswMmap { } } -pub fn load(path: PathBuf, options: IndexOptions) -> HnswMmap { +pub fn load(path: PathBuf, options: IndexOptions) -> HnswMmap { let idx_opts = options.indexing.clone().unwrap_hnsw(); let raw = Arc::new(Raw::open(path.join("raw"), options.clone())); let quantization = Quantization::open( @@ -395,7 +370,6 @@ pub fn load(path: PathBuf, options: IndexOptions) -> HnswMmap { HnswMmap { raw, quantization, - d: options.vector.d, m: idx_opts.m, edges, by_layer_id, @@ -404,7 +378,12 @@ pub fn load(path: PathBuf, options: IndexOptions) -> HnswMmap { } } -pub fn search(mmap: &HnswMmap, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { +pub fn search( + mmap: &HnswMmap, + k: usize, + vector: &[S::Scalar], + filter: &mut impl Filter, +) -> Heap { let Some(s) = entry(mmap, filter) else { return Heap::new(k); }; @@ -413,11 +392,11 @@ pub fn search(mmap: &HnswMmap, k: usize, vector: &[Scalar], filter: &mut impl Fi local_search(mmap, k, u, vector, filter) } -pub fn search_vbase<'index, 'vector>( - mmap: &'index HnswMmap, +pub fn search_vbase<'a, S: G>( + mmap: &'a HnswMmap, range: usize, - vector: &'vector [Scalar], -) -> HnswIndexIter<'index, 'vector> { + vector: &[S::Scalar], +) -> HnswIndexIter<'a, S> { let filter_fn = &mut |_| true; let Some(s) = entry(mmap, filter_fn) else { return HnswIndexIter(None); @@ -427,7 +406,7 @@ pub fn search_vbase<'index, 'vector>( local_search_vbase(mmap, range, u, vector) } -pub fn entry(mmap: &HnswMmap, filter: &mut impl Filter) -> Option { +pub fn entry(mmap: &HnswMmap, filter: &mut impl Filter) -> Option { let m = mmap.m; let n = mmap.raw.len(); let mut shift = 1u64; @@ -455,15 +434,15 @@ pub fn entry(mmap: &HnswMmap, filter: &mut impl Filter) -> Option { None } -pub fn fast_search( - mmap: &HnswMmap, +pub fn fast_search( + mmap: &HnswMmap, levels: RangeInclusive, u: u32, - vector: &[Scalar], + vector: &[S::Scalar], filter: &mut impl Filter, ) -> u32 { let mut u = u; - let mut u_dis = mmap.quantization.distance(mmap.d, vector, u); + let mut u_dis = mmap.quantization.distance(vector, u); for i in levels.rev() { let mut changed = true; while changed { @@ -473,7 +452,7 @@ pub fn fast_search( if !filter.check(mmap.raw.payload(v)) { continue; } - let v_dis = mmap.quantization.distance(mmap.d, vector, v); + let v_dis = mmap.quantization.distance(vector, v); if v_dis < u_dis { u = v; u_dis = v_dis; @@ -485,20 +464,20 @@ pub fn fast_search( u } -pub fn local_search( - mmap: &HnswMmap, +pub fn local_search( + mmap: &HnswMmap, k: usize, s: u32, - vector: &[Scalar], + vector: &[S::Scalar], filter: &mut impl Filter, ) -> Heap { assert!(k > 0); let mut visited = mmap.visited.fetch(); let mut visited = visited.fetch(); - let mut candidates = BinaryHeap::>::new(); + let mut candidates = BinaryHeap::>::new(); let mut results = Heap::new(k); visited.mark(s); - let s_dis = mmap.quantization.distance(mmap.d, vector, s); + let s_dis = mmap.quantization.distance(vector, s); candidates.push(Reverse((s_dis, s))); results.push(HeapElement { distance: s_dis, @@ -517,7 +496,7 @@ pub fn local_search( if !filter.check(mmap.raw.payload(v)) { continue; } - let v_dis = mmap.quantization.distance(mmap.d, vector, v); + let v_dis = mmap.quantization.distance(vector, v); if !results.check(v_dis) { continue; } @@ -531,20 +510,20 @@ pub fn local_search( results } -fn local_search_vbase<'mmap, 'vector>( - mmap: &'mmap HnswMmap, +fn local_search_vbase<'a, S: G>( + mmap: &'a HnswMmap, range: usize, s: u32, - vector: &'vector [Scalar], -) -> HnswIndexIter<'mmap, 'vector> { + vector: &[S::Scalar], +) -> HnswIndexIter<'a, S> { assert!(range > 0); let mut visited_guard = mmap.visited.fetch(); let mut visited = visited_guard.fetch(); - let mut candidates = BinaryHeap::>::new(); + let mut candidates = BinaryHeap::>::new(); let mut results = Heap::new(range); let mut lost = Vec::>::new(); visited.mark(s); - let s_dis = mmap.quantization.distance(mmap.d, vector, s); + let s_dis = mmap.quantization.distance(vector, s); candidates.push(Reverse((s_dis, s))); results.push(HeapElement { distance: s_dis, @@ -561,7 +540,7 @@ fn local_search_vbase<'mmap, 'vector>( continue; } visited.mark(v); - let v_dis = mmap.quantization.distance(mmap.d, vector, v); + let v_dis = mmap.quantization.distance(vector, v); if !results.check(v_dis) { continue; } @@ -582,7 +561,7 @@ fn local_search_vbase<'mmap, 'vector>( results: results.into_reversed_heap(), lost, visited: visited_guard, - vector, + vector: vector.to_vec(), })) } @@ -614,7 +593,7 @@ fn caluate_offsets(iter: impl Iterator) -> impl Iterator &[HnswMmapEdge] { +fn find_edges(mmap: &HnswMmap, u: u32, level: u8) -> &[HnswMmapEdge] { let offset = u as usize; let index = mmap.by_vertex_id[offset]..mmap.by_vertex_id[offset + 1]; let offset = index.start + level as usize; @@ -670,7 +649,7 @@ impl<'a> Drop for VisitedGuard<'a> { fn drop(&mut self) { let src = VisitedBuffer { version: 0, - data: Box::new([]), + data: Vec::new(), }; let buffer = std::mem::replace(&mut self.buffer, src); self.pool.locked_buffers.lock().push(buffer); @@ -692,39 +671,39 @@ impl<'a> VisitedChecker<'a> { struct VisitedBuffer { version: usize, - data: Box<[usize]>, + data: Vec, } impl VisitedBuffer { fn new(capacity: usize) -> Self { Self { version: 0, - data: bytemuck::zeroed_slice_box(capacity), + data: bytemuck::zeroed_vec(capacity), } } } -pub struct HnswIndexIter<'mmap, 'vector>(Option>); +pub struct HnswIndexIter<'mmap, S: G>(Option>); -pub struct HnswIndexIterInner<'mmap, 'vector> { - mmap: &'mmap HnswMmap, +pub struct HnswIndexIterInner<'mmap, S: G> { + mmap: &'mmap HnswMmap, range: usize, - candidates: BinaryHeap>, + candidates: BinaryHeap>, results: BinaryHeap>, // The points lost in the first stage, we should keep it to the second stage. lost: Vec>, visited: VisitedGuard<'mmap>, - vector: &'vector [Scalar], + vector: Vec, } -impl Iterator for HnswIndexIter<'_, '_> { +impl Iterator for HnswIndexIter<'_, S> { type Item = HeapElement; fn next(&mut self) -> Option { self.0.as_mut()?.next() } } -impl Iterator for HnswIndexIterInner<'_, '_> { +impl Iterator for HnswIndexIterInner<'_, S> { type Item = HeapElement; fn next(&mut self) -> Option { if self.results.len() > self.range { @@ -739,7 +718,7 @@ impl Iterator for HnswIndexIterInner<'_, '_> { continue; } visited.mark(v); - let v_dis = self.mmap.quantization.distance(self.mmap.d, self.vector, v); + let v_dis = self.mmap.quantization.distance(&self.vector, v); self.candidates.push(Reverse((v_dis, v))); self.results.push(Reverse(HeapElement { distance: v_dis, @@ -755,7 +734,7 @@ impl Iterator for HnswIndexIterInner<'_, '_> { } } -impl HnswIndexIterInner<'_, '_> { +impl HnswIndexIterInner<'_, S> { fn pop(&mut self) -> Option { if self.results.peek() > self.lost.last() { self.results.pop().map(|x| x.0) diff --git a/src/algorithms/ivf/ivf_naive.rs b/crates/service/src/algorithms/ivf/ivf_naive.rs similarity index 77% rename from src/algorithms/ivf/ivf_naive.rs rename to crates/service/src/algorithms/ivf/ivf_naive.rs index 06920264a..2d4e0ea1d 100644 --- a/src/algorithms/ivf/ivf_naive.rs +++ b/crates/service/src/algorithms/ivf/ivf_naive.rs @@ -20,16 +20,16 @@ use std::sync::atomic::AtomicU32; use std::sync::atomic::Ordering::{Acquire, Relaxed, Release}; use std::sync::Arc; -pub struct IvfNaive { - mmap: IvfMmap, +pub struct IvfNaive { + mmap: IvfMmap, } -impl IvfNaive { +impl IvfNaive { pub fn create( path: PathBuf, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Self { create_dir(&path).unwrap(); let ram = make(path.clone(), sealed, growing, options); @@ -47,7 +47,7 @@ impl IvfNaive { self.mmap.raw.len() } - pub fn vector(&self, i: u32) -> &[Scalar] { + pub fn vector(&self, i: u32) -> &[S::Scalar] { self.mmap.raw.vector(i) } @@ -55,65 +55,63 @@ impl IvfNaive { self.mmap.raw.payload(i) } - pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { + pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap { search(&self.mmap, k, vector, filter) } } -unsafe impl Send for IvfNaive {} -unsafe impl Sync for IvfNaive {} +unsafe impl Send for IvfNaive {} +unsafe impl Sync for IvfNaive {} -pub struct IvfRam { - raw: Arc, - quantization: Quantization, +pub struct IvfRam { + raw: Arc>, + quantization: Quantization, // ---------------------- dims: u16, - d: Distance, // ---------------------- nlist: u32, nprobe: u32, // ---------------------- - centroids: Vec2, + centroids: Vec2, heads: Vec, nexts: Vec>, } -unsafe impl Send for IvfRam {} -unsafe impl Sync for IvfRam {} +unsafe impl Send for IvfRam {} +unsafe impl Sync for IvfRam {} -pub struct IvfMmap { - raw: Arc, - quantization: Quantization, +pub struct IvfMmap { + raw: Arc>, + quantization: Quantization, // ---------------------- dims: u16, - d: Distance, // ---------------------- nlist: u32, nprobe: u32, // ---------------------- - centroids: MmapArray, + centroids: MmapArray, heads: MmapArray, nexts: MmapArray, } -unsafe impl Send for IvfMmap {} -unsafe impl Sync for IvfMmap {} +unsafe impl Send for IvfMmap {} +unsafe impl Sync for IvfMmap {} -impl IvfMmap { - fn centroids(&self, i: u32) -> &[Scalar] { +impl IvfMmap { + fn centroids(&self, i: u32) -> &[S::Scalar] { let s = i as usize * self.dims as usize; let e = (i + 1) as usize * self.dims as usize; &self.centroids[s..e] } } -pub fn make( +pub fn make( path: PathBuf, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, options: IndexOptions, -) -> IvfRam { - let VectorOptions { dims, d } = options.vector; +) -> IvfRam { + let VectorOptions { dims, .. } = options.vector; let IvfIndexingOptions { least_iterations, iterations, @@ -140,9 +138,9 @@ pub fn make( let mut samples = Vec2::new(dims, m as usize); for i in 0..m { samples[i as usize].copy_from_slice(raw.vector(f[i as usize] as u32)); - d.elkan_k_means_normalize(&mut samples[i as usize]); + S::elkan_k_means_normalize(&mut samples[i as usize]); } - let mut k_means = ElkanKMeans::new(nlist as usize, samples, d); + let mut k_means = ElkanKMeans::new(nlist as usize, samples); for _ in 0..least_iterations { k_means.iterate(); } @@ -164,10 +162,10 @@ pub fn make( }; (0..n).into_par_iter().for_each(|i| { let mut vector = raw.vector(i).to_vec(); - d.elkan_k_means_normalize(&mut vector); - let mut result = (Scalar::INFINITY, 0); + S::elkan_k_means_normalize(&mut vector); + let mut result = (F32::infinity(), 0); for i in 0..nlist { - let dis = d.elkan_k_means_distance(&vector, ¢roids[i as usize]); + let dis = S::elkan_k_means_distance(&vector, ¢roids[i as usize]); result = std::cmp::min(result, (dis, i)); } let centroid_id = result.1; @@ -191,11 +189,10 @@ pub fn make( nprobe, nlist, dims, - d, } } -pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap { +pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap { let centroids = MmapArray::create( path.join("centroids"), (0..ram.nlist) @@ -214,7 +211,6 @@ pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap { raw: ram.raw, quantization: ram.quantization, dims: ram.dims, - d: ram.d, nlist: ram.nlist, nprobe: ram.nprobe, centroids, @@ -223,7 +219,7 @@ pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap { } } -pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap { +pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap { let raw = Arc::new(Raw::open(path.join("raw"), options.clone())); let quantization = Quantization::open( path.join("quantization"), @@ -239,7 +235,6 @@ pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap { raw, quantization, dims: options.vector.dims, - d: options.vector.d, nlist, nprobe, centroids, @@ -248,13 +243,18 @@ pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap { } } -pub fn search(mmap: &IvfMmap, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { +pub fn search( + mmap: &IvfMmap, + k: usize, + vector: &[S::Scalar], + filter: &mut impl Filter, +) -> Heap { let mut target = vector.to_vec(); - mmap.d.elkan_k_means_normalize(&mut target); + S::elkan_k_means_normalize(&mut target); let mut lists = Heap::new(mmap.nprobe as usize); for i in 0..mmap.nlist { let centroid = mmap.centroids(i); - let distance = mmap.d.elkan_k_means_distance(&target, centroid); + let distance = S::elkan_k_means_distance(&target, centroid); if lists.check(distance) { lists.push(HeapElement { distance, @@ -267,7 +267,7 @@ pub fn search(mmap: &IvfMmap, k: usize, vector: &[Scalar], filter: &mut impl Fil for i in lists.iter().map(|e| e.payload as usize) { let mut j = mmap.heads[i]; while u32::MAX != j { - let distance = mmap.quantization.distance(mmap.d, vector, j); + let distance = mmap.quantization.distance(vector, j); let payload = mmap.raw.payload(j); if result.check(distance) && filter.check(payload) { result.push(HeapElement { distance, payload }); diff --git a/src/algorithms/ivf/ivf_pq.rs b/crates/service/src/algorithms/ivf/ivf_pq.rs similarity index 77% rename from src/algorithms/ivf/ivf_pq.rs rename to crates/service/src/algorithms/ivf/ivf_pq.rs index e1ac8b7de..33e5621cc 100644 --- a/src/algorithms/ivf/ivf_pq.rs +++ b/crates/service/src/algorithms/ivf/ivf_pq.rs @@ -20,16 +20,16 @@ use std::sync::atomic::AtomicU32; use std::sync::atomic::Ordering::{Acquire, Relaxed, Release}; use std::sync::Arc; -pub struct IvfPq { - mmap: IvfMmap, +pub struct IvfPq { + mmap: IvfMmap, } -impl IvfPq { +impl IvfPq { pub fn create( path: PathBuf, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Self { create_dir(&path).unwrap(); let ram = make(path.clone(), sealed, growing, options); @@ -47,7 +47,7 @@ impl IvfPq { self.mmap.raw.len() } - pub fn vector(&self, i: u32) -> &[Scalar] { + pub fn vector(&self, i: u32) -> &[S::Scalar] { self.mmap.raw.vector(i) } @@ -55,65 +55,63 @@ impl IvfPq { self.mmap.raw.payload(i) } - pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { + pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap { search(&self.mmap, k, vector, filter) } } -unsafe impl Send for IvfPq {} -unsafe impl Sync for IvfPq {} +unsafe impl Send for IvfPq {} +unsafe impl Sync for IvfPq {} -pub struct IvfRam { - raw: Arc, - quantization: ProductQuantization, +pub struct IvfRam { + raw: Arc>, + quantization: ProductQuantization, // ---------------------- dims: u16, - d: Distance, // ---------------------- nlist: u32, nprobe: u32, // ---------------------- - centroids: Vec2, + centroids: Vec2, heads: Vec, nexts: Vec>, } -unsafe impl Send for IvfRam {} -unsafe impl Sync for IvfRam {} +unsafe impl Send for IvfRam {} +unsafe impl Sync for IvfRam {} -pub struct IvfMmap { - raw: Arc, - quantization: ProductQuantization, +pub struct IvfMmap { + raw: Arc>, + quantization: ProductQuantization, // ---------------------- dims: u16, - d: Distance, // ---------------------- nlist: u32, nprobe: u32, // ---------------------- - centroids: MmapArray, + centroids: MmapArray, heads: MmapArray, nexts: MmapArray, } -unsafe impl Send for IvfMmap {} -unsafe impl Sync for IvfMmap {} +unsafe impl Send for IvfMmap {} +unsafe impl Sync for IvfMmap {} -impl IvfMmap { - fn centroids(&self, i: u32) -> &[Scalar] { +impl IvfMmap { + fn centroids(&self, i: u32) -> &[S::Scalar] { let s = i as usize * self.dims as usize; let e = (i + 1) as usize * self.dims as usize; &self.centroids[s..e] } } -pub fn make( +pub fn make( path: PathBuf, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, options: IndexOptions, -) -> IvfRam { - let VectorOptions { dims, d } = options.vector; +) -> IvfRam { + let VectorOptions { dims, .. } = options.vector; let IvfIndexingOptions { least_iterations, iterations, @@ -134,9 +132,9 @@ pub fn make( let mut samples = Vec2::new(dims, m as usize); for i in 0..m { samples[i as usize].copy_from_slice(raw.vector(f[i as usize] as u32)); - d.elkan_k_means_normalize(&mut samples[i as usize]); + S::elkan_k_means_normalize(&mut samples[i as usize]); } - let mut k_means = ElkanKMeans::new(nlist as usize, samples, d); + let mut k_means = ElkanKMeans::new(nlist as usize, samples); for _ in 0..least_iterations { k_means.iterate(); } @@ -163,10 +161,10 @@ pub fn make( &raw, |i, target| { let mut vector = target.to_vec(); - d.elkan_k_means_normalize(&mut vector); - let mut result = (Scalar::INFINITY, 0); + S::elkan_k_means_normalize(&mut vector); + let mut result = (F32::infinity(), 0); for i in 0..nlist { - let dis = d.elkan_k_means_distance(&vector, ¢roids[i as usize]); + let dis = S::elkan_k_means_distance(&vector, ¢roids[i as usize]); result = std::cmp::min(result, (dis, i)); } let centroid_id = result.1; @@ -194,11 +192,10 @@ pub fn make( nprobe, nlist, dims, - d, } } -pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap { +pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap { let centroids = MmapArray::create( path.join("centroids"), (0..ram.nlist) @@ -217,7 +214,6 @@ pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap { raw: ram.raw, quantization: ram.quantization, dims: ram.dims, - d: ram.d, nlist: ram.nlist, nprobe: ram.nprobe, centroids, @@ -226,7 +222,7 @@ pub fn save(mut ram: IvfRam, path: PathBuf) -> IvfMmap { } } -pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap { +pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap { let raw = Arc::new(Raw::open(path.join("raw"), options.clone())); let quantization = ProductQuantization::open( path.join("quantization"), @@ -242,7 +238,6 @@ pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap { raw, quantization, dims: options.vector.dims, - d: options.vector.d, nlist, nprobe, centroids, @@ -251,13 +246,18 @@ pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap { } } -pub fn search(mmap: &IvfMmap, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { +pub fn search( + mmap: &IvfMmap, + k: usize, + vector: &[S::Scalar], + filter: &mut impl Filter, +) -> Heap { let mut target = vector.to_vec(); - mmap.d.elkan_k_means_normalize(&mut target); + S::elkan_k_means_normalize(&mut target); let mut lists = Heap::new(mmap.nprobe as usize); for i in 0..mmap.nlist { let centroid = mmap.centroids(i); - let distance = mmap.d.elkan_k_means_distance(&target, centroid); + let distance = S::elkan_k_means_distance(&target, centroid); if lists.check(distance) { lists.push(HeapElement { distance, @@ -270,9 +270,9 @@ pub fn search(mmap: &IvfMmap, k: usize, vector: &[Scalar], filter: &mut impl Fil for i in lists.iter().map(|e| e.payload as u32) { let mut j = mmap.heads[i as usize]; while u32::MAX != j { - let distance = - mmap.quantization - .distance_with_delta(mmap.d, vector, j, mmap.centroids(i)); + let distance = mmap + .quantization + .distance_with_delta(vector, j, mmap.centroids(i)); let payload = mmap.raw.payload(j); if result.check(distance) && filter.check(payload) { result.push(HeapElement { distance, payload }); diff --git a/src/algorithms/ivf/mod.rs b/crates/service/src/algorithms/ivf/mod.rs similarity index 84% rename from src/algorithms/ivf/mod.rs rename to crates/service/src/algorithms/ivf/mod.rs index 9bf645421..1fa73ee66 100644 --- a/src/algorithms/ivf/mod.rs +++ b/crates/service/src/algorithms/ivf/mod.rs @@ -10,17 +10,17 @@ use crate::prelude::*; use std::path::PathBuf; use std::sync::Arc; -pub enum Ivf { - Naive(IvfNaive), - Pq(IvfPq), +pub enum Ivf { + Naive(IvfNaive), + Pq(IvfPq), } -impl Ivf { +impl Ivf { pub fn create( path: PathBuf, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Self { if options .indexing @@ -56,7 +56,7 @@ impl Ivf { } } - pub fn vector(&self, i: u32) -> &[Scalar] { + pub fn vector(&self, i: u32) -> &[S::Scalar] { match self { Ivf::Naive(x) => x.vector(i), Ivf::Pq(x) => x.vector(i), @@ -70,7 +70,7 @@ impl Ivf { } } - pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { + pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap { match self { Ivf::Naive(x) => x.search(k, vector, filter), Ivf::Pq(x) => x.search(k, vector, filter), diff --git a/src/algorithms/mod.rs b/crates/service/src/algorithms/mod.rs similarity index 84% rename from src/algorithms/mod.rs rename to crates/service/src/algorithms/mod.rs index 9c20f5655..a3c5ffd52 100644 --- a/src/algorithms/mod.rs +++ b/crates/service/src/algorithms/mod.rs @@ -1,5 +1,4 @@ pub mod clustering; -pub mod diskann; pub mod flat; pub mod hnsw; pub mod ivf; diff --git a/src/algorithms/quantization/mod.rs b/crates/service/src/algorithms/quantization/mod.rs similarity index 80% rename from src/algorithms/quantization/mod.rs rename to crates/service/src/algorithms/quantization/mod.rs index 5b5e790d6..5e23c284e 100644 --- a/src/algorithms/quantization/mod.rs +++ b/crates/service/src/algorithms/quantization/mod.rs @@ -56,35 +56,35 @@ impl QuantizationOptions { } } -pub trait Quan { +pub trait Quan { fn create( path: PathBuf, options: IndexOptions, quantization_options: QuantizationOptions, - raw: &Arc, + raw: &Arc>, ) -> Self; fn open( path: PathBuf, options: IndexOptions, quantization_options: QuantizationOptions, - raw: &Arc, + raw: &Arc>, ) -> Self; - fn distance(&self, d: Distance, lhs: &[Scalar], rhs: u32) -> Scalar; - fn distance2(&self, d: Distance, lhs: u32, rhs: u32) -> Scalar; + fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32; + fn distance2(&self, lhs: u32, rhs: u32) -> F32; } -pub enum Quantization { - Trivial(TrivialQuantization), - Scalar(ScalarQuantization), - Product(ProductQuantization), +pub enum Quantization { + Trivial(TrivialQuantization), + Scalar(ScalarQuantization), + Product(ProductQuantization), } -impl Quantization { +impl Quantization { pub fn create( path: PathBuf, options: IndexOptions, quantization_options: QuantizationOptions, - raw: &Arc, + raw: &Arc>, ) -> Self { match quantization_options { QuantizationOptions::Trivial(_) => Self::Trivial(TrivialQuantization::create( @@ -112,7 +112,7 @@ impl Quantization { path: PathBuf, options: IndexOptions, quantization_options: QuantizationOptions, - raw: &Arc, + raw: &Arc>, ) -> Self { match quantization_options { QuantizationOptions::Trivial(_) => Self::Trivial(TrivialQuantization::open( @@ -136,21 +136,21 @@ impl Quantization { } } - pub fn distance(&self, d: Distance, lhs: &[Scalar], rhs: u32) -> Scalar { + pub fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32 { use Quantization::*; match self { - Trivial(x) => x.distance(d, lhs, rhs), - Scalar(x) => x.distance(d, lhs, rhs), - Product(x) => x.distance(d, lhs, rhs), + Trivial(x) => x.distance(lhs, rhs), + Scalar(x) => x.distance(lhs, rhs), + Product(x) => x.distance(lhs, rhs), } } - pub fn distance2(&self, d: Distance, lhs: u32, rhs: u32) -> Scalar { + pub fn distance2(&self, lhs: u32, rhs: u32) -> F32 { use Quantization::*; match self { - Trivial(x) => x.distance2(d, lhs, rhs), - Scalar(x) => x.distance2(d, lhs, rhs), - Product(x) => x.distance2(d, lhs, rhs), + Trivial(x) => x.distance2(lhs, rhs), + Scalar(x) => x.distance2(lhs, rhs), + Product(x) => x.distance2(lhs, rhs), } } } diff --git a/src/algorithms/quantization/product.rs b/crates/service/src/algorithms/quantization/product.rs similarity index 81% rename from src/algorithms/quantization/product.rs rename to crates/service/src/algorithms/quantization/product.rs index a87c529f3..96855632a 100644 --- a/src/algorithms/quantization/product.rs +++ b/crates/service/src/algorithms/quantization/product.rs @@ -55,17 +55,17 @@ impl Default for ProductQuantizationOptionsRatio { } } -pub struct ProductQuantization { +pub struct ProductQuantization { dims: u16, ratio: u16, - centroids: Vec, + centroids: Vec, codes: MmapArray, } -unsafe impl Send for ProductQuantization {} -unsafe impl Sync for ProductQuantization {} +unsafe impl Send for ProductQuantization {} +unsafe impl Sync for ProductQuantization {} -impl ProductQuantization { +impl ProductQuantization { fn codes(&self, i: u32) -> &[u8] { let width = self.dims.div_ceil(self.ratio); let s = i as usize * width as usize; @@ -74,12 +74,12 @@ impl ProductQuantization { } } -impl Quan for ProductQuantization { +impl Quan for ProductQuantization { fn create( path: PathBuf, options: IndexOptions, quantization_options: QuantizationOptions, - raw: &Arc, + raw: &Arc>, ) -> Self { Self::with_normalizer(path, options, quantization_options, raw, |_, _| ()) } @@ -88,7 +88,7 @@ impl Quan for ProductQuantization { path: PathBuf, options: IndexOptions, quantization_options: QuantizationOptions, - _: &Arc, + _: &Arc>, ) -> Self { let centroids = serde_json::from_slice(&std::fs::read(path.join("centroids")).unwrap()).unwrap(); @@ -101,32 +101,32 @@ impl Quan for ProductQuantization { } } - fn distance(&self, d: Distance, lhs: &[Scalar], rhs: u32) -> Scalar { + fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32 { let dims = self.dims; let ratio = self.ratio; let rhs = self.codes(rhs); - d.product_quantization_distance(dims, ratio, &self.centroids, lhs, rhs) + S::product_quantization_distance(dims, ratio, &self.centroids, lhs, rhs) } - fn distance2(&self, d: Distance, lhs: u32, rhs: u32) -> Scalar { + fn distance2(&self, lhs: u32, rhs: u32) -> F32 { let dims = self.dims; let ratio = self.ratio; let lhs = self.codes(lhs); let rhs = self.codes(rhs); - d.product_quantization_distance2(dims, ratio, &self.centroids, lhs, rhs) + S::product_quantization_distance2(dims, ratio, &self.centroids, lhs, rhs) } } -impl ProductQuantization { +impl ProductQuantization { pub fn with_normalizer( path: PathBuf, options: IndexOptions, quantization_options: QuantizationOptions, - raw: &Raw, + raw: &Raw, normalizer: F, ) -> Self where - F: Fn(u32, &mut [Scalar]), + F: Fn(u32, &mut [S::Scalar]), { std::fs::create_dir(&path).unwrap(); let quantization_options = quantization_options.unwrap_product_quantization(); @@ -136,22 +136,22 @@ impl ProductQuantization { let m = std::cmp::min(n, quantization_options.sample); let samples = { let f = sample(&mut thread_rng(), n as usize, m as usize).into_vec(); - let mut samples = Vec2::new(options.vector.dims, m as usize); + let mut samples = Vec2::::new(options.vector.dims, m as usize); for i in 0..m { samples[i as usize].copy_from_slice(raw.vector(f[i as usize] as u32)); } samples }; let width = dims.div_ceil(ratio); - let mut centroids = vec![Scalar::Z; 256 * dims as usize]; + let mut centroids = vec![S::Scalar::zero(); 256 * dims as usize]; for i in 0..width { let subdims = std::cmp::min(ratio, dims - ratio * i); - let mut subsamples = Vec2::new(subdims, m as usize); + let mut subsamples = Vec2::::new(subdims, m as usize); for j in 0..m { let src = &samples[j as usize][(i * ratio) as usize..][..subdims as usize]; subsamples[j as usize].copy_from_slice(src); } - let mut k_means = ElkanKMeans::new(256, subsamples, Distance::L2); + let mut k_means = ElkanKMeans::::new(256, subsamples); for _ in 0..25 { if k_means.iterate() { break; @@ -170,13 +170,13 @@ impl ProductQuantization { let mut result = Vec::with_capacity(width as usize); for i in 0..width { let subdims = std::cmp::min(ratio, dims - ratio * i); - let mut minimal = Scalar::INFINITY; + let mut minimal = F32::infinity(); let mut target = 0u8; let left = &vector[(i * ratio) as usize..][..subdims as usize]; for j in 0u8..=255 { let right = ¢roids[j as usize * dims as usize..][(i * ratio) as usize..] [..subdims as usize]; - let dis = Distance::L2.distance(left, right); + let dis = S::L2::distance(left, right); if dis < minimal { minimal = dis; target = j; @@ -201,16 +201,10 @@ impl ProductQuantization { } } - pub fn distance_with_delta( - &self, - d: Distance, - lhs: &[Scalar], - rhs: u32, - delta: &[Scalar], - ) -> Scalar { + pub fn distance_with_delta(&self, lhs: &[S::Scalar], rhs: u32, delta: &[S::Scalar]) -> F32 { let dims = self.dims; let ratio = self.ratio; let rhs = self.codes(rhs); - d.product_quantization_distance_with_delta(dims, ratio, &self.centroids, lhs, rhs, delta) + S::product_quantization_distance_with_delta(dims, ratio, &self.centroids, lhs, rhs, delta) } } diff --git a/src/algorithms/quantization/scalar.rs b/crates/service/src/algorithms/quantization/scalar.rs similarity index 75% rename from src/algorithms/quantization/scalar.rs rename to crates/service/src/algorithms/quantization/scalar.rs index 46d7f6b64..f6effab2b 100644 --- a/src/algorithms/quantization/scalar.rs +++ b/crates/service/src/algorithms/quantization/scalar.rs @@ -19,17 +19,17 @@ impl Default for ScalarQuantizationOptions { } } -pub struct ScalarQuantization { +pub struct ScalarQuantization { dims: u16, - max: Vec, - min: Vec, + max: Vec, + min: Vec, codes: MmapArray, } -unsafe impl Send for ScalarQuantization {} -unsafe impl Sync for ScalarQuantization {} +unsafe impl Send for ScalarQuantization {} +unsafe impl Sync for ScalarQuantization {} -impl ScalarQuantization { +impl ScalarQuantization { fn codes(&self, i: u32) -> &[u8] { let s = i as usize * self.dims as usize; let e = (i + 1) as usize * self.dims as usize; @@ -37,17 +37,17 @@ impl ScalarQuantization { } } -impl Quan for ScalarQuantization { +impl Quan for ScalarQuantization { fn create( path: PathBuf, options: IndexOptions, _: QuantizationOptions, - raw: &Arc, + raw: &Arc>, ) -> Self { std::fs::create_dir(&path).unwrap(); let dims = options.vector.dims; - let mut max = vec![Scalar::NEG_INFINITY; dims as usize]; - let mut min = vec![Scalar::INFINITY; dims as usize]; + let mut max = vec![S::Scalar::neg_infinity(); dims as usize]; + let mut min = vec![S::Scalar::infinity(); dims as usize]; let n = raw.len(); for i in 0..n { let vector = raw.vector(i); @@ -62,7 +62,7 @@ impl Quan for ScalarQuantization { let vector = raw.vector(i); let mut result = vec![0u8; dims as usize]; for i in 0..dims as usize { - let w = ((vector[i] - min[i]) / (max[i] - min[i]) * 256.0).0 as u32; + let w = (((vector[i] - min[i]) / (max[i] - min[i])).to_f32() * 256.0) as u32; result[i] = w.clamp(0, 255) as u8; } result.into_iter() @@ -77,7 +77,7 @@ impl Quan for ScalarQuantization { } } - fn open(path: PathBuf, options: IndexOptions, _: QuantizationOptions, _: &Arc) -> Self { + fn open(path: PathBuf, options: IndexOptions, _: QuantizationOptions, _: &Arc>) -> Self { let dims = options.vector.dims; let max = serde_json::from_slice(&std::fs::read("max").unwrap()).unwrap(); let min = serde_json::from_slice(&std::fs::read("min").unwrap()).unwrap(); @@ -90,16 +90,16 @@ impl Quan for ScalarQuantization { } } - fn distance(&self, d: Distance, lhs: &[Scalar], rhs: u32) -> Scalar { + fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32 { let dims = self.dims; let rhs = self.codes(rhs); - d.scalar_quantization_distance(dims, &self.max, &self.min, lhs, rhs) + S::scalar_quantization_distance(dims, &self.max, &self.min, lhs, rhs) } - fn distance2(&self, d: Distance, lhs: u32, rhs: u32) -> Scalar { + fn distance2(&self, lhs: u32, rhs: u32) -> F32 { let dims = self.dims; let lhs = self.codes(lhs); let rhs = self.codes(rhs); - d.scalar_quantization_distance2(dims, &self.max, &self.min, lhs, rhs) + S::scalar_quantization_distance2(dims, &self.max, &self.min, lhs, rhs) } } diff --git a/src/algorithms/quantization/trivial.rs b/crates/service/src/algorithms/quantization/trivial.rs similarity index 64% rename from src/algorithms/quantization/trivial.rs rename to crates/service/src/algorithms/quantization/trivial.rs index 7bff6510f..416e398a5 100644 --- a/src/algorithms/quantization/trivial.rs +++ b/crates/service/src/algorithms/quantization/trivial.rs @@ -17,24 +17,24 @@ impl Default for TrivialQuantizationOptions { } } -pub struct TrivialQuantization { - raw: Arc, +pub struct TrivialQuantization { + raw: Arc>, } -impl Quan for TrivialQuantization { - fn create(_: PathBuf, _: IndexOptions, _: QuantizationOptions, raw: &Arc) -> Self { +impl Quan for TrivialQuantization { + fn create(_: PathBuf, _: IndexOptions, _: QuantizationOptions, raw: &Arc>) -> Self { Self { raw: raw.clone() } } - fn open(_: PathBuf, _: IndexOptions, _: QuantizationOptions, raw: &Arc) -> Self { + fn open(_: PathBuf, _: IndexOptions, _: QuantizationOptions, raw: &Arc>) -> Self { Self { raw: raw.clone() } } - fn distance(&self, d: Distance, lhs: &[Scalar], rhs: u32) -> Scalar { - d.distance(lhs, self.raw.vector(rhs)) + fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32 { + S::distance(lhs, self.raw.vector(rhs)) } - fn distance2(&self, d: Distance, lhs: u32, rhs: u32) -> Scalar { - d.distance(self.raw.vector(lhs), self.raw.vector(rhs)) + fn distance2(&self, lhs: u32, rhs: u32) -> F32 { + S::distance(self.raw.vector(lhs), self.raw.vector(rhs)) } } diff --git a/src/algorithms/raw.rs b/crates/service/src/algorithms/raw.rs similarity index 74% rename from src/algorithms/raw.rs rename to crates/service/src/algorithms/raw.rs index ec91673aa..f94ba9d59 100644 --- a/src/algorithms/raw.rs +++ b/crates/service/src/algorithms/raw.rs @@ -6,16 +6,16 @@ use crate::utils::mmap_array::MmapArray; use std::path::PathBuf; use std::sync::Arc; -pub struct Raw { - mmap: RawMmap, +pub struct Raw { + mmap: RawMmap, } -impl Raw { +impl Raw { pub fn create( path: PathBuf, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Self { std::fs::create_dir(&path).unwrap(); let ram = make(sealed, growing, options); @@ -33,7 +33,7 @@ impl Raw { self.mmap.len() } - pub fn vector(&self, i: u32) -> &[Scalar] { + pub fn vector(&self, i: u32) -> &[S::Scalar] { self.mmap.vector(i) } @@ -42,21 +42,21 @@ impl Raw { } } -unsafe impl Send for Raw {} -unsafe impl Sync for Raw {} +unsafe impl Send for Raw {} +unsafe impl Sync for Raw {} -struct RawRam { - sealed: Vec>, - growing: Vec>, +struct RawRam { + sealed: Vec>>, + growing: Vec>>, dims: u16, } -impl RawRam { +impl RawRam { fn len(&self) -> u32 { self.sealed.iter().map(|x| x.len()).sum::() + self.growing.iter().map(|x| x.len()).sum::() } - fn vector(&self, mut index: u32) -> &[Scalar] { + fn vector(&self, mut index: u32) -> &[S::Scalar] { for x in self.sealed.iter() { if index < x.len() { return x.vector(index); @@ -88,18 +88,18 @@ impl RawRam { } } -struct RawMmap { - vectors: MmapArray, +struct RawMmap { + vectors: MmapArray, payload: MmapArray, dims: u16, } -impl RawMmap { +impl RawMmap { fn len(&self) -> u32 { self.payload.len() as u32 } - fn vector(&self, i: u32) -> &[Scalar] { + fn vector(&self, i: u32) -> &[S::Scalar] { let s = i as usize * self.dims as usize; let e = (i + 1) as usize * self.dims as usize; &self.vectors[s..e] @@ -110,14 +110,14 @@ impl RawMmap { } } -unsafe impl Send for RawMmap {} -unsafe impl Sync for RawMmap {} +unsafe impl Send for RawMmap {} +unsafe impl Sync for RawMmap {} -fn make( - sealed: Vec>, - growing: Vec>, +fn make( + sealed: Vec>>, + growing: Vec>>, options: IndexOptions, -) -> RawRam { +) -> RawRam { RawRam { sealed, growing, @@ -125,7 +125,7 @@ fn make( } } -fn save(ram: RawRam, path: PathBuf) -> RawMmap { +fn save(ram: RawRam, path: PathBuf) -> RawMmap { let n = ram.len(); let vectors_iter = (0..n).flat_map(|i| ram.vector(i)).copied(); let payload_iter = (0..n).map(|i| ram.payload(i)); @@ -138,8 +138,8 @@ fn save(ram: RawRam, path: PathBuf) -> RawMmap { } } -fn load(path: PathBuf, options: IndexOptions) -> RawMmap { - let vectors: MmapArray = MmapArray::open(path.join("vectors")); +fn load(path: PathBuf, options: IndexOptions) -> RawMmap { + let vectors = MmapArray::open(path.join("vectors")); let payload = MmapArray::open(path.join("payload")); RawMmap { vectors, diff --git a/src/algorithms/diskann/vamana.rs b/crates/service/src/algorithms/vamana.rs.txt similarity index 100% rename from src/algorithms/diskann/vamana.rs rename to crates/service/src/algorithms/vamana.rs.txt diff --git a/src/index/delete.rs b/crates/service/src/index/delete.rs similarity index 100% rename from src/index/delete.rs rename to crates/service/src/index/delete.rs diff --git a/src/index/indexing/flat.rs b/crates/service/src/index/indexing/flat.rs similarity index 77% rename from src/index/indexing/flat.rs rename to crates/service/src/index/indexing/flat.rs index becb27f67..d288216f2 100644 --- a/src/index/indexing/flat.rs +++ b/crates/service/src/index/indexing/flat.rs @@ -24,16 +24,16 @@ impl Default for FlatIndexingOptions { } } -pub struct FlatIndexing { - raw: crate::algorithms::flat::Flat, +pub struct FlatIndexing { + raw: crate::algorithms::flat::Flat, } -impl AbstractIndexing for FlatIndexing { +impl AbstractIndexing for FlatIndexing { fn create( path: PathBuf, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Self { let raw = Flat::create(path, options, sealed, growing); Self { raw } @@ -48,7 +48,7 @@ impl AbstractIndexing for FlatIndexing { self.raw.len() } - fn vector(&self, i: u32) -> &[Scalar] { + fn vector(&self, i: u32) -> &[S::Scalar] { self.raw.vector(i) } @@ -56,7 +56,7 @@ impl AbstractIndexing for FlatIndexing { self.raw.payload(i) } - fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { + fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap { self.raw.search(k, vector, filter) } } diff --git a/src/index/indexing/hnsw.rs b/crates/service/src/index/indexing/hnsw.rs similarity index 79% rename from src/index/indexing/hnsw.rs rename to crates/service/src/index/indexing/hnsw.rs index 7ba4b0aa2..d1bc8eea8 100644 --- a/src/index/indexing/hnsw.rs +++ b/crates/service/src/index/indexing/hnsw.rs @@ -41,16 +41,16 @@ impl Default for HnswIndexingOptions { } } -pub struct HnswIndexing { - raw: Hnsw, +pub struct HnswIndexing { + raw: Hnsw, } -impl AbstractIndexing for HnswIndexing { +impl AbstractIndexing for HnswIndexing { fn create( path: PathBuf, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Self { let raw = Hnsw::create(path, options, sealed, growing); Self { raw } @@ -65,7 +65,7 @@ impl AbstractIndexing for HnswIndexing { self.raw.len() } - fn vector(&self, i: u32) -> &[Scalar] { + fn vector(&self, i: u32) -> &[S::Scalar] { self.raw.vector(i) } @@ -73,17 +73,13 @@ impl AbstractIndexing for HnswIndexing { self.raw.payload(i) } - fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { + fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap { self.raw.search(k, vector, filter) } } -impl HnswIndexing { - pub fn search_vbase<'index, 'vector>( - &'index self, - range: usize, - vector: &'vector [Scalar], - ) -> HnswIndexIter<'index, 'vector> { +impl HnswIndexing { + pub fn search_vbase(&self, range: usize, vector: &[S::Scalar]) -> HnswIndexIter<'_, S> { self.raw.search_vbase(range, vector) } } diff --git a/src/index/indexing/ivf.rs b/crates/service/src/index/indexing/ivf.rs similarity index 88% rename from src/index/indexing/ivf.rs rename to crates/service/src/index/indexing/ivf.rs index db83d52c9..9d074864a 100644 --- a/src/index/indexing/ivf.rs +++ b/crates/service/src/index/indexing/ivf.rs @@ -4,7 +4,6 @@ use crate::algorithms::quantization::QuantizationOptions; use crate::index::segments::growing::GrowingSegment; use crate::index::segments::sealed::SealedSegment; use crate::index::IndexOptions; -use crate::prelude::Scalar; use crate::prelude::*; use serde::{Deserialize, Serialize}; use std::path::PathBuf; @@ -64,16 +63,16 @@ impl Default for IvfIndexingOptions { } } -pub struct IvfIndexing { - raw: Ivf, +pub struct IvfIndexing { + raw: Ivf, } -impl AbstractIndexing for IvfIndexing { +impl AbstractIndexing for IvfIndexing { fn create( path: PathBuf, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Self { let raw = Ivf::create(path, options, sealed, growing); Self { raw } @@ -88,7 +87,7 @@ impl AbstractIndexing for IvfIndexing { self.raw.len() } - fn vector(&self, i: u32) -> &[Scalar] { + fn vector(&self, i: u32) -> &[S::Scalar] { self.raw.vector(i) } @@ -96,7 +95,7 @@ impl AbstractIndexing for IvfIndexing { self.raw.payload(i) } - fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { + fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap { self.raw.search(k, vector, filter) } } diff --git a/src/index/indexing/mod.rs b/crates/service/src/index/indexing/mod.rs similarity index 81% rename from src/index/indexing/mod.rs rename to crates/service/src/index/indexing/mod.rs index cc931dad2..ec8839295 100644 --- a/src/index/indexing/mod.rs +++ b/crates/service/src/index/indexing/mod.rs @@ -60,36 +60,36 @@ impl Validate for IndexingOptions { } } -pub trait AbstractIndexing: Sized { +pub trait AbstractIndexing: Sized { fn create( path: PathBuf, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Self; fn open(path: PathBuf, options: IndexOptions) -> Self; fn len(&self) -> u32; - fn vector(&self, i: u32) -> &[Scalar]; + fn vector(&self, i: u32) -> &[S::Scalar]; fn payload(&self, i: u32) -> Payload; - fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap; + fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap; } -pub enum DynamicIndexing { - Flat(FlatIndexing), - Ivf(IvfIndexing), - Hnsw(HnswIndexing), +pub enum DynamicIndexing { + Flat(FlatIndexing), + Ivf(IvfIndexing), + Hnsw(HnswIndexing), } -pub enum DynamicIndexIter<'index, 'vector> { - Hnsw(HnswIndexIter<'index, 'vector>), +pub enum DynamicIndexIter<'a, S: G> { + Hnsw(HnswIndexIter<'a, S>), } -impl DynamicIndexing { +impl DynamicIndexing { pub fn create( path: PathBuf, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Self { match options.indexing { IndexingOptions::Flat(_) => { @@ -120,7 +120,7 @@ impl DynamicIndexing { } } - pub fn vector(&self, i: u32) -> &[Scalar] { + pub fn vector(&self, i: u32) -> &[S::Scalar] { match self { DynamicIndexing::Flat(x) => x.vector(i), DynamicIndexing::Ivf(x) => x.vector(i), @@ -136,7 +136,7 @@ impl DynamicIndexing { } } - pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { + pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap { match self { DynamicIndexing::Flat(x) => x.search(k, vector, filter), DynamicIndexing::Ivf(x) => x.search(k, vector, filter), @@ -144,11 +144,7 @@ impl DynamicIndexing { } } - pub fn search_vbase<'index, 'vector>( - &'index self, - range: usize, - vector: &'vector [Scalar], - ) -> DynamicIndexIter<'index, 'vector> { + pub fn vbase(&self, range: usize, vector: &[S::Scalar]) -> DynamicIndexIter<'_, S> { use DynamicIndexIter::*; match self { DynamicIndexing::Hnsw(x) => Hnsw(x.search_vbase(range, vector)), @@ -157,7 +153,7 @@ impl DynamicIndexing { } } -impl Iterator for DynamicIndexIter<'_, '_> { +impl Iterator for DynamicIndexIter<'_, S> { type Item = HeapElement; fn next(&mut self) -> Option { use DynamicIndexIter::*; diff --git a/crates/service/src/index/mod.rs b/crates/service/src/index/mod.rs new file mode 100644 index 000000000..2d30902a0 --- /dev/null +++ b/crates/service/src/index/mod.rs @@ -0,0 +1,506 @@ +pub mod delete; +pub mod indexing; +pub mod optimizing; +pub mod segments; + +use self::delete::Delete; +use self::indexing::IndexingOptions; +use self::optimizing::OptimizingOptions; +use self::segments::growing::GrowingSegment; +use self::segments::growing::GrowingSegmentInsertError; +use self::segments::sealed::SealedSegment; +use self::segments::SegmentsOptions; +use crate::index::indexing::DynamicIndexIter; +use crate::index::optimizing::indexing::OptimizerIndexing; +use crate::index::optimizing::sealing::OptimizerSealing; +use crate::prelude::*; +use crate::utils::clean::clean; +use crate::utils::dir_ops::sync_dir; +use crate::utils::file_atomic::FileAtomic; +use arc_swap::ArcSwap; +use crossbeam::atomic::AtomicCell; +use parking_lot::Mutex; +use serde::{Deserialize, Serialize}; +use std::cmp::Reverse; +use std::collections::BinaryHeap; +use std::collections::HashMap; +use std::collections::HashSet; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Instant; +use thiserror::Error; +use uuid::Uuid; +use validator::Validate; + +#[derive(Debug, Error)] +#[error("The index view is outdated.")] +pub struct OutdatedError(#[from] pub Option); + +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +pub struct VectorOptions { + #[validate(range(min = 1, max = 65535))] + #[serde(rename = "dimensions")] + pub dims: u16, + #[serde(rename = "distance")] + pub d: Distance, + #[serde(rename = "kind")] + pub k: Kind, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +pub struct IndexOptions { + #[validate] + pub vector: VectorOptions, + #[validate] + pub segment: SegmentsOptions, + #[validate] + pub optimizing: OptimizingOptions, + #[validate] + pub indexing: IndexingOptions, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct IndexStat { + pub indexing: bool, + pub sealed: Vec, + pub growing: Vec, + pub write: u32, + pub options: IndexOptions, +} + +pub struct Index { + path: PathBuf, + options: IndexOptions, + delete: Arc, + protect: Mutex>, + view: ArcSwap>, + instant_index: AtomicCell, + instant_write: AtomicCell, + _tracker: Arc, +} + +impl Index { + pub fn create(path: PathBuf, options: IndexOptions) -> Arc { + assert!(options.validate().is_ok()); + std::fs::create_dir(&path).unwrap(); + std::fs::create_dir(path.join("segments")).unwrap(); + let startup = FileAtomic::create( + path.join("startup"), + IndexStartup { + sealeds: HashSet::new(), + growings: HashSet::new(), + }, + ); + let delete = Delete::create(path.join("delete")); + sync_dir(&path); + let index = Arc::new(Index { + path: path.clone(), + options: options.clone(), + delete: delete.clone(), + protect: Mutex::new(IndexProtect { + startup, + sealed: HashMap::new(), + growing: HashMap::new(), + write: None, + }), + view: ArcSwap::new(Arc::new(IndexView { + options: options.clone(), + sealed: HashMap::new(), + growing: HashMap::new(), + delete: delete.clone(), + write: None, + })), + instant_index: AtomicCell::new(Instant::now()), + instant_write: AtomicCell::new(Instant::now()), + _tracker: Arc::new(IndexTracker { path }), + }); + OptimizerIndexing::new(index.clone()).spawn(); + OptimizerSealing::new(index.clone()).spawn(); + index + } + pub fn open(path: PathBuf, options: IndexOptions) -> Arc { + let tracker = Arc::new(IndexTracker { path: path.clone() }); + let startup = FileAtomic::::open(path.join("startup")); + clean( + path.join("segments"), + startup + .get() + .sealeds + .iter() + .map(|s| s.to_string()) + .chain(startup.get().growings.iter().map(|s| s.to_string())), + ); + let sealed = startup + .get() + .sealeds + .iter() + .map(|&uuid| { + ( + uuid, + SealedSegment::open( + tracker.clone(), + path.join("segments").join(uuid.to_string()), + uuid, + options.clone(), + ), + ) + }) + .collect::>(); + let growing = startup + .get() + .growings + .iter() + .map(|&uuid| { + ( + uuid, + GrowingSegment::open( + tracker.clone(), + path.join("segments").join(uuid.to_string()), + uuid, + ), + ) + }) + .collect::>(); + let delete = Delete::open(path.join("delete")); + let index = Arc::new(Index { + path: path.clone(), + options: options.clone(), + delete: delete.clone(), + protect: Mutex::new(IndexProtect { + startup, + sealed: sealed.clone(), + growing: growing.clone(), + write: None, + }), + view: ArcSwap::new(Arc::new(IndexView { + options: options.clone(), + delete: delete.clone(), + sealed, + growing, + write: None, + })), + instant_index: AtomicCell::new(Instant::now()), + instant_write: AtomicCell::new(Instant::now()), + _tracker: tracker, + }); + OptimizerIndexing::new(index.clone()).spawn(); + OptimizerSealing::new(index.clone()).spawn(); + index + } + pub fn options(&self) -> &IndexOptions { + &self.options + } + pub fn view(&self) -> Arc> { + self.view.load_full() + } + pub fn refresh(&self) { + let mut protect = self.protect.lock(); + if let Some((uuid, write)) = protect.write.clone() { + if !write.is_full() { + return; + } + write.seal(); + protect.growing.insert(uuid, write); + } + let write_segment_uuid = Uuid::new_v4(); + let write_segment = GrowingSegment::create( + self._tracker.clone(), + self.path + .join("segments") + .join(write_segment_uuid.to_string()), + write_segment_uuid, + self.options.clone(), + ); + protect.write = Some((write_segment_uuid, write_segment)); + protect.maintain(self.options.clone(), self.delete.clone(), &self.view); + self.instant_write.store(Instant::now()); + } + pub fn seal(&self, check: Uuid) { + let mut protect = self.protect.lock(); + if let Some((uuid, write)) = protect.write.clone() { + if check != uuid { + return; + } + write.seal(); + protect.growing.insert(uuid, write); + } + protect.write = None; + protect.maintain(self.options.clone(), self.delete.clone(), &self.view); + self.instant_write.store(Instant::now()); + } + pub fn stat(&self) -> IndexStat { + let view = self.view(); + IndexStat { + indexing: self.instant_index.load() < self.instant_write.load(), + sealed: view.sealed.values().map(|x| x.len()).collect(), + growing: view.growing.values().map(|x| x.len()).collect(), + write: view.write.as_ref().map(|(_, x)| x.len()).unwrap_or(0), + options: self.options().clone(), + } + } +} + +impl Drop for Index { + fn drop(&mut self) {} +} + +#[derive(Debug, Clone)] +pub struct IndexTracker { + path: PathBuf, +} + +impl Drop for IndexTracker { + fn drop(&mut self) { + std::fs::remove_dir_all(&self.path).unwrap(); + } +} + +pub struct IndexView { + pub options: IndexOptions, + pub delete: Arc, + pub sealed: HashMap>>, + pub growing: HashMap>>, + pub write: Option<(Uuid, Arc>)>, +} + +impl IndexView { + pub fn search bool>( + &self, + k: usize, + vector: &[S::Scalar], + mut filter: F, + ) -> Vec { + assert_eq!(self.options.vector.dims as usize, vector.len()); + + struct Comparer(BinaryHeap>); + + impl PartialEq for Comparer { + fn eq(&self, other: &Self) -> bool { + self.cmp(other).is_eq() + } + } + + impl Eq for Comparer {} + + impl PartialOrd for Comparer { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + + impl Ord for Comparer { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0.peek().cmp(&other.0.peek()).reverse() + } + } + + let mut filter = |payload| { + if let Some(p) = self.delete.check(payload) { + filter(p) + } else { + false + } + }; + let n = self.sealed.len() + self.growing.len() + 1; + let mut result = Heap::new(k); + let mut heaps = BinaryHeap::with_capacity(1 + n); + for (_, sealed) in self.sealed.iter() { + let p = sealed.search(k, vector, &mut filter).into_reversed_heap(); + heaps.push(Comparer(p)); + } + for (_, growing) in self.growing.iter() { + let p = growing.search(k, vector, &mut filter).into_reversed_heap(); + heaps.push(Comparer(p)); + } + if let Some((_, write)) = &self.write { + let p = write.search(k, vector, &mut filter).into_reversed_heap(); + heaps.push(Comparer(p)); + } + while let Some(Comparer(mut heap)) = heaps.pop() { + if let Some(Reverse(x)) = heap.pop() { + result.push(x); + heaps.push(Comparer(heap)); + } + } + result + .into_sorted_vec() + .iter() + .map(|x| Pointer::from_u48(x.payload >> 16)) + .collect() + } + pub fn vbase(&self, vector: &[S::Scalar]) -> impl Iterator + '_ { + assert_eq!(self.options.vector.dims as usize, vector.len()); + + let range = 86; + + struct Comparer<'a, S: G> { + iter: ComparerIter<'a, S>, + item: Option, + } + + enum ComparerIter<'a, S: G> { + Sealed(DynamicIndexIter<'a, S>), + Growing(std::vec::IntoIter), + } + + impl PartialEq for Comparer<'_, S> { + fn eq(&self, other: &Self) -> bool { + self.cmp(other).is_eq() + } + } + + impl Eq for Comparer<'_, S> {} + + impl PartialOrd for Comparer<'_, S> { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + + impl Ord for Comparer<'_, S> { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.item.cmp(&other.item).reverse() + } + } + + impl Iterator for ComparerIter<'_, S> { + type Item = HeapElement; + fn next(&mut self) -> Option { + match self { + Self::Sealed(iter) => iter.next(), + Self::Growing(iter) => iter.next(), + } + } + } + + impl Iterator for Comparer<'_, S> { + type Item = HeapElement; + fn next(&mut self) -> Option { + let item = self.item.take(); + self.item = self.iter.next(); + item + } + } + + fn from_iter(mut iter: ComparerIter<'_, S>) -> Comparer<'_, S> { + let item = iter.next(); + Comparer { iter, item } + } + + use ComparerIter::*; + let filter = |payload| self.delete.check(payload).is_some(); + let n = self.sealed.len() + self.growing.len() + 1; + let mut heaps: BinaryHeap> = BinaryHeap::with_capacity(1 + n); + for (_, sealed) in self.sealed.iter() { + let res = sealed.vbase(range, vector); + heaps.push(from_iter(Sealed(res))); + } + for (_, growing) in self.growing.iter() { + let mut res = growing.vbase(vector); + res.sort_unstable(); + heaps.push(from_iter(Growing(res.into_iter()))); + } + if let Some((_, write)) = &self.write { + let mut res = write.vbase(vector); + res.sort_unstable(); + heaps.push(from_iter(Growing(res.into_iter()))); + } + std::iter::from_fn(move || { + while let Some(mut iter) = heaps.pop() { + if let Some(x) = iter.next() { + if !filter(x.payload) { + continue; + } + heaps.push(iter); + return Some(Pointer::from_u48(x.payload >> 16)); + } + } + None + }) + } + pub fn insert(&self, vector: Vec, pointer: Pointer) -> Result<(), OutdatedError> { + assert_eq!(self.options.vector.dims as usize, vector.len()); + let payload = (pointer.as_u48() << 16) | self.delete.version(pointer) as Payload; + if let Some((_, growing)) = self.write.as_ref() { + Ok(growing.insert(vector, payload)?) + } else { + Err(OutdatedError(None)) + } + } + pub fn delete bool>(&self, mut f: F) { + for (_, sealed) in self.sealed.iter() { + let n = sealed.len(); + for i in 0..n { + if let Some(p) = self.delete.check(sealed.payload(i)) { + if f(p) { + self.delete.delete(p); + } + } + } + } + for (_, growing) in self.growing.iter() { + let n = growing.len(); + for i in 0..n { + if let Some(p) = self.delete.check(growing.payload(i)) { + if f(p) { + self.delete.delete(p); + } + } + } + } + if let Some((_, write)) = &self.write { + let n = write.len(); + for i in 0..n { + if let Some(p) = self.delete.check(write.payload(i)) { + if f(p) { + self.delete.delete(p); + } + } + } + } + } + pub fn flush(&self) { + self.delete.flush(); + if let Some((_, write)) = &self.write { + write.flush(); + } + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +struct IndexStartup { + sealeds: HashSet, + growings: HashSet, +} + +struct IndexProtect { + startup: FileAtomic, + sealed: HashMap>>, + growing: HashMap>>, + write: Option<(Uuid, Arc>)>, +} + +impl IndexProtect { + fn maintain( + &mut self, + options: IndexOptions, + delete: Arc, + swap: &ArcSwap>, + ) { + let view = Arc::new(IndexView { + options, + delete, + sealed: self.sealed.clone(), + growing: self.growing.clone(), + write: self.write.clone(), + }); + let startup_write = self.write.as_ref().map(|(uuid, _)| *uuid); + let startup_sealeds = self.sealed.keys().copied().collect(); + let startup_growings = self.growing.keys().copied().chain(startup_write).collect(); + self.startup.set(IndexStartup { + sealeds: startup_sealeds, + growings: startup_growings, + }); + swap.swap(view); + } +} diff --git a/crates/service/src/index/optimizing/indexing.rs b/crates/service/src/index/optimizing/indexing.rs new file mode 100644 index 000000000..5a70ffa91 --- /dev/null +++ b/crates/service/src/index/optimizing/indexing.rs @@ -0,0 +1,148 @@ +use crate::index::GrowingSegment; +use crate::index::Index; +use crate::index::SealedSegment; +use crate::prelude::*; +use std::cmp::Reverse; +use std::sync::Arc; +use std::time::Instant; +use uuid::Uuid; + +pub struct OptimizerIndexing { + index: Arc>, +} + +impl OptimizerIndexing { + pub fn new(index: Arc>) -> Self { + Self { index } + } + pub fn spawn(self) { + std::thread::spawn(move || { + self.main(); + }); + } + pub fn main(self) { + let index = self.index; + let pool = rayon::ThreadPoolBuilder::new() + .num_threads(index.options.optimizing.optimizing_threads) + .build() + .unwrap(); + let weak_index = Arc::downgrade(&index); + std::mem::drop(index); + loop { + { + let Some(index) = weak_index.upgrade() else { + return; + }; + if let Ok(()) = pool.install(|| optimizing_indexing(index.clone())) { + continue; + } + } + std::thread::sleep(std::time::Duration::from_secs(60)); + } + } +} + +enum Seg { + Sealed(Arc>), + Growing(Arc>), +} + +impl Seg { + fn uuid(&self) -> Uuid { + use Seg::*; + match self { + Sealed(x) => x.uuid(), + Growing(x) => x.uuid(), + } + } + fn len(&self) -> u32 { + use Seg::*; + match self { + Sealed(x) => x.len(), + Growing(x) => x.len(), + } + } + fn get_sealed(&self) -> Option>> { + match self { + Seg::Sealed(x) => Some(x.clone()), + _ => None, + } + } + fn get_growing(&self) -> Option>> { + match self { + Seg::Growing(x) => Some(x.clone()), + _ => None, + } + } +} + +#[derive(Debug, thiserror::Error)] +#[error("Interrupted, retry again.")] +pub struct RetryError; + +pub fn optimizing_indexing(index: Arc>) -> Result<(), RetryError> { + use Seg::*; + let segs = { + let protect = index.protect.lock(); + let mut segs_0 = Vec::new(); + segs_0.extend(protect.growing.values().map(|x| Growing(x.clone()))); + segs_0.extend(protect.sealed.values().map(|x| Sealed(x.clone()))); + segs_0.sort_by_key(|case| Reverse(case.len())); + let mut segs_1 = Vec::new(); + let mut total = 0u64; + let mut count = 0; + while let Some(seg) = segs_0.pop() { + if total + seg.len() as u64 <= index.options.segment.max_sealed_segment_size as u64 { + total += seg.len() as u64; + if let Growing(_) = seg { + count += 1; + } + segs_1.push(seg); + } else { + break; + } + } + if segs_1.is_empty() || (segs_1.len() == 1 && count == 0) { + index.instant_index.store(Instant::now()); + return Err(RetryError); + } + segs_1 + }; + let sealed_segment = merge(&index, &segs); + { + let mut protect = index.protect.lock(); + for seg in segs.iter() { + if protect.sealed.contains_key(&seg.uuid()) { + continue; + } + if protect.growing.contains_key(&seg.uuid()) { + continue; + } + return Ok(()); + } + for seg in segs.iter() { + protect.sealed.remove(&seg.uuid()); + protect.growing.remove(&seg.uuid()); + } + protect.sealed.insert(sealed_segment.uuid(), sealed_segment); + protect.maintain(index.options.clone(), index.delete.clone(), &index.view); + } + Ok(()) +} + +fn merge(index: &Arc>, segs: &[Seg]) -> Arc> { + let sealed = segs.iter().filter_map(|x| x.get_sealed()).collect(); + let growing = segs.iter().filter_map(|x| x.get_growing()).collect(); + let sealed_segment_uuid = Uuid::new_v4(); + SealedSegment::create( + index._tracker.clone(), + index + .path + .join("segments") + .join(sealed_segment_uuid.to_string()), + sealed_segment_uuid, + index.options.clone(), + sealed, + growing, + ) +} diff --git a/src/index/optimizing/mod.rs b/crates/service/src/index/optimizing/mod.rs similarity index 67% rename from src/index/optimizing/mod.rs rename to crates/service/src/index/optimizing/mod.rs index a0426720d..d0c4b2ae0 100644 --- a/src/index/optimizing/mod.rs +++ b/crates/service/src/index/optimizing/mod.rs @@ -1,4 +1,5 @@ pub mod indexing; +pub mod sealing; pub mod vacuum; use serde::{Deserialize, Serialize}; @@ -6,9 +7,12 @@ use validator::Validate; #[derive(Debug, Clone, Serialize, Deserialize, Validate)] pub struct OptimizingOptions { - #[serde(default = "OptimizingOptions::default_waiting_secs", skip)] - #[validate(range(min = 0, max = 600))] - pub waiting_secs: u64, + #[serde(default = "OptimizingOptions::default_sealing_secs")] + #[validate(range(min = 0, max = 60))] + pub sealing_secs: u64, + #[serde(default = "OptimizingOptions::default_sealing_size")] + #[validate(range(min = 1, max = 4_000_000_000))] + pub sealing_size: u32, #[serde(default = "OptimizingOptions::default_deleted_threshold", skip)] #[validate(range(min = 0.01, max = 1.00))] pub deleted_threshold: f64, @@ -18,9 +22,12 @@ pub struct OptimizingOptions { } impl OptimizingOptions { - fn default_waiting_secs() -> u64 { + fn default_sealing_secs() -> u64 { 60 } + fn default_sealing_size() -> u32 { + 1 + } fn default_deleted_threshold() -> f64 { 0.2 } @@ -35,7 +42,8 @@ impl OptimizingOptions { impl Default for OptimizingOptions { fn default() -> Self { Self { - waiting_secs: Self::default_waiting_secs(), + sealing_secs: Self::default_sealing_secs(), + sealing_size: Self::default_sealing_size(), deleted_threshold: Self::default_deleted_threshold(), optimizing_threads: Self::default_optimizing_threads(), } diff --git a/crates/service/src/index/optimizing/sealing.rs b/crates/service/src/index/optimizing/sealing.rs new file mode 100644 index 000000000..0e02ead88 --- /dev/null +++ b/crates/service/src/index/optimizing/sealing.rs @@ -0,0 +1,49 @@ +use crate::index::Index; +use crate::prelude::*; +use std::sync::Arc; +use std::time::Duration; + +pub struct OptimizerSealing { + index: Arc>, +} + +impl OptimizerSealing { + pub fn new(index: Arc>) -> Self { + Self { index } + } + pub fn spawn(self) { + std::thread::spawn(move || { + self.main(); + }); + } + pub fn main(self) { + let index = self.index; + let dur = Duration::from_secs(index.options.optimizing.sealing_secs); + let least = index.options.optimizing.sealing_size; + let weak_index = Arc::downgrade(&index); + std::mem::drop(index); + let mut check = None; + loop { + { + let Some(index) = weak_index.upgrade() else { + return; + }; + let view = index.view(); + let stamp = view + .write + .as_ref() + .map(|(uuid, segment)| (*uuid, segment.len())); + if stamp == check { + if let Some((uuid, len)) = stamp { + if len >= least { + index.seal(uuid); + } + } + } else { + check = stamp; + } + } + std::thread::sleep(dur); + } + } +} diff --git a/src/index/optimizing/vacuum.rs b/crates/service/src/index/optimizing/vacuum.rs similarity index 100% rename from src/index/optimizing/vacuum.rs rename to crates/service/src/index/optimizing/vacuum.rs diff --git a/src/index/segments/growing.rs b/crates/service/src/index/segments/growing.rs similarity index 79% rename from src/index/segments/growing.rs rename to crates/service/src/index/segments/growing.rs index 972f23d92..cd393b47c 100644 --- a/src/index/segments/growing.rs +++ b/crates/service/src/index/segments/growing.rs @@ -1,7 +1,8 @@ +#![allow(clippy::all)] // Clippy bug. + use super::SegmentTracker; use crate::index::IndexOptions; use crate::index::IndexTracker; -use crate::index::VectorOptions; use crate::prelude::*; use crate::utils::dir_ops::sync_dir; use crate::utils::file_wal::FileWal; @@ -19,17 +20,16 @@ use uuid::Uuid; #[error("`GrowingSegment` stopped growing.")] pub struct GrowingSegmentInsertError; -pub struct GrowingSegment { +pub struct GrowingSegment { uuid: Uuid, - options: VectorOptions, - vec: Vec>>, + vec: Vec>>>, wal: Mutex, len: AtomicUsize, pro: Mutex, _tracker: Arc, } -impl GrowingSegment { +impl GrowingSegment { pub fn create( _tracker: Arc, path: PathBuf, @@ -42,7 +42,6 @@ impl GrowingSegment { sync_dir(&path); Arc::new(Self { uuid, - options: options.vector, vec: unsafe { let mut vec = Vec::with_capacity(capacity as usize); vec.set_len(capacity as usize); @@ -57,23 +56,17 @@ impl GrowingSegment { _tracker: Arc::new(SegmentTracker { path, _tracker }), }) } - pub fn open( - _tracker: Arc, - path: PathBuf, - uuid: Uuid, - options: IndexOptions, - ) -> Arc { + pub fn open(_tracker: Arc, path: PathBuf, uuid: Uuid) -> Arc { let mut wal = FileWal::open(path.join("wal")); let mut vec = Vec::new(); while let Some(log) = wal.read() { - let log = bincode::deserialize::(&log).unwrap(); + let log = bincode::deserialize::>(&log).unwrap(); vec.push(UnsafeCell::new(MaybeUninit::new(log))); } wal.truncate(); let n = vec.len(); Arc::new(Self { uuid, - options: options.vector, vec, wal: { Mutex::new(wal) }, len: AtomicUsize::new(n), @@ -87,6 +80,20 @@ impl GrowingSegment { pub fn uuid(&self) -> Uuid { self.uuid } + pub fn is_full(&self) -> bool { + let n; + { + let pro = self.pro.lock(); + if pro.inflight < pro.capacity { + return false; + } + n = pro.inflight; + } + while self.len.load(Ordering::Acquire) != n { + std::hint::spin_loop(); + } + true + } pub fn seal(&self) { let n; { @@ -104,7 +111,7 @@ impl GrowingSegment { } pub fn insert( &self, - vector: Vec, + vector: Vec, payload: Payload, ) -> Result<(), GrowingSegmentInsertError> { let log = Log { vector, payload }; @@ -126,13 +133,13 @@ impl GrowingSegment { self.len.store(1 + i, Ordering::Release); self.wal .lock() - .write(&bincode::serialize::(&log).unwrap()); + .write(&bincode::serialize::>(&log).unwrap()); Ok(()) } pub fn len(&self) -> u32 { self.len.load(Ordering::Acquire) as u32 } - pub fn vector(&self, i: u32) -> &[Scalar] { + pub fn vector(&self, i: u32) -> &[S::Scalar] { let i = i as usize; if i >= self.len.load(Ordering::Acquire) { panic!("Out of bound."); @@ -148,12 +155,12 @@ impl GrowingSegment { let log = unsafe { (*self.vec[i].get()).assume_init_ref() }; log.payload } - pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { + pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap { let n = self.len.load(Ordering::Acquire); let mut heap = Heap::new(k); for i in 0..n { let log = unsafe { (*self.vec[i].get()).assume_init_ref() }; - let distance = self.options.d.distance(vector, &log.vector); + let distance = S::distance(vector, &log.vector); if heap.check(distance) && filter.check(log.payload) { heap.push(HeapElement { distance, @@ -163,12 +170,12 @@ impl GrowingSegment { } heap } - pub fn search_all(&self, vector: &[Scalar]) -> Vec { + pub fn vbase(&self, vector: &[S::Scalar]) -> Vec { let n = self.len.load(Ordering::Acquire); let mut result = Vec::new(); for i in 0..n { let log = unsafe { (*self.vec[i].get()).assume_init_ref() }; - let distance = self.options.d.distance(vector, &log.vector); + let distance = S::distance(vector, &log.vector); result.push(HeapElement { distance, payload: log.payload, @@ -178,10 +185,10 @@ impl GrowingSegment { } } -unsafe impl Send for GrowingSegment {} -unsafe impl Sync for GrowingSegment {} +unsafe impl Send for GrowingSegment {} +unsafe impl Sync for GrowingSegment {} -impl Drop for GrowingSegment { +impl Drop for GrowingSegment { fn drop(&mut self) { let n = *self.len.get_mut(); for i in 0..n { @@ -193,8 +200,8 @@ impl Drop for GrowingSegment { } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -struct Log { - vector: Vec, +struct Log { + vector: Vec, payload: Payload, } diff --git a/src/index/segments/mod.rs b/crates/service/src/index/segments/mod.rs similarity index 67% rename from src/index/segments/mod.rs rename to crates/service/src/index/segments/mod.rs index 83b0682af..b09c1f939 100644 --- a/src/index/segments/mod.rs +++ b/crates/service/src/index/segments/mod.rs @@ -10,14 +10,10 @@ use validator::ValidationError; #[derive(Debug, Clone, Serialize, Deserialize, Validate)] #[validate(schema(function = "Self::validate_0"))] -#[validate(schema(function = "Self::validate_1"))] pub struct SegmentsOptions { #[serde(default = "SegmentsOptions::default_max_growing_segment_size")] #[validate(range(min = 1, max = 4_000_000_000))] pub max_growing_segment_size: u32, - #[serde(default = "SegmentsOptions::default_min_sealed_segment_size")] - #[validate(range(min = 1, max = 4_000_000_000))] - pub min_sealed_segment_size: u32, #[serde(default = "SegmentsOptions::default_max_sealed_segment_size")] #[validate(range(min = 1, max = 4_000_000_000))] pub max_sealed_segment_size: u32, @@ -27,22 +23,11 @@ impl SegmentsOptions { fn default_max_growing_segment_size() -> u32 { 20_000 } - fn default_min_sealed_segment_size() -> u32 { - 1_000 - } fn default_max_sealed_segment_size() -> u32 { 1_000_000 } - // min_sealed_segment_size <= max_growing_segment_size <= max_sealed_segment_size + // max_growing_segment_size <= max_sealed_segment_size fn validate_0(&self) -> Result<(), ValidationError> { - if self.min_sealed_segment_size > self.max_growing_segment_size { - return Err(ValidationError::new( - "`min_sealed_segment_size` must be less than or equal to `max_growing_segment_size`", - )); - } - Ok(()) - } - fn validate_1(&self) -> Result<(), ValidationError> { if self.max_growing_segment_size > self.max_sealed_segment_size { return Err(ValidationError::new( "`max_growing_segment_size` must be less than or equal to `max_sealed_segment_size`", @@ -56,7 +41,6 @@ impl Default for SegmentsOptions { fn default() -> Self { Self { max_growing_segment_size: Self::default_max_growing_segment_size(), - min_sealed_segment_size: Self::default_min_sealed_segment_size(), max_sealed_segment_size: Self::default_max_sealed_segment_size(), } } diff --git a/src/index/segments/sealed.rs b/crates/service/src/index/segments/sealed.rs similarity index 74% rename from src/index/segments/sealed.rs rename to crates/service/src/index/segments/sealed.rs index 1bb4c7b74..c7ad16a25 100644 --- a/src/index/segments/sealed.rs +++ b/crates/service/src/index/segments/sealed.rs @@ -8,20 +8,20 @@ use std::path::PathBuf; use std::sync::Arc; use uuid::Uuid; -pub struct SealedSegment { +pub struct SealedSegment { uuid: Uuid, - indexing: DynamicIndexing, + indexing: DynamicIndexing, _tracker: Arc, } -impl SealedSegment { +impl SealedSegment { pub fn create( _tracker: Arc, path: PathBuf, uuid: Uuid, options: IndexOptions, - sealed: Vec>, - growing: Vec>, + sealed: Vec>>, + growing: Vec>>, ) -> Arc { std::fs::create_dir(&path).unwrap(); let indexing = DynamicIndexing::create(path.join("indexing"), options, sealed, growing); @@ -51,20 +51,16 @@ impl SealedSegment { pub fn len(&self) -> u32 { self.indexing.len() } - pub fn vector(&self, i: u32) -> &[Scalar] { + pub fn vector(&self, i: u32) -> &[S::Scalar] { self.indexing.vector(i) } pub fn payload(&self, i: u32) -> Payload { self.indexing.payload(i) } - pub fn search(&self, k: usize, vector: &[Scalar], filter: &mut impl Filter) -> Heap { + pub fn search(&self, k: usize, vector: &[S::Scalar], filter: &mut impl Filter) -> Heap { self.indexing.search(k, vector, filter) } - pub fn search_vbase<'index, 'vector>( - &'index self, - range: usize, - vector: &'vector [Scalar], - ) -> DynamicIndexIter<'index, 'vector> { - self.indexing.search_vbase(range, vector) + pub fn vbase(&self, range: usize, vector: &[S::Scalar]) -> DynamicIndexIter<'_, S> { + self.indexing.vbase(range, vector) } } diff --git a/crates/service/src/lib.rs b/crates/service/src/lib.rs new file mode 100644 index 000000000..b534589f1 --- /dev/null +++ b/crates/service/src/lib.rs @@ -0,0 +1,9 @@ +#![feature(core_intrinsics)] +#![feature(avx512_target_feature)] + +pub mod algorithms; +pub mod index; +pub mod prelude; +pub mod worker; + +mod utils; diff --git a/src/prelude/error.rs b/crates/service/src/prelude/error.rs similarity index 61% rename from src/prelude/error.rs rename to crates/service/src/prelude/error.rs index e3d8b8b45..fbefffe71 100644 --- a/src/prelude/error.rs +++ b/crates/service/src/prelude/error.rs @@ -1,5 +1,3 @@ -use crate::ipc::IpcError; -use crate::prelude::*; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -15,100 +13,77 @@ or simply run the command `psql -U postgres -c 'ALTER SYSTEM SET shared_preload_ ")] BadInit, #[error("\ -The given index option is invalid. -INFORMATION: reason = {0:?}\ -")] - BadOption(String), - #[error("\ -The given vector is invalid for input. -INFORMATION: vector = {0:?} -ADVICE: Check if dimensions of the vector is matched with the index.\ +Bad literal. +INFORMATION: hint = {hint}\ ")] - BadVector(Vec), + BadLiteral { + hint: String, + }, #[error("\ Modifier of the type is invalid. ADVICE: Check if modifier of the type is an integer among 1 and 65535.\ ")] - BadTypmod, + BadTypeDimensions, #[error("\ Dimensions of the vector is invalid. ADVICE: Check if dimensions of the vector are among 1 and 65535.\ ")] - BadVecForDims, + BadValueDimensions, #[error("\ -Dimensions of the vector is unmatched with the type modifier. -INFORMATION: type_dimensions = {type_dimensions}, value_dimensions = {value_dimensions}\ +The given index option is invalid. +INFORMATION: reason = {validation:?}\ ")] - BadVecForUnmatchedDims { - value_dimensions: u16, - type_dimensions: u16, - }, + BadOption { validation: String }, #[error("\ -Operands of the operator differs in dimensions. -INFORMATION: left_dimensions = {left_dimensions}, right_dimensions = {right_dimensions}\ +Dimensions type modifier of a vector column is needed for building the index.\ ")] - DifferentVectorDims { - left_dimensions: u16, - right_dimensions: u16, - }, + BadOption2, #[error("\ Indexes can only be built on built-in distance functions. ADVICE: If you want pgvecto.rs to support more distance functions, \ visit `https://github.com/tensorchord/pgvecto.rs/issues` and contribute your ideas.\ ")] - UnsupportedOperator, + BadOptions3, #[error("\ The index is not existing in the background worker. ADVICE: Drop or rebuild the index.\ ")] - Index404, + UnknownIndex, #[error("\ -Dimensions type modifier of a vector column is needed for building the index.\ -")] - DimsIsNeeded, - #[error("\ -Bad vector string. -INFORMATION: hint = {hint}\ +Operands of the operator differs in dimensions or scalar type. +INFORMATION: left_dimensions = {left_dimensions}, right_dimensions = {right_dimensions}\ ")] - BadVectorString { - hint: String, + Unmatched { + left_dimensions: u16, + right_dimensions: u16, }, #[error("\ -`mmap` transport is not supported by MacOS.\ +The given vector is invalid for input. +ADVICE: Check if dimensions and scalar type of the vector is matched with the index.\ ")] - MmapTransportNotSupported, + Unmatched2, } -impl FriendlyError { - pub fn friendly(self) -> ! { - panic!("pgvecto.rs: {}", self); - } +pub trait FriendlyErrorLike { + fn friendly(self) -> !; } -impl IpcError { - pub fn friendly(self) -> ! { +impl FriendlyErrorLike for FriendlyError { + fn friendly(self) -> ! { panic!("pgvecto.rs: {}", self); } } -pub trait Friendly { +pub trait FriendlyResult { type Output; fn friendly(self) -> Self::Output; } -impl Friendly for Result { - type Output = T; - - fn friendly(self) -> T { - match self { - Ok(x) => x, - Err(e) => e.friendly(), - } - } -} - -impl Friendly for Result { +impl FriendlyResult for Result +where + E: FriendlyErrorLike, +{ type Output = T; fn friendly(self) -> T { diff --git a/src/prelude/filter.rs b/crates/service/src/prelude/filter.rs similarity index 100% rename from src/prelude/filter.rs rename to crates/service/src/prelude/filter.rs diff --git a/crates/service/src/prelude/global/f16.rs b/crates/service/src/prelude/global/f16.rs new file mode 100644 index 000000000..2f7e13cab --- /dev/null +++ b/crates/service/src/prelude/global/f16.rs @@ -0,0 +1,114 @@ +use crate::prelude::*; + +pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { + #[inline(always)] + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..n { + xy += lhs[i].to_f() * rhs[i].to_f(); + x2 += lhs[i].to_f() * lhs[i].to_f(); + y2 += rhs[i].to_f() * rhs[i].to_f(); + } + xy / (x2 * y2).sqrt() + } + #[cfg(target_arch = "x86_64")] + if crate::utils::detect::x86_64::detect_avx512fp16() { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + unsafe { + return c::v_f16_cosine_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); + } + } + #[cfg(target_arch = "x86_64")] + if crate::utils::detect::x86_64::detect_v3() { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + unsafe { + return c::v_f16_cosine_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); + } + } + cosine(lhs, rhs) +} + +pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { + #[inline(always)] + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + for i in 0..n { + xy += lhs[i].to_f() * rhs[i].to_f(); + } + xy + } + #[cfg(target_arch = "x86_64")] + if crate::utils::detect::x86_64::detect_avx512fp16() { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + unsafe { + return c::v_f16_dot_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); + } + } + #[cfg(target_arch = "x86_64")] + if crate::utils::detect::x86_64::detect_v3() { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + unsafe { + return c::v_f16_dot_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); + } + } + cosine(lhs, rhs) +} + +pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { + #[inline(always)] + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut d2 = F32::zero(); + for i in 0..n { + let d = lhs[i].to_f() - rhs[i].to_f(); + d2 += d * d; + } + d2 + } + #[cfg(target_arch = "x86_64")] + if crate::utils::detect::x86_64::detect_avx512fp16() { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + unsafe { + return c::v_f16_sl2_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); + } + } + #[cfg(target_arch = "x86_64")] + if crate::utils::detect::x86_64::detect_v3() { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + unsafe { + return c::v_f16_sl2_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); + } + } + sl2(lhs, rhs) +} diff --git a/crates/service/src/prelude/global/f16_cos.rs b/crates/service/src/prelude/global/f16_cos.rs new file mode 100644 index 000000000..df9f60522 --- /dev/null +++ b/crates/service/src/prelude/global/f16_cos.rs @@ -0,0 +1,244 @@ +use super::G; +use crate::prelude::scalar::F32; +use crate::prelude::*; + +#[derive(Debug, Clone, Copy)] +pub enum F16Cos {} + +impl G for F16Cos { + type Scalar = F16; + + const DISTANCE: Distance = Distance::Cos; + + type L2 = F16L2; + + fn distance(lhs: &[F16], rhs: &[F16]) -> F32 { + super::f16::cosine(lhs, rhs) * (-1.0) + } + + fn elkan_k_means_normalize(vector: &mut [F16]) { + l2_normalize(vector) + } + + fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 { + super::f16::dot(lhs, rhs).acos() + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn scalar_quantization_distance( + dims: u16, + max: &[F16], + min: &[F16], + lhs: &[F16], + rhs: &[u8], + ) -> F32 { + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..dims as usize { + let _x = lhs[i].to_f(); + let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f(); + xy += _x * _y; + x2 += _x * _x; + y2 += _y * _y; + } + xy / (x2 * y2).sqrt() * (-1.0) + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn scalar_quantization_distance2( + dims: u16, + max: &[F16], + min: &[F16], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..dims as usize { + let _x = F32(lhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f(); + let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f(); + xy += _x * _y; + x2 += _x * _x; + y2 += _y * _y; + } + xy / (x2 * y2).sqrt() * (-1.0) + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn product_quantization_distance( + dims: u16, + ratio: u16, + centroids: &[F16], + lhs: &[F16], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs); + xy += _xy; + x2 += _x2; + y2 += _y2; + } + xy / (x2 * y2).sqrt() * (-1.0) + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn product_quantization_distance2( + dims: u16, + ratio: u16, + centroids: &[F16], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhsp = lhs[i as usize] as usize * dims as usize; + let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs); + xy += _xy; + x2 += _x2; + y2 += _y2; + } + xy / (x2 * y2).sqrt() * (-1.0) + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn product_quantization_distance_with_delta( + dims: u16, + ratio: u16, + centroids: &[F16], + lhs: &[F16], + rhs: &[u8], + delta: &[F16], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let del = &delta[(i * ratio) as usize..][..k as usize]; + let (_xy, _x2, _y2) = xy_x2_y2_delta(lhs, rhs, del); + xy += _xy; + x2 += _x2; + y2 += _y2; + } + xy / (x2 * y2).sqrt() * (-1.0) + } +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +fn length(vector: &[F16]) -> F16 { + let n = vector.len(); + let mut dot = F16::zero(); + for i in 0..n { + dot += vector[i] * vector[i]; + } + dot.sqrt() +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +fn l2_normalize(vector: &mut [F16]) { + let n = vector.len(); + let l = length(vector); + for i in 0..n { + vector[i] /= l; + } +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +fn xy_x2_y2(lhs: &[F16], rhs: &[F16]) -> (F32, F32, F32) { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..n { + xy += lhs[i].to_f() * rhs[i].to_f(); + x2 += lhs[i].to_f() * lhs[i].to_f(); + y2 += rhs[i].to_f() * rhs[i].to_f(); + } + (xy, x2, y2) +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +fn xy_x2_y2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> (F32, F32, F32) { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..n { + xy += lhs[i].to_f() * (rhs[i].to_f() + del[i].to_f()); + x2 += lhs[i].to_f() * lhs[i].to_f(); + y2 += (rhs[i].to_f() + del[i].to_f()) * (rhs[i].to_f() + del[i].to_f()); + } + (xy, x2, y2) +} diff --git a/crates/service/src/prelude/global/f16_dot.rs b/crates/service/src/prelude/global/f16_dot.rs new file mode 100644 index 000000000..085c2b827 --- /dev/null +++ b/crates/service/src/prelude/global/f16_dot.rs @@ -0,0 +1,199 @@ +use super::G; +use crate::prelude::scalar::F32; +use crate::prelude::*; + +#[derive(Debug, Clone, Copy)] +pub enum F16Dot {} + +impl G for F16Dot { + type Scalar = F16; + + const DISTANCE: Distance = Distance::Dot; + + type L2 = F16L2; + + fn distance(lhs: &[F16], rhs: &[F16]) -> F32 { + super::f16::dot(lhs, rhs) * (-1.0) + } + + fn elkan_k_means_normalize(vector: &mut [F16]) { + l2_normalize(vector) + } + + fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 { + super::f16::dot(lhs, rhs).acos() + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn scalar_quantization_distance( + dims: u16, + max: &[F16], + min: &[F16], + lhs: &[F16], + rhs: &[u8], + ) -> F32 { + let mut xy = F32::zero(); + for i in 0..dims as usize { + let _x = lhs[i].to_f(); + let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f(); + xy += _x * _y; + } + xy * (-1.0) + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn scalar_quantization_distance2( + dims: u16, + max: &[F16], + min: &[F16], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let mut xy = F32::zero(); + for i in 0..dims as usize { + let _x = F32(lhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f(); + let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f(); + xy += _x * _y; + } + xy * (-1.0) + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn product_quantization_distance( + dims: u16, + ratio: u16, + centroids: &[F16], + lhs: &[F16], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let _xy = super::f16::dot(lhs, rhs); + xy += _xy; + } + xy * (-1.0) + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn product_quantization_distance2( + dims: u16, + ratio: u16, + centroids: &[F16], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhsp = lhs[i as usize] as usize * dims as usize; + let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let _xy = super::f16::dot(lhs, rhs); + xy += _xy; + } + xy * (-1.0) + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn product_quantization_distance_with_delta( + dims: u16, + ratio: u16, + centroids: &[F16], + lhs: &[F16], + rhs: &[u8], + delta: &[F16], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let del = &delta[(i * ratio) as usize..][..k as usize]; + let _xy = dot_delta(lhs, rhs, del); + xy += _xy; + } + xy * (-1.0) + } +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +fn length(vector: &[F16]) -> F16 { + let n = vector.len(); + let mut dot = F16::zero(); + for i in 0..n { + dot += vector[i] * vector[i]; + } + dot.sqrt() +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +fn l2_normalize(vector: &mut [F16]) { + let n = vector.len(); + let l = length(vector); + for i in 0..n { + vector[i] /= l; + } +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +fn dot_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n: usize = lhs.len(); + let mut xy = F32::zero(); + for i in 0..n { + xy += lhs[i].to_f() * (rhs[i].to_f() + del[i].to_f()); + } + xy +} diff --git a/crates/service/src/prelude/global/f16_l2.rs b/crates/service/src/prelude/global/f16_l2.rs new file mode 100644 index 000000000..647c6f900 --- /dev/null +++ b/crates/service/src/prelude/global/f16_l2.rs @@ -0,0 +1,165 @@ +use super::G; +use crate::prelude::scalar::F16; +use crate::prelude::scalar::F32; +use crate::prelude::*; + +#[derive(Debug, Clone, Copy)] +pub enum F16L2 {} + +impl G for F16L2 { + type Scalar = F16; + + const DISTANCE: Distance = Distance::L2; + + type L2 = F16L2; + + fn distance(lhs: &[F16], rhs: &[F16]) -> F32 { + super::f16::sl2(lhs, rhs) + } + + fn elkan_k_means_normalize(_: &mut [F16]) {} + + fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 { + super::f16::sl2(lhs, rhs).sqrt() + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn scalar_quantization_distance( + dims: u16, + max: &[F16], + min: &[F16], + lhs: &[F16], + rhs: &[u8], + ) -> F32 { + let mut result = F32::zero(); + for i in 0..dims as usize { + let _x = lhs[i].to_f(); + let _y = (F32(rhs[i] as f32) / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f(); + result += (_x - _y) * (_x - _y); + } + result + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn scalar_quantization_distance2( + dims: u16, + max: &[F16], + min: &[F16], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let mut result = F32::zero(); + for i in 0..dims as usize { + let _x = F32(lhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f(); + let _y = F32(rhs[i] as f32 / 256.0) * (max[i].to_f() - min[i].to_f()) + min[i].to_f(); + result += (_x - _y) * (_x - _y); + } + result + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn product_quantization_distance( + dims: u16, + ratio: u16, + centroids: &[F16], + lhs: &[F16], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut result = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + result += super::f16::sl2(lhs, rhs); + } + result + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn product_quantization_distance2( + dims: u16, + ratio: u16, + centroids: &[F16], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut result = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhsp = lhs[i as usize] as usize * dims as usize; + let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + result += super::f16::sl2(lhs, rhs); + } + result + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn product_quantization_distance_with_delta( + dims: u16, + ratio: u16, + centroids: &[F16], + lhs: &[F16], + rhs: &[u8], + delta: &[F16], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut result = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let del = &delta[(i * ratio) as usize..][..k as usize]; + result += distance_squared_l2_delta(lhs, rhs, del); + } + result + } +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +fn distance_squared_l2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut d2 = F32::zero(); + for i in 0..n { + let d = lhs[i].to_f() - (rhs[i].to_f() + del[i].to_f()); + d2 += d * d; + } + d2 +} diff --git a/crates/service/src/prelude/global/f32_cos.rs b/crates/service/src/prelude/global/f32_cos.rs new file mode 100644 index 000000000..06cd7001e --- /dev/null +++ b/crates/service/src/prelude/global/f32_cos.rs @@ -0,0 +1,265 @@ +use super::G; +use crate::prelude::scalar::F32; +use crate::prelude::*; + +#[derive(Debug, Clone, Copy)] +pub enum F32Cos {} + +impl G for F32Cos { + type Scalar = F32; + + const DISTANCE: Distance = Distance::Cos; + + type L2 = F32L2; + + fn distance(lhs: &[F32], rhs: &[F32]) -> F32 { + cosine(lhs, rhs) * (-1.0) + } + + fn elkan_k_means_normalize(vector: &mut [F32]) { + l2_normalize(vector) + } + + fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 { + super::f32_dot::dot(lhs, rhs).acos() + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn scalar_quantization_distance( + dims: u16, + max: &[F32], + min: &[F32], + lhs: &[F32], + rhs: &[u8], + ) -> F32 { + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..dims as usize { + let _x = lhs[i]; + let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i]; + xy += _x * _y; + x2 += _x * _x; + y2 += _y * _y; + } + xy / (x2 * y2).sqrt() * (-1.0) + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn scalar_quantization_distance2( + dims: u16, + max: &[F32], + min: &[F32], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..dims as usize { + let _x = F32(lhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i]; + let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i]; + xy += _x * _y; + x2 += _x * _x; + y2 += _y * _y; + } + xy / (x2 * y2).sqrt() * (-1.0) + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn product_quantization_distance( + dims: u16, + ratio: u16, + centroids: &[F32], + lhs: &[F32], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs); + xy += _xy; + x2 += _x2; + y2 += _y2; + } + xy / (x2 * y2).sqrt() * (-1.0) + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn product_quantization_distance2( + dims: u16, + ratio: u16, + centroids: &[F32], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhsp = lhs[i as usize] as usize * dims as usize; + let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs); + xy += _xy; + x2 += _x2; + y2 += _y2; + } + xy / (x2 * y2).sqrt() * (-1.0) + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn product_quantization_distance_with_delta( + dims: u16, + ratio: u16, + centroids: &[F32], + lhs: &[F32], + rhs: &[u8], + delta: &[F32], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let del = &delta[(i * ratio) as usize..][..k as usize]; + let (_xy, _x2, _y2) = xy_x2_y2_delta(lhs, rhs, del); + xy += _xy; + x2 += _x2; + y2 += _y2; + } + xy / (x2 * y2).sqrt() * (-1.0) + } +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +fn length(vector: &[F32]) -> F32 { + let n = vector.len(); + let mut dot = F32::zero(); + for i in 0..n { + dot += vector[i] * vector[i]; + } + dot.sqrt() +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +fn l2_normalize(vector: &mut [F32]) { + let n = vector.len(); + let l = length(vector); + for i in 0..n { + vector[i] /= l; + } +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..n { + xy += lhs[i] * rhs[i]; + x2 += lhs[i] * lhs[i]; + y2 += rhs[i] * rhs[i]; + } + xy / (x2 * y2).sqrt() +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +fn xy_x2_y2(lhs: &[F32], rhs: &[F32]) -> (F32, F32, F32) { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..n { + xy += lhs[i] * rhs[i]; + x2 += lhs[i] * lhs[i]; + y2 += rhs[i] * rhs[i]; + } + (xy, x2, y2) +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +fn xy_x2_y2_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> (F32, F32, F32) { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..n { + xy += lhs[i] * (rhs[i] + del[i]); + x2 += lhs[i] * lhs[i]; + y2 += (rhs[i] + del[i]) * (rhs[i] + del[i]); + } + (xy, x2, y2) +} diff --git a/crates/service/src/prelude/global/f32_dot.rs b/crates/service/src/prelude/global/f32_dot.rs new file mode 100644 index 000000000..58108d89e --- /dev/null +++ b/crates/service/src/prelude/global/f32_dot.rs @@ -0,0 +1,237 @@ +use super::G; +use crate::prelude::scalar::F32; +use crate::prelude::*; + +#[derive(Debug, Clone, Copy)] +pub enum F32Dot {} + +impl G for F32Dot { + type Scalar = F32; + + const DISTANCE: Distance = Distance::Dot; + + type L2 = F32L2; + + fn distance(lhs: &[F32], rhs: &[F32]) -> F32 { + dot(lhs, rhs) * (-1.0) + } + + fn elkan_k_means_normalize(vector: &mut [F32]) { + l2_normalize(vector) + } + + fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 { + super::f32_dot::dot(lhs, rhs).acos() + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn scalar_quantization_distance( + dims: u16, + max: &[F32], + min: &[F32], + lhs: &[F32], + rhs: &[u8], + ) -> F32 { + let mut xy = F32::zero(); + for i in 0..dims as usize { + let _x = lhs[i]; + let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i]; + xy += _x * _y; + } + xy * (-1.0) + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn scalar_quantization_distance2( + dims: u16, + max: &[F32], + min: &[F32], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let mut xy = F32::zero(); + for i in 0..dims as usize { + let _x = F32(lhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i]; + let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i]; + xy += _x * _y; + } + xy * (-1.0) + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn product_quantization_distance( + dims: u16, + ratio: u16, + centroids: &[F32], + lhs: &[F32], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let _xy = dot(lhs, rhs); + xy += _xy; + } + xy * (-1.0) + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn product_quantization_distance2( + dims: u16, + ratio: u16, + centroids: &[F32], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhsp = lhs[i as usize] as usize * dims as usize; + let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let _xy = dot(lhs, rhs); + xy += _xy; + } + xy * (-1.0) + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn product_quantization_distance_with_delta( + dims: u16, + ratio: u16, + centroids: &[F32], + lhs: &[F32], + rhs: &[u8], + delta: &[F32], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut xy = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let del = &delta[(i * ratio) as usize..][..k as usize]; + let _xy = dot_delta(lhs, rhs, del); + xy += _xy; + } + xy * (-1.0) + } +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +fn length(vector: &[F32]) -> F32 { + let n = vector.len(); + let mut dot = F32::zero(); + for i in 0..n { + dot += vector[i] * vector[i]; + } + dot.sqrt() +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +fn l2_normalize(vector: &mut [F32]) { + let n = vector.len(); + let l = length(vector); + for i in 0..n { + vector[i] /= l; + } +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..n { + xy += lhs[i] * rhs[i]; + x2 += lhs[i] * lhs[i]; + y2 += rhs[i] * rhs[i]; + } + xy / (x2 * y2).sqrt() +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +pub fn dot(lhs: &[F32], rhs: &[F32]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + for i in 0..n { + xy += lhs[i] * rhs[i]; + } + xy +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +fn dot_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n: usize = lhs.len(); + let mut xy = F32::zero(); + for i in 0..n { + xy += lhs[i] * (rhs[i] + del[i]); + } + xy +} diff --git a/crates/service/src/prelude/global/f32_l2.rs b/crates/service/src/prelude/global/f32_l2.rs new file mode 100644 index 000000000..2672b6714 --- /dev/null +++ b/crates/service/src/prelude/global/f32_l2.rs @@ -0,0 +1,182 @@ +use super::G; +use crate::prelude::scalar::F32; +use crate::prelude::*; + +#[derive(Debug, Clone, Copy)] +pub enum F32L2 {} + +impl G for F32L2 { + type Scalar = F32; + + const DISTANCE: Distance = Distance::L2; + + type L2 = F32L2; + + fn distance(lhs: &[F32], rhs: &[F32]) -> F32 { + distance_squared_l2(lhs, rhs) + } + + fn elkan_k_means_normalize(_: &mut [F32]) {} + + fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 { + distance_squared_l2(lhs, rhs).sqrt() + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn scalar_quantization_distance( + dims: u16, + max: &[F32], + min: &[F32], + lhs: &[F32], + rhs: &[u8], + ) -> F32 { + let mut result = F32::zero(); + for i in 0..dims as usize { + let _x = lhs[i]; + let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i]; + result += (_x - _y) * (_x - _y); + } + result + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn scalar_quantization_distance2( + dims: u16, + max: &[F32], + min: &[F32], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let mut result = F32::zero(); + for i in 0..dims as usize { + let _x = F32(lhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i]; + let _y = F32(rhs[i] as f32 / 256.0) * (max[i] - min[i]) + min[i]; + result += (_x - _y) * (_x - _y); + } + result + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn product_quantization_distance( + dims: u16, + ratio: u16, + centroids: &[F32], + lhs: &[F32], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut result = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + result += distance_squared_l2(lhs, rhs); + } + result + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn product_quantization_distance2( + dims: u16, + ratio: u16, + centroids: &[F32], + lhs: &[u8], + rhs: &[u8], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut result = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhsp = lhs[i as usize] as usize * dims as usize; + let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + result += distance_squared_l2(lhs, rhs); + } + result + } + + #[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" + ))] + fn product_quantization_distance_with_delta( + dims: u16, + ratio: u16, + centroids: &[F32], + lhs: &[F32], + rhs: &[u8], + delta: &[F32], + ) -> F32 { + let width = dims.div_ceil(ratio); + let mut result = F32::zero(); + for i in 0..width { + let k = std::cmp::min(ratio, dims - ratio * i); + let lhs = &lhs[(i * ratio) as usize..][..k as usize]; + let rhsp = rhs[i as usize] as usize * dims as usize; + let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; + let del = &delta[(i * ratio) as usize..][..k as usize]; + result += distance_squared_l2_delta(lhs, rhs, del); + } + result + } +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +pub fn distance_squared_l2(lhs: &[F32], rhs: &[F32]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut d2 = F32::zero(); + for i in 0..n { + let d = lhs[i] - rhs[i]; + d2 += d * d; + } + d2 +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +fn distance_squared_l2_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut d2 = F32::zero(); + for i in 0..n { + let d = lhs[i] - (rhs[i] + del[i]); + d2 += d * d; + } + d2 +} diff --git a/crates/service/src/prelude/global/mod.rs b/crates/service/src/prelude/global/mod.rs new file mode 100644 index 000000000..2eedaf91a --- /dev/null +++ b/crates/service/src/prelude/global/mod.rs @@ -0,0 +1,121 @@ +mod f16; +mod f16_cos; +mod f16_dot; +mod f16_l2; +mod f32_cos; +mod f32_dot; +mod f32_l2; + +pub use f16_cos::F16Cos; +pub use f16_dot::F16Dot; +pub use f16_l2::F16L2; +pub use f32_cos::F32Cos; +pub use f32_dot::F32Dot; +pub use f32_l2::F32L2; + +use crate::prelude::*; +use serde::{Deserialize, Serialize}; +use std::fmt::Debug; + +pub trait G: Copy + std::fmt::Debug + 'static { + type Scalar: Copy + + Send + + Sync + + std::fmt::Debug + + std::fmt::Display + + serde::Serialize + + for<'a> serde::Deserialize<'a> + + Ord + + bytemuck::Zeroable + + bytemuck::Pod + + num_traits::Float + + num_traits::NumOps + + num_traits::NumAssignOps + + FloatCast; + const DISTANCE: Distance; + type L2: G; + + fn distance(lhs: &[Self::Scalar], rhs: &[Self::Scalar]) -> F32; + fn elkan_k_means_normalize(vector: &mut [Self::Scalar]); + fn elkan_k_means_distance(lhs: &[Self::Scalar], rhs: &[Self::Scalar]) -> F32; + fn scalar_quantization_distance( + dims: u16, + max: &[Self::Scalar], + min: &[Self::Scalar], + lhs: &[Self::Scalar], + rhs: &[u8], + ) -> F32; + fn scalar_quantization_distance2( + dims: u16, + max: &[Self::Scalar], + min: &[Self::Scalar], + lhs: &[u8], + rhs: &[u8], + ) -> F32; + fn product_quantization_distance( + dims: u16, + ratio: u16, + centroids: &[Self::Scalar], + lhs: &[Self::Scalar], + rhs: &[u8], + ) -> F32; + fn product_quantization_distance2( + dims: u16, + ratio: u16, + centroids: &[Self::Scalar], + lhs: &[u8], + rhs: &[u8], + ) -> F32; + fn product_quantization_distance_with_delta( + dims: u16, + ratio: u16, + centroids: &[Self::Scalar], + lhs: &[Self::Scalar], + rhs: &[u8], + delta: &[Self::Scalar], + ) -> F32; +} + +pub trait FloatCast: Sized { + fn from_f32(x: f32) -> Self; + fn to_f32(self) -> f32; + fn from_f(x: F32) -> Self { + Self::from_f32(x.0) + } + fn to_f(self) -> F32 { + F32(Self::to_f32(self)) + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub enum DynamicVector { + F32(Vec), + F16(Vec), +} + +impl From> for DynamicVector { + fn from(value: Vec) -> Self { + Self::F32(value) + } +} + +impl From> for DynamicVector { + fn from(value: Vec) -> Self { + Self::F16(value) + } +} + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum Distance { + L2, + Cos, + Dot, +} + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum Kind { + F32, + F16, +} diff --git a/src/prelude/heap.rs b/crates/service/src/prelude/heap.rs similarity index 89% rename from src/prelude/heap.rs rename to crates/service/src/prelude/heap.rs index 908fe463e..a5e2be3fd 100644 --- a/src/prelude/heap.rs +++ b/crates/service/src/prelude/heap.rs @@ -1,9 +1,9 @@ -use crate::prelude::{Payload, Scalar}; +use crate::prelude::{Payload, F32}; use std::{cmp::Reverse, collections::BinaryHeap}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub struct HeapElement { - pub distance: Scalar, + pub distance: F32, pub payload: Payload, } @@ -20,7 +20,7 @@ impl Heap { k, } } - pub fn check(&self, distance: Scalar) -> bool { + pub fn check(&self, distance: F32) -> bool { self.binary_heap.len() < self.k || distance < self.binary_heap.peek().unwrap().distance } pub fn push(&mut self, element: HeapElement) -> Option { diff --git a/crates/service/src/prelude/mod.rs b/crates/service/src/prelude/mod.rs new file mode 100644 index 000000000..5b66102f2 --- /dev/null +++ b/crates/service/src/prelude/mod.rs @@ -0,0 +1,16 @@ +mod error; +mod filter; +mod global; +mod heap; +mod scalar; +mod sys; + +pub use self::error::{FriendlyError, FriendlyErrorLike, FriendlyResult}; +pub use self::global::*; +pub use self::scalar::{F16, F32}; + +pub use self::filter::{Filter, Payload}; +pub use self::heap::{Heap, HeapElement}; +pub use self::sys::{Id, Pointer}; + +pub use num_traits::{Float, Zero}; diff --git a/crates/service/src/prelude/scalar/f16.rs b/crates/service/src/prelude/scalar/f16.rs new file mode 100644 index 000000000..467542f06 --- /dev/null +++ b/crates/service/src/prelude/scalar/f16.rs @@ -0,0 +1,653 @@ +use crate::prelude::global::FloatCast; +use half::f16; +use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; +use std::fmt::{Debug, Display}; +use std::num::ParseFloatError; +use std::ops::*; +use std::str::FromStr; + +#[derive(Clone, Copy, Default, Serialize, Deserialize)] +#[repr(transparent)] +#[serde(transparent)] +pub struct F16(pub f16); + +impl Debug for F16 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Debug::fmt(&self.0, f) + } +} + +impl Display for F16 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Display::fmt(&self.0, f) + } +} + +impl PartialEq for F16 { + fn eq(&self, other: &Self) -> bool { + self.0.total_cmp(&other.0) == Ordering::Equal + } +} + +impl Eq for F16 {} + +impl PartialOrd for F16 { + #[inline(always)] + fn partial_cmp(&self, other: &Self) -> Option { + Some(Ord::cmp(self, other)) + } +} + +impl Ord for F16 { + #[inline(always)] + fn cmp(&self, other: &Self) -> Ordering { + self.0.total_cmp(&other.0) + } +} + +unsafe impl bytemuck::Zeroable for F16 {} + +unsafe impl bytemuck::Pod for F16 {} + +impl num_traits::Zero for F16 { + fn zero() -> Self { + Self(f16::zero()) + } + + fn is_zero(&self) -> bool { + self.0.is_zero() + } +} + +impl num_traits::One for F16 { + fn one() -> Self { + Self(f16::one()) + } +} + +impl num_traits::FromPrimitive for F16 { + fn from_i64(n: i64) -> Option { + f16::from_i64(n).map(Self) + } + + fn from_u64(n: u64) -> Option { + f16::from_u64(n).map(Self) + } + + fn from_isize(n: isize) -> Option { + f16::from_isize(n).map(Self) + } + + fn from_i8(n: i8) -> Option { + f16::from_i8(n).map(Self) + } + + fn from_i16(n: i16) -> Option { + f16::from_i16(n).map(Self) + } + + fn from_i32(n: i32) -> Option { + f16::from_i32(n).map(Self) + } + + fn from_i128(n: i128) -> Option { + f16::from_i128(n).map(Self) + } + + fn from_usize(n: usize) -> Option { + f16::from_usize(n).map(Self) + } + + fn from_u8(n: u8) -> Option { + f16::from_u8(n).map(Self) + } + + fn from_u16(n: u16) -> Option { + f16::from_u16(n).map(Self) + } + + fn from_u32(n: u32) -> Option { + f16::from_u32(n).map(Self) + } + + fn from_u128(n: u128) -> Option { + f16::from_u128(n).map(Self) + } + + fn from_f32(n: f32) -> Option { + Some(Self(f16::from_f32(n))) + } + + fn from_f64(n: f64) -> Option { + Some(Self(f16::from_f64(n))) + } +} + +impl num_traits::ToPrimitive for F16 { + fn to_i64(&self) -> Option { + self.0.to_i64() + } + + fn to_u64(&self) -> Option { + self.0.to_u64() + } + + fn to_isize(&self) -> Option { + self.0.to_isize() + } + + fn to_i8(&self) -> Option { + self.0.to_i8() + } + + fn to_i16(&self) -> Option { + self.0.to_i16() + } + + fn to_i32(&self) -> Option { + self.0.to_i32() + } + + fn to_i128(&self) -> Option { + self.0.to_i128() + } + + fn to_usize(&self) -> Option { + self.0.to_usize() + } + + fn to_u8(&self) -> Option { + self.0.to_u8() + } + + fn to_u16(&self) -> Option { + self.0.to_u16() + } + + fn to_u32(&self) -> Option { + self.0.to_u32() + } + + fn to_u128(&self) -> Option { + self.0.to_u128() + } + + fn to_f32(&self) -> Option { + Some(self.0.to_f32()) + } + + fn to_f64(&self) -> Option { + Some(self.0.to_f64()) + } +} + +impl num_traits::NumCast for F16 { + fn from(n: T) -> Option { + num_traits::NumCast::from(n).map(Self) + } +} + +impl num_traits::Num for F16 { + type FromStrRadixErr = ::FromStrRadixErr; + + fn from_str_radix(str: &str, radix: u32) -> Result { + f16::from_str_radix(str, radix).map(Self) + } +} + +impl num_traits::Float for F16 { + fn nan() -> Self { + Self(f16::nan()) + } + + fn infinity() -> Self { + Self(f16::infinity()) + } + + fn neg_infinity() -> Self { + Self(f16::neg_infinity()) + } + + fn neg_zero() -> Self { + Self(f16::neg_zero()) + } + + fn min_value() -> Self { + Self(f16::min_value()) + } + + fn min_positive_value() -> Self { + Self(f16::min_positive_value()) + } + + fn max_value() -> Self { + Self(f16::max_value()) + } + + fn is_nan(self) -> bool { + self.0.is_nan() + } + + fn is_infinite(self) -> bool { + self.0.is_infinite() + } + + fn is_finite(self) -> bool { + self.0.is_finite() + } + + fn is_normal(self) -> bool { + self.0.is_normal() + } + + fn classify(self) -> std::num::FpCategory { + self.0.classify() + } + + fn floor(self) -> Self { + Self(self.0.floor()) + } + + fn ceil(self) -> Self { + Self(self.0.ceil()) + } + + fn round(self) -> Self { + Self(self.0.round()) + } + + fn trunc(self) -> Self { + Self(self.0.trunc()) + } + + fn fract(self) -> Self { + Self(self.0.fract()) + } + + fn abs(self) -> Self { + Self(self.0.abs()) + } + + fn signum(self) -> Self { + Self(self.0.signum()) + } + + fn is_sign_positive(self) -> bool { + self.0.is_sign_positive() + } + + fn is_sign_negative(self) -> bool { + self.0.is_sign_negative() + } + + fn mul_add(self, a: Self, b: Self) -> Self { + Self(self.0.mul_add(a.0, b.0)) + } + + fn recip(self) -> Self { + Self(self.0.recip()) + } + + fn powi(self, n: i32) -> Self { + Self(self.0.powi(n)) + } + + fn powf(self, n: Self) -> Self { + Self(self.0.powf(n.0)) + } + + fn sqrt(self) -> Self { + Self(self.0.sqrt()) + } + + fn exp(self) -> Self { + Self(self.0.exp()) + } + + fn exp2(self) -> Self { + Self(self.0.exp2()) + } + + fn ln(self) -> Self { + Self(self.0.ln()) + } + + fn log(self, base: Self) -> Self { + Self(self.0.log(base.0)) + } + + fn log2(self) -> Self { + Self(self.0.log2()) + } + + fn log10(self) -> Self { + Self(self.0.log10()) + } + + fn max(self, other: Self) -> Self { + Self(self.0.max(other.0)) + } + + fn min(self, other: Self) -> Self { + Self(self.0.min(other.0)) + } + + fn abs_sub(self, _: Self) -> Self { + unimplemented!() + } + + fn cbrt(self) -> Self { + Self(self.0.cbrt()) + } + + fn hypot(self, other: Self) -> Self { + Self(self.0.hypot(other.0)) + } + + fn sin(self) -> Self { + Self(self.0.sin()) + } + + fn cos(self) -> Self { + Self(self.0.cos()) + } + + fn tan(self) -> Self { + Self(self.0.tan()) + } + + fn asin(self) -> Self { + Self(self.0.asin()) + } + + fn acos(self) -> Self { + Self(self.0.acos()) + } + + fn atan(self) -> Self { + Self(self.0.atan()) + } + + fn atan2(self, other: Self) -> Self { + Self(self.0.atan2(other.0)) + } + + fn sin_cos(self) -> (Self, Self) { + let (_x, _y) = self.0.sin_cos(); + (Self(_x), Self(_y)) + } + + fn exp_m1(self) -> Self { + Self(self.0.exp_m1()) + } + + fn ln_1p(self) -> Self { + Self(self.0.ln_1p()) + } + + fn sinh(self) -> Self { + Self(self.0.sinh()) + } + + fn cosh(self) -> Self { + Self(self.0.cosh()) + } + + fn tanh(self) -> Self { + Self(self.0.tanh()) + } + + fn asinh(self) -> Self { + Self(self.0.asinh()) + } + + fn acosh(self) -> Self { + Self(self.0.acosh()) + } + + fn atanh(self) -> Self { + Self(self.0.atanh()) + } + + fn integer_decode(self) -> (u64, i16, i8) { + self.0.integer_decode() + } + + fn epsilon() -> Self { + Self(f16::EPSILON) + } + + fn is_subnormal(self) -> bool { + self.0.classify() == std::num::FpCategory::Subnormal + } + + fn to_degrees(self) -> Self { + Self(self.0.to_degrees()) + } + + fn to_radians(self) -> Self { + Self(self.0.to_radians()) + } + + fn copysign(self, sign: Self) -> Self { + Self(self.0.copysign(sign.0)) + } +} + +impl Add for F16 { + type Output = F16; + + #[inline(always)] + fn add(self, rhs: F16) -> F16 { + unsafe { self::intrinsics::fadd_fast(self.0, rhs.0).into() } + } +} + +impl AddAssign for F16 { + #[inline(always)] + fn add_assign(&mut self, rhs: F16) { + unsafe { self.0 = self::intrinsics::fadd_fast(self.0, rhs.0) } + } +} + +impl Sub for F16 { + type Output = F16; + + #[inline(always)] + fn sub(self, rhs: F16) -> F16 { + unsafe { self::intrinsics::fsub_fast(self.0, rhs.0).into() } + } +} + +impl SubAssign for F16 { + #[inline(always)] + fn sub_assign(&mut self, rhs: F16) { + unsafe { self.0 = self::intrinsics::fsub_fast(self.0, rhs.0) } + } +} + +impl Mul for F16 { + type Output = F16; + + #[inline(always)] + fn mul(self, rhs: F16) -> F16 { + unsafe { self::intrinsics::fmul_fast(self.0, rhs.0).into() } + } +} + +impl MulAssign for F16 { + #[inline(always)] + fn mul_assign(&mut self, rhs: F16) { + unsafe { self.0 = self::intrinsics::fmul_fast(self.0, rhs.0) } + } +} + +impl Div for F16 { + type Output = F16; + + #[inline(always)] + fn div(self, rhs: F16) -> F16 { + unsafe { self::intrinsics::fdiv_fast(self.0, rhs.0).into() } + } +} + +impl DivAssign for F16 { + #[inline(always)] + fn div_assign(&mut self, rhs: F16) { + unsafe { self.0 = self::intrinsics::fdiv_fast(self.0, rhs.0) } + } +} + +impl Rem for F16 { + type Output = F16; + + #[inline(always)] + fn rem(self, rhs: F16) -> F16 { + unsafe { self::intrinsics::frem_fast(self.0, rhs.0).into() } + } +} + +impl RemAssign for F16 { + #[inline(always)] + fn rem_assign(&mut self, rhs: F16) { + unsafe { self.0 = self::intrinsics::frem_fast(self.0, rhs.0) } + } +} + +impl Neg for F16 { + type Output = Self; + + fn neg(self) -> Self::Output { + Self(self.0.neg()) + } +} + +impl FromStr for F16 { + type Err = ParseFloatError; + + fn from_str(s: &str) -> Result { + f16::from_str(s).map(|x| x.into()) + } +} + +impl FloatCast for F16 { + fn from_f32(x: f32) -> Self { + Self(f16::from_f32(x)) + } + + fn to_f32(self) -> f32 { + f16::to_f32(self.0) + } +} + +impl From for F16 { + fn from(value: f16) -> Self { + Self(value) + } +} + +impl From for f16 { + fn from(F16(float): F16) -> Self { + float + } +} + +impl Add for F16 { + type Output = F16; + + #[inline(always)] + fn add(self, rhs: f16) -> F16 { + unsafe { self::intrinsics::fadd_fast(self.0, rhs).into() } + } +} + +impl AddAssign for F16 { + fn add_assign(&mut self, rhs: f16) { + unsafe { self.0 = self::intrinsics::fadd_fast(self.0, rhs) } + } +} + +impl Sub for F16 { + type Output = F16; + + #[inline(always)] + fn sub(self, rhs: f16) -> F16 { + unsafe { self::intrinsics::fsub_fast(self.0, rhs).into() } + } +} + +impl SubAssign for F16 { + #[inline(always)] + fn sub_assign(&mut self, rhs: f16) { + unsafe { self.0 = self::intrinsics::fsub_fast(self.0, rhs) } + } +} + +impl Mul for F16 { + type Output = F16; + + #[inline(always)] + fn mul(self, rhs: f16) -> F16 { + unsafe { self::intrinsics::fmul_fast(self.0, rhs).into() } + } +} + +impl MulAssign for F16 { + #[inline(always)] + fn mul_assign(&mut self, rhs: f16) { + unsafe { self.0 = self::intrinsics::fmul_fast(self.0, rhs) } + } +} + +impl Div for F16 { + type Output = F16; + + #[inline(always)] + fn div(self, rhs: f16) -> F16 { + unsafe { self::intrinsics::fdiv_fast(self.0, rhs).into() } + } +} + +impl DivAssign for F16 { + #[inline(always)] + fn div_assign(&mut self, rhs: f16) { + unsafe { self.0 = self::intrinsics::fdiv_fast(self.0, rhs) } + } +} + +impl Rem for F16 { + type Output = F16; + + #[inline(always)] + fn rem(self, rhs: f16) -> F16 { + unsafe { self::intrinsics::frem_fast(self.0, rhs).into() } + } +} + +impl RemAssign for F16 { + #[inline(always)] + fn rem_assign(&mut self, rhs: f16) { + unsafe { self.0 = self::intrinsics::frem_fast(self.0, rhs) } + } +} + +mod intrinsics { + use half::f16; + + pub unsafe fn fadd_fast(lhs: f16, rhs: f16) -> f16 { + lhs + rhs + } + pub unsafe fn fsub_fast(lhs: f16, rhs: f16) -> f16 { + lhs - rhs + } + pub unsafe fn fmul_fast(lhs: f16, rhs: f16) -> f16 { + lhs * rhs + } + pub unsafe fn fdiv_fast(lhs: f16, rhs: f16) -> f16 { + lhs / rhs + } + pub unsafe fn frem_fast(lhs: f16, rhs: f16) -> f16 { + lhs % rhs + } +} diff --git a/crates/service/src/prelude/scalar/f32.rs b/crates/service/src/prelude/scalar/f32.rs new file mode 100644 index 000000000..a4e70a10a --- /dev/null +++ b/crates/service/src/prelude/scalar/f32.rs @@ -0,0 +1,632 @@ +use crate::prelude::global::FloatCast; +use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; +use std::fmt::{Debug, Display}; +use std::num::ParseFloatError; +use std::ops::*; +use std::str::FromStr; + +#[derive(Clone, Copy, Default, Serialize, Deserialize)] +#[repr(transparent)] +#[serde(transparent)] +pub struct F32(pub f32); + +impl Debug for F32 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Debug::fmt(&self.0, f) + } +} + +impl Display for F32 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Display::fmt(&self.0, f) + } +} + +impl PartialEq for F32 { + fn eq(&self, other: &Self) -> bool { + self.0.total_cmp(&other.0) == Ordering::Equal + } +} + +impl Eq for F32 {} + +impl PartialOrd for F32 { + #[inline(always)] + fn partial_cmp(&self, other: &Self) -> Option { + Some(Ord::cmp(self, other)) + } +} + +impl Ord for F32 { + #[inline(always)] + fn cmp(&self, other: &Self) -> Ordering { + self.0.total_cmp(&other.0) + } +} + +unsafe impl bytemuck::Zeroable for F32 {} + +unsafe impl bytemuck::Pod for F32 {} + +impl num_traits::Zero for F32 { + fn zero() -> Self { + Self(f32::zero()) + } + + fn is_zero(&self) -> bool { + self.0.is_zero() + } +} + +impl num_traits::One for F32 { + fn one() -> Self { + Self(f32::one()) + } +} + +impl num_traits::FromPrimitive for F32 { + fn from_i64(n: i64) -> Option { + f32::from_i64(n).map(Self) + } + + fn from_u64(n: u64) -> Option { + f32::from_u64(n).map(Self) + } + + fn from_isize(n: isize) -> Option { + f32::from_isize(n).map(Self) + } + + fn from_i8(n: i8) -> Option { + f32::from_i8(n).map(Self) + } + + fn from_i16(n: i16) -> Option { + f32::from_i16(n).map(Self) + } + + fn from_i32(n: i32) -> Option { + f32::from_i32(n).map(Self) + } + + fn from_i128(n: i128) -> Option { + f32::from_i128(n).map(Self) + } + + fn from_usize(n: usize) -> Option { + f32::from_usize(n).map(Self) + } + + fn from_u8(n: u8) -> Option { + f32::from_u8(n).map(Self) + } + + fn from_u16(n: u16) -> Option { + f32::from_u16(n).map(Self) + } + + fn from_u32(n: u32) -> Option { + f32::from_u32(n).map(Self) + } + + fn from_u128(n: u128) -> Option { + f32::from_u128(n).map(Self) + } + + fn from_f32(n: f32) -> Option { + f32::from_f32(n).map(Self) + } + + fn from_f64(n: f64) -> Option { + f32::from_f64(n).map(Self) + } +} + +impl num_traits::ToPrimitive for F32 { + fn to_i64(&self) -> Option { + self.0.to_i64() + } + + fn to_u64(&self) -> Option { + self.0.to_u64() + } + + fn to_isize(&self) -> Option { + self.0.to_isize() + } + + fn to_i8(&self) -> Option { + self.0.to_i8() + } + + fn to_i16(&self) -> Option { + self.0.to_i16() + } + + fn to_i32(&self) -> Option { + self.0.to_i32() + } + + fn to_i128(&self) -> Option { + self.0.to_i128() + } + + fn to_usize(&self) -> Option { + self.0.to_usize() + } + + fn to_u8(&self) -> Option { + self.0.to_u8() + } + + fn to_u16(&self) -> Option { + self.0.to_u16() + } + + fn to_u32(&self) -> Option { + self.0.to_u32() + } + + fn to_u128(&self) -> Option { + self.0.to_u128() + } + + fn to_f32(&self) -> Option { + self.0.to_f32() + } + + fn to_f64(&self) -> Option { + self.0.to_f64() + } +} + +impl num_traits::NumCast for F32 { + fn from(n: T) -> Option { + num_traits::NumCast::from(n).map(Self) + } +} + +impl num_traits::Num for F32 { + type FromStrRadixErr = ::FromStrRadixErr; + + fn from_str_radix(str: &str, radix: u32) -> Result { + f32::from_str_radix(str, radix).map(Self) + } +} + +impl num_traits::Float for F32 { + fn nan() -> Self { + Self(f32::nan()) + } + + fn infinity() -> Self { + Self(f32::infinity()) + } + + fn neg_infinity() -> Self { + Self(f32::neg_infinity()) + } + + fn neg_zero() -> Self { + Self(f32::neg_zero()) + } + + fn min_value() -> Self { + Self(f32::min_value()) + } + + fn min_positive_value() -> Self { + Self(f32::min_positive_value()) + } + + fn max_value() -> Self { + Self(f32::max_value()) + } + + fn is_nan(self) -> bool { + self.0.is_nan() + } + + fn is_infinite(self) -> bool { + self.0.is_infinite() + } + + fn is_finite(self) -> bool { + self.0.is_finite() + } + + fn is_normal(self) -> bool { + self.0.is_normal() + } + + fn classify(self) -> std::num::FpCategory { + self.0.classify() + } + + fn floor(self) -> Self { + Self(self.0.floor()) + } + + fn ceil(self) -> Self { + Self(self.0.ceil()) + } + + fn round(self) -> Self { + Self(self.0.round()) + } + + fn trunc(self) -> Self { + Self(self.0.trunc()) + } + + fn fract(self) -> Self { + Self(self.0.fract()) + } + + fn abs(self) -> Self { + Self(self.0.abs()) + } + + fn signum(self) -> Self { + Self(self.0.signum()) + } + + fn is_sign_positive(self) -> bool { + self.0.is_sign_positive() + } + + fn is_sign_negative(self) -> bool { + self.0.is_sign_negative() + } + + fn mul_add(self, a: Self, b: Self) -> Self { + Self(self.0.mul_add(a.0, b.0)) + } + + fn recip(self) -> Self { + Self(self.0.recip()) + } + + fn powi(self, n: i32) -> Self { + Self(self.0.powi(n)) + } + + fn powf(self, n: Self) -> Self { + Self(self.0.powf(n.0)) + } + + fn sqrt(self) -> Self { + Self(self.0.sqrt()) + } + + fn exp(self) -> Self { + Self(self.0.exp()) + } + + fn exp2(self) -> Self { + Self(self.0.exp2()) + } + + fn ln(self) -> Self { + Self(self.0.ln()) + } + + fn log(self, base: Self) -> Self { + Self(self.0.log(base.0)) + } + + fn log2(self) -> Self { + Self(self.0.log2()) + } + + fn log10(self) -> Self { + Self(self.0.log10()) + } + + fn max(self, other: Self) -> Self { + Self(self.0.max(other.0)) + } + + fn min(self, other: Self) -> Self { + Self(self.0.min(other.0)) + } + + fn abs_sub(self, _: Self) -> Self { + unimplemented!() + } + + fn cbrt(self) -> Self { + Self(self.0.cbrt()) + } + + fn hypot(self, other: Self) -> Self { + Self(self.0.hypot(other.0)) + } + + fn sin(self) -> Self { + Self(self.0.sin()) + } + + fn cos(self) -> Self { + Self(self.0.cos()) + } + + fn tan(self) -> Self { + Self(self.0.tan()) + } + + fn asin(self) -> Self { + Self(self.0.asin()) + } + + fn acos(self) -> Self { + Self(self.0.acos()) + } + + fn atan(self) -> Self { + Self(self.0.atan()) + } + + fn atan2(self, other: Self) -> Self { + Self(self.0.atan2(other.0)) + } + + fn sin_cos(self) -> (Self, Self) { + let (_x, _y) = self.0.sin_cos(); + (Self(_x), Self(_y)) + } + + fn exp_m1(self) -> Self { + Self(self.0.exp_m1()) + } + + fn ln_1p(self) -> Self { + Self(self.0.ln_1p()) + } + + fn sinh(self) -> Self { + Self(self.0.sinh()) + } + + fn cosh(self) -> Self { + Self(self.0.cosh()) + } + + fn tanh(self) -> Self { + Self(self.0.tanh()) + } + + fn asinh(self) -> Self { + Self(self.0.asinh()) + } + + fn acosh(self) -> Self { + Self(self.0.acosh()) + } + + fn atanh(self) -> Self { + Self(self.0.atanh()) + } + + fn integer_decode(self) -> (u64, i16, i8) { + self.0.integer_decode() + } + + fn epsilon() -> Self { + Self(f32::EPSILON) + } + + fn is_subnormal(self) -> bool { + self.0.classify() == std::num::FpCategory::Subnormal + } + + fn to_degrees(self) -> Self { + Self(self.0.to_degrees()) + } + + fn to_radians(self) -> Self { + Self(self.0.to_radians()) + } + + fn copysign(self, sign: Self) -> Self { + Self(self.0.copysign(sign.0)) + } +} + +impl Add for F32 { + type Output = F32; + + #[inline(always)] + fn add(self, rhs: F32) -> F32 { + unsafe { std::intrinsics::fadd_fast(self.0, rhs.0).into() } + } +} + +impl AddAssign for F32 { + #[inline(always)] + fn add_assign(&mut self, rhs: F32) { + unsafe { self.0 = std::intrinsics::fadd_fast(self.0, rhs.0) } + } +} + +impl Sub for F32 { + type Output = F32; + + #[inline(always)] + fn sub(self, rhs: F32) -> F32 { + unsafe { std::intrinsics::fsub_fast(self.0, rhs.0).into() } + } +} + +impl SubAssign for F32 { + #[inline(always)] + fn sub_assign(&mut self, rhs: F32) { + unsafe { self.0 = std::intrinsics::fsub_fast(self.0, rhs.0) } + } +} + +impl Mul for F32 { + type Output = F32; + + #[inline(always)] + fn mul(self, rhs: F32) -> F32 { + unsafe { std::intrinsics::fmul_fast(self.0, rhs.0).into() } + } +} + +impl MulAssign for F32 { + #[inline(always)] + fn mul_assign(&mut self, rhs: F32) { + unsafe { self.0 = std::intrinsics::fmul_fast(self.0, rhs.0) } + } +} + +impl Div for F32 { + type Output = F32; + + #[inline(always)] + fn div(self, rhs: F32) -> F32 { + unsafe { std::intrinsics::fdiv_fast(self.0, rhs.0).into() } + } +} + +impl DivAssign for F32 { + #[inline(always)] + fn div_assign(&mut self, rhs: F32) { + unsafe { self.0 = std::intrinsics::fdiv_fast(self.0, rhs.0) } + } +} + +impl Rem for F32 { + type Output = F32; + + #[inline(always)] + fn rem(self, rhs: F32) -> F32 { + unsafe { std::intrinsics::frem_fast(self.0, rhs.0).into() } + } +} + +impl RemAssign for F32 { + #[inline(always)] + fn rem_assign(&mut self, rhs: F32) { + unsafe { self.0 = std::intrinsics::frem_fast(self.0, rhs.0) } + } +} + +impl Neg for F32 { + type Output = Self; + + fn neg(self) -> Self::Output { + Self(self.0.neg()) + } +} + +impl FromStr for F32 { + type Err = ParseFloatError; + + fn from_str(s: &str) -> Result { + f32::from_str(s).map(|x| x.into()) + } +} + +impl FloatCast for F32 { + fn from_f32(x: f32) -> Self { + Self(x) + } + + fn to_f32(self) -> f32 { + self.0 + } +} + +impl From for F32 { + fn from(value: f32) -> Self { + Self(value) + } +} + +impl From for f32 { + fn from(F32(float): F32) -> Self { + float + } +} + +impl Add for F32 { + type Output = F32; + + #[inline(always)] + fn add(self, rhs: f32) -> F32 { + unsafe { std::intrinsics::fadd_fast(self.0, rhs).into() } + } +} + +impl AddAssign for F32 { + fn add_assign(&mut self, rhs: f32) { + unsafe { self.0 = std::intrinsics::fadd_fast(self.0, rhs) } + } +} + +impl Sub for F32 { + type Output = F32; + + #[inline(always)] + fn sub(self, rhs: f32) -> F32 { + unsafe { std::intrinsics::fsub_fast(self.0, rhs).into() } + } +} + +impl SubAssign for F32 { + #[inline(always)] + fn sub_assign(&mut self, rhs: f32) { + unsafe { self.0 = std::intrinsics::fsub_fast(self.0, rhs) } + } +} + +impl Mul for F32 { + type Output = F32; + + #[inline(always)] + fn mul(self, rhs: f32) -> F32 { + unsafe { std::intrinsics::fmul_fast(self.0, rhs).into() } + } +} + +impl MulAssign for F32 { + #[inline(always)] + fn mul_assign(&mut self, rhs: f32) { + unsafe { self.0 = std::intrinsics::fmul_fast(self.0, rhs) } + } +} + +impl Div for F32 { + type Output = F32; + + #[inline(always)] + fn div(self, rhs: f32) -> F32 { + unsafe { std::intrinsics::fdiv_fast(self.0, rhs).into() } + } +} + +impl DivAssign for F32 { + #[inline(always)] + fn div_assign(&mut self, rhs: f32) { + unsafe { self.0 = std::intrinsics::fdiv_fast(self.0, rhs) } + } +} + +impl Rem for F32 { + type Output = F32; + + #[inline(always)] + fn rem(self, rhs: f32) -> F32 { + unsafe { std::intrinsics::frem_fast(self.0, rhs).into() } + } +} + +impl RemAssign for F32 { + #[inline(always)] + fn rem_assign(&mut self, rhs: f32) { + unsafe { self.0 = std::intrinsics::frem_fast(self.0, rhs) } + } +} diff --git a/crates/service/src/prelude/scalar/mod.rs b/crates/service/src/prelude/scalar/mod.rs new file mode 100644 index 000000000..1894a906f --- /dev/null +++ b/crates/service/src/prelude/scalar/mod.rs @@ -0,0 +1,5 @@ +mod f16; +mod f32; + +pub use f16::F16; +pub use f32::F32; diff --git a/src/prelude/sys.rs b/crates/service/src/prelude/sys.rs similarity index 53% rename from src/prelude/sys.rs rename to crates/service/src/prelude/sys.rs index 0318149ef..640ff7a72 100644 --- a/src/prelude/sys.rs +++ b/crates/service/src/prelude/sys.rs @@ -3,15 +3,10 @@ use std::{fmt::Display, num::ParseIntError, str::FromStr}; #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct Id { - newtype: u32, + pub newtype: u32, } impl Id { - pub fn from_sys(sys: pgrx::pg_sys::Oid) -> Self { - Self { - newtype: sys.as_u32(), - } - } pub fn as_u32(self) -> u32 { self.newtype } @@ -35,26 +30,10 @@ impl FromStr for Id { #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] pub struct Pointer { - newtype: u64, + pub newtype: u64, } impl Pointer { - pub fn from_sys(sys: pgrx::pg_sys::ItemPointerData) -> Self { - let mut newtype = 0; - newtype |= (sys.ip_blkid.bi_hi as u64) << 32; - newtype |= (sys.ip_blkid.bi_lo as u64) << 16; - newtype |= sys.ip_posid as u64; - Self { newtype } - } - pub fn into_sys(self) -> pgrx::pg_sys::ItemPointerData { - pgrx::pg_sys::ItemPointerData { - ip_blkid: pgrx::pg_sys::BlockIdData { - bi_hi: ((self.newtype >> 32) & 0xffff) as u16, - bi_lo: ((self.newtype >> 16) & 0xffff) as u16, - }, - ip_posid: (self.newtype & 0xffff) as u16, - } - } pub fn from_u48(value: u64) -> Self { assert!(value < (1u64 << 48)); Self { newtype: value } diff --git a/crates/service/src/utils/cells.rs b/crates/service/src/utils/cells.rs new file mode 100644 index 000000000..83a0a7a57 --- /dev/null +++ b/crates/service/src/utils/cells.rs @@ -0,0 +1,26 @@ +use std::cell::UnsafeCell; + +#[repr(transparent)] +pub struct SyncUnsafeCell { + value: UnsafeCell, +} + +unsafe impl Sync for SyncUnsafeCell {} + +impl SyncUnsafeCell { + pub const fn new(value: T) -> Self { + Self { + value: UnsafeCell::new(value), + } + } +} + +impl SyncUnsafeCell { + pub fn get(&self) -> *mut T { + self.value.get() + } + + pub fn get_mut(&mut self) -> &mut T { + self.value.get_mut() + } +} diff --git a/src/utils/clean.rs b/crates/service/src/utils/clean.rs similarity index 100% rename from src/utils/clean.rs rename to crates/service/src/utils/clean.rs diff --git a/crates/service/src/utils/detect.rs b/crates/service/src/utils/detect.rs new file mode 100644 index 000000000..2a99bf589 --- /dev/null +++ b/crates/service/src/utils/detect.rs @@ -0,0 +1 @@ +pub mod x86_64; diff --git a/crates/service/src/utils/detect/x86_64.rs b/crates/service/src/utils/detect/x86_64.rs new file mode 100644 index 000000000..5bd3c8705 --- /dev/null +++ b/crates/service/src/utils/detect/x86_64.rs @@ -0,0 +1,85 @@ +#![cfg(target_arch = "x86_64")] + +use std::sync::atomic::{AtomicBool, Ordering}; + +static ATOMIC_AVX512FP16: AtomicBool = AtomicBool::new(false); + +pub fn test_avx512fp16() -> bool { + std_detect::is_x86_feature_detected!("avx512fp16") && test_v4() +} + +#[ctor::ctor] +fn ctor_avx512fp16() { + ATOMIC_AVX512FP16.store(test_avx512fp16(), Ordering::Relaxed); +} + +pub fn detect_avx512fp16() -> bool { + ATOMIC_AVX512FP16.load(Ordering::Relaxed) +} + +static ATOMIC_V4: AtomicBool = AtomicBool::new(false); + +pub fn test_v4() -> bool { + std_detect::is_x86_feature_detected!("avx512bw") + && std_detect::is_x86_feature_detected!("avx512cd") + && std_detect::is_x86_feature_detected!("avx512dq") + && std_detect::is_x86_feature_detected!("avx512f") + && std_detect::is_x86_feature_detected!("avx512vl") + && test_v3() +} + +#[ctor::ctor] +fn ctor_v4() { + ATOMIC_V4.store(test_v4(), Ordering::Relaxed); +} + +pub fn _detect_v4() -> bool { + ATOMIC_V4.load(Ordering::Relaxed) +} + +static ATOMIC_V3: AtomicBool = AtomicBool::new(false); + +pub fn test_v3() -> bool { + std_detect::is_x86_feature_detected!("avx") + && std_detect::is_x86_feature_detected!("avx2") + && std_detect::is_x86_feature_detected!("bmi1") + && std_detect::is_x86_feature_detected!("bmi2") + && std_detect::is_x86_feature_detected!("f16c") + && std_detect::is_x86_feature_detected!("fma") + && std_detect::is_x86_feature_detected!("lzcnt") + && std_detect::is_x86_feature_detected!("movbe") + && std_detect::is_x86_feature_detected!("xsave") + && test_v2() +} + +#[ctor::ctor] +fn ctor_v3() { + ATOMIC_V3.store(test_v3(), Ordering::Relaxed); +} + +pub fn detect_v3() -> bool { + ATOMIC_V3.load(Ordering::Relaxed) +} + +static ATOMIC_V2: AtomicBool = AtomicBool::new(false); + +pub fn test_v2() -> bool { + std_detect::is_x86_feature_detected!("cmpxchg16b") + && std_detect::is_x86_feature_detected!("fxsr") + && std_detect::is_x86_feature_detected!("popcnt") + && std_detect::is_x86_feature_detected!("sse") + && std_detect::is_x86_feature_detected!("sse2") + && std_detect::is_x86_feature_detected!("sse3") + && std_detect::is_x86_feature_detected!("sse4.1") + && std_detect::is_x86_feature_detected!("sse4.2") + && std_detect::is_x86_feature_detected!("ssse3") +} + +#[ctor::ctor] +fn ctor_v2() { + ATOMIC_V2.store(test_v2(), Ordering::Relaxed); +} + +pub fn _detect_v2() -> bool { + ATOMIC_V2.load(Ordering::Relaxed) +} diff --git a/src/utils/dir_ops.rs b/crates/service/src/utils/dir_ops.rs similarity index 100% rename from src/utils/dir_ops.rs rename to crates/service/src/utils/dir_ops.rs diff --git a/src/utils/file_atomic.rs b/crates/service/src/utils/file_atomic.rs similarity index 100% rename from src/utils/file_atomic.rs rename to crates/service/src/utils/file_atomic.rs diff --git a/src/utils/file_wal.rs b/crates/service/src/utils/file_wal.rs similarity index 100% rename from src/utils/file_wal.rs rename to crates/service/src/utils/file_wal.rs diff --git a/src/utils/mmap_array.rs b/crates/service/src/utils/mmap_array.rs similarity index 95% rename from src/utils/mmap_array.rs rename to crates/service/src/utils/mmap_array.rs index 02d37233a..2a83b9e84 100644 --- a/src/utils/mmap_array.rs +++ b/crates/service/src/utils/mmap_array.rs @@ -111,9 +111,11 @@ fn read_information(mut file: &File) -> Information { unsafe fn read_mmap(file: &File, len: usize) -> memmap2::Mmap { let len = len.next_multiple_of(4096); - memmap2::MmapOptions::new() - .populate() - .len(len) - .map(file) - .unwrap() + unsafe { + memmap2::MmapOptions::new() + .populate() + .len(len) + .map(file) + .unwrap() + } } diff --git a/crates/service/src/utils/mod.rs b/crates/service/src/utils/mod.rs new file mode 100644 index 000000000..e42242438 --- /dev/null +++ b/crates/service/src/utils/mod.rs @@ -0,0 +1,8 @@ +pub mod cells; +pub mod clean; +pub mod detect; +pub mod dir_ops; +pub mod file_atomic; +pub mod file_wal; +pub mod mmap_array; +pub mod vec2; diff --git a/src/utils/vec2.rs b/crates/service/src/utils/vec2.rs similarity index 78% rename from src/utils/vec2.rs rename to crates/service/src/utils/vec2.rs index 1a338ddea..2640f51c8 100644 --- a/src/utils/vec2.rs +++ b/crates/service/src/utils/vec2.rs @@ -2,16 +2,16 @@ use crate::prelude::*; use std::ops::{Deref, DerefMut, Index, IndexMut}; #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -pub struct Vec2 { +pub struct Vec2 { dims: u16, - v: Box<[Scalar]>, + v: Vec, } -impl Vec2 { +impl Vec2 { pub fn new(dims: u16, n: usize) -> Self { Self { dims, - v: bytemuck::zeroed_slice_box(dims as usize * n), + v: bytemuck::zeroed_vec(dims as usize * n), } } pub fn dims(&self) -> u16 { @@ -32,29 +32,29 @@ impl Vec2 { } } -impl Index for Vec2 { - type Output = [Scalar]; +impl Index for Vec2 { + type Output = [S::Scalar]; fn index(&self, index: usize) -> &Self::Output { &self.v[self.dims as usize * index..][..self.dims as usize] } } -impl IndexMut for Vec2 { +impl IndexMut for Vec2 { fn index_mut(&mut self, index: usize) -> &mut Self::Output { &mut self.v[self.dims as usize * index..][..self.dims as usize] } } -impl Deref for Vec2 { - type Target = [Scalar]; +impl Deref for Vec2 { + type Target = [S::Scalar]; fn deref(&self) -> &Self::Target { self.v.deref() } } -impl DerefMut for Vec2 { +impl DerefMut for Vec2 { fn deref_mut(&mut self) -> &mut Self::Target { self.v.deref_mut() } diff --git a/crates/service/src/worker/instance.rs b/crates/service/src/worker/instance.rs new file mode 100644 index 000000000..16c2df1ce --- /dev/null +++ b/crates/service/src/worker/instance.rs @@ -0,0 +1,248 @@ +use crate::index::Index; +use crate::index::IndexOptions; +use crate::index::IndexStat; +use crate::index::IndexView; +use crate::index::OutdatedError; +use crate::prelude::*; +use std::path::PathBuf; +use std::sync::Arc; + +#[derive(Clone)] +pub enum Instance { + F32Cos(Arc>), + F32Dot(Arc>), + F32L2(Arc>), + F16Cos(Arc>), + F16Dot(Arc>), + F16L2(Arc>), +} + +impl Instance { + pub fn create(path: PathBuf, options: IndexOptions) -> Self { + match (options.vector.d, options.vector.k) { + (Distance::Cos, Kind::F32) => Self::F32Cos(Index::create(path, options)), + (Distance::Dot, Kind::F32) => Self::F32Dot(Index::create(path, options)), + (Distance::L2, Kind::F32) => Self::F32L2(Index::create(path, options)), + (Distance::Cos, Kind::F16) => Self::F16Cos(Index::create(path, options)), + (Distance::Dot, Kind::F16) => Self::F16Dot(Index::create(path, options)), + (Distance::L2, Kind::F16) => Self::F16L2(Index::create(path, options)), + } + } + pub fn open(path: PathBuf, options: IndexOptions) -> Self { + match (options.vector.d, options.vector.k) { + (Distance::Cos, Kind::F32) => Self::F32Cos(Index::open(path, options)), + (Distance::Dot, Kind::F32) => Self::F32Dot(Index::open(path, options)), + (Distance::L2, Kind::F32) => Self::F32L2(Index::open(path, options)), + (Distance::Cos, Kind::F16) => Self::F16Cos(Index::open(path, options)), + (Distance::Dot, Kind::F16) => Self::F16Dot(Index::open(path, options)), + (Distance::L2, Kind::F16) => Self::F16L2(Index::open(path, options)), + } + } + pub fn options(&self) -> &IndexOptions { + match self { + Instance::F32Cos(x) => x.options(), + Instance::F32Dot(x) => x.options(), + Instance::F32L2(x) => x.options(), + Instance::F16Cos(x) => x.options(), + Instance::F16Dot(x) => x.options(), + Instance::F16L2(x) => x.options(), + } + } + pub fn refresh(&self) { + match self { + Instance::F32Cos(x) => x.refresh(), + Instance::F32Dot(x) => x.refresh(), + Instance::F32L2(x) => x.refresh(), + Instance::F16Cos(x) => x.refresh(), + Instance::F16Dot(x) => x.refresh(), + Instance::F16L2(x) => x.refresh(), + } + } + pub fn view(&self) -> InstanceView { + match self { + Instance::F32Cos(x) => InstanceView::F32Cos(x.view()), + Instance::F32Dot(x) => InstanceView::F32Dot(x.view()), + Instance::F32L2(x) => InstanceView::F32L2(x.view()), + Instance::F16Cos(x) => InstanceView::F16Cos(x.view()), + Instance::F16Dot(x) => InstanceView::F16Dot(x.view()), + Instance::F16L2(x) => InstanceView::F16L2(x.view()), + } + } + pub fn stat(&self) -> IndexStat { + match self { + Instance::F32Cos(x) => x.stat(), + Instance::F32Dot(x) => x.stat(), + Instance::F32L2(x) => x.stat(), + Instance::F16Cos(x) => x.stat(), + Instance::F16Dot(x) => x.stat(), + Instance::F16L2(x) => x.stat(), + } + } +} + +pub enum InstanceView { + F32Cos(Arc>), + F32Dot(Arc>), + F32L2(Arc>), + F16Cos(Arc>), + F16Dot(Arc>), + F16L2(Arc>), +} + +impl InstanceView { + pub fn search bool>( + &self, + k: usize, + vector: DynamicVector, + filter: F, + ) -> Result, FriendlyError> { + match (self, vector) { + (InstanceView::F32Cos(x), DynamicVector::F32(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.search(k, &vector, filter)) + } + (InstanceView::F32Dot(x), DynamicVector::F32(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.search(k, &vector, filter)) + } + (InstanceView::F32L2(x), DynamicVector::F32(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.search(k, &vector, filter)) + } + (InstanceView::F16Cos(x), DynamicVector::F16(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.search(k, &vector, filter)) + } + (InstanceView::F16Dot(x), DynamicVector::F16(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.search(k, &vector, filter)) + } + (InstanceView::F16L2(x), DynamicVector::F16(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.search(k, &vector, filter)) + } + _ => Err(FriendlyError::Unmatched2), + } + } + pub fn vbase( + &self, + vector: DynamicVector, + ) -> Result + '_, FriendlyError> { + match (self, vector) { + (InstanceView::F32Cos(x), DynamicVector::F32(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(Box::new(x.vbase(&vector)) as Box>) + } + (InstanceView::F32Dot(x), DynamicVector::F32(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(Box::new(x.vbase(&vector))) + } + (InstanceView::F32L2(x), DynamicVector::F32(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(Box::new(x.vbase(&vector))) + } + (InstanceView::F16Cos(x), DynamicVector::F16(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(Box::new(x.vbase(&vector))) + } + (InstanceView::F16Dot(x), DynamicVector::F16(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(Box::new(x.vbase(&vector))) + } + (InstanceView::F16L2(x), DynamicVector::F16(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(Box::new(x.vbase(&vector))) + } + _ => Err(FriendlyError::Unmatched2), + } + } + pub fn insert( + &self, + vector: DynamicVector, + pointer: Pointer, + ) -> Result, FriendlyError> { + match (self, vector) { + (InstanceView::F32Cos(x), DynamicVector::F32(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.insert(vector, pointer)) + } + (InstanceView::F32Dot(x), DynamicVector::F32(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.insert(vector, pointer)) + } + (InstanceView::F32L2(x), DynamicVector::F32(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.insert(vector, pointer)) + } + (InstanceView::F16Cos(x), DynamicVector::F16(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.insert(vector, pointer)) + } + (InstanceView::F16Dot(x), DynamicVector::F16(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.insert(vector, pointer)) + } + (InstanceView::F16L2(x), DynamicVector::F16(vector)) => { + if x.options.vector.dims as usize != vector.len() { + return Err(FriendlyError::Unmatched2); + } + Ok(x.insert(vector, pointer)) + } + _ => Err(FriendlyError::Unmatched2), + } + } + pub fn delete bool>(&self, f: F) { + match self { + InstanceView::F32Cos(x) => x.delete(f), + InstanceView::F32Dot(x) => x.delete(f), + InstanceView::F32L2(x) => x.delete(f), + InstanceView::F16Cos(x) => x.delete(f), + InstanceView::F16Dot(x) => x.delete(f), + InstanceView::F16L2(x) => x.delete(f), + } + } + pub fn flush(&self) { + match self { + InstanceView::F32Cos(x) => x.flush(), + InstanceView::F32Dot(x) => x.flush(), + InstanceView::F32L2(x) => x.flush(), + InstanceView::F16Cos(x) => x.flush(), + InstanceView::F16Dot(x) => x.flush(), + InstanceView::F16L2(x) => x.flush(), + } + } +} diff --git a/src/bgworker/worker.rs b/crates/service/src/worker/mod.rs similarity index 61% rename from src/bgworker/worker.rs rename to crates/service/src/worker/mod.rs index 15b53d662..3e7c97418 100644 --- a/src/bgworker/worker.rs +++ b/crates/service/src/worker/mod.rs @@ -1,7 +1,9 @@ -use crate::index::Index; -use crate::index::IndexInsertError; +pub mod instance; + +use self::instance::Instance; use crate::index::IndexOptions; -use crate::index::IndexSearchError; +use crate::index::IndexStat; +use crate::index::OutdatedError; use crate::prelude::*; use crate::utils::clean::clean; use crate::utils::dir_ops::sync_dir; @@ -57,7 +59,7 @@ impl Worker { let mut indexes = HashMap::new(); for (&id, options) in startup.get().indexes.iter() { let path = path.join("indexes").join(id.to_string()); - let index = Index::open(path, options.clone()); + let index = Instance::open(path, options.clone()); indexes.insert(id, index); } let view = Arc::new(WorkerView { @@ -72,7 +74,7 @@ impl Worker { } pub fn call_create(&self, id: Id, options: IndexOptions) { let mut protect = self.protect.lock(); - let index = Index::create(self.path.join("indexes").join(id.to_string()), options); + let index = Instance::create(self.path.join("indexes").join(id.to_string()), options); if protect.indexes.insert(id, index).is_some() { panic!("index {} already exists", id) } @@ -81,44 +83,29 @@ impl Worker { pub fn call_search( &self, id: Id, - search: (Vec, usize), + search: (DynamicVector, usize), filter: F, ) -> Result, FriendlyError> where F: FnMut(Pointer) -> bool, { let view = self.view.load_full(); - let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?; + let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?; let view = index.view(); - match view.search(search.1, &search.0, filter) { - Ok(x) => Ok(x), - Err(IndexSearchError::InvalidVector(x)) => Err(FriendlyError::BadVector(x)), - } + view.search(search.1, search.0, filter) } - pub fn call_search_vbase( + pub fn call_insert( &self, id: Id, - search: (Vec, usize), - next: F, - ) -> Result<(), FriendlyError> - where - F: FnMut(Pointer) -> bool, - { + insert: (DynamicVector, Pointer), + ) -> Result<(), FriendlyError> { let view = self.view.load_full(); - let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?; - let view = index.view(); - view.search_vbase(search.1, &search.0, next) - .map_err(|IndexSearchError::InvalidVector(x)| FriendlyError::BadVector(x)) - } - pub fn call_insert(&self, id: Id, insert: (Vec, Pointer)) -> Result<(), FriendlyError> { - let view = self.view.load_full(); - let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?; + let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?; loop { let view = index.view(); - match view.insert(insert.0.clone(), insert.1) { + match view.insert(insert.0.clone(), insert.1)? { Ok(()) => break Ok(()), - Err(IndexInsertError::InvalidVector(x)) => break Err(FriendlyError::BadVector(x)), - Err(IndexInsertError::OutdatedView(_)) => index.refresh(), + Err(OutdatedError(_)) => index.refresh(), } } } @@ -127,16 +114,16 @@ impl Worker { F: FnMut(Pointer) -> bool, { let view = self.view.load_full(); - let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?; + let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?; let view = index.view(); view.delete(f); Ok(()) } pub fn call_flush(&self, id: Id) -> Result<(), FriendlyError> { let view = self.view.load_full(); - let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?; + let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?; let view = index.view(); - view.flush().unwrap(); + view.flush(); Ok(()) } pub fn call_destory(&self, ids: Vec) { @@ -149,44 +136,25 @@ impl Worker { protect.maintain(&self.view); } } - pub fn call_stat(&self, id: Id) -> Result { + pub fn call_stat(&self, id: Id) -> Result { let view = self.view.load_full(); - let index = view.indexes.get(&id).ok_or(FriendlyError::Index404)?; - let view = index.view(); - let idx_sealed_len = view.sealed_len(); - let idx_growing_len = view.growing_len(); - let idx_write = view.write_len(); - let res = VectorIndexInfo { - indexing: index.indexing(), - idx_tuples: (idx_write + idx_sealed_len + idx_growing_len) - .try_into() - .unwrap(), - idx_sealed_len: idx_sealed_len.try_into().unwrap(), - idx_growing_len: idx_growing_len.try_into().unwrap(), - idx_write: idx_write.try_into().unwrap(), - idx_sealed: view - .sealed_len_vec() - .into_iter() - .map(|x| x.try_into().unwrap()) - .collect(), - idx_growing: view - .growing_len_vec() - .into_iter() - .map(|x| x.try_into().unwrap()) - .collect(), - idx_config: serde_json::to_string(index.options()).unwrap(), - }; - Ok(res) + let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?; + Ok(index.stat()) + } + pub fn get_instance(&self, id: Id) -> Result { + let view = self.view.load_full(); + let index = view.indexes.get(&id).ok_or(FriendlyError::UnknownIndex)?; + Ok(index.clone()) } } struct WorkerView { - indexes: HashMap>, + indexes: HashMap, } struct WorkerProtect { startup: FileAtomic, - indexes: HashMap>, + indexes: HashMap, } impl WorkerProtect { diff --git a/docs/comparison-with-specialized-vectordb.md b/docs/comparison-with-specialized-vectordb.md index 7c6c9817b..f02b90e90 100644 --- a/docs/comparison-with-specialized-vectordb.md +++ b/docs/comparison-with-specialized-vectordb.md @@ -11,7 +11,7 @@ Why not just use Postgres to do the vector similarity search? This is the reason UPDATE documents SET embedding = ai_embedding_vector(content) WHERE length(embedding) = 0; -- Create an index on the embedding column -CREATE INDEX ON documents USING vectors (embedding l2_ops); +CREATE INDEX ON documents USING vectors (embedding vector_l2_ops); -- Query the similar embeddings SELECT * FROM documents ORDER BY embedding <-> ai_embedding_vector('hello world') LIMIT 5; diff --git a/docs/get-started.md b/docs/get-started.md index 0321795a2..24b5ca3eb 100644 --- a/docs/get-started.md +++ b/docs/get-started.md @@ -46,9 +46,9 @@ We support three operators to calculate the distance between two vectors. -- squared Euclidean distance SELECT '[1, 2, 3]'::vector <-> '[3, 2, 1]'::vector; -- negative dot product -SELECT '[1, 2, 3]' <#> '[3, 2, 1]'; +SELECT '[1, 2, 3]'::vector <#> '[3, 2, 1]'::vector; -- negative cosine similarity -SELECT '[1, 2, 3]' <=> '[3, 2, 1]'; +SELECT '[1, 2, 3]'::vector <=> '[3, 2, 1]'::vector; ``` You can search for a vector simply like this. @@ -58,6 +58,10 @@ You can search for a vector simply like this. SELECT * FROM items ORDER BY embedding <-> '[3,2,1]' LIMIT 5; ``` +## Half-precision floating-point + +`vecf16` type is the same with `vector` in anything but the scalar type. It stores 16-bit floating point numbers. If you want to reduce the memory usage to get better performace, you can try to replace `vector` type with `vecf16` type. + ## Things You Need to Know `vector(n)` is a valid data type only if $1 \leq n \leq 65535$. Due to limits of PostgreSQL, it's possible to create a value of type `vector(3)` of $5$ dimensions and `vector` is also a valid data. However, you cannot still put $0$ scalar or more than $65535$ scalars to a vector. If you use `vector` for a column or there is some values mismatched with dimension denoted by the column, you won't able to create an index on it. diff --git a/docs/indexing.md b/docs/indexing.md index 9fcd8a834..2d4968934 100644 --- a/docs/indexing.md +++ b/docs/indexing.md @@ -5,11 +5,19 @@ Indexing is the core ability of pgvecto.rs. Assuming there is a table `items` and there is a column named `embedding` of type `vector(n)`, you can create a vector index for squared Euclidean distance with the following SQL. ```sql -CREATE INDEX ON items USING vectors (embedding l2_ops); +CREATE INDEX ON items USING vectors (embedding vector_l2_ops); ``` -For negative dot product, replace `l2_ops` with `dot_ops`. -For negative cosine similarity, replace `l2_ops` with `cosine_ops`. +There is a table for you to choose a proper operator class for creating indexes. + +| Vector type | Distance type | Operator class | +| ----------- | -------------------------- | -------------- | +| vector | squared Euclidean distance | vector_l2_ops | +| vector | negative dot product | vector_dot_ops | +| vector | negative cosine similarity | vector_cos_ops | +| vecf16 | squared Euclidean distance | vecf16_l2_ops | +| vecf16 | negative dot product | vecf16_dot_ops | +| vecf16 | negative cosine similarity | vecf16_cos_ops | Now you can perform a KNN search with the following SQL again, but this time the vector index is used for searching. @@ -36,14 +44,15 @@ Options for table `segment`. | Key | Type | Description | | ------------------------ | ------- | ------------------------------------------------------------------- | | max_growing_segment_size | integer | Maximum size of unindexed vectors. Default value is `20_000`. | -| min_sealed_segment_size | integer | Minimum size of vectors for indexing. Default value is `1_000`. | | max_sealed_segment_size | integer | Maximum size of vectors for indexing. Default value is `1_000_000`. | Options for table `optimizing`. -| Key | Type | Description | -| ------------------ | ------- | --------------------------------------------------------------------------- | -| optimizing_threads | integer | Maximum threads for indexing. Default value is the sqrt of number of cores. | +| Key | Type | Description | +| ------------------ | ------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | +| optimizing_threads | integer | Maximum threads for indexing. Default value is the sqrt of number of cores. | +| sealing_secs | integer | If a writing segment larger than `sealing_size` do not accept new data for `sealing_secs` seconds, the writing segment will be turned to a sealed segment. | +| sealing_size | integer | See above. | Options for table `indexing`. @@ -99,23 +108,19 @@ Options for table `product`. ## Progress View We also provide a view `pg_vector_index_info` to monitor the progress of indexing. -Note that whether idx_sealed_len is equal to idx_tuples doesn't relate to the completion of indexing. -It may do further optimization after indexing. It may also stop indexing because there are too few tuples left. - -| Column | Type | Description | -| --------------- | ------ | --------------------------------------------- | -| tablerelid | oid | The oid of the table. | -| indexrelid | oid | The oid of the index. | -| tablename | name | The name of the table. | -| indexname | name | The name of the index. | -| indexing | bool | Whether the background thread is indexing. | -| idx_tuples | int4 | The number of tuples. | -| idx_sealed_len | int4 | The number of tuples in sealed segments. | -| idx_growing_len | int4 | The number of tuples in growing segments. | -| idx_write | int4 | The number of tuples in write buffer. | -| idx_sealed | int4[] | The number of tuples in each sealed segment. | -| idx_growing | int4[] | The number of tuples in each growing segment. | -| idx_config | text | The configuration of the index. | + +| Column | Type | Description | +| ------------ | ------ | --------------------------------------------- | +| tablerelid | oid | The oid of the table. | +| indexrelid | oid | The oid of the index. | +| tablename | name | The name of the table. | +| indexname | name | The name of the index. | +| idx_indexing | bool | Whether the background thread is indexing. | +| idx_tuples | int8 | The number of tuples. | +| idx_sealed | int8[] | The number of tuples in each sealed segment. | +| idx_growing | int8[] | The number of tuples in each growing segment. | +| idx_write | int8 | The number of tuples in write buffer. | +| idx_config | text | The configuration of the index. | ## Examples @@ -124,11 +129,11 @@ There are some examples. ```sql -- HNSW algorithm, default settings. -CREATE INDEX ON items USING vectors (embedding l2_ops); +CREATE INDEX ON items USING vectors (embedding vector_l2_ops); --- Or using bruteforce with PQ. -CREATE INDEX ON items USING vectors (embedding l2_ops) +CREATE INDEX ON items USING vectors (embedding vector_l2_ops) WITH (options = $$ [indexing.flat] quantization.product.ratio = "x16" @@ -136,7 +141,7 @@ $$); --- Or using IVFPQ algorithm. -CREATE INDEX ON items USING vectors (embedding l2_ops) +CREATE INDEX ON items USING vectors (embedding vector_l2_ops) WITH (options = $$ [indexing.ivf] quantization.product.ratio = "x16" @@ -144,14 +149,14 @@ $$); -- Use more threads for background building the index. -CREATE INDEX ON items USING vectors (embedding l2_ops) +CREATE INDEX ON items USING vectors (embedding vector_l2_ops) WITH (options = $$ optimizing.optimizing_threads = 16 $$); -- Prefer smaller HNSW graph. -CREATE INDEX ON items USING vectors (embedding l2_ops) +CREATE INDEX ON items USING vectors (embedding vector_l2_ops) WITH (options = $$ segment.max_growing_segment_size = 200000 $$); diff --git a/docs/installation.md b/docs/installation.md index 3b8610146..c55903a07 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -19,24 +19,49 @@ To acheive full performance, please mount the volume to pg data directory by add You can configure PostgreSQL by the reference of the parent image in https://hub.docker.com/_/postgres/. -## Build from source +## Install from source Install Rust and base dependency. ```sh -sudo apt install -y build-essential libpq-dev libssl-dev pkg-config gcc libreadline-dev flex bison libxml2-dev libxslt-dev libxml2-utils xsltproc zlib1g-dev ccache clang git +sudo apt install -y \ + build-essential \ + libpq-dev \ + libssl-dev \ + pkg-config \ + gcc \ + libreadline-dev \ + flex \ + bison \ + libxml2-dev \ + libxslt-dev \ + libxml2-utils \ + xsltproc \ + zlib1g-dev \ + ccache \ + clang \ + git curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh ``` Install PostgreSQL. ```sh -sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list' +sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" >> /etc/apt/sources.list.d/pgdg.list' wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - sudo apt-get update sudo apt-get -y install libpq-dev postgresql-15 postgresql-server-dev-15 ``` +Install clang-16. + +```sh +sudo sh -c 'echo "deb http://apt.llvm.org/$(lsb_release -cs)/ llvm-toolchain-$(lsb_release -cs)-16 main" >> /etc/apt/sources.list' +wget --quiet -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - +sudo apt-get update +sudo apt-get -y install clang-16 +``` + Clone the Repository. ```sh @@ -54,7 +79,7 @@ cargo pgrx init --pg15=/usr/lib/postgresql/15/bin/pg_config Install pgvecto.rs. ```sh -cargo pgrx install --release +cargo pgrx install --sudo --release ``` Configure your PostgreSQL by modifying the `shared_preload_libraries` to include `vectors.so`. diff --git a/docs/searching.md b/docs/searching.md index 3cf432abf..2220b6bd7 100644 --- a/docs/searching.md +++ b/docs/searching.md @@ -15,11 +15,12 @@ If `vectors.k` is set to `64`, but your SQL returned less than `64` rows, for ex * The vector index returned `64` rows, but `32` of which are invisble to the transaction so PostgreSQL decided to hide these rows for you. * The vector index returned `64` rows, but `32` of which are satifying the condition `id % 2 = 0` in `WHERE` clause. -There are three ways to solve the problem: +There are four ways to solve the problem: * Set `vectors.k` larger. If you estimate that 20% of rows will satisfy the condition in `WHERE`, just set `vectors.k` to be 5 times than before. * Set `vectors.enable_vector_index` to `off`. If you estimate that 0.0001% of rows will satisfy the condition in `WHERE`, just do not use vector index. No alogrithms will be faster than brute force by PostgreSQL. * Set `vectors.enable_prefilter` to `on`. If you cannot estimate how many rows will satisfy the condition in `WHERE`, leave the job for the index. The index will check if the returned row can be accepted by PostgreSQL. However, it will make queries slower so the default value for this option is `off`. +* Set `vectors.enable_vbase` to `on`. It will use vbase optimization, so that the index will pull rows as many as you need. It only works for HNSW algorithm. ## Options @@ -30,6 +31,6 @@ Search options are specified by PostgreSQL GUC. You can use `SET` command to app | vectors.k | integer | Expected number of candidates returned by index. The parameter will influence the recall if you use HNSW or quantization for indexing. Default value is `64`. | | vectors.enable_prefilter | boolean | Enable prefiltering or not. Default value is `off`. | | vectors.enable_vector_index | boolean | Enable vector indexes or not. This option is for debugging. Default value is `on`. | -| vectors.vbase_range | int4 | The range size when using vbase optimization. When it is set to `0`, vbase optimization will be disabled. A recommended value is `86`. Default value is `0`. | +| vectors.enable_vbase | boolean | Enable vbase optimization. Default value is `off`. | -Note: When `vectors.vbase_range` is enabled, it will ignore `vectors.enable_prefilter`. +Note: When `vectors.enable_vbase` is enabled, prefilter does not work. diff --git a/scripts/ci_setup.sh b/scripts/ci_setup.sh index f6043ed6b..9e1b8d74f 100755 --- a/scripts/ci_setup.sh +++ b/scripts/ci_setup.sh @@ -6,10 +6,14 @@ if [ "$OS" == "ubuntu-latest" ]; then sudo pg_dropcluster 14 main fi sudo apt-get remove -y '^postgres.*' '^libpq.*' '^clang.*' '^llvm.*' '^libclang.*' '^libllvm.*' '^mono-llvm.*' - sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" > /etc/apt/sources.list.d/pgdg.list' + sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" >> /etc/apt/sources.list.d/pgdg.list' + sudo sh -c 'echo "deb http://apt.llvm.org/$(lsb_release -cs)/ llvm-toolchain-$(lsb_release -cs)-16 main" >> /etc/apt/sources.list' wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - + wget --quiet -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - sudo apt-get update sudo apt-get -y install build-essential libpq-dev postgresql-$VERSION postgresql-server-dev-$VERSION + sudo apt-get -y install clang-16 + sudo apt-get -y install crossbuild-essential-arm64 echo "local all all trust" | sudo tee /etc/postgresql/$VERSION/main/pg_hba.conf echo "host all all 127.0.0.1/32 trust" | sudo tee -a /etc/postgresql/$VERSION/main/pg_hba.conf echo "host all all ::1/128 trust" | sudo tee -a /etc/postgresql/$VERSION/main/pg_hba.conf diff --git a/src/algorithms/diskann/mod.rs b/src/algorithms/diskann/mod.rs deleted file mode 100644 index 87a515a9d..000000000 --- a/src/algorithms/diskann/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod vamana; diff --git a/src/bgworker/mod.rs b/src/bgworker/mod.rs index 75d2fa4d7..529a95e2e 100644 --- a/src/bgworker/mod.rs +++ b/src/bgworker/mod.rs @@ -1,11 +1,26 @@ -pub mod worker; - -use self::worker::Worker; use crate::ipc::server::RpcHandler; use crate::ipc::IpcError; +use service::worker::Worker; use std::path::{Path, PathBuf}; use std::sync::Arc; +pub unsafe fn init() { + use pgrx::bgworkers::BackgroundWorkerBuilder; + use pgrx::bgworkers::BgWorkerStartTime; + BackgroundWorkerBuilder::new("vectors") + .set_function("vectors_main") + .set_library("vectors") + .set_argument(None) + .enable_shmem_access(None) + .set_start_time(BgWorkerStartTime::PostmasterStart) + .load(); +} + +#[no_mangle] +extern "C" fn vectors_main(_arg: pgrx::pg_sys::Datum) { + let _ = std::panic::catch_unwind(crate::bgworker::main); +} + pub fn main() { { let mut builder = env_logger::builder(); @@ -109,10 +124,6 @@ fn session(worker: Arc, mut handler: RpcHandler) -> Result<(), IpcError> handler = x.leave(res)?; } } - RpcHandle::SearchVbase { id, search, mut x } => { - let res = worker.call_search_vbase(id, search, |p| x.next(p).unwrap()); - handler = x.leave(res)?; - } RpcHandle::Flush { id, x } => { let result = worker.call_flush(id); handler = x.leave(result)?; @@ -125,11 +136,36 @@ fn session(worker: Arc, mut handler: RpcHandler) -> Result<(), IpcError> let result = worker.call_stat(id); handler = x.leave(result)?; } - RpcHandle::Leave {} => { - log::debug!("Handle leave rpc."); - break; + RpcHandle::Vbase { id, vector, x } => { + use crate::ipc::server::VbaseHandle::*; + let instance = match worker.get_instance(id) { + Ok(x) => x, + Err(e) => { + x.error(Err(e))?; + break Ok(()); + } + }; + let view = instance.view(); + let mut it = match view.vbase(vector) { + Ok(x) => x, + Err(e) => { + x.error(Err(e))?; + break Ok(()); + } + }; + let mut x = x.error(Ok(()))?; + loop { + match x.handle()? { + Next { x: y } => { + x = y.leave(it.next())?; + } + Leave { x } => { + handler = x; + break; + } + } + } } } } - Ok(()) } diff --git a/src/datatype/casts_f32.rs b/src/datatype/casts_f32.rs new file mode 100644 index 000000000..8fcbd7498 --- /dev/null +++ b/src/datatype/casts_f32.rs @@ -0,0 +1,26 @@ +use crate::datatype::typmod::Typmod; +use crate::datatype::vecf32::{Vecf32, Vecf32Input, Vecf32Output}; +use service::prelude::*; + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn vecf32_cast_array_to_vector( + array: pgrx::Array, + typmod: i32, + _explicit: bool, +) -> Vecf32Output { + assert!(!array.is_empty()); + assert!(array.len() <= 65535); + assert!(!array.contains_nulls()); + let typmod = Typmod::parse_from_i32(typmod).unwrap(); + let len = typmod.dims().unwrap_or(array.len() as u16); + let mut data = vec![F32::zero(); len as usize]; + for (i, x) in array.iter().enumerate() { + data[i] = F32(x.unwrap_or(f32::NAN)); + } + Vecf32::new_in_postgres(&data) +} + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn vecf32_cast_vector_to_array(vector: Vecf32Input<'_>, _typmod: i32, _explicit: bool) -> Vec { + vector.data().iter().map(|x| x.to_f32()).collect() +} diff --git a/src/datatype/mod.rs b/src/datatype/mod.rs new file mode 100644 index 000000000..1b0ef7a78 --- /dev/null +++ b/src/datatype/mod.rs @@ -0,0 +1,6 @@ +pub mod casts_f32; +pub mod operators_f16; +pub mod operators_f32; +pub mod typmod; +pub mod vecf16; +pub mod vecf32; diff --git a/src/postgres/operators.rs b/src/datatype/operators_f16.rs similarity index 56% rename from src/postgres/operators.rs rename to src/datatype/operators_f16.rs index 4f7a63654..9ff044815 100644 --- a/src/postgres/operators.rs +++ b/src/datatype/operators_f16.rs @@ -1,53 +1,53 @@ -use crate::postgres::datatype::{Vector, VectorInput, VectorOutput}; -use crate::prelude::*; +use crate::datatype::vecf16::{Vecf16, Vecf16Input, Vecf16Output}; +use service::prelude::*; use std::ops::Deref; -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] #[pgrx::opname(+)] #[pgrx::commutator(+)] -fn operator_add(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> VectorOutput { +fn vecf16_operator_add(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> Vecf16Output { if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { + FriendlyError::Unmatched { left_dimensions: lhs.len() as _, right_dimensions: rhs.len() as _, } .friendly(); } let n = lhs.len(); - let mut v = Vector::new_zeroed(n); + let mut v = vec![F16::zero(); n]; for i in 0..n { v[i] = lhs[i] + rhs[i]; } - v.copy_into_postgres() + Vecf16::new_in_postgres(&v) } -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] #[pgrx::opname(-)] -fn operator_minus(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> VectorOutput { +fn vecf16_operator_minus(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> Vecf16Output { if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { + FriendlyError::Unmatched { left_dimensions: lhs.len() as _, right_dimensions: rhs.len() as _, } .friendly(); } let n = lhs.len(); - let mut v = Vector::new_zeroed(n); + let mut v = vec![F16::zero(); n]; for i in 0..n { v[i] = lhs[i] - rhs[i]; } - v.copy_into_postgres() + Vecf16::new_in_postgres(&v) } -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] #[pgrx::opname(<)] #[pgrx::negator(>=)] #[pgrx::commutator(>)] #[pgrx::restrict(scalarltsel)] #[pgrx::join(scalarltjoinsel)] -fn operator_lt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { +fn vecf16_operator_lt(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { + FriendlyError::Unmatched { left_dimensions: lhs.len() as _, right_dimensions: rhs.len() as _, } @@ -56,15 +56,15 @@ fn operator_lt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { lhs.deref() < rhs.deref() } -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] #[pgrx::opname(<=)] #[pgrx::negator(>)] #[pgrx::commutator(>=)] #[pgrx::restrict(scalarltsel)] #[pgrx::join(scalarltjoinsel)] -fn operator_lte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { +fn vecf16_operator_lte(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { + FriendlyError::Unmatched { left_dimensions: lhs.len() as _, right_dimensions: rhs.len() as _, } @@ -73,15 +73,15 @@ fn operator_lte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { lhs.deref() <= rhs.deref() } -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] #[pgrx::opname(>)] #[pgrx::negator(<=)] #[pgrx::commutator(<)] #[pgrx::restrict(scalargtsel)] #[pgrx::join(scalargtjoinsel)] -fn operator_gt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { +fn vecf16_operator_gt(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { + FriendlyError::Unmatched { left_dimensions: lhs.len() as _, right_dimensions: rhs.len() as _, } @@ -90,15 +90,15 @@ fn operator_gt(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { lhs.deref() > rhs.deref() } -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] #[pgrx::opname(>=)] #[pgrx::negator(<)] #[pgrx::commutator(<=)] #[pgrx::restrict(scalargtsel)] #[pgrx::join(scalargtjoinsel)] -fn operator_gte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { +fn vecf16_operator_gte(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { + FriendlyError::Unmatched { left_dimensions: lhs.len() as _, right_dimensions: rhs.len() as _, } @@ -107,15 +107,15 @@ fn operator_gte(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { lhs.deref() >= rhs.deref() } -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] #[pgrx::opname(=)] #[pgrx::negator(<>)] #[pgrx::commutator(=)] #[pgrx::restrict(eqsel)] #[pgrx::join(eqjoinsel)] -fn operator_eq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { +fn vecf16_operator_eq(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { + FriendlyError::Unmatched { left_dimensions: lhs.len() as _, right_dimensions: rhs.len() as _, } @@ -124,15 +124,15 @@ fn operator_eq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { lhs.deref() == rhs.deref() } -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] #[pgrx::opname(<>)] #[pgrx::negator(=)] #[pgrx::commutator(<>)] #[pgrx::restrict(eqsel)] #[pgrx::join(eqjoinsel)] -fn operator_neq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { +fn vecf16_operator_neq(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { + FriendlyError::Unmatched { left_dimensions: lhs.len() as _, right_dimensions: rhs.len() as _, } @@ -141,44 +141,44 @@ fn operator_neq(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> bool { lhs.deref() != rhs.deref() } -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] #[pgrx::opname(<=>)] #[pgrx::commutator(<=>)] -fn operator_cosine(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar { +fn vecf16_operator_cosine(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> f32 { if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { + FriendlyError::Unmatched { left_dimensions: lhs.len() as _, right_dimensions: rhs.len() as _, } .friendly(); } - Distance::Cosine.distance(&lhs, &rhs) + F16Cos::distance(&lhs, &rhs).to_f32() } -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] #[pgrx::opname(<#>)] #[pgrx::commutator(<#>)] -fn operator_dot(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar { +fn vecf16_operator_dot(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> f32 { if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { + FriendlyError::Unmatched { left_dimensions: lhs.len() as _, right_dimensions: rhs.len() as _, } .friendly(); } - Distance::Dot.distance(&lhs, &rhs) + F16Dot::distance(&lhs, &rhs).to_f32() } -#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vector"])] +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf16"])] #[pgrx::opname(<->)] #[pgrx::commutator(<->)] -fn operator_l2(lhs: VectorInput<'_>, rhs: VectorInput<'_>) -> Scalar { +fn vecf16_operator_l2(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> f32 { if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { + FriendlyError::Unmatched { left_dimensions: lhs.len() as _, right_dimensions: rhs.len() as _, } .friendly(); } - Distance::L2.distance(&lhs, &rhs) + F16L2::distance(&lhs, &rhs).to_f32() } diff --git a/src/datatype/operators_f32.rs b/src/datatype/operators_f32.rs new file mode 100644 index 000000000..d4a67c22d --- /dev/null +++ b/src/datatype/operators_f32.rs @@ -0,0 +1,184 @@ +use crate::datatype::vecf32::{Vecf32, Vecf32Input, Vecf32Output}; +use service::prelude::*; +use std::ops::Deref; + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] +#[pgrx::opname(+)] +#[pgrx::commutator(+)] +fn vecf32_operator_add(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> Vecf32Output { + if lhs.len() != rhs.len() { + FriendlyError::Unmatched { + left_dimensions: lhs.len() as _, + right_dimensions: rhs.len() as _, + } + .friendly(); + } + let n = lhs.len(); + let mut v = vec![F32::zero(); n]; + for i in 0..n { + v[i] = lhs[i] + rhs[i]; + } + Vecf32::new_in_postgres(&v) +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] +#[pgrx::opname(-)] +fn vecf32_operator_minus(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> Vecf32Output { + if lhs.len() != rhs.len() { + FriendlyError::Unmatched { + left_dimensions: lhs.len() as _, + right_dimensions: rhs.len() as _, + } + .friendly(); + } + let n = lhs.len(); + let mut v = vec![F32::zero(); n]; + for i in 0..n { + v[i] = lhs[i] - rhs[i]; + } + Vecf32::new_in_postgres(&v) +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] +#[pgrx::opname(<)] +#[pgrx::negator(>=)] +#[pgrx::commutator(>)] +#[pgrx::restrict(scalarltsel)] +#[pgrx::join(scalarltjoinsel)] +fn vecf32_operator_lt(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { + if lhs.len() != rhs.len() { + FriendlyError::Unmatched { + left_dimensions: lhs.len() as _, + right_dimensions: rhs.len() as _, + } + .friendly(); + } + lhs.deref() < rhs.deref() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] +#[pgrx::opname(<=)] +#[pgrx::negator(>)] +#[pgrx::commutator(>=)] +#[pgrx::restrict(scalarltsel)] +#[pgrx::join(scalarltjoinsel)] +fn vecf32_operator_lte(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { + if lhs.len() != rhs.len() { + FriendlyError::Unmatched { + left_dimensions: lhs.len() as _, + right_dimensions: rhs.len() as _, + } + .friendly(); + } + lhs.deref() <= rhs.deref() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] +#[pgrx::opname(>)] +#[pgrx::negator(<=)] +#[pgrx::commutator(<)] +#[pgrx::restrict(scalargtsel)] +#[pgrx::join(scalargtjoinsel)] +fn vecf32_operator_gt(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { + if lhs.len() != rhs.len() { + FriendlyError::Unmatched { + left_dimensions: lhs.len() as _, + right_dimensions: rhs.len() as _, + } + .friendly(); + } + lhs.deref() > rhs.deref() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] +#[pgrx::opname(>=)] +#[pgrx::negator(<)] +#[pgrx::commutator(<=)] +#[pgrx::restrict(scalargtsel)] +#[pgrx::join(scalargtjoinsel)] +fn vecf32_operator_gte(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { + if lhs.len() != rhs.len() { + FriendlyError::Unmatched { + left_dimensions: lhs.len() as _, + right_dimensions: rhs.len() as _, + } + .friendly(); + } + lhs.deref() >= rhs.deref() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] +#[pgrx::opname(=)] +#[pgrx::negator(<>)] +#[pgrx::commutator(=)] +#[pgrx::restrict(eqsel)] +#[pgrx::join(eqjoinsel)] +fn vecf32_operator_eq(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { + if lhs.len() != rhs.len() { + FriendlyError::Unmatched { + left_dimensions: lhs.len() as _, + right_dimensions: rhs.len() as _, + } + .friendly(); + } + lhs.deref() == rhs.deref() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] +#[pgrx::opname(<>)] +#[pgrx::negator(=)] +#[pgrx::commutator(<>)] +#[pgrx::restrict(eqsel)] +#[pgrx::join(eqjoinsel)] +fn vecf32_operator_neq(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { + if lhs.len() != rhs.len() { + FriendlyError::Unmatched { + left_dimensions: lhs.len() as _, + right_dimensions: rhs.len() as _, + } + .friendly(); + } + lhs.deref() != rhs.deref() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] +#[pgrx::opname(<=>)] +#[pgrx::commutator(<=>)] +fn vecf32_operator_cosine(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> f32 { + if lhs.len() != rhs.len() { + FriendlyError::Unmatched { + left_dimensions: lhs.len() as _, + right_dimensions: rhs.len() as _, + } + .friendly(); + } + F32Cos::distance(&lhs, &rhs).to_f32() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] +#[pgrx::opname(<#>)] +#[pgrx::commutator(<#>)] +fn vecf32_operator_dot(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> f32 { + if lhs.len() != rhs.len() { + FriendlyError::Unmatched { + left_dimensions: lhs.len() as _, + right_dimensions: rhs.len() as _, + } + .friendly(); + } + F32Dot::distance(&lhs, &rhs).to_f32() +} + +#[pgrx::pg_operator(immutable, parallel_safe, requires = ["vecf32"])] +#[pgrx::opname(<->)] +#[pgrx::commutator(<->)] +fn vecf32_operator_l2(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> f32 { + if lhs.len() != rhs.len() { + FriendlyError::Unmatched { + left_dimensions: lhs.len() as _, + right_dimensions: rhs.len() as _, + } + .friendly(); + } + F32L2::distance(&lhs, &rhs).to_f32() +} diff --git a/src/datatype/typmod.rs b/src/datatype/typmod.rs new file mode 100644 index 000000000..4760efa83 --- /dev/null +++ b/src/datatype/typmod.rs @@ -0,0 +1,77 @@ +use pgrx::Array; +use serde::{Deserialize, Serialize}; +use service::prelude::*; +use std::ffi::{CStr, CString}; +use std::num::NonZeroU16; + +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub enum Typmod { + Any, + Dims(NonZeroU16), +} + +impl Typmod { + pub fn parse_from_str(s: &str) -> Option { + use Typmod::*; + if let Ok(x) = s.parse::() { + Some(Dims(x)) + } else { + None + } + } + pub fn parse_from_i32(x: i32) -> Option { + use Typmod::*; + if x == -1 { + Some(Any) + } else if 1 <= x && x <= u16::MAX as i32 { + Some(Dims(NonZeroU16::new(x as u16).unwrap())) + } else { + None + } + } + pub fn into_option_string(self) -> Option { + use Typmod::*; + match self { + Any => None, + Dims(x) => Some(i32::from(x.get()).to_string()), + } + } + pub fn into_i32(self) -> i32 { + use Typmod::*; + match self { + Any => -1, + Dims(x) => i32::from(x.get()), + } + } + pub fn dims(self) -> Option { + use Typmod::*; + match self { + Any => None, + Dims(dims) => Some(dims.get()), + } + } +} + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn typmod_in(list: Array<&CStr>) -> i32 { + if list.is_empty() { + -1 + } else if list.len() == 1 { + let s = list.get(0).unwrap().unwrap().to_str().unwrap(); + let typmod = Typmod::parse_from_str(s) + .ok_or(FriendlyError::BadTypeDimensions) + .friendly(); + typmod.into_i32() + } else { + FriendlyError::BadTypeDimensions.friendly(); + } +} + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn typmod_out(typmod: i32) -> CString { + let typmod = Typmod::parse_from_i32(typmod).unwrap(); + match typmod.into_option_string() { + Some(s) => CString::new(format!("({})", s)).unwrap(), + None => CString::new("()").unwrap(), + } +} diff --git a/src/datatype/vecf16.rs b/src/datatype/vecf16.rs new file mode 100644 index 000000000..a8db6072c --- /dev/null +++ b/src/datatype/vecf16.rs @@ -0,0 +1,343 @@ +use crate::datatype::typmod::Typmod; +use pgrx::pg_sys::Datum; +use pgrx::pg_sys::Oid; +use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; +use pgrx::pgrx_sql_entity_graph::metadata::Returns; +use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; +use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; +use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; +use pgrx::FromDatum; +use pgrx::IntoDatum; +use service::prelude::*; +use std::alloc::Layout; +use std::cmp::Ordering; +use std::ffi::CStr; +use std::ffi::CString; +use std::ops::Deref; +use std::ops::DerefMut; +use std::ops::Index; +use std::ops::IndexMut; +use std::ptr::NonNull; + +pgrx::extension_sql!( + r#" +CREATE TYPE vecf16 ( + INPUT = vecf16_in, + OUTPUT = vecf16_out, + TYPMOD_IN = typmod_in, + TYPMOD_OUT = typmod_out, + STORAGE = EXTENDED, + INTERNALLENGTH = VARIABLE, + ALIGNMENT = double +); +"#, + name = "vecf16", + creates = [Type(Vecf16)], + requires = [vecf16_in, vecf16_out, typmod_in, typmod_out], +); + +#[repr(C, align(8))] +pub struct Vecf16 { + varlena: u32, + len: u16, + kind: u8, + reserved: u8, + phantom: [F16; 0], +} + +impl Vecf16 { + fn varlena(size: usize) -> u32 { + (size << 2) as u32 + } + fn layout(len: usize) -> Layout { + u16::try_from(len).expect("Vector is too large."); + let layout_alpha = Layout::new::(); + let layout_beta = Layout::array::(len).unwrap(); + let layout = layout_alpha.extend(layout_beta).unwrap().0; + layout.pad_to_align() + } + pub fn new_in_postgres(slice: &[F16]) -> Vecf16Output { + unsafe { + assert!(u16::try_from(slice.len()).is_ok()); + let layout = Vecf16::layout(slice.len()); + let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut Vecf16; + ptr.cast::().add(layout.size() - 8).write_bytes(0, 8); + std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16::varlena(layout.size())); + std::ptr::addr_of_mut!((*ptr).kind).write(1); + std::ptr::addr_of_mut!((*ptr).reserved).write(0); + std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16); + std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); + Vecf16Output(NonNull::new(ptr).unwrap()) + } + } + pub fn len(&self) -> usize { + self.len as usize + } + pub fn data(&self) -> &[F16] { + debug_assert_eq!(self.varlena & 3, 0); + debug_assert_eq!(self.kind, 1); + unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.len as usize) } + } + pub fn data_mut(&mut self) -> &mut [F16] { + debug_assert_eq!(self.varlena & 3, 0); + debug_assert_eq!(self.kind, 1); + unsafe { std::slice::from_raw_parts_mut(self.phantom.as_mut_ptr(), self.len as usize) } + } +} + +impl Deref for Vecf16 { + type Target = [F16]; + + fn deref(&self) -> &Self::Target { + self.data() + } +} + +impl DerefMut for Vecf16 { + fn deref_mut(&mut self) -> &mut Self::Target { + self.data_mut() + } +} + +impl AsRef<[F16]> for Vecf16 { + fn as_ref(&self) -> &[F16] { + self.data() + } +} + +impl AsMut<[F16]> for Vecf16 { + fn as_mut(&mut self) -> &mut [F16] { + self.data_mut() + } +} + +impl Index for Vecf16 { + type Output = F16; + + fn index(&self, index: usize) -> &Self::Output { + self.data().index(index) + } +} + +impl IndexMut for Vecf16 { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + self.data_mut().index_mut(index) + } +} + +impl PartialEq for Vecf16 { + fn eq(&self, other: &Self) -> bool { + if self.len() != other.len() { + return false; + } + let n = self.len(); + for i in 0..n { + if self[i] != other[i] { + return false; + } + } + true + } +} + +impl Eq for Vecf16 {} + +impl PartialOrd for Vecf16 { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Vecf16 { + fn cmp(&self, other: &Self) -> Ordering { + use Ordering::*; + if let x @ Less | x @ Greater = self.len().cmp(&other.len()) { + return x; + } + let n = self.len(); + for i in 0..n { + if let x @ Less | x @ Greater = self[i].cmp(&other[i]) { + return x; + } + } + Equal + } +} + +pub enum Vecf16Input<'a> { + Owned(Vecf16Output), + Borrowed(&'a Vecf16), +} + +impl<'a> Vecf16Input<'a> { + pub unsafe fn new(p: NonNull) -> Self { + let q = unsafe { + NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap() + }; + if p != q { + Vecf16Input::Owned(Vecf16Output(q)) + } else { + unsafe { Vecf16Input::Borrowed(p.as_ref()) } + } + } +} + +impl Deref for Vecf16Input<'_> { + type Target = Vecf16; + + fn deref(&self) -> &Self::Target { + match self { + Vecf16Input::Owned(x) => x, + Vecf16Input::Borrowed(x) => x, + } + } +} + +pub struct Vecf16Output(NonNull); + +impl Vecf16Output { + pub fn into_raw(self) -> *mut Vecf16 { + let result = self.0.as_ptr(); + std::mem::forget(self); + result + } +} + +impl Deref for Vecf16Output { + type Target = Vecf16; + + fn deref(&self) -> &Self::Target { + unsafe { self.0.as_ref() } + } +} + +impl DerefMut for Vecf16Output { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { self.0.as_mut() } + } +} + +impl Drop for Vecf16Output { + fn drop(&mut self) { + unsafe { + pgrx::pg_sys::pfree(self.0.as_ptr() as _); + } + } +} + +impl<'a> FromDatum for Vecf16Input<'a> { + unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { + if is_null { + None + } else { + let ptr = NonNull::new(datum.cast_mut_ptr::()).unwrap(); + unsafe { Some(Vecf16Input::new(ptr)) } + } + } +} + +impl IntoDatum for Vecf16Output { + fn into_datum(self) -> Option { + Some(Datum::from(self.into_raw() as *mut ())) + } + + fn type_oid() -> Oid { + pgrx::wrappers::regtypein("vecf16") + } +} + +unsafe impl SqlTranslatable for Vecf16Input<'_> { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("vecf16"))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("vecf16")))) + } +} + +unsafe impl SqlTranslatable for Vecf16Output { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("vecf16"))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("vecf16")))) + } +} + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn vecf16_in(input: &CStr, _oid: Oid, typmod: i32) -> Vecf16Output { + fn solve(option: Option, hint: &str) -> T { + if let Some(x) = option { + x + } else { + FriendlyError::BadLiteral { + hint: hint.to_string(), + } + .friendly() + } + } + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + enum State { + MatchingLeft, + Reading, + MatchedRight, + } + use State::*; + let input = input.to_bytes(); + let typmod = Typmod::parse_from_i32(typmod).unwrap(); + let mut vector = Vec::::with_capacity(typmod.dims().unwrap_or(0) as usize); + let mut state = MatchingLeft; + let mut token: Option = None; + for &c in input { + match (state, c) { + (MatchingLeft, b'[') => { + state = Reading; + } + (Reading, b'0'..=b'9' | b'.' | b'e' | b'+' | b'-') => { + let token = token.get_or_insert(String::new()); + token.push(char::from_u32(c as u32).unwrap()); + } + (Reading, b',') => { + let token = solve(token.take(), "Expect a number."); + vector.push(solve(token.parse().ok(), "Bad number.")); + } + (Reading, b']') => { + if let Some(token) = token.take() { + vector.push(solve(token.parse().ok(), "Bad number.")); + } + state = MatchedRight; + } + (_, b' ') => {} + _ => { + FriendlyError::BadLiteral { + hint: format!("Bad charactor with ascii {:#x}.", c), + } + .friendly(); + } + } + } + if state != MatchedRight { + FriendlyError::BadLiteral { + hint: "Bad sequence.".to_string(), + } + .friendly(); + } + if vector.is_empty() || vector.len() > 65535 { + FriendlyError::BadValueDimensions.friendly(); + } + Vecf16::new_in_postgres(&vector) +} + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn vecf16_out(vector: Vecf16Input<'_>) -> CString { + let mut buffer = String::new(); + buffer.push('['); + if let Some(&x) = vector.data().first() { + buffer.push_str(format!("{}", x).as_str()); + } + for &x in vector.data().iter().skip(1) { + buffer.push_str(format!(", {}", x).as_str()); + } + buffer.push(']'); + CString::new(buffer).unwrap() +} diff --git a/src/datatype/vecf32.rs b/src/datatype/vecf32.rs new file mode 100644 index 000000000..3663b45fc --- /dev/null +++ b/src/datatype/vecf32.rs @@ -0,0 +1,343 @@ +use crate::datatype::typmod::Typmod; +use pgrx::pg_sys::Datum; +use pgrx::pg_sys::Oid; +use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; +use pgrx::pgrx_sql_entity_graph::metadata::Returns; +use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; +use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; +use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; +use pgrx::FromDatum; +use pgrx::IntoDatum; +use service::prelude::*; +use std::alloc::Layout; +use std::cmp::Ordering; +use std::ffi::CStr; +use std::ffi::CString; +use std::ops::Deref; +use std::ops::DerefMut; +use std::ops::Index; +use std::ops::IndexMut; +use std::ptr::NonNull; + +pgrx::extension_sql!( + r#" +CREATE TYPE vector ( + INPUT = vecf32_in, + OUTPUT = vecf32_out, + TYPMOD_IN = typmod_in, + TYPMOD_OUT = typmod_out, + STORAGE = EXTENDED, + INTERNALLENGTH = VARIABLE, + ALIGNMENT = double +); +"#, + name = "vecf32", + creates = [Type(Vecf32)], + requires = [vecf32_in, vecf32_out, typmod_in, typmod_out], +); + +#[repr(C, align(8))] +pub struct Vecf32 { + varlena: u32, + len: u16, + kind: u8, + reserved: u8, + phantom: [F32; 0], +} + +impl Vecf32 { + fn varlena(size: usize) -> u32 { + (size << 2) as u32 + } + fn layout(len: usize) -> Layout { + u16::try_from(len).expect("Vector is too large."); + let layout_alpha = Layout::new::(); + let layout_beta = Layout::array::(len).unwrap(); + let layout = layout_alpha.extend(layout_beta).unwrap().0; + layout.pad_to_align() + } + pub fn new_in_postgres(slice: &[F32]) -> Vecf32Output { + unsafe { + assert!(u16::try_from(slice.len()).is_ok()); + let layout = Vecf32::layout(slice.len()); + let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut Vecf32; + ptr.cast::().add(layout.size() - 8).write_bytes(0, 8); + std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32::varlena(layout.size())); + std::ptr::addr_of_mut!((*ptr).kind).write(0); + std::ptr::addr_of_mut!((*ptr).reserved).write(0); + std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16); + std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); + Vecf32Output(NonNull::new(ptr).unwrap()) + } + } + pub fn len(&self) -> usize { + self.len as usize + } + pub fn data(&self) -> &[F32] { + debug_assert_eq!(self.varlena & 3, 0); + debug_assert_eq!(self.kind, 0); + unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.len as usize) } + } + pub fn data_mut(&mut self) -> &mut [F32] { + debug_assert_eq!(self.varlena & 3, 0); + debug_assert_eq!(self.kind, 0); + unsafe { std::slice::from_raw_parts_mut(self.phantom.as_mut_ptr(), self.len as usize) } + } +} + +impl Deref for Vecf32 { + type Target = [F32]; + + fn deref(&self) -> &Self::Target { + self.data() + } +} + +impl DerefMut for Vecf32 { + fn deref_mut(&mut self) -> &mut Self::Target { + self.data_mut() + } +} + +impl AsRef<[F32]> for Vecf32 { + fn as_ref(&self) -> &[F32] { + self.data() + } +} + +impl AsMut<[F32]> for Vecf32 { + fn as_mut(&mut self) -> &mut [F32] { + self.data_mut() + } +} + +impl Index for Vecf32 { + type Output = F32; + + fn index(&self, index: usize) -> &Self::Output { + self.data().index(index) + } +} + +impl IndexMut for Vecf32 { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + self.data_mut().index_mut(index) + } +} + +impl PartialEq for Vecf32 { + fn eq(&self, other: &Self) -> bool { + if self.len() != other.len() { + return false; + } + let n = self.len(); + for i in 0..n { + if self[i] != other[i] { + return false; + } + } + true + } +} + +impl Eq for Vecf32 {} + +impl PartialOrd for Vecf32 { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Vecf32 { + fn cmp(&self, other: &Self) -> Ordering { + use Ordering::*; + if let x @ Less | x @ Greater = self.len().cmp(&other.len()) { + return x; + } + let n = self.len(); + for i in 0..n { + if let x @ Less | x @ Greater = self[i].cmp(&other[i]) { + return x; + } + } + Equal + } +} + +pub enum Vecf32Input<'a> { + Owned(Vecf32Output), + Borrowed(&'a Vecf32), +} + +impl<'a> Vecf32Input<'a> { + pub unsafe fn new(p: NonNull) -> Self { + let q = unsafe { + NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap() + }; + if p != q { + Vecf32Input::Owned(Vecf32Output(q)) + } else { + unsafe { Vecf32Input::Borrowed(p.as_ref()) } + } + } +} + +impl Deref for Vecf32Input<'_> { + type Target = Vecf32; + + fn deref(&self) -> &Self::Target { + match self { + Vecf32Input::Owned(x) => x, + Vecf32Input::Borrowed(x) => x, + } + } +} + +pub struct Vecf32Output(NonNull); + +impl Vecf32Output { + pub fn into_raw(self) -> *mut Vecf32 { + let result = self.0.as_ptr(); + std::mem::forget(self); + result + } +} + +impl Deref for Vecf32Output { + type Target = Vecf32; + + fn deref(&self) -> &Self::Target { + unsafe { self.0.as_ref() } + } +} + +impl DerefMut for Vecf32Output { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { self.0.as_mut() } + } +} + +impl Drop for Vecf32Output { + fn drop(&mut self) { + unsafe { + pgrx::pg_sys::pfree(self.0.as_ptr() as _); + } + } +} + +impl<'a> FromDatum for Vecf32Input<'a> { + unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { + if is_null { + None + } else { + let ptr = NonNull::new(datum.cast_mut_ptr::()).unwrap(); + unsafe { Some(Vecf32Input::new(ptr)) } + } + } +} + +impl IntoDatum for Vecf32Output { + fn into_datum(self) -> Option { + Some(Datum::from(self.into_raw() as *mut ())) + } + + fn type_oid() -> Oid { + pgrx::wrappers::regtypein("vector") + } +} + +unsafe impl SqlTranslatable for Vecf32Input<'_> { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("vector"))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("vector")))) + } +} + +unsafe impl SqlTranslatable for Vecf32Output { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("vector"))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("vector")))) + } +} + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn vecf32_in(input: &CStr, _oid: Oid, typmod: i32) -> Vecf32Output { + fn solve(option: Option, hint: &str) -> T { + if let Some(x) = option { + x + } else { + FriendlyError::BadLiteral { + hint: hint.to_string(), + } + .friendly() + } + } + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + enum State { + MatchingLeft, + Reading, + MatchedRight, + } + use State::*; + let input = input.to_bytes(); + let typmod = Typmod::parse_from_i32(typmod).unwrap(); + let mut vector = Vec::::with_capacity(typmod.dims().unwrap_or(0) as usize); + let mut state = MatchingLeft; + let mut token: Option = None; + for &c in input { + match (state, c) { + (MatchingLeft, b'[') => { + state = Reading; + } + (Reading, b'0'..=b'9' | b'.' | b'e' | b'+' | b'-') => { + let token = token.get_or_insert(String::new()); + token.push(char::from_u32(c as u32).unwrap()); + } + (Reading, b',') => { + let token = solve(token.take(), "Expect a number."); + vector.push(solve(token.parse().ok(), "Bad number.")); + } + (Reading, b']') => { + if let Some(token) = token.take() { + vector.push(solve(token.parse().ok(), "Bad number.")); + } + state = MatchedRight; + } + (_, b' ') => {} + _ => { + FriendlyError::BadLiteral { + hint: format!("Bad charactor with ascii {:#x}.", c), + } + .friendly(); + } + } + } + if state != MatchedRight { + FriendlyError::BadLiteral { + hint: "Bad sequence.".to_string(), + } + .friendly(); + } + if vector.is_empty() || vector.len() > 65535 { + FriendlyError::BadValueDimensions.friendly(); + } + Vecf32::new_in_postgres(&vector) +} + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn vecf32_out(vector: Vecf32Input<'_>) -> CString { + let mut buffer = String::new(); + buffer.push('['); + if let Some(&x) = vector.data().first() { + buffer.push_str(format!("{}", x).as_str()); + } + for &x in vector.data().iter().skip(1) { + buffer.push_str(format!(", {}", x).as_str()); + } + buffer.push(']'); + CString::new(buffer).unwrap() +} diff --git a/src/embedding/udf.rs b/src/embedding/udf.rs index cc6bfb14c..97131bb82 100644 --- a/src/embedding/udf.rs +++ b/src/embedding/udf.rs @@ -1,14 +1,12 @@ use super::openai::{EmbeddingCreator, OpenAIEmbedding}; use super::Embedding; -use crate::postgres::datatype::Vector; -use crate::postgres::datatype::VectorOutput; -use crate::postgres::gucs::OPENAI_API_KEY_GUC; -use crate::prelude::Float; -use crate::prelude::Scalar; +use crate::datatype::vecf32::{Vecf32, Vecf32Output}; +use crate::gucs::OPENAI_API_KEY_GUC; use pgrx::prelude::*; +use service::prelude::F32; #[pg_extern] -fn ai_embedding_vector(input: String) -> VectorOutput { +fn ai_embedding_vector(input: String) -> Vecf32Output { let api_key = match OPENAI_API_KEY_GUC.get() { Some(key) => key .to_str() @@ -26,9 +24,9 @@ fn ai_embedding_vector(input: String) -> VectorOutput { Ok(embedding) => { let embedding = embedding .into_iter() - .map(|x| Scalar(x as Float)) + .map(|x| F32(x as f32)) .collect::>(); - Vector::new_in_postgres(&embedding) + Vecf32::new_in_postgres(&embedding) } Err(e) => { error!("{}", e) diff --git a/src/postgres/gucs.rs b/src/gucs.rs similarity index 89% rename from src/postgres/gucs.rs rename to src/gucs.rs index 977bcb008..e9c3e8189 100644 --- a/src/postgres/gucs.rs +++ b/src/gucs.rs @@ -23,7 +23,7 @@ pub static ENABLE_VECTOR_INDEX: GucSetting = GucSetting::::new(true) pub static ENABLE_PREFILTER: GucSetting = GucSetting::::new(false); -pub static VBASE_RANGE: GucSetting = GucSetting::::new(0); +pub static ENABLE_VBASE: GucSetting = GucSetting::::new(false); pub static TRANSPORT: GucSetting = GucSetting::::new(Transport::default()); @@ -62,13 +62,11 @@ pub unsafe fn init() { GucContext::Userset, GucFlags::default(), ); - GucRegistry::define_int_guc( - "vectors.vbase_range", - "The range of vbase.", - "The range of vbase.", - &VBASE_RANGE, - 0, - u16::MAX as _, + GucRegistry::define_bool_guc( + "vectors.enable_vbase", + "Whether to enable vbase.", + "When enabled, it will use vbase for filtering.", + &ENABLE_VBASE, GucContext::Userset, GucFlags::default(), ); diff --git a/src/postgres/index.rs b/src/index/am.rs similarity index 84% rename from src/postgres/index.rs rename to src/index/am.rs index 4454fedb1..684b6ad88 100644 --- a/src/postgres/index.rs +++ b/src/index/am.rs @@ -1,11 +1,12 @@ -use super::index_build; -use super::index_scan; -use super::index_setup; -use super::index_update; -use crate::postgres::datatype::VectorInput; -use crate::postgres::gucs::ENABLE_VECTOR_INDEX; +use super::am_build; +use super::am_scan; +use super::am_setup; +use super::am_update; +use crate::gucs::ENABLE_VECTOR_INDEX; +use crate::index::utils::from_datum; use crate::prelude::*; use crate::utils::cells::PgCell; +use service::prelude::*; static RELOPT_KIND: PgCell = unsafe { PgCell::new(0) }; @@ -28,9 +29,7 @@ pub unsafe fn init() { #[pgrx::pg_extern(sql = " CREATE OR REPLACE FUNCTION vectors_amhandler(internal) RETURNS index_am_handler PARALLEL SAFE IMMUTABLE STRICT LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@'; - CREATE ACCESS METHOD vectors TYPE INDEX HANDLER vectors_amhandler; - COMMENT ON ACCESS METHOD vectors IS 'pgvecto.rs index access method'; -", requires = ["vector"])] +", requires = ["vecf32"])] fn vectors_amhandler( _fcinfo: pgrx::pg_sys::FunctionCallInfo, ) -> pgrx::PgBox { @@ -85,7 +84,7 @@ const AM_HANDLER: pgrx::pg_sys::IndexAmRoutine = { #[pgrx::pg_guard] pub unsafe extern "C" fn amvalidate(opclass_oid: pgrx::pg_sys::Oid) -> bool { - index_setup::convert_opclass_to_distance(opclass_oid); + am_setup::convert_opclass_to_distance(opclass_oid); true } @@ -99,7 +98,7 @@ pub unsafe extern "C" fn amoptions( let tab: &[pgrx::pg_sys::relopt_parse_elt] = &[pgrx::pg_sys::relopt_parse_elt { optname: "options".as_pg_cstr(), opttype: pgrx::pg_sys::relopt_type_RELOPT_TYPE_STRING, - offset: index_setup::helper_offset() as i32, + offset: am_setup::helper_offset() as i32, }]; let mut noptions = 0; let options = @@ -111,10 +110,10 @@ pub unsafe extern "C" fn amoptions( relopt.gen.as_mut().unwrap().lockmode = pgrx::pg_sys::AccessExclusiveLock as pgrx::pg_sys::LOCKMODE; } - let rdopts = pgrx::pg_sys::allocateReloptStruct(index_setup::helper_size(), options, noptions); + let rdopts = pgrx::pg_sys::allocateReloptStruct(am_setup::helper_size(), options, noptions); pgrx::pg_sys::fillRelOptions( rdopts, - index_setup::helper_size(), + am_setup::helper_size(), options, noptions, validate, @@ -136,13 +135,13 @@ pub unsafe extern "C" fn amoptions( let tab: &[pgrx::pg_sys::relopt_parse_elt] = &[pgrx::pg_sys::relopt_parse_elt { optname: "options".as_pg_cstr(), opttype: pgrx::pg_sys::relopt_type_RELOPT_TYPE_STRING, - offset: index_setup::helper_offset() as i32, + offset: am_setup::helper_offset() as i32, }]; let rdopts = pgrx::pg_sys::build_reloptions( reloptions, validate, RELOPT_KIND.get(), - index_setup::helper_size(), + am_setup::helper_size(), tab.as_ptr(), tab.len() as _, ); @@ -182,7 +181,7 @@ pub unsafe extern "C" fn ambuild( index_info: *mut pgrx::pg_sys::IndexInfo, ) -> *mut pgrx::pg_sys::IndexBuildResult { let result = pgrx::PgBox::::alloc0(); - index_build::build( + am_build::build( index_relation, Some((heap_relation, index_info, result.as_ptr())), ); @@ -191,7 +190,7 @@ pub unsafe extern "C" fn ambuild( #[pgrx::pg_guard] pub unsafe extern "C" fn ambuildempty(index_relation: pgrx::pg_sys::Relation) { - index_build::build(index_relation, None); + am_build::build(index_relation, None); } #[cfg(any(feature = "pg12", feature = "pg13"))] @@ -199,18 +198,16 @@ pub unsafe extern "C" fn ambuildempty(index_relation: pgrx::pg_sys::Relation) { pub unsafe extern "C" fn aminsert( index_relation: pgrx::pg_sys::Relation, values: *mut pgrx::pg_sys::Datum, - is_null: *mut bool, + _is_null: *mut bool, heap_tid: pgrx::pg_sys::ItemPointer, _heap_relation: pgrx::pg_sys::Relation, _check_unique: pgrx::pg_sys::IndexUniqueCheck, _index_info: *mut pgrx::pg_sys::IndexInfo, ) -> bool { - use pgrx::FromDatum; let oid = (*index_relation).rd_node.relNode; let id = Id::from_sys(oid); - let vector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap(); - let vector = vector.data().to_vec(); - index_update::update_insert(id, vector, *heap_tid); + let vector = from_datum(*values.add(0)); + am_update::update_insert(id, vector, *heap_tid); true } @@ -219,22 +216,20 @@ pub unsafe extern "C" fn aminsert( pub unsafe extern "C" fn aminsert( index_relation: pgrx::pg_sys::Relation, values: *mut pgrx::pg_sys::Datum, - is_null: *mut bool, + _is_null: *mut bool, heap_tid: pgrx::pg_sys::ItemPointer, _heap_relation: pgrx::pg_sys::Relation, _check_unique: pgrx::pg_sys::IndexUniqueCheck, _index_unchanged: bool, _index_info: *mut pgrx::pg_sys::IndexInfo, ) -> bool { - use pgrx::FromDatum; #[cfg(any(feature = "pg14", feature = "pg15"))] let oid = (*index_relation).rd_node.relNode; #[cfg(feature = "pg16")] let oid = (*index_relation).rd_locator.relNumber; let id = Id::from_sys(oid); - let vector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap(); - let vector = vector.data().to_vec(); - index_update::update_insert(id, vector, *heap_tid); + let vector = from_datum(*values.add(0)); + am_update::update_insert(id, vector, *heap_tid); true } @@ -242,20 +237,26 @@ pub unsafe extern "C" fn aminsert( pub unsafe extern "C" fn ambeginscan( index_relation: pgrx::pg_sys::Relation, n_keys: std::os::raw::c_int, - n_order_bys: std::os::raw::c_int, + n_orderbys: std::os::raw::c_int, ) -> pgrx::pg_sys::IndexScanDesc { - index_scan::make_scan(index_relation, n_keys, n_order_bys) + assert!(n_keys == 0); + assert!(n_orderbys == 1); + am_scan::make_scan(index_relation) } #[pgrx::pg_guard] pub unsafe extern "C" fn amrescan( scan: pgrx::pg_sys::IndexScanDesc, - keys: pgrx::pg_sys::ScanKey, + _keys: pgrx::pg_sys::ScanKey, n_keys: std::os::raw::c_int, orderbys: pgrx::pg_sys::ScanKey, n_orderbys: std::os::raw::c_int, ) { - index_scan::start_scan(scan, keys, n_keys, orderbys, n_orderbys); + assert!((*scan).numberOfKeys == n_keys); + assert!((*scan).numberOfOrderBys == n_orderbys); + assert!(n_keys == 0); + assert!(n_orderbys == 1); + am_scan::start_scan(scan, orderbys); } #[pgrx::pg_guard] @@ -264,12 +265,12 @@ pub unsafe extern "C" fn amgettuple( direction: pgrx::pg_sys::ScanDirection, ) -> bool { assert!(direction == pgrx::pg_sys::ScanDirection_ForwardScanDirection); - index_scan::next_scan(scan) + am_scan::next_scan(scan) } #[pgrx::pg_guard] pub unsafe extern "C" fn amendscan(scan: pgrx::pg_sys::IndexScanDesc) { - index_scan::end_scan(scan); + am_scan::end_scan(scan); } #[pgrx::pg_guard] @@ -285,7 +286,7 @@ pub unsafe extern "C" fn ambulkdelete( let oid = (*(*info).index).rd_locator.relNumber; let id = Id::from_sys(oid); if let Some(callback) = callback { - index_update::update_delete(id, |pointer| { + am_update::update_delete(id, |pointer| { callback( &mut pointer.into_sys() as *mut pgrx::pg_sys::ItemPointerData, callback_state, diff --git a/src/postgres/index_build.rs b/src/index/am_build.rs similarity index 60% rename from src/postgres/index_build.rs rename to src/index/am_build.rs index f6b41b114..603d2544d 100644 --- a/src/postgres/index_build.rs +++ b/src/index/am_build.rs @@ -1,11 +1,13 @@ -use super::hook_transaction::{client, flush_if_commit}; -use crate::ipc::client::Rpc; -use crate::postgres::index_setup::options; +use super::hook_transaction::flush_if_commit; +use crate::index::utils::from_datum; +use crate::ipc::client::ClientGuard; use crate::prelude::*; +use crate::{index::am_setup::options, ipc::client::Rpc}; use pgrx::pg_sys::{IndexBuildResult, IndexInfo, RelationData}; +use service::prelude::*; pub struct Builder { - pub rpc: Rpc, + pub rpc: ClientGuard, pub heap_relation: *mut RelationData, pub index_info: *mut IndexInfo, pub result: *mut IndexBuildResult, @@ -22,27 +24,22 @@ pub unsafe fn build( let id = Id::from_sys(oid); flush_if_commit(id); let options = options(index); - client(|mut rpc| { - rpc.create(id, options).friendly(); - rpc - }); + let mut rpc = crate::ipc::client::borrow_mut(); + rpc.create(id, options); if let Some((heap_relation, index_info, result)) = data { - client(|rpc| { - let mut builder = Builder { - rpc, - heap_relation, - index_info, - result, - }; - pgrx::pg_sys::IndexBuildHeapScan( - heap_relation, - index, - index_info, - Some(callback), - &mut builder, - ); - builder.rpc - }); + let mut builder = Builder { + rpc, + heap_relation, + index_info, + result, + }; + pgrx::pg_sys::IndexBuildHeapScan( + heap_relation, + index, + index_info, + Some(callback), + &mut builder, + ); } } @@ -52,21 +49,17 @@ unsafe extern "C" fn callback( index_relation: pgrx::pg_sys::Relation, htup: pgrx::pg_sys::HeapTuple, values: *mut pgrx::pg_sys::Datum, - is_null: *mut bool, + _is_null: *mut bool, _tuple_is_alive: bool, state: *mut std::os::raw::c_void, ) { - use super::datatype::VectorInput; - use pgrx::FromDatum; - let ctid = &(*htup).t_self; - let oid = (*index_relation).rd_node.relNode; let id = Id::from_sys(oid); let state = &mut *(state as *mut Builder); - let pgvector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap(); - let data = (pgvector.to_vec(), Pointer::from_sys(*ctid)); - state.rpc.insert(id, data).friendly().friendly(); + let vector = from_datum(*values.add(0)); + let data = (vector, Pointer::from_sys(*ctid)); + state.rpc.insert(id, data); (*state.result).heap_tuples += 1.0; (*state.result).index_tuples += 1.0; } @@ -77,22 +70,19 @@ unsafe extern "C" fn callback( index_relation: pgrx::pg_sys::Relation, ctid: pgrx::pg_sys::ItemPointer, values: *mut pgrx::pg_sys::Datum, - is_null: *mut bool, + _is_null: *mut bool, _tuple_is_alive: bool, state: *mut std::os::raw::c_void, ) { - use super::datatype::VectorInput; - use pgrx::FromDatum; - #[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15"))] let oid = (*index_relation).rd_node.relNode; #[cfg(feature = "pg16")] let oid = (*index_relation).rd_locator.relNumber; let id = Id::from_sys(oid); let state = &mut *(state as *mut Builder); - let pgvector = VectorInput::from_datum(*values.add(0), *is_null.add(0)).unwrap(); - let data = (pgvector.to_vec(), Pointer::from_sys(*ctid)); - state.rpc.insert(id, data).friendly().friendly(); + let vector = from_datum(*values.add(0)); + let data = (vector, Pointer::from_sys(*ctid)); + state.rpc.insert(id, data); (*state.result).heap_tuples += 1.0; (*state.result).index_tuples += 1.0; } diff --git a/src/index/am_scan.rs b/src/index/am_scan.rs new file mode 100644 index 000000000..f437de764 --- /dev/null +++ b/src/index/am_scan.rs @@ -0,0 +1,234 @@ +use crate::gucs::ENABLE_PREFILTER; +use crate::gucs::ENABLE_VBASE; +use crate::gucs::K; +use crate::index::utils::from_datum; +use crate::ipc::client::ClientGuard; +use crate::ipc::client::Vbase; +use crate::prelude::*; +use pgrx::FromDatum; +use service::prelude::*; + +pub enum Scanner { + Initial { + node: Option<*mut pgrx::pg_sys::IndexScanState>, + vector: Option, + }, + Search { + node: *mut pgrx::pg_sys::IndexScanState, + data: Vec, + }, + Vbase { + node: *mut pgrx::pg_sys::IndexScanState, + vbase: ClientGuard, + }, +} + +impl Scanner { + fn node(&self) -> Option<*mut pgrx::pg_sys::IndexScanState> { + match self { + Scanner::Initial { node, .. } => *node, + Scanner::Search { node, .. } => Some(*node), + Scanner::Vbase { node, .. } => Some(*node), + } + } +} + +pub unsafe fn make_scan(index_relation: pgrx::pg_sys::Relation) -> pgrx::pg_sys::IndexScanDesc { + use pgrx::PgMemoryContexts; + + let scan = pgrx::pg_sys::RelationGetIndexScan(index_relation, 0, 1); + + (*scan).xs_recheck = false; + (*scan).xs_recheckorderby = false; + + (*scan).opaque = + PgMemoryContexts::CurrentMemoryContext.leak_and_drop_on_delete(Scanner::Initial { + vector: None, + node: None, + }) as _; + + (*scan).xs_orderbyvals = pgrx::pg_sys::palloc0(std::mem::size_of::()) as _; + + (*scan).xs_orderbynulls = { + let data = pgrx::pg_sys::palloc(std::mem::size_of::()) as *mut bool; + data.write_bytes(1, 1); + data + }; + + scan +} + +pub unsafe fn start_scan(scan: pgrx::pg_sys::IndexScanDesc, orderbys: pgrx::pg_sys::ScanKey) { + std::ptr::copy(orderbys, (*scan).orderByData, 1); + + let vector = from_datum((*orderbys.add(0)).sk_argument); + + let scanner = &mut *((*scan).opaque as *mut Scanner); + let scanner = std::mem::replace( + scanner, + Scanner::Initial { + node: scanner.node(), + vector: Some(vector), + }, + ); + + match scanner { + Scanner::Initial { .. } => {} + Scanner::Search { .. } => {} + Scanner::Vbase { vbase, .. } => { + vbase.leave(); + } + } +} + +pub unsafe fn next_scan(scan: pgrx::pg_sys::IndexScanDesc) -> bool { + let scanner = &mut *((*scan).opaque as *mut Scanner); + if let Scanner::Initial { node, vector } = scanner { + let node = node.expect("Hook failed."); + let vector = vector.as_ref().expect("Scan failed."); + + #[cfg(any(feature = "pg12", feature = "pg13", feature = "pg14", feature = "pg15"))] + let oid = (*(*scan).indexRelation).rd_node.relNode; + #[cfg(feature = "pg16")] + let oid = (*(*scan).indexRelation).rd_locator.relNumber; + let id = Id::from_sys(oid); + + let mut rpc = crate::ipc::client::borrow_mut(); + + if ENABLE_VBASE.get() { + let vbase = rpc.vbase(id, vector.clone()); + *scanner = Scanner::Vbase { node, vbase }; + } else { + let k = K.get() as _; + struct Search { + node: *mut pgrx::pg_sys::IndexScanState, + } + + impl crate::ipc::client::Search for Search { + fn check(&mut self, p: Pointer) -> bool { + unsafe { check(self.node, p) } + } + } + + let search = Search { node }; + + let mut data = rpc.search(id, (vector.clone(), k), ENABLE_PREFILTER.get(), search); + data.reverse(); + *scanner = Scanner::Search { node, data }; + } + } + match scanner { + Scanner::Initial { .. } => unreachable!(), + Scanner::Search { data, .. } => { + if let Some(p) = data.pop() { + (*scan).xs_heaptid = p.into_sys(); + true + } else { + false + } + } + Scanner::Vbase { vbase, .. } => { + if let Some(p) = vbase.next() { + (*scan).xs_heaptid = p.into_sys(); + true + } else { + false + } + } + } +} + +pub unsafe fn end_scan(scan: pgrx::pg_sys::IndexScanDesc) { + let scanner = &mut *((*scan).opaque as *mut Scanner); + let scanner = std::mem::replace( + scanner, + Scanner::Initial { + node: scanner.node(), + vector: None, + }, + ); + + match scanner { + Scanner::Initial { .. } => {} + Scanner::Search { .. } => {} + Scanner::Vbase { vbase, .. } => { + vbase.leave(); + } + } +} + +unsafe fn execute_boolean_qual( + state: *mut pgrx::pg_sys::ExprState, + econtext: *mut pgrx::pg_sys::ExprContext, +) -> bool { + use pgrx::PgMemoryContexts; + if state.is_null() { + return true; + } + assert!((*state).flags & pgrx::pg_sys::EEO_FLAG_IS_QUAL as u8 != 0); + let mut is_null = true; + pgrx::pg_sys::MemoryContextReset((*econtext).ecxt_per_tuple_memory); + let ret = PgMemoryContexts::For((*econtext).ecxt_per_tuple_memory) + .switch_to(|_| (*state).evalfunc.unwrap()(state, econtext, &mut is_null)); + assert!(!is_null); + bool::from_datum(ret, is_null).unwrap() +} + +unsafe fn check_quals(node: *mut pgrx::pg_sys::IndexScanState) -> bool { + let slot = (*node).ss.ss_ScanTupleSlot; + let econtext = (*node).ss.ps.ps_ExprContext; + (*econtext).ecxt_scantuple = slot; + if (*node).ss.ps.qual.is_null() { + return true; + } + let state = (*node).ss.ps.qual; + let econtext = (*node).ss.ps.ps_ExprContext; + execute_boolean_qual(state, econtext) +} + +unsafe fn check_mvcc(node: *mut pgrx::pg_sys::IndexScanState, p: Pointer) -> bool { + let scan_desc = (*node).iss_ScanDesc; + let heap_fetch = (*scan_desc).xs_heapfetch; + let index_relation = (*heap_fetch).rel; + let rd_tableam = (*index_relation).rd_tableam; + let snapshot = (*scan_desc).xs_snapshot; + let index_fetch_tuple = (*rd_tableam).index_fetch_tuple.unwrap(); + let mut all_dead = false; + let slot = (*node).ss.ss_ScanTupleSlot; + let mut heap_continue = false; + let found = index_fetch_tuple( + heap_fetch, + &mut p.into_sys(), + snapshot, + slot, + &mut heap_continue, + &mut all_dead, + ); + if found { + return true; + } + while heap_continue { + let found = index_fetch_tuple( + heap_fetch, + &mut p.into_sys(), + snapshot, + slot, + &mut heap_continue, + &mut all_dead, + ); + if found { + return true; + } + } + false +} + +unsafe fn check(node: *mut pgrx::pg_sys::IndexScanState, p: Pointer) -> bool { + if !check_mvcc(node, p) { + return false; + } + if !check_quals(node) { + return false; + } + true +} diff --git a/src/postgres/index_setup.rs b/src/index/am_setup.rs similarity index 74% rename from src/postgres/index_setup.rs rename to src/index/am_setup.rs index 433fd6819..fe452f0e6 100644 --- a/src/postgres/index_setup.rs +++ b/src/index/am_setup.rs @@ -1,22 +1,22 @@ -use crate::index::indexing::IndexingOptions; -use crate::index::optimizing::OptimizingOptions; -use crate::index::segments::SegmentsOptions; -use crate::index::{IndexOptions, VectorOptions}; -use crate::postgres::datatype::VectorTypmod; -use crate::prelude::*; +use crate::datatype::typmod::Typmod; use serde::{Deserialize, Serialize}; +use service::index::indexing::IndexingOptions; +use service::index::optimizing::OptimizingOptions; +use service::index::segments::SegmentsOptions; +use service::index::{IndexOptions, VectorOptions}; +use service::prelude::*; use std::ffi::CStr; use validator::Validate; pub fn helper_offset() -> usize { - memoffset::offset_of!(Helper, offset) + std::mem::offset_of!(Helper, offset) } pub fn helper_size() -> usize { std::mem::size_of::() } -pub unsafe fn convert_opclass_to_distance(opclass: pgrx::pg_sys::Oid) -> Distance { +pub unsafe fn convert_opclass_to_distance(opclass: pgrx::pg_sys::Oid) -> (Distance, Kind) { let opclass_cache_id = pgrx::pg_sys::SysCacheIdentifier_CLAOID as _; let tuple = pgrx::pg_sys::SearchSysCache1(opclass_cache_id, opclass.into()); assert!( @@ -25,12 +25,12 @@ pub unsafe fn convert_opclass_to_distance(opclass: pgrx::pg_sys::Oid) -> Distanc ); let classform = pgrx::pg_sys::GETSTRUCT(tuple).cast::(); let opfamily = (*classform).opcfamily; - let distance = convert_opfamily_to_distance(opfamily); + let result = convert_opfamily_to_distance(opfamily); pgrx::pg_sys::ReleaseSysCache(tuple); - distance + result } -pub unsafe fn convert_opfamily_to_distance(opfamily: pgrx::pg_sys::Oid) -> Distance { +pub unsafe fn convert_opfamily_to_distance(opfamily: pgrx::pg_sys::Oid) -> (Distance, Kind) { let opfamily_cache_id = pgrx::pg_sys::SysCacheIdentifier_OPFAMILYOID as _; let opstrategy_cache_id = pgrx::pg_sys::SysCacheIdentifier_AMOPSTRATEGY as _; let tuple = pgrx::pg_sys::SearchSysCache1(opfamily_cache_id, opfamily.into()); @@ -52,19 +52,25 @@ pub unsafe fn convert_opfamily_to_distance(opfamily: pgrx::pg_sys::Oid) -> Dista assert!((*amop).amopstrategy == 1); assert!((*amop).amoppurpose == pgrx::pg_sys::AMOP_ORDER as libc::c_char); let operator = (*amop).amopopr; - let distance; + let result; if operator == regoperatorin("<->(vector,vector)") { - distance = Distance::L2; + result = (Distance::L2, Kind::F32); } else if operator == regoperatorin("<#>(vector,vector)") { - distance = Distance::Dot; + result = (Distance::Dot, Kind::F32); } else if operator == regoperatorin("<=>(vector,vector)") { - distance = Distance::Cosine; + result = (Distance::Cos, Kind::F32); + } else if operator == regoperatorin("<->(vecf16,vecf16)") { + result = (Distance::L2, Kind::F16); + } else if operator == regoperatorin("<#>(vecf16,vecf16)") { + result = (Distance::Dot, Kind::F16); + } else if operator == regoperatorin("<=>(vecf16,vecf16)") { + result = (Distance::Cos, Kind::F16); } else { - FriendlyError::UnsupportedOperator.friendly(); + FriendlyError::BadOptions3.friendly(); }; pgrx::pg_sys::ReleaseCatCacheList(list); pgrx::pg_sys::ReleaseSysCache(tuple); - distance + result } pub unsafe fn options(index_relation: pgrx::pg_sys::Relation) -> IndexOptions { @@ -72,22 +78,25 @@ pub unsafe fn options(index_relation: pgrx::pg_sys::Relation) -> IndexOptions { assert!(nkeysatts == 1, "Can not be built on multicolumns."); // get distance let opfamily = (*index_relation).rd_opfamily.read(); - let d = convert_opfamily_to_distance(opfamily); + let (d, k) = convert_opfamily_to_distance(opfamily); // get dims let attrs = (*(*index_relation).rd_att).attrs.as_slice(1); let attr = &attrs[0]; - let typmod = VectorTypmod::parse_from_i32(attr.type_mod()).unwrap(); - let dims = typmod.dims().ok_or(FriendlyError::DimsIsNeeded).friendly(); + let typmod = Typmod::parse_from_i32(attr.type_mod()).unwrap(); + let dims = typmod.dims().ok_or(FriendlyError::BadOption2).friendly(); // get other options let parsed = get_parsed_from_varlena((*index_relation).rd_options); let options = IndexOptions { - vector: VectorOptions { dims, d }, + vector: VectorOptions { dims, d, k }, segment: parsed.segment, optimizing: parsed.optimizing, indexing: parsed.indexing, }; if let Err(errors) = options.validate() { - FriendlyError::BadOption(errors.to_string()).friendly(); + FriendlyError::BadOption { + validation: errors.to_string(), + } + .friendly(); } options } diff --git a/src/index/am_update.rs b/src/index/am_update.rs new file mode 100644 index 000000000..e81205f38 --- /dev/null +++ b/src/index/am_update.rs @@ -0,0 +1,31 @@ +use crate::index::hook_transaction::flush_if_commit; +use crate::prelude::*; +use service::prelude::*; + +pub fn update_insert(id: Id, vector: DynamicVector, tid: pgrx::pg_sys::ItemPointerData) { + flush_if_commit(id); + let p = Pointer::from_sys(tid); + let mut rpc = crate::ipc::client::borrow_mut(); + rpc.insert(id, (vector, p)); +} + +pub fn update_delete(id: Id, hook: impl Fn(Pointer) -> bool) { + struct Delete { + hook: H, + } + + impl crate::ipc::client::Delete for Delete + where + H: Fn(Pointer) -> bool, + { + fn test(&mut self, p: Pointer) -> bool { + (self.hook)(p) + } + } + + let client_delete = Delete { hook }; + + flush_if_commit(id); + let mut rpc = crate::ipc::client::borrow_mut(); + rpc.delete(id, client_delete); +} diff --git a/src/postgres/hook_executor.rs b/src/index/hook_executor.rs similarity index 88% rename from src/postgres/hook_executor.rs rename to src/index/hook_executor.rs index 1610372db..e11b2d27f 100644 --- a/src/postgres/hook_executor.rs +++ b/src/index/hook_executor.rs @@ -1,5 +1,4 @@ -use crate::postgres::index_scan::Scanner; -use crate::postgres::index_scan::ScannerState; +use crate::index::am_scan::Scanner; use std::ptr::null_mut; pub unsafe fn post_executor_start(query_desc: *mut pgrx::pg_sys::QueryDesc) { @@ -21,7 +20,7 @@ unsafe extern "C" fn rewrite_plan_state( if index_relation .as_ref() .and_then(|p| p.rd_indam.as_ref()) - .map(|p| p.amvalidate == Some(super::index::amvalidate)) + .map(|p| p.amvalidate == Some(super::am::amvalidate)) .unwrap_or(false) { // The logic is copied from Postgres source code. @@ -33,6 +32,13 @@ unsafe extern "C" fn rewrite_plan_state( (*node).iss_NumScanKeys, (*node).iss_NumOrderByKeys, ); + + let scanner = &mut *((*(*node).iss_ScanDesc).opaque as *mut Scanner); + *scanner = Scanner::Initial { + node: Some(node), + vector: None, + }; + if (*node).iss_NumRuntimeKeys == 0 || (*node).iss_RuntimeKeysReady { pgrx::pg_sys::index_rescan( (*node).iss_ScanDesc, @@ -42,10 +48,6 @@ unsafe extern "C" fn rewrite_plan_state( (*node).iss_NumOrderByKeys, ); } - // inject - let scanner = &mut *((*(*node).iss_ScanDesc).opaque as *mut Scanner); - scanner.index_scan_state = node; - assert!(matches!(scanner.state, ScannerState::Initial { .. })); } } } diff --git a/src/index/hook_transaction.rs b/src/index/hook_transaction.rs new file mode 100644 index 000000000..c9017a37b --- /dev/null +++ b/src/index/hook_transaction.rs @@ -0,0 +1,26 @@ +use crate::utils::cells::PgRefCell; +use service::prelude::*; +use std::collections::BTreeSet; + +static FLUSH_IF_COMMIT: PgRefCell> = unsafe { PgRefCell::new(BTreeSet::new()) }; + +pub fn aborting() { + *FLUSH_IF_COMMIT.borrow_mut() = BTreeSet::new(); +} + +pub fn committing() { + { + let flush_if_commit = FLUSH_IF_COMMIT.borrow(); + if flush_if_commit.len() != 0 { + let mut rpc = crate::ipc::client::borrow_mut(); + for id in flush_if_commit.iter().copied() { + rpc.flush(id); + } + } + } + *FLUSH_IF_COMMIT.borrow_mut() = BTreeSet::new(); +} + +pub fn flush_if_commit(id: Id) { + FLUSH_IF_COMMIT.borrow_mut().insert(id); +} diff --git a/src/postgres/hooks.rs b/src/index/hooks.rs similarity index 89% rename from src/postgres/hooks.rs rename to src/index/hooks.rs index f652bf0ea..a3fda3aaa 100644 --- a/src/postgres/hooks.rs +++ b/src/index/hooks.rs @@ -1,5 +1,5 @@ -use crate::postgres::hook_transaction::client; use crate::prelude::*; +use service::prelude::*; static mut PREV_EXECUTOR_START: pgrx::pg_sys::ExecutorStart_hook_type = None; @@ -46,10 +46,8 @@ unsafe fn xact_delete() { .iter() .map(|node| Id::from_sys(node.relNode)) .collect::>(); - client(|mut rpc| { - rpc.destory(ids).friendly(); - rpc - }); + let mut rpc = crate::ipc::client::borrow_mut(); + rpc.destory(ids); } } @@ -63,9 +61,7 @@ unsafe fn xact_delete() { .iter() .map(|node| Id::from_sys(node.relNumber)) .collect::>(); - client(|mut rpc| { - rpc.destory(ids).friendly(); - rpc - }); + let mut rpc = crate::ipc::client::borrow_mut(); + rpc.destory(ids); } } diff --git a/src/index/mod.rs b/src/index/mod.rs index 0f0f0fc80..8aaa6b2ec 100644 --- a/src/index/mod.rs +++ b/src/index/mod.rs @@ -1,552 +1,17 @@ -pub mod delete; -pub mod indexing; -pub mod optimizing; -pub mod segments; - -use self::delete::Delete; -use self::indexing::IndexingOptions; -use self::optimizing::OptimizingOptions; -use self::segments::growing::GrowingSegment; -use self::segments::growing::GrowingSegmentInsertError; -use self::segments::sealed::SealedSegment; -use self::segments::SegmentsOptions; -use crate::index::indexing::DynamicIndexIter; -use crate::prelude::*; -use crate::utils::clean::clean; -use crate::utils::dir_ops::sync_dir; -use crate::utils::file_atomic::FileAtomic; -use arc_swap::ArcSwap; -use crossbeam::sync::Parker; -use crossbeam::sync::Unparker; -use parking_lot::Mutex; -use serde::{Deserialize, Serialize}; -use std::cmp::Reverse; -use std::collections::BinaryHeap; -use std::collections::HashMap; -use std::collections::HashSet; -use std::path::PathBuf; -use std::sync::{Arc, Weak}; -use thiserror::Error; -use uuid::Uuid; -use validator::Validate; - -#[derive(Debug, Error)] -pub enum IndexInsertError { - #[error("The vector is invalid.")] - InvalidVector(Vec), - #[error("The index view is outdated.")] - OutdatedView(#[from] Option), -} - -#[derive(Debug, Error)] -pub enum IndexSearchError { - #[error("The vector is invalid.")] - InvalidVector(Vec), -} - -#[derive(Debug, Clone, Serialize, Deserialize, Validate)] -pub struct VectorOptions { - #[validate(range(min = 1, max = 65535))] - #[serde(rename = "dimensions")] - pub dims: u16, - #[serde(rename = "distance")] - pub d: Distance, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Validate)] -pub struct IndexOptions { - #[validate] - pub vector: VectorOptions, - #[validate] - pub segment: SegmentsOptions, - #[validate] - pub optimizing: OptimizingOptions, - #[validate] - pub indexing: IndexingOptions, -} - -pub struct Index { - path: PathBuf, - options: IndexOptions, - delete: Arc, - protect: Mutex, - view: ArcSwap, - optimize_unparker: Unparker, - indexing: Mutex, - _tracker: Arc, -} - -impl Index { - pub fn create(path: PathBuf, options: IndexOptions) -> Arc { - assert!(options.validate().is_ok()); - std::fs::create_dir(&path).unwrap(); - std::fs::create_dir(path.join("segments")).unwrap(); - let startup = FileAtomic::create( - path.join("startup"), - IndexStartup { - sealeds: HashSet::new(), - growings: HashSet::new(), - }, - ); - let delete = Delete::create(path.join("delete")); - sync_dir(&path); - let parker = Parker::new(); - let index = Arc::new(Index { - path: path.clone(), - options: options.clone(), - delete: delete.clone(), - protect: Mutex::new(IndexProtect { - startup, - sealed: HashMap::new(), - growing: HashMap::new(), - write: None, - }), - view: ArcSwap::new(Arc::new(IndexView { - options: options.clone(), - sealed: HashMap::new(), - growing: HashMap::new(), - delete: delete.clone(), - write: None, - })), - optimize_unparker: parker.unparker().clone(), - indexing: Mutex::new(true), - _tracker: Arc::new(IndexTracker { path }), - }); - IndexBackground { - index: Arc::downgrade(&index), - parker, - } - .spawn(); - index - } - pub fn open(path: PathBuf, options: IndexOptions) -> Arc { - let tracker = Arc::new(IndexTracker { path: path.clone() }); - let startup = FileAtomic::::open(path.join("startup")); - clean( - path.join("segments"), - startup - .get() - .sealeds - .iter() - .map(|s| s.to_string()) - .chain(startup.get().growings.iter().map(|s| s.to_string())), - ); - let sealed = startup - .get() - .sealeds - .iter() - .map(|&uuid| { - ( - uuid, - SealedSegment::open( - tracker.clone(), - path.join("segments").join(uuid.to_string()), - uuid, - options.clone(), - ), - ) - }) - .collect::>(); - let growing = startup - .get() - .growings - .iter() - .map(|&uuid| { - ( - uuid, - GrowingSegment::open( - tracker.clone(), - path.join("segments").join(uuid.to_string()), - uuid, - options.clone(), - ), - ) - }) - .collect::>(); - let delete = Delete::open(path.join("delete")); - let parker = Parker::new(); - let index = Arc::new(Index { - path: path.clone(), - options: options.clone(), - delete: delete.clone(), - protect: Mutex::new(IndexProtect { - startup, - sealed: sealed.clone(), - growing: growing.clone(), - write: None, - }), - view: ArcSwap::new(Arc::new(IndexView { - options: options.clone(), - delete: delete.clone(), - sealed, - growing, - write: None, - })), - optimize_unparker: parker.unparker().clone(), - indexing: Mutex::new(true), - _tracker: tracker, - }); - IndexBackground { - index: Arc::downgrade(&index), - parker, - } - .spawn(); - index - } - pub fn options(&self) -> &IndexOptions { - &self.options - } - pub fn view(&self) -> Arc { - self.view.load_full() - } - pub fn refresh(&self) { - let mut protect = self.protect.lock(); - if let Some((uuid, write)) = protect.write.clone() { - write.seal(); - protect.growing.insert(uuid, write); - } - let write_segment_uuid = Uuid::new_v4(); - let write_segment = GrowingSegment::create( - self._tracker.clone(), - self.path - .join("segments") - .join(write_segment_uuid.to_string()), - write_segment_uuid, - self.options.clone(), - ); - protect.write = Some((write_segment_uuid, write_segment)); - protect.maintain(self.options.clone(), self.delete.clone(), &self.view); - self.optimize_unparker.unpark(); - } - pub fn indexing(&self) -> bool { - *self.indexing.lock() - } -} - -impl Drop for Index { - fn drop(&mut self) { - self.optimize_unparker.unpark(); - } -} - -#[derive(Debug, Clone)] -pub struct IndexTracker { - path: PathBuf, -} - -impl Drop for IndexTracker { - fn drop(&mut self) { - std::fs::remove_dir_all(&self.path).unwrap(); - } -} - -pub struct IndexView { - options: IndexOptions, - delete: Arc, - sealed: HashMap>, - growing: HashMap>, - write: Option<(Uuid, Arc)>, -} - -impl IndexView { - pub fn sealed_len(&self) -> u32 { - self.sealed.values().map(|x| x.len()).sum::() - } - pub fn growing_len(&self) -> u32 { - self.growing.values().map(|x| x.len()).sum::() - } - pub fn write_len(&self) -> u32 { - self.write.as_ref().map(|x| x.1.len()).unwrap_or(0) - } - pub fn sealed_len_vec(&self) -> Vec { - self.sealed.values().map(|x| x.len()).collect() - } - pub fn growing_len_vec(&self) -> Vec { - self.growing.values().map(|x| x.len()).collect() - } - pub fn search bool>( - &self, - k: usize, - vector: &[Scalar], - mut filter: F, - ) -> Result, IndexSearchError> { - if self.options.vector.dims as usize != vector.len() { - return Err(IndexSearchError::InvalidVector(vector.to_vec())); - } - - struct Comparer(BinaryHeap>); - - impl PartialEq for Comparer { - fn eq(&self, other: &Self) -> bool { - self.cmp(other).is_eq() - } - } - - impl Eq for Comparer {} - - impl PartialOrd for Comparer { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } - } - - impl Ord for Comparer { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.0.peek().cmp(&other.0.peek()).reverse() - } - } - - let mut filter = |payload| { - if let Some(p) = self.delete.check(payload) { - filter(p) - } else { - false - } - }; - let n = self.sealed.len() + self.growing.len() + 1; - let mut result = Heap::new(k); - let mut heaps = BinaryHeap::with_capacity(1 + n); - for (_, sealed) in self.sealed.iter() { - let p = sealed.search(k, vector, &mut filter).into_reversed_heap(); - heaps.push(Comparer(p)); - } - for (_, growing) in self.growing.iter() { - let p = growing.search(k, vector, &mut filter).into_reversed_heap(); - heaps.push(Comparer(p)); - } - if let Some((_, write)) = &self.write { - let p = write.search(k, vector, &mut filter).into_reversed_heap(); - heaps.push(Comparer(p)); - } - while let Some(Comparer(mut heap)) = heaps.pop() { - if let Some(Reverse(x)) = heap.pop() { - result.push(x); - heaps.push(Comparer(heap)); - } - } - Ok(result - .into_sorted_vec() - .iter() - .map(|x| Pointer::from_u48(x.payload >> 16)) - .collect()) - } - pub fn search_vbase( - &self, - range: usize, - vector: &[Scalar], - mut next: F, - ) -> Result<(), IndexSearchError> - where - F: FnMut(Pointer) -> bool, - { - if self.options.vector.dims as usize != vector.len() { - return Err(IndexSearchError::InvalidVector(vector.to_vec())); - } - - struct Comparer<'index, 'vector> { - iter: ComparerIter<'index, 'vector>, - item: Option, - } - - enum ComparerIter<'index, 'vector> { - Sealed(DynamicIndexIter<'index, 'vector>), - Growing(std::vec::IntoIter), - } - - impl PartialEq for Comparer<'_, '_> { - fn eq(&self, other: &Self) -> bool { - self.cmp(other).is_eq() - } - } - - impl Eq for Comparer<'_, '_> {} - - impl PartialOrd for Comparer<'_, '_> { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } - } - - impl Ord for Comparer<'_, '_> { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.item.cmp(&other.item).reverse() - } - } - - impl Iterator for ComparerIter<'_, '_> { - type Item = HeapElement; - fn next(&mut self) -> Option { - match self { - Self::Sealed(iter) => iter.next(), - Self::Growing(iter) => iter.next(), - } - } - } - - impl Iterator for Comparer<'_, '_> { - type Item = HeapElement; - fn next(&mut self) -> Option { - let item = self.item.take(); - self.item = self.iter.next(); - item - } - } - - fn from_iter<'index, 'vector>( - mut iter: ComparerIter<'index, 'vector>, - ) -> Comparer<'index, 'vector> { - let item = iter.next(); - Comparer { iter, item } - } - - use ComparerIter::*; - let filter = |payload| self.delete.check(payload).is_some(); - let n = self.sealed.len() + self.growing.len() + 1; - let mut heaps: BinaryHeap = BinaryHeap::with_capacity(1 + n); - for (_, sealed) in self.sealed.iter() { - let res = sealed.search_vbase(range, vector); - heaps.push(from_iter(Sealed(res))); - } - for (_, growing) in self.growing.iter() { - let mut res = growing.search_all(vector); - res.sort_unstable(); - heaps.push(from_iter(Growing(res.into_iter()))); - } - if let Some((_, write)) = &self.write { - let mut res = write.search_all(vector); - res.sort_unstable(); - heaps.push(from_iter(Growing(res.into_iter()))); - } - while let Some(mut iter) = heaps.pop() { - if let Some(x) = iter.next() { - if !filter(x.payload) { - continue; - } - let stop = next(Pointer::from_u48(x.payload >> 16)); - if stop { - break; - } - heaps.push(iter); - } - } - Ok(()) - } - pub fn insert(&self, vector: Vec, pointer: Pointer) -> Result<(), IndexInsertError> { - if self.options.vector.dims as usize != vector.len() { - return Err(IndexInsertError::InvalidVector(vector)); - } - let payload = (pointer.as_u48() << 16) | self.delete.version(pointer) as Payload; - if let Some((_, growing)) = self.write.as_ref() { - Ok(growing.insert(vector, payload)?) - } else { - Err(IndexInsertError::OutdatedView(None)) - } - } - pub fn delete bool>(&self, mut f: F) { - for (_, sealed) in self.sealed.iter() { - let n = sealed.len(); - for i in 0..n { - if let Some(p) = self.delete.check(sealed.payload(i)) { - if f(p) { - self.delete.delete(p); - } - } - } - } - for (_, growing) in self.growing.iter() { - let n = growing.len(); - for i in 0..n { - if let Some(p) = self.delete.check(growing.payload(i)) { - if f(p) { - self.delete.delete(p); - } - } - } - } - if let Some((_, write)) = &self.write { - let n = write.len(); - for i in 0..n { - if let Some(p) = self.delete.check(write.payload(i)) { - if f(p) { - self.delete.delete(p); - } - } - } - } - } - pub fn flush(&self) -> Result<(), IndexInsertError> { - self.delete.flush(); - if let Some((_, write)) = &self.write { - write.flush(); - } - Ok(()) - } -} - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -struct IndexStartup { - sealeds: HashSet, - growings: HashSet, -} - -struct IndexProtect { - startup: FileAtomic, - sealed: HashMap>, - growing: HashMap>, - write: Option<(Uuid, Arc)>, -} - -impl IndexProtect { - fn maintain(&mut self, options: IndexOptions, delete: Arc, swap: &ArcSwap) { - let view: Arc = Arc::new(IndexView { - options, - delete, - sealed: self.sealed.clone(), - growing: self.growing.clone(), - write: self.write.clone(), - }); - let startup_write = self.write.as_ref().map(|(uuid, _)| *uuid); - let startup_sealeds = self.sealed.keys().copied().collect(); - let startup_growings = self.growing.keys().copied().chain(startup_write).collect(); - self.startup.set(IndexStartup { - sealeds: startup_sealeds, - growings: startup_growings, - }); - swap.swap(view); - } -} - -pub struct IndexBackground { - index: Weak, - parker: Parker, -} - -impl IndexBackground { - pub fn main(self) { - let pool; - if let Some(index) = self.index.upgrade() { - pool = rayon::ThreadPoolBuilder::new() - .num_threads(index.options.optimizing.optimizing_threads) - .build() - .unwrap(); - } else { - return; - } - while let Some(index) = self.index.upgrade() { - let done = pool.install(|| optimizing::indexing::optimizing_indexing(index.clone())); - if done { - *index.indexing.lock() = false; - drop(index); - self.parker.park(); - if let Some(index) = self.index.upgrade() { - *index.indexing.lock() = true; - } - } - } - } - pub fn spawn(self) { - std::thread::spawn(move || { - self.main(); - }); - } +#![allow(unsafe_op_in_unsafe_fn)] + +mod am; +mod am_build; +mod am_scan; +mod am_setup; +mod am_update; +mod hook_executor; +mod hook_transaction; +mod hooks; +mod utils; +mod views; + +pub unsafe fn init() { + self::hooks::init(); + self::am::init(); } diff --git a/src/index/optimizing/indexing.rs b/src/index/optimizing/indexing.rs deleted file mode 100644 index 445496201..000000000 --- a/src/index/optimizing/indexing.rs +++ /dev/null @@ -1,105 +0,0 @@ -use crate::index::GrowingSegment; -use crate::index::Index; -use crate::index::SealedSegment; -use std::cmp::Reverse; -use std::sync::Arc; -use uuid::Uuid; - -enum Seg { - Sealed(Arc), - Growing(Arc), -} - -impl Seg { - fn uuid(&self) -> Uuid { - use Seg::*; - match self { - Sealed(x) => x.uuid(), - Growing(x) => x.uuid(), - } - } - fn len(&self) -> u32 { - use Seg::*; - match self { - Sealed(x) => x.len(), - Growing(x) => x.len(), - } - } - fn get_sealed(&self) -> Option> { - match self { - Seg::Sealed(x) => Some(x.clone()), - _ => None, - } - } - fn get_growing(&self) -> Option> { - match self { - Seg::Growing(x) => Some(x.clone()), - _ => None, - } - } -} - -pub fn optimizing_indexing(index: Arc) -> bool { - use Seg::*; - let segs = { - let mut all_segs = { - let protect = index.protect.lock(); - let mut all_segs = Vec::new(); - all_segs.extend(protect.growing.values().map(|x| Growing(x.clone()))); - all_segs.extend(protect.sealed.values().map(|x| Sealed(x.clone()))); - all_segs.sort_by_key(|case| Reverse(case.len())); - all_segs - }; - let mut segs = Vec::new(); - let mut segs_len = 0u64; - while let Some(seg) = all_segs.pop() { - if segs_len + seg.len() as u64 <= index.options.segment.max_sealed_segment_size as u64 { - segs_len += seg.len() as u64; - segs.push(seg); - } else { - break; - } - } - if segs_len < index.options.segment.min_sealed_segment_size as u64 || segs.len() < 3 { - return true; - } - segs - }; - let sealed_segment = merge(&index, &segs); - { - let mut protect = index.protect.lock(); - for seg in segs.iter() { - if protect.sealed.contains_key(&seg.uuid()) { - continue; - } - if protect.growing.contains_key(&seg.uuid()) { - continue; - } - return false; - } - for seg in segs.iter() { - protect.sealed.remove(&seg.uuid()); - protect.growing.remove(&seg.uuid()); - } - protect.sealed.insert(sealed_segment.uuid(), sealed_segment); - protect.maintain(index.options.clone(), index.delete.clone(), &index.view); - } - false -} - -fn merge(index: &Arc, segs: &[Seg]) -> Arc { - let sealed = segs.iter().filter_map(|x| x.get_sealed()).collect(); - let growing = segs.iter().filter_map(|x| x.get_growing()).collect(); - let sealed_segment_uuid = Uuid::new_v4(); - SealedSegment::create( - index._tracker.clone(), - index - .path - .join("segments") - .join(sealed_segment_uuid.to_string()), - sealed_segment_uuid, - index.options.clone(), - sealed, - growing, - ) -} diff --git a/src/index/utils.rs b/src/index/utils.rs new file mode 100644 index 000000000..1be0145e0 --- /dev/null +++ b/src/index/utils.rs @@ -0,0 +1,25 @@ +use crate::datatype::vecf16::Vecf16; +use crate::datatype::vecf32::Vecf32; +use service::prelude::DynamicVector; + +#[repr(C, align(8))] +struct Header { + varlena: u32, + len: u16, + kind: u8, + reserved: u8, +} + +pub unsafe fn from_datum(datum: pgrx::pg_sys::Datum) -> DynamicVector { + let p = datum.cast_mut_ptr::(); + let q = pgrx::pg_sys::pg_detoast_datum(p); + let vector = match (*q.cast::
()).kind { + 0 => DynamicVector::F32((*q.cast::()).data().to_vec()), + 1 => DynamicVector::F16((*q.cast::()).data().to_vec()), + _ => unreachable!(), + }; + if p != q { + pgrx::pg_sys::pfree(q.cast()); + } + vector +} diff --git a/src/index/views.rs b/src/index/views.rs new file mode 100644 index 000000000..113b4e349 --- /dev/null +++ b/src/index/views.rs @@ -0,0 +1,46 @@ +use crate::prelude::*; +use service::prelude::*; + +pgrx::extension_sql!( + "\ +CREATE TYPE VectorIndexStat AS ( + idx_indexing BOOL, + idx_tuples BIGINT, + idx_sealed BIGINT[], + idx_growing BIGINT[], + idx_write BIGINT, + idx_options TEXT +);", + name = "create_composites", +); + +#[pgrx::pg_extern(volatile, strict)] +fn vector_stat(oid: pgrx::pg_sys::Oid) -> pgrx::composite_type!("VectorIndexStat") { + let id = Id::from_sys(oid); + let mut res = pgrx::prelude::PgHeapTuple::new_composite_type("VectorIndexStat").unwrap(); + let mut rpc = crate::ipc::client::borrow_mut(); + let stat = rpc.stat(id); + res.set_by_name("idx_indexing", stat.indexing).unwrap(); + res.set_by_name("idx_tuples", { + let mut tuples = 0; + tuples += stat.sealed.iter().map(|x| *x as i64).sum::(); + tuples += stat.growing.iter().map(|x| *x as i64).sum::(); + tuples += stat.write as i64; + tuples + }) + .unwrap(); + res.set_by_name("idx_sealed", { + let sealed = stat.sealed; + sealed.into_iter().map(|x| x as i64).collect::>() + }) + .unwrap(); + res.set_by_name("idx_growing", { + let growing = stat.growing; + growing.into_iter().map(|x| x as i64).collect::>() + }) + .unwrap(); + res.set_by_name("idx_write", stat.write as i64).unwrap(); + res.set_by_name("idx_options", serde_json::to_string(&stat.options)) + .unwrap(); + res +} diff --git a/src/ipc/client.rs b/src/ipc/client.rs deleted file mode 100644 index 8d0913d5c..000000000 --- a/src/ipc/client.rs +++ /dev/null @@ -1,238 +0,0 @@ -use crate::index::IndexOptions; -use crate::ipc::packet::*; -use crate::ipc::transport::Socket; -use crate::ipc::IpcError; -use crate::prelude::*; - -pub struct Rpc { - socket: Socket, -} - -impl Rpc { - pub(super) fn new(socket: Socket) -> Self { - Self { socket } - } - pub fn create(&mut self, id: Id, options: IndexOptions) -> Result<(), IpcError> { - let packet = RpcPacket::Create { id, options }; - self.socket.send(packet)?; - let CreatePacket::Leave {} = self.socket.recv::()?; - Ok(()) - } - pub fn search( - mut self, - id: Id, - search: (Vec, usize), - prefilter: bool, - ) -> Result { - let packet = RpcPacket::Search { - id, - search, - prefilter, - }; - self.socket.send(packet)?; - Ok(SearchHandler { - socket: self.socket, - }) - } - pub fn search_vbase( - mut self, - id: Id, - search: (Vec, usize), - ) -> Result { - let packet = RpcPacket::SearchVbase { id, search }; - self.socket.send(packet)?; - Ok(SearchVbaseHandler { - socket: self.socket, - }) - } - pub fn delete(mut self, id: Id) -> Result { - let packet = RpcPacket::Delete { id }; - self.socket.send(packet)?; - Ok(DeleteHandler { - socket: self.socket, - }) - } - pub fn insert( - &mut self, - id: Id, - insert: (Vec, Pointer), - ) -> Result, IpcError> { - let packet = RpcPacket::Insert { id, insert }; - self.socket.send(packet)?; - let InsertPacket::Leave { result } = self.socket.recv::()?; - Ok(result) - } - pub fn flush(&mut self, id: Id) -> Result, IpcError> { - let packet = RpcPacket::Flush { id }; - self.socket.send(packet)?; - let FlushPacket::Leave { result } = self.socket.recv::()?; - Ok(result) - } - pub fn destory(&mut self, ids: Vec) -> Result<(), IpcError> { - let packet = RpcPacket::Destory { ids }; - self.socket.send(packet)?; - let DestoryPacket::Leave {} = self.socket.recv::()?; - Ok(()) - } - pub fn stat(&mut self, id: Id) -> Result, IpcError> { - let packet = RpcPacket::Stat { id }; - self.socket.send(packet)?; - let StatPacket::Leave { result } = self.socket.recv::()?; - Ok(result) - } -} - -pub enum SearchHandle { - Check { - p: Pointer, - x: SearchCheck, - }, - Leave { - result: Result, FriendlyError>, - x: Rpc, - }, -} - -pub struct SearchHandler { - socket: Socket, -} - -impl SearchHandler { - pub fn handle(mut self) -> Result { - Ok(match self.socket.recv::()? { - SearchPacket::Check { p } => SearchHandle::Check { - p, - x: SearchCheck { - socket: self.socket, - }, - }, - SearchPacket::Leave { result } => SearchHandle::Leave { - result, - x: Rpc { - socket: self.socket, - }, - }, - }) - } -} - -pub struct SearchCheck { - socket: Socket, -} - -impl SearchCheck { - pub fn leave(mut self, result: bool) -> Result { - let packet = SearchCheckPacket::Leave { result }; - self.socket.send(packet)?; - Ok(SearchHandler { - socket: self.socket, - }) - } -} - -pub enum SearchVbaseHandle { - Next { - p: Pointer, - x: SearchVbaseNext, - }, - Leave { - result: Result<(), FriendlyError>, - x: Rpc, - }, -} - -pub struct SearchVbaseHandler { - socket: Socket, -} - -impl SearchVbaseHandler { - pub fn handle(mut self) -> Result { - Ok(match self.socket.recv::()? { - SearchVbasePacket::Next { p } => SearchVbaseHandle::Next { - p, - x: SearchVbaseNext { - socket: self.socket, - }, - }, - SearchVbasePacket::Leave { result } => SearchVbaseHandle::Leave { - result, - x: Rpc { - socket: self.socket, - }, - }, - }) - } -} - -pub struct SearchVbaseNext { - socket: Socket, -} - -impl SearchVbaseNext { - pub fn next(mut self) -> Result { - let packet = SearchVbaseNextPacket::Leave { stop: false }; - self.socket.send(packet)?; - Ok(SearchVbaseHandler { - socket: self.socket, - }) - } - pub fn stop(mut self) -> Result { - let packet = SearchVbaseNextPacket::Leave { stop: true }; - self.socket.send(packet)?; - match self.socket.recv::()? { - SearchVbasePacket::Leave { result } => result.friendly(), - _ => unreachable!(), - }; - Ok(Rpc { - socket: self.socket, - }) - } -} - -pub enum DeleteHandle { - Next { - p: Pointer, - x: DeleteNext, - }, - Leave { - result: Result<(), FriendlyError>, - x: Rpc, - }, -} - -pub struct DeleteHandler { - socket: Socket, -} - -impl DeleteHandler { - pub fn handle(mut self) -> Result { - Ok(match self.socket.recv::()? { - DeletePacket::Next { p } => DeleteHandle::Next { - p, - x: DeleteNext { - socket: self.socket, - }, - }, - DeletePacket::Leave { result } => DeleteHandle::Leave { - result, - x: Rpc { - socket: self.socket, - }, - }, - }) - } -} - -pub struct DeleteNext { - socket: Socket, -} - -impl DeleteNext { - pub fn leave(mut self, delete: bool) -> Result { - let packet = DeleteNextPacket::Leave { delete }; - self.socket.send(packet)?; - Ok(DeleteHandler { - socket: self.socket, - }) - } -} diff --git a/src/ipc/client/mod.rs b/src/ipc/client/mod.rs new file mode 100644 index 000000000..38aa32b97 --- /dev/null +++ b/src/ipc/client/mod.rs @@ -0,0 +1,235 @@ +use super::packet::*; +use super::transport::Socket; +use crate::gucs::{Transport, TRANSPORT}; +use crate::utils::cells::PgRefCell; +use service::index::IndexOptions; +use service::index::IndexStat; +use service::prelude::*; +use std::mem::ManuallyDrop; +use std::ops::Deref; +use std::ops::DerefMut; + +pub trait ClientLike: 'static { + const RESET: bool = false; + + fn from_socket(socket: Socket) -> Self; + fn to_socket(self) -> Socket; +} + +pub struct ClientGuard(pub ManuallyDrop); + +impl ClientGuard { + fn map(mut self) -> ClientGuard { + unsafe { + let t = ManuallyDrop::take(&mut self.0); + std::mem::forget(self); + ClientGuard::new(U::from_socket(t.to_socket())) + } + } +} + +impl Deref for ClientGuard { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for ClientGuard { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +pub struct Rpc { + socket: Socket, +} + +impl Rpc { + pub fn new(socket: Socket) -> Self { + Self { socket } + } + pub fn create(self: &mut ClientGuard, id: Id, options: IndexOptions) { + let packet = RpcPacket::Create { id, options }; + self.socket.client_send(packet).friendly(); + let create::CreatePacket::Leave {} = self.socket.client_recv().friendly(); + } + pub fn search( + self: &mut ClientGuard, + id: Id, + search: (DynamicVector, usize), + prefilter: bool, + mut t: impl Search, + ) -> Vec { + let packet = RpcPacket::Search { + id, + search, + prefilter, + }; + self.socket.client_send(packet).friendly(); + loop { + match self.socket.client_recv().friendly() { + search::SearchPacket::Check { p } => { + self.socket + .client_send(search::SearchCheckPacket { result: t.check(p) }) + .friendly(); + } + search::SearchPacket::Leave { result } => { + return result.friendly(); + } + } + } + } + pub fn delete(self: &mut ClientGuard, id: Id, mut t: impl Delete) { + let packet = RpcPacket::Delete { id }; + self.socket.client_send(packet).friendly(); + loop { + match self.socket.client_recv().friendly() { + delete::DeletePacket::Test { p } => { + self.socket + .client_send(delete::DeleteTestPacket { delete: t.test(p) }) + .friendly(); + } + delete::DeletePacket::Leave { result } => { + return result.friendly(); + } + } + } + } + pub fn insert(self: &mut ClientGuard, id: Id, insert: (DynamicVector, Pointer)) { + let packet = RpcPacket::Insert { id, insert }; + self.socket.client_send(packet).friendly(); + let insert::InsertPacket::Leave { result } = self.socket.client_recv().friendly(); + result.friendly() + } + pub fn flush(self: &mut ClientGuard, id: Id) { + let packet = RpcPacket::Flush { id }; + self.socket.client_send(packet).friendly(); + let flush::FlushPacket::Leave { result } = self.socket.client_recv().friendly(); + result.friendly() + } + pub fn destory(self: &mut ClientGuard, ids: Vec) { + let packet = RpcPacket::Destory { ids }; + self.socket.client_send(packet).friendly(); + let destory::DestoryPacket::Leave {} = self.socket.client_recv().friendly(); + } + pub fn stat(self: &mut ClientGuard, id: Id) -> IndexStat { + let packet = RpcPacket::Stat { id }; + self.socket.client_send(packet).friendly(); + let stat::StatPacket::Leave { result } = self.socket.client_recv().friendly(); + result.friendly() + } + pub fn vbase(mut self: ClientGuard, id: Id, vector: DynamicVector) -> ClientGuard { + let packet = RpcPacket::Vbase { id, vector }; + self.socket.client_send(packet).friendly(); + let vbase::VbaseErrorPacket { result } = self.socket.client_recv().friendly(); + result.friendly(); + ClientGuard::map(self) + } +} + +impl ClientLike for Rpc { + const RESET: bool = true; + + fn from_socket(socket: Socket) -> Self { + Self { socket } + } + + fn to_socket(self) -> Socket { + self.socket + } +} + +pub trait Search { + fn check(&mut self, p: Pointer) -> bool; +} + +pub trait Delete { + fn test(&mut self, p: Pointer) -> bool; +} + +pub struct Vbase { + socket: Socket, +} + +impl Vbase { + pub fn next(self: &mut ClientGuard) -> Option { + let packet = vbase::VbasePacket::Next {}; + self.socket.client_send(packet).friendly(); + let vbase::VbaseNextPacket { p } = self.socket.client_recv().friendly(); + p + } + pub fn leave(mut self: ClientGuard) -> ClientGuard { + let packet = vbase::VbasePacket::Leave {}; + self.socket.client_send(packet).friendly(); + let vbase::VbaseLeavePacket {} = self.socket.client_recv().friendly(); + ClientGuard::map(self) + } +} + +impl ClientLike for Vbase { + fn from_socket(socket: Socket) -> Self { + Self { socket } + } + + fn to_socket(self) -> Socket { + self.socket + } +} + +enum Status { + Borrowed, + Lost, + Reset(Socket), +} + +static CLIENT: PgRefCell = unsafe { PgRefCell::new(Status::Lost) }; + +pub fn borrow_mut() -> ClientGuard { + let mut x = CLIENT.borrow_mut(); + match &mut *x { + Status::Borrowed => { + panic!("borrowed when borrowed"); + } + Status::Lost => { + let socket = match TRANSPORT.get() { + Transport::unix => crate::ipc::connect_unix(), + Transport::mmap => crate::ipc::connect_mmap(), + }; + *x = Status::Borrowed; + ClientGuard::new(Rpc::new(socket)) + } + x @ Status::Reset(_) => { + let Status::Reset(socket) = std::mem::replace(x, Status::Borrowed) else { + unreachable!() + }; + ClientGuard::new(Rpc::new(socket)) + } + } +} + +impl ClientGuard { + pub fn new(t: T) -> Self { + Self(ManuallyDrop::new(t)) + } +} + +impl Drop for ClientGuard { + fn drop(&mut self) { + let mut x = CLIENT.borrow_mut(); + match *x { + Status::Borrowed => { + if T::RESET { + unsafe { + *x = Status::Reset(ManuallyDrop::take(&mut self.0).to_socket()); + } + } else { + *x = Status::Lost; + } + } + Status::Lost => unreachable!(), + Status::Reset(_) => unreachable!(), + } + } +} diff --git a/src/ipc/mod.rs b/src/ipc/mod.rs index 63682a3b7..f93df1b41 100644 --- a/src/ipc/mod.rs +++ b/src/ipc/mod.rs @@ -3,8 +3,8 @@ mod packet; pub mod server; pub mod transport; -use self::client::Rpc; use self::server::RpcHandler; +use service::prelude::*; use thiserror::Error; #[derive(Debug, Clone, Error)] @@ -18,6 +18,12 @@ Please check the full Postgresql log to get more information.\ Closed, } +impl FriendlyErrorLike for IpcError { + fn friendly(self) -> ! { + panic!("pgvecto.rs: {}", self); + } +} + pub fn listen_unix() -> impl Iterator { std::iter::from_fn(move || { let socket = self::transport::Socket::Unix(self::transport::unix::accept()); @@ -32,12 +38,15 @@ pub fn listen_mmap() -> impl Iterator { }) } -pub fn connect_unix() -> Rpc { - let socket = self::transport::Socket::Unix(self::transport::unix::connect()); - self::client::Rpc::new(socket) +pub fn connect_unix() -> self::transport::Socket { + self::transport::Socket::Unix(self::transport::unix::connect()) +} + +pub fn connect_mmap() -> self::transport::Socket { + self::transport::Socket::Mmap(self::transport::mmap::connect()) } -pub fn connect_mmap() -> Rpc { - let socket = self::transport::Socket::Mmap(self::transport::mmap::connect()); - self::client::Rpc::new(socket) +pub fn init() { + self::transport::mmap::init(); + self::transport::unix::init(); } diff --git a/src/ipc/packet.rs b/src/ipc/packet.rs deleted file mode 100644 index 7ba476749..000000000 --- a/src/ipc/packet.rs +++ /dev/null @@ -1,101 +0,0 @@ -use crate::index::IndexOptions; -use crate::prelude::*; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize)] -pub enum RpcPacket { - Create { - id: Id, - options: IndexOptions, - }, - Flush { - id: Id, - }, - Destory { - ids: Vec, - }, - Insert { - id: Id, - insert: (Vec, Pointer), - }, - Delete { - id: Id, - }, - Search { - id: Id, - search: (Vec, usize), - prefilter: bool, - }, - SearchVbase { - id: Id, - search: (Vec, usize), - }, - Stat { - id: Id, - }, - Leave {}, -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum CreatePacket { - Leave {}, -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum FlushPacket { - Leave { result: Result<(), FriendlyError> }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum DestoryPacket { - Leave {}, -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum InsertPacket { - Leave { result: Result<(), FriendlyError> }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum DeletePacket { - Next { p: Pointer }, - Leave { result: Result<(), FriendlyError> }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum DeleteNextPacket { - Leave { delete: bool }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum SearchPacket { - Check { - p: Pointer, - }, - Leave { - result: Result, FriendlyError>, - }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum SearchCheckPacket { - Leave { result: bool }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum SearchVbasePacket { - Next { p: Pointer }, - Leave { result: Result<(), FriendlyError> }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum SearchVbaseNextPacket { - Leave { stop: bool }, -} - -#[derive(Debug, Serialize, Deserialize)] -pub enum StatPacket { - Leave { - result: Result, - }, -} diff --git a/src/ipc/packet/create.rs b/src/ipc/packet/create.rs new file mode 100644 index 000000000..14a19969b --- /dev/null +++ b/src/ipc/packet/create.rs @@ -0,0 +1,7 @@ +use serde::Deserialize; +use serde::Serialize; + +#[derive(Debug, Serialize, Deserialize)] +pub enum CreatePacket { + Leave {}, +} diff --git a/src/ipc/packet/delete.rs b/src/ipc/packet/delete.rs new file mode 100644 index 000000000..f1280ec61 --- /dev/null +++ b/src/ipc/packet/delete.rs @@ -0,0 +1,13 @@ +use serde::{Deserialize, Serialize}; +use service::prelude::*; + +#[derive(Debug, Serialize, Deserialize)] +pub enum DeletePacket { + Test { p: Pointer }, + Leave { result: Result<(), FriendlyError> }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct DeleteTestPacket { + pub delete: bool, +} diff --git a/src/ipc/packet/destory.rs b/src/ipc/packet/destory.rs new file mode 100644 index 000000000..0c021ca99 --- /dev/null +++ b/src/ipc/packet/destory.rs @@ -0,0 +1,6 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub enum DestoryPacket { + Leave {}, +} diff --git a/src/ipc/packet/flush.rs b/src/ipc/packet/flush.rs new file mode 100644 index 000000000..d111f6f73 --- /dev/null +++ b/src/ipc/packet/flush.rs @@ -0,0 +1,7 @@ +use serde::{Deserialize, Serialize}; +use service::prelude::*; + +#[derive(Debug, Serialize, Deserialize)] +pub enum FlushPacket { + Leave { result: Result<(), FriendlyError> }, +} diff --git a/src/ipc/packet/insert.rs b/src/ipc/packet/insert.rs new file mode 100644 index 000000000..9ba2008d6 --- /dev/null +++ b/src/ipc/packet/insert.rs @@ -0,0 +1,7 @@ +use serde::{Deserialize, Serialize}; +use service::prelude::*; + +#[derive(Debug, Serialize, Deserialize)] +pub enum InsertPacket { + Leave { result: Result<(), FriendlyError> }, +} diff --git a/src/ipc/packet/mod.rs b/src/ipc/packet/mod.rs new file mode 100644 index 000000000..2df63f0a6 --- /dev/null +++ b/src/ipc/packet/mod.rs @@ -0,0 +1,45 @@ +pub mod create; +pub mod delete; +pub mod destory; +pub mod flush; +pub mod insert; +pub mod search; +pub mod stat; +pub mod vbase; + +use serde::{Deserialize, Serialize}; +use service::index::IndexOptions; +use service::prelude::*; + +#[derive(Debug, Serialize, Deserialize)] +pub enum RpcPacket { + Create { + id: Id, + options: IndexOptions, + }, + Delete { + id: Id, + }, + Destory { + ids: Vec, + }, + Flush { + id: Id, + }, + Insert { + id: Id, + insert: (DynamicVector, Pointer), + }, + Search { + id: Id, + search: (DynamicVector, usize), + prefilter: bool, + }, + Stat { + id: Id, + }, + Vbase { + id: Id, + vector: DynamicVector, + }, +} diff --git a/src/ipc/packet/search.rs b/src/ipc/packet/search.rs new file mode 100644 index 000000000..a95c96b4f --- /dev/null +++ b/src/ipc/packet/search.rs @@ -0,0 +1,17 @@ +use serde::{Deserialize, Serialize}; +use service::prelude::*; + +#[derive(Debug, Serialize, Deserialize)] +pub enum SearchPacket { + Check { + p: Pointer, + }, + Leave { + result: Result, FriendlyError>, + }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SearchCheckPacket { + pub result: bool, +} diff --git a/src/ipc/packet/stat.rs b/src/ipc/packet/stat.rs new file mode 100644 index 000000000..de388e032 --- /dev/null +++ b/src/ipc/packet/stat.rs @@ -0,0 +1,10 @@ +use serde::{Deserialize, Serialize}; +use service::index::IndexStat; +use service::prelude::*; + +#[derive(Debug, Serialize, Deserialize)] +pub enum StatPacket { + Leave { + result: Result, + }, +} diff --git a/src/ipc/packet/vbase.rs b/src/ipc/packet/vbase.rs new file mode 100644 index 000000000..bce914b32 --- /dev/null +++ b/src/ipc/packet/vbase.rs @@ -0,0 +1,21 @@ +use serde::{Deserialize, Serialize}; +use service::prelude::*; + +#[derive(Debug, Serialize, Deserialize)] +pub struct VbaseErrorPacket { + pub result: Result<(), FriendlyError>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum VbasePacket { + Next {}, + Leave {}, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct VbaseNextPacket { + pub p: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct VbaseLeavePacket {} diff --git a/src/ipc/server.rs b/src/ipc/server/mod.rs similarity index 57% rename from src/ipc/server.rs rename to src/ipc/server/mod.rs index 9e8ce181c..efafb6efd 100644 --- a/src/ipc/server.rs +++ b/src/ipc/server/mod.rs @@ -1,8 +1,9 @@ -use crate::index::IndexOptions; -use crate::ipc::packet::*; -use crate::ipc::transport::Socket; -use crate::ipc::IpcError; -use crate::prelude::*; +use super::packet::*; +use super::transport::Socket; +use super::IpcError; +use service::index::IndexOptions; +use service::index::IndexStat; +use service::prelude::*; pub struct RpcHandler { socket: Socket, @@ -13,7 +14,7 @@ impl RpcHandler { Self { socket } } pub fn handle(mut self) -> Result { - Ok(match self.socket.recv::()? { + Ok(match self.socket.server_recv::()? { RpcPacket::Create { id, options } => RpcHandle::Create { id, options, @@ -46,13 +47,6 @@ impl RpcHandler { socket: self.socket, }, }, - RpcPacket::SearchVbase { id, search } => RpcHandle::SearchVbase { - id, - search, - x: SearchVbase { - socket: self.socket, - }, - }, RpcPacket::Flush { id } => RpcHandle::Flush { id, x: Flush { @@ -71,7 +65,13 @@ impl RpcHandler { socket: self.socket, }, }, - RpcPacket::Leave {} => RpcHandle::Leave {}, + RpcPacket::Vbase { id, vector } => RpcHandle::Vbase { + id, + vector, + x: Vbase { + socket: self.socket, + }, + }, }) } } @@ -84,18 +84,13 @@ pub enum RpcHandle { }, Search { id: Id, - search: (Vec, usize), + search: (DynamicVector, usize), prefilter: bool, x: Search, }, - SearchVbase { - id: Id, - search: (Vec, usize), - x: SearchVbase, - }, Insert { id: Id, - insert: (Vec, Pointer), + insert: (DynamicVector, Pointer), x: Insert, }, Delete { @@ -114,7 +109,11 @@ pub enum RpcHandle { id: Id, x: Stat, }, - Leave {}, + Vbase { + id: Id, + vector: DynamicVector, + x: Vbase, + }, } pub struct Create { @@ -123,8 +122,8 @@ pub struct Create { impl Create { pub fn leave(mut self) -> Result { - let packet = CreatePacket::Leave {}; - self.socket.send(packet)?; + let packet = create::CreatePacket::Leave {}; + self.socket.server_send(packet)?; Ok(RpcHandler { socket: self.socket, }) @@ -137,8 +136,8 @@ pub struct Insert { impl Insert { pub fn leave(mut self, result: Result<(), FriendlyError>) -> Result { - let packet = InsertPacket::Leave { result }; - self.socket.send(packet)?; + let packet = insert::InsertPacket::Leave { result }; + self.socket.server_send(packet)?; Ok(RpcHandler { socket: self.socket, }) @@ -151,14 +150,15 @@ pub struct Delete { impl Delete { pub fn next(&mut self, p: Pointer) -> Result { - let packet = DeletePacket::Next { p }; - self.socket.send(packet)?; - let DeleteNextPacket::Leave { delete } = self.socket.recv::()?; + let packet = delete::DeletePacket::Test { p }; + self.socket.server_send(packet)?; + let delete::DeleteTestPacket { delete } = + self.socket.server_recv::()?; Ok(delete) } pub fn leave(mut self, result: Result<(), FriendlyError>) -> Result { - let packet = DeletePacket::Leave { result }; - self.socket.send(packet)?; + let packet = delete::DeletePacket::Leave { result }; + self.socket.server_send(packet)?; Ok(RpcHandler { socket: self.socket, }) @@ -171,37 +171,18 @@ pub struct Search { impl Search { pub fn check(&mut self, p: Pointer) -> Result { - let packet = SearchPacket::Check { p }; - self.socket.send(packet)?; - let SearchCheckPacket::Leave { result } = self.socket.recv::()?; + let packet = search::SearchPacket::Check { p }; + self.socket.server_send(packet)?; + let search::SearchCheckPacket { result } = + self.socket.server_recv::()?; Ok(result) } pub fn leave( mut self, result: Result, FriendlyError>, ) -> Result { - let packet = SearchPacket::Leave { result }; - self.socket.send(packet)?; - Ok(RpcHandler { - socket: self.socket, - }) - } -} - -pub struct SearchVbase { - socket: Socket, -} - -impl SearchVbase { - pub fn next(&mut self, p: Pointer) -> Result { - let packet = SearchVbasePacket::Next { p }; - self.socket.send(packet)?; - let SearchVbaseNextPacket::Leave { stop } = self.socket.recv::()?; - Ok(stop) - } - pub fn leave(mut self, result: Result<(), FriendlyError>) -> Result { - let packet = SearchVbasePacket::Leave { result }; - self.socket.send(packet)?; + let packet = search::SearchPacket::Leave { result }; + self.socket.server_send(packet)?; Ok(RpcHandler { socket: self.socket, }) @@ -214,8 +195,8 @@ pub struct Flush { impl Flush { pub fn leave(mut self, result: Result<(), FriendlyError>) -> Result { - let packet = FlushPacket::Leave { result }; - self.socket.send(packet)?; + let packet = flush::FlushPacket::Leave { result }; + self.socket.server_send(packet)?; Ok(RpcHandler { socket: self.socket, }) @@ -228,8 +209,8 @@ pub struct Destory { impl Destory { pub fn leave(mut self) -> Result { - let packet = DestoryPacket::Leave {}; - self.socket.send(packet)?; + let packet = destory::DestoryPacket::Leave {}; + self.socket.server_send(packet)?; Ok(RpcHandler { socket: self.socket, }) @@ -243,12 +224,69 @@ pub struct Stat { impl Stat { pub fn leave( mut self, - result: Result, + result: Result, ) -> Result { - let packet = StatPacket::Leave { result }; - self.socket.send(packet)?; + let packet = stat::StatPacket::Leave { result }; + self.socket.server_send(packet)?; Ok(RpcHandler { socket: self.socket, }) } } + +pub struct Vbase { + socket: Socket, +} + +impl Vbase { + pub fn error(mut self, result: Result<(), FriendlyError>) -> Result { + self.socket + .server_send(vbase::VbaseErrorPacket { result })?; + Ok(VbaseHandler { + socket: self.socket, + }) + } +} + +pub struct VbaseHandler { + socket: Socket, +} + +impl VbaseHandler { + pub fn handle(mut self) -> Result { + Ok(match self.socket.server_recv::()? { + vbase::VbasePacket::Next {} => VbaseHandle::Next { + x: VbaseNext { + socket: self.socket, + }, + }, + vbase::VbasePacket::Leave {} => { + self.socket.server_send(vbase::VbaseLeavePacket {})?; + VbaseHandle::Leave { + x: RpcHandler { + socket: self.socket, + }, + } + } + }) + } +} + +pub enum VbaseHandle { + Next { x: VbaseNext }, + Leave { x: RpcHandler }, +} + +pub struct VbaseNext { + socket: Socket, +} + +impl VbaseNext { + pub fn leave(mut self, p: Option) -> Result { + let packet = vbase::VbaseNextPacket { p }; + self.socket.server_send(packet)?; + Ok(VbaseHandler { + socket: self.socket, + }) + } +} diff --git a/src/ipc/transport/mmap.rs b/src/ipc/transport/mmap.rs index d1bcbc5ca..a240dca11 100644 --- a/src/ipc/transport/mmap.rs +++ b/src/ipc/transport/mmap.rs @@ -1,4 +1,4 @@ -use crate::ipc::IpcError; +use super::IpcError; use crate::utils::file_socket::FileSocket; use crate::utils::os::{futex_wait, futex_wake, memfd_create, mmap_populate}; use rustix::fd::{AsFd, OwnedFd}; @@ -65,10 +65,7 @@ impl Socket { Err(e) => panic!("{:?}", e), } } - pub fn send(&mut self, packet: T) -> Result<(), IpcError> - where - T: Serialize, - { + pub fn send(&mut self, packet: T) -> Result<(), IpcError> { let buffer = bincode::serialize(&packet).expect("Failed to serialize"); unsafe { if self.is_server { @@ -79,10 +76,7 @@ impl Socket { } Ok(()) } - pub fn recv(&mut self) -> Result - where - T: for<'a> Deserialize<'a>, - { + pub fn recv Deserialize<'a>>(&mut self) -> Result { let buffer = unsafe { if self.is_server { (*self.addr).server_recv(|| self.test())? @@ -106,7 +100,7 @@ struct Channel { futex: AtomicU32, } -static_assertions::assert_eq_size!(Channel, [u8; BUFFER_SIZE]); +const _: () = assert!(std::mem::size_of::() == BUFFER_SIZE); impl Channel { unsafe fn client_recv(&self, test: impl Fn() -> bool) -> Result, IpcError> { @@ -132,30 +126,40 @@ impl Channel { { break; } - futex_wait(&self.futex, Y); + unsafe { + futex_wait(&self.futex, Y); + } } Y => { if !test() { return Err(IpcError::Closed); } - futex_wait(&self.futex, Y); + unsafe { + futex_wait(&self.futex, Y); + } } - _ => std::hint::unreachable_unchecked(), + _ => unsafe { std::hint::unreachable_unchecked() }, } } - let len = *self.len.get(); - let res = (*self.bytes.get())[0..len as usize].to_vec(); - Ok(res) + unsafe { + let len = *self.len.get(); + let res = (*self.bytes.get())[0..len as usize].to_vec(); + Ok(res) + } } unsafe fn client_send(&self, data: &[u8]) { const S: u32 = 0; const T: u32 = 1; const X: u32 = 2; debug_assert!(matches!(self.futex.load(Ordering::Relaxed), S | X)); - *self.len.get() = data.len() as u32; - (*self.bytes.get())[0..data.len()].copy_from_slice(data); + unsafe { + *self.len.get() = data.len() as u32; + (*self.bytes.get())[0..data.len()].copy_from_slice(data); + } if X == self.futex.swap(T, Ordering::Release) { - futex_wake(&self.futex); + unsafe { + futex_wake(&self.futex); + } } } unsafe fn server_recv(&self, test: impl Fn() -> bool) -> Result, IpcError> { @@ -181,30 +185,40 @@ impl Channel { { break; } - futex_wait(&self.futex, Y); + unsafe { + futex_wait(&self.futex, Y); + } } Y => { if !test() { return Err(IpcError::Closed); } - futex_wait(&self.futex, Y); + unsafe { + futex_wait(&self.futex, Y); + } } - _ => std::hint::unreachable_unchecked(), + _ => unsafe { std::hint::unreachable_unchecked() }, } } - let len = *self.len.get(); - let res = (*self.bytes.get())[0..len as usize].to_vec(); - Ok(res) + unsafe { + let len = *self.len.get(); + let res = (*self.bytes.get())[0..len as usize].to_vec(); + Ok(res) + } } unsafe fn server_send(&self, data: &[u8]) { const S: u32 = 1; const T: u32 = 0; const X: u32 = 3; debug_assert!(matches!(self.futex.load(Ordering::Relaxed), S | X)); - *self.len.get() = data.len() as u32; - (*self.bytes.get())[0..data.len()].copy_from_slice(data); + unsafe { + *self.len.get() = data.len() as u32; + (*self.bytes.get())[0..data.len()].copy_from_slice(data); + } if X == self.futex.swap(T, Ordering::Release) { - futex_wake(&self.futex); + unsafe { + futex_wake(&self.futex); + } } } } diff --git a/src/ipc/transport/mod.rs b/src/ipc/transport/mod.rs index d210db77c..1e559ca3e 100644 --- a/src/ipc/transport/mod.rs +++ b/src/ipc/transport/mod.rs @@ -10,13 +10,25 @@ pub enum Socket { } impl Socket { - pub fn send(&mut self, packet: T) -> Result<(), IpcError> { + pub fn client_send(&mut self, packet: T) -> Result<(), IpcError> { match self { Socket::Unix(x) => x.send(packet), Socket::Mmap(x) => x.send(packet), } } - pub fn recv Deserialize<'a>>(&mut self) -> Result { + pub fn client_recv Deserialize<'a>>(&mut self) -> Result { + match self { + Socket::Unix(x) => x.recv(), + Socket::Mmap(x) => x.recv(), + } + } + pub fn server_send(&mut self, packet: T) -> Result<(), IpcError> { + match self { + Socket::Unix(x) => x.send(packet), + Socket::Mmap(x) => x.send(packet), + } + } + pub fn server_recv Deserialize<'a>>(&mut self) -> Result { match self { Socket::Unix(x) => x.recv(), Socket::Mmap(x) => x.recv(), diff --git a/src/ipc/transport/unix.rs b/src/ipc/transport/unix.rs index 708afa3bf..8ba654baf 100644 --- a/src/ipc/transport/unix.rs +++ b/src/ipc/transport/unix.rs @@ -1,4 +1,4 @@ -use crate::ipc::IpcError; +use super::IpcError; use crate::utils::file_socket::FileSocket; use byteorder::{ReadBytesExt, WriteBytesExt}; use rustix::fd::AsFd; @@ -40,10 +40,7 @@ macro_rules! resolve_closed { } impl Socket { - pub fn send(&mut self, packet: T) -> Result<(), IpcError> - where - T: Serialize, - { + pub fn send(&mut self, packet: T) -> Result<(), IpcError> { use byteorder::NativeEndian as N; let buffer = bincode::serialize(&packet).expect("Failed to serialize"); let len = u32::try_from(buffer.len()).expect("Packet is too large."); @@ -51,10 +48,7 @@ impl Socket { resolve_closed!(self.stream.write_all(&buffer)); Ok(()) } - pub fn recv(&mut self) -> Result - where - T: for<'a> Deserialize<'a>, - { + pub fn recv Deserialize<'a>>(&mut self) -> Result { use byteorder::NativeEndian as N; let len = resolve_closed!(self.stream.read_u32::()); let mut buffer = vec![0u8; len as usize]; diff --git a/src/lib.rs b/src/lib.rs index 068d24b98..cf69461ac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,16 @@ //! Postgres vector extension. //! //! Provides an easy-to-use extension for vector similarity search. -#![feature(core_intrinsics)] +#![feature(offset_of)] +#![feature(arbitrary_self_types)] +#![feature(try_blocks)] -mod algorithms; mod bgworker; +mod datatype; mod embedding; +mod gucs; mod index; mod ipc; -mod postgres; mod prelude; mod utils; @@ -19,27 +21,16 @@ pgrx::extension_sql_file!("./sql/finalize.sql", finalize); #[allow(non_snake_case)] #[pgrx::pg_guard] unsafe extern "C" fn _PG_init() { - use crate::prelude::*; - if pgrx::pg_sys::IsUnderPostmaster { + use service::prelude::*; + if unsafe { pgrx::pg_sys::IsUnderPostmaster } { FriendlyError::BadInit.friendly(); } - use pgrx::bgworkers::BackgroundWorkerBuilder; - use pgrx::bgworkers::BgWorkerStartTime; - BackgroundWorkerBuilder::new("vectors") - .set_function("vectors_main") - .set_library("vectors") - .set_argument(None) - .enable_shmem_access(None) - .set_start_time(BgWorkerStartTime::PostmasterStart) - .load(); - self::postgres::init(); - self::ipc::transport::unix::init(); - self::ipc::transport::mmap::init(); -} - -#[no_mangle] -extern "C" fn vectors_main(_arg: pgrx::pg_sys::Datum) { - let _ = std::panic::catch_unwind(crate::bgworker::main); + unsafe { + self::gucs::init(); + self::index::init(); + self::ipc::init(); + self::bgworker::init(); + } } #[cfg(not(any(target_os = "linux", target_os = "macos")))] diff --git a/src/postgres/casts.rs b/src/postgres/casts.rs deleted file mode 100644 index 26d1cf924..000000000 --- a/src/postgres/casts.rs +++ /dev/null @@ -1,21 +0,0 @@ -use super::datatype::{Vector, VectorInput, VectorOutput, VectorTypmod}; -use crate::prelude::Scalar; - -#[pgrx::pg_extern(immutable, parallel_safe, strict)] -fn cast_array_to_vector(array: pgrx::Array, typmod: i32, _explicit: bool) -> VectorOutput { - assert!(!array.is_empty()); - assert!(array.len() <= 65535); - assert!(!array.contains_nulls()); - let typmod = VectorTypmod::parse_from_i32(typmod).unwrap(); - let len = typmod.dims().unwrap_or(array.len() as u16); - let mut data = Vector::new_zeroed_in_postgres(len as usize); - for (i, x) in array.iter().enumerate() { - data[i] = x.unwrap_or(Scalar::NAN); - } - data -} - -#[pgrx::pg_extern(immutable, parallel_safe, strict)] -fn cast_vector_to_array(vector: VectorInput<'_>, _typmod: i32, _explicit: bool) -> Vec { - vector.data().to_vec() -} diff --git a/src/postgres/datatype.rs b/src/postgres/datatype.rs deleted file mode 100644 index 46b8056a0..000000000 --- a/src/postgres/datatype.rs +++ /dev/null @@ -1,456 +0,0 @@ -use crate::prelude::*; -use pgrx::pg_sys::Datum; -use pgrx::pg_sys::Oid; -use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; -use pgrx::pgrx_sql_entity_graph::metadata::Returns; -use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; -use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; -use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; -use pgrx::Array; -use pgrx::FromDatum; -use pgrx::IntoDatum; -use serde::{Deserialize, Serialize}; -use std::alloc::Layout; -use std::cmp::Ordering; -use std::ffi::CStr; -use std::ffi::CString; -use std::num::NonZeroU16; -use std::ops::Deref; -use std::ops::DerefMut; -use std::ops::Index; -use std::ops::IndexMut; -use std::ptr::NonNull; - -#[derive(Debug, Clone, Copy, Serialize, Deserialize)] -pub enum VectorTypmod { - Any, - Dims(NonZeroU16), -} - -impl VectorTypmod { - pub fn parse_from_str(s: &str) -> Option { - use VectorTypmod::*; - if let Ok(x) = s.parse::() { - Some(Dims(x)) - } else { - None - } - } - pub fn parse_from_i32(x: i32) -> Option { - use VectorTypmod::*; - if x == -1 { - Some(Any) - } else if 1 <= x && x <= u16::MAX as i32 { - Some(Dims(NonZeroU16::new(x as u16).unwrap())) - } else { - None - } - } - pub fn into_option_string(self) -> Option { - use VectorTypmod::*; - match self { - Any => None, - Dims(x) => Some(i32::from(x.get()).to_string()), - } - } - pub fn into_i32(self) -> i32 { - use VectorTypmod::*; - match self { - Any => -1, - Dims(x) => i32::from(x.get()), - } - } - pub fn dims(self) -> Option { - use VectorTypmod::*; - match self { - Any => None, - Dims(dims) => Some(dims.get()), - } - } -} - -pgrx::extension_sql!( - r#" -CREATE TYPE vector ( - INPUT = vector_in, - OUTPUT = vector_out, - TYPMOD_IN = vector_typmod_in, - TYPMOD_OUT = vector_typmod_out, - STORAGE = EXTENDED, - INTERNALLENGTH = VARIABLE, - ALIGNMENT = double -); -"#, - name = "vector", - creates = [Type(Vector)], - requires = [vector_in, vector_out, vector_typmod_in, vector_typmod_out], -); - -#[repr(C, align(8))] -pub struct Vector { - varlena: u32, - len: u16, - phantom: [Scalar; 0], -} - -impl Vector { - fn varlena(size: usize) -> u32 { - (size << 2) as u32 - } - fn layout(len: usize) -> Layout { - u16::try_from(len).expect("Vector is too large."); - let layout_alpha = Layout::new::(); - let layout_beta = Layout::array::(len).unwrap(); - let layout = layout_alpha.extend(layout_beta).unwrap().0; - layout.pad_to_align() - } - pub fn new(slice: &[Scalar]) -> Box { - unsafe { - assert!(u16::try_from(slice.len()).is_ok()); - let layout = Vector::layout(slice.len()); - let ptr = std::alloc::alloc(layout) as *mut Vector; - std::ptr::addr_of_mut!((*ptr).varlena).write(Vector::varlena(layout.size())); - std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16); - std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); - Box::from_raw(ptr) - } - } - pub fn new_in_postgres(slice: &[Scalar]) -> VectorOutput { - unsafe { - assert!(u16::try_from(slice.len()).is_ok()); - let layout = Vector::layout(slice.len()); - let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut Vector; - std::ptr::addr_of_mut!((*ptr).varlena).write(Vector::varlena(layout.size())); - std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16); - std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); - VectorOutput(NonNull::new(ptr).unwrap()) - } - } - pub fn new_zeroed(len: usize) -> Box { - unsafe { - assert!(u16::try_from(len).is_ok()); - let layout = Vector::layout(len); - let ptr = std::alloc::alloc_zeroed(layout) as *mut Vector; - std::ptr::addr_of_mut!((*ptr).varlena).write(Vector::varlena(layout.size())); - std::ptr::addr_of_mut!((*ptr).len).write(len as u16); - Box::from_raw(ptr) - } - } - #[allow(dead_code)] - pub fn new_zeroed_in_postgres(len: usize) -> VectorOutput { - unsafe { - assert!(u64::try_from(len).is_ok()); - let layout = Vector::layout(len); - let ptr = pgrx::pg_sys::palloc0(layout.size()) as *mut Vector; - std::ptr::addr_of_mut!((*ptr).varlena).write(Vector::varlena(layout.size())); - std::ptr::addr_of_mut!((*ptr).len).write(len as u16); - VectorOutput(NonNull::new(ptr).unwrap()) - } - } - pub fn len(&self) -> usize { - self.len as usize - } - pub fn data(&self) -> &[Scalar] { - debug_assert_eq!(self.varlena & 3, 0); - unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.len as usize) } - } - pub fn data_mut(&mut self) -> &mut [Scalar] { - debug_assert_eq!(self.varlena & 3, 0); - unsafe { std::slice::from_raw_parts_mut(self.phantom.as_mut_ptr(), self.len as usize) } - } - #[allow(dead_code)] - pub fn copy(&self) -> Box { - Vector::new(self.data()) - } - pub fn copy_into_postgres(&self) -> VectorOutput { - Vector::new_in_postgres(self.data()) - } -} - -impl Deref for Vector { - type Target = [Scalar]; - - fn deref(&self) -> &Self::Target { - self.data() - } -} - -impl DerefMut for Vector { - fn deref_mut(&mut self) -> &mut Self::Target { - self.data_mut() - } -} - -impl AsRef<[Scalar]> for Vector { - fn as_ref(&self) -> &[Scalar] { - self.data() - } -} - -impl AsMut<[Scalar]> for Vector { - fn as_mut(&mut self) -> &mut [Scalar] { - self.data_mut() - } -} - -impl Index for Vector { - type Output = Scalar; - - fn index(&self, index: usize) -> &Self::Output { - self.data().index(index) - } -} - -impl IndexMut for Vector { - fn index_mut(&mut self, index: usize) -> &mut Self::Output { - self.data_mut().index_mut(index) - } -} - -impl PartialEq for Vector { - fn eq(&self, other: &Self) -> bool { - if self.len() != other.len() { - return false; - } - let n = self.len(); - for i in 0..n { - if self[i] != other[i] { - return false; - } - } - true - } -} - -impl Eq for Vector {} - -impl PartialOrd for Vector { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for Vector { - fn cmp(&self, other: &Self) -> Ordering { - use Ordering::*; - if let x @ Less | x @ Greater = self.len().cmp(&other.len()) { - return x; - } - let n = self.len(); - for i in 0..n { - if let x @ Less | x @ Greater = self[i].cmp(&other[i]) { - return x; - } - } - Equal - } -} - -pub enum VectorInput<'a> { - Owned(VectorOutput), - Borrowed(&'a Vector), -} - -impl<'a> VectorInput<'a> { - pub unsafe fn new(p: NonNull) -> Self { - let q = NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap(); - if p != q { - VectorInput::Owned(VectorOutput(q)) - } else { - VectorInput::Borrowed(p.as_ref()) - } - } -} - -impl Deref for VectorInput<'_> { - type Target = Vector; - - fn deref(&self) -> &Self::Target { - match self { - VectorInput::Owned(x) => x, - VectorInput::Borrowed(x) => x, - } - } -} - -pub struct VectorOutput(NonNull); - -impl VectorOutput { - pub fn into_raw(self) -> *mut Vector { - let result = self.0.as_ptr(); - std::mem::forget(self); - result - } -} - -impl Deref for VectorOutput { - type Target = Vector; - - fn deref(&self) -> &Self::Target { - unsafe { self.0.as_ref() } - } -} - -impl DerefMut for VectorOutput { - fn deref_mut(&mut self) -> &mut Self::Target { - unsafe { self.0.as_mut() } - } -} - -impl Drop for VectorOutput { - fn drop(&mut self) { - unsafe { - pgrx::pg_sys::pfree(self.0.as_ptr() as _); - } - } -} - -impl<'a> FromDatum for VectorInput<'a> { - unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { - if is_null { - None - } else { - let ptr = NonNull::new(datum.cast_mut_ptr::()).unwrap(); - Some(VectorInput::new(ptr)) - } - } -} - -impl IntoDatum for VectorOutput { - fn into_datum(self) -> Option { - Some(Datum::from(self.into_raw() as *mut ())) - } - - fn type_oid() -> Oid { - pgrx::wrappers::regtypein("vector") - } -} - -unsafe impl SqlTranslatable for VectorInput<'_> { - fn argument_sql() -> Result { - Ok(SqlMapping::As(String::from("vector"))) - } - fn return_sql() -> Result { - Ok(Returns::One(SqlMapping::As(String::from("vector")))) - } -} - -unsafe impl SqlTranslatable for VectorOutput { - fn argument_sql() -> Result { - Ok(SqlMapping::As(String::from("vector"))) - } - fn return_sql() -> Result { - Ok(Returns::One(SqlMapping::As(String::from("vector")))) - } -} - -#[pgrx::pg_extern(immutable, parallel_safe, strict)] -fn vector_in(input: &CStr, _oid: Oid, typmod: i32) -> VectorOutput { - fn solve(option: Option, hint: &str) -> T { - if let Some(x) = option { - x - } else { - FriendlyError::BadVectorString { - hint: hint.to_string(), - } - .friendly() - } - } - #[derive(Debug, Clone, Copy, PartialEq, Eq)] - enum State { - MatchingLeft, - Reading, - MatchedRight, - } - use State::*; - let input = input.to_bytes(); - let typmod = VectorTypmod::parse_from_i32(typmod).unwrap(); - let mut vector = Vec::::with_capacity(typmod.dims().unwrap_or(0) as usize); - let mut state = MatchingLeft; - let mut token: Option = None; - for &c in input { - match (state, c) { - (MatchingLeft, b'[') => { - state = Reading; - } - (Reading, b'0'..=b'9' | b'.' | b'e' | b'+' | b'-') => { - let token = token.get_or_insert(String::new()); - token.push(char::from_u32(c as u32).unwrap()); - } - (Reading, b',') => { - let token = solve(token.take(), "Expect a number."); - vector.push(solve(token.parse().ok(), "Bad number.")); - } - (Reading, b']') => { - if let Some(token) = token.take() { - vector.push(solve(token.parse().ok(), "Bad number.")); - } - state = MatchedRight; - } - (_, b' ') => {} - _ => { - FriendlyError::BadVectorString { - hint: format!("Bad charactor with ascii {:#x}.", c), - } - .friendly(); - } - } - } - if state != MatchedRight { - FriendlyError::BadVectorString { - hint: "Bad sequence.".to_string(), - } - .friendly(); - } - if vector.is_empty() || vector.len() > 65535 { - FriendlyError::BadVecForDims.friendly(); - } - if let Some(dims) = typmod.dims() { - if dims as usize != vector.len() { - FriendlyError::BadVecForUnmatchedDims { - value_dimensions: dims, - type_dimensions: vector.len() as u16, - } - .friendly(); - } - } - Vector::new_in_postgres(&vector) -} - -#[pgrx::pg_extern(immutable, parallel_safe, strict)] -fn vector_out(vector: VectorInput<'_>) -> CString { - let mut buffer = String::new(); - buffer.push('['); - if let Some(&x) = vector.data().first() { - buffer.push_str(format!("{}", x).as_str()); - } - for &x in vector.data().iter().skip(1) { - buffer.push_str(format!(", {}", x).as_str()); - } - buffer.push(']'); - CString::new(buffer).unwrap() -} - -#[pgrx::pg_extern(immutable, parallel_safe, strict)] -fn vector_typmod_in(list: Array<&CStr>) -> i32 { - if list.is_empty() { - -1 - } else if list.len() == 1 { - let s = list.get(0).unwrap().unwrap().to_str().unwrap(); - let typmod = VectorTypmod::parse_from_str(s) - .ok_or(FriendlyError::BadTypmod) - .friendly(); - typmod.into_i32() - } else { - FriendlyError::BadTypmod.friendly(); - } -} - -#[pgrx::pg_extern(immutable, parallel_safe, strict)] -fn vector_typmod_out(typmod: i32) -> CString { - let typmod = VectorTypmod::parse_from_i32(typmod).unwrap(); - match typmod.into_option_string() { - Some(s) => CString::new(format!("({})", s)).unwrap(), - None => CString::new("()").unwrap(), - } -} diff --git a/src/postgres/hook_transaction.rs b/src/postgres/hook_transaction.rs deleted file mode 100644 index 4343a31f8..000000000 --- a/src/postgres/hook_transaction.rs +++ /dev/null @@ -1,74 +0,0 @@ -use super::gucs::Transport; -use super::gucs::TRANSPORT; -use crate::ipc::client::Rpc; -use crate::ipc::{connect_mmap, connect_unix}; -use crate::prelude::*; -use crate::utils::cells::PgRefCell; -use std::cell::RefMut; -use std::collections::BTreeSet; - -static FLUSH_IF_COMMIT: PgRefCell> = unsafe { PgRefCell::new(BTreeSet::new()) }; - -static CLIENT: PgRefCell> = unsafe { PgRefCell::new(None) }; - -pub fn aborting() { - *FLUSH_IF_COMMIT.borrow_mut() = BTreeSet::new(); -} - -pub fn committing() { - { - let flush_if_commit = FLUSH_IF_COMMIT.borrow(); - if flush_if_commit.len() != 0 { - client(|mut rpc| { - for id in flush_if_commit.iter().copied() { - rpc.flush(id).friendly().friendly(); - } - - rpc - }); - } - } - *FLUSH_IF_COMMIT.borrow_mut() = BTreeSet::new(); -} - -pub fn flush_if_commit(id: Id) { - FLUSH_IF_COMMIT.borrow_mut().insert(id); -} - -pub fn client(f: F) -where - F: FnOnce(Rpc) -> Rpc, -{ - let mut guard = CLIENT.borrow_mut(); - let client = guard.take().unwrap_or_else(|| match TRANSPORT.get() { - Transport::unix => connect_unix(), - Transport::mmap => connect_mmap(), - }); - let client = f(client); - *guard = Some(client); -} - -pub struct ClientGuard(RefMut<'static, Option>); - -pub fn client_guard() -> (Rpc, ClientGuard) { - let mut guard = CLIENT.borrow_mut(); - let client = guard.take().unwrap_or_else(|| match TRANSPORT.get() { - Transport::unix => connect_unix(), - Transport::mmap => connect_mmap(), - }); - (client, ClientGuard(guard)) -} - -impl ClientGuard { - pub fn reset(mut self, client: Rpc) { - *self.0 = Some(client); - } -} - -impl Drop for ClientGuard { - fn drop(&mut self) { - if self.0.is_none() { - panic!("ClientGuard was dropped without resetting the client"); - } - } -} diff --git a/src/postgres/index_scan.rs b/src/postgres/index_scan.rs deleted file mode 100644 index 8008a76fe..000000000 --- a/src/postgres/index_scan.rs +++ /dev/null @@ -1,273 +0,0 @@ -use super::gucs::ENABLE_PREFILTER; -use super::hook_transaction::{client, ClientGuard}; -use crate::ipc::client::SearchVbaseHandler; -use crate::postgres::datatype::VectorInput; -use crate::postgres::gucs::{K, VBASE_RANGE}; -use crate::postgres::hook_transaction::client_guard; -use crate::prelude::*; -use pgrx::FromDatum; - -pub struct Scanner { - pub index_scan_state: *mut pgrx::pg_sys::IndexScanState, - pub state: ScannerState, -} - -pub enum ScannerState { - Initial { - vector: Option>, - }, - Once { - data: Vec, - }, - Iter { - handler: SearchVbaseHandler, - guard: ClientGuard, - }, - Stop, -} - -pub unsafe fn make_scan( - index_relation: pgrx::pg_sys::Relation, - n_keys: std::os::raw::c_int, - n_orderbys: std::os::raw::c_int, -) -> pgrx::pg_sys::IndexScanDesc { - use pgrx::PgMemoryContexts; - - assert!(n_keys == 0); - assert!(n_orderbys == 1); - - let scan = pgrx::pg_sys::RelationGetIndexScan(index_relation, n_keys, n_orderbys); - - (*scan).xs_recheck = false; - (*scan).xs_recheckorderby = false; - - let scanner = Scanner { - index_scan_state: std::ptr::null_mut(), - state: ScannerState::Initial { vector: None }, - }; - - (*scan).opaque = PgMemoryContexts::CurrentMemoryContext.leak_and_drop_on_delete(scanner) as _; - - scan -} - -pub unsafe fn start_scan( - scan: pgrx::pg_sys::IndexScanDesc, - keys: pgrx::pg_sys::ScanKey, - n_keys: std::os::raw::c_int, - orderbys: pgrx::pg_sys::ScanKey, - n_orderbys: std::os::raw::c_int, -) { - use ScannerState::*; - - assert!((*scan).numberOfKeys == n_keys); - assert!((*scan).numberOfOrderBys == n_orderbys); - assert!(n_keys == 0); - assert!(n_orderbys == 1); - - if n_keys > 0 { - std::ptr::copy(keys, (*scan).keyData, n_keys as usize); - } - if n_orderbys > 0 { - std::ptr::copy(orderbys, (*scan).orderByData, n_orderbys as usize); - } - if n_orderbys > 0 { - let size = std::mem::size_of::(); - let size = size * (*scan).numberOfOrderBys as usize; - let data = pgrx::pg_sys::palloc0(size) as *mut _; - (*scan).xs_orderbyvals = data; - } - if n_orderbys > 0 { - let size = std::mem::size_of::(); - let size = size * (*scan).numberOfOrderBys as usize; - let data = pgrx::pg_sys::palloc(size) as *mut bool; - data.write_bytes(1, (*scan).numberOfOrderBys as usize); - (*scan).xs_orderbynulls = data; - } - let orderby = orderbys.add(0); - let argument = (*orderby).sk_argument; - let vector = VectorInput::from_datum(argument, false).unwrap(); - let vector = vector.to_vec(); - - let state = &mut (*((*scan).opaque as *mut Scanner)).state; - *state = Initial { - vector: Some(vector), - }; -} - -pub unsafe fn next_scan(scan: pgrx::pg_sys::IndexScanDesc) -> bool { - use ScannerState::*; - - let scanner = &mut *((*scan).opaque as *mut Scanner); - if matches!(scanner.state, Stop) { - return false; - } - - if matches!(scanner.state, Initial { .. }) { - let Initial { vector } = std::mem::replace(&mut scanner.state, Initial { vector: None }) - else { - unreachable!() - }; - - #[cfg(any(feature = "pg12", feature = "pg13", feature = "pg14", feature = "pg15"))] - let oid = (*(*scan).indexRelation).rd_node.relNode; - #[cfg(feature = "pg16")] - let oid = (*(*scan).indexRelation).rd_locator.relNumber; - let id = Id::from_sys(oid); - let vector = vector.expect("`rescan` is never called."); - let index_scan_state = scanner.index_scan_state; - - if VBASE_RANGE.get() == 0 { - let prefilter = !index_scan_state.is_null() && ENABLE_PREFILTER.get(); - client(|rpc| { - let k = K.get() as _; - let mut handler = rpc.search(id, (vector, k), prefilter).friendly(); - let mut res; - let rpc = loop { - use crate::ipc::client::SearchHandle::*; - match handler.handle().friendly() { - Check { p, x } => { - let result = check(index_scan_state, p); - handler = x.leave(result).friendly(); - } - Leave { result, x } => { - res = result.friendly(); - break x; - } - } - }; - res.reverse(); - scanner.state = Once { data: res }; - rpc - }); - } else { - let range = VBASE_RANGE.get() as _; - let (rpc, guard) = client_guard(); - let handler = rpc.search_vbase(id, (vector, range)).friendly(); - scanner.state = Iter { handler, guard }; - } - } - - if let Once { data } = &mut scanner.state { - if let Some(p) = data.pop() { - (*scan).xs_heaptid = p.into_sys(); - return true; - } - scanner.state = Stop; - return false; - } - - let Iter { handler, guard } = std::mem::replace(&mut scanner.state, Stop) else { - unreachable!() - }; - use crate::ipc::client::SearchVbaseHandle::*; - match handler.handle().friendly() { - Next { p, x } => { - (*scan).xs_heaptid = p.into_sys(); - let handler = x.next().friendly(); - scanner.state = ScannerState::Iter { handler, guard }; - true - } - Leave { result, x } => { - result.friendly(); - guard.reset(x); - false - } - } -} - -pub unsafe fn end_scan(scan: pgrx::pg_sys::IndexScanDesc) { - use ScannerState::*; - - let scanner = &mut *((*scan).opaque as *mut Scanner); - if let Iter { handler, guard } = std::mem::replace(&mut scanner.state, Stop) { - use crate::ipc::client::SearchVbaseHandle::*; - match handler.handle().friendly() { - Next { p, x } => { - (*scan).xs_heaptid = p.into_sys(); - let client = x.stop().friendly(); - guard.reset(client); - } - Leave { result, x } => { - result.friendly(); - guard.reset(x); - } - } - } -} - -unsafe fn execute_boolean_qual( - state: *mut pgrx::pg_sys::ExprState, - econtext: *mut pgrx::pg_sys::ExprContext, -) -> bool { - use pgrx::PgMemoryContexts; - if state.is_null() { - return true; - } - assert!((*state).flags & pgrx::pg_sys::EEO_FLAG_IS_QUAL as u8 != 0); - let mut is_null = true; - pgrx::pg_sys::MemoryContextReset((*econtext).ecxt_per_tuple_memory); - let ret = PgMemoryContexts::For((*econtext).ecxt_per_tuple_memory) - .switch_to(|_| (*state).evalfunc.unwrap()(state, econtext, &mut is_null)); - assert!(!is_null); - bool::from_datum(ret, is_null).unwrap() -} - -unsafe fn check_quals(node: *mut pgrx::pg_sys::IndexScanState) -> bool { - let slot = (*node).ss.ss_ScanTupleSlot; - let econtext = (*node).ss.ps.ps_ExprContext; - (*econtext).ecxt_scantuple = slot; - if (*node).ss.ps.qual.is_null() { - return true; - } - let state = (*node).ss.ps.qual; - let econtext = (*node).ss.ps.ps_ExprContext; - execute_boolean_qual(state, econtext) -} - -unsafe fn check_mvcc(node: *mut pgrx::pg_sys::IndexScanState, p: Pointer) -> bool { - let scan_desc = (*node).iss_ScanDesc; - let heap_fetch = (*scan_desc).xs_heapfetch; - let index_relation = (*heap_fetch).rel; - let rd_tableam = (*index_relation).rd_tableam; - let snapshot = (*scan_desc).xs_snapshot; - let index_fetch_tuple = (*rd_tableam).index_fetch_tuple.unwrap(); - let mut all_dead = false; - let slot = (*node).ss.ss_ScanTupleSlot; - let mut heap_continue = false; - let found = index_fetch_tuple( - heap_fetch, - &mut p.into_sys(), - snapshot, - slot, - &mut heap_continue, - &mut all_dead, - ); - if found { - return true; - } - while heap_continue { - let found = index_fetch_tuple( - heap_fetch, - &mut p.into_sys(), - snapshot, - slot, - &mut heap_continue, - &mut all_dead, - ); - if found { - return true; - } - } - false -} - -unsafe fn check(node: *mut pgrx::pg_sys::IndexScanState, p: Pointer) -> bool { - if !check_mvcc(node, p) { - return false; - } - if !check_quals(node) { - return false; - } - true -} diff --git a/src/postgres/index_update.rs b/src/postgres/index_update.rs deleted file mode 100644 index e9d861920..000000000 --- a/src/postgres/index_update.rs +++ /dev/null @@ -1,31 +0,0 @@ -use crate::postgres::hook_transaction::{client, flush_if_commit}; -use crate::prelude::*; - -pub fn update_insert(id: Id, vector: Vec, tid: pgrx::pg_sys::ItemPointerData) { - flush_if_commit(id); - let p = Pointer::from_sys(tid); - client(|mut rpc| { - rpc.insert(id, (vector, p)).friendly().friendly(); - rpc - }) -} - -pub fn update_delete(id: Id, hook: impl Fn(Pointer) -> bool) { - flush_if_commit(id); - client(|rpc| { - use crate::ipc::client::DeleteHandle; - let mut handler = rpc.delete(id).friendly(); - loop { - let handle = handler.handle().friendly(); - match handle { - DeleteHandle::Next { p, x } => { - handler = x.leave(hook(p)).friendly(); - } - DeleteHandle::Leave { result, x } => { - result.friendly(); - break x; - } - } - } - }) -} diff --git a/src/postgres/mod.rs b/src/postgres/mod.rs deleted file mode 100644 index 0d240478d..000000000 --- a/src/postgres/mod.rs +++ /dev/null @@ -1,19 +0,0 @@ -mod casts; -pub mod datatype; -pub mod gucs; -mod hook_executor; -mod hook_transaction; -mod hooks; -mod index; -mod index_build; -mod index_scan; -mod index_setup; -mod index_update; -mod operators; -mod stat; - -pub unsafe fn init() { - self::gucs::init(); - self::hooks::init(); - self::index::init(); -} diff --git a/src/postgres/stat.rs b/src/postgres/stat.rs deleted file mode 100644 index 72f9b57d5..000000000 --- a/src/postgres/stat.rs +++ /dev/null @@ -1,38 +0,0 @@ -use super::hook_transaction::client; -use crate::prelude::*; - -pgrx::extension_sql!( - "\ -CREATE TYPE VectorIndexInfo AS ( - indexing BOOL, - idx_tuples INT, - idx_sealed_len INT, - idx_growing_len INT, - idx_write INT, - idx_sealed INT[], - idx_growing INT[], - idx_config TEXT -);", - name = "create_composites", -); - -#[pgrx::pg_extern(volatile, strict)] -fn vector_stat(oid: pgrx::pg_sys::Oid) -> pgrx::composite_type!("VectorIndexInfo") { - let id = Id::from_sys(oid); - let mut res = pgrx::prelude::PgHeapTuple::new_composite_type("VectorIndexInfo").unwrap(); - client(|mut rpc| { - let rpc_res = rpc.stat(id).unwrap().friendly(); - res.set_by_name("indexing", rpc_res.indexing).unwrap(); - res.set_by_name("idx_tuples", rpc_res.idx_tuples).unwrap(); - res.set_by_name("idx_sealed_len", rpc_res.idx_sealed_len) - .unwrap(); - res.set_by_name("idx_growing_len", rpc_res.idx_growing_len) - .unwrap(); - res.set_by_name("idx_write", rpc_res.idx_write).unwrap(); - res.set_by_name("idx_sealed", rpc_res.idx_sealed).unwrap(); - res.set_by_name("idx_growing", rpc_res.idx_growing).unwrap(); - res.set_by_name("idx_config", rpc_res.idx_config).unwrap(); - rpc - }); - res -} diff --git a/src/prelude.rs b/src/prelude.rs new file mode 100644 index 000000000..7a18d2954 --- /dev/null +++ b/src/prelude.rs @@ -0,0 +1,39 @@ +use service::prelude::*; + +pub trait FromSys { + fn from_sys(sys: T) -> Self; +} + +impl FromSys for Id { + fn from_sys(sys: pgrx::pg_sys::Oid) -> Self { + Self { + newtype: sys.as_u32(), + } + } +} + +impl FromSys for Pointer { + fn from_sys(sys: pgrx::pg_sys::ItemPointerData) -> Self { + let mut newtype = 0; + newtype |= (sys.ip_blkid.bi_hi as u64) << 32; + newtype |= (sys.ip_blkid.bi_lo as u64) << 16; + newtype |= sys.ip_posid as u64; + Self { newtype } + } +} + +pub trait IntoSys { + fn into_sys(self) -> T; +} + +impl IntoSys for Pointer { + fn into_sys(self) -> pgrx::pg_sys::ItemPointerData { + pgrx::pg_sys::ItemPointerData { + ip_blkid: pgrx::pg_sys::BlockIdData { + bi_hi: ((self.newtype >> 32) & 0xffff) as u16, + bi_lo: ((self.newtype >> 16) & 0xffff) as u16, + }, + ip_posid: (self.newtype & 0xffff) as u16, + } + } +} diff --git a/src/prelude/distance.rs b/src/prelude/distance.rs deleted file mode 100644 index 433b386ed..000000000 --- a/src/prelude/distance.rs +++ /dev/null @@ -1,516 +0,0 @@ -use crate::prelude::*; -use serde::{Deserialize, Serialize}; -use std::fmt::Debug; - -mod sealed { - pub trait Sealed {} -} - -#[repr(u8)] -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] -pub enum Distance { - L2, - Cosine, - Dot, -} - -impl Distance { - pub fn distance(self, lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { - match self { - Distance::L2 => distance_squared_l2(lhs, rhs), - Distance::Cosine => distance_cosine(lhs, rhs) * (-1.0), - Distance::Dot => distance_dot(lhs, rhs) * (-1.0), - } - } - pub fn elkan_k_means_normalize(self, vector: &mut [Scalar]) { - match self { - Distance::L2 => (), - Distance::Cosine => l2_normalize(vector), - Distance::Dot => l2_normalize(vector), - } - } - pub fn elkan_k_means_distance(self, lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { - match self { - Distance::L2 => distance_squared_l2(lhs, rhs).sqrt(), - Distance::Cosine => distance_dot(lhs, rhs).acos(), - Distance::Dot => distance_dot(lhs, rhs).acos(), - } - } - pub fn scalar_quantization_distance( - self, - dims: u16, - max: &[Scalar], - min: &[Scalar], - lhs: &[Scalar], - rhs: &[u8], - ) -> Scalar { - scalar_quantization_distance(self, dims, max, min, lhs, rhs) - } - pub fn scalar_quantization_distance2( - self, - dims: u16, - max: &[Scalar], - min: &[Scalar], - lhs: &[u8], - rhs: &[u8], - ) -> Scalar { - scalar_quantization_distance2(self, dims, max, min, lhs, rhs) - } - pub fn product_quantization_distance( - self, - dims: u16, - ratio: u16, - centroids: &[Scalar], - lhs: &[Scalar], - rhs: &[u8], - ) -> Scalar { - product_quantization_distance(self, dims, ratio, centroids, lhs, rhs) - } - pub fn product_quantization_distance2( - self, - dims: u16, - ratio: u16, - centroids: &[Scalar], - lhs: &[u8], - rhs: &[u8], - ) -> Scalar { - product_quantization_distance2(self, dims, ratio, centroids, lhs, rhs) - } - pub fn product_quantization_distance_with_delta( - self, - dims: u16, - ratio: u16, - centroids: &[Scalar], - lhs: &[Scalar], - rhs: &[u8], - delta: &[Scalar], - ) -> Scalar { - product_quantization_distance_with_delta(self, dims, ratio, centroids, lhs, rhs, delta) - } -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn distance_squared_l2(lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { - if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } - let n = lhs.len(); - let mut d2 = Scalar::Z; - for i in 0..n { - let d = lhs[i] - rhs[i]; - d2 += d * d; - } - d2 -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn distance_cosine(lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { - if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } - let n = lhs.len(); - let mut xy = Scalar::Z; - let mut x2 = Scalar::Z; - let mut y2 = Scalar::Z; - for i in 0..n { - xy += lhs[i] * rhs[i]; - x2 += lhs[i] * lhs[i]; - y2 += rhs[i] * rhs[i]; - } - xy / (x2 * y2).sqrt() -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn distance_dot(lhs: &[Scalar], rhs: &[Scalar]) -> Scalar { - if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } - let n = lhs.len(); - let mut xy = Scalar::Z; - for i in 0..n { - xy += lhs[i] * rhs[i]; - } - xy -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn xy_x2_y2(lhs: &[Scalar], rhs: &[Scalar]) -> (Scalar, Scalar, Scalar) { - if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } - let n = lhs.len(); - let mut xy = Scalar::Z; - let mut x2 = Scalar::Z; - let mut y2 = Scalar::Z; - for i in 0..n { - xy += lhs[i] * rhs[i]; - x2 += lhs[i] * lhs[i]; - y2 += rhs[i] * rhs[i]; - } - (xy, x2, y2) -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn length(vector: &[Scalar]) -> Scalar { - let n = vector.len(); - let mut dot = Scalar::Z; - for i in 0..n { - dot += vector[i] * vector[i]; - } - dot.sqrt() -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn l2_normalize(vector: &mut [Scalar]) { - let n = vector.len(); - let l = length(vector); - for i in 0..n { - vector[i] /= l; - } -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn distance_squared_l2_delta(lhs: &[Scalar], rhs: &[Scalar], del: &[Scalar]) -> Scalar { - if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } - let n = lhs.len(); - let mut d2 = Scalar::Z; - for i in 0..n { - let d = lhs[i] - (rhs[i] + del[i]); - d2 += d * d; - } - d2 -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn xy_x2_y2_delta(lhs: &[Scalar], rhs: &[Scalar], del: &[Scalar]) -> (Scalar, Scalar, Scalar) { - if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } - let n = lhs.len(); - let mut xy = Scalar::Z; - let mut x2 = Scalar::Z; - let mut y2 = Scalar::Z; - for i in 0..n { - xy += lhs[i] * (rhs[i] + del[i]); - x2 += lhs[i] * lhs[i]; - y2 += (rhs[i] + del[i]) * (rhs[i] + del[i]); - } - (xy, x2, y2) -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn distance_dot_delta(lhs: &[Scalar], rhs: &[Scalar], del: &[Scalar]) -> Scalar { - if lhs.len() != rhs.len() { - FriendlyError::DifferentVectorDims { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } - let n = lhs.len(); - let mut xy = Scalar::Z; - for i in 0..n { - xy += lhs[i] * (rhs[i] + del[i]); - } - xy -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn scalar_quantization_distance( - distance: Distance, - dims: u16, - max: &[Scalar], - min: &[Scalar], - lhs: &[Scalar], - rhs: &[u8], -) -> Scalar { - match distance { - Distance::L2 => { - let mut result = Scalar::Z; - for i in 0..dims as usize { - let _x = lhs[i]; - let _y = Scalar(rhs[i] as Float / 256.0) * (max[i] - min[i]) + min[i]; - result += (_x - _y) * (_x - _y); - } - result - } - Distance::Cosine => { - let mut xy = Scalar::Z; - let mut x2 = Scalar::Z; - let mut y2 = Scalar::Z; - for i in 0..dims as usize { - let _x = lhs[i]; - let _y = Scalar(rhs[i] as Float / 256.0) * (max[i] - min[i]) + min[i]; - xy += _x * _y; - x2 += _x * _x; - y2 += _y * _y; - } - xy / (x2 * y2).sqrt() * (-1.0) - } - Distance::Dot => { - let mut xy = Scalar::Z; - for i in 0..dims as usize { - let _x = lhs[i]; - let _y = Scalar(rhs[i] as Float / 256.0) * (max[i] - min[i]) + min[i]; - xy += _x * _y; - } - xy * (-1.0) - } - } -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn scalar_quantization_distance2( - distance: Distance, - dims: u16, - max: &[Scalar], - min: &[Scalar], - lhs: &[u8], - rhs: &[u8], -) -> Scalar { - match distance { - Distance::L2 => { - let mut result = Scalar::Z; - for i in 0..dims as usize { - let _x = Scalar(lhs[i] as Float / 256.0) * (max[i] - min[i]) + min[i]; - let _y = Scalar(rhs[i] as Float / 256.0) * (max[i] - min[i]) + min[i]; - result += (_x - _y) * (_x - _y); - } - result - } - Distance::Cosine => { - let mut xy = Scalar::Z; - let mut x2 = Scalar::Z; - let mut y2 = Scalar::Z; - for i in 0..dims as usize { - let _x = Scalar(lhs[i] as Float / 256.0) * (max[i] - min[i]) + min[i]; - let _y = Scalar(rhs[i] as Float / 256.0) * (max[i] - min[i]) + min[i]; - xy += _x * _y; - x2 += _x * _x; - y2 += _y * _y; - } - xy / (x2 * y2).sqrt() * (-1.0) - } - Distance::Dot => { - let mut xy = Scalar::Z; - for i in 0..dims as usize { - let _x = Scalar(lhs[i] as Float / 256.0) * (max[i] - min[i]) + min[i]; - let _y = Scalar(rhs[i] as Float / 256.0) * (max[i] - min[i]) + min[i]; - xy += _x * _y; - } - xy * (-1.0) - } - } -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn product_quantization_distance( - distance: Distance, - dims: u16, - ratio: u16, - centroids: &[Scalar], - lhs: &[Scalar], - rhs: &[u8], -) -> Scalar { - match distance { - Distance::L2 => { - let width = dims.div_ceil(ratio); - let mut result = Scalar::Z; - for i in 0..width { - let k = std::cmp::min(ratio, dims - ratio * i); - let lhs = &lhs[(i * ratio) as usize..][..k as usize]; - let rhsp = rhs[i as usize] as usize * dims as usize; - let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - result += distance_squared_l2(lhs, rhs); - } - result - } - Distance::Cosine => { - let width = dims.div_ceil(ratio); - let mut xy = Scalar::Z; - let mut x2 = Scalar::Z; - let mut y2 = Scalar::Z; - for i in 0..width { - let k = std::cmp::min(ratio, dims - ratio * i); - let lhs = &lhs[(i * ratio) as usize..][..k as usize]; - let rhsp = rhs[i as usize] as usize * dims as usize; - let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs); - xy += _xy; - x2 += _x2; - y2 += _y2; - } - xy / (x2 * y2).sqrt() * (-1.0) - } - Distance::Dot => { - let width = dims.div_ceil(ratio); - let mut xy = Scalar::Z; - for i in 0..width { - let k = std::cmp::min(ratio, dims - ratio * i); - let lhs = &lhs[(i * ratio) as usize..][..k as usize]; - let rhsp = rhs[i as usize] as usize * dims as usize; - let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let _xy = distance_dot(lhs, rhs); - xy += _xy; - } - xy * (-1.0) - } - } -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn product_quantization_distance2( - distance: Distance, - dims: u16, - ratio: u16, - centroids: &[Scalar], - lhs: &[u8], - rhs: &[u8], -) -> Scalar { - match distance { - Distance::L2 => { - let width = dims.div_ceil(ratio); - let mut result = Scalar::Z; - for i in 0..width { - let k = std::cmp::min(ratio, dims - ratio * i); - let lhsp = lhs[i as usize] as usize * dims as usize; - let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; - let rhsp = rhs[i as usize] as usize * dims as usize; - let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - result += distance_squared_l2(lhs, rhs); - } - result - } - Distance::Cosine => { - let width = dims.div_ceil(ratio); - let mut xy = Scalar::Z; - let mut x2 = Scalar::Z; - let mut y2 = Scalar::Z; - for i in 0..width { - let k = std::cmp::min(ratio, dims - ratio * i); - let lhsp = lhs[i as usize] as usize * dims as usize; - let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; - let rhsp = rhs[i as usize] as usize * dims as usize; - let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let (_xy, _x2, _y2) = xy_x2_y2(lhs, rhs); - xy += _xy; - x2 += _x2; - y2 += _y2; - } - xy / (x2 * y2).sqrt() * (-1.0) - } - Distance::Dot => { - let width = dims.div_ceil(ratio); - let mut xy = Scalar::Z; - for i in 0..width { - let k = std::cmp::min(ratio, dims - ratio * i); - let lhsp = lhs[i as usize] as usize * dims as usize; - let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; - let rhsp = rhs[i as usize] as usize * dims as usize; - let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let _xy = distance_dot(lhs, rhs); - xy += _xy; - } - xy * (-1.0) - } - } -} - -#[inline(always)] -#[multiversion::multiversion(targets = "simd")] -fn product_quantization_distance_with_delta( - distance: Distance, - dims: u16, - ratio: u16, - centroids: &[Scalar], - lhs: &[Scalar], - rhs: &[u8], - delta: &[Scalar], -) -> Scalar { - match distance { - Distance::L2 => { - let width = dims.div_ceil(ratio); - let mut result = Scalar::Z; - for i in 0..width { - let k = std::cmp::min(ratio, dims - ratio * i); - let lhs = &lhs[(i * ratio) as usize..][..k as usize]; - let rhsp = rhs[i as usize] as usize * dims as usize; - let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let del = &delta[(i * ratio) as usize..][..k as usize]; - result += distance_squared_l2_delta(lhs, rhs, del); - } - result - } - Distance::Cosine => { - let width = dims.div_ceil(ratio); - let mut xy = Scalar::Z; - let mut x2 = Scalar::Z; - let mut y2 = Scalar::Z; - for i in 0..width { - let k = std::cmp::min(ratio, dims - ratio * i); - let lhs = &lhs[(i * ratio) as usize..][..k as usize]; - let rhsp = rhs[i as usize] as usize * dims as usize; - let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let del = &delta[(i * ratio) as usize..][..k as usize]; - let (_xy, _x2, _y2) = xy_x2_y2_delta(lhs, rhs, del); - xy += _xy; - x2 += _x2; - y2 += _y2; - } - xy / (x2 * y2).sqrt() * (-1.0) - } - Distance::Dot => { - let width = dims.div_ceil(ratio); - let mut xy = Scalar::Z; - for i in 0..width { - let k = std::cmp::min(ratio, dims - ratio * i); - let lhs = &lhs[(i * ratio) as usize..][..k as usize]; - let rhsp = rhs[i as usize] as usize * dims as usize; - let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let del = &delta[(i * ratio) as usize..][..k as usize]; - let _xy = distance_dot_delta(lhs, rhs, del); - xy += _xy; - } - xy * (-1.0) - } - } -} diff --git a/src/prelude/mod.rs b/src/prelude/mod.rs deleted file mode 100644 index d740d177d..000000000 --- a/src/prelude/mod.rs +++ /dev/null @@ -1,15 +0,0 @@ -mod distance; -mod error; -mod filter; -mod heap; -mod scalar; -mod stat; -mod sys; - -pub use self::distance::Distance; -pub use self::error::{Friendly, FriendlyError}; -pub use self::filter::{Filter, Payload}; -pub use self::heap::{Heap, HeapElement}; -pub use self::scalar::{Float, Scalar}; -pub use self::stat::VectorIndexInfo; -pub use self::sys::{Id, Pointer}; diff --git a/src/prelude/scalar.rs b/src/prelude/scalar.rs deleted file mode 100644 index 99a6338d1..000000000 --- a/src/prelude/scalar.rs +++ /dev/null @@ -1,295 +0,0 @@ -use bytemuck::{Pod, Zeroable}; -use pgrx::pg_sys::{Datum, Oid}; -use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; -use pgrx::pgrx_sql_entity_graph::metadata::FunctionMetadataTypeEntity; -use pgrx::pgrx_sql_entity_graph::metadata::Returns; -use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; -use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; -use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; -use pgrx::{FromDatum, IntoDatum}; -use serde::{Deserialize, Serialize}; -use std::cmp::Ordering; -use std::fmt::{Debug, Display}; -use std::num::ParseFloatError; -use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Rem, RemAssign, Sub, SubAssign}; -use std::str::FromStr; - -pub type Float = f32; - -#[derive(Clone, Copy, Default, Serialize, Deserialize)] -#[repr(transparent)] -#[serde(transparent)] -pub struct Scalar(pub Float); - -impl Scalar { - pub const INFINITY: Self = Self(Float::INFINITY); - pub const NEG_INFINITY: Self = Self(Float::NEG_INFINITY); - pub const NAN: Self = Self(Float::NAN); - pub const Z: Self = Self(0.0); - - #[inline(always)] - pub fn acos(self) -> Self { - Self(self.0.acos()) - } - - #[inline(always)] - pub fn sqrt(self) -> Self { - Self(self.0.sqrt()) - } -} - -unsafe impl Zeroable for Scalar {} - -unsafe impl Pod for Scalar {} - -impl Debug for Scalar { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Debug::fmt(&self.0, f) - } -} - -impl Display for Scalar { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Display::fmt(&self.0, f) - } -} - -impl From for Scalar { - fn from(value: Float) -> Self { - Self(value) - } -} - -impl From for Float { - fn from(Scalar(float): Scalar) -> Self { - float - } -} - -impl PartialEq for Scalar { - fn eq(&self, other: &Self) -> bool { - self.0.total_cmp(&other.0) == Ordering::Equal - } -} - -impl Eq for Scalar {} - -impl PartialOrd for Scalar { - #[inline(always)] - fn partial_cmp(&self, other: &Self) -> Option { - Some(Ord::cmp(self, other)) - } -} - -impl Ord for Scalar { - #[inline(always)] - fn cmp(&self, other: &Self) -> Ordering { - self.0.total_cmp(&other.0) - } -} - -impl Add for Scalar { - type Output = Scalar; - - #[inline(always)] - fn add(self, rhs: Float) -> Scalar { - unsafe { std::intrinsics::fadd_fast(self.0, rhs).into() } - } -} - -impl AddAssign for Scalar { - fn add_assign(&mut self, rhs: Float) { - unsafe { self.0 = std::intrinsics::fadd_fast(self.0, rhs) } - } -} - -impl Add for Scalar { - type Output = Scalar; - - #[inline(always)] - fn add(self, rhs: Scalar) -> Scalar { - unsafe { std::intrinsics::fadd_fast(self.0, rhs.0).into() } - } -} - -impl AddAssign for Scalar { - #[inline(always)] - fn add_assign(&mut self, rhs: Scalar) { - unsafe { self.0 = std::intrinsics::fadd_fast(self.0, rhs.0) } - } -} - -impl Sub for Scalar { - type Output = Scalar; - - #[inline(always)] - fn sub(self, rhs: Float) -> Scalar { - unsafe { std::intrinsics::fsub_fast(self.0, rhs).into() } - } -} - -impl SubAssign for Scalar { - #[inline(always)] - fn sub_assign(&mut self, rhs: Float) { - unsafe { self.0 = std::intrinsics::fsub_fast(self.0, rhs) } - } -} - -impl Sub for Scalar { - type Output = Scalar; - - #[inline(always)] - fn sub(self, rhs: Scalar) -> Scalar { - unsafe { std::intrinsics::fsub_fast(self.0, rhs.0).into() } - } -} - -impl SubAssign for Scalar { - #[inline(always)] - fn sub_assign(&mut self, rhs: Scalar) { - unsafe { self.0 = std::intrinsics::fsub_fast(self.0, rhs.0) } - } -} - -impl Mul for Scalar { - type Output = Scalar; - - #[inline(always)] - fn mul(self, rhs: Float) -> Scalar { - unsafe { std::intrinsics::fmul_fast(self.0, rhs).into() } - } -} - -impl MulAssign for Scalar { - #[inline(always)] - fn mul_assign(&mut self, rhs: Float) { - unsafe { self.0 = std::intrinsics::fmul_fast(self.0, rhs) } - } -} - -impl Mul for Scalar { - type Output = Scalar; - - #[inline(always)] - fn mul(self, rhs: Scalar) -> Scalar { - unsafe { std::intrinsics::fmul_fast(self.0, rhs.0).into() } - } -} - -impl MulAssign for Scalar { - #[inline(always)] - fn mul_assign(&mut self, rhs: Scalar) { - unsafe { self.0 = std::intrinsics::fmul_fast(self.0, rhs.0) } - } -} - -impl Div for Scalar { - type Output = Scalar; - - #[inline(always)] - fn div(self, rhs: Float) -> Scalar { - unsafe { std::intrinsics::fdiv_fast(self.0, rhs).into() } - } -} - -impl DivAssign for Scalar { - #[inline(always)] - fn div_assign(&mut self, rhs: Float) { - unsafe { self.0 = std::intrinsics::fdiv_fast(self.0, rhs) } - } -} - -impl Div for Scalar { - type Output = Scalar; - - #[inline(always)] - fn div(self, rhs: Scalar) -> Scalar { - unsafe { std::intrinsics::fdiv_fast(self.0, rhs.0).into() } - } -} - -impl DivAssign for Scalar { - #[inline(always)] - fn div_assign(&mut self, rhs: Scalar) { - unsafe { self.0 = std::intrinsics::fdiv_fast(self.0, rhs.0) } - } -} - -impl Rem for Scalar { - type Output = Scalar; - - #[inline(always)] - fn rem(self, rhs: Float) -> Scalar { - unsafe { std::intrinsics::frem_fast(self.0, rhs).into() } - } -} - -impl RemAssign for Scalar { - #[inline(always)] - fn rem_assign(&mut self, rhs: Float) { - unsafe { self.0 = std::intrinsics::frem_fast(self.0, rhs) } - } -} - -impl Rem for Scalar { - type Output = Scalar; - - #[inline(always)] - fn rem(self, rhs: Scalar) -> Scalar { - unsafe { std::intrinsics::frem_fast(self.0, rhs.0).into() } - } -} - -impl RemAssign for Scalar { - #[inline(always)] - fn rem_assign(&mut self, rhs: Scalar) { - unsafe { self.0 = std::intrinsics::frem_fast(self.0, rhs.0) } - } -} - -impl FromStr for Scalar { - type Err = ParseFloatError; - - fn from_str(s: &str) -> Result { - Float::from_str(s).map(|x| x.into()) - } -} - -impl FromDatum for Scalar { - const GET_TYPOID: bool = false; - - unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, typoid: Oid) -> Option { - Float::from_polymorphic_datum(datum, is_null, typoid).map(Self) - } -} - -impl IntoDatum for Scalar { - fn into_datum(self) -> Option { - Float::into_datum(self.0) - } - - fn type_oid() -> Oid { - Float::type_oid() - } -} - -unsafe impl SqlTranslatable for Scalar { - fn type_name() -> &'static str { - Float::type_name() - } - fn argument_sql() -> Result { - Float::argument_sql() - } - fn return_sql() -> Result { - Float::return_sql() - } - fn variadic() -> bool { - Float::variadic() - } - fn optional() -> bool { - Float::optional() - } - fn entity() -> FunctionMetadataTypeEntity { - Float::entity() - } -} diff --git a/src/prelude/stat.rs b/src/prelude/stat.rs deleted file mode 100644 index b325b60ed..000000000 --- a/src/prelude/stat.rs +++ /dev/null @@ -1,13 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize)] -pub struct VectorIndexInfo { - pub indexing: bool, - pub idx_tuples: i32, - pub idx_sealed_len: i32, - pub idx_growing_len: i32, - pub idx_write: i32, - pub idx_sealed: Vec, - pub idx_growing: Vec, - pub idx_config: String, -} diff --git a/src/sql/bootstrap.sql b/src/sql/bootstrap.sql index a12734ab2..5310d6794 100644 --- a/src/sql/bootstrap.sql +++ b/src/sql/bootstrap.sql @@ -1 +1,2 @@ CREATE TYPE vector; +CREATE TYPE vecf16; diff --git a/src/sql/finalize.sql b/src/sql/finalize.sql index bdc484b03..3e59ad0c1 100644 --- a/src/sql/finalize.sql +++ b/src/sql/finalize.sql @@ -1,21 +1,36 @@ CREATE CAST (real[] AS vector) - WITH FUNCTION cast_array_to_vector(real[], integer, boolean) AS IMPLICIT; + WITH FUNCTION vecf32_cast_array_to_vector(real[], integer, boolean) AS IMPLICIT; CREATE CAST (vector AS real[]) - WITH FUNCTION cast_vector_to_array(vector, integer, boolean) AS IMPLICIT; + WITH FUNCTION vecf32_cast_vector_to_array(vector, integer, boolean) AS IMPLICIT; -CREATE OPERATOR CLASS l2_ops +CREATE ACCESS METHOD vectors TYPE INDEX HANDLER vectors_amhandler; +COMMENT ON ACCESS METHOD vectors IS 'pgvecto.rs index access method'; + +CREATE OPERATOR CLASS vector_l2_ops FOR TYPE vector USING vectors AS OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops; -CREATE OPERATOR CLASS dot_ops +CREATE OPERATOR CLASS vector_dot_ops FOR TYPE vector USING vectors AS OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops; -CREATE OPERATOR CLASS cosine_ops +CREATE OPERATOR CLASS vector_cos_ops FOR TYPE vector USING vectors AS OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops; +CREATE OPERATOR CLASS vecf16_l2_ops + FOR TYPE vecf16 USING vectors AS + OPERATOR 1 <-> (vecf16, vecf16) FOR ORDER BY float_ops; + +CREATE OPERATOR CLASS vecf16_dot_ops + FOR TYPE vecf16 USING vectors AS + OPERATOR 1 <#> (vecf16, vecf16) FOR ORDER BY float_ops; + +CREATE OPERATOR CLASS vecf16_cos_ops + FOR TYPE vecf16 USING vectors AS + OPERATOR 1 <=> (vecf16, vecf16) FOR ORDER BY float_ops; + CREATE VIEW pg_vector_index_info AS SELECT C.oid AS tablerelid, @@ -27,4 +42,4 @@ CREATE VIEW pg_vector_index_info AS pg_index X ON C.oid = X.indrelid JOIN pg_class I ON I.oid = X.indexrelid JOIN pg_am A ON A.oid = I.relam - WHERE A.amname = 'vectors'; \ No newline at end of file + WHERE A.amname = 'vectors'; diff --git a/src/utils/cells.rs b/src/utils/cells.rs index e9f004570..edaf4b22a 100644 --- a/src/utils/cells.rs +++ b/src/utils/cells.rs @@ -1,4 +1,4 @@ -use std::cell::{Cell, RefCell, UnsafeCell}; +use std::cell::{Cell, RefCell}; pub struct PgCell(Cell); @@ -36,28 +36,3 @@ impl PgRefCell { self.0.borrow() } } - -#[repr(transparent)] -pub struct SyncUnsafeCell { - value: UnsafeCell, -} - -unsafe impl Sync for SyncUnsafeCell {} - -impl SyncUnsafeCell { - pub const fn new(value: T) -> Self { - Self { - value: UnsafeCell::new(value), - } - } -} - -impl SyncUnsafeCell { - pub fn get(&self) -> *mut T { - self.value.get() - } - - pub fn get_mut(&mut self) -> &mut T { - self.value.get_mut() - } -} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 813ce1d83..462c6ba81 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,9 +1,3 @@ pub mod cells; -pub mod clean; -pub mod dir_ops; -pub mod file_atomic; pub mod file_socket; -pub mod file_wal; -pub mod mmap_array; pub mod os; -pub mod vec2; diff --git a/src/utils/os.rs b/src/utils/os.rs index 44316be89..d77f71899 100644 --- a/src/utils/os.rs +++ b/src/utils/os.rs @@ -8,18 +8,22 @@ pub unsafe fn futex_wait(futex: &AtomicU32, value: u32) { tv_sec: 15, tv_nsec: 0, }; - libc::syscall( - libc::SYS_futex, - futex.as_ptr(), - libc::FUTEX_WAIT, - value, - &FUTEX_TIMEOUT, - ); + unsafe { + libc::syscall( + libc::SYS_futex, + futex.as_ptr(), + libc::FUTEX_WAIT, + value, + &FUTEX_TIMEOUT, + ); + } } #[cfg(target_os = "linux")] pub unsafe fn futex_wake(futex: &AtomicU32) { - libc::syscall(libc::SYS_futex, futex.as_ptr(), libc::FUTEX_WAKE, i32::MAX); + unsafe { + libc::syscall(libc::SYS_futex, futex.as_ptr(), libc::FUTEX_WAKE, i32::MAX); + } } #[cfg(target_os = "linux")] @@ -34,34 +38,40 @@ pub fn memfd_create() -> std::io::Result { #[cfg(target_os = "linux")] pub unsafe fn mmap_populate(len: usize, fd: impl AsFd) -> std::io::Result<*mut libc::c_void> { use std::ptr::null_mut; - Ok(rustix::mm::mmap( - null_mut(), - len, - ProtFlags::READ | ProtFlags::WRITE, - MapFlags::SHARED | MapFlags::POPULATE, - fd, - 0, - )?) + unsafe { + Ok(rustix::mm::mmap( + null_mut(), + len, + ProtFlags::READ | ProtFlags::WRITE, + MapFlags::SHARED | MapFlags::POPULATE, + fd, + 0, + )?) + } } #[cfg(target_os = "macos")] pub unsafe fn futex_wait(futex: &AtomicU32, value: u32) { const ULOCK_TIMEOUT: u32 = 15_000_000; - ulock_sys::__ulock_wait( - ulock_sys::darwin19::UL_COMPARE_AND_WAIT_SHARED, - futex.as_ptr().cast(), - value as _, - ULOCK_TIMEOUT, - ); + unsafe { + ulock_sys::__ulock_wait( + ulock_sys::darwin19::UL_COMPARE_AND_WAIT_SHARED, + futex.as_ptr().cast(), + value as _, + ULOCK_TIMEOUT, + ); + } } #[cfg(target_os = "macos")] pub unsafe fn futex_wake(futex: &AtomicU32) { - ulock_sys::__ulock_wake( - ulock_sys::darwin19::UL_COMPARE_AND_WAIT_SHARED, - futex.as_ptr().cast(), - 0, - ); + unsafe { + ulock_sys::__ulock_wake( + ulock_sys::darwin19::UL_COMPARE_AND_WAIT_SHARED, + futex.as_ptr().cast(), + 0, + ); + } } #[cfg(target_os = "macos")] @@ -87,12 +97,14 @@ pub fn memfd_create() -> std::io::Result { #[cfg(target_os = "macos")] pub unsafe fn mmap_populate(len: usize, fd: impl AsFd) -> std::io::Result<*mut libc::c_void> { use std::ptr::null_mut; - Ok(rustix::mm::mmap( - null_mut(), - len, - ProtFlags::READ | ProtFlags::WRITE, - MapFlags::SHARED, - fd, - 0, - )?) + unsafe { + Ok(rustix::mm::mmap( + null_mut(), + len, + ProtFlags::READ | ProtFlags::WRITE, + MapFlags::SHARED, + fd, + 0, + )?) + } } diff --git a/tests/sqllogictest/error.slt b/tests/sqllogictest/error.slt index 466e547f7..7e4428da4 100644 --- a/tests/sqllogictest/error.slt +++ b/tests/sqllogictest/error.slt @@ -5,7 +5,7 @@ statement ok CREATE TABLE t (val vector(3)); statement ok -CREATE INDEX ON t USING vectors (val l2_ops); +CREATE INDEX ON t USING vectors (val vector_l2_ops); statement error The given vector is invalid for input. INSERT INTO t (val) VALUES ('[0, 1, 2, 3]'); diff --git a/tests/sqllogictest/flat.slt b/tests/sqllogictest/flat.slt index d666518ab..601219db1 100644 --- a/tests/sqllogictest/flat.slt +++ b/tests/sqllogictest/flat.slt @@ -8,7 +8,7 @@ statement ok INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[] FROM generate_series(1, 1000); statement ok -CREATE INDEX ON t USING vectors (val l2_ops) +CREATE INDEX ON t USING vectors (val vector_l2_ops) WITH (options = "[indexing.flat]"); statement ok diff --git a/tests/sqllogictest/hnsw.slt b/tests/sqllogictest/hnsw.slt index 07f666de0..51f394828 100644 --- a/tests/sqllogictest/hnsw.slt +++ b/tests/sqllogictest/hnsw.slt @@ -11,7 +11,7 @@ INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[] FROM gene # And because of borrow checker, we can't remove this table before restarting the postgres. # Maybe we need better error handling. statement ok -CREATE INDEX ON t USING vectors (val l2_ops) +CREATE INDEX ON t USING vectors (val vector_l2_ops) WITH (options = "[indexing.hnsw]"); statement ok diff --git a/tests/sqllogictest/ivf.slt b/tests/sqllogictest/ivf.slt index 56b7aef8c..1f938f990 100644 --- a/tests/sqllogictest/ivf.slt +++ b/tests/sqllogictest/ivf.slt @@ -12,7 +12,7 @@ statement ok INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[] FROM generate_series(1, 1000); statement ok -CREATE INDEX ON t USING vectors (val l2_ops) +CREATE INDEX ON t USING vectors (val vector_l2_ops) WITH (options = "[indexing.ivf]"); statement ok @@ -47,7 +47,7 @@ statement ok INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[] FROM generate_series(1, 1000); statement ok -CREATE INDEX ON t USING vectors (val l2_ops) +CREATE INDEX ON t USING vectors (val vector_l2_ops) WITH (options = "[indexing.ivf.quantization.product]"); statement ok diff --git a/tests/sqllogictest/quantization.slt b/tests/sqllogictest/quantization.slt index af96edeb4..7908d7703 100644 --- a/tests/sqllogictest/quantization.slt +++ b/tests/sqllogictest/quantization.slt @@ -9,7 +9,7 @@ statement ok INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[] FROM generate_series(1, 1000); statement ok -CREATE INDEX ON t USING vectors (val l2_ops) +CREATE INDEX ON t USING vectors (val vector_l2_ops) WITH (options = "[indexing.hnsw.quantization.product]"); statement ok @@ -41,7 +41,7 @@ statement ok INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[] FROM generate_series(1, 1000); statement ok -CREATE INDEX ON t USING vectors (val l2_ops) +CREATE INDEX ON t USING vectors (val vector_l2_ops) WITH (options = "[indexing.hnsw.quantization.scalar]"); statement ok diff --git a/tests/sqllogictest/reindex.slt b/tests/sqllogictest/reindex.slt index de3b7e374..b871a5c9f 100644 --- a/tests/sqllogictest/reindex.slt +++ b/tests/sqllogictest/reindex.slt @@ -8,7 +8,7 @@ statement ok INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[] FROM generate_series(1, 1000); statement ok -CREATE INDEX ON t USING vectors (val l2_ops) +CREATE INDEX ON t USING vectors (val vector_l2_ops) WITH (options = "[indexing.hnsw]"); statement ok diff --git a/tests/sqllogictest/update.slt b/tests/sqllogictest/update.slt index 0306f7262..8fd44673b 100644 --- a/tests/sqllogictest/update.slt +++ b/tests/sqllogictest/update.slt @@ -8,7 +8,7 @@ statement ok INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[] FROM generate_series(1, 1000); statement ok -CREATE INDEX CONCURRENTLY ON t USING vectors (val l2_ops); +CREATE INDEX CONCURRENTLY ON t USING vectors (val vector_l2_ops); statement ok UPDATE t SET val = ARRAY[0.2, random(), random()]::real[] WHERE val = (SELECT val FROM t ORDER BY val <-> '[0.1,0.1,0.1]' LIMIT 1); diff --git a/tests/sqllogictest/vbase.slt b/tests/sqllogictest/vbase.slt index e8b1ea7c4..7121c9a56 100644 --- a/tests/sqllogictest/vbase.slt +++ b/tests/sqllogictest/vbase.slt @@ -8,14 +8,14 @@ statement ok INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[] FROM generate_series(1, 100000); statement ok -CREATE INDEX ON t USING vectors (val l2_ops) +CREATE INDEX ON t USING vectors (val vector_l2_ops) WITH (options = "[indexing.hnsw]"); statement ok INSERT INTO t (val) VALUES ('[0.6,0.6,0.6]'); statement ok -SET vectors.vbase_range=86; +SET vectors.enable_vbase=on; query I SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <-> '[0.5,0.5,0.5]' limit 100) t2;