Skip to content

Commit

Permalink
review changes
Browse files Browse the repository at this point in the history
Signed-off-by: Dhanus M Lal <[email protected]>
  • Loading branch information
DhanusML committed Sep 23, 2024
1 parent 0293b6f commit 9617b65
Showing 1 changed file with 48 additions and 36 deletions.
84 changes: 48 additions & 36 deletions cpp/daal/src/externals/service_stat_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,24 +176,22 @@ struct RefStatistics<double, cpu>
{
int errcode = 0;
daal::internal::ref::OpenBlas<double, cpu> blasInst;
double accWtOld = *nPreviousObservations;
double accWt = *nPreviousObservations + nVectors;
DAAL_INT one = 1;
char transa = 'N';
char transb = 'N';
double beta = 0.0;
double alpha;
const double accWtOld = *nPreviousObservations;
const double accWt = *nPreviousObservations + nVectors;
constexpr DAAL_INT one = 1;
if (accWtOld != 0)
{
double * sumOld = daal::services::internal::service_malloc<double, cpu>(nFeatures, sizeof(double));
double * const sumOld = daal::services::internal::service_malloc<double, cpu>(nFeatures, sizeof(double));
DAAL_CHECK_MALLOC(sumOld);
for (DAAL_INT i = 0; i < nFeatures; ++i)
{
sumOld[i] = sum[i];
}
// S_old S_old^t/accWtOld
alpha = 1.0 / accWtOld;
beta = 1.0;
const double alpha = 1.0 / accWtOld;
const double beta = 1.0;
constexpr char transa = 'N';
constexpr char transb = 'N';
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sumOld, &nFeatures, sumOld, &one, &beta, crossProduct, &nFeatures);
daal::services::daal_free(sumOld);
}
Expand All @@ -216,15 +214,23 @@ struct RefStatistics<double, cpu>
}

// -S S^t/accWt
alpha = -1.0 / accWt;
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sum, &nFeatures, sum, &one, &beta, crossProduct, &nFeatures);
{
const double alpha = -1.0 / accWt;
const double beta = accWtOld != 0 ? 1.0 : 0.0;
constexpr char transa = 'N';
constexpr char transb = 'N';
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sum, &nFeatures, sum, &one, &beta, crossProduct, &nFeatures);
}

// X X^t
transb = 'T';
alpha = 1.0;
beta = 1.0;
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &nVectors, &alpha, data, &nFeatures, data, &nFeatures, &beta, crossProduct,
&nFeatures);
{
constexpr double alpha = 1.0;
constexpr double beta = 1.0;
constexpr char transa = 'N';
constexpr char transb = 'T';
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &nVectors, &alpha, data, &nFeatures, data, &nFeatures, &beta, crossProduct,
&nFeatures);
}

return errcode;
}
Expand All @@ -250,7 +256,7 @@ struct RefStatistics<double, cpu>
// E(x-\mu)^2 = E(x^2) - \mu^2
int errcode = 0;
double * sum = (double *)daal::services::internal::service_calloc<double, cpu>(nFeatures, sizeof(double));
if (!sum) return -4;
DAAL_CHECK_MALLOC(sum);
daal::services::internal::service_memset<double, cpu>(variance, double(0), nFeatures);
DAAL_INT feature_ptr, vec_ptr;
double wtInv = (double)1 / nVectors;
Expand Down Expand Up @@ -337,24 +343,22 @@ struct RefStatistics<float, cpu>
{
int errcode = 0;
daal::internal::ref::OpenBlas<float, cpu> blasInst;
float accWtOld = *nPreviousObservations;
float accWt = *nPreviousObservations + nVectors;
DAAL_INT one = 1;
char transa = 'N';
char transb = 'N';
float beta = 0.0;
float alpha;
const float accWtOld = *nPreviousObservations;
const float accWt = *nPreviousObservations + nVectors;
constexpr DAAL_INT one = 1;
if (accWtOld != 0)
{
float * sumOld = daal::services::internal::service_malloc<float, cpu>(nFeatures, sizeof(float));
float * const sumOld = daal::services::internal::service_malloc<float, cpu>(nFeatures, sizeof(float));
DAAL_CHECK_MALLOC(sumOld);
for (DAAL_INT i = 0; i < nFeatures; ++i)
{
sumOld[i] = sum[i];
}
// S_old S_old^t/accWtOld
alpha = 1.0 / accWtOld;
beta = 1.0;
const float alpha = 1.0 / accWtOld;
const float beta = 1.0;
constexpr char transa = 'N';
constexpr char transb = 'N';
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sumOld, &nFeatures, sumOld, &one, &beta, crossProduct, &nFeatures);
daal::services::daal_free(sumOld);
}
Expand All @@ -377,15 +381,23 @@ struct RefStatistics<float, cpu>
}

// -S S^t/accWt
alpha = -1.0 / accWt;
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sum, &nFeatures, sum, &one, &beta, crossProduct, &nFeatures);
{
const float alpha = -1.0 / accWt;
const float beta = accWtOld != 0 ? 1.0 : 0.0;
constexpr char transa = 'N';
constexpr char transb = 'N';
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sum, &nFeatures, sum, &one, &beta, crossProduct, &nFeatures);
}

// X X^t
transb = 'T';
alpha = 1.0;
beta = 1.0;
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &nVectors, &alpha, data, &nFeatures, data, &nFeatures, &beta, crossProduct,
&nFeatures);
{
constexpr float alpha = 1.0;
constexpr float beta = 1.0;
constexpr char transa = 'N';
constexpr char transb = 'T';
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &nVectors, &alpha, data, &nFeatures, data, &nFeatures, &beta, crossProduct,
&nFeatures);
}

return errcode;
}
Expand All @@ -411,7 +423,7 @@ struct RefStatistics<float, cpu>
// E(x-\mu)^2 = E(x^2) - \mu^2
int errcode = 0;
float * sum = (float *)daal::services::internal::service_calloc<float, cpu>(nFeatures, sizeof(float));
if (!sum) return -4;
DAAL_CHECK_MALLOC(sum);
daal::services::internal::service_memset<float, cpu>(variance, float(0), nFeatures);
DAAL_INT feature_ptr, vec_ptr;
float wtInv = (float)1 / nVectors;
Expand Down

0 comments on commit 9617b65

Please sign in to comment.