diff --git a/cpp/daal/src/externals/service_stat_ref.h b/cpp/daal/src/externals/service_stat_ref.h index 0ff35505527..be71b5921ba 100644 --- a/cpp/daal/src/externals/service_stat_ref.h +++ b/cpp/daal/src/externals/service_stat_ref.h @@ -26,6 +26,7 @@ #define __SERVICE_STAT_REF_H__ #include "src/externals/service_memory.h" +#include "src/externals/service_blas_ref.h" typedef void (*func_type)(DAAL_INT, DAAL_INT, DAAL_INT, void *); extern "C" @@ -173,7 +174,57 @@ struct RefStatistics static int xcp(double * data, __int64 nFeatures, __int64 nVectors, double * nPreviousObservations, double * sum, double * crossProduct, __int64 method) { - int errcode = 0; + int errcode = 0; + double * sumOld = NULL; + double accWtOld = *nPreviousObservations; + double accWt = *nPreviousObservations + nVectors; + if (accWtOld) + { + sumOld = daal::services::internal::service_calloc(nFeatures, sizeof(double)); + for (DAAL_INT i = 0; i < nFeatures; ++i) + { + sumOld[i] = sum[i]; + } + } + for (DAAL_INT i = 0; i < nVectors; ++i) + { + for (DAAL_INT j = 0; j < nFeatures; ++j) // if accWtOld = 0, overwrite sum + { + if (accWtOld != 0) + { + sum[j] += data[i * nFeatures + j]; + } + else + { + if (i == 0) + sum[j] = data[i * nFeatures + j]; //overwrite the current sum + else + sum[j] += data[i * nFeatures + j]; + } + } + } + daal::internal::ref::OpenBlas blasInst; + DAAL_INT one = 1; + char transa = 'N'; + char transb = 'N'; + double alpha = -1.0 / accWt; + double beta = accWtOld ? 1.0 : 0.0; // if accWtOld = 0, overwrite the cross product + // -S S^t/accWt + blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sum, &nFeatures, sum, &one, &beta, crossProduct, &nFeatures); + beta = 1.0; + if (accWtOld) + { + alpha = 1.0 / accWtOld; + // S_old S_old^t/accWtOld + blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sumOld, &nFeatures, sumOld, &one, &beta, crossProduct, &nFeatures); + } + transb = 'T'; + alpha = 1.0; + + // X X^t + blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &nVectors, &alpha, data, &nFeatures, data, &nFeatures, &beta, crossProduct, + &nFeatures); + if (accWtOld) daal::services::daal_free(sumOld); return errcode; } @@ -284,7 +335,58 @@ struct RefStatistics static int xcp(float * data, __int64 nFeatures, __int64 nVectors, float * nPreviousObservations, float * sum, float * crossProduct, __int64 method) { - int errcode = 0; + int errcode = 0; + float * sumOld = NULL; + float accWtOld = *nPreviousObservations; + float accWt = *nPreviousObservations + nVectors; + if (accWtOld) + { + sumOld = daal::services::internal::service_calloc(nFeatures, sizeof(float)); + for (DAAL_INT i = 0; i < nFeatures; ++i) + { + sumOld[i] = sum[i]; + } + } + for (DAAL_INT i = 0; i < nVectors; ++i) + { + for (DAAL_INT j = 0; j < nFeatures; ++j) // if accWtOld = 0, overwrite sum + { + if (accWtOld != 0) + { + sum[j] += data[i * nFeatures + j]; + } + else + { + if (i == 0) + sum[j] = data[i * nFeatures + j]; //overwrite the current sum + else + sum[j] += data[i * nFeatures + j]; + } + } + } + daal::internal::ref::OpenBlas blasInst; + DAAL_INT one = 1; + char transa = 'N'; + char transb = 'N'; + float alpha = -1.0 / accWt; + float beta = accWtOld ? 1.0 : 0.0; // if accWtOld = 0, overwrite the cross product + // -S S^t/accWt + blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sum, &nFeatures, sum, &one, &beta, crossProduct, &nFeatures); + beta = 1.0; + if (accWtOld) + { + alpha = 1.0 / accWtOld; + // S_old S_old^t/accWtOld + blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sumOld, &nFeatures, sumOld, &one, &beta, crossProduct, &nFeatures); + } + transb = 'T'; + alpha = 1.0; + + // X X^t + blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &nVectors, &alpha, data, &nFeatures, data, &nFeatures, &beta, crossProduct, + &nFeatures); + + if (accWtOld) daal::services::daal_free(sumOld); return errcode; }