Skip to content

Commit

Permalink
Merge branch 'haol/distance-cpp-remove-omp' into 'mqdb-dev'
Browse files Browse the repository at this point in the history
remove OMP and use manual sgemm to avoid debug error

Closes #17

See merge request mqdb/faiss!34
  • Loading branch information
Linpeng Tang committed Feb 1, 2024
2 parents bd549be + c4b150c commit 603cbb2
Showing 1 changed file with 50 additions and 13 deletions.
63 changes: 50 additions & 13 deletions faiss/utils/distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
#define FINTEGER long
#endif

// Remove all OpenMP to avoid CPU interference

extern "C" {

/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
Expand All @@ -52,6 +54,41 @@ int sgemm_(
FINTEGER* ldc);
}

int sgemm_manual(
const char* transa,
const char* transb,
FINTEGER* m,
FINTEGER* n,
FINTEGER* k,
const float* alpha,
const float* a,
FINTEGER* lda,
const float* b,
FINTEGER* ldb,
float* beta,
float* c,
FINTEGER* ldc
){
for (int i = 0; i < *m; ++i) {
for (int j = 0; j < *n; ++j) {
float result = 0.0;
for (int l = 0; l < (*k); ++l) {
result += (*alpha) * a[i * (*lda) + l] * b[l * (*ldb) + j];
}
c[i * (*ldc) + j] = (*beta) * c[i * (*ldc) + j] + result;
}
}
return 0;
}


#ifdef NDEBUG
#define sgemm_func sgemm_
#else
// use sgemm_manual in debug mode to avoid sleep error in ClickHouse
#define sgemm_func sgemm_manual
#endif

namespace faiss {

/***************************************************************************
Expand All @@ -64,7 +101,7 @@ void fvec_norms_L2(
const float* __restrict x,
size_t d,
size_t nx) {
#pragma omp parallel for
//#pragma omp parallel for
for (int64_t i = 0; i < nx; i++) {
nr[i] = sqrtf(fvec_norm_L2sqr(x + i * d, d));
}
Expand All @@ -75,13 +112,13 @@ void fvec_norms_L2sqr(
const float* __restrict x,
size_t d,
size_t nx) {
#pragma omp parallel for
//#pragma omp parallel for
for (int64_t i = 0; i < nx; i++)
nr[i] = fvec_norm_L2sqr(x + i * d, d);
}

void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
#pragma omp parallel for
//#pragma omp parallel for
for (int64_t i = 0; i < nx; i++) {
float* __restrict xi = x + i * d;

Expand Down Expand Up @@ -289,7 +326,7 @@ void exhaustive_L2sqr_blas_default_impl(
ip_block.get(),
&nyi);
}
#pragma omp parallel for
//#pragma omp parallel for
for (int64_t i = i0; i < i1; i++) {
float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);

Expand Down Expand Up @@ -370,7 +407,7 @@ void exhaustive_L2sqr_blas_cmax_avx2(
{
float one = 1, zero = 0;
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
sgemm_("Transpose",
sgemm_func("Transpose",
"Not transpose",
&nyi,
&nxi,
Expand All @@ -384,7 +421,7 @@ void exhaustive_L2sqr_blas_cmax_avx2(
ip_block.get(),
&nyi);
}
#pragma omp parallel for
//#pragma omp parallel for
for (int64_t i = i0; i < i1; i++) {
float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);

Expand Down Expand Up @@ -824,7 +861,7 @@ void pairwise_indexed_L2sqr(
const float* y,
const int64_t* iy,
float* dis) {
#pragma omp parallel for
//#pragma omp parallel for
for (int64_t j = 0; j < n; j++) {
if (ix[j] >= 0 && iy[j] >= 0) {
dis[j] = fvec_L2sqr(x + d * ix[j], y + d * iy[j], d);
Expand All @@ -840,7 +877,7 @@ void pairwise_indexed_inner_product(
const float* y,
const int64_t* iy,
float* dis) {
#pragma omp parallel for
//#pragma omp parallel for
for (int64_t j = 0; j < n; j++) {
if (ix[j] >= 0 && iy[j] >= 0) {
dis[j] = fvec_inner_product(x + d * ix[j], y + d * iy[j], d);
Expand All @@ -865,7 +902,7 @@ void knn_inner_products_by_idx(
ld_ids = ny;
}

#pragma omp parallel for if (nx > 100)
//#pragma omp parallel for if (nx > 100)
for (int64_t i = 0; i < nx; i++) {
const float* x_ = x + i * d;
const int64_t* idsi = ids + i * ld_ids;
Expand Down Expand Up @@ -901,7 +938,7 @@ void knn_L2sqr_by_idx(
if (ld_ids < 0) {
ld_ids = ny;
}
#pragma omp parallel for if (nx > 100)
//#pragma omp parallel for if (nx > 100)
for (int64_t i = 0; i < nx; i++) {
const float* x_ = x + i * d;
const int64_t* __restrict idsi = ids + i * ld_ids;
Expand Down Expand Up @@ -941,11 +978,11 @@ void pairwise_L2sqr(
// store in beginning of distance matrix to avoid malloc
float* b_norms = dis;

#pragma omp parallel for
//#pragma omp parallel for
for (int64_t i = 0; i < nb; i++)
b_norms[i] = fvec_norm_L2sqr(xb + i * ldb, d);

#pragma omp parallel for
//#pragma omp parallel for
for (int64_t i = 1; i < nq; i++) {
float q_norm = fvec_norm_L2sqr(xq + i * ldq, d);
for (int64_t j = 0; j < nb; j++)
Expand Down Expand Up @@ -984,7 +1021,7 @@ void inner_product_to_L2sqr(
const float* nr2,
size_t n1,
size_t n2) {
#pragma omp parallel for
//#pragma omp parallel for
for (int64_t j = 0; j < n1; j++) {
float* disj = dis + j * n2;
for (size_t i = 0; i < n2; i++)
Expand Down

0 comments on commit 603cbb2

Please sign in to comment.