|
32 | 32 |
|
33 | 33 | #pragma once
|
34 | 34 |
|
| 35 | +#include "matx/kernels/pwelch.cuh" |
| 36 | + |
35 | 37 | namespace matx
|
36 | 38 | {
|
37 |
| - |
38 |
| - enum PwelchOutputScaleMode { |
39 |
| - PwelchOutputScaleMode_Spectrum, |
40 |
| - PwelchOutputScaleMode_Density, |
41 |
| - PwelchOutputScaleMode_Spectrum_dB, |
42 |
| - PwelchOutputScaleMode_Density_dB |
43 |
| - }; |
44 |
| - |
45 |
| - namespace detail { |
46 |
| - template<PwelchOutputScaleMode OUTPUT_SCALE_MODE, typename T_IN, typename T_OUT> |
47 |
| - __global__ void pwelch_kernel(const T_IN t_in, T_OUT t_out, typename T_OUT::value_type fs) |
48 |
| - { |
49 |
| - const index_t tid = blockIdx.x * blockDim.x + threadIdx.x; |
50 |
| - const index_t batches = t_in.Shape()[0]; |
51 |
| - const index_t nfft = t_in.Shape()[1]; |
52 |
| - |
53 |
| - if (tid < nfft) |
54 |
| - { |
55 |
| - typename T_OUT::value_type pxx = 0; |
56 |
| - constexpr typename T_OUT::value_type ten = 10; |
57 |
| - |
58 |
| - for (index_t batch = 0; batch < batches; batch++) |
59 |
| - { |
60 |
| - pxx += cuda::std::norm(t_in(batch, tid)); |
61 |
| - } |
62 |
| - |
63 |
| - if constexpr (OUTPUT_SCALE_MODE == PwelchOutputScaleMode_Spectrum) |
64 |
| - { |
65 |
| - t_out(tid) = pxx / batches; |
66 |
| - } |
67 |
| - else if constexpr (OUTPUT_SCALE_MODE == PwelchOutputScaleMode_Density) |
68 |
| - { |
69 |
| - t_out(tid) = pxx / (batches * fs); |
70 |
| - } |
71 |
| - else if constexpr (OUTPUT_SCALE_MODE == PwelchOutputScaleMode_Spectrum_dB) |
72 |
| - { |
73 |
| - pxx /= batches; |
74 |
| - if (pxx != 0) |
75 |
| - { |
76 |
| - t_out(tid) = ten * cuda::std::log10(pxx); |
77 |
| - } |
78 |
| - else |
79 |
| - { |
80 |
| - t_out(tid) = cuda::std::numeric_limits<typename T_OUT::value_type>::lowest(); |
81 |
| - } |
82 |
| - } |
83 |
| - else if constexpr (OUTPUT_SCALE_MODE == PwelchOutputScaleMode_Density_dB) |
84 |
| - { |
85 |
| - pxx /= (batches * fs); |
86 |
| - if (pxx != 0) |
87 |
| - { |
88 |
| - t_out(tid) = ten * cuda::std::log10(pxx); |
89 |
| - } |
90 |
| - else |
91 |
| - { |
92 |
| - t_out(tid) = cuda::std::numeric_limits<typename T_OUT::value_type>::lowest(); |
93 |
| - } |
94 |
| - } |
95 |
| - } |
96 |
| - } |
97 |
| - }; |
98 |
| - |
99 |
| - extern int g_pwelch_alg_mode; |
100 | 39 | template <typename PxxType, typename xType, typename wType>
|
101 | 40 | __MATX_INLINE__ void pwelch_impl(PxxType Pxx, const xType& x, const wType& w, index_t nperseg, index_t noverlap, index_t nfft, PwelchOutputScaleMode output_scale_mode, typename PxxType::value_type fs, cudaStream_t stream=0)
|
102 |
| - { |
| 41 | + { |
| 42 | + #ifndef __CUDACC__ |
| 43 | + MATX_THROW(matxNotSupported, "pwelch not supported on host"); |
| 44 | + #else |
103 | 45 | MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
|
104 | 46 |
|
105 | 47 | MATX_ASSERT_STR(Pxx.Rank() == x.Rank(), matxInvalidDim, "pwelch: Pxx rank must be the same as x rank");
|
@@ -141,6 +83,6 @@ namespace matx
|
141 | 83 | {
|
142 | 84 | detail::pwelch_kernel<PwelchOutputScaleMode_Density_dB><<<bpk, tpb, 0, stream>>>(X_with_overlaps, Pxx, fs);
|
143 | 85 | }
|
144 |
| - } |
145 |
| - |
| 86 | + #endif |
| 87 | + } |
146 | 88 | } // end namespace matx
|
0 commit comments