Skip to content

Commit

Permalink
Fix op calsses for lantern_hnsw am (#180)
Browse files Browse the repository at this point in the history
  • Loading branch information
var77 authored Sep 30, 2023
1 parent 4c80eca commit 29af187
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 44 deletions.
113 changes: 69 additions & 44 deletions sql/lantern.sql
Original file line number Diff line number Diff line change
@@ -1,8 +1,74 @@
-- Definitions concerning our hnsw-based index data strucuture

CREATE FUNCTION hnsw_handler(internal) RETURNS index_am_handler
AS 'MODULE_PATHNAME' LANGUAGE C;

-- functions
CREATE FUNCTION ldb_generic_dist(real[], real[]) RETURNS real
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

CREATE FUNCTION ldb_generic_dist(integer[], integer[]) RETURNS real
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

CREATE FUNCTION l2sq_dist(real[], real[]) RETURNS real
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

CREATE FUNCTION cos_dist(real[], real[]) RETURNS real
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

CREATE FUNCTION hamming_dist(integer[], integer[]) RETURNS integer
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

-- operators
CREATE OPERATOR <-> (
LEFTARG = real[], RIGHTARG = real[], PROCEDURE = ldb_generic_dist,
COMMUTATOR = '<->'
);

CREATE OPERATOR <-> (
LEFTARG = integer[], RIGHTARG = integer[], PROCEDURE = ldb_generic_dist,
COMMUTATOR = '<->'
);

-- operator classes
CREATE OR REPLACE FUNCTION _create_ldb_operator_classes(access_method_name TEXT) RETURNS BOOLEAN AS $$
DECLARE
dist_l2sq_ops TEXT;
dist_cos_ops TEXT;
dist_hamming_ops TEXT;
BEGIN
-- Construct the SQL statement to create the operator classes dynamically.
dist_l2sq_ops := '
CREATE OPERATOR CLASS dist_l2sq_ops
DEFAULT FOR TYPE real[] USING ' || access_method_name || ' AS
OPERATOR 1 <-> (real[], real[]) FOR ORDER BY float_ops,
FUNCTION 1 l2sq_dist(real[], real[]);
';

dist_cos_ops := '
CREATE OPERATOR CLASS dist_cos_ops
FOR TYPE real[] USING ' || access_method_name || ' AS
OPERATOR 1 <-> (real[], real[]) FOR ORDER BY float_ops,
FUNCTION 1 cos_dist(real[], real[]);
';

dist_hamming_ops := '
CREATE OPERATOR CLASS dist_hamming_ops
FOR TYPE integer[] USING ' || access_method_name || ' AS
OPERATOR 1 <-> (integer[], integer[]) FOR ORDER BY float_ops,
FUNCTION 1 hamming_dist(integer[], integer[]);
';

-- Execute the dynamic SQL statement.
EXECUTE dist_l2sq_ops;
EXECUTE dist_cos_ops;
EXECUTE dist_hamming_ops;

RETURN TRUE;
END;
$$ LANGUAGE plpgsql VOLATILE;


-- Create access method
DO $BODY$
DECLARE
hnsw_am_exists boolean;
Expand Down Expand Up @@ -41,55 +107,14 @@ BEGIN


IF hnsw_am_exists THEN
PERFORM _create_ldb_operator_classes('lantern_hnsw');
RAISE WARNING 'Access method(index type) "hnsw" already exists. Creating lantern_hnsw access method';
ELSE
-- create access method
CREATE ACCESS METHOD hnsw TYPE INDEX HANDLER hnsw_handler;
COMMENT ON ACCESS METHOD hnsw IS 'LanternDB access method for vector embeddings, based on the hnsw algorithm';
PERFORM _create_ldb_operator_classes('hnsw');
END IF;
END;
$BODY$
LANGUAGE plpgsql;

-- functions
CREATE FUNCTION ldb_generic_dist(real[], real[]) RETURNS real
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

CREATE FUNCTION ldb_generic_dist(integer[], integer[]) RETURNS real
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

CREATE FUNCTION l2sq_dist(real[], real[]) RETURNS real
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

CREATE FUNCTION cos_dist(real[], real[]) RETURNS real
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

CREATE FUNCTION hamming_dist(integer[], integer[]) RETURNS integer
AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

-- operators
CREATE OPERATOR <-> (
LEFTARG = real[], RIGHTARG = real[], PROCEDURE = ldb_generic_dist,
COMMUTATOR = '<->'
);

CREATE OPERATOR <-> (
LEFTARG = integer[], RIGHTARG = integer[], PROCEDURE = ldb_generic_dist,
COMMUTATOR = '<->'
);

-- operator classes
CREATE OPERATOR CLASS dist_l2sq_ops
DEFAULT FOR TYPE real[] USING hnsw AS
OPERATOR 1 <-> (real[], real[]) FOR ORDER BY float_ops,
FUNCTION 1 l2sq_dist(real[], real[]);

CREATE OPERATOR CLASS dist_cos_ops
FOR TYPE real[] USING hnsw AS
OPERATOR 1 <-> (real[], real[]) FOR ORDER BY float_ops,
FUNCTION 1 cos_dist(real[], real[]);

CREATE OPERATOR CLASS dist_hamming_ops
FOR TYPE integer[] USING hnsw AS
OPERATOR 1 <-> (integer[], integer[]) FOR ORDER BY float_ops,
FUNCTION 1 hamming_dist(integer[], integer[]);
38 changes: 38 additions & 0 deletions sql/updates/0.0.4-latest.sql
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,45 @@ SELECT EXISTS (
WHERE typname = 'vector'
) INTO pgvector_exists;

CREATE OR REPLACE FUNCTION _create_ldb_operator_classes(access_method_name TEXT) RETURNS BOOLEAN AS $$
DECLARE
dist_l2sq_ops TEXT;
dist_cos_ops TEXT;
dist_hamming_ops TEXT;
BEGIN
-- Construct the SQL statement to create the operator classes dynamically.
dist_l2sq_ops := '
CREATE OPERATOR CLASS dist_l2sq_ops
DEFAULT FOR TYPE real[] USING ' || access_method_name || ' AS
OPERATOR 1 <-> (real[], real[]) FOR ORDER BY float_ops,
FUNCTION 1 l2sq_dist(real[], real[]);
';

dist_cos_ops := '
CREATE OPERATOR CLASS dist_cos_ops
FOR TYPE real[] USING ' || access_method_name || ' AS
OPERATOR 1 <-> (real[], real[]) FOR ORDER BY float_ops,
FUNCTION 1 cos_dist(real[], real[]);
';

dist_hamming_ops := '
CREATE OPERATOR CLASS dist_hamming_ops
FOR TYPE integer[] USING ' || access_method_name || ' AS
OPERATOR 1 <-> (integer[], integer[]) FOR ORDER BY float_ops,
FUNCTION 1 hamming_dist(integer[], integer[]);
';

-- Execute the dynamic SQL statement.
EXECUTE dist_l2sq_ops;
EXECUTE dist_cos_ops;
EXECUTE dist_hamming_ops;

RETURN TRUE;
END;
$$ LANGUAGE plpgsql VOLATILE;

IF pgvector_exists THEN
CREATE FUNCTION l2sq_dist(vector, vector) RETURNS float8
AS 'MODULE_PATHNAME', 'vector_l2sq_dist' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
PERFORM _create_ldb_operator_classes('lantern_hnsw');
END IF
48 changes: 48 additions & 0 deletions test/expected/hnsw_vector.out
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,51 @@ CREATE INDEX ON test_table USING lantern_hnsw (int_to_fixed_binary_vector(id)) W
INFO: done init usearch index
INFO: inserted 3 elements
INFO: done saving 3 vectors
-- Make sure that lantern_hnsw is working correctly alongside pgvector
CREATE TABLE small_world_arr (id SERIAL PRIMARY KEY, v REAL[]);
INSERT INTO small_world_arr (v) VALUES ('{0,0,0}'), ('{0,0,1}'), ('{0,0,2}');
CREATE INDEX l2_idx ON small_world_arr USING lantern_hnsw(v) WITH (dim=3, m=2);
INFO: done init usearch index
INFO: inserted 3 elements
INFO: done saving 3 vectors
EXPLAIN (COSTS FALSE) SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0];
QUERY PLAN
--------------------------------------------
Index Scan using l2_idx on small_world_arr
Order By: (v <-> '{0,0,0}'::real[])
(2 rows)

SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0];
id
----
1
2
3
(3 rows)

DROP INDEX l2_idx;
CREATE INDEX cos_idx ON small_world_arr USING lantern_hnsw(v) WITH (m=2);
INFO: done init usearch index
INFO: inserted 3 elements
INFO: done saving 3 vectors
SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0];
id
----
1
2
3
(3 rows)

DROP INDEX cos_idx;
CREATE INDEX ham_idx ON small_world_arr USING lantern_hnsw(v) WITH (m=3);
INFO: done init usearch index
INFO: inserted 3 elements
INFO: done saving 3 vectors
SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0];
id
----
1
2
3
(3 rows)

13 changes: 13 additions & 0 deletions test/sql/hnsw_vector.sql
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,16 @@ END;
$$ LANGUAGE plpgsql IMMUTABLE;

CREATE INDEX ON test_table USING lantern_hnsw (int_to_fixed_binary_vector(id)) WITH (M=2);

-- Make sure that lantern_hnsw is working correctly alongside pgvector
CREATE TABLE small_world_arr (id SERIAL PRIMARY KEY, v REAL[]);
INSERT INTO small_world_arr (v) VALUES ('{0,0,0}'), ('{0,0,1}'), ('{0,0,2}');
CREATE INDEX l2_idx ON small_world_arr USING lantern_hnsw(v) WITH (dim=3, m=2);
EXPLAIN (COSTS FALSE) SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0];
SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0];
DROP INDEX l2_idx;
CREATE INDEX cos_idx ON small_world_arr USING lantern_hnsw(v) WITH (m=2);
SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0];
DROP INDEX cos_idx;
CREATE INDEX ham_idx ON small_world_arr USING lantern_hnsw(v) WITH (m=3);
SELECT id FROM small_world_arr ORDER BY v <-> ARRAY[0,0,0];

0 comments on commit 29af187

Please sign in to comment.