Skip to content

Commit

Permalink
Fix l2 operator bug when used with int arrays (#243)
Browse files Browse the repository at this point in the history
* Add test showing implicit type casts are banned for indexes

* Fix l2 operator bug when used with int arrays
  • Loading branch information
Ngalstyan4 authored Dec 11, 2023
1 parent 6c0d2d8 commit 896ed5a
Show file tree
Hide file tree
Showing 12 changed files with 128 additions and 58 deletions.
14 changes: 1 addition & 13 deletions sql/lantern.sql
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,7 @@ BEGIN
CREATE FUNCTION cos_dist(vector, vector) RETURNS float8
AS 'MODULE_PATHNAME', 'vector_cos_dist' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

CREATE FUNCTION hamming_dist(vector, vector) RETURNS float8
AS 'MODULE_PATHNAME', 'vector_hamming_dist' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;

CREATE OPERATOR <+> (
LEFTARG = vector, RIGHTARG = vector, PROCEDURE = hamming_dist,
COMMUTATOR = '<+>'
);
-- pgvecor's vector type requires floats and we cannot define hamming distance for floats

CREATE OPERATOR CLASS dist_vec_l2sq_ops
DEFAULT FOR TYPE vector USING lantern_hnsw AS
Expand All @@ -162,12 +156,6 @@ BEGIN
OPERATOR 2 <=> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 2 cos_dist(vector, vector);

CREATE OPERATOR CLASS dist_vec_hamming_ops
FOR TYPE vector USING lantern_hnsw AS
OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 1 hamming_dist(vector, vector),
OPERATOR 2 <+> (vector, vector) FOR ORDER BY float_ops,
FUNCTION 2 hamming_dist(vector, vector);
END IF;


Expand Down
5 changes: 5 additions & 0 deletions sql/updates/0.0.9--0.0.10.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-- these go for good.

DROP OPERATOR CLASS IF EXISTS dist_vec_hamming_ops USING hnsw CASCADE;
DROP FUNCTION IF EXISTS cos_dist(vector, vector);
DROP OPERATOR <+>(vector, vector) CASCADE
12 changes: 7 additions & 5 deletions src/hnsw.c
Original file line number Diff line number Diff line change
Expand Up @@ -295,20 +295,22 @@ static float4 array_dist(ArrayType *a, ArrayType *b, usearch_metric_kind_t metri
}

float4 result;
bool is_int_array = (metric_kind == usearch_metric_hamming_k);

if(is_int_array) {
if(metric_kind == usearch_metric_hamming_k) {
// when computing hamming distance, array element type must be an integer type
if(ARR_ELEMTYPE(a) != INT4OID || ARR_ELEMTYPE(b) != INT4OID) {
elog(ERROR, "expected integer array but got array with element type %d", ARR_ELEMTYPE(a));
}
int32 *ax_int = (int32 *)ARR_DATA_PTR(a);
int32 *bx_int = (int32 *)ARR_DATA_PTR(b);

// calling usearch_scalar_f32_k here even though it's an integer array is fine
// the hamming distance in usearch actually ignores the scalar type
// and it will get casted appropriately in usearch even with this scalar type
result = usearch_dist(ax_int, bx_int, metric_kind, a_dim, usearch_scalar_f32_k);

} else {
float4 *ax = (float4 *)ARR_DATA_PTR(a);
float4 *bx = (float4 *)ARR_DATA_PTR(b);
float4 *ax = ToFloat4Array(a);
float4 *bx = ToFloat4Array(b);

result = usearch_dist(ax, bx, metric_kind, a_dim, usearch_scalar_f32_k);
}
Expand Down
2 changes: 1 addition & 1 deletion src/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ PGDLLEXPORT Datum vector_hamming_dist(PG_FUNCTION_ARGS);

HnswColumnType GetColumnTypeFromOid(Oid oid);
HnswColumnType GetIndexColumnType(Relation index);
void *DatumGetSizedArray(Datum datum, HnswColumnType type, int dimensions);
void* DatumGetSizedArray(Datum datum, HnswColumnType type, int dimensions);

#define LDB_UNUSED(x) (void)(x)

Expand Down
22 changes: 22 additions & 0 deletions src/hnsw/utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "utils.h"

#include <assert.h>
#include <catalog/pg_type_d.h>
#include <math.h>
#include <miscadmin.h>
#include <regex.h>
Expand Down Expand Up @@ -81,3 +82,24 @@ void CheckMem(int limit, Relation index, usearch_index_t uidx, uint32 n_nodes, c
elog(WARNING, "%s", msg);
}
}

// if the element type of the passed array is already float4, this function just returns that pointer
// otherwise, it allocates a new array, casts all elements to float4 and returns the resulting array
float4 *ToFloat4Array(ArrayType *arr)
{
Oid element_type = ARR_ELEMTYPE(arr);
if(element_type == FLOAT4OID) {
return (float4 *)ARR_DATA_PTR(arr);
} else if(element_type == INT4OID) {
int arr_dim = ArrayGetNItems(ARR_NDIM(arr), ARR_DIMS(arr));

float4 *result = palloc(arr_dim * sizeof(int32));
int32 *typed_src = (int32 *)ARR_DATA_PTR(arr);
for(int i = 0; i < arr_dim; i++) {
result[ i ] = typed_src[ i ];
}
return result;
} else {
elog(ERROR, "unsupported element type: %d", element_type);
}
}
2 changes: 2 additions & 0 deletions src/hnsw/utils.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef LDB_HNSW_UTILS_H
#define LDB_HNSW_UTILS_H
#include <access/amapi.h>
#include <utils/array.h>

#include "options.h"
#include "usearch.h"
Expand All @@ -9,6 +10,7 @@ void CheckMem(int limit, Relation index, usearch_index_t uidx, uint32
void LogUsearchOptions(usearch_init_options_t *opts);
void PopulateUsearchOpts(Relation index, usearch_init_options_t *opts);
usearch_label_t GetUsearchLabel(ItemPointer itemPtr);
float4 *ToFloat4Array(ArrayType *arr);

static inline void ldb_invariant(bool condition, const char *msg, ...)
{
Expand Down
9 changes: 8 additions & 1 deletion test/expected/hnsw_insert.out
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ set work_mem = '64kB';
set client_min_messages = 'ERROR';
CREATE TABLE small_world (
id SERIAL PRIMARY KEY,
v REAL[2]
v REAL[2] -- this demonstates that postgres actually does not enforce real[] length as we actually insert vectors of length 3
);
CREATE TABLE small_world_int (
id SERIAL PRIMARY KEY,
v INTEGER[]
);
CREATE INDEX ON small_world USING hnsw (v) WITH (dim=3);
INFO: done init usearch index
Expand All @@ -28,6 +32,9 @@ INSERT INTO small_world (v) VALUES ('{0,0,1}'), ('{0,1,0}');
INSERT INTO small_world (v) VALUES (NULL);
-- Attempt to insert a row with an incorrect vector length
\set ON_ERROR_STOP off
-- Cannot create an hnsw index with implicit typecasts (trying to cast integer[] to real[], in this case)
CREATE INDEX ON small_world_int USING hnsw (v dist_l2sq_ops) WITH (dim=3);
ERROR: operator class "dist_l2sq_ops" does not accept data type integer[]
INSERT INTO small_world (v) VALUES ('{1,1,1,1}');
ERROR: Wrong number of dimensions: 4 instead of 3 expected
\set ON_ERROR_STOP on
Expand Down
58 changes: 58 additions & 0 deletions test/expected/hnsw_operators.out
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ INFO: done init usearch index
INFO: inserted 2 elements
INFO: done saving 2 vectors
-- should rewrite operator
SET lantern.pgvector_compat=FALSE;
SELECT * FROM op_test ORDER BY v <-> ARRAY[1,1,1];
v
---------
Expand All @@ -27,6 +28,63 @@ ERROR: Operator <-> is invalid outside of ORDER BY context
SET lantern.pgvector_compat=TRUE;
SET enable_seqscan=OFF;
\set ON_ERROR_STOP on
-- one-off vector distance calculations should work with relevant operator
-- with integer arrays:
SELECT ARRAY[0,0,0] <-> ARRAY[2,3,-4];
?column?
----------
29
(1 row)

-- with float arrays:
SELECT ARRAY[0,0,0] <-> ARRAY[2,3,-4]::real[];
?column?
----------
29
(1 row)

SELECT ARRAY[0,0,0]::real[] <-> ARRAY[2,3,-4]::real[];
?column?
----------
29
(1 row)

SELECT '{1,0,1}' <-> '{0,1,0}'::integer[];
?column?
----------
3
(1 row)

SELECT '{1,0,1}' <=> '{0,1,0}'::integer[];
?column?
----------
1
(1 row)

SELECT ROUND(num::NUMERIC, 2) FROM (SELECT '{1,1,1}' <=> '{0,1,0}'::INTEGER[] AS num) _sub;
round
-------
0.42
(1 row)

SELECT ARRAY[.1,0,0] <=> ARRAY[0,.5,0];
?column?
----------
1
(1 row)

SELECT cos_dist(ARRAY[.1,0,0]::real[], ARRAY[0,.5,0]::real[]);
cos_dist
----------
1
(1 row)

SELECT ARRAY[1,0,0] <+> ARRAY[0,1,0];
?column?
----------
2
(1 row)

-- NOW THIS IS TRIGGERING INDEX SCAN AS WELL
-- BECAUSE WE ARE REGISTERING <-> FOR ALL OPERATOR CLASSES
-- IDEALLY THIS SHOULD NOT TRIGGER INDEX SCAN WHEN lantern.pgvector_compat=TRUE
Expand Down
27 changes: 0 additions & 27 deletions test/expected/hnsw_vector.out
Original file line number Diff line number Diff line change
Expand Up @@ -321,30 +321,3 @@ FROM small_world ORDER BY v <=> '[0,1,0]'::VECTOR LIMIT 7;
Order By: (v <=> '[0,1,0]'::vector)
(3 rows)

-- hamming index
CREATE INDEX hamming_idx ON small_world USING lantern_hnsw (v dist_vec_hamming_ops);
INFO: done init usearch index
INFO: inserted 8 elements
INFO: done saving 8 vectors
SELECT ROUND((v <+> '[0,1,0]'::VECTOR)::numeric, 2) as dist
FROM small_world ORDER BY v <+> '[0,1,0]'::VECTOR LIMIT 7;
dist
-------
0.00
7.00
7.00
7.00
14.00
14.00
14.00
(7 rows)

EXPLAIN (COSTS FALSE) SELECT ROUND((v <+> '[0,1,0]'::VECTOR)::numeric, 2) as dist
FROM small_world ORDER BY v <+> '[0,1,0]'::VECTOR LIMIT 7;
QUERY PLAN
---------------------------------------------------
Limit
-> Index Scan using hamming_idx on small_world
Order By: (v <+> '[0,1,0]'::vector)
(3 rows)

10 changes: 9 additions & 1 deletion test/sql/hnsw_insert.sql
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,14 @@ set client_min_messages = 'ERROR';

CREATE TABLE small_world (
id SERIAL PRIMARY KEY,
v REAL[2]
v REAL[2] -- this demonstates that postgres actually does not enforce real[] length as we actually insert vectors of length 3
);

CREATE TABLE small_world_int (
id SERIAL PRIMARY KEY,
v INTEGER[]
);

CREATE INDEX ON small_world USING hnsw (v) WITH (dim=3);
SELECT _lantern_internal.validate_index('small_world_v_idx', false);

Expand All @@ -21,6 +27,8 @@ INSERT INTO small_world (v) VALUES (NULL);

-- Attempt to insert a row with an incorrect vector length
\set ON_ERROR_STOP off
-- Cannot create an hnsw index with implicit typecasts (trying to cast integer[] to real[], in this case)
CREATE INDEX ON small_world_int USING hnsw (v dist_l2sq_ops) WITH (dim=3);
INSERT INTO small_world (v) VALUES ('{1,1,1,1}');
\set ON_ERROR_STOP on

Expand Down
14 changes: 14 additions & 0 deletions test/sql/hnsw_operators.sql
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ CREATE TABLE op_test (v REAL[]);
INSERT INTO op_test (v) VALUES (ARRAY[0,0,0]), (ARRAY[1,1,1]);
CREATE INDEX cos_idx ON op_test USING hnsw(v dist_cos_ops);
-- should rewrite operator
SET lantern.pgvector_compat=FALSE;
SELECT * FROM op_test ORDER BY v <-> ARRAY[1,1,1];

-- should throw error
Expand All @@ -20,6 +21,19 @@ SET lantern.pgvector_compat=TRUE;
SET enable_seqscan=OFF;
\set ON_ERROR_STOP on

-- one-off vector distance calculations should work with relevant operator
-- with integer arrays:
SELECT ARRAY[0,0,0] <-> ARRAY[2,3,-4];
-- with float arrays:
SELECT ARRAY[0,0,0] <-> ARRAY[2,3,-4]::real[];
SELECT ARRAY[0,0,0]::real[] <-> ARRAY[2,3,-4]::real[];
SELECT '{1,0,1}' <-> '{0,1,0}'::integer[];
SELECT '{1,0,1}' <=> '{0,1,0}'::integer[];
SELECT ROUND(num::NUMERIC, 2) FROM (SELECT '{1,1,1}' <=> '{0,1,0}'::INTEGER[] AS num) _sub;
SELECT ARRAY[.1,0,0] <=> ARRAY[0,.5,0];
SELECT cos_dist(ARRAY[.1,0,0]::real[], ARRAY[0,.5,0]::real[]);
SELECT ARRAY[1,0,0] <+> ARRAY[0,1,0];

-- NOW THIS IS TRIGGERING INDEX SCAN AS WELL
-- BECAUSE WE ARE REGISTERING <-> FOR ALL OPERATOR CLASSES
-- IDEALLY THIS SHOULD NOT TRIGGER INDEX SCAN WHEN lantern.pgvector_compat=TRUE
Expand Down
11 changes: 1 addition & 10 deletions test/sql/hnsw_vector.sql
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,4 @@ SELECT ROUND(cos_dist(v, '[0,1,0]'::VECTOR)::numeric, 2) as dist
FROM small_world ORDER BY v <=> '[0,1,0]'::VECTOR LIMIT 7;

EXPLAIN (COSTS FALSE) SELECT ROUND(cos_dist(v, '[0,1,0]'::VECTOR)::numeric, 2) as dist
FROM small_world ORDER BY v <=> '[0,1,0]'::VECTOR LIMIT 7;

-- hamming index
CREATE INDEX hamming_idx ON small_world USING lantern_hnsw (v dist_vec_hamming_ops);

SELECT ROUND((v <+> '[0,1,0]'::VECTOR)::numeric, 2) as dist
FROM small_world ORDER BY v <+> '[0,1,0]'::VECTOR LIMIT 7;

EXPLAIN (COSTS FALSE) SELECT ROUND((v <+> '[0,1,0]'::VECTOR)::numeric, 2) as dist
FROM small_world ORDER BY v <+> '[0,1,0]'::VECTOR LIMIT 7;
FROM small_world ORDER BY v <=> '[0,1,0]'::VECTOR LIMIT 7;

0 comments on commit 896ed5a

Please sign in to comment.