Skip to content

Commit

Permalink
Pre-evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
gopal-msr committed Nov 25, 2024
1 parent e6afdbb commit f9bb0c3
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 38 deletions.
3 changes: 2 additions & 1 deletion include/abstract_data_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ template <typename data_t> class AbstractDataStore
float *distances, AbstractScratch<data_t> *scratch_space = nullptr) const = 0;
// Specific overload for index.cpp.
virtual void get_distance(const data_t *preprocessed_query, const std::vector<location_t> &ids,
std::vector<float> &distances, AbstractScratch<data_t> *scratch_space) const = 0;
std::vector<float> &distances, AbstractScratch<data_t> *scratch_space,
float threshold = FLT_MAX) const = 0;
virtual float get_distance(const location_t loc1, const location_t loc2) const = 0;

// stats of the data stored in store
Expand Down
43 changes: 28 additions & 15 deletions include/distance.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#pragma once
#include "windows_customizations.h"
#include <cstring>
#include <climits>
#include <cfloat>

namespace diskann
{
Expand All @@ -20,7 +22,7 @@ template <typename T> class Distance
}

// distance comparison function
DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, uint32_t length) const = 0;
DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, uint32_t length, float threshold = FLT_MAX) const = 0;

// Needed only for COSINE-BYTE and INNER_PRODUCT-BYTE
DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, const float normA, const float normB,
Expand Down Expand Up @@ -77,7 +79,8 @@ class DistanceCosineInt8 : public Distance<int8_t>
DistanceCosineInt8() : Distance<int8_t>(diskann::Metric::COSINE)
{
}
DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t length) const;
DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t length,
float threshold = FLT_MAX) const;
};

class DistanceL2Int8 : public Distance<int8_t>
Expand All @@ -86,7 +89,9 @@ class DistanceL2Int8 : public Distance<int8_t>
DistanceL2Int8() : Distance<int8_t>(diskann::Metric::L2)
{
}
DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t size) const;
DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t size,
float threshold = FLT_MAX) const;

};

// AVX implementations. Borrowed from HNSW code.
Expand All @@ -96,7 +101,8 @@ class AVXDistanceL2Int8 : public Distance<int8_t>
AVXDistanceL2Int8() : Distance<int8_t>(diskann::Metric::L2)
{
}
DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t length) const;
DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t length,
float threshold = FLT_MAX) const;
};

class DistanceCosineFloat : public Distance<float>
Expand All @@ -105,7 +111,8 @@ class DistanceCosineFloat : public Distance<float>
DistanceCosineFloat() : Distance<float>(diskann::Metric::COSINE)
{
}
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const;
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length,
float threshold = FLT_MAX) const;
};

class DistanceL2Float : public Distance<float>
Expand All @@ -116,9 +123,11 @@ class DistanceL2Float : public Distance<float>
}

#ifdef _WINDOWS
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t size) const;
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t size,
float threshold = FLT_MAX) const;
#else
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t size) const __attribute__((hot));
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t size,
float threshold = FLT_MAX) const __attribute__((hot));
#endif
};

Expand All @@ -128,7 +137,8 @@ class AVXDistanceL2Float : public Distance<float>
AVXDistanceL2Float() : Distance<float>(diskann::Metric::L2)
{
}
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const;
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length,
float threshold = FLT_MAX) const;
};

template <typename T> class SlowDistanceL2 : public Distance<T>
Expand All @@ -137,7 +147,7 @@ template <typename T> class SlowDistanceL2 : public Distance<T>
SlowDistanceL2() : Distance<T>(diskann::Metric::L2)
{
}
DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, uint32_t length) const;
DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, uint32_t length, float threshold = FLT_MAX) const;
};

class SlowDistanceCosineUInt8 : public Distance<uint8_t>
Expand All @@ -146,7 +156,8 @@ class SlowDistanceCosineUInt8 : public Distance<uint8_t>
SlowDistanceCosineUInt8() : Distance<uint8_t>(diskann::Metric::COSINE)
{
}
DISKANN_DLLEXPORT virtual float compare(const uint8_t *a, const uint8_t *b, uint32_t length) const;
DISKANN_DLLEXPORT virtual float compare(const uint8_t *a, const uint8_t *b, uint32_t length,
float threshold = FLT_MAX) const;
};

class DistanceL2UInt8 : public Distance<uint8_t>
Expand All @@ -155,7 +166,8 @@ class DistanceL2UInt8 : public Distance<uint8_t>
DistanceL2UInt8() : Distance<uint8_t>(diskann::Metric::L2)
{
}
DISKANN_DLLEXPORT virtual float compare(const uint8_t *a, const uint8_t *b, uint32_t size) const;
DISKANN_DLLEXPORT virtual float compare(const uint8_t *a, const uint8_t *b, uint32_t size,
float threshold = FLT_MAX) const;
};

template <typename T> class DistanceInnerProduct : public Distance<T>
Expand All @@ -170,7 +182,7 @@ template <typename T> class DistanceInnerProduct : public Distance<T>
}
inline float inner_product(const T *a, const T *b, unsigned size) const;

inline float compare(const T *a, const T *b, unsigned size) const
inline float compare(const T *a, const T *b, unsigned size, float threshold = FLT_MAX) const
{
float result = inner_product(a, b, size);
// if (result < 0)
Expand All @@ -189,7 +201,7 @@ template <typename T> class DistanceFastL2 : public DistanceInnerProduct<T>
{
}
float norm(const T *a, unsigned size) const;
float compare(const T *a, const T *b, float norm, unsigned size) const;
float compare(const T *a, const T *b, float norm, unsigned size, float threshold = FLT_MAX) const;
};

class AVXDistanceInnerProductFloat : public Distance<float>
Expand All @@ -198,7 +210,7 @@ class AVXDistanceInnerProductFloat : public Distance<float>
AVXDistanceInnerProductFloat() : Distance<float>(diskann::Metric::INNER_PRODUCT)
{
}
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const;
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length, float threshold = FLT_MAX) const;
};

class AVXNormalizedCosineDistanceFloat : public Distance<float>
Expand All @@ -213,7 +225,8 @@ class AVXNormalizedCosineDistanceFloat : public Distance<float>
AVXNormalizedCosineDistanceFloat() : Distance<float>(diskann::Metric::COSINE)
{
}
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length) const
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length,
float threshold = FLT_MAX) const
{
// Inner product returns negative values to indicate distance.
// This will ensure that cosine is between -1 and 1.
Expand Down
2 changes: 1 addition & 1 deletion include/in_mem_data_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ template <typename data_t> class InMemDataStore : public AbstractDataStore<data_
const uint32_t location_count, float *distances,
AbstractScratch<data_t> *scratch) const override;
virtual void get_distance(const data_t *preprocessed_query, const std::vector<location_t> &ids,
std::vector<float> &distances, AbstractScratch<data_t> *scratch_space) const override;
std::vector<float> &distances, AbstractScratch<data_t> *scratch_space, float threshold = FLT_MAX) const override;

virtual location_t calculate_medoid() const override;

Expand Down
5 changes: 5 additions & 0 deletions include/neighbor.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ class NeighborPriorityQueue
return _data[pre];
}

Neighbor farthest_distance() const
{
return _data[_size - 1];
}

bool has_unexpanded_node() const
{
return _cur < _size;
Expand Down
2 changes: 1 addition & 1 deletion include/pq_data_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ template <typename data_t> class PQDataStore : public AbstractDataStore<data_t>
// NOTE: Caller must invoke "PQDistance->preprocess_query" ONCE before calling
// this function.
virtual void get_distance(const data_t *preprocessed_query, const std::vector<location_t> &ids,
std::vector<float> &distances, AbstractScratch<data_t> *scratch_space) const override;
std::vector<float> &distances, AbstractScratch<data_t> *scratch_space, float threshold = FLT_MAX) const override;

// We are returning the distance function that is used for full precision
// vectors here, not the PQ distance function. This is because the callers
Expand Down
47 changes: 33 additions & 14 deletions src/distance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ template <typename T> size_t Distance<T>::get_required_alignment() const
// Cosine distance functions.
//

float DistanceCosineInt8::compare(const int8_t *a, const int8_t *b, uint32_t length) const
float DistanceCosineInt8::compare(const int8_t *a, const int8_t *b, uint32_t length, float threshold) const
{
#ifdef _WINDOWS
return diskann::CosineSimilarity2<int8_t>(a, b, length);
Expand All @@ -82,7 +82,7 @@ float DistanceCosineInt8::compare(const int8_t *a, const int8_t *b, uint32_t len
#endif
}

float DistanceCosineFloat::compare(const float *a, const float *b, uint32_t length) const
float DistanceCosineFloat::compare(const float *a, const float *b, uint32_t length, float threshold) const
{
#ifdef _WINDOWS
return diskann::CosineSimilarity2<float>(a, b, length);
Expand All @@ -99,7 +99,8 @@ float DistanceCosineFloat::compare(const float *a, const float *b, uint32_t leng
#endif
}

float SlowDistanceCosineUInt8::compare(const uint8_t *a, const uint8_t *b, uint32_t length) const
float SlowDistanceCosineUInt8::compare(const uint8_t *a, const uint8_t *b, uint32_t length,
float threshold) const
{
int magA = 0, magB = 0, scalarProduct = 0;
for (uint32_t i = 0; i < length; i++)
Expand All @@ -116,7 +117,7 @@ float SlowDistanceCosineUInt8::compare(const uint8_t *a, const uint8_t *b, uint3
// L2 distance functions.
//

float DistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t size) const
float DistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t size, float threshold) const
{
#ifdef _WINDOWS
#ifdef USE_AVX2
Expand All @@ -129,6 +130,10 @@ float DistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t size) c
pX += 32;
pY += 32;
size -= 32;
if (_mm256_reduce_add_ps(r) > threshold) {
diskann::cout << "Breaking because sum exceeded threshold: " << threshold << std::endl;
return FLT_MAX;
}
}
while (size > 0)
{
Expand All @@ -137,6 +142,12 @@ float DistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t size) c
pX += 4;
pY += 4;
size -= 4;

if (_mm256_reduce_add_ps(r) > threshold) {
diskann::cout << "Breaking because sum exceeded threshold: " << threshold << std::endl;
return FLT_MAX;
}

}
r = _mm256_hadd_ps(_mm256_hadd_ps(r, r), r);
return r.m256_f32[0] + r.m256_f32[4];
Expand All @@ -160,7 +171,7 @@ float DistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t size) c
#endif
}

float DistanceL2UInt8::compare(const uint8_t *a, const uint8_t *b, uint32_t size) const
float DistanceL2UInt8::compare(const uint8_t *a, const uint8_t *b, uint32_t size, float threshold) const
{
uint32_t result = 0;
#ifndef _WINDOWS
Expand All @@ -174,12 +185,12 @@ float DistanceL2UInt8::compare(const uint8_t *a, const uint8_t *b, uint32_t size
}

#ifndef _WINDOWS
float DistanceL2Float::compare(const float *a, const float *b, uint32_t size) const
float DistanceL2Float::compare(const float *a, const float *b, uint32_t size, float threshold) const
{
a = (const float *)__builtin_assume_aligned(a, 32);
b = (const float *)__builtin_assume_aligned(b, 32);
#else
float DistanceL2Float::compare(const float *a, const float *b, uint32_t size) const
float DistanceL2Float::compare(const float *a, const float *b, uint32_t size, float threshold) const
{
#endif

Expand All @@ -204,6 +215,11 @@ float DistanceL2Float::compare(const float *a, const float *b, uint32_t size) co
__m256 tmp_vec = _mm256_sub_ps(a_vec, b_vec);

sum = _mm256_fmadd_ps(tmp_vec, tmp_vec, sum);

if (_mm256_reduce_add_ps(sum) > threshold) {
//diskann::cout << "Breaking because sum exceeded threshold: " << threshold << std::endl;
return FLT_MAX;
}
}

// horizontal add sum
Expand All @@ -220,7 +236,8 @@ float DistanceL2Float::compare(const float *a, const float *b, uint32_t size) co
return result;
}

template <typename T> float SlowDistanceL2<T>::compare(const T *a, const T *b, uint32_t length) const
template <typename T>
float SlowDistanceL2<T>::compare(const T *a, const T *b, uint32_t length, float threshold) const
{
float result = 0.0f;
for (uint32_t i = 0; i < length; i++)
Expand All @@ -231,7 +248,7 @@ template <typename T> float SlowDistanceL2<T>::compare(const T *a, const T *b, u
}

#ifdef _WINDOWS
float AVXDistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t length) const
float AVXDistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t length, float threshold) const
{
__m128 r = _mm_setzero_ps();
__m128i r1;
Expand Down Expand Up @@ -269,7 +286,7 @@ float AVXDistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t leng
return res;
}

float AVXDistanceL2Float::compare(const float *a, const float *b, uint32_t length) const
float AVXDistanceL2Float::compare(const float *a, const float *b, uint32_t length, float threshold) const
{
__m128 diff, v1, v2;
__m128 sum = _mm_set1_ps(0);
Expand All @@ -288,11 +305,11 @@ float AVXDistanceL2Float::compare(const float *a, const float *b, uint32_t lengt
return sum.m128_f32[0] + sum.m128_f32[1] + sum.m128_f32[2] + sum.m128_f32[3];
}
#else
float AVXDistanceL2Int8::compare(const int8_t *, const int8_t *, uint32_t) const
float AVXDistanceL2Int8::compare(const int8_t *, const int8_t *, uint32_t, float threshold) const
{
return 0;
}
float AVXDistanceL2Float::compare(const float *, const float *, uint32_t) const
float AVXDistanceL2Float::compare(const float *, const float *, uint32_t, float threshold) const
{
return 0;
}
Expand Down Expand Up @@ -411,7 +428,8 @@ template <typename T> float DistanceInnerProduct<T>::inner_product(const T *a, c
return result;
}

template <typename T> float DistanceFastL2<T>::compare(const T *a, const T *b, float norm, uint32_t size) const
template <typename T>
float DistanceFastL2<T>::compare(const T *a, const T *b, float norm, uint32_t size, float threshold) const
{
float result = -2 * DistanceInnerProduct<T>::inner_product(a, b, size);
result += norm;
Expand Down Expand Up @@ -519,7 +537,8 @@ template <typename T> float DistanceFastL2<T>::norm(const T *a, uint32_t size) c
return result;
}

float AVXDistanceInnerProductFloat::compare(const float *a, const float *b, uint32_t size) const
float AVXDistanceInnerProductFloat::compare(const float *a, const float *b, uint32_t size,
float threshold) const
{
float result = 0.0f;
#define AVX_DOT(addr1, addr2, dest, tmp1, tmp2) \
Expand Down
4 changes: 2 additions & 2 deletions src/in_mem_data_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,12 @@ float InMemDataStore<data_t>::get_distance(const location_t loc1, const location

template <typename data_t>
void InMemDataStore<data_t>::get_distance(const data_t *preprocessed_query, const std::vector<location_t> &ids,
std::vector<float> &distances, AbstractScratch<data_t> *scratch_space) const
std::vector<float> &distances, AbstractScratch<data_t> *scratch_space, float threshold) const
{
for (int i = 0; i < ids.size(); i++)
{
distances[i] =
_distance_fn->compare(preprocessed_query, _data + ids[i] * _aligned_dim, (uint32_t)this->_aligned_dim);
_distance_fn->compare(preprocessed_query, _data + ids[i] * _aligned_dim, (uint32_t)this->_aligned_dim, threshold);
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -835,8 +835,8 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
};

// Lambda to batch compute query<-> node distances in PQ space
auto compute_dists = [this, scratch, pq_dists](const std::vector<uint32_t> &ids, std::vector<float> &dists_out) {
_pq_data_store->get_distance(scratch->aligned_query(), ids, dists_out, scratch);
auto compute_dists = [this, scratch, pq_dists](const std::vector<uint32_t> &ids, std::vector<float> &dists_out, float threshold) {
_pq_data_store->get_distance(scratch->aligned_query(), ids, dists_out, scratch, threshold);
};

// Initialize the candidate pool with starting points
Expand Down Expand Up @@ -963,7 +963,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
}

assert(dist_scratch.capacity() >= id_scratch.size());
compute_dists(id_scratch, dist_scratch);
compute_dists(id_scratch, dist_scratch, best_L_nodes.farthest_distance().distance);
cmps += (uint32_t)id_scratch.size();

// Insert <id, dist> pairs into the pool of candidates
Expand Down
2 changes: 1 addition & 1 deletion src/pq_data_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ void PQDataStore<data_t>::get_distance(const data_t *preprocessed_query, const l

template <typename data_t>
void PQDataStore<data_t>::get_distance(const data_t *preprocessed_query, const std::vector<location_t> &ids,
std::vector<float> &distances, AbstractScratch<data_t> *scratch_space) const
std::vector<float> &distances, AbstractScratch<data_t> *scratch_space, float threshold) const
{
if (scratch_space == nullptr)
{
Expand Down

0 comments on commit f9bb0c3

Please sign in to comment.