Skip to content

Commit 592c593

Browse files
authored
Cleanup to use pass-by-reference more consistently (#861)
* Cleanup to use pass-by-reference more consistently
1 parent b1d0b1e commit 592c593

File tree

5 files changed

+27
-70
lines changed

5 files changed

+27
-70
lines changed

include/matx/operators/solve.h

+6-5
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,11 @@ class SolveOp : public BaseOp<SolveOp<OpA, OpB>> {
102102
}
103103

104104
template <typename ShapeType, typename Executor>
105-
__MATX_INLINE__ void InnerPreRun([[maybe_unused]] ShapeType &&shape,
106-
[[maybe_unused]] Executor &&ex) const noexcept {
105+
__MATX_INLINE__ void
106+
InnerPreRun([[maybe_unused]] ShapeType &&shape,
107+
[[maybe_unused]] Executor &&ex) const noexcept {
107108
static_assert(is_sparse_tensor_v<OpA>,
108-
"Direct solver currently only supports sparse system");
109+
"Direct solver currently only supports sparse system");
109110
if constexpr (is_matx_op<OpB>()) {
110111
b_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
111112
}
@@ -122,9 +123,9 @@ class SolveOp : public BaseOp<SolveOp<OpA, OpB>> {
122123

123124
template <typename ShapeType, typename Executor>
124125
__MATX_INLINE__ void PostRun([[maybe_unused]] ShapeType &&shape,
125-
[[maybe_unused]]Executor &&ex) const noexcept {
126+
[[maybe_unused]] Executor &&ex) const noexcept {
126127
static_assert(is_sparse_tensor_v<OpA>,
127-
"Direct solver currently only supports sparse system");
128+
"Direct solver currently only supports sparse system");
128129
if constexpr (is_matx_op<OpB>()) {
129130
b_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
130131
}

include/matx/operators/sparse2dense.h

+2-3
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,12 @@ class Sparse2DenseOp : public BaseOp<Sparse2DenseOp<OpA>> {
8686
template <typename Out, typename Executor>
8787
void Exec([[maybe_unused]] Out &&out, [[maybe_unused]] Executor &&ex) const {
8888
if constexpr (is_sparse_tensor_v<OpA>) {
89-
auto ref = cuda::std::get<0>(out);
90-
using Rtype = decltype(ref);
89+
using Rtype = decltype(cuda::std::get<0>(out));
9190
if constexpr (is_sparse_tensor_v<Rtype>) {
9291
MATX_THROW(matxNotSupported,
9392
"Cannot use sparse2dense for sparse output");
9493
} else {
95-
sparse2dense_impl(ref, a_, ex);
94+
sparse2dense_impl(cuda::std::get<0>(out), a_, ex);
9695
}
9796
} else {
9897
MATX_THROW(matxNotSupported, "Cannot use sparse2dense on dense input");

include/matx/transforms/convert/sparse2dense_cusparse.h

+3-16
Original file line numberDiff line numberDiff line change
@@ -215,33 +215,20 @@ using sparse2dense_cache_t =
215215

216216
} // end namespace detail
217217

218-
template <typename Op>
219-
__MATX_INLINE__ auto getSparse2DenseSupportedTensor(const Op &in,
220-
cudaStream_t stream) {
221-
const auto support_func = [&in]() { return true; };
222-
return GetSupportedTensor(in, support_func, MATX_ASYNC_DEVICE_MEMORY, stream);
223-
}
224-
225218
template <typename OutputTensorType, typename InputTensorType>
226-
void sparse2dense_impl(OutputTensorType O, const InputTensorType A,
219+
void sparse2dense_impl(OutputTensorType &o, const InputTensorType &a,
227220
const cudaExecutor &exec) {
228221
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
229222
const auto stream = exec.getStream();
230223

231-
auto a = A; // always sparse
232-
auto o = getSparse2DenseSupportedTensor(O, stream);
233-
234224
// TODO: some more checking, supported type? on device? etc.
235225

236-
using atype = decltype(a);
237-
using otype = decltype(o);
238-
239226
// Get parameters required by these tensors (for caching).
240227
auto params =
241-
detail::Sparse2DenseHandle_t<otype, atype>::GetConvParams(o, a, stream);
228+
detail::Sparse2DenseHandle_t<OutputTensorType, InputTensorType>::GetConvParams(o, a, stream);
242229

243230
// Lookup and cache.
244-
using cache_val_type = detail::Sparse2DenseHandle_t<otype, atype>;
231+
using cache_val_type = detail::Sparse2DenseHandle_t<OutputTensorType, InputTensorType>;
245232
detail::GetCache().LookupAndExec<detail::sparse2dense_cache_t>(
246233
detail::GetCacheIdFromType<detail::sparse2dense_cache_t>(), params,
247234
[&]() { return std::make_shared<cache_val_type>(o, a, stream); },

include/matx/transforms/matmul/matmul_cusparse.h

+3-18
Original file line numberDiff line numberDiff line change
@@ -254,37 +254,22 @@ using gemm_cusparse_cache_t =
254254

255255
} // end namespace detail
256256

257-
template <typename Op>
258-
__MATX_INLINE__ auto getCUSPARSESupportedTensor(const Op &in,
259-
cudaStream_t stream) {
260-
const auto support_func = [&in]() { return true; };
261-
return GetSupportedTensor(in, support_func, MATX_ASYNC_DEVICE_MEMORY, stream);
262-
}
263-
264257
template <typename TensorTypeC, typename TensorTypeA, typename TensorTypeB>
265-
void sparse_matmul_impl(TensorTypeC C, const TensorTypeA A, const TensorTypeB B,
258+
void sparse_matmul_impl(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB &b,
266259
const cudaExecutor &exec, float alpha = 1.0,
267260
float beta = 0.0) {
268261
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
269262
const auto stream = exec.getStream();
270263

271-
auto a = A; // always sparse
272-
auto b = getCUSPARSESupportedTensor(B, stream);
273-
auto c = getCUSPARSESupportedTensor(C, stream);
274-
275264
// TODO: some more checking, supported type? on device? etc.
276265

277-
using atype = decltype(a);
278-
using btype = decltype(b);
279-
using ctype = decltype(c);
280-
281266
// Get parameters required by these tensors (for caching).
282267
auto params =
283-
detail::MatMulCUSPARSEHandle_t<ctype, atype, btype>::GetGemmParams(
268+
detail::MatMulCUSPARSEHandle_t<TensorTypeC, TensorTypeA, TensorTypeB>::GetGemmParams(
284269
c, a, b, stream, alpha, beta);
285270

286271
// Lookup and cache.
287-
using cache_val_type = detail::MatMulCUSPARSEHandle_t<ctype, atype, btype>;
272+
using cache_val_type = detail::MatMulCUSPARSEHandle_t<TensorTypeC, TensorTypeA, TensorTypeB>;
288273
detail::GetCache().LookupAndExec<detail::gemm_cusparse_cache_t>(
289274
detail::GetCacheIdFromType<detail::gemm_cusparse_cache_t>(), params,
290275
[&]() {

include/matx/transforms/solve/solve_cudss.h

+13-28
Original file line numberDiff line numberDiff line change
@@ -238,35 +238,20 @@ using gemm_cudss_cache_t =
238238

239239
} // end namespace detail
240240

241-
template <typename Op>
242-
__MATX_INLINE__ auto getCUDSSSupportedTensor(const Op &in,
243-
cudaStream_t stream) {
244-
const auto support_func = [&in]() { return true; };
245-
return GetSupportedTensor(in, support_func, MATX_ASYNC_DEVICE_MEMORY, stream);
246-
}
247-
248241
template <typename TensorTypeC, typename TensorTypeA, typename TensorTypeB>
249-
void sparse_solve_impl_trans(TensorTypeC C, const TensorTypeA A,
250-
const TensorTypeB B, const cudaExecutor &exec) {
242+
void sparse_solve_impl_trans(TensorTypeC &c, const TensorTypeA &a,
243+
const TensorTypeB &b, const cudaExecutor &exec) {
251244
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
252245
const auto stream = exec.getStream();
253246

254-
auto a = A; // always sparse
255-
auto b = getCUDSSSupportedTensor(B, stream);
256-
auto c = getCUDSSSupportedTensor(C, stream);
257-
258247
// TODO: some more checking, supported type? on device? etc.
259248

260-
using atype = decltype(a);
261-
using btype = decltype(b);
262-
using ctype = decltype(c);
263-
264249
// Get parameters required by these tensors (for caching).
265-
auto params = detail::SolveCUDSSHandle_t<ctype, atype, btype>::GetSolveParams(
250+
auto params = detail::SolveCUDSSHandle_t<TensorTypeC, TensorTypeA, TensorTypeB>::GetSolveParams(
266251
c, a, b, stream);
267252

268253
// Lookup and cache.
269-
using cache_val_type = detail::SolveCUDSSHandle_t<ctype, atype, btype>;
254+
using cache_val_type = detail::SolveCUDSSHandle_t<TensorTypeC, TensorTypeA, TensorTypeB>;
270255
detail::GetCache().LookupAndExec<detail::gemm_cudss_cache_t>(
271256
detail::GetCacheIdFromType<detail::gemm_cudss_cache_t>(), params,
272257
[&]() { return std::make_shared<cache_val_type>(c, a, b, stream); },
@@ -282,29 +267,29 @@ void sparse_solve_impl_trans(TensorTypeC C, const TensorTypeA A,
282267
// supports MATX native row-major storage, which will clean up the copies from
283268
// and to memory.
284269
template <typename TensorTypeC, typename TensorTypeA, typename TensorTypeB>
285-
void sparse_solve_impl(TensorTypeC C, const TensorTypeA A, const TensorTypeB B,
286-
const cudaExecutor &exec) {
270+
void sparse_solve_impl(TensorTypeC &c, const TensorTypeA &a,
271+
const TensorTypeB &b, const cudaExecutor &exec) {
287272
const auto stream = exec.getStream();
288273

289274
// Some copying-in hacks, assumes rank 2.
290275
using TB = typename TensorTypeB::value_type;
291276
using TC = typename TensorTypeB::value_type;
292277
TB *bptr;
293278
matxAlloc(reinterpret_cast<void **>(&bptr),
294-
sizeof(TB) * B.Size(0) * B.Size(1), MATX_ASYNC_DEVICE_MEMORY,
279+
sizeof(TB) * b.Size(0) * b.Size(1), MATX_ASYNC_DEVICE_MEMORY,
295280
stream);
296-
auto bT = make_tensor(bptr, {B.Size(1), B.Size(0)});
297-
(bT = transpose(B)).run(exec);
281+
auto bT = make_tensor(bptr, {b.Size(1), b.Size(0)});
282+
(bT = transpose(b)).run(exec);
298283
TC *cptr;
299284
matxAlloc(reinterpret_cast<void **>(&cptr),
300-
sizeof(TC) * C.Size(0) * C.Size(1), MATX_ASYNC_DEVICE_MEMORY,
285+
sizeof(TC) * c.Size(0) * c.Size(1), MATX_ASYNC_DEVICE_MEMORY,
301286
stream);
302-
auto cT = make_tensor(cptr, {C.Size(1), C.Size(0)});
287+
auto cT = make_tensor(cptr, {c.Size(1), c.Size(0)});
303288

304-
sparse_solve_impl_trans(cT, A, bT, exec);
289+
sparse_solve_impl_trans(cT, a, bT, exec);
305290

306291
// Some copying-back hacks.
307-
(C = transpose(cT)).run(exec);
292+
(c = transpose(cT)).run(exec);
308293
}
309294

310295
} // end namespace matx

0 commit comments

Comments
 (0)