From 9617b65ff1478efbb5b6622d561b75b41895529d Mon Sep 17 00:00:00 2001 From: Dhanus M Lal Date: Mon, 23 Sep 2024 13:31:26 +0530 Subject: [PATCH] review changes Signed-off-by: Dhanus M Lal --- cpp/daal/src/externals/service_stat_ref.h | 84 +++++++++++++---------- 1 file changed, 48 insertions(+), 36 deletions(-) diff --git a/cpp/daal/src/externals/service_stat_ref.h b/cpp/daal/src/externals/service_stat_ref.h index e0b30fa9cf9..81363ce3079 100644 --- a/cpp/daal/src/externals/service_stat_ref.h +++ b/cpp/daal/src/externals/service_stat_ref.h @@ -176,24 +176,22 @@ struct RefStatistics { int errcode = 0; daal::internal::ref::OpenBlas blasInst; - double accWtOld = *nPreviousObservations; - double accWt = *nPreviousObservations + nVectors; - DAAL_INT one = 1; - char transa = 'N'; - char transb = 'N'; - double beta = 0.0; - double alpha; + const double accWtOld = *nPreviousObservations; + const double accWt = *nPreviousObservations + nVectors; + constexpr DAAL_INT one = 1; if (accWtOld != 0) { - double * sumOld = daal::services::internal::service_malloc(nFeatures, sizeof(double)); + double * const sumOld = daal::services::internal::service_malloc(nFeatures, sizeof(double)); DAAL_CHECK_MALLOC(sumOld); for (DAAL_INT i = 0; i < nFeatures; ++i) { sumOld[i] = sum[i]; } // S_old S_old^t/accWtOld - alpha = 1.0 / accWtOld; - beta = 1.0; + const double alpha = 1.0 / accWtOld; + const double beta = 1.0; + constexpr char transa = 'N'; + constexpr char transb = 'N'; blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sumOld, &nFeatures, sumOld, &one, &beta, crossProduct, &nFeatures); daal::services::daal_free(sumOld); } @@ -216,15 +214,23 @@ struct RefStatistics } // -S S^t/accWt - alpha = -1.0 / accWt; - blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sum, &nFeatures, sum, &one, &beta, crossProduct, &nFeatures); + { + const double alpha = -1.0 / accWt; + const double beta = accWtOld != 0 ? 1.0 : 0.0; + constexpr char transa = 'N'; + constexpr char transb = 'N'; + blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sum, &nFeatures, sum, &one, &beta, crossProduct, &nFeatures); + } // X X^t - transb = 'T'; - alpha = 1.0; - beta = 1.0; - blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &nVectors, &alpha, data, &nFeatures, data, &nFeatures, &beta, crossProduct, - &nFeatures); + { + constexpr double alpha = 1.0; + constexpr double beta = 1.0; + constexpr char transa = 'N'; + constexpr char transb = 'T'; + blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &nVectors, &alpha, data, &nFeatures, data, &nFeatures, &beta, crossProduct, + &nFeatures); + } return errcode; } @@ -250,7 +256,7 @@ struct RefStatistics // E(x-\mu)^2 = E(x^2) - \mu^2 int errcode = 0; double * sum = (double *)daal::services::internal::service_calloc(nFeatures, sizeof(double)); - if (!sum) return -4; + DAAL_CHECK_MALLOC(sum); daal::services::internal::service_memset(variance, double(0), nFeatures); DAAL_INT feature_ptr, vec_ptr; double wtInv = (double)1 / nVectors; @@ -337,24 +343,22 @@ struct RefStatistics { int errcode = 0; daal::internal::ref::OpenBlas blasInst; - float accWtOld = *nPreviousObservations; - float accWt = *nPreviousObservations + nVectors; - DAAL_INT one = 1; - char transa = 'N'; - char transb = 'N'; - float beta = 0.0; - float alpha; + const float accWtOld = *nPreviousObservations; + const float accWt = *nPreviousObservations + nVectors; + constexpr DAAL_INT one = 1; if (accWtOld != 0) { - float * sumOld = daal::services::internal::service_malloc(nFeatures, sizeof(float)); + float * const sumOld = daal::services::internal::service_malloc(nFeatures, sizeof(float)); DAAL_CHECK_MALLOC(sumOld); for (DAAL_INT i = 0; i < nFeatures; ++i) { sumOld[i] = sum[i]; } // S_old S_old^t/accWtOld - alpha = 1.0 / accWtOld; - beta = 1.0; + const float alpha = 1.0 / accWtOld; + const float beta = 1.0; + constexpr char transa = 'N'; + constexpr char transb = 'N'; blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sumOld, &nFeatures, sumOld, &one, &beta, crossProduct, &nFeatures); daal::services::daal_free(sumOld); } @@ -377,15 +381,23 @@ struct RefStatistics } // -S S^t/accWt - alpha = -1.0 / accWt; - blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sum, &nFeatures, sum, &one, &beta, crossProduct, &nFeatures); + { + const float alpha = -1.0 / accWt; + const float beta = accWtOld != 0 ? 1.0 : 0.0; + constexpr char transa = 'N'; + constexpr char transb = 'N'; + blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sum, &nFeatures, sum, &one, &beta, crossProduct, &nFeatures); + } // X X^t - transb = 'T'; - alpha = 1.0; - beta = 1.0; - blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &nVectors, &alpha, data, &nFeatures, data, &nFeatures, &beta, crossProduct, - &nFeatures); + { + constexpr float alpha = 1.0; + constexpr float beta = 1.0; + constexpr char transa = 'N'; + constexpr char transb = 'T'; + blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &nVectors, &alpha, data, &nFeatures, data, &nFeatures, &beta, crossProduct, + &nFeatures); + } return errcode; } @@ -411,7 +423,7 @@ struct RefStatistics // E(x-\mu)^2 = E(x^2) - \mu^2 int errcode = 0; float * sum = (float *)daal::services::internal::service_calloc(nFeatures, sizeof(float)); - if (!sum) return -4; + DAAL_CHECK_MALLOC(sum); daal::services::internal::service_memset(variance, float(0), nFeatures); DAAL_INT feature_ptr, vec_ptr; float wtInv = (float)1 / nVectors;