Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix operator classes for lantern_hnsw access method #180

Merged
merged 1 commit into from
Sep 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 69 additions & 44 deletions sql/lantern.sql
Original file line number Diff line number Diff line change
@@ -1,8 +1,74 @@
-- Definitions concerning our hnsw-based index data strucuture

CREATE FUNCTION hnsw_handler(internal) RETURNS index_am_handler
AS 'MODULE_PATHNAME' LANGUAGE C;

-- functions
CREATE FUNCTION ldb_generic_dist(real[], real[]) RETURNS real
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

CREATE FUNCTION ldb_generic_dist(integer[], integer[]) RETURNS real
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

CREATE FUNCTION l2sq_dist(real[], real[]) RETURNS real
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

CREATE FUNCTION cos_dist(real[], real[]) RETURNS real
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

CREATE FUNCTION hamming_dist(integer[], integer[]) RETURNS integer
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

-- operators
CREATE OPERATOR <-> (
LEFTARG = real[], RIGHTARG = real[], PROCEDURE = ldb_generic_dist,
COMMUTATOR = '<->'
);

CREATE OPERATOR <-> (
LEFTARG = integer[], RIGHTARG = integer[], PROCEDURE = ldb_generic_dist,
COMMUTATOR = '<->'
);

-- operator classes
CREATE OR REPLACE FUNCTION _create_ldb_operator_classes(access_method_name TEXT) RETURNS BOOLEAN AS $$
DECLARE
dist_l2sq_ops TEXT;
dist_cos_ops TEXT;
dist_hamming_ops TEXT;
BEGIN
-- Construct the SQL statement to create the operator classes dynamically.
dist_l2sq_ops := '
CREATE OPERATOR CLASS dist_l2sq_ops
DEFAULT FOR TYPE real[] USING ' || access_method_name || ' AS
OPERATOR 1 <-> (real[], real[]) FOR ORDER BY float_ops,
FUNCTION 1 l2sq_dist(real[], real[]);
';

dist_cos_ops := '
CREATE OPERATOR CLASS dist_cos_ops
FOR TYPE real[] USING ' || access_method_name || ' AS
OPERATOR 1 <-> (real[], real[]) FOR ORDER BY float_ops,
FUNCTION 1 cos_dist(real[], real[]);
';

dist_hamming_ops := '
CREATE OPERATOR CLASS dist_hamming_ops
FOR TYPE integer[] USING ' || access_method_name || ' AS
OPERATOR 1 <-> (integer[], integer[]) FOR ORDER BY float_ops,
FUNCTION 1 hamming_dist(integer[], integer[]);
';

-- Execute the dynamic SQL statement.
EXECUTE dist_l2sq_ops;
EXECUTE dist_cos_ops;
EXECUTE dist_hamming_ops;

RETURN TRUE;
END;
$$ LANGUAGE plpgsql VOLATILE;


-- Create access method
DO $BODY$
DECLARE
hnsw_am_exists boolean;
Expand Down Expand Up @@ -41,55 +107,14 @@ BEGIN


IF hnsw_am_exists THEN
PERFORM _create_ldb_operator_classes('lantern_hnsw');
RAISE WARNING 'Access method(index type) "hnsw" already exists. Creating lantern_hnsw access method';
ELSE
-- create access method
CREATE ACCESS METHOD hnsw TYPE INDEX HANDLER hnsw_handler;
COMMENT ON ACCESS METHOD hnsw IS 'LanternDB access method for vector embeddings, based on the hnsw algorithm';
PERFORM _create_ldb_operator_classes('hnsw');
END IF;
END;
$BODY$
LANGUAGE plpgsql;

-- functions
CREATE FUNCTION ldb_generic_dist(real[], real[]) RETURNS real
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

CREATE FUNCTION ldb_generic_dist(integer[], integer[]) RETURNS real
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

CREATE FUNCTION l2sq_dist(real[], real[]) RETURNS real
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

CREATE FUNCTION cos_dist(real[], real[]) RETURNS real
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

CREATE FUNCTION hamming_dist(integer[], integer[]) RETURNS integer
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

-- operators
CREATE OPERATOR <-> (
LEFTARG = real[], RIGHTARG = real[], PROCEDURE = ldb_generic_dist,
COMMUTATOR = '<->'
);

CREATE OPERATOR <-> (
LEFTARG = integer[], RIGHTARG = integer[], PROCEDURE = ldb_generic_dist,
COMMUTATOR = '<->'
);

-- operator classes
CREATE OPERATOR CLASS dist_l2sq_ops
DEFAULT FOR TYPE real[] USING hnsw AS
OPERATOR 1 <-> (real[], real[]) FOR ORDER BY float_ops,
FUNCTION 1 l2sq_dist(real[], real[]);

CREATE OPERATOR CLASS dist_cos_ops
FOR TYPE real[] USING hnsw AS
OPERATOR 1 <-> (real[], real[]) FOR ORDER BY float_ops,
FUNCTION 1 cos_dist(real[], real[]);

CREATE OPERATOR CLASS dist_hamming_ops
FOR TYPE integer[] USING hnsw AS
OPERATOR 1 <-> (integer[], integer[]) FOR ORDER BY float_ops,
FUNCTION 1 hamming_dist(integer[], integer[]);
38 changes: 38 additions & 0 deletions sql/updates/0.0.4-latest.sql
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,45 @@ SELECT EXISTS (
WHERE typname = 'vector'
) INTO pgvector_exists;

CREATE OR REPLACE FUNCTION _create_ldb_operator_classes(access_method_name TEXT) RETURNS BOOLEAN AS $$
DECLARE
dist_l2sq_ops TEXT;
dist_cos_ops TEXT;
dist_hamming_ops TEXT;
BEGIN
-- Construct the SQL statement to create the operator classes dynamically.
dist_l2sq_ops := '
CREATE OPERATOR CLASS dist_l2sq_ops
DEFAULT FOR TYPE real[] USING ' || access_method_name || ' AS
OPERATOR 1 <-> (real[], real[]) FOR ORDER BY float_ops,
FUNCTION 1 l2sq_dist(real[], real[]);
';

dist_cos_ops := '
CREATE OPERATOR CLASS dist_cos_ops
FOR TYPE real[] USING ' || access_method_name || ' AS
OPERATOR 1 <-> (real[], real[]) FOR ORDER BY float_ops,
FUNCTION 1 cos_dist(real[], real[]);
';

dist_hamming_ops := '
CREATE OPERATOR CLASS dist_hamming_ops
FOR TYPE integer[] USING ' || access_method_name || ' AS
OPERATOR 1 <-> (integer[], integer[]) FOR ORDER BY float_ops,
FUNCTION 1 hamming_dist(integer[], integer[]);
';

-- Execute the dynamic SQL statement.
EXECUTE dist_l2sq_ops;
EXECUTE dist_cos_ops;
EXECUTE dist_hamming_ops;

RETURN TRUE;
END;
$$ LANGUAGE plpgsql VOLATILE;

IF pgvector_exists THEN
CREATE FUNCTION l2sq_dist(vector, vector) RETURNS float8
AS 'MODULE_PATHNAME', 'vector_l2sq_dist' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
PERFORM _create_ldb_operator_classes('lantern_hnsw');
END IF
48 changes: 48 additions & 0 deletions test/expected/hnsw_vector.out
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,51 @@ CREATE INDEX ON test_table USING lantern_hnsw (int_to_fixed_binary_vector(id)) W
INFO: done init usearch index
INFO: inserted 3 elements
INFO: done saving 3 vectors
-- Make sure that lantern_hnsw is working correctly alongside pgvector
CREATE TABLE small_world_arr (id SERIAL PRIMARY KEY, v REAL[]);
INSERT INTO small_world_arr (v) VALUES ('{0,0,0}'), ('{0,0,1}'), ('{0,0,2}');
CREATE INDEX l2_idx ON small_world_arr USING lantern_hnsw(v) WITH (dim=3, m=2);
INFO: done init usearch index
INFO: inserted 3 elements
INFO: done saving 3 vectors
EXPLAIN (COSTS FALSE) SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0];
QUERY PLAN
--------------------------------------------
Index Scan using l2_idx on small_world_arr
Order By: (v <-> '{0,0,0}'::real[])
(2 rows)

SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0];
id
----
1
2
3
(3 rows)

DROP INDEX l2_idx;
CREATE INDEX cos_idx ON small_world_arr USING lantern_hnsw(v) WITH (m=2);
INFO: done init usearch index
INFO: inserted 3 elements
INFO: done saving 3 vectors
SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0];
id
----
1
2
3
(3 rows)

DROP INDEX cos_idx;
CREATE INDEX ham_idx ON small_world_arr USING lantern_hnsw(v) WITH (m=3);
INFO: done init usearch index
INFO: inserted 3 elements
INFO: done saving 3 vectors
SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0];
id
----
1
2
3
(3 rows)

13 changes: 13 additions & 0 deletions test/sql/hnsw_vector.sql
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,16 @@ END;
$$ LANGUAGE plpgsql IMMUTABLE;

CREATE INDEX ON test_table USING lantern_hnsw (int_to_fixed_binary_vector(id)) WITH (M=2);

-- Make sure that lantern_hnsw is working correctly alongside pgvector
CREATE TABLE small_world_arr (id SERIAL PRIMARY KEY, v REAL[]);
INSERT INTO small_world_arr (v) VALUES ('{0,0,0}'), ('{0,0,1}'), ('{0,0,2}');
CREATE INDEX l2_idx ON small_world_arr USING lantern_hnsw(v) WITH (dim=3, m=2);
EXPLAIN (COSTS FALSE) SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0];
SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0];
DROP INDEX l2_idx;
CREATE INDEX cos_idx ON small_world_arr USING lantern_hnsw(v) WITH (m=2);
SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0];
DROP INDEX cos_idx;
CREATE INDEX ham_idx ON small_world_arr USING lantern_hnsw(v) WITH (m=3);
SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0];