Skip to content

Commit

Permalink
added working example of matrix mult using amrex table2D interface to…
Browse files Browse the repository at this point in the history
… cublas, but getting right ans for wrong reasons
  • Loading branch information
saurabh-s-sawant committed Aug 27, 2024
1 parent d1b30a1 commit 2f2f388
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 50 deletions.
72 changes: 72 additions & 0 deletions scripts/negf_cpp/math_libraries/Source/GlobalFuncs.H
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#ifndef GLOBAL_FUNCS_H_
#define GLOBAL_FUNCS_H_

#include <AMReX_GpuUtility.H>

#include <iomanip>
#include <string>

#include "MatrixDef.H"

template<typename U, typename V>
void
Define_Table2D (U& Tab2D_data, V val)
{
auto tlo = Tab2D_data.lo();
auto thi = Tab2D_data.hi();

auto const& Tab2D = Tab2D_data.table();

for (int i = tlo[0]; i < thi[0]; ++i)
{
for (int j = tlo[1]; j < thi[1]; ++j) //slow access
{
ComplexType new_val(val.real()*(j+1), val.imag()*(j+1));

Tab2D(i,j) = new_val;
}
}
}


template<typename U, typename V>
void
SetVal_Table2D (U& Tab2D_data, V val)
{
auto tlo = Tab2D_data.lo();
auto thi = Tab2D_data.hi();

auto const& Tab2D = Tab2D_data.table();

for (int i = tlo[0]; i < thi[0]; ++i)
{
for (int j = tlo[1]; j < thi[1]; ++j) //slow access
{
Tab2D(i,j) = val;
}
}
}

template<typename U>
void Print_Table2D(const U& Tab2D_data, const std::string tablename="")
{
auto tlo = Tab2D_data.lo();
auto thi = Tab2D_data.hi();

auto const& Tab2D = Tab2D_data.table();

std::cout << "\nPrinting Table: " << tablename << "\n";
for (int i = tlo[0]; i < thi[0]; ++i)
{
for (int j = tlo[1]; j < thi[1]; ++j) //slow access
{
std::cout << std::setw(12) << std::setprecision(6) << std::fixed
<< Tab2D(i,j).real() << " + "
<< Tab2D(i,j).imag() << "i";

if (j < thi[1] - 1) std::cout << ", ";
}
std::cout << "\n";
}
}
#endif
1 change: 1 addition & 0 deletions scripts/negf_cpp/math_libraries/Source/Make.package
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ CEXE_sources += main.cpp
CEXE_sources += MathLib.cpp
CEXE_headers += MathLib.H
CEXE_headers += MatrixDef.H
CEXE_headers += GlobalFuncs.H
CEXE_headers += cudaErrorCheck.H

VPATH_LOCATIONS += $(CODE_HOME)/Source
Expand Down
1 change: 0 additions & 1 deletion scripts/negf_cpp/math_libraries/Source/MathLib.H
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include <AMReX_REAL.H>
#include "MatrixDef.H"

namespace MathLib
Expand Down
27 changes: 15 additions & 12 deletions scripts/negf_cpp/math_libraries/Source/MathLib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,17 @@
// CUSOLVER
#include "cusolverSp.h"
#include "cusolverSp_LOWLEVEL_PREVIEW.h"
//#include "helper_cuda.h"
//#include "helper_cusolver.h"

// CUBLAS
#include <cublas_v2.h>

//#include <helper_functions.h>
//#include <helper_cuda.h>
#endif

void MathLib::MatrixMatrixMultiply(ComplexType* d_C,
const ComplexType* d_A,
const ComplexType* d_B,
unsigned int wA,
unsigned int hA,
unsigned int wB)
unsigned int A_rows,
unsigned int A_cols,
unsigned int B_cols)
{
#ifdef AMREX_USE_GPU
#ifdef AMREX_USE_CUDA
Expand All @@ -33,10 +28,18 @@ void MathLib::MatrixMatrixMultiply(ComplexType* d_C,
checkCudaErrors(cublasCreate(&handle));

// Perform matrix multiplication: C = alpha * A * B + beta * C
checkCudaErrors(cublasZgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, wB, hA, wA, &alpha,
reinterpret_cast<const cuDoubleComplex*>(d_B), wB,
reinterpret_cast<const cuDoubleComplex*>(d_A), wA, &beta,
reinterpret_cast<cuDoubleComplex*>(d_C), wB));
// see: https://docs.nvidia.com/cuda/archive/10.0/cublas/index.html

// Usage Error: Below, instead of A_rows+1, A_cols+1,
// we should use A_rows and A_cols. But current usage gives correct answer!
// These fields are lda, ldb, and ldc (leading dimensions to store matrices)

cublasStatus_t status =cublasZgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N,
A_rows, B_cols, A_cols, &alpha,
reinterpret_cast<const cuDoubleComplex*>(d_A), A_rows+1,
reinterpret_cast<const cuDoubleComplex*>(d_B), A_cols+1, &beta,
reinterpret_cast<cuDoubleComplex*>(d_C), A_rows+1);
checkCudaErrors(status);

checkCudaErrors(cublasDestroy(handle));
#elif AMREX_USE_HIP
Expand Down
21 changes: 2 additions & 19 deletions scripts/negf_cpp/math_libraries/Source/MatrixDef.H
Original file line number Diff line number Diff line change
@@ -1,28 +1,11 @@
#ifndef MATRIX_DEF_H_
#define MATRIX_DEF_H_

#include <AMReX_GpuComplex.H>
#include <AMReX_REAL.H>
#include<AMReX_TableData.H>
#include <AMReX_GpuUtility.H>
#include <AMReX_REAL.H>

using ComplexType = amrex::GpuComplex<amrex::Real>;
using Matrix2D = amrex::TableData<ComplexType, 2>;

template<typename U, typename V>
void
SetVal_Table2D (U& Tab2D_data, V val)
{
auto tlo = Tab2D_data.lo();
auto thi = Tab2D_data.hi();

auto const& Tab2D = Tab2D_data.table();

for (int i = tlo[0]; i < thi[0]; ++i)
{
for (int j = tlo[1]; j < thi[1]; ++j) //slow moving index. printing slow
{
Tab2D(i,j) = val;
}
}
}
#endif
82 changes: 64 additions & 18 deletions scripts/negf_cpp/math_libraries/Source/main.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include <AMReX_ParmParse.H>

#include "MathLib.H"
#include "MatrixDef.H"
#include "GlobalFuncs.H"

#include <AMReX_ParmParse.H>
using namespace amrex;

int main (int argc, char* argv[])
Expand All @@ -12,40 +14,84 @@ int main (int argc, char* argv[])
int my_rank = ParallelDescriptor::MyProc();
amrex::Print() << "total number of procs: " << num_proc << "\n";

int N_total = 4; /*matrix size*/
amrex::ParmParse pp;
pp.query("N_total", N_total);

int A_rows=4, A_cols=4, B_cols = 4; //assume, B_rows = A_cols
pp.query("A_rows", A_rows);
pp.query("A_cols", A_cols);
pp.query("B_cols", B_cols);

int print_matrix_flag = false;
pp.query("print_matrix", print_matrix_flag);

ComplexType num1(1.,2.);
ComplexType num2(-5.,3.);

Matrix2D h_A_data({0,0},{N_total,2*N_total},The_Pinned_Arena());
SetVal_Table2D(h_A_data, num1);
//Create A, B, C matrices as 2D tables. We want to perform C = A * B.
Matrix2D h_A_data({0,0},{A_rows, A_cols},The_Pinned_Arena());
Matrix2D d_A_data({0,0},{A_rows, A_cols},The_Arena());

Matrix2D d_A_data({0,0},{N_total,2*N_total},The_Arena());
h_A_data.copy(d_A_data);
Matrix2D h_B_data({0,0},{A_cols, B_cols},The_Pinned_Arena());
Matrix2D d_B_data({0,0},{A_cols, B_cols},The_Arena());

Matrix2D h_B_data({0,0},{2*N_total,N_total},The_Pinned_Arena());
SetVal_Table2D(h_A_data, num2);
Matrix2D h_C_data({0,0},{A_rows, B_cols},The_Pinned_Arena());
Matrix2D d_C_data({0,0},{A_rows, B_cols},The_Arena());

Matrix2D d_B_data({0,0},{2*N_total,N_total},The_Arena());
h_A_data.copy(d_A_data);
//define matrices A & B
ComplexType num1(1.,2.);
ComplexType num2(-5.,3.);
ComplexType zero(0.,0.);
Define_Table2D(h_A_data, num1);
Define_Table2D(h_B_data, num2);
SetVal_Table2D(h_C_data, zero);

Matrix2D h_C_data({0,0},{N_total,N_total},The_Pinned_Arena());
Matrix2D d_C_data({0,0},{N_total,N_total},The_Arena());
//copy to A & B to device
d_A_data.copy(h_A_data);
d_B_data.copy(h_B_data);
d_C_data.copy(h_C_data);

//get references to tables
const auto& dim_A = d_A_data.hi();
const auto& dim_B = d_B_data.hi();
const auto& d_A = d_A_data.const_table();
const auto& d_B = d_B_data.const_table();
const auto& d_C = d_C_data.table();

MathLib::MatrixMatrixMultiply(d_C.p, d_A.p, d_B.p, dim_A[0], dim_A[1], dim_B[0]);
d_A_data.copy(h_A_data);
amrex::Print() << "dim_A (rows/cols): " << dim_A[0] << " " << dim_A[1] << "\n";
amrex::Print() << "dim_B (rows/cols): " << dim_B[0] << " " << dim_B[1] << "\n";

//print A & B
Print_Table2D(h_A_data, "A");
const auto& h_A = h_A_data.const_table();
const auto& h_B = h_B_data.const_table();

amrex::Print() << "\nPrinting h_A using h_A.p\n";

//Usage Error: In the forloop we should be using (i < dim_A[0]*dim_A[1])
//But currently we need to add a buffer of 1 unit size to print properly!
for(int i=0; i<(dim_A[0]+1)*dim_A[1]; ++i)
{
amrex::Print() << i << " "<< *(h_A.p+i) << "\n";
}

Print_Table2D(h_B_data, "B");

//Perform C = A * B
MathLib::MatrixMatrixMultiply(d_C.p, d_A.p, d_B.p, dim_A[0], dim_A[1], dim_B[1]);

//copy C from device to host and print
h_C_data.copy(d_C_data);
Gpu::streamSynchronize();

Print_Table2D(h_C_data, "C");

h_A_data.clear();
h_B_data.clear();
h_C_data.clear();
d_A_data.clear();
d_B_data.clear();
d_C_data.clear();

d_A_data.clear();
d_B_data.clear();
d_C_data.clear();

amrex::Finalize();
}

0 comments on commit 2f2f388

Please sign in to comment.