Skip to content

Commit c3420a0

Browse files
authored
Fix bugs in 2D FFTs and add tests (#587)
- Fixes several issues where input/output, dimensions, or types were swapped in internal calculations. - Adds tests for non-square, batched, strided batched, R2C, and C2R 2D transforms.
1 parent d75d702 commit c3420a0

File tree

3 files changed

+193
-9
lines changed

3 files changed

+193
-9
lines changed

include/matx/transforms/fft.h

+7-7
Original file line numberDiff line numberDiff line change
@@ -235,23 +235,23 @@ template <typename OutTensorType, typename InTensorType> class matxFFTPlan_t {
235235
else if (fft_rank == 2) {
236236
if (params.transform_type == CUFFT_C2R ||
237237
params.transform_type == CUFFT_Z2D) {
238-
params.n[0] = o.Size(RANK-1);
239-
params.n[1] = o.Size(RANK-2);
238+
params.n[1] = o.Size(RANK-1);
239+
params.n[0] = o.Size(RANK-2);
240240
}
241241
else {
242242
params.n[1] = i.Size(RANK-1);
243243
params.n[0] = i.Size(RANK-2);
244244
}
245245

246246
params.batch = (RANK == 2) ? 1 : i.Size(RANK - 3);
247-
params.inembed[1] = o.Size(RANK-1);
248-
params.onembed[1] = i.Size(RANK-1);
247+
params.inembed[1] = i.Size(RANK-1);
248+
params.onembed[1] = o.Size(RANK-1);
249249
params.istride = i.Stride(RANK-1);
250250
params.ostride = o.Stride(RANK-1);
251251
params.idist = (RANK<=2) ? 1 : (int) i.Stride(RANK-3);
252252
params.odist = (RANK<=2) ? 1 : (int) o.Stride(RANK-3);
253253

254-
if constexpr (is_complex_half_v<T1> || is_complex_half_v<T1>) {
254+
if constexpr (is_complex_half_v<T1> || is_half_v<T1>) {
255255
if ((params.n[0] & (params.n[0] - 1)) != 0 ||
256256
(params.n[1] & (params.n[1] - 1)) != 0) {
257257
MATX_THROW(matxInvalidDim,
@@ -367,7 +367,7 @@ template <typename OutTensorType, typename InTensorType> class matxFFTPlan_t {
367367
if constexpr (is_complex_half_v<T2>) {
368368
return CUFFT_C2C;
369369
}
370-
else if constexpr (is_half_v<T1>) {
370+
else if constexpr (is_half_v<T2>) {
371371
return CUFFT_R2C;
372372
}
373373
}
@@ -1057,7 +1057,7 @@ __MATX_INLINE__ void ifft2_impl(OutputTensor o, const InputTensor i,
10571057
}
10581058

10591059
// Get parameters required by these tensors
1060-
auto params = detail::matxFFTPlan_t<decltype(in), decltype(out)>::GetFFTParams(out, in, 2);
1060+
auto params = detail::matxFFTPlan_t<decltype(out), decltype(in)>::GetFFTParams(out, in, 2);
10611061
params.stream = stream;
10621062

10631063
// Get cache or new FFT plan if it doesn't exist

test/00_transform/FFT.cu

+152
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,65 @@ TYPED_TEST(FFTTestComplexTypes, FFT2D16C2C)
640640
MATX_EXIT_HANDLER();
641641
}
642642

643+
TYPED_TEST(FFTTestComplexTypes, FFT2D16x32C2C)
644+
{
645+
MATX_ENTER_HANDLER();
646+
const index_t fft_dim[] = {16, 32};
647+
this->pb->template InitAndRunTVGenerator<TypeParam>(
648+
"00_transforms", "fft_operators", "fft_2d", {fft_dim[0], fft_dim[1]});
649+
650+
tensor_t<TypeParam, 2> av{{fft_dim[0], fft_dim[1]}};
651+
tensor_t<TypeParam, 2> avo{{fft_dim[0], fft_dim[1]}};
652+
this->pb->NumpyToTensorView(av, "a_in");
653+
654+
(avo = fft2(av)).run();
655+
cudaStreamSynchronize(0);
656+
657+
MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh);
658+
MATX_EXIT_HANDLER();
659+
}
660+
661+
TYPED_TEST(FFTTestComplexTypes, FFT2D16BatchedC2C)
662+
{
663+
MATX_ENTER_HANDLER();
664+
const index_t batch_size = 10;
665+
const index_t fft_dim = 16;
666+
this->pb->template InitAndRunTVGenerator<TypeParam>(
667+
"00_transforms", "fft_operators", "fft_2d_batched",
668+
{batch_size, fft_dim, fft_dim});
669+
670+
tensor_t<TypeParam, 3> av{{batch_size, fft_dim, fft_dim}};
671+
tensor_t<TypeParam, 3> avo{{batch_size, fft_dim, fft_dim}};
672+
this->pb->NumpyToTensorView(av, "a_in");
673+
674+
(avo = fft2(av)).run();
675+
cudaStreamSynchronize(0);
676+
677+
MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh);
678+
MATX_EXIT_HANDLER();
679+
}
680+
681+
TYPED_TEST(FFTTestComplexTypes, FFT2D16BatchedStridedC2C)
682+
{
683+
MATX_ENTER_HANDLER();
684+
const index_t batch_size = 10;
685+
const index_t fft_dim = 16;
686+
this->pb->template InitAndRunTVGenerator<TypeParam>(
687+
"00_transforms", "fft_operators", "fft_2d_batched_strided",
688+
{fft_dim, batch_size, fft_dim});
689+
690+
tensor_t<TypeParam, 3> av{{fft_dim, batch_size, fft_dim}};
691+
tensor_t<TypeParam, 3> avo{{fft_dim, batch_size, fft_dim}};
692+
this->pb->NumpyToTensorView(av, "a_in");
693+
694+
const int32_t axes[] = {0, 2};
695+
(avo = fft2(av, axes)).run();
696+
cudaStreamSynchronize(0);
697+
698+
MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh);
699+
MATX_EXIT_HANDLER();
700+
}
701+
643702
TYPED_TEST(FFTTestComplexTypes, IFFT2D16C2C)
644703
{
645704
MATX_ENTER_HANDLER();
@@ -658,6 +717,99 @@ TYPED_TEST(FFTTestComplexTypes, IFFT2D16C2C)
658717
MATX_EXIT_HANDLER();
659718
}
660719

720+
TYPED_TEST(FFTTestComplexTypes, IFFT2D16x32C2C)
721+
{
722+
MATX_ENTER_HANDLER();
723+
const index_t fft_dim[] = {16, 32};
724+
this->pb->template InitAndRunTVGenerator<TypeParam>(
725+
"00_transforms", "fft_operators", "ifft_2d", {fft_dim[0], fft_dim[1]});
726+
727+
tensor_t<TypeParam, 2> av{{fft_dim[0], fft_dim[1]}};
728+
tensor_t<TypeParam, 2> avo{{fft_dim[0], fft_dim[1]}};
729+
this->pb->NumpyToTensorView(av, "a_in");
730+
731+
(avo = ifft2(av)).run();
732+
cudaStreamSynchronize(0);
733+
734+
MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh);
735+
MATX_EXIT_HANDLER();
736+
}
737+
738+
TYPED_TEST(FFTTestComplexNonHalfTypes, FFT2D16R2C)
739+
{
740+
MATX_ENTER_HANDLER();
741+
const index_t fft_dim = 16;
742+
using rtype = typename TypeParam::value_type;
743+
this->pb->template InitAndRunTVGenerator<rtype>(
744+
"00_transforms", "fft_operators", "rfft_2d", {fft_dim, fft_dim});
745+
746+
tensor_t<rtype, 2> av{{fft_dim, fft_dim}};
747+
tensor_t<TypeParam, 2> avo{{fft_dim, fft_dim / 2 + 1}};
748+
this->pb->NumpyToTensorView(av, "a_in");
749+
750+
(avo = fft2(av)).run();
751+
cudaStreamSynchronize(0);
752+
753+
MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh);
754+
MATX_EXIT_HANDLER();
755+
}
756+
757+
TYPED_TEST(FFTTestComplexNonHalfTypes, FFT2D16x32R2C)
758+
{
759+
MATX_ENTER_HANDLER();
760+
const index_t fft_dim[] = {16, 32};
761+
using rtype = typename TypeParam::value_type;
762+
this->pb->template InitAndRunTVGenerator<rtype>(
763+
"00_transforms", "fft_operators", "rfft_2d", {fft_dim[0], fft_dim[1]});
764+
765+
tensor_t<rtype, 2> av{{fft_dim[0], fft_dim[1]}};
766+
tensor_t<TypeParam, 2> avo{{fft_dim[0], fft_dim[1] / 2 + 1}};
767+
this->pb->NumpyToTensorView(av, "a_in");
768+
769+
(avo = fft2(av)).run();
770+
cudaStreamSynchronize(0);
771+
772+
MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh);
773+
MATX_EXIT_HANDLER();
774+
}
775+
776+
TYPED_TEST(FFTTestComplexNonHalfTypes, IFFT2D16C2R)
777+
{
778+
MATX_ENTER_HANDLER();
779+
const index_t fft_dim = 16;
780+
using rtype = typename TypeParam::value_type;
781+
this->pb->template InitAndRunTVGenerator<TypeParam>(
782+
"00_transforms", "fft_operators", "irfft_2d", {fft_dim, fft_dim});
783+
784+
tensor_t<TypeParam, 2> av{{fft_dim, fft_dim / 2 + 1}};
785+
tensor_t<rtype, 2> avo{{fft_dim, fft_dim}};
786+
this->pb->NumpyToTensorView(av, "a_in");
787+
788+
(avo = ifft2(av)).run();
789+
cudaStreamSynchronize(0);
790+
791+
MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh);
792+
MATX_EXIT_HANDLER();
793+
}
794+
795+
TYPED_TEST(FFTTestComplexNonHalfTypes, IFFT2D16x32C2R)
796+
{
797+
MATX_ENTER_HANDLER();
798+
const index_t fft_dim[] = {16, 32};
799+
using rtype = typename TypeParam::value_type;
800+
this->pb->template InitAndRunTVGenerator<TypeParam>(
801+
"00_transforms", "fft_operators", "irfft_2d", {fft_dim[0], fft_dim[1]});
802+
803+
tensor_t<TypeParam, 2> av{{fft_dim[0], fft_dim[1] / 2 + 1}};
804+
tensor_t<rtype, 2> avo{{fft_dim[0], fft_dim[1]}};
805+
this->pb->NumpyToTensorView(av, "a_in");
806+
807+
(avo = ifft2(av)).run();
808+
cudaStreamSynchronize(0);
809+
810+
MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh);
811+
MATX_EXIT_HANDLER();
812+
}
661813

662814
TYPED_TEST(FFTTestComplexNonHalfTypes, FFT1D1024C2CShort)
663815
{

test/test_vectors/generators/00_transforms.py

+34-2
Original file line numberDiff line numberDiff line change
@@ -319,15 +319,47 @@ def fft_2d(self) -> Dict[str, np.ndarray]:
319319
(self.size[0], self.size[1]), self.dtype)
320320
return {
321321
'a_in': seq,
322-
'a_out': np.fft.fft2(seq, (self.size[1], self.size[1]))
322+
'a_out': np.fft.fft2(seq, (self.size[0], self.size[1]))
323+
}
324+
325+
def fft_2d_batched(self) -> Dict[str, np.ndarray]:
326+
seq = matx_common.randn_ndarray(
327+
(self.size[0], self.size[1], self.size[2]), self.dtype)
328+
return {
329+
'a_in': seq,
330+
'a_out': np.fft.fft2(seq, (self.size[1], self.size[2]))
331+
}
332+
333+
def fft_2d_batched_strided(self) -> Dict[str, np.ndarray]:
334+
seq = matx_common.randn_ndarray(
335+
(self.size[0], self.size[1], self.size[2]), self.dtype)
336+
return {
337+
'a_in': seq,
338+
'a_out': np.fft.fft2(seq, (self.size[0], self.size[2]), axes=(0, 2))
323339
}
324340

325341
def ifft_2d(self) -> Dict[str, np.ndarray]:
326342
seq = matx_common.randn_ndarray(
327343
(self.size[0], self.size[1]), self.dtype)
328344
return {
329345
'a_in': seq,
330-
'a_out': np.fft.ifft2(seq, (self.size[1], self.size[1]))
346+
'a_out': np.fft.ifft2(seq, (self.size[0], self.size[1]))
347+
}
348+
349+
def rfft_2d(self) -> Dict[str, np.ndarray]:
350+
seq = matx_common.randn_ndarray(
351+
(self.size[0], self.size[1]), self.dtype)
352+
return {
353+
'a_in': seq,
354+
'a_out': np.fft.rfft2(seq, (self.size[0], self.size[1]))
355+
}
356+
357+
def irfft_2d(self) -> Dict[str, np.ndarray]:
358+
seq = matx_common.randn_ndarray(
359+
(self.size[0], self.size[1]), self.dtype)
360+
return {
361+
'a_in': seq,
362+
'a_out': np.fft.irfft2(seq, (self.size[0], self.size[1]))
331363
}
332364

333365

0 commit comments

Comments
 (0)