Skip to content

Commit 7c49418

Browse files
committed
Add configurable scaling modes for pwelch, using custom reduction kernel that performs better than CUB when in memory FFT bin powers are {batches, nfft}
1 parent 6bc555d commit 7c49418

File tree

5 files changed

+206
-72
lines changed

5 files changed

+206
-72
lines changed

examples/pwelch.cu

+18-22
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,15 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
5050
MATX_ENTER_HANDLER();
5151
using complex = cuda::std::complex<float>;
5252

53-
float exec_time_ms;
54-
const int num_iterations = 100;
55-
index_t signal_size = 256;
56-
index_t nperseg = 32;
57-
index_t nfft = nperseg;
58-
index_t noverlap = 8;
59-
float ftone = 3.0;
53+
const int num_iterations = 500;
54+
index_t signal_size = 256000;
55+
index_t nperseg = 512;
56+
index_t noverlap = 256;
57+
index_t nfft = 65536;
58+
59+
float ftone = 2048.0;
6060
cudaStream_t stream;
6161
cudaStreamCreate(&stream);
62-
cudaEvent_t start, stop;
63-
cudaEventCreate(&start);
64-
cudaEventCreate(&stop);
6562
cudaExecutor exec{stream};
6663

6764
// Create input signal as a complex exponential
@@ -71,31 +68,30 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
7168
auto x = make_tensor<complex>({signal_size});
7269
(x = tmp_x).run(exec); // pre-compute x, tmp_x is otherwise lazily evaluated
7370

71+
// Create window
72+
auto w = make_tensor<complex>({nperseg});
73+
(w = flattop<0>({nperseg})).run(exec);
74+
7475
// Create output tensor
7576
auto Pxx = make_tensor<typename complex::value_type>({nfft});
7677

7778
// Run one time to pre-cache the FFT plan
78-
(Pxx = pwelch(x, nperseg, noverlap, nfft)).run(exec);
79+
(Pxx = pwelch(x, w, nperseg, noverlap, nfft)).run(exec);
7980
exec.sync();
8081

8182
// Start the timing
82-
cudaEventRecord(start, stream);
83-
84-
// Start the timing
85-
cudaEventRecord(start, stream);
83+
exec.start_timer();
8684

8785
for (int iteration = 0; iteration < num_iterations; iteration++) {
8886
// Use the PWelch operator
89-
(Pxx = pwelch(x, nperseg, noverlap, nfft)).run(exec);
87+
(Pxx = pwelch(x, w, nperseg, noverlap, nfft)).run(exec);
9088
}
91-
92-
cudaEventRecord(stop, stream);
9389
exec.sync();
94-
cudaEventElapsedTime(&exec_time_ms, start, stop);
90+
exec.stop_timer();
9591

96-
printf("Output Pxx:\n");
97-
print(Pxx);
98-
printf("PWelchOp avg runtime = %.3f ms\n", exec_time_ms / num_iterations);
92+
printf("Pxx(0) = %f\n", Pxx(0));
93+
printf("Pxx(ftone) = %f\n", Pxx(2048));
94+
printf("PWelchOp avg runtime = %.3f ms\n", exec.get_time_ms() / num_iterations);
9995

10096
MATX_CUDA_CHECK_LAST_ERROR();
10197
MATX_EXIT_HANDLER();

include/matx/operators/pwelch.h

+59-24
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,6 @@ namespace matx
4343
template <typename OpX, typename OpW>
4444
class PWelchOp : public BaseOp<PWelchOp<OpX,OpW>>
4545
{
46-
private:
47-
typename detail::base_type_t<OpX> x_;
48-
typename detail::base_type_t<OpW> w_;
49-
50-
index_t nperseg_;
51-
index_t noverlap_;
52-
index_t nfft_;
53-
cuda::std::array<index_t, 1> out_dims_;
54-
mutable detail::tensor_impl_t<typename remove_cvref_t<OpX>::value_type, 1> tmp_out_;
55-
mutable typename remove_cvref_t<OpX>::value_type *ptr = nullptr;
56-
5746
public:
5847
using matxop = bool;
5948
using value_type = typename OpX::value_type::value_type;
@@ -66,9 +55,23 @@ namespace matx
6655
return "pwelch(" + get_type_str(x_) + "," + get_type_str(w_) + ")";
6756
}
6857

69-
__MATX_INLINE__ PWelchOp(const OpX &x, const OpW &w, index_t nperseg, index_t noverlap, index_t nfft) :
70-
x_(x), w_(w), nperseg_(nperseg), noverlap_(noverlap), nfft_(nfft) {
71-
58+
__MATX_INLINE__ PWelchOp(
59+
const OpX &x,
60+
const OpW &w,
61+
index_t nperseg,
62+
index_t noverlap,
63+
index_t nfft,
64+
PwelchOutputScaleMode output_scale_mode,
65+
value_type fs
66+
) :
67+
x_(x),
68+
w_(w),
69+
nperseg_(nperseg),
70+
noverlap_(noverlap),
71+
nfft_(nfft),
72+
output_scale_mode_(output_scale_mode),
73+
fs_(fs)
74+
{
7275
MATX_STATIC_ASSERT_STR(OpX::Rank() == 1, matxInvalidDim, "pwelch: Only input rank of 1 is supported presently");
7376
for (int r = 0; r < OpX::Rank(); r++) {
7477
out_dims_[r] = nfft_;
@@ -96,25 +99,25 @@ namespace matx
9699
template <typename Out, typename Executor>
97100
void Exec(Out &&out, Executor &&ex) const{
98101
static_assert(is_cuda_executor_v<Executor>, "pwelch() only supports the CUDA executor currently");
99-
pwelch_impl(cuda::std::get<0>(out), x_, w_, nperseg_, noverlap_, nfft_, ex.getStream());
102+
pwelch_impl(cuda::std::get<0>(out), x_, w_, nperseg_, noverlap_, nfft_, output_scale_mode_, fs_, ex.getStream());
100103
}
101104

102105
template <typename ShapeType, typename Executor>
103106
__MATX_INLINE__ void InnerPreRun([[maybe_unused]] ShapeType &&shape, Executor &&ex) const noexcept
104107
{
105108
if constexpr (is_matx_op<OpX>()) {
106109
x_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
107-
}
110+
}
108111

109112
if constexpr (is_matx_op<OpW>()) {
110113
w_.PreRun(Shape(w_), std::forward<Executor>(ex));
111-
}
112-
}
114+
}
115+
}
113116

114117
template <typename ShapeType, typename Executor>
115118
__MATX_INLINE__ void PreRun([[maybe_unused]] ShapeType &&shape, Executor &&ex) const noexcept
116119
{
117-
InnerPreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
120+
InnerPreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
118121

119122
detail::AllocateTempTensor(tmp_out_, std::forward<Executor>(ex), out_dims_, &ptr);
120123

@@ -133,7 +136,20 @@ namespace matx
133136
}
134137

135138
matxFree(ptr);
136-
}
139+
}
140+
141+
private:
142+
typename detail::base_type_t<OpX> x_;
143+
typename detail::base_type_t<OpW> w_;
144+
145+
index_t nperseg_;
146+
index_t noverlap_;
147+
index_t nfft_;
148+
PwelchOutputScaleMode output_scale_mode_;
149+
value_type fs_;
150+
cuda::std::array<index_t, 1> out_dims_;
151+
mutable detail::tensor_impl_t<typename remove_cvref_t<OpX>::value_type, 1> tmp_out_;
152+
mutable typename remove_cvref_t<OpX>::value_type *ptr = nullptr;
137153
};
138154
}
139155

@@ -154,22 +170,41 @@ namespace matx
154170
* Number of points to overlap between segments. Defaults to 0
155171
* @param nfft
156172
* Length of FFT used per segment. nfft >= nperseg. Defaults to nfft = nperseg
173+
* @param output_scale_mode
174+
* Output scale mode. Defaults to PwelchOutputScaleMode_Spectrum
175+
* @param fs
176+
* Sampling frequency. Defaults to 1
157177
*
158178
* @returns Operator with power spectral density of x
159179
*
160180
*/
161181

162182
template <typename xType, typename wType>
163-
__MATX_INLINE__ auto pwelch(const xType& x, const wType& w, index_t nperseg, index_t noverlap, index_t nfft)
183+
__MATX_INLINE__ auto pwelch(
184+
const xType& x,
185+
const wType& w,
186+
index_t nperseg,
187+
index_t noverlap,
188+
index_t nfft,
189+
PwelchOutputScaleMode output_scale_mode = PwelchOutputScaleMode_Spectrum,
190+
typename xType::value_type::value_type fs = 1
191+
)
164192
{
165193
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
166194

167-
return detail::PWelchOp(x, w, nperseg, noverlap, nfft);
195+
return detail::PWelchOp(x, w, nperseg, noverlap, nfft, output_scale_mode, fs);
168196
}
169197

170198
template <typename xType>
171-
__MATX_INLINE__ auto pwelch(const xType& x, index_t nperseg, index_t noverlap, index_t nfft)
199+
__MATX_INLINE__ auto pwelch(
200+
const xType& x,
201+
index_t nperseg,
202+
index_t noverlap,
203+
index_t nfft,
204+
PwelchOutputScaleMode output_scale_mode = PwelchOutputScaleMode_Spectrum,
205+
typename xType::value_type::value_type fs = 1
206+
)
172207
{
173-
return detail::PWelchOp(x, std::nullopt, nperseg, noverlap, nfft);
208+
return detail::PWelchOp(x, std::nullopt, nperseg, noverlap, nfft, output_scale_mode, fs);
174209
}
175210
}

include/matx/transforms/pwelch.h

+82-7
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,71 @@
3434

3535
namespace matx
3636
{
37+
38+
enum PwelchOutputScaleMode {
39+
PwelchOutputScaleMode_Spectrum,
40+
PwelchOutputScaleMode_Density,
41+
PwelchOutputScaleMode_Spectrum_dB,
42+
PwelchOutputScaleMode_Density_dB
43+
};
44+
45+
namespace detail {
46+
template<PwelchOutputScaleMode OUTPUT_SCALE_MODE, typename T_IN, typename T_OUT>
47+
__global__ void pwelch_kernel(const T_IN t_in, T_OUT t_out, typename T_OUT::value_type fs)
48+
{
49+
const index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
50+
const index_t batches = t_in.Shape()[0];
51+
const index_t nfft = t_in.Shape()[1];
52+
53+
if (tid < nfft)
54+
{
55+
typename T_OUT::value_type pxx = 0;
56+
constexpr typename T_OUT::value_type ten = 10;
57+
58+
for (index_t batch = 0; batch < batches; batch++)
59+
{
60+
pxx += cuda::std::norm(t_in(batch, tid));
61+
}
62+
63+
if constexpr (OUTPUT_SCALE_MODE == PwelchOutputScaleMode_Spectrum)
64+
{
65+
t_out(tid) = pxx / batches;
66+
}
67+
else if constexpr (OUTPUT_SCALE_MODE == PwelchOutputScaleMode_Density)
68+
{
69+
t_out(tid) = pxx / (batches * fs);
70+
}
71+
else if constexpr (OUTPUT_SCALE_MODE == PwelchOutputScaleMode_Spectrum_dB)
72+
{
73+
pxx /= batches;
74+
if (pxx != 0)
75+
{
76+
t_out(tid) = ten * cuda::std::log10(pxx);
77+
}
78+
else
79+
{
80+
t_out(tid) = cuda::std::numeric_limits<typename T_OUT::value_type>::lowest();
81+
}
82+
}
83+
else if constexpr (OUTPUT_SCALE_MODE == PwelchOutputScaleMode_Density_dB)
84+
{
85+
pxx /= (batches * fs);
86+
if (pxx != 0)
87+
{
88+
t_out(tid) = ten * cuda::std::log10(pxx);
89+
}
90+
else
91+
{
92+
t_out(tid) = cuda::std::numeric_limits<typename T_OUT::value_type>::lowest();
93+
}
94+
}
95+
}
96+
}
97+
};
98+
99+
extern int g_pwelch_alg_mode;
37100
template <typename PxxType, typename xType, typename wType>
38-
__MATX_INLINE__ void pwelch_impl(PxxType Pxx, const xType& x, const wType& w, index_t nperseg, index_t noverlap, index_t nfft, cudaStream_t stream=0)
101+
__MATX_INLINE__ void pwelch_impl(PxxType Pxx, const xType& x, const wType& w, index_t nperseg, index_t noverlap, index_t nfft, PwelchOutputScaleMode output_scale_mode, typename PxxType::value_type fs, cudaStream_t stream=0)
39102
{
40103
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
41104

@@ -59,13 +122,25 @@ namespace matx
59122
(X_with_overlaps = fft(x_with_overlaps * w,nfft)).run(stream);
60123
}
61124

62-
// Compute magnitude squared in-place
63-
(X_with_overlaps = conj(X_with_overlaps) * X_with_overlaps).run(stream);
64-
auto mag_sq_X_with_overlaps = X_with_overlaps.RealView();
125+
int tpb = 512;
126+
int bpk = (static_cast<int>(nfft) + tpb - 1) / tpb;
65127

66-
// Perform the reduction across 'batches' rows and normalize
67-
auto norm_factor = static_cast<typename PxxType::value_type>(1.) / static_cast<typename PxxType::value_type>(batches);
68-
(Pxx = sum(mag_sq_X_with_overlaps, {0}) * norm_factor).run(stream);
128+
if (output_scale_mode == PwelchOutputScaleMode_Spectrum)
129+
{
130+
detail::pwelch_kernel<PwelchOutputScaleMode_Spectrum><<<bpk, tpb, 0, stream>>>(X_with_overlaps, Pxx, fs);
131+
}
132+
else if (output_scale_mode == PwelchOutputScaleMode_Density)
133+
{
134+
detail::pwelch_kernel<PwelchOutputScaleMode_Density><<<bpk, tpb, 0, stream>>>(X_with_overlaps, Pxx, fs);
135+
}
136+
else if (output_scale_mode == PwelchOutputScaleMode_Spectrum_dB)
137+
{
138+
detail::pwelch_kernel<PwelchOutputScaleMode_Spectrum_dB><<<bpk, tpb, 0, stream>>>(X_with_overlaps, Pxx, fs);
139+
}
140+
else //if (output_scale_mode == PwelchOutputScaleMode_Density_dB)
141+
{
142+
detail::pwelch_kernel<PwelchOutputScaleMode_Density_dB><<<bpk, tpb, 0, stream>>>(X_with_overlaps, Pxx, fs);
143+
}
69144
}
70145

71146
} // end namespace matx

test/00_operators/PWelch.cu

+29-9
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,23 @@ struct TestParams {
4747
index_t nperseg;
4848
index_t noverlap;
4949
index_t nfft;
50+
PwelchOutputScaleMode output_scale_mode;
51+
float fs;
5052
float ftone;
5153
float sigma;
5254
};
5355

5456
const std::vector<TestParams> CONFIGS = {
55-
{"none", 8, 8, 2, 8, 0., 0.},
56-
{"none", 16, 8, 4, 8, 1., 0.},
57-
{"none", 16, 8, 4, 8, 2., 1.},
58-
{"none", 16384, 256, 64, 256, 63., 0.},
59-
{"boxcar", 8, 8, 2, 8, 0., 0.},
60-
{"hann", 16, 8, 4, 8, 1., 0.},
61-
{"flattop", 1024, 64, 32, 128, 2., 1.},
57+
{"none", 8, 8, 2, 8, PwelchOutputScaleMode_Spectrum, 1.0, 0., 0.},
58+
{"none", 16, 8, 4, 8, PwelchOutputScaleMode_Spectrum, 1.0, 1., 0.},
59+
{"none", 16, 8, 4, 8, PwelchOutputScaleMode_Spectrum, 1.0, 2., 1.},
60+
{"none", 16384, 256, 64, 256, PwelchOutputScaleMode_Spectrum, 1.0, 63., 0.},
61+
{"boxcar", 8, 8, 2, 8, PwelchOutputScaleMode_Spectrum, 1.0, 0., 0.},
62+
{"hann", 16, 8, 4, 8, PwelchOutputScaleMode_Spectrum, 1.0, 1., 0.},
63+
{"flattop", 1024, 64, 32, 128, PwelchOutputScaleMode_Spectrum, 2.0, 2., 1.},
64+
{"flattop", 1024, 64, 32, 128, PwelchOutputScaleMode_Density, 2.0, 2., 1.},
65+
{"flattop", 1024, 64, 32, 128, PwelchOutputScaleMode_Spectrum_dB, 2.0, 2., 1.},
66+
{"flattop", 1024, 64, 32, 128, PwelchOutputScaleMode_Density_dB, 2.0, 2., 1.},
6267
};
6368

6469
class PWelchComplexExponentialTest : public ::testing::TestWithParam<TestParams>
@@ -87,11 +92,26 @@ void helper(PWelchComplexExponentialTest& test)
8792
"nperseg"_a=test.params.nperseg,
8893
"noverlap"_a=test.params.noverlap,
8994
"nfft"_a=test.params.nfft,
95+
"scaling"_a="spectrum",
96+
"fs"_a =test.params.fs,
9097
"ftone"_a=test.params.ftone,
9198
"sigma"_a=test.params.sigma,
9299
"window_name"_a=test.params.window_name
93100
);
94101

102+
if (test.params.output_scale_mode == PwelchOutputScaleMode_Density)
103+
{
104+
cfg["scaling"] = "density";
105+
}
106+
else if (test.params.output_scale_mode == PwelchOutputScaleMode_Density_dB)
107+
{
108+
cfg["scaling"] = "density_dB";
109+
}
110+
else if (test.params.output_scale_mode == PwelchOutputScaleMode_Spectrum_dB)
111+
{
112+
cfg["scaling"] = "spectrum_dB";
113+
}
114+
95115
test.pb->template InitAndRunTVGeneratorWithCfg<TypeParam>(
96116
"00_operators", "pwelch_operators", "pwelch_complex_exponential", cfg);
97117

@@ -104,7 +124,7 @@ void helper(PWelchComplexExponentialTest& test)
104124

105125
if (test.params.window_name == "none")
106126
{
107-
(Pxx = pwelch(x, test.params.nperseg, test.params.noverlap, test.params.nfft)).run(exec);
127+
(Pxx = pwelch(x, test.params.nperseg, test.params.noverlap, test.params.nfft, test.params.output_scale_mode, test.params.fs)).run(exec);
108128
}
109129
else
110130
{
@@ -125,7 +145,7 @@ void helper(PWelchComplexExponentialTest& test)
125145
{
126146
ASSERT_TRUE(false) << "Unknown window parameter name " + test.params.window_name;
127147
}
128-
(Pxx = pwelch(x, w, test.params.nperseg, test.params.noverlap, test.params.nfft)).run(exec);
148+
(Pxx = pwelch(x, w, test.params.nperseg, test.params.noverlap, test.params.nfft, test.params.output_scale_mode, test.params.fs)).run(exec);
129149
}
130150

131151
exec.sync();

0 commit comments

Comments
 (0)