Skip to content

Commit

Permalink
Pre-evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
gopal-msr committed Nov 25, 2024
1 parent f9bb0c3 commit ddc8823
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 34 deletions.
32 changes: 17 additions & 15 deletions include/distance.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ template <typename T> class Distance
}

// distance comparison function
DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, uint32_t length, float threshold = FLT_MAX) const = 0;
DISKANN_DLLEXPORT virtual float compare(const T * __restrict a, const T * __restrict b, uint32_t length, float threshold = FLT_MAX) const = 0;

// Needed only for COSINE-BYTE and INNER_PRODUCT-BYTE
DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, const float normA, const float normB,
Expand Down Expand Up @@ -79,7 +79,7 @@ class DistanceCosineInt8 : public Distance<int8_t>
DistanceCosineInt8() : Distance<int8_t>(diskann::Metric::COSINE)
{
}
DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t length,
DISKANN_DLLEXPORT virtual float compare(const int8_t * __restrict a, const int8_t * __restrict b, uint32_t length,
float threshold = FLT_MAX) const;
};

Expand All @@ -89,7 +89,7 @@ class DistanceL2Int8 : public Distance<int8_t>
DistanceL2Int8() : Distance<int8_t>(diskann::Metric::L2)
{
}
DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t size,
DISKANN_DLLEXPORT virtual float compare(const int8_t *__restrict a, const int8_t *__restrict b, uint32_t size,
float threshold = FLT_MAX) const;

};
Expand All @@ -101,7 +101,7 @@ class AVXDistanceL2Int8 : public Distance<int8_t>
AVXDistanceL2Int8() : Distance<int8_t>(diskann::Metric::L2)
{
}
DISKANN_DLLEXPORT virtual float compare(const int8_t *a, const int8_t *b, uint32_t length,
DISKANN_DLLEXPORT virtual float compare(const int8_t *__restrict a, const int8_t *__restrict b, uint32_t length,
float threshold = FLT_MAX) const;
};

Expand All @@ -111,7 +111,7 @@ class DistanceCosineFloat : public Distance<float>
DistanceCosineFloat() : Distance<float>(diskann::Metric::COSINE)
{
}
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length,
DISKANN_DLLEXPORT virtual float compare(const float *__restrict a, const float *__restrict b, uint32_t length,
float threshold = FLT_MAX) const;
};

Expand All @@ -123,10 +123,10 @@ class DistanceL2Float : public Distance<float>
}

#ifdef _WINDOWS
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t size,
DISKANN_DLLEXPORT virtual float compare(const float *__restrict a, const float *__restrict b, uint32_t size,
float threshold = FLT_MAX) const;
#else
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t size,
DISKANN_DLLEXPORT virtual float compare(const float *__restrict a, const float *__restrict b, uint32_t size,
float threshold = FLT_MAX) const __attribute__((hot));
#endif
};
Expand All @@ -137,7 +137,7 @@ class AVXDistanceL2Float : public Distance<float>
AVXDistanceL2Float() : Distance<float>(diskann::Metric::L2)
{
}
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length,
DISKANN_DLLEXPORT virtual float compare(const float *__restrict a, const float *__restrict b, uint32_t length,
float threshold = FLT_MAX) const;
};

Expand All @@ -147,7 +147,8 @@ template <typename T> class SlowDistanceL2 : public Distance<T>
SlowDistanceL2() : Distance<T>(diskann::Metric::L2)
{
}
DISKANN_DLLEXPORT virtual float compare(const T *a, const T *b, uint32_t length, float threshold = FLT_MAX) const;
DISKANN_DLLEXPORT virtual float compare(const T *__restrict a, const T *__restrict b, uint32_t length,
float threshold = FLT_MAX) const;
};

class SlowDistanceCosineUInt8 : public Distance<uint8_t>
Expand All @@ -156,7 +157,7 @@ class SlowDistanceCosineUInt8 : public Distance<uint8_t>
SlowDistanceCosineUInt8() : Distance<uint8_t>(diskann::Metric::COSINE)
{
}
DISKANN_DLLEXPORT virtual float compare(const uint8_t *a, const uint8_t *b, uint32_t length,
DISKANN_DLLEXPORT virtual float compare(const uint8_t *__restrict a, const uint8_t *__restrict b, uint32_t length,
float threshold = FLT_MAX) const;
};

Expand All @@ -166,7 +167,7 @@ class DistanceL2UInt8 : public Distance<uint8_t>
DistanceL2UInt8() : Distance<uint8_t>(diskann::Metric::L2)
{
}
DISKANN_DLLEXPORT virtual float compare(const uint8_t *a, const uint8_t *b, uint32_t size,
DISKANN_DLLEXPORT virtual float compare(const uint8_t *__restrict a, const uint8_t *__restrict b, uint32_t size,
float threshold = FLT_MAX) const;
};

Expand All @@ -182,7 +183,7 @@ template <typename T> class DistanceInnerProduct : public Distance<T>
}
inline float inner_product(const T *a, const T *b, unsigned size) const;

inline float compare(const T *a, const T *b, unsigned size, float threshold = FLT_MAX) const
inline float compare(const T *__restrict a, const T *__restrict b, unsigned size, float threshold = FLT_MAX) const
{
float result = inner_product(a, b, size);
// if (result < 0)
Expand All @@ -201,7 +202,7 @@ template <typename T> class DistanceFastL2 : public DistanceInnerProduct<T>
{
}
float norm(const T *a, unsigned size) const;
float compare(const T *a, const T *b, float norm, unsigned size, float threshold = FLT_MAX) const;
float compare(const T *__restrict a, const T *__restrict b, float norm, unsigned size, float threshold = FLT_MAX) const;
};

class AVXDistanceInnerProductFloat : public Distance<float>
Expand All @@ -210,7 +211,8 @@ class AVXDistanceInnerProductFloat : public Distance<float>
AVXDistanceInnerProductFloat() : Distance<float>(diskann::Metric::INNER_PRODUCT)
{
}
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length, float threshold = FLT_MAX) const;
DISKANN_DLLEXPORT virtual float compare(const float *__restrict a, const float *__restrict b, uint32_t length,
float threshold = FLT_MAX) const;
};

class AVXNormalizedCosineDistanceFloat : public Distance<float>
Expand All @@ -225,7 +227,7 @@ class AVXNormalizedCosineDistanceFloat : public Distance<float>
AVXNormalizedCosineDistanceFloat() : Distance<float>(diskann::Metric::COSINE)
{
}
DISKANN_DLLEXPORT virtual float compare(const float *a, const float *b, uint32_t length,
DISKANN_DLLEXPORT virtual float compare(const float *__restrict a, const float *__restrict b, uint32_t length,
float threshold = FLT_MAX) const
{
// Inner product returns negative values to indicate distance.
Expand Down
48 changes: 29 additions & 19 deletions src/distance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ template <typename T> size_t Distance<T>::get_required_alignment() const
// Cosine distance functions.
//

float DistanceCosineInt8::compare(const int8_t *a, const int8_t *b, uint32_t length, float threshold) const
float DistanceCosineInt8::compare(const int8_t *__restrict a, const int8_t *__restrict b, uint32_t length,
float threshold) const
{
#ifdef _WINDOWS
return diskann::CosineSimilarity2<int8_t>(a, b, length);
Expand All @@ -82,7 +83,8 @@ float DistanceCosineInt8::compare(const int8_t *a, const int8_t *b, uint32_t len
#endif
}

float DistanceCosineFloat::compare(const float *a, const float *b, uint32_t length, float threshold) const
float DistanceCosineFloat::compare(const float *__restrict a, const float *__restrict b, uint32_t length,
float threshold) const
{
#ifdef _WINDOWS
return diskann::CosineSimilarity2<float>(a, b, length);
Expand All @@ -99,7 +101,7 @@ float DistanceCosineFloat::compare(const float *a, const float *b, uint32_t leng
#endif
}

float SlowDistanceCosineUInt8::compare(const uint8_t *a, const uint8_t *b, uint32_t length,
float SlowDistanceCosineUInt8::compare(const uint8_t *__restrict a, const uint8_t *__restrict b, uint32_t length,
float threshold) const
{
int magA = 0, magB = 0, scalarProduct = 0;
Expand All @@ -117,7 +119,7 @@ float SlowDistanceCosineUInt8::compare(const uint8_t *a, const uint8_t *b, uint3
// L2 distance functions.
//

float DistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t size, float threshold) const
float DistanceL2Int8::compare(const int8_t *__restrict a, const int8_t *__restrict b, uint32_t size, float threshold) const
{
#ifdef _WINDOWS
#ifdef USE_AVX2
Expand All @@ -131,7 +133,7 @@ float DistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t size, f
pY += 32;
size -= 32;
if (_mm256_reduce_add_ps(r) > threshold) {
diskann::cout << "Breaking because sum exceeded threshold: " << threshold << std::endl;
//diskann::cout << "Breaking because sum exceeded threshold: " << threshold << std::endl;
return FLT_MAX;
}
}
Expand All @@ -144,7 +146,7 @@ float DistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t size, f
size -= 4;

if (_mm256_reduce_add_ps(r) > threshold) {
diskann::cout << "Breaking because sum exceeded threshold: " << threshold << std::endl;
//diskann::cout << "Breaking because sum exceeded threshold: " << threshold << std::endl;
return FLT_MAX;
}

Expand All @@ -171,7 +173,8 @@ float DistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t size, f
#endif
}

float DistanceL2UInt8::compare(const uint8_t *a, const uint8_t *b, uint32_t size, float threshold) const
float DistanceL2UInt8::compare(const uint8_t *__restrict a, const uint8_t *__restrict b, uint32_t size,
float threshold) const
{
uint32_t result = 0;
#ifndef _WINDOWS
Expand All @@ -185,12 +188,12 @@ float DistanceL2UInt8::compare(const uint8_t *a, const uint8_t *b, uint32_t size
}

#ifndef _WINDOWS
float DistanceL2Float::compare(const float *a, const float *b, uint32_t size, float threshold) const
float DistanceL2Float::compare(const float *__restrict a, const float *__restrict b, uint32_t size, float threshold) const
{
a = (const float *)__builtin_assume_aligned(a, 32);
b = (const float *)__builtin_assume_aligned(b, 32);
#else
float DistanceL2Float::compare(const float *a, const float *b, uint32_t size, float threshold) const
float DistanceL2Float::compare(const float *__restrict a, const float *__restrict b, uint32_t size, float threshold) const
{
#endif

Expand All @@ -216,9 +219,13 @@ float DistanceL2Float::compare(const float *a, const float *b, uint32_t size, fl

sum = _mm256_fmadd_ps(tmp_vec, tmp_vec, sum);

if (_mm256_reduce_add_ps(sum) > threshold) {
//diskann::cout << "Breaking because sum exceeded threshold: " << threshold << std::endl;
return FLT_MAX;
if (j == (niters/2) || j == (3* niters/4) || j == (7* niters/8) )
{
if (_mm256_reduce_add_ps(sum) > threshold)
{
//diskann::cout << "Breaking because sum exceeded threshold: " << threshold << std::endl;
return FLT_MAX;
}
}
}

Expand All @@ -237,7 +244,7 @@ float DistanceL2Float::compare(const float *a, const float *b, uint32_t size, fl
}

template <typename T>
float SlowDistanceL2<T>::compare(const T *a, const T *b, uint32_t length, float threshold) const
float SlowDistanceL2<T>::compare(const T *__restrict a, const T *__restrict b, uint32_t length, float threshold) const
{
float result = 0.0f;
for (uint32_t i = 0; i < length; i++)
Expand All @@ -248,7 +255,8 @@ float SlowDistanceL2<T>::compare(const T *a, const T *b, uint32_t length, float
}

#ifdef _WINDOWS
float AVXDistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t length, float threshold) const
float AVXDistanceL2Int8::compare(const int8_t *__restrict a, const int8_t *__restrict b, uint32_t length,
float threshold) const
{
__m128 r = _mm_setzero_ps();
__m128i r1;
Expand Down Expand Up @@ -286,7 +294,8 @@ float AVXDistanceL2Int8::compare(const int8_t *a, const int8_t *b, uint32_t leng
return res;
}

float AVXDistanceL2Float::compare(const float *a, const float *b, uint32_t length, float threshold) const
float AVXDistanceL2Float::compare(const float *__restrict a, const float *__restrict b, uint32_t length,
float threshold) const
{
__m128 diff, v1, v2;
__m128 sum = _mm_set1_ps(0);
Expand All @@ -305,11 +314,11 @@ float AVXDistanceL2Float::compare(const float *a, const float *b, uint32_t lengt
return sum.m128_f32[0] + sum.m128_f32[1] + sum.m128_f32[2] + sum.m128_f32[3];
}
#else
float AVXDistanceL2Int8::compare(const int8_t *, const int8_t *, uint32_t, float threshold) const
float AVXDistanceL2Int8::compare(const int8_t *restrict, const int8_t *restrict, uint32_t, float threshold) const
{
return 0;
}
float AVXDistanceL2Float::compare(const float *, const float *, uint32_t, float threshold) const
float AVXDistanceL2Float::compare(const float *restrict, const float *restrict, uint32_t, float threshold) const
{
return 0;
}
Expand Down Expand Up @@ -429,7 +438,8 @@ template <typename T> float DistanceInnerProduct<T>::inner_product(const T *a, c
}

template <typename T>
float DistanceFastL2<T>::compare(const T *a, const T *b, float norm, uint32_t size, float threshold) const
float DistanceFastL2<T>::compare(const T *__restrict a, const T *__restrict b, float norm, uint32_t size,
float threshold) const
{
float result = -2 * DistanceInnerProduct<T>::inner_product(a, b, size);
result += norm;
Expand Down Expand Up @@ -537,7 +547,7 @@ template <typename T> float DistanceFastL2<T>::norm(const T *a, uint32_t size) c
return result;
}

float AVXDistanceInnerProductFloat::compare(const float *a, const float *b, uint32_t size,
float AVXDistanceInnerProductFloat::compare(const float *__restrict a, const float *__restrict b, uint32_t size,
float threshold) const
{
float result = 0.0f;
Expand Down

0 comments on commit ddc8823

Please sign in to comment.