Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer array dimensions from index expressions #175

Merged
merged 5 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions scripts/extern_defined.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ SED_PATTERN='s/@/@/p' # noop pattern
nm -D --with-symbol-versions $PG_BIN | grep " T " | awk '{print $3}' | sed -e "$SED_PATTERN"
# global bss symbol in postgres
nm -D --with-symbol-versions $PG_BIN | grep " B " | awk '{print $3}' | sed -e "$SED_PATTERN"
# postgres Initialized data (bbs), global symbols
nm -D --with-symbol-versions $PG_BIN | grep " D " | awk '{print $3}' | sed -e "$SED_PATTERN"
# postgres weak symbols
nm -D --with-symbol-versions $PG_BIN | grep " w " | awk '{print $2}' | sed -e "$SED_PATTERN"

Expand Down
3 changes: 2 additions & 1 deletion scripts/run_all_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,14 @@ pgvector_installed=$($PSQL -U $DB_USER -d postgres -c "SELECT 1 FROM pg_availabl
rm -rf $TMP_OUTDIR/schedule.txt
if [ -n "$FILTER" ]; then
if [[ "$pgvector_installed" == "1" ]]; then
TEST_FILES=$(cat schedule.txt | grep -E '^(test:|test_pgvector:)' | sed -e 's/^\(test:\|test_pgvector:\)//' | tr " " "\n" | sed -e '/^$/d')
TEST_FILES=$(cat schedule.txt | grep -E '^(test:|test_pgvector:)' | sed -E -e 's/^test:|test_pgvector://' | tr " " "\n" | sed -e '/^$/d')
else
TEST_FILES=$(cat schedule.txt | grep '^test:' | sed -e 's/^test://' | tr " " "\n" | sed -e '/^$/d')
fi

while IFS= read -r f; do
if [[ $f == *"$FILTER"* ]]; then
echo "HERE $f"
echo "test: $f" >> $TMP_OUTDIR/schedule.txt
fi
done <<< "$TEST_FILES"
Expand Down
27 changes: 18 additions & 9 deletions src/hnsw.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <catalog/namespace.h>
#include <commands/vacuum.h>
#include <float.h>
#include <hnsw/utils.h>
var77 marked this conversation as resolved.
Show resolved Hide resolved
#include <math.h>
#include <utils/guc.h>
#include <utils/lsyscache.h>
Expand Down Expand Up @@ -343,25 +344,33 @@ Datum vector_l2sq_dist(PG_FUNCTION_ARGS)
}

/*
* Get data type of index
*/
HnswColumnType GetIndexColumnType(Relation index)
* Get data type for give oid
* */
HnswColumnType GetColumnTypeFromOid(Oid oid)
{
TupleDesc indexTupDesc = RelationGetDescr(index);
Form_pg_attribute attr = TupleDescAttr(indexTupDesc, 0);
Oid columnType = attr->atttypid;
ldb_invariant(OidIsValid(oid), "Invalid oid passed");

if(columnType == FLOAT4ARRAYOID) {
if(oid == FLOAT4ARRAYOID) {
return REAL_ARRAY;
} else if(columnType == TypenameGetTypid("vector")) {
} else if(oid == TypenameGetTypid("vector")) {
return VECTOR;
} else if(columnType == INT4ARRAYOID) {
} else if(oid == INT4ARRAYOID) {
return INT_ARRAY;
} else {
return UNKNOWN;
}
}

/*
* Get data type of index
*/
HnswColumnType GetIndexColumnType(Relation index)
{
TupleDesc indexTupDesc = RelationGetDescr(index);
Form_pg_attribute attr = TupleDescAttr(indexTupDesc, 0);
return GetColumnTypeFromOid(attr->atttypid);
}

/*
* Given vector data and vector type, convert it to a float4 array
*/
Expand Down
1 change: 1 addition & 0 deletions src/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ PGDLLEXPORT Datum vector_l2sq_dist(PG_FUNCTION_ARGS);
PGDLLEXPORT Datum hamming_dist(PG_FUNCTION_ARGS);
PGDLLEXPORT Datum cos_dist(PG_FUNCTION_ARGS);

HnswColumnType GetColumnTypeFromOid(Oid oid);
HnswColumnType GetIndexColumnType(Relation index);
float4 *DatumGetSizedFloatArray(Datum datum, HnswColumnType type, int dimensions);

Expand Down
96 changes: 86 additions & 10 deletions src/hnsw/build.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <catalog/index.h>
#include <catalog/namespace.h>
#include <catalog/pg_type.h>
#include <executor/executor.h>
#include <nodes/execnodes.h>
#include <storage/bufmgr.h>
#include <utils/array.h>
#include <utils/lsyscache.h>
Expand Down Expand Up @@ -132,7 +134,71 @@ static void BuildCallback(
MemoryContextReset(buildstate->tmpCtx);
}

static int GetArrayLengthFromHeap(Relation heap, int indexCol)
static int GetArrayLengthFromExpression(Expr *expression, Relation heap, HeapTuple tuple)
{
ExprContext *econtext;
ExprState *exprstate;
EState *estate;
Datum result;
bool isNull;
ArrayType *array;
TupleTableSlot *slot;
TupleDesc tupdesc = RelationGetDescr(heap);

#if PG_VERSION_NUM >= 120000
slot = MakeSingleTupleTableSlot(tupdesc, &TTSOpsHeapTuple);
#else
slot = MakeSingleTupleTableSlot(tupdesc);
#endif

// Create an expression context
econtext = CreateStandaloneExprContext();
estate = CreateExecutorState();

// Build the expression state for your expression
exprstate = ExecPrepareExpr(expression, estate);

#if PG_VERSION_NUM >= 120000
ExecStoreHeapTuple(tuple, slot, false);
#else
ExecStoreTuple(tuple, slot, InvalidBuffer, false);
#endif
// Set up the tuple for the expression evaluation
econtext->ecxt_scantuple = slot;

// Evaluate the expression for the first row
result = ExecEvalExprSwitchContext(exprstate, econtext, &isNull);
var77 marked this conversation as resolved.
Show resolved Hide resolved
array = DatumGetArrayTypeP(result);

// Release tuple descriptor
ReleaseTupleDesc(tupdesc);

Oid oid = get_array_type(array->elemtype);

if(!OidIsValid(oid)) {
// Oid 0 can only be in a case when the datum will be vector type
// As postgres will guard that arbitrary data type won't be returned
// from the function expression
// so we should be safe here to case result into vector in this case
var77 marked this conversation as resolved.
Show resolved Hide resolved
Vector *vector = DatumGetVector(result);
// if the vector will not have dimensions or somehow it
// won't be casted the code will go to ldb_invariant case
if(vector->dim > 0) {
return vector->dim;
}
}
// Check if the result is not null and is supported type
// There is a guard in postgres that wont' allow passing
// Anything else from the defined operator class types
// Throwing an error like: ERROR: data type text has no default operator class for access method "hnsw"
// So this case will be marked as invariant
ldb_invariant(!isNull && GetColumnTypeFromOid(oid) != UNKNOWN,
"Expression used in CREATE INDEX statement did not result in hnsw-index compatible array");

return ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array));
}

static int GetArrayLengthFromHeap(Relation heap, int indexCol, IndexInfo *indexInfo)
{
#if PG_VERSION_NUM < 120000
HeapScanDesc scan;
Expand Down Expand Up @@ -161,19 +227,29 @@ static int GetArrayLengthFromHeap(Relation heap, int indexCol)
return n_items;
}

// Get the indexed column out of the row and return it's dimensions
datum = heap_getattr(tuple, indexCol, RelationGetDescr(heap), &isNull);
if(!isNull) {
array = DatumGetArrayTypePCopy(datum);
n_items = ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array));
if(indexInfo->ii_Expressions != NULL) {
// We don't suport multicolumn indexes
// So trying to pass multiple expressions on index creation
// Will result an error before getting here
ldb_invariant(indexInfo->ii_Expressions->length == 1,
"Index expressions can not be greater than 1 as multicolumn indexes are not supported");
Expr *indexpr_item = lfirst(list_head(indexInfo->ii_Expressions));
var77 marked this conversation as resolved.
Show resolved Hide resolved
n_items = GetArrayLengthFromExpression(indexpr_item, heap, tuple);
} else {
// Get the indexed column out of the row and return it's dimensions
datum = heap_getattr(tuple, indexCol, RelationGetDescr(heap), &isNull);
if(!isNull) {
array = DatumGetArrayTypePCopy(datum);
n_items = ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array));
}
}

heap_endscan(scan);

return n_items;
}

int GetHnswIndexDimensions(Relation index)
int GetHnswIndexDimensions(Relation index, IndexInfo *indexInfo)
{
HnswColumnType columnType = GetIndexColumnType(index);

Expand All @@ -195,7 +271,7 @@ int GetHnswIndexDimensions(Relation index)
#else
heap = table_open(index->rd_index->indrelid, AccessShareLock);
#endif
opt_dim = GetArrayLengthFromHeap(heap, attrNum);
opt_dim = GetArrayLengthFromHeap(heap, attrNum, indexInfo);
opts = (ldb_HnswOptions *)index->rd_options;
if(opts != NULL) {
opts->dim = opt_dim;
Expand Down Expand Up @@ -248,7 +324,7 @@ static int InferDimension(Relation heap, IndexInfo *indexInfo)
}

indexCol = indexInfo->ii_IndexAttrNumbers[ 0 ];
return GetArrayLengthFromHeap(heap, indexCol);
return GetArrayLengthFromHeap(heap, indexCol, indexInfo);
}

/*
Expand All @@ -260,7 +336,7 @@ static void InitBuildState(HnswBuildState *buildstate, Relation heap, Relation i
buildstate->index = index;
buildstate->indexInfo = indexInfo;
buildstate->columnType = GetIndexColumnType(index);
buildstate->dimensions = GetHnswIndexDimensions(index);
buildstate->dimensions = GetHnswIndexDimensions(index, indexInfo);
buildstate->index_file_path = ldb_HnswGetIndexFilePath(index);

// If a dimension wasn't specified try to infer it
Expand Down
2 changes: 1 addition & 1 deletion src/hnsw/build.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ typedef struct HnswBuildState

IndexBuildResult *ldb_ambuild(Relation heap, Relation index, IndexInfo *indexInfo);
void ldb_ambuildunlogged(Relation index);
int GetHnswIndexDimensions(Relation index);
int GetHnswIndexDimensions(Relation index, IndexInfo *indexInfo);
void CheckHnswIndexDimensions(Relation index, Datum arrayDatum, int deimensions);
// todo: does this render my check unnecessary
#endif // LDB_HNSW_BUILD_H
3 changes: 1 addition & 2 deletions src/hnsw/insert.c
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ bool ldb_aminsert(Relation index,
HnswIndexTuple *new_tuple;
usearch_init_options_t opts = {0};
LDB_UNUSED(heap);
LDB_UNUSED(indexInfo);
#if PG_VERSION_NUM >= 140000
LDB_UNUSED(indexUnchanged);
#endif
Expand Down Expand Up @@ -103,7 +102,7 @@ bool ldb_aminsert(Relation index,
hdr = (HnswIndexHeaderPage *)PageGetContents(hdr_page);
assert(hdr->magicNumber == LDB_WAL_MAGIC_NUMBER);

opts.dimensions = GetHnswIndexDimensions(index);
opts.dimensions = GetHnswIndexDimensions(index, indexInfo);
CheckHnswIndexDimensions(index, values[ 0 ], opts.dimensions);
PopulateUsearchOpts(index, &opts);
opts.retriever_ctx = ldb_wal_retriever_area_init(index, hdr);
Expand Down
2 changes: 1 addition & 1 deletion src/hooks/executor_start.c
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,4 @@ void ExecutorStart_hook_with_operator_check(QueryDesc *queryDesc, int eflags)
}

standard_ExecutorStart(queryDesc, eflags);
}
}
14 changes: 13 additions & 1 deletion src/hooks/post_parse.c
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ static bool operator_used_incorrectly_walker(Node *node, OperatorUsedCorrectlyCo
Node *arg2 = (Node *)lsecond(opExpr->args);
bool isVar1 = IsA(arg1, Var);
bool isVar2 = IsA(arg2, Var);
/* There is a case when operator is used with index
* that was created via expression (CREATE INDEX ON t USING hnsw (func(id)) WITH (M=2))
* in this case the query may look like this
* SELECT id FROM test ORDER BY func(id) <-> ARRAY[0,0,0] LIMIT 2
* or like this
* SELECT id FROM test ORDER BY func(id) <-> func(n) LIMIT 2
* we should check if IsA(arg1, FuncExpr) || IsA(arg2, FuncExpr)
* if true we may go and check the oid of function result to see if it is an array type
* we also can check that the argument of FuncExpr is at least one of the arg1 and arg2
* will contain column of the table (e.g iterate over list and check IsA(arg, Var))
* so the function will not be called with constant arguments on both sides
*/
if(isVar1 && isVar2) {
return false;
} else if(isVar1 && !isVar2) {
Expand Down Expand Up @@ -168,4 +180,4 @@ void post_parse_analyze_hook_with_operator_check(ParseState *pstate,
list_free(sort_group_refs);
}
list_free(oidList);
}
}
45 changes: 44 additions & 1 deletion test/expected/hnsw_create_expr.out
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,55 @@ BEGIN
RETURN real_array;
END;
$$ LANGUAGE plpgsql IMMUTABLE;
CREATE OR REPLACE FUNCTION int_to_dynamic_binary_real_array(n INT) RETURNS REAL[] AS $$
DECLARE
binary_string TEXT;
real_array REAL[] := '{}';
i INT;
result_length INT;
BEGIN
binary_string := lpad(CAST(n::BIT(3) AS TEXT), 3, '0');

-- Calculate the length of the result array
result_length := 3 + n;

FOR i IN 1..result_length
LOOP
IF i <= 3 THEN
real_array := array_append(real_array, CAST(substring(binary_string, i, 1) AS REAL));
ELSE
real_array := array_append(real_array, CAST(i - 3 AS REAL));
END IF;
END LOOP;

RETURN real_array;
END;
$$ LANGUAGE plpgsql IMMUTABLE;
CREATE OR REPLACE FUNCTION int_to_string(n INT) RETURNS TEXT AS $$
BEGIN
RETURN lpad(CAST(n::BIT(3) AS TEXT), 3, '0');
END;
$$ LANGUAGE plpgsql IMMUTABLE;
CREATE TABLE test_table (id INTEGER);
INSERT INTO test_table VALUES (0), (1), (7);
\set enable_seqscan = off;
CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id)) WITH (M=2, dim=3);
-- This should success
CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id)) WITH (M=2);
INFO: done init usearch index
INFO: inserted 3 elements
INFO: done saving 3 vectors
\set ON_ERROR_STOP off
-- This should result in an error that dimensions does not match
CREATE INDEX ON test_table USING hnsw (int_to_dynamic_binary_real_array(id)) WITH (M=2);
INFO: done init usearch index
ERROR: Wrong number of dimensions: 4 instead of 3 expected
-- This should result in an error that data type text has no default operator class
CREATE INDEX ON test_table USING hnsw (int_to_string(id)) WITH (M=2);
ERROR: data type text has no default operator class for access method "hnsw"
-- This should result in error about multicolumn expressions support
CREATE INDEX ON test_table USING hnsw (int_to_fixed_binary_real_array(id), int_to_dynamic_binary_real_array(id)) WITH (M=2);
ERROR: access method "hnsw" does not support multicolumn indexes
-- This currently results in an error about using the operator outside of index
-- This case should be fixed
SELECT id FROM test_table ORDER BY int_to_fixed_binary_real_array(id) <-> int_to_fixed_binary_real_array(0) LIMIT 2;
ERROR: Operator <-> can only be used inside of an index
21 changes: 21 additions & 0 deletions test/expected/hnsw_vector.out
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,24 @@ SELECT 1 FROM small_world ORDER BY v <-> '[0,1,0,1]' LIMIT 1;
ERROR: Expected vector with dimension 3, got 4
SELECT l2sq_dist('[1,1]'::vector, '[0,1,0]'::vector);
ERROR: expected equally sized vectors but got vectors with dimensions 2 and 3
-- Test creating index with expression
CREATE TABLE test_table (id INTEGER);
INSERT INTO test_table VALUES (0), (1), (7);
CREATE OR REPLACE FUNCTION int_to_fixed_binary_vector(n INT) RETURNS vector AS $$
DECLARE
binary_string TEXT;
real_array REAL[] := '{}';
i INT;
BEGIN
binary_string := lpad(CAST(n::BIT(3) AS TEXT), 3, '0');
FOR i IN 1..length(binary_string)
LOOP
real_array := array_append(real_array, CAST(substring(binary_string, i, 1) AS REAL));
END LOOP;
RETURN real_array::vector;
END;
$$ LANGUAGE plpgsql IMMUTABLE;
CREATE INDEX ON test_table USING lantern_hnsw (int_to_fixed_binary_vector(id)) WITH (M=2);
INFO: done init usearch index
INFO: inserted 3 elements
INFO: done saving 3 vectors
Loading