From 4b280d799cb58aeeb02e733fe301ea049ab87596 Mon Sep 17 00:00:00 2001 From: "Cui, Yifeng" Date: Thu, 1 Aug 2024 22:54:01 -0700 Subject: [PATCH] Init MKL for Pytorch XPU and enable fft_c2c --- cmake/Modules/FindONEMKL.cmake | 54 +++ src/ATen/CMakeLists.txt | 3 + src/ATen/native/xpu/SpectralOps.cpp | 39 ++ src/ATen/native/xpu/XPUFallback.template | 1 - src/ATen/native/xpu/mkl/SpectralOps.cpp | 457 +++++++++++++++++++++++ src/ATen/native/xpu/mkl/SpectralOps.h | 11 + src/CMakeLists.txt | 9 + test/xpu/run_test_with_skip.py | 16 + test/xpu/test_spectral_ops_xpu.py | 81 ++++ yaml/xpu_functions.yaml | 2 + 10 files changed, 672 insertions(+), 1 deletion(-) create mode 100644 cmake/Modules/FindONEMKL.cmake create mode 100644 src/ATen/native/xpu/SpectralOps.cpp create mode 100644 src/ATen/native/xpu/mkl/SpectralOps.cpp create mode 100644 src/ATen/native/xpu/mkl/SpectralOps.h create mode 100644 test/xpu/test_spectral_ops_xpu.py diff --git a/cmake/Modules/FindONEMKL.cmake b/cmake/Modules/FindONEMKL.cmake new file mode 100644 index 000000000..f76df4975 --- /dev/null +++ b/cmake/Modules/FindONEMKL.cmake @@ -0,0 +1,54 @@ +set(ONEMKL_XPU_FOUND FALSE) + +set(ONEMKL_XPU_LIBRARIES) + +if(DEFINED ENV{MKLROOT}) + set(ONEMKL_ROOT $ENV{MKLROOT}) +endif() + +if(NOT ONEMKL_ROOT) + message(WARNING "Cannot find oneMKL in ENV{MKLROOT}, please setup oneMKL before building!!") + return() +endif() + +find_file( + ONEMKL_INCLUDE_DIR + NAMES include + HINTS ${ONEMKL_ROOT} + NO_DEFAULT_PATH +) + +find_file( + ONEMKL_LIB_DIR + NAMES lib + HINTS ${ONEMKL_ROOT} + NO_DEFAULT_PATH +) + +if((NOT ONEMKL_INCLUDE_DIR) OR (NOT ONEMKL_LIB_DIR)) + message(WARNING "oneMKL sdk is incomplete!!") + return() +endif() + +set(CMAKE_INCLUDE_PATH ${CMAKE_INCLUDE_PATH} + "${ONEMKL_INCLUDE_DIR}") +set(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH} + "${ONEMKL_LIB_DIR}") + +set(MKL_LIB_NAMES "mkl_intel_lp64" "mkl_gnu_thread" "mkl_core" "mkl_sycl_dft") + +foreach(LIB_NAME IN LISTS MKL_LIB_NAMES) + find_library( + ${LIB_NAME}_library + NAMES ${LIB_NAME} + HINTS ${ONEMKL_LIB_DIR} + NO_CMAKE_PATH + NO_CMAKE_ENVIRONMENT_PATH + ) + list(APPEND ONEMKL_XPU_LIBRARIES ${${LIB_NAME}_library}) +endforeach() + +list(INSERT ONEMKL_XPU_LIBRARIES 0 "-Wl,--no-as-needed") +list(APPEND ONEMKL_XPU_LIBRARIES "-Wl,--as-needed") + +set(ONEMKL_XPU_FOUND TRUE) diff --git a/src/ATen/CMakeLists.txt b/src/ATen/CMakeLists.txt index 815ad018f..040e7f21a 100644 --- a/src/ATen/CMakeLists.txt +++ b/src/ATen/CMakeLists.txt @@ -1,10 +1,13 @@ # ATen XPU sources file(GLOB xpu_cpp "xpu/*.cpp" "native/xpu/*.cpp" "native/sparse/*.cpp") +file(GLOB xpu_mkl "native/xpu/mkl/*.cpp") file(GLOB xpu_sycl "native/xpu/sycl/*.cpp") list(APPEND ATen_XPU_CPP_SRCS ${xpu_cpp}) +list(APPEND ATen_XPU_MKL_SRCS ${xpu_mkl}) list(APPEND ATen_XPU_SYCL_SRCS ${xpu_sycl}) set(ATen_XPU_CPP_SRCS ${ATen_XPU_CPP_SRCS} PARENT_SCOPE) +set(ATen_XPU_MKL_SRCS ${ATen_XPU_MKL_SRCS} PARENT_SCOPE) set(ATen_XPU_SYCL_SRCS ${ATen_XPU_SYCL_SRCS} PARENT_SCOPE) diff --git a/src/ATen/native/xpu/SpectralOps.cpp b/src/ATen/native/xpu/SpectralOps.cpp new file mode 100644 index 000000000..3225ff51b --- /dev/null +++ b/src/ATen/native/xpu/SpectralOps.cpp @@ -0,0 +1,39 @@ +#include +#include +#include + +namespace at { + +Tensor XPUNativeFunctions::_fft_c2c( + const Tensor& self, + IntArrayRef dim, + int64_t normalization, + bool forward) { + TORCH_CHECK(self.is_complex()); + + if (dim.empty()) { + return self.clone(); + } + + return native::xpu::_fft_c2c_mkl(self, dim, normalization, forward); +} + +Tensor& XPUNativeFunctions::_fft_c2c_out( + const Tensor& self, + IntArrayRef dim, + int64_t normalization, + bool forward, + Tensor& out) { + TORCH_CHECK(self.is_complex()); + + if (dim.empty()) { + out.copy_(self); + return out; + } + + Tensor result = native::xpu::_fft_c2c_mkl(self, dim, normalization, forward); + out.copy_(result); + return out; +} + +} // namespace at diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 5b2d6e5ff..c64427699 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -176,7 +176,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "_efficient_attention_forward", "_embedding_bag_dense_backward", "_embedding_bag_per_sample_weights_backward", - "_fft_c2c", "_fft_c2r", "_fft_r2c", "_flash_attention_forward", diff --git a/src/ATen/native/xpu/mkl/SpectralOps.cpp b/src/ATen/native/xpu/mkl/SpectralOps.cpp new file mode 100644 index 000000000..4b536bf3d --- /dev/null +++ b/src/ATen/native/xpu/mkl/SpectralOps.cpp @@ -0,0 +1,457 @@ +#include +#include +#include +#include +#include + +using namespace oneapi::mkl::dft; + +namespace at::native::xpu { + +namespace impl { + +constexpr int64_t mkl_max_ndim = 3; + +// Sort transform dimensions by input layout, for best performance +// exclude_last is for onesided transforms where the last dimension cannot be +// reordered +static DimVector _sort_dims( + const Tensor& self, + IntArrayRef dim, + bool exclude_last = false) { + DimVector sorted_dims(dim.begin(), dim.end()); + auto self_strides = self.strides(); + std::sort( + sorted_dims.begin(), + sorted_dims.end() - exclude_last, + [&](int64_t a, int64_t b) { return self_strides[a] > self_strides[b]; }); + return sorted_dims; +} + +class dft_config_t { + public: + using config_int64_t = std::unordered_map; + using config_float_t = std::unordered_map; + using config_double_t = std::unordered_map; + + dft_config_t() { + val_int64_.clear(); + val_float_.clear(); + val_double_.clear(); + fwd_strides_.clear(); + bwd_strides_.clear(); + } + + void set_strides( + std::vector& fwd_strides, + std::vector& bwd_strides) { + fwd_strides_ = fwd_strides; + bwd_strides_ = bwd_strides; + } + + template + void set_value(config_param key, T value) { + if (std::is_same::value || + std::is_same::value) { + val_int64_.insert({key, value}); + } else if (std::is_same::value) { + val_float_.insert({key, value}); + } else if (std::is_same::value) { + val_double_.insert({key, value}); + } else { + TORCH_CHECK(0, "Unsupported value type in FFT config!"); + } + } + + template + void commit_values(descriptor& desc) { +#define COMMIT_VAL(val_map) \ + for (auto& value : (val_map)) { \ + desc.set_value(value.first, value.second); \ + } + + COMMIT_VAL(val_int64_); + COMMIT_VAL(val_float_); + COMMIT_VAL(val_double_); + + if (!fwd_strides_.empty()) { + desc.set_value(config_param::FWD_STRIDES, fwd_strides_.data()); + } + if (!bwd_strides_.empty()) { + desc.set_value(config_param::BWD_STRIDES, bwd_strides_.data()); + } + } + + private: + config_int64_t val_int64_; + config_float_t val_float_; + config_double_t val_double_; + std::vector fwd_strides_; + std::vector bwd_strides_; +}; + +template +class dft_desc_t { + public: + using mkl_desc_t = descriptor; + + dft_desc_t( + sycl::queue& q, + std::vector& dimensions, + std::shared_ptr configs) + : desc_(dimensions), configs_(configs) { + configs_->commit_values(desc_); + desc_.set_value( + oneapi::mkl::dft::config_param::WORKSPACE, + oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL); + desc_.commit(q); + } + + mkl_desc_t& raw() { + return desc_; + } + + private: + mkl_desc_t desc_; + std::shared_ptr configs_; +}; + +template +void _mkl_dft( + const Tensor& input, + Tensor& output, + int64_t signal_ndim, + bool complex_input, + bool complex_output, + bool inverse, + IntArrayRef checked_signal_sizes, + bool onesided, + int64_t batch) { + auto& queue = at::xpu::getCurrentSYCLQueue(); + std::vector mkl_signal_sizes( + checked_signal_sizes.begin() + 1, checked_signal_sizes.end()); + + std::shared_ptr desc_config(new dft_config_t); + desc_config->set_value(config_param::PLACEMENT, DFTI_NOT_INPLACE); + desc_config->set_value(config_param::NUMBER_OF_TRANSFORMS, batch); + + auto istrides = input.strides(); + auto ostrides = output.strides(); + int64_t idist = istrides[0]; + int64_t odist = ostrides[0]; + + if (!inverse) { + desc_config->set_value(config_param::FWD_DISTANCE, idist); + desc_config->set_value(config_param::BWD_DISTANCE, odist); + } else { + desc_config->set_value(config_param::FWD_DISTANCE, odist); + desc_config->set_value(config_param::BWD_DISTANCE, idist); + } + + std::vector fwd_strides(1 + signal_ndim, 0), + bwd_strides(1 + signal_ndim, 0); + + for (int64_t i = 1; i <= signal_ndim; i++) { + if (!inverse) { + fwd_strides[i] = istrides[i]; + bwd_strides[i] = ostrides[i]; + } else { + fwd_strides[i] = ostrides[i]; + bwd_strides[i] = istrides[i]; + } + } + + desc_config->set_strides(fwd_strides, bwd_strides); + + if (!complex_input || !complex_output) { + desc_config->set_value( + config_param::CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX); + } + + auto desc = + dft_desc_t(queue, mkl_signal_sizes, desc_config); + + // Obtain the size of workspace required after commit. + size_t workspaceSizeBytes = 0; + desc.raw().get_value( + oneapi::mkl::dft::config_param::WORKSPACE_BYTES, &workspaceSizeBytes); + + // Allocate USM workspace and provide it to the descriptor. + Tensor workspaceBuf = at::empty( + {(long)(workspaceSizeBytes / sizeof(double))}, + input.options().dtype(at::kDouble), + c10::nullopt); + desc.raw().set_workspace((double*)workspaceBuf.data_ptr()); + + auto in_data = (scalar_t*)input.data_ptr(); + auto out_data = (scalar_t*)output.data_ptr(); + + sycl::event event; + if (!inverse) { + event = compute_forward(desc.raw(), in_data, out_data); + } else { + event = compute_backward(desc.raw(), in_data, out_data); + } + event.wait_and_throw(); + queue.throw_asynchronous(); +} + +void _fft_with_size( + Tensor& output, + const Tensor& self, + int64_t signal_ndim, + bool complex_input, + bool complex_output, + bool inverse, + IntArrayRef checked_signal_sizes, + bool onesided) { + int64_t batch = self.size(0); + Tensor input_ = self; + // real/imag dimension must aligned when viewed as of complex type + + if (complex_input) { + bool need_contiguous = input_.stride(-1) != 1; + + for (int64_t i = 0; !need_contiguous && i <= signal_ndim; i++) { + need_contiguous |= input_.stride(i) % 2 != 0; + } + + if (need_contiguous) { + input_ = input_.contiguous(); + } + } + + bool complex_type = inverse ? complex_output : complex_input; + + void (*dft_func)( + const class at::Tensor&, + class at::Tensor&, + int64_t, + bool, + bool, + bool, + class c10::ArrayRef, + bool, + int64_t); + Tensor input = input_; + + if (input.scalar_type() == ScalarType::Float || + input.scalar_type() == ScalarType::ComplexFloat) { + dft_func = complex_type + ? _mkl_dft + : _mkl_dft; + } else if ( + input.scalar_type() == ScalarType::Double || + input.scalar_type() == ScalarType::ComplexDouble) { + dft_func = complex_type + ? _mkl_dft + : _mkl_dft; + } else { + AT_ERROR("MKL FFT doesn't support tensor of type"); + } + + dft_func( + input, + output, + signal_ndim, + complex_input, + complex_output, + inverse, + checked_signal_sizes, + onesided, + batch); +} + +// Execute a general fft operation (can be c2c, onesided r2c or onesided c2r) +Tensor& _exec_fft( + Tensor& out, + Tensor self, + IntArrayRef out_sizes, + IntArrayRef dim, + bool onesided, + bool forward) { + const auto ndim = self.dim(); + const int64_t signal_ndim = dim.size(); + const auto batch_dims = ndim - signal_ndim; + + // Permute dimensions so batch dimensions come first, and in stride order + // This maximizes data locality when collapsing to a single batch dimension + DimVector dim_permute(ndim); + std::iota(dim_permute.begin(), dim_permute.end(), int64_t{0}); + + c10::SmallVector is_transformed_dim(ndim); + for (const auto& d : dim) { + is_transformed_dim[d] = true; + } + + auto batch_end = + std::partition(dim_permute.begin(), dim_permute.end(), [&](int64_t d) { + return !is_transformed_dim[d]; + }); + + auto self_strides = self.strides(); + std::sort(dim_permute.begin(), batch_end, [&](int64_t a, int64_t b) { + return self_strides[a] > self_strides[b]; + }); + std::copy(dim.cbegin(), dim.cend(), batch_end); + + auto input = self.permute(dim_permute); + + // Collapse batch dimensions into a single dimension + DimVector batched_sizes(signal_ndim + 1); + batched_sizes[0] = -1; + std::copy( + input.sizes().cbegin() + batch_dims, + input.sizes().cend(), + batched_sizes.begin() + 1); + input = input.reshape(batched_sizes); + + const auto batch_size = input.sizes()[0]; + DimVector signal_size(signal_ndim + 1); + signal_size[0] = batch_size; + + for (int64_t i = 0; i < signal_ndim; ++i) { + auto in_size = input.sizes()[i + 1]; + auto out_size = out_sizes[dim[i]]; + signal_size[i + 1] = std::max(in_size, out_size); + TORCH_INTERNAL_ASSERT( + in_size == signal_size[i + 1] || + in_size == (signal_size[i + 1] / 2) + 1); + TORCH_INTERNAL_ASSERT( + out_size == signal_size[i + 1] || + out_size == (signal_size[i + 1] / 2) + 1); + } + + batched_sizes[0] = batch_size; + DimVector batched_out_sizes(batched_sizes.begin(), batched_sizes.end()); + + for (size_t i = 0; i < dim.size(); ++i) { + batched_out_sizes[i + 1] = out_sizes[dim[i]]; + } + + out.resize_(batched_out_sizes, MemoryFormat::Contiguous); + + // run the FFT + _fft_with_size( + out, + input, + signal_ndim, + input.is_complex(), + out.is_complex(), + !forward, + signal_size, + onesided); + + // Inplace reshaping to original batch shape and inverting the dimension + // permutation + DimVector out_strides(ndim); + int64_t batch_numel = 1; + + for (int64_t i = batch_dims - 1; i >= 0; --i) { + out_strides[dim_permute[i]] = batch_numel * out.strides()[0]; + batch_numel *= out_sizes[dim_permute[i]]; + } + + for (int64_t i = batch_dims; i < ndim; ++i) { + out_strides[dim_permute[i]] = out.strides()[1 + (i - batch_dims)]; + } + + out.as_strided_(out_sizes, out_strides, out.storage_offset()); + + return out; +} + +double _dft_scale( + IntArrayRef dim, + IntArrayRef input_sizes, + IntArrayRef out_sizes, + int64_t normalization) { + const auto norm = static_cast(normalization); + double double_scale = 1.0; + + if (norm == fft_norm_mode::none) { + return double_scale; + } + + const int64_t signal_ndim = dim.size(); + int64_t signal_numel = 1; + + for (int64_t i = 0; i < signal_ndim; ++i) { + auto in_size = input_sizes[dim[i]]; + auto out_size = out_sizes[dim[i]]; + auto signal_size = std::max(in_size, out_size); + + signal_numel *= signal_size; + TORCH_INTERNAL_ASSERT( + in_size == signal_size || in_size == (signal_size / 2) + 1); + TORCH_INTERNAL_ASSERT( + out_size == signal_size || out_size == (signal_size / 2) + 1); + } + + if (norm == fft_norm_mode::by_root_n) { + double_scale = 1.0 / std::sqrt(signal_numel); + } else { + double_scale = 1.0 / static_cast(signal_numel); + } + + return double_scale; +} + +const Tensor& _fft_apply_normalization( + const Tensor& self, + int64_t normalization, + IntArrayRef sizes, + IntArrayRef dims) { + auto scale = _dft_scale(dims, sizes, self.sizes(), normalization); + return (scale == 1.0) ? self : self.mul_(scale); +} + +} // namespace impl + +Tensor _fft_c2c_mkl( + const Tensor& self, + IntArrayRef dim, + int64_t normalization, + bool forward) { + TORCH_CHECK(self.is_complex()); + + auto sorted_dims = impl::_sort_dims(self, dim); + auto out_sizes = self.sizes(); + auto out = at::empty(out_sizes, self.options()); + auto input_sizes = self.sizes(); + auto working_tensor = self; + + while (!sorted_dims.empty()) { + const auto max_dims = + std::min(static_cast(impl::mkl_max_ndim), sorted_dims.size()); + auto fft_dims = + IntArrayRef(sorted_dims).slice(sorted_dims.size() - max_dims, max_dims); + + impl::_exec_fft( + out, + working_tensor, + out_sizes, + fft_dims, + /*onesided=*/false, + forward); + + sorted_dims.resize(sorted_dims.size() - max_dims); + + if (sorted_dims.empty()) { + break; + } + + sorted_dims = impl::_sort_dims(self, sorted_dims); + + if (working_tensor.is_same(self)) { + working_tensor = std::move(out); + out = at::empty(out_sizes, self.options()); + } else { + std::swap(out, working_tensor); + } + } + + return impl::_fft_apply_normalization(out, normalization, input_sizes, dim); +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/mkl/SpectralOps.h b/src/ATen/native/xpu/mkl/SpectralOps.h new file mode 100644 index 000000000..405e099d2 --- /dev/null +++ b/src/ATen/native/xpu/mkl/SpectralOps.h @@ -0,0 +1,11 @@ +#pragma once + +namespace at::native::xpu { + +Tensor _fft_c2c_mkl( + const Tensor& self, + IntArrayRef dim, + int64_t normalization, + bool forward); + +} // namespace at::native::xpu diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ba40a4b8c..a806c130f 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -2,6 +2,7 @@ include(${TORCH_XPU_OPS_ROOT}/cmake/Codegen.cmake) set(ATen_XPU_CPP_SRCS) +set(ATen_XPU_MKL_SRCS) set(ATen_XPU_SYCL_SRCS) set(ATen_XPU_INCLUDE_DIRS ${TORCH_XPU_OPS_ROOT}/src CACHE STRING "ATen XPU Include directory") @@ -15,6 +16,7 @@ add_library( torch_xpu_ops STATIC ${ATen_XPU_CPP_SRCS} + ${ATen_XPU_MKL_SRCS} ${ATen_XPU_GEN_SRCS}) if(BUILD_SEPARATE_OPS) @@ -91,3 +93,10 @@ if(CLANG_FORMAT) add_custom_target(CL_FORMAT_CSRCS COMMAND ${CLANG_FORMAT_EXEC} -i -style=file ${ALL_CSRCS}) add_dependencies(torch_xpu_ops CL_FORMAT_CSRCS) endif() + +find_package(ONEMKL) +if(NOT ONEMKL_XPU_FOUND) + message(FATAL_ERROR "Can NOT find ONEMKL cmake helpers module!") +endif() + +target_link_libraries(torch_xpu_ops PUBLIC ${ONEMKL_XPU_LIBRARIES}) diff --git a/test/xpu/run_test_with_skip.py b/test/xpu/run_test_with_skip.py index e7c34283e..62d118723 100644 --- a/test/xpu/run_test_with_skip.py +++ b/test/xpu/run_test_with_skip.py @@ -790,6 +790,16 @@ def launch_test(test_case, skip_list=None, exe_list=None): # 2. Half dtype is a common dtype in workloads. # So far CUDA doesn't support Half, so that XPU fails as we aligned claimed dtypes with CUDA in test infra. "test_dtypes_nextafter_xpu", + + # Deselect temporarily + "test_out_fft_fft2_xpu_float32", + "test_out_fft_fftn_xpu_float32", + "test_out_fft_ifft2_xpu_float32", + "test_out_fft_ifftn_xpu_float32", + "test_out_warning_fft_fft2_xpu", + "test_out_warning_fft_fftn_xpu", + "test_out_warning_fft_ifft2_xpu", + "test_out_warning_fft_ifftn_xpu", ) res += launch_test("test_ops_xpu.py", skip_list) @@ -2992,5 +3002,11 @@ def launch_test(test_case, skip_list=None, exe_list=None): res += launch_test("nn/test_parametrization_xpu.py") +# test_spectral_ops +skip_list = ( + "test_cufft_plan_cache_xpu_float64", +) +res += launch_test("test_spectral_ops_xpu.py", skip_list) + exit_code = os.WEXITSTATUS(res) sys.exit(exit_code) diff --git a/test/xpu/test_spectral_ops_xpu.py b/test/xpu/test_spectral_ops_xpu.py new file mode 100644 index 000000000..bc60cf2ae --- /dev/null +++ b/test/xpu/test_spectral_ops_xpu.py @@ -0,0 +1,81 @@ +# Owner(s): ["module: intel"] + +import torch +import numpy as np +from packaging import version +from itertools import product + +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, ops, onlyNativeDeviceTypes) +from torch.testing._internal.common_methods_invocations import ( + spectral_funcs, SpectralFuncType) +from torch.testing._internal.common_utils import run_tests + +try: + from .xpu_test_utils import XPUPatchForImport +except Exception as e: + from ..xpu_test_utils import XPUPatchForImport + +with XPUPatchForImport(False): + from test_spectral_ops import TestFFT + +has_scipy_fft = False +try: + import scipy.fft + has_scipy_fft = True +except ModuleNotFoundError: + pass + +REFERENCE_NORM_MODES = ( + (None, "forward", "backward", "ortho") + if version.parse(np.__version__) >= version.parse('1.20.0') and ( + not has_scipy_fft or version.parse(scipy.__version__) >= version.parse('1.6.0')) + else (None, "ortho")) + +@ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.OneD], + allowed_dtypes=(torch.float, torch.cfloat)) +def _test_reference_1d(self, device, dtype, op): + if op.ref is None: + raise unittest.SkipTest("No reference implementation") + + norm_modes = REFERENCE_NORM_MODES + test_args = [ + *product( + # input + (torch.randn(67, device=device, dtype=dtype), + torch.randn(80, device=device, dtype=dtype), + torch.randn(12, 14, device=device, dtype=dtype), + torch.randn(9, 6, 3, device=device, dtype=dtype)), + # n + (None, 50, 6), + # dim + (-1, 0), + # norm + norm_modes + ), + # Test transforming middle dimensions of multi-dim tensor + *product( + (torch.randn(4, 5, 6, 7, device=device, dtype=dtype),), + (None,), + (1, 2, -2,), + norm_modes + ) + ] + + for iargs in test_args: + args = list(iargs) + input = args[0] + args = args[1:] + + expected = op.ref(input.cpu().numpy(), *args) + exact_dtype = dtype in (torch.double, torch.complex128) + actual = op(input, *args) + self.assertEqual(actual, expected, exact_dtype=exact_dtype, atol=1e-4, rtol=1e-5) + +TestFFT.test_reference_1d = _test_reference_1d + +instantiate_device_type_tests(TestFFT, globals(), only_for=("xpu"), allow_xpu=True) + + +if __name__ == "__main__": + run_tests() diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index db3df0667..ad75c8359 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -719,3 +719,5 @@ supported: - _weight_norm_interface - _weight_norm_interface_backward - range.out + - _fft_c2c + - _fft_c2c.out