Skip to content

Commit

Permalink
Refine fft_c2c_out
Browse files Browse the repository at this point in the history
  • Loading branch information
CuiYifeng committed Aug 8, 2024
1 parent 2bfa3a6 commit 03b3264
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
5 changes: 1 addition & 4 deletions src/ATen/native/xpu/SpectralOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ Tensor& XPUNativeFunctions::_fft_c2c_out(
Tensor& out) {
TORCH_CHECK(self.is_complex());

Tensor result = native::xpu::_fft_c2c_mkl(self, dim, normalization, forward);
at::native::resize_output(out, result.sizes());
out.copy_(result);
return out;
return native::xpu::_fft_c2c_mkl_out(self, dim, normalization, forward, out);
}

} // namespace at
24 changes: 24 additions & 0 deletions src/ATen/native/xpu/mkl/SpectralOps.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <ATen/native/Resize.h>
#include <ATen/native/SpectralOpsUtils.h>
#include <ATen/native/xpu/mkl/SpectralOps.h>
#include <comm/SYCLContext.h>
Expand Down Expand Up @@ -406,6 +407,16 @@ const Tensor& _fft_apply_normalization(
return (scale == 1.0) ? self : self.mul_(scale);
}

Tensor& _fft_apply_normalization_out(
Tensor& out,
const Tensor& self,
int64_t normalization,
IntArrayRef sizes,
IntArrayRef dims) {
auto scale = _dft_scale(dims, sizes, self.sizes(), normalization);
return at::mul_out(out, self, c10::scalar_to_tensor(scale));
}

} // namespace impl

Tensor _fft_c2c_mkl(
Expand Down Expand Up @@ -456,4 +467,17 @@ Tensor _fft_c2c_mkl(
return impl::_fft_apply_normalization(out, normalization, input_sizes, dim);
}

Tensor& _fft_c2c_mkl_out(
const Tensor& self,
IntArrayRef dim,
int64_t normalization,
bool forward,
Tensor& out) {
auto result = _fft_c2c_mkl(
self, dim, static_cast<int64_t>(fft_norm_mode::none), forward);
at::native::resize_output(out, result.sizes());
return impl::_fft_apply_normalization_out(
out, result, normalization, result.sizes(), dim);
}

} // namespace at::native::xpu
7 changes: 7 additions & 0 deletions src/ATen/native/xpu/mkl/SpectralOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,11 @@ Tensor _fft_c2c_mkl(
int64_t normalization,
bool forward);

Tensor& _fft_c2c_mkl_out(
const Tensor& self,
IntArrayRef dim,
int64_t normalization,
bool forward,
Tensor& out);

} // namespace at::native::xpu

0 comments on commit 03b3264

Please sign in to comment.