-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Init MKL for Pytorch XPU and enable fft_c2c
- Loading branch information
Showing
13 changed files
with
760 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
set(ONEMKL_FOUND FALSE) | ||
|
||
set(ONEMKL_LIBRARIES) | ||
|
||
# In order to be compatible with various situations of Pytorch development | ||
# bundle setup, ENV{MKLROOT} and SYCL_ROOT will be checked sequentially to get | ||
# the root directory of oneMKL. | ||
if(DEFINED ENV{MKLROOT}) | ||
# Directly get the root directory of oneMKL if ENV{MKLROOT} exists. | ||
set(ONEMKL_ROOT $ENV{MKLROOT}) | ||
elseif(SYCL_FOUND) | ||
# oneMKL configuration may not be imported into the build system. Get the root | ||
# directory of oneMKL based on the root directory of compiler relatively. | ||
get_filename_component(ONEMKL_ROOT "${SYCL_ROOT}/../../mkl/latest" REALPATH) | ||
endif() | ||
|
||
if(NOT DEFINED ONEMKL_ROOT) | ||
message( | ||
WARNING | ||
"Cannot find either ENV{MKLROOT} or SYCL_ROOT, please setup oneAPI environment before building!!" | ||
) | ||
return() | ||
endif() | ||
|
||
if(NOT EXISTS ${ONEMKL_ROOT}) | ||
message( | ||
WARNING | ||
"${ONEMKL_ROOT} not found, please setup oneAPI environment before building!!" | ||
) | ||
return() | ||
endif() | ||
|
||
find_file( | ||
ONEMKL_INCLUDE_DIR | ||
NAMES include | ||
HINTS ${ONEMKL_ROOT} | ||
NO_DEFAULT_PATH) | ||
|
||
find_file( | ||
ONEMKL_LIB_DIR | ||
NAMES lib | ||
HINTS ${ONEMKL_ROOT} | ||
NO_DEFAULT_PATH) | ||
|
||
if((ONEMKL_INCLUDE_DIR STREQUAL "ONEMKL_INCLUDE_DIR-NOTFOUND") | ||
OR (ONEMKL_LIB_DIR STREQUAL "ONEMKL_LIB_DIR-NOTFOUND")) | ||
message(WARNING "oneMKL sdk is incomplete!!") | ||
return() | ||
endif() | ||
|
||
if(WIN32) | ||
set(MKL_LIB_NAMES "mkl_intel_lp64" "mkl_intel_thread" "mkl_core" "mkl_sycl") | ||
else() | ||
set(MKL_LIB_NAMES "mkl_intel_lp64" "mkl_gnu_thread" "mkl_core" "mkl_sycl_dft") | ||
endif() | ||
|
||
foreach(LIB_NAME IN LISTS MKL_LIB_NAMES) | ||
find_library( | ||
${LIB_NAME}_library | ||
NAMES ${LIB_NAME} | ||
HINTS ${ONEMKL_LIB_DIR} | ||
NO_CMAKE_PATH NO_CMAKE_ENVIRONMENT_PATH) | ||
list(APPEND ONEMKL_LIBRARIES ${${LIB_NAME}_library}) | ||
endforeach() | ||
|
||
set(ONEMKL_FOUND TRUE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
find_package(ONEMKL) | ||
if(NOT ONEMKL_FOUND) | ||
message(FATAL_ERROR "Can NOT find ONEMKL cmake helpers module!") | ||
endif() | ||
|
||
set(TORCH_XPU_OPS_ONEMKL_INCLUDE_DIR ${ONEMKL_INCLUDE_DIR}) | ||
|
||
set(TORCH_XPU_OPS_ONEMKL_LIBRARIES ${ONEMKL_LIBRARIES}) | ||
|
||
list(INSERT TORCH_XPU_OPS_ONEMKL_LIBRARIES 0 "-Wl,--no-as-needed") | ||
list(APPEND TORCH_XPU_OPS_ONEMKL_LIBRARIES "-Wl,--as-needed") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,13 @@ | ||
# ATen XPU sources | ||
|
||
file(GLOB xpu_cpp "xpu/*.cpp" "native/xpu/*.cpp" "native/sparse/*.cpp") | ||
file(GLOB xpu_mkl "native/xpu/mkl/*.cpp") | ||
file(GLOB xpu_sycl "native/xpu/sycl/*.cpp") | ||
|
||
list(APPEND ATen_XPU_CPP_SRCS ${xpu_cpp}) | ||
list(APPEND ATen_XPU_MKL_SRCS ${xpu_mkl}) | ||
list(APPEND ATen_XPU_SYCL_SRCS ${xpu_sycl}) | ||
|
||
set(ATen_XPU_CPP_SRCS ${ATen_XPU_CPP_SRCS} PARENT_SCOPE) | ||
set(ATen_XPU_MKL_SRCS ${ATen_XPU_MKL_SRCS} PARENT_SCOPE) | ||
set(ATen_XPU_SYCL_SRCS ${ATen_XPU_SYCL_SRCS} PARENT_SCOPE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#include <ATen/ATen.h> | ||
#include <ATen/native/Resize.h> | ||
#include <ATen/native/xpu/mkl/SpectralOps.h> | ||
#include <ATen/xpu/XPUNativeFunctions.h> | ||
|
||
namespace at { | ||
|
||
Tensor XPUNativeFunctions::_fft_c2c( | ||
const Tensor& self, | ||
IntArrayRef dim, | ||
int64_t normalization, | ||
bool forward) { | ||
TORCH_CHECK(self.is_complex()); | ||
|
||
#if defined(__MKL_FALLBACK_TO_CPU) | ||
Tensor out_cpu = native::_fft_c2c_mkl( | ||
self.to(Device(at::kCPU)), dim, normalization, forward); | ||
return out_cpu.to(Device(at::kXPU)); | ||
#else | ||
return native::xpu::_fft_c2c_mkl(self, dim, normalization, forward); | ||
#endif // __MKL_FALLBACK_TO_CPU | ||
} | ||
|
||
Tensor& XPUNativeFunctions::_fft_c2c_out( | ||
const Tensor& self, | ||
IntArrayRef dim, | ||
int64_t normalization, | ||
bool forward, | ||
Tensor& out) { | ||
TORCH_CHECK(self.is_complex()); | ||
|
||
#if defined(__MKL_FALLBACK_TO_CPU) | ||
Tensor out_cpu = out.to(Device(at::kCPU)); | ||
native::_fft_c2c_mkl_out( | ||
self.to(Device(at::kCPU)), dim, normalization, forward, out_cpu); | ||
out.copy_(out_cpu); | ||
return out; | ||
#else | ||
return native::xpu::_fft_c2c_mkl_out(self, dim, normalization, forward, out); | ||
#endif // __MKL_FALLBACK_TO_CPU | ||
} | ||
|
||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.