From 6223b7af1816c51d1a21c0addeeaa37e58af8153 Mon Sep 17 00:00:00 2001 From: Varik Matevosyan Date: Sun, 10 Dec 2023 00:51:29 +0400 Subject: [PATCH] Add flag to disable operator rewriting hooks and make pgvector-compatible (#240) * Add flag to disable operator rewriting hooks * Add operators for cosine and hamming distances to work on pgvector compatibility mode * Add update sql file * Run pgvector tests in pgvector_compat mode * Fix vector tests * Chown pgvector dir for postgres * remove pgvector directory before installing * Fix update path * Keep original hooks every time the pgvector_compat guc is changed * Reset original hooks only if changed in fini * Set pgvector_compat to TRUE by default and update tests * Update README * Fix brew symlink issue * Remove symlink before brew install * Ignore brew install error --- CMakeLists.txt | 3 +- README.md | 15 ++- ci/scripts/build-linux.sh | 1 + ci/scripts/build-mac.sh | 2 +- ci/scripts/build.sh | 4 +- ci/scripts/run-tests-linux.sh | 23 ++++- sql/lantern.sql | 68 +++++++++++-- sql/updates/0.0.8--0.0.9.sql | 169 +++++++++++++++++++++++++++++++ src/hnsw.c | 43 ++++++++ src/hnsw.h | 4 + src/hnsw/options.c | 31 +++++- src/hnsw/options.h | 1 + src/hooks/executor_start.c | 6 ++ src/hooks/op_rewrite.c | 1 - src/hooks/post_parse.c | 3 +- test/expected/ext_relocation.out | 14 ++- test/expected/hnsw_dist_func.out | 5 +- test/expected/hnsw_operators.out | 116 +++++++++++++++++++++ test/expected/hnsw_select.out | 3 +- test/expected/hnsw_todo.out | 16 +-- test/expected/hnsw_vector.out | 106 +++++++++++++++++++ test/schedule.txt | 2 +- test/sql/hnsw_dist_func.sql | 5 +- test/sql/hnsw_operators.sql | 62 ++++++++++++ test/sql/hnsw_select.sql | 3 +- test/sql/hnsw_todo.sql | 3 +- test/sql/hnsw_vector.sql | 41 ++++++++ 27 files changed, 711 insertions(+), 39 deletions(-) create mode 100644 sql/updates/0.0.8--0.0.9.sql create mode 100644 test/expected/hnsw_operators.out create mode 100644 test/sql/hnsw_operators.sql diff --git a/CMakeLists.txt b/CMakeLists.txt index a566feb07..4f950b8bc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.3) -set(LANTERNDB_VERSION 0.0.8) +set(LANTERNDB_VERSION 0.0.9) project( LanternDB @@ -189,6 +189,7 @@ set (_update_files sql/updates/0.0.5--0.0.6.sql sql/updates/0.0.6--0.0.7.sql sql/updates/0.0.7--0.0.8.sql + sql/updates/0.0.8--0.0.9.sql ) add_custom_command( diff --git a/README.md b/README.md index bee8eb6a7..f59986755 100644 --- a/README.md +++ b/README.md @@ -81,16 +81,27 @@ FROM small_world ORDER BY vector <-> ARRAY[0,0,0] LIMIT 1; ### A note on operators and operator classes -Lantern supports several distance functions in the index. You only need to specify the distance function used for a column at index creation time. Lantern will automatically infer the distance function to use for search so you always use `<->` operator in search queries. +Lantern supports several distance functions in the index and it has 2 modes for operators: + +1. `lantern.pgvector_compat=TRUE` (default) + In this mode there are 3 operators available `<->` (l2sq), `<=>` (cosine), `<+>` (hamming). + You need to use right operator in order to trigger index scan + +2. `lantern.pgvector_compat=FALSE` + In this mode you only need to specify the distance function used for a column at index creation time. Lantern will automatically infer the distance function to use for search so you always use `<->` operator in search queries. Note that the operator `<->` is intended exclusively for use with index lookups. If you expect to not use the index in a query, just use the distance function directly (e.g. `l2sq_dist(v1, v2)`) +> To switch between modes set `lantern.pgvector_compat` variable to `TRUE` or `FALSE`. + There are four defined operator classes that can be employed during index creation: - **`dist_l2sq_ops`**: Default for the type `real[]` - **`dist_vec_l2sq_ops`**: Default for the type `vector` - **`dist_cos_ops`**: Applicable to the type `real[]` -- **`dist_hamming_ops`**: Applicable for the type `integer[]` +- **`dist_vec_cos_ops`**: Applicable to the type `vector` +- **`dist_hamming_ops`**: Applicable to the type `integer[]` +- **`dist_vec_hamming_ops`**: Applicable to the type `vector` ### Index Construction Parameters diff --git a/ci/scripts/build-linux.sh b/ci/scripts/build-linux.sh index 3da4cec7f..81f0869ba 100755 --- a/ci/scripts/build-linux.sh +++ b/ci/scripts/build-linux.sh @@ -45,4 +45,5 @@ function cleanup_environment() { # Chown to postgres for running tests chown -R postgres:postgres /tmp/lantern + chown -R postgres:postgres /tmp/pgvector } diff --git a/ci/scripts/build-mac.sh b/ci/scripts/build-mac.sh index ce39eb773..5229dbd46 100755 --- a/ci/scripts/build-mac.sh +++ b/ci/scripts/build-mac.sh @@ -8,7 +8,7 @@ function setup_locale_and_install_packages() { } function setup_postgres() { - cmd="brew install postgresql@${PG_VERSION} clang-format" + cmd="brew install postgresql@${PG_VERSION} clang-format || true" # ignoring brew linking errors if [[ $USER == "root" ]] then # Runner is github CI user diff --git a/ci/scripts/build.sh b/ci/scripts/build.sh index ec3ed1916..c6a98b5d4 100755 --- a/ci/scripts/build.sh +++ b/ci/scripts/build.sh @@ -41,7 +41,9 @@ function install_external_dependencies() { PGVECTOR_VERSION=0.5.0 wget -O pgvector.tar.gz https://github.com/pgvector/pgvector/archive/refs/tags/v${PGVECTOR_VERSION}.tar.gz tar xzf pgvector.tar.gz - pushd pgvector-${PGVECTOR_VERSION} + rm -rf pgvector || true + mv pgvector-${PGVECTOR_VERSION} pgvector + pushd pgvector make && make install popd popd diff --git a/ci/scripts/run-tests-linux.sh b/ci/scripts/run-tests-linux.sh index 5f445a6bc..63ef1d806 100755 --- a/ci/scripts/run-tests-linux.sh +++ b/ci/scripts/run-tests-linux.sh @@ -22,12 +22,33 @@ function wait_for_pg(){ done } +function run_pgvector_tests(){ + pushd /tmp/pgvector + # Add lantern to load-extension in pgregress + sed -i '/REGRESS_OPTS \=/ s/$/ --load-extension lantern/' Makefile + + # Set pgvector_compat flag in test files + for file in ./test/sql/*; do + echo 'SET lantern.pgvector_compat=TRUE;' | cat - $file > temp && mv temp $file + done + + # Set pgvector_compat flag in result files + for file in ./test/expected/*.out; do + echo 'SET lantern.pgvector_compat=TRUE;' | cat - $file > temp && mv temp $file + done + + # Run tests + make installcheck + popd +} + function run_db_tests(){ if [[ "$RUN_TESTS" == "1" ]] then cd $WORKDIR/build && \ make test && \ - make test-client + make test-client && \ + run_pgvector_tests && \ killall postgres && \ gcovr -r $WORKDIR/src/ --object-directory $WORKDIR/build/ --xml /tmp/coverage.xml fi diff --git a/sql/lantern.sql b/sql/lantern.sql index 1a78f189b..6bc39f26a 100644 --- a/sql/lantern.sql +++ b/sql/lantern.sql @@ -6,29 +6,52 @@ CREATE FUNCTION hnsw_handler(internal) RETURNS index_am_handler CREATE FUNCTION ldb_generic_dist(real[], real[]) RETURNS real AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; -CREATE FUNCTION ldb_generic_dist(integer[], integer[]) RETURNS real - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; - CREATE FUNCTION l2sq_dist(real[], real[]) RETURNS real AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +-- this function is needed, as we should also use <-> operator +-- with integer[] type (to overwrite hamming dist function in our hooks) +-- and if we do not create l2sq_dist for integer[] type it will fail to cast in pgvector_compat mode +CREATE FUNCTION l2sq_dist(integer[], integer[]) RETURNS real + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION cos_dist(real[], real[]) RETURNS real AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +-- functions _with_guard suffix are used to forbid operator usage +-- if operator hooks are enabled (lantern.pgvector_compat=FALSE) +CREATE FUNCTION cos_dist_with_guard(real[], real[]) RETURNS real + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION hamming_dist(integer[], integer[]) RETURNS integer AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION hamming_dist_with_guard(integer[], integer[]) RETURNS integer + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- operators CREATE OPERATOR <-> ( - LEFTARG = real[], RIGHTARG = real[], PROCEDURE = ldb_generic_dist, + LEFTARG = real[], RIGHTARG = real[], PROCEDURE = l2sq_dist, COMMUTATOR = '<->' ); CREATE OPERATOR <-> ( - LEFTARG = integer[], RIGHTARG = integer[], PROCEDURE = ldb_generic_dist, + LEFTARG = integer[], RIGHTARG = integer[], PROCEDURE = l2sq_dist, COMMUTATOR = '<->' ); +CREATE OPERATOR <=> ( + LEFTARG = real[], RIGHTARG = real[], PROCEDURE = cos_dist_with_guard, + COMMUTATOR = '<=>' +); + +CREATE OPERATOR <+> ( + LEFTARG = integer[], RIGHTARG = integer[], PROCEDURE = hamming_dist_with_guard, + COMMUTATOR = '<+>' +); + + CREATE SCHEMA _lantern_internal; CREATE FUNCTION _lantern_internal.validate_index(index regclass, print_info boolean DEFAULT true) RETURNS VOID @@ -56,14 +79,20 @@ BEGIN CREATE OPERATOR CLASS dist_cos_ops FOR TYPE real[] USING ' || access_method_name || ' AS OPERATOR 1 <-> (real[], real[]) FOR ORDER BY float_ops, - FUNCTION 1 cos_dist(real[], real[]); + FUNCTION 1 cos_dist(real[], real[]), + -- it is important to set the function with guard the second + -- as op rewriting hook takes the first function to use + OPERATOR 2 <=> (real[], real[]) FOR ORDER BY float_ops, + FUNCTION 2 cos_dist_with_guard(real[], real[]); '; dist_hamming_ops := ' CREATE OPERATOR CLASS dist_hamming_ops FOR TYPE integer[] USING ' || access_method_name || ' AS OPERATOR 1 <-> (integer[], integer[]) FOR ORDER BY float_ops, - FUNCTION 1 hamming_dist(integer[], integer[]); + FUNCTION 1 hamming_dist(integer[], integer[]), + OPERATOR 2 <+> (integer[], integer[]) FOR ORDER BY integer_ops, + FUNCTION 2 hamming_dist_with_guard(integer[], integer[]); '; -- Execute the dynamic SQL statement. @@ -107,10 +136,35 @@ BEGIN CREATE FUNCTION l2sq_dist(vector, vector) RETURNS float8 AS 'MODULE_PATHNAME', 'vector_l2sq_dist' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION cos_dist(vector, vector) RETURNS float8 + AS 'MODULE_PATHNAME', 'vector_cos_dist' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + + CREATE FUNCTION hamming_dist(vector, vector) RETURNS float8 + AS 'MODULE_PATHNAME', 'vector_hamming_dist' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + + CREATE OPERATOR <+> ( + LEFTARG = vector, RIGHTARG = vector, PROCEDURE = hamming_dist, + COMMUTATOR = '<+>' + ); + CREATE OPERATOR CLASS dist_vec_l2sq_ops DEFAULT FOR TYPE vector USING lantern_hnsw AS OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, FUNCTION 1 l2sq_dist(vector, vector); + + CREATE OPERATOR CLASS dist_vec_cos_ops + FOR TYPE vector USING lantern_hnsw AS + OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 cos_dist(vector, vector), + OPERATOR 2 <=> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 2 cos_dist(vector, vector); + + CREATE OPERATOR CLASS dist_vec_hamming_ops + FOR TYPE vector USING lantern_hnsw AS + OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 hamming_dist(vector, vector), + OPERATOR 2 <+> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 2 hamming_dist(vector, vector); END IF; diff --git a/sql/updates/0.0.8--0.0.9.sql b/sql/updates/0.0.8--0.0.9.sql new file mode 100644 index 000000000..054887c92 --- /dev/null +++ b/sql/updates/0.0.8--0.0.9.sql @@ -0,0 +1,169 @@ +DO $BODY$ +DECLARE + pgvector_exists boolean; + am_name TEXT; + r pg_indexes%ROWTYPE; + indexes_cursor REFCURSOR; + index_names TEXT[] := '{}'; + index_definitions TEXT[] := '{}'; +BEGIN + -- Function to recreate operator classes for specified access method + CREATE OR REPLACE FUNCTION _lantern_internal._recreate_ldb_operator_classes(access_method_name TEXT) RETURNS BOOLEAN AS $$ + DECLARE + dist_l2sq_ops TEXT; + dist_l2sq_ops_drop TEXT; + dist_cos_ops TEXT; + dist_cos_ops_drop TEXT; + dist_hamming_ops TEXT; + dist_hamming_ops_drop TEXT; + BEGIN + + -- Construct the SQL statement to create the operator classes dynamically. + dist_l2sq_ops_drop := 'DROP OPERATOR CLASS IF EXISTS dist_l2sq_ops USING ' || access_method_name || ' CASCADE;'; + dist_l2sq_ops := ' + CREATE OPERATOR CLASS dist_l2sq_ops + DEFAULT FOR TYPE real[] USING ' || access_method_name || ' AS + OPERATOR 1 <-> (real[], real[]) FOR ORDER BY float_ops, + FUNCTION 1 l2sq_dist(real[], real[]); + '; + + dist_cos_ops_drop := 'DROP OPERATOR CLASS IF EXISTS dist_cos_ops USING ' || access_method_name || ' CASCADE;'; + dist_cos_ops := ' + CREATE OPERATOR CLASS dist_cos_ops + FOR TYPE real[] USING ' || access_method_name || ' AS + OPERATOR 1 <-> (real[], real[]) FOR ORDER BY float_ops, + FUNCTION 1 cos_dist(real[], real[]), + -- it is important to set the function with guard the second + -- as op rewriting hook takes the first function to use + OPERATOR 2 <=> (real[], real[]) FOR ORDER BY float_ops, + FUNCTION 2 cos_dist_with_guard(real[], real[]); + '; + + + dist_hamming_ops_drop := 'DROP OPERATOR CLASS IF EXISTS dist_hamming_ops USING ' || access_method_name || ' CASCADE;'; + dist_hamming_ops := ' + CREATE OPERATOR CLASS dist_hamming_ops + FOR TYPE integer[] USING ' || access_method_name || ' AS + OPERATOR 1 <-> (integer[], integer[]) FOR ORDER BY float_ops, + FUNCTION 1 hamming_dist(integer[], integer[]), + OPERATOR 2 <+> (integer[], integer[]) FOR ORDER BY integer_ops, + FUNCTION 2 hamming_dist_with_guard(integer[], integer[]); + '; + + + -- Execute the dynamic SQL statement. + EXECUTE dist_l2sq_ops_drop; + EXECUTE dist_l2sq_ops; + EXECUTE dist_cos_ops_drop; + EXECUTE dist_cos_ops; + EXECUTE dist_hamming_ops_drop; + EXECUTE dist_hamming_ops; + + RETURN TRUE; + END; + $$ LANGUAGE plpgsql VOLATILE; + + -- Check if the vector type from pgvector exists + SELECT EXISTS ( + SELECT 1 + FROM pg_type + WHERE typname = 'vector' + ) INTO pgvector_exists; + + am_name := 'hnsw'; + + + IF pgvector_exists THEN + CREATE FUNCTION cos_dist(vector, vector) RETURNS float8 + AS 'MODULE_PATHNAME', 'vector_cos_dist' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + + CREATE FUNCTION hamming_dist(vector, vector) RETURNS float8 + AS 'MODULE_PATHNAME', 'vector_hamming_dist' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + + + CREATE OPERATOR <+> ( + LEFTARG = vector, RIGHTARG = vector, PROCEDURE = hamming_dist, + COMMUTATOR = '<+>' + ); + + CREATE OPERATOR CLASS dist_vec_cos_ops + FOR TYPE vector USING lantern_hnsw AS + OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 cos_dist(vector, vector), + OPERATOR 2 <=> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 2 cos_dist(vector, vector); + + CREATE OPERATOR CLASS dist_vec_hamming_ops + FOR TYPE vector USING lantern_hnsw AS + OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 hamming_dist(vector, vector), + OPERATOR 2 <+> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 2 hamming_dist(vector, vector); + + am_name := 'lantern_hnsw'; + END IF; + + -- this function is needed, as we should also use <-> operator + -- with integer[] type (to overwrite hamming dist function in our hooks) + -- and if we do create l2sq_dist for integer[] type it will fail to cast in pgvector_compat mode + CREATE OR REPLACE FUNCTION l2sq_dist(integer[], integer[]) RETURNS real + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + + -- functions _with_guard suffix are used to forbid operator usage + -- if operator hooks are enabled (lantern.pgvector_compat=FALSE) + CREATE FUNCTION cos_dist_with_guard(real[], real[]) RETURNS real + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + + CREATE FUNCTION hamming_dist_with_guard(integer[], integer[]) RETURNS integer + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + + + -- keep existing indexes to reindex as we should drop indexes in order to change operator classes + OPEN indexes_cursor FOR SELECT * FROM pg_indexes WHERE indexdef ILIKE '%USING ' || am_name || '%'; + -- Fetch index names into the array + LOOP + FETCH indexes_cursor INTO r; + EXIT WHEN NOT FOUND; + + -- Append index name to the array + index_names := array_append(index_names, r.indexname); + index_definitions := array_append(index_definitions, r.indexdef); + END LOOP; + + CLOSE indexes_cursor; + + -- operators + DROP OPERATOR <->(real[], real[]) CASCADE; + CREATE OPERATOR <-> ( + LEFTARG = real[], RIGHTARG = real[], PROCEDURE = l2sq_dist, + COMMUTATOR = '<->' + ); + + DROP OPERATOR <->(integer[], integer[]) CASCADE; + CREATE OPERATOR <-> ( + LEFTARG = integer[], RIGHTARG = integer[], PROCEDURE = l2sq_dist, + COMMUTATOR = '<->' + ); + + CREATE OPERATOR <=> ( + LEFTARG = real[], RIGHTARG = real[], PROCEDURE = cos_dist_with_guard, + COMMUTATOR = '<=>' + ); + + CREATE OPERATOR <+> ( + LEFTARG = integer[], RIGHTARG = integer[], PROCEDURE = hamming_dist_with_guard, + COMMUTATOR = '<+>' + ); + + PERFORM _lantern_internal._recreate_ldb_operator_classes(am_name); + + SET client_min_messages TO NOTICE; + -- reindex indexes + FOR i IN 1..coalesce(array_length(index_names, 1), 0) LOOP + RAISE NOTICE 'Reindexing index %', index_names[i]; + EXECUTE index_definitions[i]; + RAISE NOTICE 'Reindexed index: %', index_names[i]; + END LOOP; +END; +$BODY$ +LANGUAGE plpgsql; diff --git a/src/hnsw.c b/src/hnsw.c index 59e233924..d46085ef0 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -325,6 +325,13 @@ static float8 vector_dist(Vector *a, Vector *b, usearch_metric_kind_t metric_kin return usearch_dist(a->x, b->x, metric_kind, a->dim, usearch_scalar_f32_k); } +static void pgvector_compat_guard() +{ + if(!ldb_pgvector_compat) { + elog(ERROR, "Operator can only be used when lantern.pgvector_compat=TRUE"); + } +} + PGDLLEXPORT PG_FUNCTION_INFO_V1(ldb_generic_dist); Datum ldb_generic_dist(PG_FUNCTION_ARGS) { PG_RETURN_NULL(); } @@ -344,6 +351,15 @@ Datum cos_dist(PG_FUNCTION_ARGS) PG_RETURN_FLOAT4(array_dist(a, b, usearch_metric_cos_k)); } +PGDLLEXPORT PG_FUNCTION_INFO_V1(cos_dist_with_guard); +Datum cos_dist_with_guard(PG_FUNCTION_ARGS) +{ + pgvector_compat_guard(); + ArrayType *a = PG_GETARG_ARRAYTYPE_P(0); + ArrayType *b = PG_GETARG_ARRAYTYPE_P(1); + PG_RETURN_FLOAT4(array_dist(a, b, usearch_metric_cos_k)); +} + PGDLLEXPORT PG_FUNCTION_INFO_V1(hamming_dist); Datum hamming_dist(PG_FUNCTION_ARGS) { @@ -352,6 +368,15 @@ Datum hamming_dist(PG_FUNCTION_ARGS) PG_RETURN_INT32((int32)array_dist(a, b, usearch_metric_hamming_k)); } +PGDLLEXPORT PG_FUNCTION_INFO_V1(hamming_dist_with_guard); +Datum hamming_dist_with_guard(PG_FUNCTION_ARGS) +{ + pgvector_compat_guard(); + ArrayType *a = PG_GETARG_ARRAYTYPE_P(0); + ArrayType *b = PG_GETARG_ARRAYTYPE_P(1); + PG_RETURN_INT32((int32)array_dist(a, b, usearch_metric_hamming_k)); +} + PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_l2sq_dist); Datum vector_l2sq_dist(PG_FUNCTION_ARGS) { @@ -361,6 +386,24 @@ Datum vector_l2sq_dist(PG_FUNCTION_ARGS) PG_RETURN_FLOAT8((double)vector_dist(a, b, usearch_metric_l2sq_k)); } +PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_cos_dist); +Datum vector_cos_dist(PG_FUNCTION_ARGS) +{ + Vector *a = PG_GETARG_VECTOR_P(0); + Vector *b = PG_GETARG_VECTOR_P(1); + + PG_RETURN_FLOAT8((double)vector_dist(a, b, usearch_metric_cos_k)); +} + +PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_hamming_dist); +Datum vector_hamming_dist(PG_FUNCTION_ARGS) +{ + Vector *a = PG_GETARG_VECTOR_P(0); + Vector *b = PG_GETARG_VECTOR_P(1); + + PG_RETURN_FLOAT8((double)vector_dist(a, b, usearch_metric_hamming_k)); +} + PGDLLEXPORT PG_FUNCTION_INFO_V1(lantern_internal_validate_index); Datum lantern_internal_validate_index(PG_FUNCTION_ARGS) { diff --git a/src/hnsw.h b/src/hnsw.h index 570ed6bf2..d3b5edc49 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -30,7 +30,11 @@ PGDLLEXPORT void _PG_fini(void); PGDLLEXPORT Datum l2sq_dist(PG_FUNCTION_ARGS); PGDLLEXPORT Datum vector_l2sq_dist(PG_FUNCTION_ARGS); PGDLLEXPORT Datum hamming_dist(PG_FUNCTION_ARGS); +PGDLLEXPORT Datum hamming_dist_with_guard(PG_FUNCTION_ARGS); PGDLLEXPORT Datum cos_dist(PG_FUNCTION_ARGS); +PGDLLEXPORT Datum cos_dist_with_guard(PG_FUNCTION_ARGS); +PGDLLEXPORT Datum vector_cos_dist(PG_FUNCTION_ARGS); +PGDLLEXPORT Datum vector_hamming_dist(PG_FUNCTION_ARGS); HnswColumnType GetColumnTypeFromOid(Oid oid); HnswColumnType GetIndexColumnType(Relation index); diff --git a/src/hnsw/options.c b/src/hnsw/options.c index 2019878a0..bbd99d823 100644 --- a/src/hnsw/options.c +++ b/src/hnsw/options.c @@ -31,6 +31,10 @@ static relopt_kind ldb_hnsw_index_withopts; int ldb_hnsw_init_k; int ldb_hnsw_ef_search; +// if this variable is set to true +// our operator rewriting hooks will be disabled +bool ldb_pgvector_compat; + // this variable is only set during testing and controls whether // certain elog() calls are made // see ldb_dlog() definition and callsites for details @@ -91,9 +95,9 @@ usearch_metric_kind_t ldb_HnswGetMetricKind(Relation index) if(fnaddr == l2sq_dist || fnaddr == vector_l2sq_dist) { return usearch_metric_l2sq_k; - } else if(fnaddr == hamming_dist) { + } else if(fnaddr == hamming_dist || fnaddr == vector_hamming_dist) { return usearch_metric_hamming_k; - } else if(fnaddr == cos_dist) { + } else if(fnaddr == cos_dist || fnaddr == vector_cos_dist) { return usearch_metric_cos_k; } else { elog(ERROR, "could not find distance function for index"); @@ -252,12 +256,31 @@ void _PG_init(void) NULL, NULL, NULL); + + DefineCustomBoolVariable("lantern.pgvector_compat", + "Whether or not the operator <-> should automatically detect the right distance function", + "set this to 1 to disable operator rewriting hooks", + &ldb_pgvector_compat, + true, + PGC_USERSET, + 0, + NULL, + NULL, + NULL); } // Called with extension unload. void _PG_fini(void) { // Return back the original hook value. - post_parse_analyze_hook = original_post_parse_analyze_hook; - ExecutorStart_hook = original_ExecutorStart_hook; + // This check is because there might be case if while we stop the hooks (in pgvector_compat mode) + // Another extension will be loaded and it will overwrite the hooks + // And when lantern extension will be unloaded it will set the hooks to original values + // Overwriting the current changed hooks set by another extension + if(ExecutorStart_hook == ExecutorStart_hook_with_operator_check) { + ExecutorStart_hook = original_ExecutorStart_hook; + } + if(post_parse_analyze_hook == post_parse_analyze_hook_with_operator_check) { + post_parse_analyze_hook = original_post_parse_analyze_hook; + } } diff --git a/src/hnsw/options.h b/src/hnsw/options.h index 75eb91851..f1b57e9a3 100644 --- a/src/hnsw/options.h +++ b/src/hnsw/options.h @@ -54,5 +54,6 @@ bytea* ldb_amoptions(Datum reloptions, bool validate); extern int ldb_hnsw_init_k; extern int ldb_hnsw_ef_search; extern bool ldb_is_test; +extern bool ldb_pgvector_compat; #endif // LDB_HNSW_OPTIONS_H diff --git a/src/hooks/executor_start.c b/src/hooks/executor_start.c index 20b264016..edfbcadf8 100644 --- a/src/hooks/executor_start.c +++ b/src/hooks/executor_start.c @@ -9,6 +9,7 @@ #include #include +#include "../hnsw/options.h" #include "../hnsw/utils.h" #include "op_rewrite.h" #include "plan_tree_walker.h" @@ -71,6 +72,11 @@ void ExecutorStart_hook_with_operator_check(QueryDesc *queryDesc, int eflags) original_ExecutorStart_hook(queryDesc, eflags); } + if(ldb_pgvector_compat) { + standard_ExecutorStart(queryDesc, eflags); + return; + } + if(creating_extension) { // this is true in only CREATE EXTENSION and ALTER EXTENSION UPDATE commands // these statements are guaranteed to not use our operators and state necessary diff --git a/src/hooks/op_rewrite.c b/src/hooks/op_rewrite.c index 9b8e1c60d..12b777e98 100644 --- a/src/hooks/op_rewrite.c +++ b/src/hooks/op_rewrite.c @@ -172,7 +172,6 @@ static Oid get_func_id_from_index(Relation index) // it doesn't enforce this invariant. Ideally we would call SearchCatCache1 directly but postgres doesn't expose // necessary constants CatCList *opList = SearchSysCacheList1(AMPROCNUM, ObjectIdGetDatum(opclassOid)); - assert(opList->n_members == 1); HeapTuple opTuple = &opList->members[ 0 ]->tuple; if(!HeapTupleIsValid(opTuple)) { index_close(index, AccessShareLock); diff --git a/src/hooks/post_parse.c b/src/hooks/post_parse.c index 99bd459fd..339820a2d 100644 --- a/src/hooks/post_parse.c +++ b/src/hooks/post_parse.c @@ -11,6 +11,7 @@ #include #include +#include "../hnsw/options.h" #include "utils.h" post_parse_analyze_hook_type original_post_parse_analyze_hook = NULL; @@ -171,7 +172,7 @@ void post_parse_analyze_hook_with_operator_check(ParseState *pstate, #endif } - if(creating_extension) { + if(ldb_pgvector_compat || creating_extension) { return; } diff --git a/test/expected/ext_relocation.out b/test/expected/ext_relocation.out index 9e21d69ec..97f6b08b4 100644 --- a/test/expected/ext_relocation.out +++ b/test/expected/ext_relocation.out @@ -39,12 +39,14 @@ ORDER BY 1, 3, 2; schema1 | reindex_lantern_indexes | _lantern_internal schema1 | validate_index | _lantern_internal schema1 | cos_dist | schema1 + schema1 | cos_dist_with_guard | schema1 schema1 | hamming_dist | schema1 + schema1 | hamming_dist_with_guard | schema1 schema1 | hnsw_handler | schema1 schema1 | l2sq_dist | schema1 + schema1 | l2sq_dist | schema1 schema1 | ldb_generic_dist | schema1 - schema1 | ldb_generic_dist | schema1 -(10 rows) +(12 rows) -- show all the extension operators SELECT ne.nspname AS extschema, op.oprname, np.nspname AS proschema @@ -59,7 +61,9 @@ ORDER BY 1, 3; -----------+---------+----------- schema1 | <-> | schema1 schema1 | <-> | schema1 -(2 rows) + schema1 | <=> | schema1 + schema1 | <+> | schema1 +(4 rows) SET search_path TO public, schema1; -- extension function is accessible @@ -102,7 +106,9 @@ ORDER BY 1, 3; -----------+---------+----------- schema1 | <-> | schema1 schema1 | <-> | schema1 -(2 rows) + schema1 | <=> | schema1 + schema1 | <+> | schema1 +(4 rows) SET search_path TO public, schema2; --extension access method is still accessible since access methods are not schema-qualified diff --git a/test/expected/hnsw_dist_func.out b/test/expected/hnsw_dist_func.out index 1b91440d4..ee0d9be57 100644 --- a/test/expected/hnsw_dist_func.out +++ b/test/expected/hnsw_dist_func.out @@ -34,7 +34,8 @@ INFO: done saving 0 vectors INSERT INTO small_world_l2 SELECT id, v FROM small_world; INSERT INTO small_world_cos SELECT id, v FROM small_world; INSERT INTO small_world_ham SELECT id, ARRAY[CAST(v[1] AS INTEGER), CAST(v[2] AS INTEGER), CAST(v[3] AS INTEGER)] FROM small_world; -SET enable_seqscan = false; +SET enable_seqscan=FALSE; +SET lantern.pgvector_compat=FALSE; -- Verify that the distance functions work (check distances) SELECT ROUND(l2sq_dist(v, '{0,1,0}')::numeric, 2) FROM small_world_l2 ORDER BY v <-> '{0,1,0}'; round @@ -133,7 +134,7 @@ SELECT 1 FROM small_world_cos ORDER BY v <-> '{0,1,0,1}' LIMIT 1; ERROR: Expected real array with dimension 3, got 4 SELECT 1 FROM small_world_ham ORDER BY v <-> '{0,1,0,1}' LIMIT 1; ERROR: Expected int array with dimension 3, got 4 -SELECT l2sq_dist('{1,1}', '{0,1,0}'); +SELECT l2sq_dist('{1,1}'::REAL[], '{0,1,0}'::REAL[]); ERROR: expected equally sized arrays but got arrays with dimensions 2 and 3 SELECT cos_dist('{1,1}', '{0,1,0}'); ERROR: expected equally sized arrays but got arrays with dimensions 2 and 3 diff --git a/test/expected/hnsw_operators.out b/test/expected/hnsw_operators.out new file mode 100644 index 000000000..b8971acad --- /dev/null +++ b/test/expected/hnsw_operators.out @@ -0,0 +1,116 @@ +-- Validate that lantern.pgvector_compat disables the operator rewriting hooks +CREATE TABLE op_test (v REAL[]); +INSERT INTO op_test (v) VALUES (ARRAY[0,0,0]), (ARRAY[1,1,1]); +CREATE INDEX cos_idx ON op_test USING hnsw(v dist_cos_ops); +INFO: done init usearch index +INFO: inserted 2 elements +INFO: done saving 2 vectors +-- should rewrite operator +SELECT * FROM op_test ORDER BY v <-> ARRAY[1,1,1]; + v +--------- + {1,1,1} + {0,0,0} +(2 rows) + +-- should throw error +\set ON_ERROR_STOP off +SET lantern.pgvector_compat=FALSE; +SELECT * FROM op_test ORDER BY v <=> ARRAY[1,1,1]; +ERROR: Operator can only be used when lantern.pgvector_compat=TRUE +-- should throw error +SELECT * FROM op_test ORDER BY v::INTEGER[] <+> ARRAY[1,1,1]; +ERROR: Operator can only be used when lantern.pgvector_compat=TRUE +-- should throw error +SELECT v <-> ARRAY[1,1,1] FROM op_test ORDER BY v <-> ARRAY[1,1,1]; +ERROR: Operator <-> is invalid outside of ORDER BY context +SET lantern.pgvector_compat=TRUE; +SET enable_seqscan=OFF; +\set ON_ERROR_STOP on +-- NOW THIS IS TRIGGERING INDEX SCAN AS WELL +-- BECAUSE WE ARE REGISTERING <-> FOR ALL OPERATOR CLASSES +-- IDEALLY THIS SHOULD NOT TRIGGER INDEX SCAN WHEN lantern.pgvector_compat=TRUE +EXPLAIN (COSTS FALSE) SELECT * FROM op_test ORDER BY v <-> ARRAY[1,1,1]; + QUERY PLAN +--------------------------------------- + Index Scan using cos_idx on op_test + Order By: (v <-> '{1,1,1}'::real[]) +(2 rows) + +-- should sort with index +EXPLAIN (COSTS FALSE) SELECT * FROM op_test ORDER BY v <=> ARRAY[1,1,1]; + QUERY PLAN +--------------------------------------- + Index Scan using cos_idx on op_test + Order By: (v <=> '{1,1,1}'::real[]) +(2 rows) + +-- should sort without index +EXPLAIN (COSTS FALSE) SELECT * FROM op_test ORDER BY v::INTEGER[] <+> ARRAY[1,1,1]; + QUERY PLAN +--------------------------------------------------------- + Sort + Sort Key: (((v)::integer[] <+> '{1,1,1}'::integer[])) + -> Seq Scan on op_test +(3 rows) + +-- should not throw error +\set ON_ERROR_STOP on +SELECT v <=> ARRAY[1,1,1] FROM op_test ORDER BY v <=> ARRAY[1,1,1]; + ?column? +---------- + 0 + 1 +(2 rows) + +-- should not throw error +SELECT v::INTEGER[] <+> ARRAY[1,1,1] FROM op_test ORDER BY v::INTEGER[] <+> ARRAY[1,1,1]; + ?column? +---------- + 0 + 3 +(2 rows) + +-- should not throw error +SELECT v <-> ARRAY[1,1,1] FROM op_test ORDER BY v <-> ARRAY[1,1,1]; + ?column? +---------- + 0 + 3 +(2 rows) + +RESET ALL; +-- Set false twice to verify that no crash is happening +SET lantern.pgvector_compat=FALSE; +SET lantern.pgvector_compat=FALSE; +\set ON_ERROR_STOP off +-- should rewrite operator +SELECT * FROM op_test ORDER BY v <-> ARRAY[1,1,1]; + v +--------- + {1,1,1} + {0,0,0} +(2 rows) + +SET lantern.pgvector_compat=TRUE; +SET enable_seqscan=OFF; +CREATE INDEX hamming_idx ON op_test USING hnsw(cast(v as INTEGER[]) dist_hamming_ops); +INFO: done init usearch index +INFO: inserted 2 elements +INFO: done saving 2 vectors +-- should sort with cos_idx index +EXPLAIN (COSTS FALSE) SELECT * FROM op_test ORDER BY v <=> ARRAY[1,1,1]; + QUERY PLAN +--------------------------------------- + Index Scan using cos_idx on op_test + Order By: (v <=> '{1,1,1}'::real[]) +(2 rows) + +-- should sort with hamming_idx index +EXPLAIN (COSTS FALSE) SELECT * FROM op_test ORDER BY v::INTEGER[] <+> ARRAY[1,1,1]; + QUERY PLAN +------------------------------------------------------- + Index Scan using hamming_idx on op_test + Order By: ((v)::integer[] <+> '{1,1,1}'::integer[]) +(2 rows) + diff --git a/test/expected/hnsw_select.out b/test/expected/hnsw_select.out index 8b4b42fc7..e8c5b3cc4 100644 --- a/test/expected/hnsw_select.out +++ b/test/expected/hnsw_select.out @@ -39,7 +39,8 @@ CREATE INDEX ON test1 USING hnsw (v); INFO: done init usearch index INFO: inserted 1 elements INFO: done saving 1 vectors -SET enable_seqscan = false; +SET enable_seqscan=FALSE; +SET lantern.pgvector_compat=FALSE; -- Verify that basic queries still work given our query parser and planner hooks SELECT 0 + 1; ?column? diff --git a/test/expected/hnsw_todo.out b/test/expected/hnsw_todo.out index e65164191..c6cdf62c7 100644 --- a/test/expected/hnsw_todo.out +++ b/test/expected/hnsw_todo.out @@ -14,7 +14,8 @@ INSERT INTO small_world_l2 (id, vector) VALUES ('101', '{1,0,1}'), ('110', '{1,1,0}'), ('111', '{1,1,1}'); -SET enable_seqscan = false; +SET enable_seqscan=FALSE; +SET lantern.pgvector_compat=FALSE; \set ON_ERROR_STOP off CREATE INDEX ON small_world_l2 USING hnsw (vector dist_l2sq_ops); INFO: done init usearch index @@ -38,14 +39,13 @@ EXPLAIN (COSTS FALSE) SELECT id, ROUND(l2sq_dist(vector_int, array[0,1,0])::numeric, 2) as dist FROM small_world_l2 ORDER BY vector_int <-> array[0,1,0] LIMIT 7; - QUERY PLAN ------------------------------------------------------------------------ + QUERY PLAN +------------------------------------------------------------------------ Limit - -> Result - -> Sort - Sort Key: (l2sq_dist(vector_int, '{0,1,0}'::integer[])) - -> Seq Scan on small_world_l2 -(5 rows) + -> Sort + Sort Key: (public.l2sq_dist(vector_int, '{0,1,0}'::integer[])) + -> Seq Scan on small_world_l2 +(4 rows) --- Test scenarious --- ----------------------------------------- diff --git a/test/expected/hnsw_vector.out b/test/expected/hnsw_vector.out index a405c8b65..5d2013dc6 100644 --- a/test/expected/hnsw_vector.out +++ b/test/expected/hnsw_vector.out @@ -10,6 +10,7 @@ CREATE EXTENSION vector; SET client_min_messages=ERROR; CREATE EXTENSION lantern; RESET client_min_messages; +SET lantern.pgvector_compat=FALSE; -- Verify basic functionality of pgvector SELECT '[1,2,3]'::vector; vector @@ -242,3 +243,108 @@ SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0]; 3 (3 rows) +-- Test pgvector in lantern.pgvector_compat=TRUE mode +DROP TABLE small_world; +\ir utils/small_world_vector.sql +CREATE TABLE small_world ( + id VARCHAR(3), + b BOOLEAN, + v VECTOR(3) +); +INSERT INTO small_world (id, b, v) VALUES + ('000', TRUE, '[0,0,0]'), + ('001', TRUE, '[0,0,1]'), + ('010', FALSE, '[0,1,0]'), + ('011', TRUE, '[0,1,1]'), + ('100', FALSE, '[1,0,0]'), + ('101', FALSE, '[1,0,1]'), + ('110', FALSE, '[1,1,0]'), + ('111', TRUE, '[1,1,1]'); +-- Distance functions +SET lantern.pgvector_compat=TRUE; +SET enable_seqscan=OFF; +-- Note: +-- For l2sqs and cosine distances in SELECT statement +-- It is better to use the function by name like cos_dist or l2sq_dist +-- As operators for vector types are provided from pgvector, so it may cause undefined behaviour +-- l2sq index +CREATE INDEX l2_idx ON small_world USING lantern_hnsw (v) WITH (dim=3, M=5, ef=20, ef_construction=20); +INFO: done init usearch index +INFO: inserted 8 elements +INFO: done saving 8 vectors +SELECT ROUND(l2sq_dist(v, '[0,1,0]'::VECTOR)::numeric, 2) as dist +FROM small_world ORDER BY v <-> '[0,1,0]'::VECTOR LIMIT 7; + dist +------ + 0.00 + 1.00 + 1.00 + 1.00 + 2.00 + 2.00 + 2.00 +(7 rows) + +EXPLAIN (COSTS FALSE) SELECT ROUND(l2sq_dist(v, '[0,1,0]'::VECTOR)::numeric, 2) as dist +FROM small_world ORDER BY v <-> '[0,1,0]'::VECTOR LIMIT 7; + QUERY PLAN +---------------------------------------------- + Limit + -> Index Scan using l2_idx on small_world + Order By: (v <-> '[0,1,0]'::vector) +(3 rows) + +-- cosine index +CREATE INDEX cos_idx ON small_world USING lantern_hnsw (v dist_vec_cos_ops); +INFO: done init usearch index +INFO: inserted 8 elements +INFO: done saving 8 vectors +SELECT ROUND(cos_dist(v, '[0,1,0]'::VECTOR)::numeric, 2) as dist +FROM small_world ORDER BY v <=> '[0,1,0]'::VECTOR LIMIT 7; + dist +------ + 0.00 + 0.29 + 0.29 + 0.42 + 1.00 + 1.00 + 1.00 +(7 rows) + +EXPLAIN (COSTS FALSE) SELECT ROUND(cos_dist(v, '[0,1,0]'::VECTOR)::numeric, 2) as dist +FROM small_world ORDER BY v <=> '[0,1,0]'::VECTOR LIMIT 7; + QUERY PLAN +----------------------------------------------- + Limit + -> Index Scan using cos_idx on small_world + Order By: (v <=> '[0,1,0]'::vector) +(3 rows) + +-- hamming index +CREATE INDEX hamming_idx ON small_world USING lantern_hnsw (v dist_vec_hamming_ops); +INFO: done init usearch index +INFO: inserted 8 elements +INFO: done saving 8 vectors +SELECT ROUND((v <+> '[0,1,0]'::VECTOR)::numeric, 2) as dist +FROM small_world ORDER BY v <+> '[0,1,0]'::VECTOR LIMIT 7; + dist +------- + 0.00 + 7.00 + 7.00 + 7.00 + 14.00 + 14.00 + 14.00 +(7 rows) + +EXPLAIN (COSTS FALSE) SELECT ROUND((v <+> '[0,1,0]'::VECTOR)::numeric, 2) as dist +FROM small_world ORDER BY v <+> '[0,1,0]'::VECTOR LIMIT 7; + QUERY PLAN +--------------------------------------------------- + Limit + -> Index Scan using hamming_idx on small_world + Order By: (v <+> '[0,1,0]'::vector) +(3 rows) + diff --git a/test/schedule.txt b/test/schedule.txt index d8c5c008a..e93defec8 100644 --- a/test/schedule.txt +++ b/test/schedule.txt @@ -3,5 +3,5 @@ # - every test that needs to be run iff pgvector is installed appears in a 'test_pgvector:' line # - 'test' lines may have multiple space-separated tests. All tests in a single 'test' line will be run in parallel -test: hnsw_config hnsw_correct hnsw_create hnsw_create_expr hnsw_dist_func hnsw_insert hnsw_select hnsw_todo hnsw_index_from_file hnsw_cost_estimate ext_relocation hnsw_ef_search hnsw_failure_point +test: hnsw_config hnsw_correct hnsw_create hnsw_create_expr hnsw_dist_func hnsw_insert hnsw_select hnsw_todo hnsw_index_from_file hnsw_cost_estimate ext_relocation hnsw_ef_search hnsw_failure_point hnsw_operators test_pgvector: hnsw_vector diff --git a/test/sql/hnsw_dist_func.sql b/test/sql/hnsw_dist_func.sql index 48b5f1f5b..757334070 100644 --- a/test/sql/hnsw_dist_func.sql +++ b/test/sql/hnsw_dist_func.sql @@ -16,7 +16,8 @@ INSERT INTO small_world_l2 SELECT id, v FROM small_world; INSERT INTO small_world_cos SELECT id, v FROM small_world; INSERT INTO small_world_ham SELECT id, ARRAY[CAST(v[1] AS INTEGER), CAST(v[2] AS INTEGER), CAST(v[3] AS INTEGER)] FROM small_world; -SET enable_seqscan = false; +SET enable_seqscan=FALSE; +SET lantern.pgvector_compat=FALSE; -- Verify that the distance functions work (check distances) SELECT ROUND(l2sq_dist(v, '{0,1,0}')::numeric, 2) FROM small_world_l2 ORDER BY v <-> '{0,1,0}'; @@ -39,7 +40,7 @@ EXPLAIN (COSTS false) SELECT id FROM small_world_ham ORDER BY v <-> '{0,1,0}'; SELECT 1 FROM small_world_l2 ORDER BY v <-> '{0,1,0,1}' LIMIT 1; SELECT 1 FROM small_world_cos ORDER BY v <-> '{0,1,0,1}' LIMIT 1; SELECT 1 FROM small_world_ham ORDER BY v <-> '{0,1,0,1}' LIMIT 1; -SELECT l2sq_dist('{1,1}', '{0,1,0}'); +SELECT l2sq_dist('{1,1}'::REAL[], '{0,1,0}'::REAL[]); SELECT cos_dist('{1,1}', '{0,1,0}'); SELECT hamming_dist('{1,1}', '{0,1,0}'); diff --git a/test/sql/hnsw_operators.sql b/test/sql/hnsw_operators.sql new file mode 100644 index 000000000..ba6514e1e --- /dev/null +++ b/test/sql/hnsw_operators.sql @@ -0,0 +1,62 @@ +-- Validate that lantern.pgvector_compat disables the operator rewriting hooks +CREATE TABLE op_test (v REAL[]); +INSERT INTO op_test (v) VALUES (ARRAY[0,0,0]), (ARRAY[1,1,1]); +CREATE INDEX cos_idx ON op_test USING hnsw(v dist_cos_ops); +-- should rewrite operator +SELECT * FROM op_test ORDER BY v <-> ARRAY[1,1,1]; + +-- should throw error +\set ON_ERROR_STOP off +SET lantern.pgvector_compat=FALSE; +SELECT * FROM op_test ORDER BY v <=> ARRAY[1,1,1]; + +-- should throw error +SELECT * FROM op_test ORDER BY v::INTEGER[] <+> ARRAY[1,1,1]; + +-- should throw error +SELECT v <-> ARRAY[1,1,1] FROM op_test ORDER BY v <-> ARRAY[1,1,1]; + +SET lantern.pgvector_compat=TRUE; +SET enable_seqscan=OFF; +\set ON_ERROR_STOP on + +-- NOW THIS IS TRIGGERING INDEX SCAN AS WELL +-- BECAUSE WE ARE REGISTERING <-> FOR ALL OPERATOR CLASSES +-- IDEALLY THIS SHOULD NOT TRIGGER INDEX SCAN WHEN lantern.pgvector_compat=TRUE +EXPLAIN (COSTS FALSE) SELECT * FROM op_test ORDER BY v <-> ARRAY[1,1,1]; + +-- should sort with index +EXPLAIN (COSTS FALSE) SELECT * FROM op_test ORDER BY v <=> ARRAY[1,1,1]; + +-- should sort without index +EXPLAIN (COSTS FALSE) SELECT * FROM op_test ORDER BY v::INTEGER[] <+> ARRAY[1,1,1]; + +-- should not throw error +\set ON_ERROR_STOP on + +SELECT v <=> ARRAY[1,1,1] FROM op_test ORDER BY v <=> ARRAY[1,1,1]; + +-- should not throw error +SELECT v::INTEGER[] <+> ARRAY[1,1,1] FROM op_test ORDER BY v::INTEGER[] <+> ARRAY[1,1,1]; + +-- should not throw error +SELECT v <-> ARRAY[1,1,1] FROM op_test ORDER BY v <-> ARRAY[1,1,1]; + +RESET ALL; +-- Set false twice to verify that no crash is happening +SET lantern.pgvector_compat=FALSE; +SET lantern.pgvector_compat=FALSE; +\set ON_ERROR_STOP off +-- should rewrite operator +SELECT * FROM op_test ORDER BY v <-> ARRAY[1,1,1]; + +SET lantern.pgvector_compat=TRUE; +SET enable_seqscan=OFF; + +CREATE INDEX hamming_idx ON op_test USING hnsw(cast(v as INTEGER[]) dist_hamming_ops); + +-- should sort with cos_idx index +EXPLAIN (COSTS FALSE) SELECT * FROM op_test ORDER BY v <=> ARRAY[1,1,1]; + +-- should sort with hamming_idx index +EXPLAIN (COSTS FALSE) SELECT * FROM op_test ORDER BY v::INTEGER[] <+> ARRAY[1,1,1]; diff --git a/test/sql/hnsw_select.sql b/test/sql/hnsw_select.sql index 5bd9c90cf..6f8132f66 100644 --- a/test/sql/hnsw_select.sql +++ b/test/sql/hnsw_select.sql @@ -15,7 +15,8 @@ INSERT INTO test1 (v) VALUES ('{5,3}'); INSERT INTO test2 (v) VALUES ('{5,4}'); CREATE INDEX ON test1 USING hnsw (v); -SET enable_seqscan = false; +SET enable_seqscan=FALSE; +SET lantern.pgvector_compat=FALSE; -- Verify that basic queries still work given our query parser and planner hooks SELECT 0 + 1; diff --git a/test/sql/hnsw_todo.sql b/test/sql/hnsw_todo.sql index 9a6e637b1..8f5113254 100644 --- a/test/sql/hnsw_todo.sql +++ b/test/sql/hnsw_todo.sql @@ -17,7 +17,8 @@ INSERT INTO small_world_l2 (id, vector) VALUES ('110', '{1,1,0}'), ('111', '{1,1,1}'); -SET enable_seqscan = false; +SET enable_seqscan=FALSE; +SET lantern.pgvector_compat=FALSE; \set ON_ERROR_STOP off diff --git a/test/sql/hnsw_vector.sql b/test/sql/hnsw_vector.sql index 25393ae89..8d000a8a4 100644 --- a/test/sql/hnsw_vector.sql +++ b/test/sql/hnsw_vector.sql @@ -11,6 +11,7 @@ CREATE EXTENSION vector; SET client_min_messages=ERROR; CREATE EXTENSION lantern; RESET client_min_messages; +SET lantern.pgvector_compat=FALSE; -- Verify basic functionality of pgvector SELECT '[1,2,3]'::vector; @@ -112,3 +113,43 @@ SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0]; DROP INDEX cos_idx; CREATE INDEX ham_idx ON small_world_arr USING lantern_hnsw(v) WITH (m=3); SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0]; + +-- Test pgvector in lantern.pgvector_compat=TRUE mode +DROP TABLE small_world; +\ir utils/small_world_vector.sql + +-- Distance functions +SET lantern.pgvector_compat=TRUE; +SET enable_seqscan=OFF; + +-- Note: +-- For l2sqs and cosine distances in SELECT statement +-- It is better to use the function by name like cos_dist or l2sq_dist +-- As operators for vector types are provided from pgvector, so it may cause undefined behaviour + +-- l2sq index +CREATE INDEX l2_idx ON small_world USING lantern_hnsw (v) WITH (dim=3, M=5, ef=20, ef_construction=20); + +SELECT ROUND(l2sq_dist(v, '[0,1,0]'::VECTOR)::numeric, 2) as dist +FROM small_world ORDER BY v <-> '[0,1,0]'::VECTOR LIMIT 7; + +EXPLAIN (COSTS FALSE) SELECT ROUND(l2sq_dist(v, '[0,1,0]'::VECTOR)::numeric, 2) as dist +FROM small_world ORDER BY v <-> '[0,1,0]'::VECTOR LIMIT 7; + +-- cosine index +CREATE INDEX cos_idx ON small_world USING lantern_hnsw (v dist_vec_cos_ops); + +SELECT ROUND(cos_dist(v, '[0,1,0]'::VECTOR)::numeric, 2) as dist +FROM small_world ORDER BY v <=> '[0,1,0]'::VECTOR LIMIT 7; + +EXPLAIN (COSTS FALSE) SELECT ROUND(cos_dist(v, '[0,1,0]'::VECTOR)::numeric, 2) as dist +FROM small_world ORDER BY v <=> '[0,1,0]'::VECTOR LIMIT 7; + +-- hamming index +CREATE INDEX hamming_idx ON small_world USING lantern_hnsw (v dist_vec_hamming_ops); + +SELECT ROUND((v <+> '[0,1,0]'::VECTOR)::numeric, 2) as dist +FROM small_world ORDER BY v <+> '[0,1,0]'::VECTOR LIMIT 7; + +EXPLAIN (COSTS FALSE) SELECT ROUND((v <+> '[0,1,0]'::VECTOR)::numeric, 2) as dist +FROM small_world ORDER BY v <+> '[0,1,0]'::VECTOR LIMIT 7;