Skip to content

Commit

Permalink
Add more tests for index expressions, fix index expressions for vecto…
Browse files Browse the repository at this point in the history
…r type
  • Loading branch information
var77 committed Sep 25, 2023
1 parent cd0e4a8 commit 346f3d0
Show file tree
Hide file tree
Showing 10 changed files with 200 additions and 25 deletions.
3 changes: 2 additions & 1 deletion scripts/run_all_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,14 @@ pgvector_installed=$($PSQL -U $DB_USER -d postgres -c "SELECT 1 FROM pg_availabl
rm -rf $TMP_OUTDIR/schedule.txt
if [ -n "$FILTER" ]; then
if [[ "$pgvector_installed" == "1" ]]; then
TEST_FILES=$(cat schedule.txt | grep -E '^(test:|test_pgvector:)' | sed -e 's/^\(test:\|test_pgvector:\)//' | tr " " "\n" | sed -e '/^$/d')
TEST_FILES=$(cat schedule.txt | grep -E '^(test:|test_pgvector:)' | sed -E -e 's/^test:|test_pgvector://' | tr " " "\n" | sed -e '/^$/d')
else
TEST_FILES=$(cat schedule.txt | grep '^test:' | sed -e 's/^test://' | tr " " "\n" | sed -e '/^$/d')
fi

while IFS= read -r f; do
if [[ $f == *"$FILTER"* ]]; then
echo "HERE $f"
echo "test: $f" >> $TMP_OUTDIR/schedule.txt
fi
done <<< "$TEST_FILES"
Expand Down
27 changes: 18 additions & 9 deletions src/hnsw.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <catalog/namespace.h>
#include <commands/vacuum.h>
#include <float.h>
#include <hnsw/utils.h>
#include <math.h>
#include <utils/guc.h>
#include <utils/lsyscache.h>
Expand Down Expand Up @@ -343,25 +344,33 @@ Datum vector_l2sq_dist(PG_FUNCTION_ARGS)
}

/*
* Get data type of index
*/
HnswColumnType GetIndexColumnType(Relation index)
* Get data type for give oid
* */
HnswColumnType GetColumnTypeFromOid(Oid oid)
{
TupleDesc indexTupDesc = RelationGetDescr(index);
Form_pg_attribute attr = TupleDescAttr(indexTupDesc, 0);
Oid columnType = attr->atttypid;
ldb_invariant(OidIsValid(oid), "Invalid oid passed");

if(columnType == FLOAT4ARRAYOID) {
if(oid == FLOAT4ARRAYOID) {
return REAL_ARRAY;
} else if(columnType == TypenameGetTypid("vector")) {
} else if(oid == TypenameGetTypid("vector")) {
return VECTOR;
} else if(columnType == INT4ARRAYOID) {
} else if(oid == INT4ARRAYOID) {
return INT_ARRAY;
} else {
return UNKNOWN;
}
}

/*
* Get data type of index
*/
HnswColumnType GetIndexColumnType(Relation index)
{
TupleDesc indexTupDesc = RelationGetDescr(index);
Form_pg_attribute attr = TupleDescAttr(indexTupDesc, 0);
return GetColumnTypeFromOid(attr->atttypid);
}

/*
* Given vector data and vector type, convert it to a float4 array
*/
Expand Down
1 change: 1 addition & 0 deletions src/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ PGDLLEXPORT Datum vector_l2sq_dist(PG_FUNCTION_ARGS);
PGDLLEXPORT Datum hamming_dist(PG_FUNCTION_ARGS);
PGDLLEXPORT Datum cos_dist(PG_FUNCTION_ARGS);

HnswColumnType GetColumnTypeFromOid(Oid oid);
HnswColumnType GetIndexColumnType(Relation index);
float4 *DatumGetSizedFloatArray(Datum datum, HnswColumnType type, int dimensions);

Expand Down
39 changes: 28 additions & 11 deletions src/hnsw/build.c
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ static int GetArrayLengthFromExpression(Expr *expression, Relation heap, HeapTup
Datum result;
bool isNull;
ArrayType *array;
int n_items = HNSW_DEFAULT_DIM;
TupleTableSlot *slot;
TupleDesc tupdesc = RelationGetDescr(heap);

Expand All @@ -169,21 +168,34 @@ static int GetArrayLengthFromExpression(Expr *expression, Relation heap, HeapTup

// Evaluate the expression for the first row
result = ExecEvalExprSwitchContext(exprstate, econtext, &isNull);
array = DatumGetArrayTypeP(result);

// Release tuple descriptor
ReleaseTupleDesc(tupdesc);

// Check if the result is not null
if(!isNull) {
// todo check if datum is array
array = DatumGetArrayTypePCopy(result);
n_items = ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array));
return n_items;
} else {
elog(ERROR, "Expression did not result in an array");
Oid oid = get_array_type(array->elemtype);

if(!OidIsValid(oid)) {
// Oid 0 can only be in a case when the datum will be vector type
// As postgres will guard that arbitrary data type won't be returned
// from the function expression
// so we should be safe here to case result into vector in this case
Vector *vector = DatumGetVector(result);
// if the vector will not have dimensions or somehow it
// won't be casted the code will go to ldb_invariant case
if(vector->dim > 0) {
return vector->dim;
}
}

return n_items;
// Check if the result is not null and is supported type
// There is a guard in postgres that wont' allow passing
// Anything else from the defined operator class types
// Throwing an error like: ERROR: data type text has no default operator class for access method "hnsw"
// So this case will be marked as invariant
ldb_invariant(!isNull && GetColumnTypeFromOid(oid) != UNKNOWN,
"Expression used in CREATE INDEX statement did not result in hnsw-index compatible array");

return ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array));
}

static int GetArrayLengthFromHeap(Relation heap, int indexCol, IndexInfo *indexInfo)
Expand Down Expand Up @@ -216,6 +228,11 @@ static int GetArrayLengthFromHeap(Relation heap, int indexCol, IndexInfo *indexI
}

if(indexInfo->ii_Expressions != NULL) {
// We don't suport multicolumn indexes
// So trying to pass multiple expressions on index creation
// Will result an error before getting here
ldb_invariant(indexInfo->ii_Expressions->length == 1,
"Index expressions can not be greater than 1 as multicolumn indexes are not supported");
Expr *indexpr_item = lfirst(list_head(indexInfo->ii_Expressions));
n_items = GetArrayLengthFromExpression(indexpr_item, heap, tuple);
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/hooks/executor_start.c
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,4 @@ void ExecutorStart_hook_with_operator_check(QueryDesc *queryDesc, int eflags)
}

standard_ExecutorStart(queryDesc, eflags);
}
}
14 changes: 13 additions & 1 deletion src/hooks/post_parse.c
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,18 @@ static bool operator_used_incorrectly_walker(Node *node, OperatorUsedCorrectlyCo
Node *arg2 = (Node *)lsecond(opExpr->args);
bool isVar1 = IsA(arg1, Var);
bool isVar2 = IsA(arg2, Var);
/* There is a case when operator is used with index
* that was created via expression (CREATE INDEX ON t USING hnsw (func(id)) WITH (M=2))
* in this case the query may look like this
* SELECT id FROM test ORDER BY func(id) <-> ARRAY[0,0,0] LIMIT 2
* or like this
* SELECT id FROM test ORDER BY func(id) <-> func(n) LIMIT 2
* we should check if IsA(arg1, FuncExpr) || IsA(arg2, FuncExpr)
* if true we may go and check the oid of function result to see if it is an array type
* we also can check that the argument of FuncExpr is at least one of the arg1 and arg2
* will contain column of the table (e.g iterate over list and check IsA(arg, Var))
* so the function will not be called on constant arguments on both sides
*/
if(isVar1 && isVar2) {
return false;
} else if(!isVar1 && !isVar2) {
Expand Down Expand Up @@ -145,4 +157,4 @@ void post_parse_analyze_hook_with_operator_check(ParseState *pstate,
list_free(sort_group_refs);
}
list_free(oidList);
}
}
48 changes: 47 additions & 1 deletion test/expected/hnsw_create_expr.out
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,54 @@ BEGIN
RETURN real_array;
END;
$$ LANGUAGE plpgsql IMMUTABLE;
CREATE OR REPLACE FUNCTION int_to_dynamic_binary_real_array(n INT) RETURNS REAL[] AS $$
DECLARE
binary_string TEXT;
real_array REAL[] := '{}';
i INT;
result_length INT;
BEGIN
binary_string := lpad(CAST(n::BIT(3) AS TEXT), 3, '0');

-- Calculate the length of the result array
result_length := 3 + n;

FOR i IN 1..result_length
LOOP
IF i <= 3 THEN
real_array := array_append(real_array, CAST(substring(binary_string, i, 1) AS REAL));
ELSE
real_array := array_append(real_array, CAST(i - 3 AS REAL));
END IF;
END LOOP;

RETURN real_array;
END;
$$ LANGUAGE plpgsql IMMUTABLE;
CREATE OR REPLACE FUNCTION int_to_string(n INT) RETURNS TEXT AS $$
BEGIN
RETURN lpad(CAST(n::BIT(3) AS TEXT), 3, '0');
END;
$$ LANGUAGE plpgsql IMMUTABLE;
CREATE TABLE test_table (id INTEGER);
INSERT INTO test_table VALUES (0), (1), (7);
\set enable_seqscan = off;
CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id)) WITH (M=2);
ERROR: invalid attnum: 0
INFO: done init usearch index
INFO: inserted 3 elements
INFO: done saving 3 vectors
\set ON_ERROR_STOP off
-- This should result in an error that dimensions does not match
CREATE INDEX ON test_table USING hnsw (int_to_dynamic_binary_real_array(id)) WITH (M=2);
INFO: done init usearch index
ERROR: Wrong number of dimensions: 4 instead of 3 expected
-- This should result in an error that data type text has no default operator class
CREATE INDEX ON test_table USING hnsw (int_to_string(id)) WITH (M=2);
ERROR: data type text has no default operator class for access method "hnsw"
-- This should result in error about multicolumn expressions support
CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id), int_to_dynamic_binary_real_array(id)) WITH (M=2);
ERROR: access method "hnsw" does not support multicolumn indexes
-- This currently results in an error about using the operator outside of index
-- This case should be fixed
SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> int_to_fixed_binary_real_array(0) LIMIT 2;
ERROR: Operator <-> has no standalone meaning and is reserved for use in vector index lookups only
21 changes: 21 additions & 0 deletions test/expected/hnsw_vector.out
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,24 @@ SELECT 1 FROM small_world ORDER BY v <-> '[0,1,0,1]' LIMIT 1;
ERROR: Expected vector with dimension 3, got 4
SELECT l2sq_dist('[1,1]'::vector, '[0,1,0]'::vector);
ERROR: expected equally sized vectors but got vectors with dimensions 2 and 3
-- Test creating index with expression
CREATE TABLE test_table (id INTEGER);
INSERT INTO test_table VALUES (0), (1), (7);
CREATE OR REPLACE FUNCTION int_to_fixed_binary_vector(n INT) RETURNS vector AS $$
DECLARE
binary_string TEXT;
real_array REAL[] := '{}';
i INT;
BEGIN
binary_string := lpad(CAST(n::BIT(3) AS TEXT), 3, '0');
FOR i IN 1..length(binary_string)
LOOP
real_array := array_append(real_array, CAST(substring(binary_string, i, 1) AS REAL));
END LOOP;
RETURN real_array::vector;
END;
$$ LANGUAGE plpgsql IMMUTABLE;
CREATE INDEX ON test_table USING lantern_hnsw (int_to_fixed_binary_vector(id)) WITH (M=2);
INFO: done init usearch index
INFO: inserted 3 elements
INFO: done saving 3 vectors
49 changes: 48 additions & 1 deletion test/sql/hnsw_create_expr.sql
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,56 @@ BEGIN
END;
$$ LANGUAGE plpgsql IMMUTABLE;

CREATE OR REPLACE FUNCTION int_to_dynamic_binary_real_array(n INT) RETURNS REAL[] AS $$
DECLARE
binary_string TEXT;
real_array REAL[] := '{}';
i INT;
result_length INT;
BEGIN
binary_string := lpad(CAST(n::BIT(3) AS TEXT), 3, '0');

-- Calculate the length of the result array
result_length := 3 + n;

FOR i IN 1..result_length
LOOP
IF i <= 3 THEN
real_array := array_append(real_array, CAST(substring(binary_string, i, 1) AS REAL));
ELSE
real_array := array_append(real_array, CAST(i - 3 AS REAL));
END IF;
END LOOP;

RETURN real_array;
END;
$$ LANGUAGE plpgsql IMMUTABLE;

CREATE OR REPLACE FUNCTION int_to_string(n INT) RETURNS TEXT AS $$
BEGIN
RETURN lpad(CAST(n::BIT(3) AS TEXT), 3, '0');
END;
$$ LANGUAGE plpgsql IMMUTABLE;


CREATE TABLE test_table (id INTEGER);
INSERT INTO test_table VALUES (0), (1), (7);

\set enable_seqscan = off;

CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id)) WITH (M=2);

SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> int_to_fixed_binary_real_array(0) LIMIT 2;
\set ON_ERROR_STOP off

-- This should result in an error that dimensions does not match
CREATE INDEX ON test_table USING hnsw (int_to_dynamic_binary_real_array(id)) WITH (M=2);

-- This should result in an error that data type text has no default operator class
CREATE INDEX ON test_table USING hnsw (int_to_string(id)) WITH (M=2);

-- This should result in error about multicolumn expressions support
CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id), int_to_dynamic_binary_real_array(id)) WITH (M=2);

-- This currently results in an error about using the operator outside of index
-- This case should be fixed
SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> int_to_fixed_binary_real_array(0) LIMIT 2;
21 changes: 21 additions & 0 deletions test/sql/hnsw_vector.sql
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,24 @@ SELECT ARRAY[1,2,3] <-> ARRAY[3,2,1];
-- Expect error due to mismatching vector dimensions
SELECT 1 FROM small_world ORDER BY v <-> '[0,1,0,1]' LIMIT 1;
SELECT l2sq_dist('[1,1]'::vector, '[0,1,0]'::vector);

-- Test creating index with expression
CREATE TABLE test_table (id INTEGER);
INSERT INTO test_table VALUES (0), (1), (7);

CREATE OR REPLACE FUNCTION int_to_fixed_binary_vector(n INT) RETURNS vector AS $$
DECLARE
binary_string TEXT;
real_array REAL[] := '{}';
i INT;
BEGIN
binary_string := lpad(CAST(n::BIT(3) AS TEXT), 3, '0');
FOR i IN 1..length(binary_string)
LOOP
real_array := array_append(real_array, CAST(substring(binary_string, i, 1) AS REAL));
END LOOP;
RETURN real_array::vector;
END;
$$ LANGUAGE plpgsql IMMUTABLE;

CREATE INDEX ON test_table USING lantern_hnsw (int_to_fixed_binary_vector(id)) WITH (M=2);

0 comments on commit 346f3d0

Please sign in to comment.