Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GEMM Strides/Extents to CBLAS mapping #146

Open
crtrott opened this issue Feb 4, 2022 · 1 comment
Open

GEMM Strides/Extents to CBLAS mapping #146

crtrott opened this issue Feb 4, 2022 · 1 comment

Comments

@crtrott
Copy link
Member

crtrott commented Feb 4, 2022

Does this seem right? Posting for each matrix (extent(0),extent(1)), (stride(0),stride(1))

C A B defacto layouts CBLAS call
(M,N), (1,M) (M,K), (1,M) (K,N), (1,K) (left,left,left) gemm('N','N', M, N, K, 1., A.data(), K, B.data(), N, 1., C.data(), N)
(M,N), (1,M) (M,K), (K,1) (K,N), (1,K) (left,right,left) gemm('T','N', N, M, K, 1., A.data(), M, B.data(), N, 1., C.data(), N)
(M,N), (1,M) (M,K), (1,M) (K,N), (N,1) (left,left,right) gemm('N','T', M, N, K, 1., A.data(), K, B.data(), K, 1., C.data(), N)
(M,N), (1,M) (M,K), (K,1) (K,N), (N,1) (left,right,right) gemm('T','T', M, N, K, 1., A.data(), M, B.data(), K, 1., C.data(), N)
(M,N), (N,1) (M,K), (1, M) (K,N), (1,K) (right,left,left) gemm('T','T', N, M, K, 1., B.data(), N, A.data(), K, 1., C.data(), M)
(M,N), (N,1) (M,K), (K, 1) (K,N), (1,K) (right,right,left) gemm('T','N', N, M, K, 1., B.data(), N, A.data(), M, 1., C.data(), M)
(M,N), (N,1) (M,K), (1, M) (K,N), (N,1) (right,left,right) gemm('N','T', N, M, K, 1., B.data(), K, A.data(), K, 1., C.data(), M)
(M,N), (N,1) (M,K), (K,1) (K,N), (N,1) (right,right,right) gemm('N','N', N, M, K, 1., B.data(), K, A.data(), M, 1., C.data(), M)
@crtrott
Copy link
Member Author

crtrott commented Feb 15, 2022

Here is the test code:

#include<Kokkos_Core.hpp>
#include<Kokkos_Random.hpp>

extern "C" void dgemm_(const char*, const char*, int*, int*, int*, double*, double*, int*, double*, int*, double*, double*, int*);

template<class AT, class BT, class CT>
void gemm(CT C, AT A, BT B) {
//  printf("C: %i %i %i %i\n",C.extent_int(0),C.extent_int(1),int(C.stride(0)),int(C.stride(1)));
//  printf("A: %i %i %i %i\n",A.extent_int(0),A.extent_int(1),int(A.stride(0)),int(A.stride(1)));
//  printf("B: %i %i %i %i\n",B.extent_int(0),B.extent_int(1),int(B.stride(0)),int(B.stride(1)));

  int M = C.extent(0);
  int N = C.extent(1);
  int K = A.extent(1);
  int LDA = A.stride(0)==1?A.extent(0):A.extent(1);
  int LDB = B.stride(0)==1?B.extent(0):B.extent(1);
  int LDC = C.stride(0)==1?C.extent(0):C.extent(1);
  double alpha = 1., beta = 0.;
  double* A_data = A.data();
  double* B_data = B.data();
  double* C_data = C.data();

  if(C.stride(0)==1) {
    if(A.stride(0)==1 && B.stride(0)==1)
      dgemm_("N","N",&M,&N,&K,&alpha,A_data,&LDA,B_data,&LDB,&beta,C_data,&LDC);
    if(A.stride(1)==1 && B.stride(0)==1)
      dgemm_("T","N",&M,&N,&K,&alpha,A_data,&LDA,B_data,&LDB,&beta,C_data,&LDC);
    if(A.stride(0)==1 && B.stride(1)==1)
      dgemm_("N","T",&M,&N,&K,&alpha,A_data,&LDA,B_data,&LDB,&beta,C_data,&LDC);
    if(A.stride(1)==1 && B.stride(1)==1)
      dgemm_("T","T",&M,&N,&K,&alpha,A_data,&LDA,B_data,&LDB,&beta,C_data,&LDC);
  } else if(C.stride(1)==1) {
    if(A.stride(0)==1 && B.stride(0)==1)
      dgemm_("T","T",&N,&M,&K,&alpha,B_data,&LDB,A_data,&LDA,&beta,C_data,&LDC);
    if(A.stride(1)==1 && B.stride(0)==1)
      dgemm_("T","N",&N,&M,&K,&alpha,B_data,&LDB,A_data,&LDA,&beta,C_data,&LDC);
    if(A.stride(0)==1 && B.stride(1)==1)
      dgemm_("N","T",&N,&M,&K,&alpha,B_data,&LDB,A_data,&LDA,&beta,C_data,&LDC);
    if(A.stride(1)==1 && B.stride(1)==1)
      dgemm_("N","N",&N,&M,&K,&alpha,B_data,&LDB,A_data,&LDA,&beta,C_data,&LDC);
  }
}

template<class LC, class LA, class LB>
void testgemm(int M, int N, int K) {
  Kokkos::View<double**,LA> A("A",M,K);
  Kokkos::View<double**,LB> B("B",K,N);
  Kokkos::View<double**,LC> C("C",M,N),C2("C2",M,N);

  Kokkos::Random_XorShift64_Pool<> g(1321);
  Kokkos::fill_random(A,g,1.0);
  Kokkos::fill_random(B,g,1.0);

  Kokkos::parallel_for("CreateReference",
    Kokkos::MDRangePolicy<Kokkos::Rank<2>>({0,0}, {C.extent(0), C.extent(1)}),
    KOKKOS_LAMBDA(int i, int j) {
      C2(i,j) = 0;
      for(int k=0; k<A.extent(1); k++) {
        C2(i,j) += A(i,k)*B(k,j);
      }
    }
  );
  gemm(C,A,B);
  int total_errors = 0;
  Kokkos::parallel_reduce("CheckEquivalence",
    Kokkos::MDRangePolicy<Kokkos::Rank<2>>({0,0}, {C.extent(0), C.extent(1)}),
    KOKKOS_LAMBDA(int i, int j, int& errors) {
      if((C(i,j) - C2(i,j))>1e-13) errors++;
      if(i==3 && j==3) printf("%lf %lf\n",C(i,j),C2(i,j));
  },total_errors);
  printf("Errors: %i\n",total_errors);
}

int main(int argc, char* argv[]) {
  Kokkos::initialize(argc,argv);
  {
  int N = 200, M = 57, K=113;
//  int N = 100, M = 100, K=100;

  testgemm<Kokkos::LayoutLeft,Kokkos::LayoutLeft,Kokkos::LayoutLeft>(N,M,K);
  testgemm<Kokkos::LayoutLeft,Kokkos::LayoutRight,Kokkos::LayoutLeft>(N,M,K);
  testgemm<Kokkos::LayoutLeft,Kokkos::LayoutLeft,Kokkos::LayoutRight>(N,M,K);
  testgemm<Kokkos::LayoutLeft,Kokkos::LayoutRight,Kokkos::LayoutRight>(N,M,K);
  testgemm<Kokkos::LayoutRight,Kokkos::LayoutLeft,Kokkos::LayoutLeft>(N,M,K);
  testgemm<Kokkos::LayoutRight,Kokkos::LayoutRight,Kokkos::LayoutLeft>(N,M,K);
  testgemm<Kokkos::LayoutRight,Kokkos::LayoutLeft,Kokkos::LayoutRight>(N,M,K);
  testgemm<Kokkos::LayoutRight,Kokkos::LayoutRight,Kokkos::LayoutRight>(N,M,K);

  }
  Kokkos::finalize();
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant