From 03b3264128d857823912e6b3551641160239f559 Mon Sep 17 00:00:00 2001 From: "Cui, Yifeng" Date: Thu, 8 Aug 2024 07:01:22 -0700 Subject: [PATCH] Refine fft_c2c_out --- src/ATen/native/xpu/SpectralOps.cpp | 5 +---- src/ATen/native/xpu/mkl/SpectralOps.cpp | 24 ++++++++++++++++++++++++ src/ATen/native/xpu/mkl/SpectralOps.h | 7 +++++++ 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/src/ATen/native/xpu/SpectralOps.cpp b/src/ATen/native/xpu/SpectralOps.cpp index e1a3a476d..cca2810f9 100644 --- a/src/ATen/native/xpu/SpectralOps.cpp +++ b/src/ATen/native/xpu/SpectralOps.cpp @@ -23,10 +23,7 @@ Tensor& XPUNativeFunctions::_fft_c2c_out( Tensor& out) { TORCH_CHECK(self.is_complex()); - Tensor result = native::xpu::_fft_c2c_mkl(self, dim, normalization, forward); - at::native::resize_output(out, result.sizes()); - out.copy_(result); - return out; + return native::xpu::_fft_c2c_mkl_out(self, dim, normalization, forward, out); } } // namespace at diff --git a/src/ATen/native/xpu/mkl/SpectralOps.cpp b/src/ATen/native/xpu/mkl/SpectralOps.cpp index 664aedc36..1f22b17c7 100644 --- a/src/ATen/native/xpu/mkl/SpectralOps.cpp +++ b/src/ATen/native/xpu/mkl/SpectralOps.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -406,6 +407,16 @@ const Tensor& _fft_apply_normalization( return (scale == 1.0) ? self : self.mul_(scale); } +Tensor& _fft_apply_normalization_out( + Tensor& out, + const Tensor& self, + int64_t normalization, + IntArrayRef sizes, + IntArrayRef dims) { + auto scale = _dft_scale(dims, sizes, self.sizes(), normalization); + return at::mul_out(out, self, c10::scalar_to_tensor(scale)); +} + } // namespace impl Tensor _fft_c2c_mkl( @@ -456,4 +467,17 @@ Tensor _fft_c2c_mkl( return impl::_fft_apply_normalization(out, normalization, input_sizes, dim); } +Tensor& _fft_c2c_mkl_out( + const Tensor& self, + IntArrayRef dim, + int64_t normalization, + bool forward, + Tensor& out) { + auto result = _fft_c2c_mkl( + self, dim, static_cast(fft_norm_mode::none), forward); + at::native::resize_output(out, result.sizes()); + return impl::_fft_apply_normalization_out( + out, result, normalization, result.sizes(), dim); +} + } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/mkl/SpectralOps.h b/src/ATen/native/xpu/mkl/SpectralOps.h index 405e099d2..39763e428 100644 --- a/src/ATen/native/xpu/mkl/SpectralOps.h +++ b/src/ATen/native/xpu/mkl/SpectralOps.h @@ -8,4 +8,11 @@ Tensor _fft_c2c_mkl( int64_t normalization, bool forward); +Tensor& _fft_c2c_mkl_out( + const Tensor& self, + IntArrayRef dim, + int64_t normalization, + bool forward, + Tensor& out); + } // namespace at::native::xpu