Skip to content

Commit

Permalink
Added xcp ref implementation
Browse files Browse the repository at this point in the history
This routine computes the matrix of cross product
of data stored in column major format, in batches.

For matrix X of dimensions p x n, the i,j th entry
of the cross product matrix is

C_ij = \sum_k (x_ik-\mu_i) (x_jk-\mu_k)

where x_ij is the jth element of the ith row, of the matrix X.

Implementation uses the BLAS routine GEMM.

Signed-off-by: Dhanus M Lal <[email protected]>
  • Loading branch information
DhanusML committed Sep 10, 2024
1 parent c56aef7 commit 478255a
Showing 1 changed file with 104 additions and 2 deletions.
106 changes: 104 additions & 2 deletions cpp/daal/src/externals/service_stat_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define __SERVICE_STAT_REF_H__

#include "src/externals/service_memory.h"
#include "src/externals/service_blas_ref.h"

typedef void (*func_type)(DAAL_INT, DAAL_INT, DAAL_INT, void *);
extern "C"
Expand Down Expand Up @@ -173,7 +174,57 @@ struct RefStatistics<double, cpu>
static int xcp(double * data, __int64 nFeatures, __int64 nVectors, double * nPreviousObservations, double * sum, double * crossProduct,
__int64 method)
{
int errcode = 0;
int errcode = 0;
double * sumOld = NULL;
double accWtOld = *nPreviousObservations;
double accWt = *nPreviousObservations + nVectors;
if (accWtOld)
{
sumOld = daal::services::internal::service_calloc<double, cpu>(nFeatures, sizeof(double));
for (DAAL_INT i = 0; i < nFeatures; ++i)
{
sumOld[i] = sum[i];
}
}
for (DAAL_INT i = 0; i < nVectors; ++i)
{
for (DAAL_INT j = 0; j < nFeatures; ++j) // if accWtOld = 0, overwrite sum
{
if (accWtOld != 0)
{
sum[j] += data[i * nFeatures + j];
}
else
{
if (i == 0)
sum[j] = data[i * nFeatures + j]; //overwrite the current sum
else
sum[j] += data[i * nFeatures + j];
}
}
}
daal::internal::ref::OpenBlas<double, cpu> blasInst;
DAAL_INT one = 1;
char transa = 'N';
char transb = 'N';
double alpha = -1.0 / accWt;
double beta = accWtOld ? 1.0 : 0.0; // if accWtOld = 0, overwrite the cross product
// -S S^t/accWt
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sum, &nFeatures, sum, &one, &beta, crossProduct, &nFeatures);
beta = 1.0;
if (accWtOld)
{
alpha = 1.0 / accWtOld;
// S_old S_old^t/accWtOld
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sumOld, &nFeatures, sumOld, &one, &beta, crossProduct, &nFeatures);
}
transb = 'T';
alpha = 1.0;

// X X^t
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &nVectors, &alpha, data, &nFeatures, data, &nFeatures, &beta, crossProduct,
&nFeatures);
if (accWtOld) daal::services::daal_free(sumOld);

return errcode;
}
Expand Down Expand Up @@ -284,7 +335,58 @@ struct RefStatistics<float, cpu>
static int xcp(float * data, __int64 nFeatures, __int64 nVectors, float * nPreviousObservations, float * sum, float * crossProduct,
__int64 method)
{
int errcode = 0;
int errcode = 0;
float * sumOld = NULL;
float accWtOld = *nPreviousObservations;
float accWt = *nPreviousObservations + nVectors;
if (accWtOld)
{
sumOld = daal::services::internal::service_calloc<float, cpu>(nFeatures, sizeof(float));
for (DAAL_INT i = 0; i < nFeatures; ++i)
{
sumOld[i] = sum[i];
}
}
for (DAAL_INT i = 0; i < nVectors; ++i)
{
for (DAAL_INT j = 0; j < nFeatures; ++j) // if accWtOld = 0, overwrite sum
{
if (accWtOld != 0)
{
sum[j] += data[i * nFeatures + j];
}
else
{
if (i == 0)
sum[j] = data[i * nFeatures + j]; //overwrite the current sum
else
sum[j] += data[i * nFeatures + j];
}
}
}
daal::internal::ref::OpenBlas<float, cpu> blasInst;
DAAL_INT one = 1;
char transa = 'N';
char transb = 'N';
float alpha = -1.0 / accWt;
float beta = accWtOld ? 1.0 : 0.0; // if accWtOld = 0, overwrite the cross product
// -S S^t/accWt
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sum, &nFeatures, sum, &one, &beta, crossProduct, &nFeatures);
beta = 1.0;
if (accWtOld)
{
alpha = 1.0 / accWtOld;
// S_old S_old^t/accWtOld
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sumOld, &nFeatures, sumOld, &one, &beta, crossProduct, &nFeatures);
}
transb = 'T';
alpha = 1.0;

// X X^t
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &nVectors, &alpha, data, &nFeatures, data, &nFeatures, &beta, crossProduct,
&nFeatures);

if (accWtOld) daal::services::daal_free(sumOld);

return errcode;
}
Expand Down

0 comments on commit 478255a

Please sign in to comment.