Skip to content

Commit

Permalink
Fix hamming_dist calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
therealdarkknight authored Oct 12, 2023
1 parent 9e3e17b commit 0c0aceb
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 55 deletions.
42 changes: 27 additions & 15 deletions src/hnsw.c
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,26 @@ static float4 array_dist(ArrayType *a, ArrayType *b, usearch_metric_kind_t metri
elog(ERROR, "expected equally sized arrays but got arrays with dimensions %d and %d", a_dim, b_dim);
}

float4 *ax = (float4 *)ARR_DATA_PTR(a);
float4 *bx = (float4 *)ARR_DATA_PTR(b);
float4 result;
bool is_int_array = (metric_kind == usearch_metric_hamming_k);

return usearch_dist(ax, bx, metric_kind, a_dim, usearch_scalar_f32_k);
if(is_int_array) {
int32 *ax_int = (int32 *)ARR_DATA_PTR(a);
int32 *bx_int = (int32 *)ARR_DATA_PTR(b);

// calling usearch_scalar_f32_k here even though it's an integer array is fine
// the hamming distance in usearch actually ignores the scalar type
// and it will get casted appropriately in usearch even with this scalar type
result = usearch_dist(ax_int, bx_int, metric_kind, a_dim, usearch_scalar_f32_k);

} else {
float4 *ax = (float4 *)ARR_DATA_PTR(a);
float4 *bx = (float4 *)ARR_DATA_PTR(b);

result = usearch_dist(ax, bx, metric_kind, a_dim, usearch_scalar_f32_k);
}

return result;
}

static float8 vector_dist(Vector *a, Vector *b, usearch_metric_kind_t metric_kind)
Expand Down Expand Up @@ -330,7 +346,7 @@ Datum hamming_dist(PG_FUNCTION_ARGS)
{
ArrayType *a = PG_GETARG_ARRAYTYPE_P(0);
ArrayType *b = PG_GETARG_ARRAYTYPE_P(1);
PG_RETURN_INT32(array_dist(a, b, usearch_metric_hamming_k));
PG_RETURN_INT32((int32)array_dist(a, b, usearch_metric_hamming_k));
}

PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_l2sq_dist);
Expand Down Expand Up @@ -371,36 +387,32 @@ HnswColumnType GetIndexColumnType(Relation index)
}

/*
* Given vector data and vector type, convert it to a float4 array
* Given vector data and vector type, read it as either a float4 or int32 array and return as void*
*/
float4 *DatumGetSizedFloatArray(Datum datum, HnswColumnType type, int dimensions)
void *DatumGetSizedArray(Datum datum, HnswColumnType type, int dimensions)
{
if(type == VECTOR) {
Vector *vector = DatumGetVector(datum);
if(vector->dim != dimensions) {
elog(ERROR, "Expected vector with dimension %d, got %d", dimensions, vector->dim);
}
return vector->x;
return (void *)vector->x;
} else if(type == REAL_ARRAY) {
ArrayType *array = DatumGetArrayTypePCopy(datum);
int array_dim = ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array));
if(array_dim != dimensions) {
elog(ERROR, "Expected real array with dimension %d, got %d", dimensions, array_dim);
}
return (float4 *)ARR_DATA_PTR(array);
return (void *)((float4 *)ARR_DATA_PTR(array));
} else if(type == INT_ARRAY) {
ArrayType *array = DatumGetArrayTypePCopy(datum);
int array_dim = ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array));
if(array_dim != dimensions) {
elog(ERROR, "Expected int array with dimension %d, got %d", dimensions, array_dim);
}
int *intArray = (int *)ARR_DATA_PTR(array);
float4 *floatArray = (float4 *)palloc(sizeof(float) * array_dim);
for(int i = 0; i < array_dim; i++) {
floatArray[ i ] = (float)intArray[ i ];
}
// todo:: free this array
return floatArray;

int32 *intArray = (int32 *)ARR_DATA_PTR(array);
return (void *)intArray;
} else {
elog(ERROR, "Unsupported type");
}
Expand Down
2 changes: 1 addition & 1 deletion src/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ PGDLLEXPORT Datum cos_dist(PG_FUNCTION_ARGS);

HnswColumnType GetColumnTypeFromOid(Oid oid);
HnswColumnType GetIndexColumnType(Relation index);
float4 *DatumGetSizedFloatArray(Datum datum, HnswColumnType type, int dimensions);
void *DatumGetSizedArray(Datum datum, HnswColumnType type, int dimensions);

#define LDB_UNUSED(x) (void)(x)

Expand Down
7 changes: 4 additions & 3 deletions src/hnsw/build.c
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,17 @@ static void AddTupleToUsearchIndex(ItemPointer tid, Datum *values, HnswBuildStat
usearch_error_t error = NULL;
Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[ 0 ]));
usearch_scalar_kind_t usearch_scalar;
float4 *vector = DatumGetSizedFloatArray(value, buildstate->columnType, buildstate->dimensions);

void *vector = DatumGetSizedArray(value, buildstate->columnType, buildstate->dimensions);
switch(buildstate->columnType) {
case REAL_ARRAY:
case VECTOR:
usearch_scalar = usearch_scalar_f32_k;
break;
case INT_ARRAY:
// q:: I think in this case we need to do a type conversion from int to float
// before passing the buffer to usearch
// this is fine, since we only use integer arrays with hamming distance metric
// and hamming distance in usearch doesn't care about scalar type
// also, usearch will appropriately cast integer arrays even with this scalar type
usearch_scalar = usearch_scalar_f32_k;
break;
default:
Expand Down
2 changes: 1 addition & 1 deletion src/hnsw/insert.c
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ bool ldb_aminsert(Relation index,
assert(!error);

datum = PointerGetDatum(PG_DETOAST_DATUM(values[ 0 ]));
float4 *vector = DatumGetSizedFloatArray(datum, insertstate->columnType, opts.dimensions);
void *vector = DatumGetSizedArray(datum, insertstate->columnType, opts.dimensions);

#if LANTERNDB_COPYNODES
// currently not fully ported to the latest changes
Expand Down
8 changes: 4 additions & 4 deletions src/hnsw/scan.c
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ bool ldb_amgettuple(IndexScanDesc scan, ScanDirection dir)
if(scanstate->first) {
int num_returned;
Datum value;
float4 *vec;
void *vec;
usearch_error_t error = NULL;
int k = ldb_hnsw_init_k;

Expand All @@ -183,7 +183,7 @@ bool ldb_amgettuple(IndexScanDesc scan, ScanDirection dir)
Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value)));
Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value)));

vec = DatumGetSizedFloatArray(value, scanstate->columnType, scanstate->dimensions);
vec = DatumGetSizedArray(value, scanstate->columnType, scanstate->dimensions);

if(scanstate->distances == NULL) {
scanstate->distances = palloc(k * sizeof(float));
Expand All @@ -209,7 +209,7 @@ bool ldb_amgettuple(IndexScanDesc scan, ScanDirection dir)
if(scanstate->current == scanstate->count) {
int num_returned;
Datum value;
float4 *vec;
void *vec;
usearch_error_t error = NULL;
int k = scanstate->count * 2;
int index_size = usearch_size(scanstate->usearch_index, &error);
Expand All @@ -221,7 +221,7 @@ bool ldb_amgettuple(IndexScanDesc scan, ScanDirection dir)

value = scan->orderByData->sk_argument;

vec = DatumGetSizedFloatArray(value, scanstate->columnType, scanstate->dimensions);
vec = DatumGetSizedArray(value, scanstate->columnType, scanstate->dimensions);

/* double k and reallocate arrays to account for increased size */
scanstate->distances = repalloc(scanstate->distances, k * sizeof(float));
Expand Down
21 changes: 20 additions & 1 deletion test/expected/hnsw_dist_func.out
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ INFO: inserted 0 elements
INFO: done saving 0 vectors
INSERT INTO small_world_l2 SELECT id, v FROM small_world;
INSERT INTO small_world_cos SELECT id, v FROM small_world;
INSERT INTO small_world_ham SELECT id, v FROM small_world;
INSERT INTO small_world_ham SELECT id, ARRAY[CAST(v[1] AS INTEGER), CAST(v[2] AS INTEGER), CAST(v[3] AS INTEGER)] FROM small_world;
SET enable_seqscan = false;
-- Verify that the distance functions work (check distances)
SELECT ROUND(l2sq_dist(v, '{0,1,0}')::numeric, 2) FROM small_world_l2 ORDER BY v <-> '{0,1,0}';
Expand Down Expand Up @@ -220,3 +220,22 @@ WITH t AS (SELECT id FROM test1 ORDER BY v <-> '{1,2}' LIMIT 1) SELECT id, COUNT
ERROR: Operator <-> can only be used inside of an index
WITH t AS (SELECT id FROM test1 ORDER BY v <-> '{1,2}') SELECT id FROM t UNION SELECT id FROM t;
ERROR: Operator <-> can only be used inside of an index
-- Check that hamming distance query results are sorted correctly
CREATE TABLE extra_small_world_ham (
id SERIAL PRIMARY KEY,
v INT[2]
);
INSERT INTO extra_small_world_ham (v) VALUES ('{0,0}'), ('{1,1}'), ('{2,2}'), ('{3,3}');
CREATE INDEX ON extra_small_world_ham USING hnsw (v dist_hamming_ops) WITH (dim=2);
INFO: done init usearch index
INFO: inserted 4 elements
INFO: done saving 4 vectors
SELECT ROUND(hamming_dist(v, '{0,0}')::numeric, 2) FROM extra_small_world_ham ORDER BY v <-> '{0,0}';
round
-------
0.00
2.00
2.00
4.00
(4 rows)

19 changes: 0 additions & 19 deletions test/expected/hnsw_todo.out
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,6 @@ SELECT id, ROUND(l2sq_dist(vector_int, array[0,1,0])::numeric, 2) as dist
FROM small_world_l2
ORDER BY vector_int <-> array[0,1,0] LIMIT 7;
ERROR: Operator <-> can only be used inside of an index
-- this result is not sorted correctly
CREATE TABLE small_world_ham (
id SERIAL PRIMARY KEY,
v INT[2]
);
INSERT INTO small_world_ham (v) VALUES ('{0,0}'), ('{1,1}'), ('{2,2}'), ('{3,3}');
CREATE INDEX ON small_world_ham USING hnsw (v dist_hamming_ops) WITH (dim=2);
INFO: done init usearch index
INFO: inserted 4 elements
INFO: done saving 4 vectors
SELECT ROUND(hamming_dist(v, '{0,0}')::numeric, 2) FROM small_world_ham ORDER BY v <-> '{0,0}';
round
-------
0.00
2.00
4.00
2.00
(4 rows)

--- Test scenarious ---
-----------------------------------------
-- Case:
Expand Down
13 changes: 11 additions & 2 deletions test/sql/hnsw_dist_func.sql
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ CREATE INDEX ON small_world_ham USING hnsw (v dist_hamming_ops) WITH (dim=3);

INSERT INTO small_world_l2 SELECT id, v FROM small_world;
INSERT INTO small_world_cos SELECT id, v FROM small_world;
INSERT INTO small_world_ham SELECT id, v FROM small_world;
INSERT INTO small_world_ham SELECT id, ARRAY[CAST(v[1] AS INTEGER), CAST(v[2] AS INTEGER), CAST(v[3] AS INTEGER)] FROM small_world;

SET enable_seqscan = false;

Expand Down Expand Up @@ -88,4 +88,13 @@ SELECT 1 FROM test1 ORDER BY v <-> (SELECT '{1,3}'::real[]);
SELECT t2_results.id FROM test1 t1 JOIN LATERAL (SELECT t2.id FROM test2 t2 ORDER BY t1.v <-> t2.v LIMIT 1) t2_results ON TRUE;
WITH t AS (SELECT id FROM test1 ORDER BY v <-> '{1,2}' LIMIT 1) SELECT DISTINCT id FROM t;
WITH t AS (SELECT id FROM test1 ORDER BY v <-> '{1,2}' LIMIT 1) SELECT id, COUNT(*) FROM t GROUP BY 1;
WITH t AS (SELECT id FROM test1 ORDER BY v <-> '{1,2}') SELECT id FROM t UNION SELECT id FROM t;
WITH t AS (SELECT id FROM test1 ORDER BY v <-> '{1,2}') SELECT id FROM t UNION SELECT id FROM t;

-- Check that hamming distance query results are sorted correctly
CREATE TABLE extra_small_world_ham (
id SERIAL PRIMARY KEY,
v INT[2]
);
INSERT INTO extra_small_world_ham (v) VALUES ('{0,0}'), ('{1,1}'), ('{2,2}'), ('{3,3}');
CREATE INDEX ON extra_small_world_ham USING hnsw (v dist_hamming_ops) WITH (dim=2);
SELECT ROUND(hamming_dist(v, '{0,0}')::numeric, 2) FROM extra_small_world_ham ORDER BY v <-> '{0,0}';
9 changes: 0 additions & 9 deletions test/sql/hnsw_todo.sql
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,6 @@ SELECT id, ROUND(l2sq_dist(vector_int, array[0,1,0])::numeric, 2) as dist
FROM small_world_l2
ORDER BY vector_int <-> array[0,1,0] LIMIT 7;

-- this result is not sorted correctly
CREATE TABLE small_world_ham (
id SERIAL PRIMARY KEY,
v INT[2]
);
INSERT INTO small_world_ham (v) VALUES ('{0,0}'), ('{1,1}'), ('{2,2}'), ('{3,3}');
CREATE INDEX ON small_world_ham USING hnsw (v dist_hamming_ops) WITH (dim=2);
SELECT ROUND(hamming_dist(v, '{0,0}')::numeric, 2) FROM small_world_ham ORDER BY v <-> '{0,0}';

--- Test scenarious ---
-----------------------------------------
-- Case:
Expand Down

0 comments on commit 0c0aceb

Please sign in to comment.