diff --git a/src/hnsw.c b/src/hnsw.c index ae811311e..6e475ff1e 100644 --- a/src/hnsw.c +++ b/src/hnsw.c @@ -291,10 +291,26 @@ static float4 array_dist(ArrayType *a, ArrayType *b, usearch_metric_kind_t metri elog(ERROR, "expected equally sized arrays but got arrays with dimensions %d and %d", a_dim, b_dim); } - float4 *ax = (float4 *)ARR_DATA_PTR(a); - float4 *bx = (float4 *)ARR_DATA_PTR(b); + float4 result; + bool is_int_array = (metric_kind == usearch_metric_hamming_k); - return usearch_dist(ax, bx, metric_kind, a_dim, usearch_scalar_f32_k); + if(is_int_array) { + int32 *ax_int = (int32 *)ARR_DATA_PTR(a); + int32 *bx_int = (int32 *)ARR_DATA_PTR(b); + + // calling usearch_scalar_f32_k here even though it's an integer array is fine + // the hamming distance in usearch actually ignores the scalar type + // and it will get casted appropriately in usearch even with this scalar type + result = usearch_dist(ax_int, bx_int, metric_kind, a_dim, usearch_scalar_f32_k); + + } else { + float4 *ax = (float4 *)ARR_DATA_PTR(a); + float4 *bx = (float4 *)ARR_DATA_PTR(b); + + result = usearch_dist(ax, bx, metric_kind, a_dim, usearch_scalar_f32_k); + } + + return result; } static float8 vector_dist(Vector *a, Vector *b, usearch_metric_kind_t metric_kind) @@ -330,7 +346,7 @@ Datum hamming_dist(PG_FUNCTION_ARGS) { ArrayType *a = PG_GETARG_ARRAYTYPE_P(0); ArrayType *b = PG_GETARG_ARRAYTYPE_P(1); - PG_RETURN_INT32(array_dist(a, b, usearch_metric_hamming_k)); + PG_RETURN_INT32((int32)array_dist(a, b, usearch_metric_hamming_k)); } PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_l2sq_dist); @@ -371,36 +387,32 @@ HnswColumnType GetIndexColumnType(Relation index) } /* - * Given vector data and vector type, convert it to a float4 array + * Given vector data and vector type, read it as either a float4 or int32 array and return as void* */ -float4 *DatumGetSizedFloatArray(Datum datum, HnswColumnType type, int dimensions) +void *DatumGetSizedArray(Datum datum, HnswColumnType type, int dimensions) { if(type == VECTOR) { Vector *vector = DatumGetVector(datum); if(vector->dim != dimensions) { elog(ERROR, "Expected vector with dimension %d, got %d", dimensions, vector->dim); } - return vector->x; + return (void *)vector->x; } else if(type == REAL_ARRAY) { ArrayType *array = DatumGetArrayTypePCopy(datum); int array_dim = ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array)); if(array_dim != dimensions) { elog(ERROR, "Expected real array with dimension %d, got %d", dimensions, array_dim); } - return (float4 *)ARR_DATA_PTR(array); + return (void *)((float4 *)ARR_DATA_PTR(array)); } else if(type == INT_ARRAY) { ArrayType *array = DatumGetArrayTypePCopy(datum); int array_dim = ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array)); if(array_dim != dimensions) { elog(ERROR, "Expected int array with dimension %d, got %d", dimensions, array_dim); } - int *intArray = (int *)ARR_DATA_PTR(array); - float4 *floatArray = (float4 *)palloc(sizeof(float) * array_dim); - for(int i = 0; i < array_dim; i++) { - floatArray[ i ] = (float)intArray[ i ]; - } - // todo:: free this array - return floatArray; + + int32 *intArray = (int32 *)ARR_DATA_PTR(array); + return (void *)intArray; } else { elog(ERROR, "Unsupported type"); } diff --git a/src/hnsw.h b/src/hnsw.h index 0542f705e..570ed6bf2 100644 --- a/src/hnsw.h +++ b/src/hnsw.h @@ -34,7 +34,7 @@ PGDLLEXPORT Datum cos_dist(PG_FUNCTION_ARGS); HnswColumnType GetColumnTypeFromOid(Oid oid); HnswColumnType GetIndexColumnType(Relation index); -float4 *DatumGetSizedFloatArray(Datum datum, HnswColumnType type, int dimensions); +void *DatumGetSizedArray(Datum datum, HnswColumnType type, int dimensions); #define LDB_UNUSED(x) (void)(x) diff --git a/src/hnsw/build.c b/src/hnsw/build.c index 89bc5d739..13e53d217 100644 --- a/src/hnsw/build.c +++ b/src/hnsw/build.c @@ -65,16 +65,17 @@ static void AddTupleToUsearchIndex(ItemPointer tid, Datum *values, HnswBuildStat usearch_error_t error = NULL; Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[ 0 ])); usearch_scalar_kind_t usearch_scalar; - float4 *vector = DatumGetSizedFloatArray(value, buildstate->columnType, buildstate->dimensions); + void *vector = DatumGetSizedArray(value, buildstate->columnType, buildstate->dimensions); switch(buildstate->columnType) { case REAL_ARRAY: case VECTOR: usearch_scalar = usearch_scalar_f32_k; break; case INT_ARRAY: - // q:: I think in this case we need to do a type conversion from int to float - // before passing the buffer to usearch + // this is fine, since we only use integer arrays with hamming distance metric + // and hamming distance in usearch doesn't care about scalar type + // also, usearch will appropriately cast integer arrays even with this scalar type usearch_scalar = usearch_scalar_f32_k; break; default: diff --git a/src/hnsw/insert.c b/src/hnsw/insert.c index d66e9b919..a669fdcb6 100644 --- a/src/hnsw/insert.c +++ b/src/hnsw/insert.c @@ -132,7 +132,7 @@ bool ldb_aminsert(Relation index, assert(!error); datum = PointerGetDatum(PG_DETOAST_DATUM(values[ 0 ])); - float4 *vector = DatumGetSizedFloatArray(datum, insertstate->columnType, opts.dimensions); + void *vector = DatumGetSizedArray(datum, insertstate->columnType, opts.dimensions); #if LANTERNDB_COPYNODES // currently not fully ported to the latest changes diff --git a/src/hnsw/scan.c b/src/hnsw/scan.c index 1bf384f0f..175b0063c 100644 --- a/src/hnsw/scan.c +++ b/src/hnsw/scan.c @@ -162,7 +162,7 @@ bool ldb_amgettuple(IndexScanDesc scan, ScanDirection dir) if(scanstate->first) { int num_returned; Datum value; - float4 *vec; + void *vec; usearch_error_t error = NULL; int k = ldb_hnsw_init_k; @@ -183,7 +183,7 @@ bool ldb_amgettuple(IndexScanDesc scan, ScanDirection dir) Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value))); Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value))); - vec = DatumGetSizedFloatArray(value, scanstate->columnType, scanstate->dimensions); + vec = DatumGetSizedArray(value, scanstate->columnType, scanstate->dimensions); if(scanstate->distances == NULL) { scanstate->distances = palloc(k * sizeof(float)); @@ -209,7 +209,7 @@ bool ldb_amgettuple(IndexScanDesc scan, ScanDirection dir) if(scanstate->current == scanstate->count) { int num_returned; Datum value; - float4 *vec; + void *vec; usearch_error_t error = NULL; int k = scanstate->count * 2; int index_size = usearch_size(scanstate->usearch_index, &error); @@ -221,7 +221,7 @@ bool ldb_amgettuple(IndexScanDesc scan, ScanDirection dir) value = scan->orderByData->sk_argument; - vec = DatumGetSizedFloatArray(value, scanstate->columnType, scanstate->dimensions); + vec = DatumGetSizedArray(value, scanstate->columnType, scanstate->dimensions); /* double k and reallocate arrays to account for increased size */ scanstate->distances = repalloc(scanstate->distances, k * sizeof(float)); diff --git a/test/expected/hnsw_dist_func.out b/test/expected/hnsw_dist_func.out index 8f91465dd..78acab3a5 100644 --- a/test/expected/hnsw_dist_func.out +++ b/test/expected/hnsw_dist_func.out @@ -33,7 +33,7 @@ INFO: inserted 0 elements INFO: done saving 0 vectors INSERT INTO small_world_l2 SELECT id, v FROM small_world; INSERT INTO small_world_cos SELECT id, v FROM small_world; -INSERT INTO small_world_ham SELECT id, v FROM small_world; +INSERT INTO small_world_ham SELECT id, ARRAY[CAST(v[1] AS INTEGER), CAST(v[2] AS INTEGER), CAST(v[3] AS INTEGER)] FROM small_world; SET enable_seqscan = false; -- Verify that the distance functions work (check distances) SELECT ROUND(l2sq_dist(v, '{0,1,0}')::numeric, 2) FROM small_world_l2 ORDER BY v <-> '{0,1,0}'; @@ -220,3 +220,22 @@ WITH t AS (SELECT id FROM test1 ORDER BY v <-> '{1,2}' LIMIT 1) SELECT id, COUNT ERROR: Operator <-> can only be used inside of an index WITH t AS (SELECT id FROM test1 ORDER BY v <-> '{1,2}') SELECT id FROM t UNION SELECT id FROM t; ERROR: Operator <-> can only be used inside of an index +-- Check that hamming distance query results are sorted correctly +CREATE TABLE extra_small_world_ham ( + id SERIAL PRIMARY KEY, + v INT[2] +); +INSERT INTO extra_small_world_ham (v) VALUES ('{0,0}'), ('{1,1}'), ('{2,2}'), ('{3,3}'); +CREATE INDEX ON extra_small_world_ham USING hnsw (v dist_hamming_ops) WITH (dim=2); +INFO: done init usearch index +INFO: inserted 4 elements +INFO: done saving 4 vectors +SELECT ROUND(hamming_dist(v, '{0,0}')::numeric, 2) FROM extra_small_world_ham ORDER BY v <-> '{0,0}'; + round +------- + 0.00 + 2.00 + 2.00 + 4.00 +(4 rows) + diff --git a/test/expected/hnsw_todo.out b/test/expected/hnsw_todo.out index 8eb3b1b75..70f0824ee 100644 --- a/test/expected/hnsw_todo.out +++ b/test/expected/hnsw_todo.out @@ -29,25 +29,6 @@ SELECT id, ROUND(l2sq_dist(vector_int, array[0,1,0])::numeric, 2) as dist FROM small_world_l2 ORDER BY vector_int <-> array[0,1,0] LIMIT 7; ERROR: Operator <-> can only be used inside of an index --- this result is not sorted correctly -CREATE TABLE small_world_ham ( - id SERIAL PRIMARY KEY, - v INT[2] -); -INSERT INTO small_world_ham (v) VALUES ('{0,0}'), ('{1,1}'), ('{2,2}'), ('{3,3}'); -CREATE INDEX ON small_world_ham USING hnsw (v dist_hamming_ops) WITH (dim=2); -INFO: done init usearch index -INFO: inserted 4 elements -INFO: done saving 4 vectors -SELECT ROUND(hamming_dist(v, '{0,0}')::numeric, 2) FROM small_world_ham ORDER BY v <-> '{0,0}'; - round -------- - 0.00 - 2.00 - 4.00 - 2.00 -(4 rows) - --- Test scenarious --- ----------------------------------------- -- Case: diff --git a/test/sql/hnsw_dist_func.sql b/test/sql/hnsw_dist_func.sql index 4a0b61409..4184fb35f 100644 --- a/test/sql/hnsw_dist_func.sql +++ b/test/sql/hnsw_dist_func.sql @@ -14,7 +14,7 @@ CREATE INDEX ON small_world_ham USING hnsw (v dist_hamming_ops) WITH (dim=3); INSERT INTO small_world_l2 SELECT id, v FROM small_world; INSERT INTO small_world_cos SELECT id, v FROM small_world; -INSERT INTO small_world_ham SELECT id, v FROM small_world; +INSERT INTO small_world_ham SELECT id, ARRAY[CAST(v[1] AS INTEGER), CAST(v[2] AS INTEGER), CAST(v[3] AS INTEGER)] FROM small_world; SET enable_seqscan = false; @@ -88,4 +88,13 @@ SELECT 1 FROM test1 ORDER BY v <-> (SELECT '{1,3}'::real[]); SELECT t2_results.id FROM test1 t1 JOIN LATERAL (SELECT t2.id FROM test2 t2 ORDER BY t1.v <-> t2.v LIMIT 1) t2_results ON TRUE; WITH t AS (SELECT id FROM test1 ORDER BY v <-> '{1,2}' LIMIT 1) SELECT DISTINCT id FROM t; WITH t AS (SELECT id FROM test1 ORDER BY v <-> '{1,2}' LIMIT 1) SELECT id, COUNT(*) FROM t GROUP BY 1; -WITH t AS (SELECT id FROM test1 ORDER BY v <-> '{1,2}') SELECT id FROM t UNION SELECT id FROM t; \ No newline at end of file +WITH t AS (SELECT id FROM test1 ORDER BY v <-> '{1,2}') SELECT id FROM t UNION SELECT id FROM t; + +-- Check that hamming distance query results are sorted correctly +CREATE TABLE extra_small_world_ham ( + id SERIAL PRIMARY KEY, + v INT[2] +); +INSERT INTO extra_small_world_ham (v) VALUES ('{0,0}'), ('{1,1}'), ('{2,2}'), ('{3,3}'); +CREATE INDEX ON extra_small_world_ham USING hnsw (v dist_hamming_ops) WITH (dim=2); +SELECT ROUND(hamming_dist(v, '{0,0}')::numeric, 2) FROM extra_small_world_ham ORDER BY v <-> '{0,0}'; \ No newline at end of file diff --git a/test/sql/hnsw_todo.sql b/test/sql/hnsw_todo.sql index 2159d7eba..07753c60d 100644 --- a/test/sql/hnsw_todo.sql +++ b/test/sql/hnsw_todo.sql @@ -32,15 +32,6 @@ SELECT id, ROUND(l2sq_dist(vector_int, array[0,1,0])::numeric, 2) as dist FROM small_world_l2 ORDER BY vector_int <-> array[0,1,0] LIMIT 7; --- this result is not sorted correctly -CREATE TABLE small_world_ham ( - id SERIAL PRIMARY KEY, - v INT[2] -); -INSERT INTO small_world_ham (v) VALUES ('{0,0}'), ('{1,1}'), ('{2,2}'), ('{3,3}'); -CREATE INDEX ON small_world_ham USING hnsw (v dist_hamming_ops) WITH (dim=2); -SELECT ROUND(hamming_dist(v, '{0,0}')::numeric, 2) FROM small_world_ham ORDER BY v <-> '{0,0}'; - --- Test scenarious --- ----------------------------------------- -- Case: