diff --git a/src/hnsw.c b/src/hnsw.c index 42e70e4e4..ae811311e 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -7,7 +7,6 @@ #include #include #include -#include #include #include #include diff --git a/src/hnsw/build.c b/src/hnsw/build.c index fb32c0f69..89bc5d739 100644 --- a/src/hnsw/build.c +++ b/src/hnsw/build.c @@ -8,11 +8,13 @@ #include #include #include +#include #include #include #include #include #include + #ifdef _WIN32 #define access _access #else @@ -141,7 +143,7 @@ static int GetArrayLengthFromExpression(Expr *expression, Relation heap, HeapTup EState *estate; Datum result; bool isNull; - ArrayType *array; + Oid resultOid; TupleTableSlot *slot; TupleDesc tupdesc = RelationGetDescr(heap); @@ -168,34 +170,32 @@ static int GetArrayLengthFromExpression(Expr *expression, Relation heap, HeapTup // Evaluate the expression for the first row result = ExecEvalExprSwitchContext(exprstate, econtext, &isNull); - array = DatumGetArrayTypeP(result); // Release tuple descriptor ReleaseTupleDesc(tupdesc); - Oid oid = get_array_type(array->elemtype); + // Get the return type information + get_expr_result_type((Node *)exprstate->expr, &resultOid, NULL); + + HnswColumnType columnType = GetColumnTypeFromOid(resultOid); - if(!OidIsValid(oid)) { - // Oid 0 can only be in a case when the datum will be vector type - // As postgres will guard that arbitrary data type won't be returned - // from the function expression - // so we should be safe here to case result into vector in this case + if(columnType == REAL_ARRAY || columnType == INT_ARRAY) { + ArrayType *array = DatumGetArrayTypeP(result); + return ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array)); + } else if(columnType == VECTOR) { Vector *vector = DatumGetVector(result); - // if the vector will not have dimensions or somehow it - // won't be casted the code will go to ldb_invariant case - if(vector->dim > 0) { - return vector->dim; - } + return vector->dim; + } else { + // Check if the result is not null and is supported type + // There is a guard in postgres that wont' allow passing + // Anything else from the defined operator class types + // Throwing an error like: ERROR: data type text has no default operator class for access method "hnsw" + // So this case will be marked as invariant + ldb_invariant(!isNull && columnType != UNKNOWN, + "Expression used in CREATE INDEX statement did not result in hnsw-index compatible array"); } - // Check if the result is not null and is supported type - // There is a guard in postgres that wont' allow passing - // Anything else from the defined operator class types - // Throwing an error like: ERROR: data type text has no default operator class for access method "hnsw" - // So this case will be marked as invariant - ldb_invariant(!isNull && GetColumnTypeFromOid(oid) != UNKNOWN, - "Expression used in CREATE INDEX statement did not result in hnsw-index compatible array"); - - return ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array)); + + return HNSW_DEFAULT_DIM; } static int GetArrayLengthFromHeap(Relation heap, int indexCol, IndexInfo *indexInfo) diff --git a/test/expected/hnsw_create_expr.out b/test/expected/hnsw_create_expr.out index 17595fddc..70a942c85 100644 --- a/test/expected/hnsw_create_expr.out +++ b/test/expected/hnsw_create_expr.out @@ -1,3 +1,10 @@ +/* +This function, int_to_fixed_binary_real_array(n INT), will create a 3-dimensional float array (REAL[]). +It fills the array with the first 3 bits of the passed integer 'n' by converting 'n' to binary, +left-padding it to 3 digits, and then converting each digit to a REAL value. +For example, int_to_fixed_binary_real_array(1); will result in the array {0,0,1}, +and int_to_fixed_binary_real_array(2); will result in {0,1,0}. +*/ CREATE OR REPLACE FUNCTION int_to_fixed_binary_real_array(n INT) RETURNS REAL[] AS $$ DECLARE binary_string TEXT; @@ -12,6 +19,13 @@ BEGIN RETURN real_array; END; $$ LANGUAGE plpgsql IMMUTABLE; +/* +This function, int_to_dynamic_binary_real_array(n INT), will create a 3+n dimensional float array (REAL[]). +It first fills the first 3 elements of the array with the first 3 bits of the passed integer 'n' +(using a similar binary conversion as the previous function), and then adds elements sequentially from 4 to 'n+3'. +For example, int_to_dynamic_binary_real_array(3); will result in the array {0,1,1,1,2,3}, +and int_to_dynamic_binary_real_array(4); will result in the array {1,0,0,1,2,3,4}. +*/ CREATE OR REPLACE FUNCTION int_to_dynamic_binary_real_array(n INT) RETURNS REAL[] AS $$ DECLARE binary_string TEXT; @@ -36,6 +50,11 @@ BEGIN RETURN real_array; END; $$ LANGUAGE plpgsql IMMUTABLE; +/* +This simple function, int_to_string(n INT), converts the integer 'n' to a 3-character text representation +by converting 'n' to binary and left-padding it to 3 digits with '0's. +For example, int_to_string(1); will return '001', and int_to_string(2); will return '010'. +*/ CREATE OR REPLACE FUNCTION int_to_string(n INT) RETURNS TEXT AS $$ BEGIN RETURN lpad(CAST(n::BIT(3) AS TEXT), 3, '0'); @@ -62,5 +81,5 @@ CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id), int_t ERROR: access method "hnsw" does not support multicolumn indexes -- This currently results in an error about using the operator outside of index -- This case should be fixed -SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> int_to_fixed_binary_real_array(0) LIMIT 2; +SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> '{0,0,0}'::REAL[] LIMIT 2; ERROR: Operator <-> can only be used inside of an index diff --git a/test/expected/hnsw_todo.out b/test/expected/hnsw_todo.out index 9b50c6920..8eb3b1b75 100644 --- a/test/expected/hnsw_todo.out +++ b/test/expected/hnsw_todo.out @@ -103,3 +103,25 @@ SELECT ROUND(l2sq_dist(v, :'v1001')::numeric, 2) FROM sift_base1k order by v <-> 249285.00 (1 row) +---- Query on expression based index is failing to check correct <-> operator usage -------- +CREATE OR REPLACE FUNCTION int_to_fixed_binary_real_array(n INT) RETURNS REAL[] AS $$ +DECLARE + binary_string TEXT; + real_array REAL[] := '{}'; + i INT; +BEGIN + binary_string := lpad(CAST(n::BIT(3) AS TEXT), 3, '0'); + FOR i IN 1..length(binary_string) + LOOP + real_array := array_append(real_array, CAST(substring(binary_string, i, 1) AS REAL)); + END LOOP; + RETURN real_array; +END; +$$ LANGUAGE plpgsql IMMUTABLE; +CREATE TABLE test_table (id INTEGER); +INSERT INTO test_table VALUES (0), (1), (7); +\set enable_seqscan = off; +-- This currently results in an error about using the operator outside of index +-- This case should be fixed +SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> '{0,0,0}'::REAL[] LIMIT 2; +ERROR: Operator <-> can only be used inside of an index diff --git a/test/sql/hnsw_create_expr.sql b/test/sql/hnsw_create_expr.sql index 0504c0074..ecabfbaf9 100644 --- a/test/sql/hnsw_create_expr.sql +++ b/test/sql/hnsw_create_expr.sql @@ -1,3 +1,10 @@ +/* +This function, int_to_fixed_binary_real_array(n INT), will create a 3-dimensional float array (REAL[]). +It fills the array with the first 3 bits of the passed integer 'n' by converting 'n' to binary, +left-padding it to 3 digits, and then converting each digit to a REAL value. +For example, int_to_fixed_binary_real_array(1); will result in the array {0,0,1}, +and int_to_fixed_binary_real_array(2); will result in {0,1,0}. +*/ CREATE OR REPLACE FUNCTION int_to_fixed_binary_real_array(n INT) RETURNS REAL[] AS $$ DECLARE binary_string TEXT; @@ -13,6 +20,13 @@ BEGIN END; $$ LANGUAGE plpgsql IMMUTABLE; +/* +This function, int_to_dynamic_binary_real_array(n INT), will create a 3+n dimensional float array (REAL[]). +It first fills the first 3 elements of the array with the first 3 bits of the passed integer 'n' +(using a similar binary conversion as the previous function), and then adds elements sequentially from 4 to 'n+3'. +For example, int_to_dynamic_binary_real_array(3); will result in the array {0,1,1,1,2,3}, +and int_to_dynamic_binary_real_array(4); will result in the array {1,0,0,1,2,3,4}. +*/ CREATE OR REPLACE FUNCTION int_to_dynamic_binary_real_array(n INT) RETURNS REAL[] AS $$ DECLARE binary_string TEXT; @@ -38,6 +52,11 @@ BEGIN END; $$ LANGUAGE plpgsql IMMUTABLE; +/* +This simple function, int_to_string(n INT), converts the integer 'n' to a 3-character text representation +by converting 'n' to binary and left-padding it to 3 digits with '0's. +For example, int_to_string(1); will return '001', and int_to_string(2); will return '010'. +*/ CREATE OR REPLACE FUNCTION int_to_string(n INT) RETURNS TEXT AS $$ BEGIN RETURN lpad(CAST(n::BIT(3) AS TEXT), 3, '0'); @@ -65,4 +84,4 @@ CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id), int_t -- This currently results in an error about using the operator outside of index -- This case should be fixed -SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> int_to_fixed_binary_real_array(0) LIMIT 2; +SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> '{0,0,0}'::REAL[] LIMIT 2; diff --git a/test/sql/hnsw_todo.sql b/test/sql/hnsw_todo.sql index 34f240a2c..2159d7eba 100644 --- a/test/sql/hnsw_todo.sql +++ b/test/sql/hnsw_todo.sql @@ -72,3 +72,28 @@ CREATE INDEX hnsw_l2_index ON sift_base1k USING hnsw (v) WITH (_experimental_ind -- So the usearch index can not find 1,1,1,1,1.. vector in the index and wrong results will be returned -- This is an expected behaviour for now SELECT ROUND(l2sq_dist(v, :'v1001')::numeric, 2) FROM sift_base1k order by v <-> :'v1001' LIMIT 1; + +---- Query on expression based index is failing to check correct <-> operator usage -------- +CREATE OR REPLACE FUNCTION int_to_fixed_binary_real_array(n INT) RETURNS REAL[] AS $$ +DECLARE + binary_string TEXT; + real_array REAL[] := '{}'; + i INT; +BEGIN + binary_string := lpad(CAST(n::BIT(3) AS TEXT), 3, '0'); + FOR i IN 1..length(binary_string) + LOOP + real_array := array_append(real_array, CAST(substring(binary_string, i, 1) AS REAL)); + END LOOP; + RETURN real_array; +END; +$$ LANGUAGE plpgsql IMMUTABLE; + +CREATE TABLE test_table (id INTEGER); +INSERT INTO test_table VALUES (0), (1), (7); + +\set enable_seqscan = off; +-- This currently results in an error about using the operator outside of index +-- This case should be fixed +SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> '{0,0,0}'::REAL[] LIMIT 2; +