diff --git a/sql/lantern.sql b/sql/lantern.sql index a04c5837d..ab1fe7d25 100644 --- a/sql/lantern.sql +++ b/sql/lantern.sql @@ -1,8 +1,74 @@ -- Definitions concerning our hnsw-based index data strucuture - CREATE FUNCTION hnsw_handler(internal) RETURNS index_am_handler AS 'MODULE_PATHNAME' LANGUAGE C; +-- functions +CREATE FUNCTION ldb_generic_dist(real[], real[]) RETURNS real + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION ldb_generic_dist(integer[], integer[]) RETURNS real + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION l2sq_dist(real[], real[]) RETURNS real + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION cos_dist(real[], real[]) RETURNS real + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION hamming_dist(integer[], integer[]) RETURNS integer + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +-- operators +CREATE OPERATOR <-> ( + LEFTARG = real[], RIGHTARG = real[], PROCEDURE = ldb_generic_dist, + COMMUTATOR = '<->' +); + +CREATE OPERATOR <-> ( + LEFTARG = integer[], RIGHTARG = integer[], PROCEDURE = ldb_generic_dist, + COMMUTATOR = '<->' +); + +-- operator classes +CREATE OR REPLACE FUNCTION _create_ldb_operator_classes(access_method_name TEXT) RETURNS BOOLEAN AS $$ +DECLARE + dist_l2sq_ops TEXT; + dist_cos_ops TEXT; + dist_hamming_ops TEXT; +BEGIN + -- Construct the SQL statement to create the operator classes dynamically. + dist_l2sq_ops := ' + CREATE OPERATOR CLASS dist_l2sq_ops + DEFAULT FOR TYPE real[] USING ' || access_method_name || ' AS + OPERATOR 1 <-> (real[], real[]) FOR ORDER BY float_ops, + FUNCTION 1 l2sq_dist(real[], real[]); + '; + + dist_cos_ops := ' + CREATE OPERATOR CLASS dist_cos_ops + FOR TYPE real[] USING ' || access_method_name || ' AS + OPERATOR 1 <-> (real[], real[]) FOR ORDER BY float_ops, + FUNCTION 1 cos_dist(real[], real[]); + '; + + dist_hamming_ops := ' + CREATE OPERATOR CLASS dist_hamming_ops + FOR TYPE integer[] USING ' || access_method_name || ' AS + OPERATOR 1 <-> (integer[], integer[]) FOR ORDER BY float_ops, + FUNCTION 1 hamming_dist(integer[], integer[]); + '; + + -- Execute the dynamic SQL statement. + EXECUTE dist_l2sq_ops; + EXECUTE dist_cos_ops; + EXECUTE dist_hamming_ops; + + RETURN TRUE; +END; +$$ LANGUAGE plpgsql VOLATILE; + + +-- Create access method DO $BODY$ DECLARE hnsw_am_exists boolean; @@ -41,55 +107,14 @@ BEGIN IF hnsw_am_exists THEN + PERFORM _create_ldb_operator_classes('lantern_hnsw'); RAISE WARNING 'Access method(index type) "hnsw" already exists. Creating lantern_hnsw access method'; ELSE -- create access method CREATE ACCESS METHOD hnsw TYPE INDEX HANDLER hnsw_handler; COMMENT ON ACCESS METHOD hnsw IS 'LanternDB access method for vector embeddings, based on the hnsw algorithm'; + PERFORM _create_ldb_operator_classes('hnsw'); END IF; END; $BODY$ LANGUAGE plpgsql; - --- functions -CREATE FUNCTION ldb_generic_dist(real[], real[]) RETURNS real - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; - -CREATE FUNCTION ldb_generic_dist(integer[], integer[]) RETURNS real - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; - -CREATE FUNCTION l2sq_dist(real[], real[]) RETURNS real - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; - -CREATE FUNCTION cos_dist(real[], real[]) RETURNS real - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; - -CREATE FUNCTION hamming_dist(integer[], integer[]) RETURNS integer - AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; - --- operators -CREATE OPERATOR <-> ( - LEFTARG = real[], RIGHTARG = real[], PROCEDURE = ldb_generic_dist, - COMMUTATOR = '<->' -); - -CREATE OPERATOR <-> ( - LEFTARG = integer[], RIGHTARG = integer[], PROCEDURE = ldb_generic_dist, - COMMUTATOR = '<->' -); - --- operator classes -CREATE OPERATOR CLASS dist_l2sq_ops - DEFAULT FOR TYPE real[] USING hnsw AS - OPERATOR 1 <-> (real[], real[]) FOR ORDER BY float_ops, - FUNCTION 1 l2sq_dist(real[], real[]); - -CREATE OPERATOR CLASS dist_cos_ops - FOR TYPE real[] USING hnsw AS - OPERATOR 1 <-> (real[], real[]) FOR ORDER BY float_ops, - FUNCTION 1 cos_dist(real[], real[]); - -CREATE OPERATOR CLASS dist_hamming_ops - FOR TYPE integer[] USING hnsw AS - OPERATOR 1 <-> (integer[], integer[]) FOR ORDER BY float_ops, - FUNCTION 1 hamming_dist(integer[], integer[]); diff --git a/sql/updates/0.0.4-latest.sql b/sql/updates/0.0.4-latest.sql index 5affc003b..4654ea412 100644 --- a/sql/updates/0.0.4-latest.sql +++ b/sql/updates/0.0.4-latest.sql @@ -9,7 +9,45 @@ SELECT EXISTS ( WHERE typname = 'vector' ) INTO pgvector_exists; +CREATE OR REPLACE FUNCTION _create_ldb_operator_classes(access_method_name TEXT) RETURNS BOOLEAN AS $$ +DECLARE + dist_l2sq_ops TEXT; + dist_cos_ops TEXT; + dist_hamming_ops TEXT; +BEGIN + -- Construct the SQL statement to create the operator classes dynamically. + dist_l2sq_ops := ' + CREATE OPERATOR CLASS dist_l2sq_ops + DEFAULT FOR TYPE real[] USING ' || access_method_name || ' AS + OPERATOR 1 <-> (real[], real[]) FOR ORDER BY float_ops, + FUNCTION 1 l2sq_dist(real[], real[]); + '; + + dist_cos_ops := ' + CREATE OPERATOR CLASS dist_cos_ops + FOR TYPE real[] USING ' || access_method_name || ' AS + OPERATOR 1 <-> (real[], real[]) FOR ORDER BY float_ops, + FUNCTION 1 cos_dist(real[], real[]); + '; + + dist_hamming_ops := ' + CREATE OPERATOR CLASS dist_hamming_ops + FOR TYPE integer[] USING ' || access_method_name || ' AS + OPERATOR 1 <-> (integer[], integer[]) FOR ORDER BY float_ops, + FUNCTION 1 hamming_dist(integer[], integer[]); + '; + + -- Execute the dynamic SQL statement. + EXECUTE dist_l2sq_ops; + EXECUTE dist_cos_ops; + EXECUTE dist_hamming_ops; + + RETURN TRUE; +END; +$$ LANGUAGE plpgsql VOLATILE; + IF pgvector_exists THEN CREATE FUNCTION l2sq_dist(vector, vector) RETURNS float8 AS 'MODULE_PATHNAME', 'vector_l2sq_dist' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + PERFORM _create_ldb_operator_classes('lantern_hnsw'); END IF diff --git a/test/expected/hnsw_vector.out b/test/expected/hnsw_vector.out index a02876065..a405c8b65 100644 --- a/test/expected/hnsw_vector.out +++ b/test/expected/hnsw_vector.out @@ -194,3 +194,51 @@ CREATE INDEX ON test_table USING lantern_hnsw (int_to_fixed_binary_vector(id)) W INFO: done init usearch index INFO: inserted 3 elements INFO: done saving 3 vectors +-- Make sure that lantern_hnsw is working correctly alongside pgvector +CREATE TABLE small_world_arr (id SERIAL PRIMARY KEY, v REAL[]); +INSERT INTO small_world_arr (v) VALUES ('{0,0,0}'), ('{0,0,1}'), ('{0,0,2}'); +CREATE INDEX l2_idx ON small_world_arr USING lantern_hnsw(v) WITH (dim=3, m=2); +INFO: done init usearch index +INFO: inserted 3 elements +INFO: done saving 3 vectors +EXPLAIN (COSTS FALSE) SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0]; + QUERY PLAN +-------------------------------------------- + Index Scan using l2_idx on small_world_arr + Order By: (v <-> '{0,0,0}'::real[]) +(2 rows) + +SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0]; + id +---- + 1 + 2 + 3 +(3 rows) + +DROP INDEX l2_idx; +CREATE INDEX cos_idx ON small_world_arr USING lantern_hnsw(v) WITH (m=2); +INFO: done init usearch index +INFO: inserted 3 elements +INFO: done saving 3 vectors +SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0]; + id +---- + 1 + 2 + 3 +(3 rows) + +DROP INDEX cos_idx; +CREATE INDEX ham_idx ON small_world_arr USING lantern_hnsw(v) WITH (m=3); +INFO: done init usearch index +INFO: inserted 3 elements +INFO: done saving 3 vectors +SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0]; + id +---- + 1 + 2 + 3 +(3 rows) + diff --git a/test/sql/hnsw_vector.sql b/test/sql/hnsw_vector.sql index cfe282ab9..25393ae89 100644 --- a/test/sql/hnsw_vector.sql +++ b/test/sql/hnsw_vector.sql @@ -99,3 +99,16 @@ END; $$ LANGUAGE plpgsql IMMUTABLE; CREATE INDEX ON test_table USING lantern_hnsw (int_to_fixed_binary_vector(id)) WITH (M=2); + +-- Make sure that lantern_hnsw is working correctly alongside pgvector +CREATE TABLE small_world_arr (id SERIAL PRIMARY KEY, v REAL[]); +INSERT INTO small_world_arr (v) VALUES ('{0,0,0}'), ('{0,0,1}'), ('{0,0,2}'); +CREATE INDEX l2_idx ON small_world_arr USING lantern_hnsw(v) WITH (dim=3, m=2); +EXPLAIN (COSTS FALSE) SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0]; +SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0]; +DROP INDEX l2_idx; +CREATE INDEX cos_idx ON small_world_arr USING lantern_hnsw(v) WITH (m=2); +SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0]; +DROP INDEX cos_idx; +CREATE INDEX ham_idx ON small_world_arr USING lantern_hnsw(v) WITH (m=3); +SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0];