Skip to content

Commit

Permalink
Fallback MKL XPU to CPU on Windows
Browse files Browse the repository at this point in the history
  • Loading branch information
CuiYifeng committed Aug 16, 2024
1 parent e39ea93 commit b544908
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/ATen/native/xpu/SpectralOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ Tensor XPUNativeFunctions::_fft_c2c(
bool forward) {
TORCH_CHECK(self.is_complex());

#if defined(__MKL_FALLBACK_TO_CPU)
Tensor out_cpu = native::_fft_c2c_mkl(
self.to(Device(at::kCPU)), dim, normalization, forward);
return out_cpu.to(Device(at::kXPU));
#else
return native::xpu::_fft_c2c_mkl(self, dim, normalization, forward);
#endif // __MKL_FALLBACK_TO_CPU
}

Tensor& XPUNativeFunctions::_fft_c2c_out(
Expand All @@ -23,7 +29,15 @@ Tensor& XPUNativeFunctions::_fft_c2c_out(
Tensor& out) {
TORCH_CHECK(self.is_complex());

#if defined(__MKL_FALLBACK_TO_CPU)
Tensor out_cpu = out.to(Device(at::kCPU));
native::_fft_c2c_mkl_out(
self.to(Device(at::kCPU)), dim, normalization, forward, out_cpu);
out.copy_(out_cpu);
return out;
#else
return native::xpu::_fft_c2c_mkl_out(self, dim, normalization, forward, out);
#endif // __MKL_FALLBACK_TO_CPU
}

} // namespace at
3 changes: 3 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,7 @@ if(CLANG_FORMAT)
endif()

target_include_directories(torch_xpu_ops PUBLIC ${ONEMKL_INCLUDE_DIR})
if(WIN32)
target_compile_options(torch_xpu_ops PRIVATE "-D__MKL_FALLBACK_TO_CPU")
endif()
target_link_libraries(torch_xpu_ops PUBLIC ${TORCH_XPU_OPS_ONEMKL_LIBRARIES})

0 comments on commit b544908

Please sign in to comment.