diff --git a/src/ATen/native/xpu/SpectralOps.cpp b/src/ATen/native/xpu/SpectralOps.cpp index cca2810f9..adc25737f 100644 --- a/src/ATen/native/xpu/SpectralOps.cpp +++ b/src/ATen/native/xpu/SpectralOps.cpp @@ -12,7 +12,13 @@ Tensor XPUNativeFunctions::_fft_c2c( bool forward) { TORCH_CHECK(self.is_complex()); +#if defined(__MKL_FALLBACK_TO_CPU) + Tensor out_cpu = native::_fft_c2c_mkl( + self.to(Device(at::kCPU)), dim, normalization, forward); + return out_cpu.to(Device(at::kXPU)); +#else return native::xpu::_fft_c2c_mkl(self, dim, normalization, forward); +#endif // __MKL_FALLBACK_TO_CPU } Tensor& XPUNativeFunctions::_fft_c2c_out( @@ -23,7 +29,15 @@ Tensor& XPUNativeFunctions::_fft_c2c_out( Tensor& out) { TORCH_CHECK(self.is_complex()); +#if defined(__MKL_FALLBACK_TO_CPU) + Tensor out_cpu = out.to(Device(at::kCPU)); + native::_fft_c2c_mkl_out( + self.to(Device(at::kCPU)), dim, normalization, forward, out_cpu); + out.copy_(out_cpu); + return out; +#else return native::xpu::_fft_c2c_mkl_out(self, dim, normalization, forward, out); +#endif // __MKL_FALLBACK_TO_CPU } } // namespace at diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5d162ccbb..caaeb9b6c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -95,4 +95,7 @@ if(CLANG_FORMAT) endif() target_include_directories(torch_xpu_ops PUBLIC ${ONEMKL_INCLUDE_DIR}) +if(WIN32) + target_compile_options(torch_xpu_ops PRIVATE "-D__MKL_FALLBACK_TO_CPU") +endif() target_link_libraries(torch_xpu_ops PUBLIC ${TORCH_XPU_OPS_ONEMKL_LIBRARIES})