Skip to content

Commit

Permalink
fixed hamming distance calculation on sql function as well as interna…
Browse files Browse the repository at this point in the history
…l use in the index (builds, inserts, scans), and updated tests
  • Loading branch information
therealdarkknight committed Oct 10, 2023
1 parent 7225b4c commit 3dba819
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 46 deletions.
54 changes: 20 additions & 34 deletions src/hnsw.c
Original file line number Diff line number Diff line change
Expand Up @@ -291,33 +291,23 @@ static float4 array_dist(ArrayType *a, ArrayType *b, usearch_metric_kind_t metri
elog(ERROR, "expected equally sized arrays but got arrays with dimensions %d and %d", a_dim, b_dim);
}

float4 *ax;
float4 *bx;
float4 result;
bool is_int_array = (metric_kind == usearch_metric_hamming_k);

bool convert_to_int = (metric_kind == usearch_metric_hamming_k);
if(is_int_array) {
int32 *ax_int = (int32 *)ARR_DATA_PTR(a);
int32 *bx_int = (int32 *)ARR_DATA_PTR(b);

if(convert_to_int) {
int32 *ax_int = (int32*) ARR_DATA_PTR(a);
int32 *bx_int = (int32*) ARR_DATA_PTR(b);
// calling usearch_scalar_f32_k here even though it's an integer array is fine
// the hamming distance in usearch actually ignores the scalar type
// and it will get casted appropriately in usearch even with this scalar type
result = usearch_dist(ax_int, bx_int, metric_kind, a_dim, usearch_scalar_f32_k);

ax = (float4*) palloc(a_dim * sizeof(float4));
bx = (float4*) palloc(b_dim * sizeof(float4));

for (int i = 0; i < a_dim; i++) {
ax[i] = (float4) ax_int[i];
bx[i] = (float4) bx_int[i];
}
}
else {
ax = (float4*)ARR_DATA_PTR(a);
bx = (float4*)ARR_DATA_PTR(b);
}

float4 result = usearch_dist(ax, bx, metric_kind, a_dim, usearch_scalar_f32_k);
} else {
float4 *ax = (float4 *)ARR_DATA_PTR(a);
float4 *bx = (float4 *)ARR_DATA_PTR(b);

if(convert_to_int) {
pfree(ax);
pfree(bx);
result = usearch_dist(ax, bx, metric_kind, a_dim, usearch_scalar_f32_k);
}

return result;
Expand Down Expand Up @@ -389,36 +379,32 @@ HnswColumnType GetIndexColumnType(Relation index)
}

/*
* Given vector data and vector type, convert it to a float4 array
* Given vector data and vector type, read it as either a float4 or int32 array and return as void*
*/
float4 *DatumGetSizedFloatArray(Datum datum, HnswColumnType type, int dimensions)
void *DatumGetSizedArray(Datum datum, HnswColumnType type, int dimensions)
{
if(type == VECTOR) {
Vector *vector = DatumGetVector(datum);
if(vector->dim != dimensions) {
elog(ERROR, "Expected vector with dimension %d, got %d", dimensions, vector->dim);
}
return vector->x;
return (void *)vector->x;
} else if(type == REAL_ARRAY) {
ArrayType *array = DatumGetArrayTypePCopy(datum);
int array_dim = ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array));
if(array_dim != dimensions) {
elog(ERROR, "Expected real array with dimension %d, got %d", dimensions, array_dim);
}
return (float4 *)ARR_DATA_PTR(array);
return (void *)((float4 *)ARR_DATA_PTR(array));
} else if(type == INT_ARRAY) {
ArrayType *array = DatumGetArrayTypePCopy(datum);
int array_dim = ArrayGetNItems(ARR_NDIM(array), ARR_DIMS(array));
if(array_dim != dimensions) {
elog(ERROR, "Expected int array with dimension %d, got %d", dimensions, array_dim);
}
int *intArray = (int *)ARR_DATA_PTR(array);
float4 *floatArray = (float4 *)palloc(sizeof(float) * array_dim);
for(int i = 0; i < array_dim; i++) {
floatArray[ i ] = (float)intArray[ i ];
}
// todo:: free this array
return floatArray;

int32 *intArray = (int32 *)ARR_DATA_PTR(array);
return (void *)intArray;
} else {
elog(ERROR, "Unsupported type");
}
Expand Down
5 changes: 4 additions & 1 deletion src/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ PGDLLEXPORT Datum hamming_dist(PG_FUNCTION_ARGS);
PGDLLEXPORT Datum cos_dist(PG_FUNCTION_ARGS);

HnswColumnType GetIndexColumnType(Relation index);
float4 *DatumGetSizedFloatArray(Datum datum, HnswColumnType type, int dimensions);
void *DatumGetSizedArray(Datum datum, HnswColumnType type, int dimensions);




#define LDB_UNUSED(x) (void)(x)

Expand Down
7 changes: 4 additions & 3 deletions src/hnsw/build.c
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,17 @@ static void AddTupleToUsearchIndex(ItemPointer tid, Datum *values, HnswBuildStat
usearch_error_t error = NULL;
Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[ 0 ]));
usearch_scalar_kind_t usearch_scalar;
float4 *vector = DatumGetSizedFloatArray(value, buildstate->columnType, buildstate->dimensions);

void *vector = DatumGetSizedArray(value, buildstate->columnType, buildstate->dimensions);
switch(buildstate->columnType) {
case REAL_ARRAY:
case VECTOR:
usearch_scalar = usearch_scalar_f32_k;
break;
case INT_ARRAY:
// q:: I think in this case we need to do a type conversion from int to float
// before passing the buffer to usearch
// this is fine, since we only use integer arrays with hamming distance metric
// and hamming distance in usearch doesn't care about scalar type
// also, usearch will appropriately cast integer arrays even with this scalar type
usearch_scalar = usearch_scalar_f32_k;
break;
default:
Expand Down
2 changes: 1 addition & 1 deletion src/hnsw/insert.c
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ bool ldb_aminsert(Relation index,
assert(!error);

datum = PointerGetDatum(PG_DETOAST_DATUM(values[ 0 ]));
float4 *vector = DatumGetSizedFloatArray(datum, insertstate->columnType, opts.dimensions);
void *vector = DatumGetSizedArray(datum, insertstate->columnType, opts.dimensions);

#if LANTERNDB_COPYNODES
// currently not fully ported to the latest changes
Expand Down
8 changes: 4 additions & 4 deletions src/hnsw/scan.c
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ bool ldb_amgettuple(IndexScanDesc scan, ScanDirection dir)
if(scanstate->first) {
int num_returned;
Datum value;
float4 *vec;
void *vec;
usearch_error_t error = NULL;
int k = ldb_hnsw_init_k;

Expand All @@ -183,7 +183,7 @@ bool ldb_amgettuple(IndexScanDesc scan, ScanDirection dir)
Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value)));
Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value)));

vec = DatumGetSizedFloatArray(value, scanstate->columnType, scanstate->dimensions);
vec = DatumGetSizedArray(value, scanstate->columnType, scanstate->dimensions);

if(scanstate->distances == NULL) {
scanstate->distances = palloc(k * sizeof(float));
Expand All @@ -209,7 +209,7 @@ bool ldb_amgettuple(IndexScanDesc scan, ScanDirection dir)
if(scanstate->current == scanstate->count) {
int num_returned;
Datum value;
float4 *vec;
void *vec;
usearch_error_t error = NULL;
int k = scanstate->count * 2;
int index_size = usearch_size(scanstate->usearch_index, &error);
Expand All @@ -221,7 +221,7 @@ bool ldb_amgettuple(IndexScanDesc scan, ScanDirection dir)

value = scan->orderByData->sk_argument;

vec = DatumGetSizedFloatArray(value, scanstate->columnType, scanstate->dimensions);
vec = DatumGetSizedArray(value, scanstate->columnType, scanstate->dimensions);

/* double k and reallocate arrays to account for increased size */
scanstate->distances = repalloc(scanstate->distances, k * sizeof(float));
Expand Down
2 changes: 1 addition & 1 deletion test/expected/hnsw_dist_func.out
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ INFO: inserted 0 elements
INFO: done saving 0 vectors
INSERT INTO small_world_l2 SELECT id, v FROM small_world;
INSERT INTO small_world_cos SELECT id, v FROM small_world;
INSERT INTO small_world_ham SELECT id, v FROM small_world;
INSERT INTO small_world_ham SELECT id, ARRAY[CAST(v[1] AS INTEGER), CAST(v[2] AS INTEGER), CAST(v[3] AS INTEGER)] FROM small_world;
SET enable_seqscan = false;
-- Verify that the distance functions work (check distances)
SELECT ROUND(l2sq_dist(v, '{0,1,0}')::numeric, 2) FROM small_world_l2 ORDER BY v <-> '{0,1,0}';
Expand Down
2 changes: 1 addition & 1 deletion test/expected/hnsw_todo.out
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ SELECT ROUND(hamming_dist(v, '{0,0}')::numeric, 2) FROM small_world_ham ORDER BY
-------
0.00
2.00
4.00
2.00
4.00
(4 rows)

--- Test scenarious ---
Expand Down
2 changes: 1 addition & 1 deletion test/sql/hnsw_dist_func.sql
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ CREATE INDEX ON small_world_ham USING hnsw (v dist_hamming_ops) WITH (dim=3);

INSERT INTO small_world_l2 SELECT id, v FROM small_world;
INSERT INTO small_world_cos SELECT id, v FROM small_world;
INSERT INTO small_world_ham SELECT id, v FROM small_world;
INSERT INTO small_world_ham SELECT id, ARRAY[CAST(v[1] AS INTEGER), CAST(v[2] AS INTEGER), CAST(v[3] AS INTEGER)] FROM small_world;

SET enable_seqscan = false;

Expand Down

0 comments on commit 3dba819

Please sign in to comment.