From 0293b6f6c12676dbfc42083fa1e6a234d0594192 Mon Sep 17 00:00:00 2001 From: Dhanus M Lal Date: Fri, 13 Sep 2024 17:23:31 +0530 Subject: [PATCH] refactor and check for malloc fail Signed-off-by: Dhanus M Lal --- cpp/daal/src/externals/service_stat_ref.h | 81 +++++++++++------------ 1 file changed, 40 insertions(+), 41 deletions(-) diff --git a/cpp/daal/src/externals/service_stat_ref.h b/cpp/daal/src/externals/service_stat_ref.h index be71b5921ba..e0b30fa9cf9 100644 --- a/cpp/daal/src/externals/service_stat_ref.h +++ b/cpp/daal/src/externals/service_stat_ref.h @@ -174,17 +174,28 @@ struct RefStatistics static int xcp(double * data, __int64 nFeatures, __int64 nVectors, double * nPreviousObservations, double * sum, double * crossProduct, __int64 method) { - int errcode = 0; - double * sumOld = NULL; + int errcode = 0; + daal::internal::ref::OpenBlas blasInst; double accWtOld = *nPreviousObservations; double accWt = *nPreviousObservations + nVectors; - if (accWtOld) + DAAL_INT one = 1; + char transa = 'N'; + char transb = 'N'; + double beta = 0.0; + double alpha; + if (accWtOld != 0) { - sumOld = daal::services::internal::service_calloc(nFeatures, sizeof(double)); + double * sumOld = daal::services::internal::service_malloc(nFeatures, sizeof(double)); + DAAL_CHECK_MALLOC(sumOld); for (DAAL_INT i = 0; i < nFeatures; ++i) { sumOld[i] = sum[i]; } + // S_old S_old^t/accWtOld + alpha = 1.0 / accWtOld; + beta = 1.0; + blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sumOld, &nFeatures, sumOld, &one, &beta, crossProduct, &nFeatures); + daal::services::daal_free(sumOld); } for (DAAL_INT i = 0; i < nVectors; ++i) { @@ -203,28 +214,17 @@ struct RefStatistics } } } - daal::internal::ref::OpenBlas blasInst; - DAAL_INT one = 1; - char transa = 'N'; - char transb = 'N'; - double alpha = -1.0 / accWt; - double beta = accWtOld ? 1.0 : 0.0; // if accWtOld = 0, overwrite the cross product + // -S S^t/accWt + alpha = -1.0 / accWt; blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sum, &nFeatures, sum, &one, &beta, crossProduct, &nFeatures); - beta = 1.0; - if (accWtOld) - { - alpha = 1.0 / accWtOld; - // S_old S_old^t/accWtOld - blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sumOld, &nFeatures, sumOld, &one, &beta, crossProduct, &nFeatures); - } - transb = 'T'; - alpha = 1.0; // X X^t + transb = 'T'; + alpha = 1.0; + beta = 1.0; blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &nVectors, &alpha, data, &nFeatures, data, &nFeatures, &beta, crossProduct, &nFeatures); - if (accWtOld) daal::services::daal_free(sumOld); return errcode; } @@ -335,17 +335,28 @@ struct RefStatistics static int xcp(float * data, __int64 nFeatures, __int64 nVectors, float * nPreviousObservations, float * sum, float * crossProduct, __int64 method) { - int errcode = 0; - float * sumOld = NULL; + int errcode = 0; + daal::internal::ref::OpenBlas blasInst; float accWtOld = *nPreviousObservations; float accWt = *nPreviousObservations + nVectors; - if (accWtOld) + DAAL_INT one = 1; + char transa = 'N'; + char transb = 'N'; + float beta = 0.0; + float alpha; + if (accWtOld != 0) { - sumOld = daal::services::internal::service_calloc(nFeatures, sizeof(float)); + float * sumOld = daal::services::internal::service_malloc(nFeatures, sizeof(float)); + DAAL_CHECK_MALLOC(sumOld); for (DAAL_INT i = 0; i < nFeatures; ++i) { sumOld[i] = sum[i]; } + // S_old S_old^t/accWtOld + alpha = 1.0 / accWtOld; + beta = 1.0; + blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sumOld, &nFeatures, sumOld, &one, &beta, crossProduct, &nFeatures); + daal::services::daal_free(sumOld); } for (DAAL_INT i = 0; i < nVectors; ++i) { @@ -364,30 +375,18 @@ struct RefStatistics } } } - daal::internal::ref::OpenBlas blasInst; - DAAL_INT one = 1; - char transa = 'N'; - char transb = 'N'; - float alpha = -1.0 / accWt; - float beta = accWtOld ? 1.0 : 0.0; // if accWtOld = 0, overwrite the cross product + // -S S^t/accWt + alpha = -1.0 / accWt; blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sum, &nFeatures, sum, &one, &beta, crossProduct, &nFeatures); - beta = 1.0; - if (accWtOld) - { - alpha = 1.0 / accWtOld; - // S_old S_old^t/accWtOld - blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sumOld, &nFeatures, sumOld, &one, &beta, crossProduct, &nFeatures); - } - transb = 'T'; - alpha = 1.0; // X X^t + transb = 'T'; + alpha = 1.0; + beta = 1.0; blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &nVectors, &alpha, data, &nFeatures, data, &nFeatures, &beta, crossProduct, &nFeatures); - if (accWtOld) daal::services::daal_free(sumOld); - return errcode; }