Skip to content

Commit 63ec538

Browse files
authored
Fixed fftw guards and temp allocation (#660)
1 parent ddd4577 commit 63ec538

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

include/matx/core/operator_utils.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ namespace matx {
120120
template <typename TensorType, typename Executor, typename ShapeType>
121121
__MATX_HOST__ __MATX_INLINE__ void AllocateTempTensor(TensorType &tensor, Executor &&ex, ShapeType &&shape, typename TensorType::scalar_type **ptr) {
122122
const auto ttl_size = std::accumulate(shape.begin(), shape.end(), static_cast<index_t>(1),
123-
std::multiplies<index_t>()) * sizeof(*ptr);
123+
std::multiplies<index_t>()) * sizeof(typename TensorType::scalar_type);
124124
if constexpr (is_cuda_executor_v<Executor>) {
125125
matxAlloc((void**)ptr, ttl_size, MATX_ASYNC_DEVICE_MEMORY, ex.getStream());
126126
make_tensor(tensor, *ptr, shape);

include/matx/transforms/fft/fft_fftw.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
#pragma once
3434

35+
#if MATX_EN_CPU_FFT
3536
#include "matx/core/cache.h"
3637
#include "matx/core/error.h"
3738
#include "matx/core/make_tensor.h"
@@ -560,7 +561,6 @@ template<typename OutTensorType, typename InTensorType> class matxFFTWPlan_t {
560561
[[maybe_unused]] const FftFFTWParams_t &params,
561562
[[maybe_unused]] detail::FFTDirection dir,
562563
[[maybe_unused]] const HostExecutor<MODE> &exec) {
563-
#if MATX_EN_CPU_FFT
564564
using cache_val_type = detail::matxFFTWPlan_t<OutputTensor, InputTensor>;
565565
detail::GetCache().LookupAndExec<detail::fft_fftw_cache_t>(
566566
detail::GetCacheIdFromType<detail::fft_fftw_cache_t>(),
@@ -572,7 +572,6 @@ template<typename OutTensorType, typename InTensorType> class matxFFTWPlan_t {
572572
ctype->Exec(o, i);
573573
}
574574
);
575-
#endif
576575
}
577576

578577
template <typename OutputTensor, typename InputTensor, ThreadsMode MODE>
@@ -760,3 +759,4 @@ template<typename OutTensorType, typename InTensorType> class matxFFTWPlan_t {
760759

761760

762761
}; // end namespace matx
762+
#endif

0 commit comments

Comments
 (0)