Skip to content

Commit

Permalink
Infer array dimensions from index expressions (#175)
Browse files Browse the repository at this point in the history
* Add create expression test

* Infer array dimensions from index expressions

* Fix ExecStoreHeapTuple for pg11

* Add more tests for index expressions, fix index expressions for vector type

* Add expr result oid check with invariant, add comments in sql, move failing tests to todo

---------

Co-authored-by: Di Qi <[email protected]>
  • Loading branch information
var77 and dqii authored Sep 28, 2023
1 parent 5d9ffe0 commit 4c80eca
Show file tree
Hide file tree
Showing 15 changed files with 345 additions and 29 deletions.
2 changes: 2 additions & 0 deletions scripts/extern_defined.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ SED_PATTERN='s/@/@/p' # noop pattern
nm -D --with-symbol-versions $PG_BIN | grep " T " | awk '{print $3}' | sed -e "$SED_PATTERN"
# global bss symbol in postgres
nm -D --with-symbol-versions $PG_BIN | grep " B " | awk '{print $3}' | sed -e "$SED_PATTERN"
# postgres Initialized data (bbs), global symbols
nm -D --with-symbol-versions $PG_BIN | grep " D " | awk '{print $3}' | sed -e "$SED_PATTERN"
# postgres weak symbols
nm -D --with-symbol-versions $PG_BIN | grep " w " | awk '{print $2}' | sed -e "$SED_PATTERN"

Expand Down
3 changes: 2 additions & 1 deletion scripts/run_all_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,14 @@ pgvector_installed=$($PSQL -U $DB_USER -d postgres -c "SELECT 1 FROM pg_availabl
rm -rf $TMP_OUTDIR/schedule.txt
if [ -n "$FILTER" ]; then
if [[ "$pgvector_installed" == "1" ]]; then
TEST_FILES=$(cat schedule.txt | grep -E '^(test:|test_pgvector:)' | sed -e 's/^\(test:\|test_pgvector:\)//' | tr " " "\n" | sed -e '/^$/d')
TEST_FILES=$(cat schedule.txt | grep -E '^(test:|test_pgvector:)' | sed -E -e 's/^test:|test_pgvector://' | tr " " "\n" | sed -e '/^$/d')
else
TEST_FILES=$(cat schedule.txt | grep '^test:' | sed -e 's/^test://' | tr " " "\n" | sed -e '/^$/d')
fi

while IFS= read -r f; do
if [[ $f == *"$FILTER"* ]]; then
echo "HERE $f"
echo "test: $f" >> $TMP_OUTDIR/schedule.txt
fi
done <<< "$TEST_FILES"
Expand Down
26 changes: 17 additions & 9 deletions src/hnsw.c
Original file line number Diff line number Diff line change
Expand Up @@ -343,25 +343,33 @@ Datum vector_l2sq_dist(PG_FUNCTION_ARGS)
}

/*
* Get data type of index
*/
HnswColumnType GetIndexColumnType(Relation index)
* Get data type for give oid
* */
HnswColumnType GetColumnTypeFromOid(Oid oid)
{
TupleDesc indexTupDesc = RelationGetDescr(index);
Form_pg_attribute attr = TupleDescAttr(indexTupDesc, 0);
Oid columnType = attr->atttypid;
ldb_invariant(OidIsValid(oid), "Invalid oid passed");

if(columnType == FLOAT4ARRAYOID) {
if(oid == FLOAT4ARRAYOID) {
return REAL_ARRAY;
} else if(columnType == TypenameGetTypid("vector")) {
} else if(oid == TypenameGetTypid("vector")) {
return VECTOR;
} else if(columnType == INT4ARRAYOID) {
} else if(oid == INT4ARRAYOID) {
return INT_ARRAY;
} else {
return UNKNOWN;
}
}

/*
* Get data type of index
*/
HnswColumnType GetIndexColumnType(Relation index)
{
TupleDesc indexTupDesc = RelationGetDescr(index);
Form_pg_attribute attr = TupleDescAttr(indexTupDesc, 0);
return GetColumnTypeFromOid(attr->atttypid);
}

/*
* Given vector data and vector type, convert it to a float4 array
*/
Expand Down
1 change: 1 addition & 0 deletions src/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ PGDLLEXPORT Datum vector_l2sq_dist(PG_FUNCTION_ARGS);
PGDLLEXPORT Datum hamming_dist(PG_FUNCTION_ARGS);
PGDLLEXPORT Datum cos_dist(PG_FUNCTION_ARGS);

HnswColumnType GetColumnTypeFromOid(Oid oid);
HnswColumnType GetIndexColumnType(Relation index);
float4 *DatumGetSizedFloatArray(Datum datum, HnswColumnType type, int dimensions);

Expand Down
96 changes: 86 additions & 10 deletions src/hnsw/build.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
#include <catalog/index.h>
#include <catalog/namespace.h>
#include <catalog/pg_type.h>
#include <executor/executor.h>
#include <funcapi.h>
#include <nodes/execnodes.h>
#include <storage/bufmgr.h>
#include <utils/array.h>
#include <utils/lsyscache.h>
#include <utils/memutils.h>

#ifdef _WIN32
#define access _access
#else
Expand Down Expand Up @@ -132,7 +136,69 @@ static void BuildCallback(
MemoryContextReset(buildstate->tmpCtx);
}

static int GetArrayLengthFromHeap(Relation heap, int indexCol)
static int GetArrayLengthFromExpression(Expr *expression, Relation heap, HeapTuple tuple)
{
ExprContext *econtext;
ExprState *exprstate;
EState *estate;
Datum result;
bool isNull;
Oid resultOid;
TupleTableSlot *slot;
TupleDesc tupdesc = RelationGetDescr(heap);

#if PG_VERSION_NUM >= 120000
slot = MakeSingleTupleTableSlot(tupdesc, &TTSOpsHeapTuple);
#else
slot = MakeSingleTupleTableSlot(tupdesc);
#endif

// Create an expression context
econtext = CreateStandaloneExprContext();
estate = CreateExecutorState();

// Build the expression state for your expression
exprstate = ExecPrepareExpr(expression, estate);

#if PG_VERSION_NUM >= 120000
ExecStoreHeapTuple(tuple, slot, false);
#else
ExecStoreTuple(tuple, slot, InvalidBuffer, false);
#endif
// Set up the tuple for the expression evaluation
econtext->ecxt_scantuple = slot;

// Evaluate the expression for the first row
result = ExecEvalExprSwitchContext(exprstate, econtext, &isNull);

// Release tuple descriptor
ReleaseTupleDesc(tupdesc);

// Get the return type information
get_expr_result_type((Node *)exprstate->expr, &resultOid, NULL);

HnswColumnType columnType = GetColumnTypeFromOid(resultOid);

if(columnType == REAL_ARRAY || columnType == INT_ARRAY) {
ArrayType *array = DatumGetArrayTypeP(result);
return ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array));
} else if(columnType == VECTOR) {
Vector *vector = DatumGetVector(result);
return vector->dim;
} else {
// Check if the result is not null and is supported type
// There is a guard in postgres that wont' allow passing
// Anything else from the defined operator class types
// Throwing an error like: ERROR: data type text has no default operator class for access method "hnsw"
// So this case will be marked as invariant
ldb_invariant(!isNull && columnType != UNKNOWN,
"Expression used in CREATE INDEX statement did not result in hnsw-index compatible array");
}

return HNSW_DEFAULT_DIM;
}

static int GetArrayLengthFromHeap(Relation heap, int indexCol, IndexInfo *indexInfo)
{
#if PG_VERSION_NUM < 120000
HeapScanDesc scan;
Expand Down Expand Up @@ -161,19 +227,29 @@ static int GetArrayLengthFromHeap(Relation heap, int indexCol)
return n_items;
}

// Get the indexed column out of the row and return it's dimensions
datum = heap_getattr(tuple, indexCol, RelationGetDescr(heap), &isNull);
if(!isNull) {
array = DatumGetArrayTypePCopy(datum);
n_items = ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array));
if(indexInfo->ii_Expressions != NULL) {
// We don't suport multicolumn indexes
// So trying to pass multiple expressions on index creation
// Will result an error before getting here
ldb_invariant(indexInfo->ii_Expressions->length == 1,
"Index expressions can not be greater than 1 as multicolumn indexes are not supported");
Expr *indexpr_item = lfirst(list_head(indexInfo->ii_Expressions));
n_items = GetArrayLengthFromExpression(indexpr_item, heap, tuple);
} else {
// Get the indexed column out of the row and return it's dimensions
datum = heap_getattr(tuple, indexCol, RelationGetDescr(heap), &isNull);
if(!isNull) {
array = DatumGetArrayTypePCopy(datum);
n_items = ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array));
}
}

heap_endscan(scan);

return n_items;
}

int GetHnswIndexDimensions(Relation index)
int GetHnswIndexDimensions(Relation index, IndexInfo *indexInfo)
{
HnswColumnType columnType = GetIndexColumnType(index);

Expand All @@ -195,7 +271,7 @@ int GetHnswIndexDimensions(Relation index)
#else
heap = table_open(index->rd_index->indrelid, AccessShareLock);
#endif
opt_dim = GetArrayLengthFromHeap(heap, attrNum);
opt_dim = GetArrayLengthFromHeap(heap, attrNum, indexInfo);
opts = (ldb_HnswOptions *)index->rd_options;
if(opts != NULL) {
opts->dim = opt_dim;
Expand Down Expand Up @@ -248,7 +324,7 @@ static int InferDimension(Relation heap, IndexInfo *indexInfo)
}

indexCol = indexInfo->ii_IndexAttrNumbers[ 0 ];
return GetArrayLengthFromHeap(heap, indexCol);
return GetArrayLengthFromHeap(heap, indexCol, indexInfo);
}

/*
Expand All @@ -260,7 +336,7 @@ static void InitBuildState(HnswBuildState *buildstate, Relation heap, Relation i
buildstate->index = index;
buildstate->indexInfo = indexInfo;
buildstate->columnType = GetIndexColumnType(index);
buildstate->dimensions = GetHnswIndexDimensions(index);
buildstate->dimensions = GetHnswIndexDimensions(index, indexInfo);
buildstate->index_file_path = ldb_HnswGetIndexFilePath(index);

// If a dimension wasn't specified try to infer it
Expand Down
2 changes: 1 addition & 1 deletion src/hnsw/build.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ typedef struct HnswBuildState

IndexBuildResult *ldb_ambuild(Relation heap, Relation index, IndexInfo *indexInfo);
void ldb_ambuildunlogged(Relation index);
int GetHnswIndexDimensions(Relation index);
int GetHnswIndexDimensions(Relation index, IndexInfo *indexInfo);
void CheckHnswIndexDimensions(Relation index, Datum arrayDatum, int deimensions);
// todo: does this render my check unnecessary
#endif // LDB_HNSW_BUILD_H
3 changes: 1 addition & 2 deletions src/hnsw/insert.c
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ bool ldb_aminsert(Relation index,
HnswIndexTuple *new_tuple;
usearch_init_options_t opts = {0};
LDB_UNUSED(heap);
LDB_UNUSED(indexInfo);
#if PG_VERSION_NUM >= 140000
LDB_UNUSED(indexUnchanged);
#endif
Expand Down Expand Up @@ -103,7 +102,7 @@ bool ldb_aminsert(Relation index,
hdr = (HnswIndexHeaderPage *)PageGetContents(hdr_page);
assert(hdr->magicNumber == LDB_WAL_MAGIC_NUMBER);

opts.dimensions = GetHnswIndexDimensions(index);
opts.dimensions = GetHnswIndexDimensions(index, indexInfo);
CheckHnswIndexDimensions(index, values[ 0 ], opts.dimensions);
PopulateUsearchOpts(index, &opts);
opts.retriever_ctx = ldb_wal_retriever_area_init(index, hdr);
Expand Down
2 changes: 1 addition & 1 deletion src/hooks/executor_start.c
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,4 @@ void ExecutorStart_hook_with_operator_check(QueryDesc *queryDesc, int eflags)
}

standard_ExecutorStart(queryDesc, eflags);
}
}
14 changes: 13 additions & 1 deletion src/hooks/post_parse.c
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ static bool operator_used_incorrectly_walker(Node *node, OperatorUsedCorrectlyCo
Node *arg2 = (Node *)lsecond(opExpr->args);
bool isVar1 = IsA(arg1, Var);
bool isVar2 = IsA(arg2, Var);
/* There is a case when operator is used with index
* that was created via expression (CREATE INDEX ON t USING hnsw (func(id)) WITH (M=2))
* in this case the query may look like this
* SELECT id FROM test ORDER BY func(id) <-> ARRAY[0,0,0] LIMIT 2
* or like this
* SELECT id FROM test ORDER BY func(id) <-> func(n) LIMIT 2
* we should check if IsA(arg1, FuncExpr) || IsA(arg2, FuncExpr)
* if true we may go and check the oid of function result to see if it is an array type
* we also can check that the argument of FuncExpr is at least one of the arg1 and arg2
* will contain column of the table (e.g iterate over list and check IsA(arg, Var))
* so the function will not be called with constant arguments on both sides
*/
if(isVar1 && isVar2) {
return false;
} else if(isVar1 && !isVar2) {
Expand Down Expand Up @@ -168,4 +180,4 @@ void post_parse_analyze_hook_with_operator_check(ParseState *pstate,
list_free(sort_group_refs);
}
list_free(oidList);
}
}
66 changes: 64 additions & 2 deletions test/expected/hnsw_create_expr.out
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
/*
This function, int_to_fixed_binary_real_array(n INT), will create a 3-dimensional float array (REAL[]).
It fills the array with the first 3 bits of the passed integer 'n' by converting 'n' to binary,
left-padding it to 3 digits, and then converting each digit to a REAL value.
For example, int_to_fixed_binary_real_array(1); will result in the array {0,0,1},
and int_to_fixed_binary_real_array(2); will result in {0,1,0}.
*/
CREATE OR REPLACE FUNCTION int_to_fixed_binary_real_array(n INT) RETURNS REAL[] AS $$
DECLARE
binary_string TEXT;
Expand All @@ -12,12 +19,67 @@ BEGIN
RETURN real_array;
END;
$$ LANGUAGE plpgsql IMMUTABLE;
/*
This function, int_to_dynamic_binary_real_array(n INT), will create a 3+n dimensional float array (REAL[]).
It first fills the first 3 elements of the array with the first 3 bits of the passed integer 'n'
(using a similar binary conversion as the previous function), and then adds elements sequentially from 4 to 'n+3'.
For example, int_to_dynamic_binary_real_array(3); will result in the array {0,1,1,1,2,3},
and int_to_dynamic_binary_real_array(4); will result in the array {1,0,0,1,2,3,4}.
*/
CREATE OR REPLACE FUNCTION int_to_dynamic_binary_real_array(n INT) RETURNS REAL[] AS $$
DECLARE
binary_string TEXT;
real_array REAL[] := '{}';
i INT;
result_length INT;
BEGIN
binary_string := lpad(CAST(n::BIT(3) AS TEXT), 3, '0');

-- Calculate the length of the result array
result_length := 3 + n;

FOR i IN 1..result_length
LOOP
IF i <= 3 THEN
real_array := array_append(real_array, CAST(substring(binary_string, i, 1) AS REAL));
ELSE
real_array := array_append(real_array, CAST(i - 3 AS REAL));
END IF;
END LOOP;

RETURN real_array;
END;
$$ LANGUAGE plpgsql IMMUTABLE;
/*
This simple function, int_to_string(n INT), converts the integer 'n' to a 3-character text representation
by converting 'n' to binary and left-padding it to 3 digits with '0's.
For example, int_to_string(1); will return '001', and int_to_string(2); will return '010'.
*/
CREATE OR REPLACE FUNCTION int_to_string(n INT) RETURNS TEXT AS $$
BEGIN
RETURN lpad(CAST(n::BIT(3) AS TEXT), 3, '0');
END;
$$ LANGUAGE plpgsql IMMUTABLE;
CREATE TABLE test_table (id INTEGER);
INSERT INTO test_table VALUES (0), (1), (7);
\set enable_seqscan = off;
CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id)) WITH (M=2, dim=3);
-- This should success
CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id)) WITH (M=2);
INFO: done init usearch index
INFO: inserted 3 elements
INFO: done saving 3 vectors
SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> int_to_fixed_binary_real_array(0) LIMIT 2;
\set ON_ERROR_STOP off
-- This should result in an error that dimensions does not match
CREATE INDEX ON test_table USING hnsw (int_to_dynamic_binary_real_array(id)) WITH (M=2);
INFO: done init usearch index
ERROR: Wrong number of dimensions: 4 instead of 3 expected
-- This should result in an error that data type text has no default operator class
CREATE INDEX ON test_table USING hnsw (int_to_string(id)) WITH (M=2);
ERROR: data type text has no default operator class for access method "hnsw"
-- This should result in error about multicolumn expressions support
CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id), int_to_dynamic_binary_real_array(id)) WITH (M=2);
ERROR: access method "hnsw" does not support multicolumn indexes
-- This currently results in an error about using the operator outside of index
-- This case should be fixed
SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> '{0,0,0}'::REAL[] LIMIT 2;
ERROR: Operator <-> can only be used inside of an index
22 changes: 22 additions & 0 deletions test/expected/hnsw_todo.out
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,25 @@ SELECT ROUND(l2sq_dist(v, :'v1001')::numeric, 2) FROM sift_base1k order by v <->
249285.00
(1 row)

---- Query on expression based index is failing to check correct <-> operator usage --------
CREATE OR REPLACE FUNCTION int_to_fixed_binary_real_array(n INT) RETURNS REAL[] AS $$
DECLARE
binary_string TEXT;
real_array REAL[] := '{}';
i INT;
BEGIN
binary_string := lpad(CAST(n::BIT(3) AS TEXT), 3, '0');
FOR i IN 1..length(binary_string)
LOOP
real_array := array_append(real_array, CAST(substring(binary_string, i, 1) AS REAL));
END LOOP;
RETURN real_array;
END;
$$ LANGUAGE plpgsql IMMUTABLE;
CREATE TABLE test_table (id INTEGER);
INSERT INTO test_table VALUES (0), (1), (7);
\set enable_seqscan = off;
-- This currently results in an error about using the operator outside of index
-- This case should be fixed
SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> '{0,0,0}'::REAL[] LIMIT 2;
ERROR: Operator <-> can only be used inside of an index
Loading

0 comments on commit 4c80eca

Please sign in to comment.