diff --git a/scripts/run_all_tests.sh b/scripts/run_all_tests.sh index de9df7351..7b6e79cda 100755 --- a/scripts/run_all_tests.sh +++ b/scripts/run_all_tests.sh @@ -71,13 +71,14 @@ pgvector_installed=$($PSQL -U $DB_USER -d postgres -c "SELECT 1 FROM pg_availabl rm -rf $TMP_OUTDIR/schedule.txt if [ -n "$FILTER" ]; then if [[ "$pgvector_installed" == "1" ]]; then - TEST_FILES=$(cat schedule.txt | grep -E '^(test:|test_pgvector:)' | sed -e 's/^\(test:\|test_pgvector:\)//' | tr " " "\n" | sed -e '/^$/d') + TEST_FILES=$(cat schedule.txt | grep -E '^(test:|test_pgvector:)' | sed -E -e 's/^test:|test_pgvector://' | tr " " "\n" | sed -e '/^$/d') else TEST_FILES=$(cat schedule.txt | grep '^test:' | sed -e 's/^test://' | tr " " "\n" | sed -e '/^$/d') fi while IFS= read -r f; do if [[ $f == *"$FILTER"* ]]; then + echo "HERE $f" echo "test: $f" >> $TMP_OUTDIR/schedule.txt fi done <<< "$TEST_FILES" diff --git a/src/hnsw.c b/src/hnsw.c index 9276db57e..42e70e4e4 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -343,25 +344,33 @@ Datum vector_l2sq_dist(PG_FUNCTION_ARGS) } /* - * Get data type of index - */ -HnswColumnType GetIndexColumnType(Relation index) + * Get data type for give oid + * */ +HnswColumnType GetColumnTypeFromOid(Oid oid) { - TupleDesc indexTupDesc = RelationGetDescr(index); - Form_pg_attribute attr = TupleDescAttr(indexTupDesc, 0); - Oid columnType = attr->atttypid; + ldb_invariant(OidIsValid(oid), "Invalid oid passed"); - if(columnType == FLOAT4ARRAYOID) { + if(oid == FLOAT4ARRAYOID) { return REAL_ARRAY; - } else if(columnType == TypenameGetTypid("vector")) { + } else if(oid == TypenameGetTypid("vector")) { return VECTOR; - } else if(columnType == INT4ARRAYOID) { + } else if(oid == INT4ARRAYOID) { return INT_ARRAY; } else { return UNKNOWN; } } +/* + * Get data type of index + */ +HnswColumnType GetIndexColumnType(Relation index) +{ + TupleDesc indexTupDesc = RelationGetDescr(index); + Form_pg_attribute attr = TupleDescAttr(indexTupDesc, 0); + return GetColumnTypeFromOid(attr->atttypid); +} + /* * Given vector data and vector type, convert it to a float4 array */ diff --git a/src/hnsw.h b/src/hnsw.h index 2f448687c..0542f705e 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -32,6 +32,7 @@ PGDLLEXPORT Datum vector_l2sq_dist(PG_FUNCTION_ARGS); PGDLLEXPORT Datum hamming_dist(PG_FUNCTION_ARGS); PGDLLEXPORT Datum cos_dist(PG_FUNCTION_ARGS); +HnswColumnType GetColumnTypeFromOid(Oid oid); HnswColumnType GetIndexColumnType(Relation index); float4 *DatumGetSizedFloatArray(Datum datum, HnswColumnType type, int dimensions); diff --git a/src/hnsw/build.c b/src/hnsw/build.c index a4f13b8d4..fb32c0f69 100644 --- a/src/hnsw/build.c +++ b/src/hnsw/build.c @@ -142,7 +142,6 @@ static int GetArrayLengthFromExpression(Expr *expression, Relation heap, HeapTup Datum result; bool isNull; ArrayType *array; - int n_items = HNSW_DEFAULT_DIM; TupleTableSlot *slot; TupleDesc tupdesc = RelationGetDescr(heap); @@ -169,21 +168,34 @@ static int GetArrayLengthFromExpression(Expr *expression, Relation heap, HeapTup // Evaluate the expression for the first row result = ExecEvalExprSwitchContext(exprstate, econtext, &isNull); + array = DatumGetArrayTypeP(result); // Release tuple descriptor ReleaseTupleDesc(tupdesc); - // Check if the result is not null - if(!isNull) { - // todo check if datum is array - array = DatumGetArrayTypePCopy(result); - n_items = ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array)); - return n_items; - } else { - elog(ERROR, "Expression did not result in an array"); + Oid oid = get_array_type(array->elemtype); + + if(!OidIsValid(oid)) { + // Oid 0 can only be in a case when the datum will be vector type + // As postgres will guard that arbitrary data type won't be returned + // from the function expression + // so we should be safe here to case result into vector in this case + Vector *vector = DatumGetVector(result); + // if the vector will not have dimensions or somehow it + // won't be casted the code will go to ldb_invariant case + if(vector->dim > 0) { + return vector->dim; + } } - - return n_items; + // Check if the result is not null and is supported type + // There is a guard in postgres that wont' allow passing + // Anything else from the defined operator class types + // Throwing an error like: ERROR: data type text has no default operator class for access method "hnsw" + // So this case will be marked as invariant + ldb_invariant(!isNull && GetColumnTypeFromOid(oid) != UNKNOWN, + "Expression used in CREATE INDEX statement did not result in hnsw-index compatible array"); + + return ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array)); } static int GetArrayLengthFromHeap(Relation heap, int indexCol, IndexInfo *indexInfo) @@ -216,6 +228,11 @@ static int GetArrayLengthFromHeap(Relation heap, int indexCol, IndexInfo *indexI } if(indexInfo->ii_Expressions != NULL) { + // We don't suport multicolumn indexes + // So trying to pass multiple expressions on index creation + // Will result an error before getting here + ldb_invariant(indexInfo->ii_Expressions->length == 1, + "Index expressions can not be greater than 1 as multicolumn indexes are not supported"); Expr *indexpr_item = lfirst(list_head(indexInfo->ii_Expressions)); n_items = GetArrayLengthFromExpression(indexpr_item, heap, tuple); } else { diff --git a/src/hooks/executor_start.c b/src/hooks/executor_start.c index 866cae499..b323606d2 100644 --- a/src/hooks/executor_start.c +++ b/src/hooks/executor_start.c @@ -83,4 +83,4 @@ void ExecutorStart_hook_with_operator_check(QueryDesc *queryDesc, int eflags) } standard_ExecutorStart(queryDesc, eflags); -} \ No newline at end of file +} diff --git a/src/hooks/post_parse.c b/src/hooks/post_parse.c index 2a08ec6c9..85455686a 100644 --- a/src/hooks/post_parse.c +++ b/src/hooks/post_parse.c @@ -88,6 +88,18 @@ static bool operator_used_incorrectly_walker(Node *node, OperatorUsedCorrectlyCo Node *arg2 = (Node *)lsecond(opExpr->args); bool isVar1 = IsA(arg1, Var); bool isVar2 = IsA(arg2, Var); + /* There is a case when operator is used with index + * that was created via expression (CREATE INDEX ON t USING hnsw (func(id)) WITH (M=2)) + * in this case the query may look like this + * SELECT id FROM test ORDER BY func(id) <-> ARRAY[0,0,0] LIMIT 2 + * or like this + * SELECT id FROM test ORDER BY func(id) <-> func(n) LIMIT 2 + * we should check if IsA(arg1, FuncExpr) || IsA(arg2, FuncExpr) + * if true we may go and check the oid of function result to see if it is an array type + * we also can check that the argument of FuncExpr is at least one of the arg1 and arg2 + * will contain column of the table (e.g iterate over list and check IsA(arg, Var)) + * so the function will not be called on constant arguments on both sides + */ if(isVar1 && isVar2) { return false; } else if(!isVar1 && !isVar2) { @@ -145,4 +157,4 @@ void post_parse_analyze_hook_with_operator_check(ParseState *pstate, list_free(sort_group_refs); } list_free(oidList); -} \ No newline at end of file +} diff --git a/test/expected/hnsw_create_expr.out b/test/expected/hnsw_create_expr.out index 089818ac6..ccf76725f 100644 --- a/test/expected/hnsw_create_expr.out +++ b/test/expected/hnsw_create_expr.out @@ -12,8 +12,54 @@ BEGIN RETURN real_array; END; $$ LANGUAGE plpgsql IMMUTABLE; +CREATE OR REPLACE FUNCTION int_to_dynamic_binary_real_array(n INT) RETURNS REAL[] AS $$ +DECLARE + binary_string TEXT; + real_array REAL[] := '{}'; + i INT; + result_length INT; +BEGIN + binary_string := lpad(CAST(n::BIT(3) AS TEXT), 3, '0'); + + -- Calculate the length of the result array + result_length := 3 + n; + + FOR i IN 1..result_length + LOOP + IF i <= 3 THEN + real_array := array_append(real_array, CAST(substring(binary_string, i, 1) AS REAL)); + ELSE + real_array := array_append(real_array, CAST(i - 3 AS REAL)); + END IF; + END LOOP; + + RETURN real_array; +END; +$$ LANGUAGE plpgsql IMMUTABLE; +CREATE OR REPLACE FUNCTION int_to_string(n INT) RETURNS TEXT AS $$ +BEGIN + RETURN lpad(CAST(n::BIT(3) AS TEXT), 3, '0'); +END; +$$ LANGUAGE plpgsql IMMUTABLE; CREATE TABLE test_table (id INTEGER); INSERT INTO test_table VALUES (0), (1), (7); \set enable_seqscan = off; CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id)) WITH (M=2); -ERROR: invalid attnum: 0 +INFO: done init usearch index +INFO: inserted 3 elements +INFO: done saving 3 vectors +\set ON_ERROR_STOP off +-- This should result in an error that dimensions does not match +CREATE INDEX ON test_table USING hnsw (int_to_dynamic_binary_real_array(id)) WITH (M=2); +INFO: done init usearch index +ERROR: Wrong number of dimensions: 4 instead of 3 expected +-- This should result in an error that data type text has no default operator class +CREATE INDEX ON test_table USING hnsw (int_to_string(id)) WITH (M=2); +ERROR: data type text has no default operator class for access method "hnsw" +-- This should result in error about multicolumn expressions support +CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id), int_to_dynamic_binary_real_array(id)) WITH (M=2); +ERROR: access method "hnsw" does not support multicolumn indexes +-- This currently results in an error about using the operator outside of index +-- This case should be fixed +SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> int_to_fixed_binary_real_array(0) LIMIT 2; +ERROR: Operator <-> has no standalone meaning and is reserved for use in vector index lookups only diff --git a/test/expected/hnsw_vector.out b/test/expected/hnsw_vector.out index 814da5500..344716cdd 100644 --- a/test/expected/hnsw_vector.out +++ b/test/expected/hnsw_vector.out @@ -173,3 +173,24 @@ SELECT 1 FROM small_world ORDER BY v <-> '[0,1,0,1]' LIMIT 1; ERROR: Expected vector with dimension 3, got 4 SELECT l2sq_dist('[1,1]'::vector, '[0,1,0]'::vector); ERROR: expected equally sized vectors but got vectors with dimensions 2 and 3 +-- Test creating index with expression +CREATE TABLE test_table (id INTEGER); +INSERT INTO test_table VALUES (0), (1), (7); +CREATE OR REPLACE FUNCTION int_to_fixed_binary_vector(n INT) RETURNS vector AS $$ +DECLARE + binary_string TEXT; + real_array REAL[] := '{}'; + i INT; +BEGIN + binary_string := lpad(CAST(n::BIT(3) AS TEXT), 3, '0'); + FOR i IN 1..length(binary_string) + LOOP + real_array := array_append(real_array, CAST(substring(binary_string, i, 1) AS REAL)); + END LOOP; + RETURN real_array::vector; +END; +$$ LANGUAGE plpgsql IMMUTABLE; +CREATE INDEX ON test_table USING lantern_hnsw (int_to_fixed_binary_vector(id)) WITH (M=2); +INFO: done init usearch index +INFO: inserted 3 elements +INFO: done saving 3 vectors diff --git a/test/sql/hnsw_create_expr.sql b/test/sql/hnsw_create_expr.sql index 732960425..823a89f9b 100644 --- a/test/sql/hnsw_create_expr.sql +++ b/test/sql/hnsw_create_expr.sql @@ -13,9 +13,56 @@ BEGIN END; $$ LANGUAGE plpgsql IMMUTABLE; +CREATE OR REPLACE FUNCTION int_to_dynamic_binary_real_array(n INT) RETURNS REAL[] AS $$ +DECLARE + binary_string TEXT; + real_array REAL[] := '{}'; + i INT; + result_length INT; +BEGIN + binary_string := lpad(CAST(n::BIT(3) AS TEXT), 3, '0'); + + -- Calculate the length of the result array + result_length := 3 + n; + + FOR i IN 1..result_length + LOOP + IF i <= 3 THEN + real_array := array_append(real_array, CAST(substring(binary_string, i, 1) AS REAL)); + ELSE + real_array := array_append(real_array, CAST(i - 3 AS REAL)); + END IF; + END LOOP; + + RETURN real_array; +END; +$$ LANGUAGE plpgsql IMMUTABLE; + +CREATE OR REPLACE FUNCTION int_to_string(n INT) RETURNS TEXT AS $$ +BEGIN + RETURN lpad(CAST(n::BIT(3) AS TEXT), 3, '0'); +END; +$$ LANGUAGE plpgsql IMMUTABLE; + + CREATE TABLE test_table (id INTEGER); INSERT INTO test_table VALUES (0), (1), (7); + \set enable_seqscan = off; + CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id)) WITH (M=2); -SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> int_to_fixed_binary_real_array(0) LIMIT 2; \ No newline at end of file +\set ON_ERROR_STOP off + +-- This should result in an error that dimensions does not match +CREATE INDEX ON test_table USING hnsw (int_to_dynamic_binary_real_array(id)) WITH (M=2); + +-- This should result in an error that data type text has no default operator class +CREATE INDEX ON test_table USING hnsw (int_to_string(id)) WITH (M=2); + +-- This should result in error about multicolumn expressions support +CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id), int_to_dynamic_binary_real_array(id)) WITH (M=2); + +-- This currently results in an error about using the operator outside of index +-- This case should be fixed +SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> int_to_fixed_binary_real_array(0) LIMIT 2; diff --git a/test/sql/hnsw_vector.sql b/test/sql/hnsw_vector.sql index e7aa0285c..cfe282ab9 100644 --- a/test/sql/hnsw_vector.sql +++ b/test/sql/hnsw_vector.sql @@ -78,3 +78,24 @@ SELECT ARRAY[1,2,3] <-> ARRAY[3,2,1]; -- Expect error due to mismatching vector dimensions SELECT 1 FROM small_world ORDER BY v <-> '[0,1,0,1]' LIMIT 1; SELECT l2sq_dist('[1,1]'::vector, '[0,1,0]'::vector); + +-- Test creating index with expression +CREATE TABLE test_table (id INTEGER); +INSERT INTO test_table VALUES (0), (1), (7); + +CREATE OR REPLACE FUNCTION int_to_fixed_binary_vector(n INT) RETURNS vector AS $$ +DECLARE + binary_string TEXT; + real_array REAL[] := '{}'; + i INT; +BEGIN + binary_string := lpad(CAST(n::BIT(3) AS TEXT), 3, '0'); + FOR i IN 1..length(binary_string) + LOOP + real_array := array_append(real_array, CAST(substring(binary_string, i, 1) AS REAL)); + END LOOP; + RETURN real_array::vector; +END; +$$ LANGUAGE plpgsql IMMUTABLE; + +CREATE INDEX ON test_table USING lantern_hnsw (int_to_fixed_binary_vector(id)) WITH (M=2);