Skip to content

Commit aabb88a

Browse files
authoredJul 12, 2024··
NVPL BLAS Support (#665)
1 parent 940c4b8 commit aabb88a

File tree

18 files changed

+1334
-698
lines changed

18 files changed

+1334
-698
lines changed
 

‎CMakeLists.txt

+20-15
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,20 @@ set(WARN_FLAGS ${WARN_FLAGS} $<$<COMPILE_LANGUAGE:CXX>:-Werror>)
172172
set (CUTLASS_INC "")
173173
target_compile_definitions(matx INTERFACE MATX_ENABLE_CUTLASS=0)
174174

175+
if (MATX_NVTX_FLAGS)
176+
add_definitions(-DMATX_NVTX_FLAGS)
177+
target_compile_definitions(matx INTERFACE MATX_NVTX_FLAGS)
178+
endif()
179+
if (MATX_BUILD_32_BIT)
180+
set(INT_TYPE "lp64")
181+
add_definitions(-DINDEX_32_BIT)
182+
target_compile_definitions(matx INTERFACE INDEX_32_BIT)
183+
else()
184+
set(INT_TYPE "ilp64")
185+
add_definitions(-DINDEX_64_BIT)
186+
target_compile_definitions(matx INTERFACE INDEX_64_BIT)
187+
endif()
188+
175189
# Host support
176190
if (MATX_EN_NVPL OR MATX_EN_X86_FFTW)
177191
message(STATUS "Enabling OpenMP support")
@@ -180,9 +194,12 @@ if (MATX_EN_NVPL OR MATX_EN_X86_FFTW)
180194
target_compile_options(matx INTERFACE ${OpenMP_CXX_FLAGS})
181195
target_compile_definitions(matx INTERFACE MATX_EN_OMP=1)
182196
if (MATX_EN_NVPL)
183-
message(STATUS "Enabling NVPL library support for ARM CPUs")
184-
find_package(nvpl REQUIRED COMPONENTS fft)
185-
target_link_libraries(matx INTERFACE nvpl::fftw)
197+
message(STATUS "Enabling NVPL library support for ARM CPUs with ${INT_TYPE} interface")
198+
find_package(nvpl REQUIRED COMPONENTS fft blas)
199+
if (NOT MATX_BUILD_32_BIT)
200+
target_compile_definitions(matx INTERFACE NVPL_ILP64)
201+
endif()
202+
target_link_libraries(matx INTERFACE nvpl::fftw nvpl::blas_${INT_TYPE}_omp)
186203
target_compile_definitions(matx INTERFACE MATX_EN_NVPL=1)
187204
else()
188205
if (MATX_EN_X86_FFTW)
@@ -316,18 +333,6 @@ if (NOT_SUBPROJECT)
316333
endif()
317334

318335

319-
if (MATX_NVTX_FLAGS)
320-
add_definitions(-DMATX_NVTX_FLAGS)
321-
target_compile_definitions(matx INTERFACE MATX_NVTX_FLAGS)
322-
endif()
323-
if (MATX_BUILD_32_BIT)
324-
add_definitions(-DINDEX_32_BIT)
325-
target_compile_definitions(matx INTERFACE INDEX_32_BIT)
326-
else()
327-
add_definitions(-DINDEX_64_BIT)
328-
target_compile_definitions(matx INTERFACE INDEX_64_BIT)
329-
endif()
330-
331336
if (MATX_BUILD_EXAMPLES)
332337
add_subdirectory(examples)
333338
endif()

‎include/matx/executors/support.h

+24
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ namespace matx {
4646
#define MATX_EN_CPU_FFT 0
4747
#endif
4848

49+
// MatMul
50+
#if defined(MATX_EN_NVPL)
51+
#define MATX_EN_CPU_MATMUL 1
52+
#else
53+
#define MATX_EN_CPU_MATMUL 0
54+
#endif
55+
4956
template <typename Exec, typename T>
5057
constexpr bool CheckFFTSupport() {
5158
if constexpr (is_host_executor_v<Exec>) {
@@ -70,5 +77,22 @@ constexpr bool CheckDirect1DConvSupport() {
7077
}
7178
}
7279

80+
template <typename Exec, typename T>
81+
constexpr bool CheckMatMulSupport() {
82+
if constexpr (is_host_executor_v<Exec>) {
83+
if constexpr (std::is_same_v<T, float> ||
84+
std::is_same_v<T, double> ||
85+
std::is_same_v<T, cuda::std::complex<float>> ||
86+
std::is_same_v<T, cuda::std::complex<double>>) {
87+
return MATX_EN_CPU_MATMUL;
88+
} else {
89+
return false;
90+
}
91+
}
92+
else {
93+
return true;
94+
}
95+
}
96+
7397
}; // detail
7498
}; // matx

‎include/matx/kernels/channelize_poly.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ __global__ void ChannelizePoly1D_Smem(OutType output, InType input, FilterType f
256256
__syncthreads();
257257

258258
// Load next elems_per_channel_per_cta elements for each channel
259-
const index_t next_last_elem = cuda::std::min(next_start_elem + by - 1, last_elem);
259+
const index_t next_last_elem = cuda::std::min(next_start_elem + static_cast<index_t>(by) - 1, last_elem);
260260
const uint32_t out_samples_this_iter = static_cast<uint32_t>(next_last_elem - next_start_elem + 1);
261261
if (ty < out_samples_this_iter) {
262262
indims[InRank-1] = (next_start_elem + ty) * num_channels + chan;

‎include/matx/operators/cov.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ namespace matx
8585
template <typename Out, typename Executor>
8686
void Exec(Out &&out, Executor &&ex) const {
8787
static_assert(is_cuda_executor_v<Executor>, "cov() only supports the CUDA executor currently");
88-
cov_impl(cuda::std::get<0>(out), a_, ex.getStream());
88+
cov_impl(cuda::std::get<0>(out), a_, ex);
8989
}
9090

9191
template <typename ShapeType, typename Executor>

‎include/matx/operators/matmul.h

+6-4
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535

3636
#include "matx/core/type_utils.h"
3737
#include "matx/operators/base_operator.h"
38-
#include "matx/transforms/matmul.h"
38+
#include "matx/transforms/matmul/matmul_cuda.h"
39+
#ifdef MATX_EN_CPU_MATMUL
40+
#include "matx/transforms/matmul/matmul_cblas.h"
41+
#endif
3942

4043
namespace matx
4144
{
@@ -108,12 +111,11 @@ namespace matx
108111

109112
template <typename Out, typename Executor>
110113
void Exec(Out &&out, Executor &&ex) const {
111-
static_assert(is_cuda_executor_v<Executor>, "matmul() only supports the CUDA executor currently");
112114
if constexpr (!std::is_same_v<PermDims, no_permute_t>) {
113-
matmul_impl(permute(cuda::std::get<0>(out), perm_), a_, b_, ex.getStream(), alpha_, beta_);
115+
matmul_impl(permute(cuda::std::get<0>(out), perm_), a_, b_, ex, alpha_, beta_);
114116
}
115117
else {
116-
matmul_impl(cuda::std::get<0>(out), a_, b_, ex.getStream(), alpha_, beta_);
118+
matmul_impl(cuda::std::get<0>(out), a_, b_, ex, alpha_, beta_);
117119
}
118120
}
119121

‎include/matx/operators/matvec.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,7 @@ namespace matx
8989

9090
template <typename Out, typename Executor>
9191
void Exec(Out &&out, Executor &&ex) const{
92-
static_assert(is_cuda_executor_v<Executor>, "matvec() only supports the CUDA executor currently");
93-
matvec_impl(cuda::std::get<0>(out), a_, b_, ex.getStream(), alpha_, beta_);
92+
matvec_impl(cuda::std::get<0>(out), a_, b_, ex, alpha_, beta_);
9493
}
9594

9695
template <typename ShapeType, typename Executor>

‎include/matx/operators/outer.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,7 @@ namespace matx
9898

9999
template <typename Out, typename Executor>
100100
void Exec(Out &&out, Executor &&ex) const{
101-
static_assert(is_cuda_executor_v<Executor>, "outer() only supports the CUDA executor currently");
102-
outer_impl(cuda::std::get<0>(out), a_, b_, ex.getStream(), alpha_, beta_);
101+
outer_impl(cuda::std::get<0>(out), a_, b_, ex, alpha_, beta_);
103102
}
104103

105104
template <typename ShapeType, typename Executor>

‎include/matx/operators/qr.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ namespace detail {
6464
static_assert(is_cuda_executor_v<Executor>, "svd() only supports the CUDA executor currently");
6565
static_assert(cuda::std::tuple_size_v<remove_cvref_t<Out>> == 3, "Must use mtie with 3 outputs on qr(). ie: (mtie(Q, R) = qr(A))");
6666

67-
qr_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), a_, ex.getStream());
67+
qr_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), a_, ex);
6868
}
6969

7070
static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()

‎include/matx/operators/svd.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ namespace detail {
127127
static_assert(is_cuda_executor_v<Executor>, "svdpi() only supports the CUDA executor currently");
128128
static_assert(cuda::std::tuple_size_v<remove_cvref_t<Out>> == 4, "Must use mtie with 3 outputs on svdpi(). ie: (mtie(U, S, V) = svdpi(A))");
129129

130-
svdpi_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), cuda::std::get<2>(out), a_, x_, iterations_, ex.getStream(), k_);
130+
svdpi_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), cuda::std::get<2>(out), a_, x_, iterations_, ex, k_);
131131
}
132132

133133
static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
@@ -204,7 +204,7 @@ namespace detail {
204204
static_assert(is_cuda_executor_v<Executor>, "svdbpi() only supports the CUDA executor currently");
205205
static_assert(cuda::std::tuple_size_v<remove_cvref_t<Out>> == 4, "Must use mtie with 3 outputs on svdbpi(). ie: (mtie(U, S, V) = svdbpi(A))");
206206

207-
svdbpi_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), cuda::std::get<2>(out), a_, max_iters_, tol_, ex.getStream());
207+
svdbpi_impl(cuda::std::get<0>(out), cuda::std::get<1>(out), cuda::std::get<2>(out), a_, max_iters_, tol_, ex);
208208
}
209209

210210
static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()

‎include/matx/transforms/cov.h

+14-10
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
#include "matx/core/error.h"
4040
#include "matx/core/nvtx.h"
4141
#include "matx/core/tensor.h"
42-
#include "matx/transforms/matmul.h"
42+
#include "matx/transforms/matmul/matmul_cuda.h"
4343
#include "matx/transforms/transpose.h"
4444

4545
namespace matx {
@@ -137,16 +137,18 @@ template <typename TensorTypeC, typename TensorTypeA> class matxCovHandle_t {
137137
* Output covariance matrix
138138
* @param a
139139
* Input tensor A
140-
* @param stream
141-
* CUDA stream
140+
* @param exec
141+
* CUDA executor
142142
*
143143
*/
144144
inline void Exec(TensorTypeC &c, const TensorTypeA &a,
145-
cudaStream_t stream)
145+
const cudaExecutor &exec)
146146
{
147147
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
148+
const auto stream = exec.getStream();
149+
148150
// Calculate a matrix of means
149-
matmul_impl(means, onesM, a, stream,
151+
matmul_impl(means, onesM, a, exec,
150152
1.0f / static_cast<float>(a.Size(RANK - 2)));
151153

152154
// Subtract the means from the observations to get the deviations
@@ -165,7 +167,7 @@ template <typename TensorTypeC, typename TensorTypeA> class matxCovHandle_t {
165167
}
166168

167169
// Multiply by itself and scale by N-1 for the final covariance
168-
matmul_impl(c, devsT, devs, stream,
170+
matmul_impl(c, devsT, devs, exec,
169171
1.0f / static_cast<float>(a.Size(RANK - 2) - 1));
170172
}
171173

@@ -224,14 +226,16 @@ using cov_cache_t = std::unordered_map<CovParams_t, std::any, CovParamsKeyHash,
224226
* Covariance matrix output view
225227
* @param a
226228
* Covariance matrix input view
227-
* @param stream
228-
* CUDA stream
229+
* @param exec
230+
* CUDA executor
229231
*/
230232
template <typename TensorTypeC, typename TensorTypeA>
231233
void cov_impl(TensorTypeC &c, const TensorTypeA &a,
232-
cudaStream_t stream = 0)
234+
const cudaExecutor &exec)
233235
{
234236
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
237+
const auto stream = exec.getStream();
238+
235239
// Get parameters required by these tensors
236240
auto params = detail::matxCovHandle_t<TensorTypeC, TensorTypeA>::GetCovParams(c, a, stream);
237241

@@ -243,7 +247,7 @@ void cov_impl(TensorTypeC &c, const TensorTypeA &a,
243247
return std::make_shared<cache_val_type>(c, a);
244248
},
245249
[&](std::shared_ptr<cache_val_type> ctype) {
246-
ctype->Exec(c, a, stream);
250+
ctype->Exec(c, a, exec);
247251
}
248252
);
249253
}

0 commit comments

Comments
 (0)
Please sign in to comment.