39
39
#include " matx/core/error.h"
40
40
#include " matx/core/nvtx.h"
41
41
#include " matx/core/tensor.h"
42
- #include " matx/transforms/matmul.h"
42
+ #include " matx/transforms/matmul/matmul_cuda .h"
43
43
#include " matx/transforms/transpose.h"
44
44
45
45
namespace matx {
@@ -137,16 +137,18 @@ template <typename TensorTypeC, typename TensorTypeA> class matxCovHandle_t {
137
137
* Output covariance matrix
138
138
* @param a
139
139
* Input tensor A
140
- * @param stream
141
- * CUDA stream
140
+ * @param exec
141
+ * CUDA executor
142
142
*
143
143
*/
144
144
inline void Exec (TensorTypeC &c, const TensorTypeA &a,
145
- cudaStream_t stream )
145
+ const cudaExecutor &exec )
146
146
{
147
147
MATX_NVTX_START (" " , matx::MATX_NVTX_LOG_INTERNAL)
148
+ const auto stream = exec.getStream ();
149
+
148
150
// Calculate a matrix of means
149
- matmul_impl (means, onesM, a, stream ,
151
+ matmul_impl (means, onesM, a, exec ,
150
152
1 .0f / static_cast <float >(a.Size (RANK - 2 )));
151
153
152
154
// Subtract the means from the observations to get the deviations
@@ -165,7 +167,7 @@ template <typename TensorTypeC, typename TensorTypeA> class matxCovHandle_t {
165
167
}
166
168
167
169
// Multiply by itself and scale by N-1 for the final covariance
168
- matmul_impl (c, devsT, devs, stream ,
170
+ matmul_impl (c, devsT, devs, exec ,
169
171
1 .0f / static_cast <float >(a.Size (RANK - 2 ) - 1 ));
170
172
}
171
173
@@ -224,14 +226,16 @@ using cov_cache_t = std::unordered_map<CovParams_t, std::any, CovParamsKeyHash,
224
226
* Covariance matrix output view
225
227
* @param a
226
228
* Covariance matrix input view
227
- * @param stream
228
- * CUDA stream
229
+ * @param exec
230
+ * CUDA executor
229
231
*/
230
232
template <typename TensorTypeC, typename TensorTypeA>
231
233
void cov_impl (TensorTypeC &c, const TensorTypeA &a,
232
- cudaStream_t stream = 0 )
234
+ const cudaExecutor &exec )
233
235
{
234
236
MATX_NVTX_START (" " , matx::MATX_NVTX_LOG_API)
237
+ const auto stream = exec.getStream ();
238
+
235
239
// Get parameters required by these tensors
236
240
auto params = detail::matxCovHandle_t<TensorTypeC, TensorTypeA>::GetCovParams (c, a, stream);
237
241
@@ -243,7 +247,7 @@ void cov_impl(TensorTypeC &c, const TensorTypeA &a,
243
247
return std::make_shared<cache_val_type>(c, a);
244
248
},
245
249
[&](std::shared_ptr<cache_val_type> ctype) {
246
- ctype->Exec (c, a, stream );
250
+ ctype->Exec (c, a, exec );
247
251
}
248
252
);
249
253
}
0 commit comments