Skip to content

Commit efbb2a0

Browse files
authored
Do not create CUDA events in ephemeral executors (#889)
* Do not create CUDA events in ephemeral executors
1 parent 84dfc1b commit efbb2a0

File tree

2 files changed

+20
-11
lines changed

2 files changed

+20
-11
lines changed

include/matx/executors/cuda.h

+18-9
Original file line numberDiff line numberDiff line change
@@ -53,24 +53,31 @@ namespace matx
5353
* @brief Construct a new cudaExecutor with a stream
5454
*
5555
* @param stream CUDA stream
56+
* @param profiling Whether to enable profiling
5657
*/
57-
cudaExecutor(cudaStream_t stream) : stream_(stream) {
58-
MATX_CUDA_CHECK(cudaEventCreate(&start_));
59-
MATX_CUDA_CHECK(cudaEventCreate(&stop_));
58+
cudaExecutor(cudaStream_t stream, bool profiling = true) : stream_(stream), profiling_(profiling) {
59+
if (profiling_) {
60+
MATX_CUDA_CHECK(cudaEventCreate(&start_));
61+
MATX_CUDA_CHECK(cudaEventCreate(&stop_));
62+
}
6063
}
6164

62-
cudaExecutor(int stream) : stream_(reinterpret_cast<cudaStream_t>(stream)) {
63-
MATX_CUDA_CHECK(cudaEventCreate(&start_));
64-
MATX_CUDA_CHECK(cudaEventCreate(&stop_));
65+
cudaExecutor(int stream, bool profiling = true) : stream_(reinterpret_cast<cudaStream_t>(stream)), profiling_(profiling) {
66+
if (profiling_) {
67+
MATX_CUDA_CHECK(cudaEventCreate(&start_));
68+
MATX_CUDA_CHECK(cudaEventCreate(&stop_));
69+
}
6570
}
6671

6772
/**
6873
* @brief Construct a new cudaExecutor object using the default stream
6974
*
7075
*/
71-
cudaExecutor() : stream_(0) {
72-
MATX_CUDA_CHECK(cudaEventCreate(&start_));
73-
MATX_CUDA_CHECK(cudaEventCreate(&stop_));
76+
cudaExecutor() : stream_(0), profiling_(true) {
77+
if (profiling_) {
78+
MATX_CUDA_CHECK(cudaEventCreate(&start_));
79+
MATX_CUDA_CHECK(cudaEventCreate(&stop_));
80+
}
7481
}
7582

7683
/**
@@ -99,6 +106,7 @@ namespace matx
99106
* This will block until the event is synchronized
100107
*/
101108
float get_time_ms() {
109+
MATX_ASSERT_STR(profiling_, matxInvalidParameter, "Profiling not enabled when using get_time_ms()");
102110
float time;
103111
cudaEventSynchronize(stop_);
104112
cudaEventElapsedTime(&time, start_, stop_);
@@ -169,6 +177,7 @@ namespace matx
169177
}
170178

171179
private:
180+
bool profiling_;
172181
cudaStream_t stream_;
173182
cudaEvent_t start_;
174183
cudaEvent_t stop_;

include/matx/operators/base_operator.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ namespace matx
109109
__MATX_INLINE__ void run(cudaStream_t stream = 0)
110110
{
111111
MATX_NVTX_START(detail::get_type_str(*static_cast<T *>(this)), matx::MATX_NVTX_LOG_API)
112-
run(cudaExecutor{stream});
112+
run(cudaExecutor{stream, false});
113113
}
114114

115115
/**
@@ -122,7 +122,7 @@ namespace matx
122122
{
123123
MATX_NVTX_START(static_cast<T *>(this)->str(), matx::MATX_NVTX_LOG_API)
124124

125-
run(cudaExecutor{stream});
125+
run(cudaExecutor{stream, false});
126126
cudaEventRecord(ev, stream);
127127
}
128128

0 commit comments

Comments
 (0)