Skip to content

Commit 9bb79f1

Browse files
committed
Move nvcc-specific features behind __CUDACC__ guards and add
static_asserts for signal type
1 parent 7491b19 commit 9bb79f1

File tree

4 files changed

+153
-80
lines changed

4 files changed

+153
-80
lines changed

examples/pwelch.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
6969
(x = tmp_x).run(exec); // pre-compute x, tmp_x is otherwise lazily evaluated
7070

7171
// Create window
72-
auto w = make_tensor<complex>({nperseg});
72+
auto w = make_tensor<float>({nperseg});
7373
(w = flattop<0>({nperseg})).run(exec);
7474

7575
// Create output tensor

include/matx/kernels/pwelch.cuh

+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
////////////////////////////////////////////////////////////////////////////////
2+
// BSD 3-Clause License
3+
//
4+
// Copyright (c) 2023, NVIDIA Corporation
5+
// All rights reserved.
6+
//
7+
// Redistribution and use in source and binary forms, with or without
8+
// modification, are permitted provided that the following conditions are met:
9+
//
10+
// 1. Redistributions of source code must retain the above copyright notice, this
11+
// list of conditions and the following disclaimer.
12+
//
13+
// 2. Redistributions in binary form must reproduce the above copyright notice,
14+
// this list of conditions and the following disclaimer in the documentation
15+
// and/or other materials provided with the distribution.
16+
//
17+
// 3. Neither the name of the copyright holder nor the names of its
18+
// contributors may be used to endorse or promote products derived from
19+
// this software without specific prior written permission.
20+
//
21+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24+
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25+
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26+
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27+
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28+
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29+
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30+
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
/////////////////////////////////////////////////////////////////////////////////
32+
33+
#pragma once
34+
35+
namespace matx {
36+
37+
enum PwelchOutputScaleMode {
38+
PwelchOutputScaleMode_Spectrum,
39+
PwelchOutputScaleMode_Density,
40+
PwelchOutputScaleMode_Spectrum_dB,
41+
PwelchOutputScaleMode_Density_dB
42+
};
43+
44+
namespace detail {
45+
46+
#ifdef __CUDACC__
47+
template<PwelchOutputScaleMode OUTPUT_SCALE_MODE, typename T_IN, typename T_OUT, typename fsType>
48+
__global__ void pwelch_kernel(const T_IN t_in, T_OUT t_out, fsType fs)
49+
{
50+
const index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
51+
const index_t batches = t_in.Shape()[0];
52+
const index_t nfft = t_in.Shape()[1];
53+
54+
if (tid < nfft)
55+
{
56+
typename T_OUT::value_type pxx = 0;
57+
constexpr typename T_OUT::value_type ten = 10;
58+
59+
for (index_t batch = 0; batch < batches; batch++)
60+
{
61+
pxx += cuda::std::norm(t_in(batch, tid));
62+
}
63+
64+
if constexpr (OUTPUT_SCALE_MODE == PwelchOutputScaleMode_Spectrum)
65+
{
66+
t_out(tid) = pxx / batches;
67+
}
68+
else if constexpr (OUTPUT_SCALE_MODE == PwelchOutputScaleMode_Density)
69+
{
70+
t_out(tid) = pxx / (batches * fs);
71+
}
72+
else if constexpr (OUTPUT_SCALE_MODE == PwelchOutputScaleMode_Spectrum_dB)
73+
{
74+
pxx /= batches;
75+
if (pxx != 0)
76+
{
77+
t_out(tid) = ten * cuda::std::log10(pxx);
78+
}
79+
else
80+
{
81+
t_out(tid) = cuda::std::numeric_limits<typename T_OUT::value_type>::lowest();
82+
}
83+
}
84+
else if constexpr (OUTPUT_SCALE_MODE == PwelchOutputScaleMode_Density_dB)
85+
{
86+
pxx /= (batches * fs);
87+
if (pxx != 0)
88+
{
89+
t_out(tid) = ten * cuda::std::log10(pxx);
90+
}
91+
else
92+
{
93+
t_out(tid) = cuda::std::numeric_limits<typename T_OUT::value_type>::lowest();
94+
}
95+
}
96+
}
97+
}
98+
#endif
99+
100+
};
101+
};

include/matx/operators/pwelch.h

+41-11
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,16 @@
4040
namespace matx
4141
{
4242
namespace detail {
43-
template <typename OpX, typename OpW>
44-
class PWelchOp : public BaseOp<PWelchOp<OpX,OpW>>
43+
template <typename OpX, typename OpW, typename fsType>
44+
class PWelchOp : public BaseOp<PWelchOp<OpX,OpW,fsType>>
4545
{
4646
public:
47+
static_assert(is_complex_v<typename OpX::value_type>, "pwelch() must have a complex input type");
4748
using matxop = bool;
4849
using value_type = typename OpX::value_type::value_type;
4950
using matx_transform_op = bool;
5051
using pwelch_xform_op = bool;
5152

52-
static_assert(is_complex_v<typename OpX::value_type>, "pwelch() must have a complex input type");
5353

5454
__MATX_INLINE__ std::string str() const {
5555
return "pwelch(" + get_type_str(x_) + "," + get_type_str(w_) + ")";
@@ -62,7 +62,7 @@ namespace matx
6262
index_t noverlap,
6363
index_t nfft,
6464
PwelchOutputScaleMode output_scale_mode,
65-
value_type fs
65+
fsType fs
6666
) :
6767
x_(x),
6868
w_(w),
@@ -146,7 +146,7 @@ namespace matx
146146
index_t noverlap_;
147147
index_t nfft_;
148148
PwelchOutputScaleMode output_scale_mode_;
149-
value_type fs_;
149+
fsType fs_;
150150
cuda::std::array<index_t, 1> out_dims_;
151151
mutable detail::tensor_impl_t<typename remove_cvref_t<OpX>::value_type, 1> tmp_out_;
152152
mutable typename remove_cvref_t<OpX>::value_type *ptr = nullptr;
@@ -160,6 +160,8 @@ namespace matx
160160
* Input time domain data type
161161
* @tparam wType
162162
* Input window type
163+
* @tparam fsType
164+
* Sampling frequency type
163165
* @param x
164166
* Input time domain tensor
165167
* @param w
@@ -179,32 +181,60 @@ namespace matx
179181
*
180182
*/
181183

182-
template <typename xType, typename wType>
184+
template <
185+
typename xType,
186+
typename wType,
187+
typename fsType>
183188
__MATX_INLINE__ auto pwelch(
184189
const xType& x,
185190
const wType& w,
186191
index_t nperseg,
187192
index_t noverlap,
188193
index_t nfft,
189194
PwelchOutputScaleMode output_scale_mode = PwelchOutputScaleMode_Spectrum,
190-
typename xType::value_type::value_type fs = 1
195+
fsType fs = 1
191196
)
192197
{
193-
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
194-
195198
return detail::PWelchOp(x, w, nperseg, noverlap, nfft, output_scale_mode, fs);
196199
}
197200

198-
template <typename xType>
201+
template <
202+
typename xType,
203+
typename fsType>
199204
__MATX_INLINE__ auto pwelch(
200205
const xType& x,
201206
index_t nperseg,
202207
index_t noverlap,
203208
index_t nfft,
204209
PwelchOutputScaleMode output_scale_mode = PwelchOutputScaleMode_Spectrum,
205-
typename xType::value_type::value_type fs = 1
210+
fsType fs = 1
206211
)
207212
{
208213
return detail::PWelchOp(x, std::nullopt, nperseg, noverlap, nfft, output_scale_mode, fs);
209214
}
215+
216+
template <typename xType, typename wType>
217+
__MATX_INLINE__ auto pwelch(
218+
const xType& x,
219+
const wType& w,
220+
index_t nperseg,
221+
index_t noverlap,
222+
index_t nfft,
223+
PwelchOutputScaleMode output_scale_mode = PwelchOutputScaleMode_Spectrum
224+
)
225+
{
226+
return detail::PWelchOp(x, w, nperseg, noverlap, nfft, output_scale_mode, 1.f);
227+
}
228+
229+
template <typename xType>
230+
__MATX_INLINE__ auto pwelch(
231+
const xType& x,
232+
index_t nperseg,
233+
index_t noverlap,
234+
index_t nfft,
235+
PwelchOutputScaleMode output_scale_mode = PwelchOutputScaleMode_Spectrum
236+
)
237+
{
238+
return detail::PWelchOp(x, std::nullopt, nperseg, noverlap, nfft, output_scale_mode, 1.f);
239+
}
210240
}

include/matx/transforms/pwelch.h

+10-68
Original file line numberDiff line numberDiff line change
@@ -32,74 +32,16 @@
3232

3333
#pragma once
3434

35+
#include "matx/kernels/pwelch.cuh"
36+
3537
namespace matx
3638
{
37-
38-
enum PwelchOutputScaleMode {
39-
PwelchOutputScaleMode_Spectrum,
40-
PwelchOutputScaleMode_Density,
41-
PwelchOutputScaleMode_Spectrum_dB,
42-
PwelchOutputScaleMode_Density_dB
43-
};
44-
45-
namespace detail {
46-
template<PwelchOutputScaleMode OUTPUT_SCALE_MODE, typename T_IN, typename T_OUT>
47-
__global__ void pwelch_kernel(const T_IN t_in, T_OUT t_out, typename T_OUT::value_type fs)
48-
{
49-
const index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
50-
const index_t batches = t_in.Shape()[0];
51-
const index_t nfft = t_in.Shape()[1];
52-
53-
if (tid < nfft)
54-
{
55-
typename T_OUT::value_type pxx = 0;
56-
constexpr typename T_OUT::value_type ten = 10;
57-
58-
for (index_t batch = 0; batch < batches; batch++)
59-
{
60-
pxx += cuda::std::norm(t_in(batch, tid));
61-
}
62-
63-
if constexpr (OUTPUT_SCALE_MODE == PwelchOutputScaleMode_Spectrum)
64-
{
65-
t_out(tid) = pxx / batches;
66-
}
67-
else if constexpr (OUTPUT_SCALE_MODE == PwelchOutputScaleMode_Density)
68-
{
69-
t_out(tid) = pxx / (batches * fs);
70-
}
71-
else if constexpr (OUTPUT_SCALE_MODE == PwelchOutputScaleMode_Spectrum_dB)
72-
{
73-
pxx /= batches;
74-
if (pxx != 0)
75-
{
76-
t_out(tid) = ten * cuda::std::log10(pxx);
77-
}
78-
else
79-
{
80-
t_out(tid) = cuda::std::numeric_limits<typename T_OUT::value_type>::lowest();
81-
}
82-
}
83-
else if constexpr (OUTPUT_SCALE_MODE == PwelchOutputScaleMode_Density_dB)
84-
{
85-
pxx /= (batches * fs);
86-
if (pxx != 0)
87-
{
88-
t_out(tid) = ten * cuda::std::log10(pxx);
89-
}
90-
else
91-
{
92-
t_out(tid) = cuda::std::numeric_limits<typename T_OUT::value_type>::lowest();
93-
}
94-
}
95-
}
96-
}
97-
};
98-
99-
extern int g_pwelch_alg_mode;
100-
template <typename PxxType, typename xType, typename wType>
101-
__MATX_INLINE__ void pwelch_impl(PxxType Pxx, const xType& x, const wType& w, index_t nperseg, index_t noverlap, index_t nfft, PwelchOutputScaleMode output_scale_mode, typename PxxType::value_type fs, cudaStream_t stream=0)
102-
{
39+
template <typename PxxType, typename xType, typename wType, typename fsType>
40+
__MATX_INLINE__ void pwelch_impl(PxxType Pxx, const xType& x, const wType& w, index_t nperseg, index_t noverlap, index_t nfft, PwelchOutputScaleMode output_scale_mode, fsType fs, cudaStream_t stream=0)
41+
{
42+
#ifndef __CUDACC__
43+
MATX_THROW(matxNotSupported, "pwelch not supported on host");
44+
#else
10345
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
10446

10547
MATX_ASSERT_STR(Pxx.Rank() == x.Rank(), matxInvalidDim, "pwelch: Pxx rank must be the same as x rank");
@@ -141,6 +83,6 @@ namespace matx
14183
{
14284
detail::pwelch_kernel<PwelchOutputScaleMode_Density_dB><<<bpk, tpb, 0, stream>>>(X_with_overlaps, Pxx, fs);
14385
}
144-
}
145-
86+
#endif
87+
}
14688
} // end namespace matx

0 commit comments

Comments
 (0)