Skip to content

Commit

Permalink
refactor and check for malloc fail
Browse files Browse the repository at this point in the history
Signed-off-by: Dhanus M Lal <[email protected]>
  • Loading branch information
DhanusML committed Sep 13, 2024
1 parent 478255a commit 0293b6f
Showing 1 changed file with 40 additions and 41 deletions.
81 changes: 40 additions & 41 deletions cpp/daal/src/externals/service_stat_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,28 @@ struct RefStatistics<double, cpu>
static int xcp(double * data, __int64 nFeatures, __int64 nVectors, double * nPreviousObservations, double * sum, double * crossProduct,
__int64 method)
{
int errcode = 0;
double * sumOld = NULL;
int errcode = 0;
daal::internal::ref::OpenBlas<double, cpu> blasInst;
double accWtOld = *nPreviousObservations;
double accWt = *nPreviousObservations + nVectors;
if (accWtOld)
DAAL_INT one = 1;
char transa = 'N';
char transb = 'N';
double beta = 0.0;
double alpha;
if (accWtOld != 0)
{
sumOld = daal::services::internal::service_calloc<double, cpu>(nFeatures, sizeof(double));
double * sumOld = daal::services::internal::service_malloc<double, cpu>(nFeatures, sizeof(double));
DAAL_CHECK_MALLOC(sumOld);
for (DAAL_INT i = 0; i < nFeatures; ++i)
{
sumOld[i] = sum[i];
}
// S_old S_old^t/accWtOld
alpha = 1.0 / accWtOld;
beta = 1.0;
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sumOld, &nFeatures, sumOld, &one, &beta, crossProduct, &nFeatures);
daal::services::daal_free(sumOld);
}
for (DAAL_INT i = 0; i < nVectors; ++i)
{
Expand All @@ -203,28 +214,17 @@ struct RefStatistics<double, cpu>
}
}
}
daal::internal::ref::OpenBlas<double, cpu> blasInst;
DAAL_INT one = 1;
char transa = 'N';
char transb = 'N';
double alpha = -1.0 / accWt;
double beta = accWtOld ? 1.0 : 0.0; // if accWtOld = 0, overwrite the cross product

// -S S^t/accWt
alpha = -1.0 / accWt;
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sum, &nFeatures, sum, &one, &beta, crossProduct, &nFeatures);
beta = 1.0;
if (accWtOld)
{
alpha = 1.0 / accWtOld;
// S_old S_old^t/accWtOld
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sumOld, &nFeatures, sumOld, &one, &beta, crossProduct, &nFeatures);
}
transb = 'T';
alpha = 1.0;

// X X^t
transb = 'T';
alpha = 1.0;
beta = 1.0;
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &nVectors, &alpha, data, &nFeatures, data, &nFeatures, &beta, crossProduct,
&nFeatures);
if (accWtOld) daal::services::daal_free(sumOld);

return errcode;
}
Expand Down Expand Up @@ -335,17 +335,28 @@ struct RefStatistics<float, cpu>
static int xcp(float * data, __int64 nFeatures, __int64 nVectors, float * nPreviousObservations, float * sum, float * crossProduct,
__int64 method)
{
int errcode = 0;
float * sumOld = NULL;
int errcode = 0;
daal::internal::ref::OpenBlas<float, cpu> blasInst;
float accWtOld = *nPreviousObservations;
float accWt = *nPreviousObservations + nVectors;
if (accWtOld)
DAAL_INT one = 1;
char transa = 'N';
char transb = 'N';
float beta = 0.0;
float alpha;
if (accWtOld != 0)
{
sumOld = daal::services::internal::service_calloc<float, cpu>(nFeatures, sizeof(float));
float * sumOld = daal::services::internal::service_malloc<float, cpu>(nFeatures, sizeof(float));
DAAL_CHECK_MALLOC(sumOld);
for (DAAL_INT i = 0; i < nFeatures; ++i)
{
sumOld[i] = sum[i];
}
// S_old S_old^t/accWtOld
alpha = 1.0 / accWtOld;
beta = 1.0;
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sumOld, &nFeatures, sumOld, &one, &beta, crossProduct, &nFeatures);
daal::services::daal_free(sumOld);
}
for (DAAL_INT i = 0; i < nVectors; ++i)
{
Expand All @@ -364,30 +375,18 @@ struct RefStatistics<float, cpu>
}
}
}
daal::internal::ref::OpenBlas<float, cpu> blasInst;
DAAL_INT one = 1;
char transa = 'N';
char transb = 'N';
float alpha = -1.0 / accWt;
float beta = accWtOld ? 1.0 : 0.0; // if accWtOld = 0, overwrite the cross product

// -S S^t/accWt
alpha = -1.0 / accWt;
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sum, &nFeatures, sum, &one, &beta, crossProduct, &nFeatures);
beta = 1.0;
if (accWtOld)
{
alpha = 1.0 / accWtOld;
// S_old S_old^t/accWtOld
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sumOld, &nFeatures, sumOld, &one, &beta, crossProduct, &nFeatures);
}
transb = 'T';
alpha = 1.0;

// X X^t
transb = 'T';
alpha = 1.0;
beta = 1.0;
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &nVectors, &alpha, data, &nFeatures, data, &nFeatures, &beta, crossProduct,
&nFeatures);

if (accWtOld) daal::services::daal_free(sumOld);

return errcode;
}

Expand Down

0 comments on commit 0293b6f

Please sign in to comment.