From ada64ba7e5aa9ef19caae03d4825bce8fee218f1 Mon Sep 17 00:00:00 2001 From: Gopal Srinivasa Date: Tue, 26 Nov 2024 01:13:37 +0530 Subject: [PATCH] Pre-evaluation --- src/distance.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/distance.cpp b/src/distance.cpp index 542bf534a..02bc9e8a9 100644 --- a/src/distance.cpp +++ b/src/distance.cpp @@ -201,7 +201,7 @@ float DistanceL2Float::compare(const float *__restrict a, const float *__restric #ifdef USE_AVX2 // assume size is divisible by 8 uint16_t niters = (uint16_t)(size / 8); - uint32_t half = niters / 2, three_fourths = (niters * 3) / 4, seven_eights = (niters * 7) / 8; + uint32_t half = niters / 2, two_thirds = (niters * 2)/3, three_fourths = (niters * 3) / 4, seven_eights = (niters * 7) / 8; __m256 sum = _mm256_setzero_ps(); for (uint16_t j = 0; j < niters; j++) { @@ -220,11 +220,12 @@ float DistanceL2Float::compare(const float *__restrict a, const float *__restric sum = _mm256_fmadd_ps(tmp_vec, tmp_vec, sum); - if (j == half || j == three_fourths || j == seven_eights ) + if (j == half || j == two_thirds || j == three_fourths /* || j == seven_eights*/) { if (_mm256_reduce_add_ps(sum) > threshold) { - //diskann::cout << "Breaking because sum exceeded threshold: " << threshold << std::endl; + //diskann::cout << "Breaking at " << j << " instead of " << niters + // << " because sum exceeded threshold : " << threshold << std::endl; return FLT_MAX; } }