From 4c80eca1a536a8ecde22ea39e80002343e8f4501 Mon Sep 17 00:00:00 2001 From: Varik Matevosyan Date: Thu, 28 Sep 2023 13:10:37 +0400 Subject: [PATCH] Infer array dimensions from index expressions (#175) * Add create expression test * Infer array dimensions from index expressions * Fix ExecStoreHeapTuple for pg11 * Add more tests for index expressions, fix index expressions for vector type * Add expr result oid check with invariant, add comments in sql, move failing tests to todo --------- Co-authored-by: Di Qi --- scripts/extern_defined.sh | 2 + scripts/run_all_tests.sh | 3 +- src/hnsw.c | 26 +++++--- src/hnsw.h | 1 + src/hnsw/build.c | 96 ++++++++++++++++++++++++++---- src/hnsw/build.h | 2 +- src/hnsw/insert.c | 3 +- src/hooks/executor_start.c | 2 +- src/hooks/post_parse.c | 14 ++++- test/expected/hnsw_create_expr.out | 66 +++++++++++++++++++- test/expected/hnsw_todo.out | 22 +++++++ test/expected/hnsw_vector.out | 21 +++++++ test/sql/hnsw_create_expr.sql | 70 +++++++++++++++++++++- test/sql/hnsw_todo.sql | 25 ++++++++ test/sql/hnsw_vector.sql | 21 +++++++ 15 files changed, 345 insertions(+), 29 deletions(-) diff --git a/scripts/extern_defined.sh b/scripts/extern_defined.sh index 8e1eefe5f..f48f8626e 100755 --- a/scripts/extern_defined.sh +++ b/scripts/extern_defined.sh @@ -39,6 +39,8 @@ SED_PATTERN='s/@/@/p' # noop pattern nm -D --with-symbol-versions $PG_BIN | grep " T " | awk '{print $3}' | sed -e "$SED_PATTERN" # global bss symbol in postgres nm -D --with-symbol-versions $PG_BIN | grep " B " | awk '{print $3}' | sed -e "$SED_PATTERN" +# postgres Initialized data (bbs), global symbols +nm -D --with-symbol-versions $PG_BIN | grep " D " | awk '{print $3}' | sed -e "$SED_PATTERN" # postgres weak symbols nm -D --with-symbol-versions $PG_BIN | grep " w " | awk '{print $2}' | sed -e "$SED_PATTERN" diff --git a/scripts/run_all_tests.sh b/scripts/run_all_tests.sh index de9df7351..7b6e79cda 100755 --- a/scripts/run_all_tests.sh +++ b/scripts/run_all_tests.sh @@ -71,13 +71,14 @@ pgvector_installed=$($PSQL -U $DB_USER -d postgres -c "SELECT 1 FROM pg_availabl rm -rf $TMP_OUTDIR/schedule.txt if [ -n "$FILTER" ]; then if [[ "$pgvector_installed" == "1" ]]; then - TEST_FILES=$(cat schedule.txt | grep -E '^(test:|test_pgvector:)' | sed -e 's/^\(test:\|test_pgvector:\)//' | tr " " "\n" | sed -e '/^$/d') + TEST_FILES=$(cat schedule.txt | grep -E '^(test:|test_pgvector:)' | sed -E -e 's/^test:|test_pgvector://' | tr " " "\n" | sed -e '/^$/d') else TEST_FILES=$(cat schedule.txt | grep '^test:' | sed -e 's/^test://' | tr " " "\n" | sed -e '/^$/d') fi while IFS= read -r f; do if [[ $f == *"$FILTER"* ]]; then + echo "HERE $f" echo "test: $f" >> $TMP_OUTDIR/schedule.txt fi done <<< "$TEST_FILES" diff --git a/src/hnsw.c b/src/hnsw.c index 9276db57e..ae811311e 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -343,25 +343,33 @@ Datum vector_l2sq_dist(PG_FUNCTION_ARGS) } /* - * Get data type of index - */ -HnswColumnType GetIndexColumnType(Relation index) + * Get data type for give oid + * */ +HnswColumnType GetColumnTypeFromOid(Oid oid) { - TupleDesc indexTupDesc = RelationGetDescr(index); - Form_pg_attribute attr = TupleDescAttr(indexTupDesc, 0); - Oid columnType = attr->atttypid; + ldb_invariant(OidIsValid(oid), "Invalid oid passed"); - if(columnType == FLOAT4ARRAYOID) { + if(oid == FLOAT4ARRAYOID) { return REAL_ARRAY; - } else if(columnType == TypenameGetTypid("vector")) { + } else if(oid == TypenameGetTypid("vector")) { return VECTOR; - } else if(columnType == INT4ARRAYOID) { + } else if(oid == INT4ARRAYOID) { return INT_ARRAY; } else { return UNKNOWN; } } +/* + * Get data type of index + */ +HnswColumnType GetIndexColumnType(Relation index) +{ + TupleDesc indexTupDesc = RelationGetDescr(index); + Form_pg_attribute attr = TupleDescAttr(indexTupDesc, 0); + return GetColumnTypeFromOid(attr->atttypid); +} + /* * Given vector data and vector type, convert it to a float4 array */ diff --git a/src/hnsw.h b/src/hnsw.h index 2f448687c..0542f705e 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -32,6 +32,7 @@ PGDLLEXPORT Datum vector_l2sq_dist(PG_FUNCTION_ARGS); PGDLLEXPORT Datum hamming_dist(PG_FUNCTION_ARGS); PGDLLEXPORT Datum cos_dist(PG_FUNCTION_ARGS); +HnswColumnType GetColumnTypeFromOid(Oid oid); HnswColumnType GetIndexColumnType(Relation index); float4 *DatumGetSizedFloatArray(Datum datum, HnswColumnType type, int dimensions); diff --git a/src/hnsw/build.c b/src/hnsw/build.c index 5178bbd41..89bc5d739 100644 --- a/src/hnsw/build.c +++ b/src/hnsw/build.c @@ -7,10 +7,14 @@ #include #include #include +#include +#include +#include #include #include #include #include + #ifdef _WIN32 #define access _access #else @@ -132,7 +136,69 @@ static void BuildCallback( MemoryContextReset(buildstate->tmpCtx); } -static int GetArrayLengthFromHeap(Relation heap, int indexCol) +static int GetArrayLengthFromExpression(Expr *expression, Relation heap, HeapTuple tuple) +{ + ExprContext *econtext; + ExprState *exprstate; + EState *estate; + Datum result; + bool isNull; + Oid resultOid; + TupleTableSlot *slot; + TupleDesc tupdesc = RelationGetDescr(heap); + +#if PG_VERSION_NUM >= 120000 + slot = MakeSingleTupleTableSlot(tupdesc, &TTSOpsHeapTuple); +#else + slot = MakeSingleTupleTableSlot(tupdesc); +#endif + + // Create an expression context + econtext = CreateStandaloneExprContext(); + estate = CreateExecutorState(); + + // Build the expression state for your expression + exprstate = ExecPrepareExpr(expression, estate); + +#if PG_VERSION_NUM >= 120000 + ExecStoreHeapTuple(tuple, slot, false); +#else + ExecStoreTuple(tuple, slot, InvalidBuffer, false); +#endif + // Set up the tuple for the expression evaluation + econtext->ecxt_scantuple = slot; + + // Evaluate the expression for the first row + result = ExecEvalExprSwitchContext(exprstate, econtext, &isNull); + + // Release tuple descriptor + ReleaseTupleDesc(tupdesc); + + // Get the return type information + get_expr_result_type((Node *)exprstate->expr, &resultOid, NULL); + + HnswColumnType columnType = GetColumnTypeFromOid(resultOid); + + if(columnType == REAL_ARRAY || columnType == INT_ARRAY) { + ArrayType *array = DatumGetArrayTypeP(result); + return ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array)); + } else if(columnType == VECTOR) { + Vector *vector = DatumGetVector(result); + return vector->dim; + } else { + // Check if the result is not null and is supported type + // There is a guard in postgres that wont' allow passing + // Anything else from the defined operator class types + // Throwing an error like: ERROR: data type text has no default operator class for access method "hnsw" + // So this case will be marked as invariant + ldb_invariant(!isNull && columnType != UNKNOWN, + "Expression used in CREATE INDEX statement did not result in hnsw-index compatible array"); + } + + return HNSW_DEFAULT_DIM; +} + +static int GetArrayLengthFromHeap(Relation heap, int indexCol, IndexInfo *indexInfo) { #if PG_VERSION_NUM < 120000 HeapScanDesc scan; @@ -161,11 +227,21 @@ static int GetArrayLengthFromHeap(Relation heap, int indexCol) return n_items; } - // Get the indexed column out of the row and return it's dimensions - datum = heap_getattr(tuple, indexCol, RelationGetDescr(heap), &isNull); - if(!isNull) { - array = DatumGetArrayTypePCopy(datum); - n_items = ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array)); + if(indexInfo->ii_Expressions != NULL) { + // We don't suport multicolumn indexes + // So trying to pass multiple expressions on index creation + // Will result an error before getting here + ldb_invariant(indexInfo->ii_Expressions->length == 1, + "Index expressions can not be greater than 1 as multicolumn indexes are not supported"); + Expr *indexpr_item = lfirst(list_head(indexInfo->ii_Expressions)); + n_items = GetArrayLengthFromExpression(indexpr_item, heap, tuple); + } else { + // Get the indexed column out of the row and return it's dimensions + datum = heap_getattr(tuple, indexCol, RelationGetDescr(heap), &isNull); + if(!isNull) { + array = DatumGetArrayTypePCopy(datum); + n_items = ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array)); + } } heap_endscan(scan); @@ -173,7 +249,7 @@ static int GetArrayLengthFromHeap(Relation heap, int indexCol) return n_items; } -int GetHnswIndexDimensions(Relation index) +int GetHnswIndexDimensions(Relation index, IndexInfo *indexInfo) { HnswColumnType columnType = GetIndexColumnType(index); @@ -195,7 +271,7 @@ int GetHnswIndexDimensions(Relation index) #else heap = table_open(index->rd_index->indrelid, AccessShareLock); #endif - opt_dim = GetArrayLengthFromHeap(heap, attrNum); + opt_dim = GetArrayLengthFromHeap(heap, attrNum, indexInfo); opts = (ldb_HnswOptions *)index->rd_options; if(opts != NULL) { opts->dim = opt_dim; @@ -248,7 +324,7 @@ static int InferDimension(Relation heap, IndexInfo *indexInfo) } indexCol = indexInfo->ii_IndexAttrNumbers[ 0 ]; - return GetArrayLengthFromHeap(heap, indexCol); + return GetArrayLengthFromHeap(heap, indexCol, indexInfo); } /* @@ -260,7 +336,7 @@ static void InitBuildState(HnswBuildState *buildstate, Relation heap, Relation i buildstate->index = index; buildstate->indexInfo = indexInfo; buildstate->columnType = GetIndexColumnType(index); - buildstate->dimensions = GetHnswIndexDimensions(index); + buildstate->dimensions = GetHnswIndexDimensions(index, indexInfo); buildstate->index_file_path = ldb_HnswGetIndexFilePath(index); // If a dimension wasn't specified try to infer it diff --git a/src/hnsw/build.h b/src/hnsw/build.h index c5e227433..272bd394b 100644 --- a/src/hnsw/build.h +++ b/src/hnsw/build.h @@ -35,7 +35,7 @@ typedef struct HnswBuildState IndexBuildResult *ldb_ambuild(Relation heap, Relation index, IndexInfo *indexInfo); void ldb_ambuildunlogged(Relation index); -int GetHnswIndexDimensions(Relation index); +int GetHnswIndexDimensions(Relation index, IndexInfo *indexInfo); void CheckHnswIndexDimensions(Relation index, Datum arrayDatum, int deimensions); // todo: does this render my check unnecessary #endif // LDB_HNSW_BUILD_H diff --git a/src/hnsw/insert.c b/src/hnsw/insert.c index 94cd19b47..d66e9b919 100644 --- a/src/hnsw/insert.c +++ b/src/hnsw/insert.c @@ -71,7 +71,6 @@ bool ldb_aminsert(Relation index, HnswIndexTuple *new_tuple; usearch_init_options_t opts = {0}; LDB_UNUSED(heap); - LDB_UNUSED(indexInfo); #if PG_VERSION_NUM >= 140000 LDB_UNUSED(indexUnchanged); #endif @@ -103,7 +102,7 @@ bool ldb_aminsert(Relation index, hdr = (HnswIndexHeaderPage *)PageGetContents(hdr_page); assert(hdr->magicNumber == LDB_WAL_MAGIC_NUMBER); - opts.dimensions = GetHnswIndexDimensions(index); + opts.dimensions = GetHnswIndexDimensions(index, indexInfo); CheckHnswIndexDimensions(index, values[ 0 ], opts.dimensions); PopulateUsearchOpts(index, &opts); opts.retriever_ctx = ldb_wal_retriever_area_init(index, hdr); diff --git a/src/hooks/executor_start.c b/src/hooks/executor_start.c index 4a55ffb71..6067facfb 100644 --- a/src/hooks/executor_start.c +++ b/src/hooks/executor_start.c @@ -83,4 +83,4 @@ void ExecutorStart_hook_with_operator_check(QueryDesc *queryDesc, int eflags) } standard_ExecutorStart(queryDesc, eflags); -} \ No newline at end of file +} diff --git a/src/hooks/post_parse.c b/src/hooks/post_parse.c index 637e59883..81dcba0c8 100644 --- a/src/hooks/post_parse.c +++ b/src/hooks/post_parse.c @@ -104,6 +104,18 @@ static bool operator_used_incorrectly_walker(Node *node, OperatorUsedCorrectlyCo Node *arg2 = (Node *)lsecond(opExpr->args); bool isVar1 = IsA(arg1, Var); bool isVar2 = IsA(arg2, Var); + /* There is a case when operator is used with index + * that was created via expression (CREATE INDEX ON t USING hnsw (func(id)) WITH (M=2)) + * in this case the query may look like this + * SELECT id FROM test ORDER BY func(id) <-> ARRAY[0,0,0] LIMIT 2 + * or like this + * SELECT id FROM test ORDER BY func(id) <-> func(n) LIMIT 2 + * we should check if IsA(arg1, FuncExpr) || IsA(arg2, FuncExpr) + * if true we may go and check the oid of function result to see if it is an array type + * we also can check that the argument of FuncExpr is at least one of the arg1 and arg2 + * will contain column of the table (e.g iterate over list and check IsA(arg, Var)) + * so the function will not be called with constant arguments on both sides + */ if(isVar1 && isVar2) { return false; } else if(isVar1 && !isVar2) { @@ -168,4 +180,4 @@ void post_parse_analyze_hook_with_operator_check(ParseState *pstate, list_free(sort_group_refs); } list_free(oidList); -} \ No newline at end of file +} diff --git a/test/expected/hnsw_create_expr.out b/test/expected/hnsw_create_expr.out index 3c9b26f18..70a942c85 100644 --- a/test/expected/hnsw_create_expr.out +++ b/test/expected/hnsw_create_expr.out @@ -1,3 +1,10 @@ +/* +This function, int_to_fixed_binary_real_array(n INT), will create a 3-dimensional float array (REAL[]). +It fills the array with the first 3 bits of the passed integer 'n' by converting 'n' to binary, +left-padding it to 3 digits, and then converting each digit to a REAL value. +For example, int_to_fixed_binary_real_array(1); will result in the array {0,0,1}, +and int_to_fixed_binary_real_array(2); will result in {0,1,0}. +*/ CREATE OR REPLACE FUNCTION int_to_fixed_binary_real_array(n INT) RETURNS REAL[] AS $$ DECLARE binary_string TEXT; @@ -12,12 +19,67 @@ BEGIN RETURN real_array; END; $$ LANGUAGE plpgsql IMMUTABLE; +/* +This function, int_to_dynamic_binary_real_array(n INT), will create a 3+n dimensional float array (REAL[]). +It first fills the first 3 elements of the array with the first 3 bits of the passed integer 'n' +(using a similar binary conversion as the previous function), and then adds elements sequentially from 4 to 'n+3'. +For example, int_to_dynamic_binary_real_array(3); will result in the array {0,1,1,1,2,3}, +and int_to_dynamic_binary_real_array(4); will result in the array {1,0,0,1,2,3,4}. +*/ +CREATE OR REPLACE FUNCTION int_to_dynamic_binary_real_array(n INT) RETURNS REAL[] AS $$ +DECLARE + binary_string TEXT; + real_array REAL[] := '{}'; + i INT; + result_length INT; +BEGIN + binary_string := lpad(CAST(n::BIT(3) AS TEXT), 3, '0'); + + -- Calculate the length of the result array + result_length := 3 + n; + + FOR i IN 1..result_length + LOOP + IF i <= 3 THEN + real_array := array_append(real_array, CAST(substring(binary_string, i, 1) AS REAL)); + ELSE + real_array := array_append(real_array, CAST(i - 3 AS REAL)); + END IF; + END LOOP; + + RETURN real_array; +END; +$$ LANGUAGE plpgsql IMMUTABLE; +/* +This simple function, int_to_string(n INT), converts the integer 'n' to a 3-character text representation +by converting 'n' to binary and left-padding it to 3 digits with '0's. +For example, int_to_string(1); will return '001', and int_to_string(2); will return '010'. +*/ +CREATE OR REPLACE FUNCTION int_to_string(n INT) RETURNS TEXT AS $$ +BEGIN + RETURN lpad(CAST(n::BIT(3) AS TEXT), 3, '0'); +END; +$$ LANGUAGE plpgsql IMMUTABLE; CREATE TABLE test_table (id INTEGER); INSERT INTO test_table VALUES (0), (1), (7); \set enable_seqscan = off; -CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id)) WITH (M=2, dim=3); +-- This should success +CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id)) WITH (M=2); INFO: done init usearch index INFO: inserted 3 elements INFO: done saving 3 vectors -SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> int_to_fixed_binary_real_array(0) LIMIT 2; +\set ON_ERROR_STOP off +-- This should result in an error that dimensions does not match +CREATE INDEX ON test_table USING hnsw (int_to_dynamic_binary_real_array(id)) WITH (M=2); +INFO: done init usearch index +ERROR: Wrong number of dimensions: 4 instead of 3 expected +-- This should result in an error that data type text has no default operator class +CREATE INDEX ON test_table USING hnsw (int_to_string(id)) WITH (M=2); +ERROR: data type text has no default operator class for access method "hnsw" +-- This should result in error about multicolumn expressions support +CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id), int_to_dynamic_binary_real_array(id)) WITH (M=2); +ERROR: access method "hnsw" does not support multicolumn indexes +-- This currently results in an error about using the operator outside of index +-- This case should be fixed +SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> '{0,0,0}'::REAL[] LIMIT 2; ERROR: Operator <-> can only be used inside of an index diff --git a/test/expected/hnsw_todo.out b/test/expected/hnsw_todo.out index 9b50c6920..8eb3b1b75 100644 --- a/test/expected/hnsw_todo.out +++ b/test/expected/hnsw_todo.out @@ -103,3 +103,25 @@ SELECT ROUND(l2sq_dist(v, :'v1001')::numeric, 2) FROM sift_base1k order by v <-> 249285.00 (1 row) +---- Query on expression based index is failing to check correct <-> operator usage -------- +CREATE OR REPLACE FUNCTION int_to_fixed_binary_real_array(n INT) RETURNS REAL[] AS $$ +DECLARE + binary_string TEXT; + real_array REAL[] := '{}'; + i INT; +BEGIN + binary_string := lpad(CAST(n::BIT(3) AS TEXT), 3, '0'); + FOR i IN 1..length(binary_string) + LOOP + real_array := array_append(real_array, CAST(substring(binary_string, i, 1) AS REAL)); + END LOOP; + RETURN real_array; +END; +$$ LANGUAGE plpgsql IMMUTABLE; +CREATE TABLE test_table (id INTEGER); +INSERT INTO test_table VALUES (0), (1), (7); +\set enable_seqscan = off; +-- This currently results in an error about using the operator outside of index +-- This case should be fixed +SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> '{0,0,0}'::REAL[] LIMIT 2; +ERROR: Operator <-> can only be used inside of an index diff --git a/test/expected/hnsw_vector.out b/test/expected/hnsw_vector.out index 5cb787f4d..a02876065 100644 --- a/test/expected/hnsw_vector.out +++ b/test/expected/hnsw_vector.out @@ -173,3 +173,24 @@ SELECT 1 FROM small_world ORDER BY v <-> '[0,1,0,1]' LIMIT 1; ERROR: Expected vector with dimension 3, got 4 SELECT l2sq_dist('[1,1]'::vector, '[0,1,0]'::vector); ERROR: expected equally sized vectors but got vectors with dimensions 2 and 3 +-- Test creating index with expression +CREATE TABLE test_table (id INTEGER); +INSERT INTO test_table VALUES (0), (1), (7); +CREATE OR REPLACE FUNCTION int_to_fixed_binary_vector(n INT) RETURNS vector AS $$ +DECLARE + binary_string TEXT; + real_array REAL[] := '{}'; + i INT; +BEGIN + binary_string := lpad(CAST(n::BIT(3) AS TEXT), 3, '0'); + FOR i IN 1..length(binary_string) + LOOP + real_array := array_append(real_array, CAST(substring(binary_string, i, 1) AS REAL)); + END LOOP; + RETURN real_array::vector; +END; +$$ LANGUAGE plpgsql IMMUTABLE; +CREATE INDEX ON test_table USING lantern_hnsw (int_to_fixed_binary_vector(id)) WITH (M=2); +INFO: done init usearch index +INFO: inserted 3 elements +INFO: done saving 3 vectors diff --git a/test/sql/hnsw_create_expr.sql b/test/sql/hnsw_create_expr.sql index eca96c69b..ecabfbaf9 100644 --- a/test/sql/hnsw_create_expr.sql +++ b/test/sql/hnsw_create_expr.sql @@ -1,3 +1,10 @@ +/* +This function, int_to_fixed_binary_real_array(n INT), will create a 3-dimensional float array (REAL[]). +It fills the array with the first 3 bits of the passed integer 'n' by converting 'n' to binary, +left-padding it to 3 digits, and then converting each digit to a REAL value. +For example, int_to_fixed_binary_real_array(1); will result in the array {0,0,1}, +and int_to_fixed_binary_real_array(2); will result in {0,1,0}. +*/ CREATE OR REPLACE FUNCTION int_to_fixed_binary_real_array(n INT) RETURNS REAL[] AS $$ DECLARE binary_string TEXT; @@ -13,9 +20,68 @@ BEGIN END; $$ LANGUAGE plpgsql IMMUTABLE; +/* +This function, int_to_dynamic_binary_real_array(n INT), will create a 3+n dimensional float array (REAL[]). +It first fills the first 3 elements of the array with the first 3 bits of the passed integer 'n' +(using a similar binary conversion as the previous function), and then adds elements sequentially from 4 to 'n+3'. +For example, int_to_dynamic_binary_real_array(3); will result in the array {0,1,1,1,2,3}, +and int_to_dynamic_binary_real_array(4); will result in the array {1,0,0,1,2,3,4}. +*/ +CREATE OR REPLACE FUNCTION int_to_dynamic_binary_real_array(n INT) RETURNS REAL[] AS $$ +DECLARE + binary_string TEXT; + real_array REAL[] := '{}'; + i INT; + result_length INT; +BEGIN + binary_string := lpad(CAST(n::BIT(3) AS TEXT), 3, '0'); + + -- Calculate the length of the result array + result_length := 3 + n; + + FOR i IN 1..result_length + LOOP + IF i <= 3 THEN + real_array := array_append(real_array, CAST(substring(binary_string, i, 1) AS REAL)); + ELSE + real_array := array_append(real_array, CAST(i - 3 AS REAL)); + END IF; + END LOOP; + + RETURN real_array; +END; +$$ LANGUAGE plpgsql IMMUTABLE; + +/* +This simple function, int_to_string(n INT), converts the integer 'n' to a 3-character text representation +by converting 'n' to binary and left-padding it to 3 digits with '0's. +For example, int_to_string(1); will return '001', and int_to_string(2); will return '010'. +*/ +CREATE OR REPLACE FUNCTION int_to_string(n INT) RETURNS TEXT AS $$ +BEGIN + RETURN lpad(CAST(n::BIT(3) AS TEXT), 3, '0'); +END; +$$ LANGUAGE plpgsql IMMUTABLE; + + CREATE TABLE test_table (id INTEGER); INSERT INTO test_table VALUES (0), (1), (7); + \set enable_seqscan = off; -CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id)) WITH (M=2, dim=3); -SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> int_to_fixed_binary_real_array(0) LIMIT 2; \ No newline at end of file +-- This should success +CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id)) WITH (M=2); + +\set ON_ERROR_STOP off +-- This should result in an error that dimensions does not match +CREATE INDEX ON test_table USING hnsw (int_to_dynamic_binary_real_array(id)) WITH (M=2); + +-- This should result in an error that data type text has no default operator class +CREATE INDEX ON test_table USING hnsw (int_to_string(id)) WITH (M=2); + +-- This should result in error about multicolumn expressions support +CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id), int_to_dynamic_binary_real_array(id)) WITH (M=2); + +-- This currently results in an error about using the operator outside of index +-- This case should be fixed +SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> '{0,0,0}'::REAL[] LIMIT 2; diff --git a/test/sql/hnsw_todo.sql b/test/sql/hnsw_todo.sql index 34f240a2c..2159d7eba 100644 --- a/test/sql/hnsw_todo.sql +++ b/test/sql/hnsw_todo.sql @@ -72,3 +72,28 @@ CREATE INDEX hnsw_l2_index ON sift_base1k USING hnsw (v) WITH (_experimental_ind -- So the usearch index can not find 1,1,1,1,1.. vector in the index and wrong results will be returned -- This is an expected behaviour for now SELECT ROUND(l2sq_dist(v, :'v1001')::numeric, 2) FROM sift_base1k order by v <-> :'v1001' LIMIT 1; + +---- Query on expression based index is failing to check correct <-> operator usage -------- +CREATE OR REPLACE FUNCTION int_to_fixed_binary_real_array(n INT) RETURNS REAL[] AS $$ +DECLARE + binary_string TEXT; + real_array REAL[] := '{}'; + i INT; +BEGIN + binary_string := lpad(CAST(n::BIT(3) AS TEXT), 3, '0'); + FOR i IN 1..length(binary_string) + LOOP + real_array := array_append(real_array, CAST(substring(binary_string, i, 1) AS REAL)); + END LOOP; + RETURN real_array; +END; +$$ LANGUAGE plpgsql IMMUTABLE; + +CREATE TABLE test_table (id INTEGER); +INSERT INTO test_table VALUES (0), (1), (7); + +\set enable_seqscan = off; +-- This currently results in an error about using the operator outside of index +-- This case should be fixed +SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> '{0,0,0}'::REAL[] LIMIT 2; + diff --git a/test/sql/hnsw_vector.sql b/test/sql/hnsw_vector.sql index e7aa0285c..cfe282ab9 100644 --- a/test/sql/hnsw_vector.sql +++ b/test/sql/hnsw_vector.sql @@ -78,3 +78,24 @@ SELECT ARRAY[1,2,3] <-> ARRAY[3,2,1]; -- Expect error due to mismatching vector dimensions SELECT 1 FROM small_world ORDER BY v <-> '[0,1,0,1]' LIMIT 1; SELECT l2sq_dist('[1,1]'::vector, '[0,1,0]'::vector); + +-- Test creating index with expression +CREATE TABLE test_table (id INTEGER); +INSERT INTO test_table VALUES (0), (1), (7); + +CREATE OR REPLACE FUNCTION int_to_fixed_binary_vector(n INT) RETURNS vector AS $$ +DECLARE + binary_string TEXT; + real_array REAL[] := '{}'; + i INT; +BEGIN + binary_string := lpad(CAST(n::BIT(3) AS TEXT), 3, '0'); + FOR i IN 1..length(binary_string) + LOOP + real_array := array_append(real_array, CAST(substring(binary_string, i, 1) AS REAL)); + END LOOP; + RETURN real_array::vector; +END; +$$ LANGUAGE plpgsql IMMUTABLE; + +CREATE INDEX ON test_table USING lantern_hnsw (int_to_fixed_binary_vector(id)) WITH (M=2);