-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Init MKL for Pytorch XPU and enable fft_c2c
- Loading branch information
Showing
10 changed files
with
672 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
set(ONEMKL_XPU_FOUND FALSE) | ||
|
||
set(ONEMKL_XPU_LIBRARIES) | ||
|
||
if(DEFINED ENV{MKLROOT}) | ||
set(ONEMKL_ROOT $ENV{MKLROOT}) | ||
endif() | ||
|
||
if(NOT ONEMKL_ROOT) | ||
message(WARNING "Cannot find oneMKL in ENV{MKLROOT}, please setup oneMKL before building!!") | ||
return() | ||
endif() | ||
|
||
find_file( | ||
ONEMKL_INCLUDE_DIR | ||
NAMES include | ||
HINTS ${ONEMKL_ROOT} | ||
NO_DEFAULT_PATH | ||
) | ||
|
||
find_file( | ||
ONEMKL_LIB_DIR | ||
NAMES lib | ||
HINTS ${ONEMKL_ROOT} | ||
NO_DEFAULT_PATH | ||
) | ||
|
||
if((NOT ONEMKL_INCLUDE_DIR) OR (NOT ONEMKL_LIB_DIR)) | ||
message(WARNING "oneMKL sdk is incomplete!!") | ||
return() | ||
endif() | ||
|
||
set(CMAKE_INCLUDE_PATH ${CMAKE_INCLUDE_PATH} | ||
"${ONEMKL_INCLUDE_DIR}") | ||
set(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH} | ||
"${ONEMKL_LIB_DIR}") | ||
|
||
set(MKL_LIB_NAMES "mkl_intel_lp64" "mkl_gnu_thread" "mkl_core" "mkl_sycl_dft") | ||
|
||
foreach(LIB_NAME IN LISTS MKL_LIB_NAMES) | ||
find_library( | ||
${LIB_NAME}_library | ||
NAMES ${LIB_NAME} | ||
HINTS ${ONEMKL_LIB_DIR} | ||
NO_CMAKE_PATH | ||
NO_CMAKE_ENVIRONMENT_PATH | ||
) | ||
list(APPEND ONEMKL_XPU_LIBRARIES ${${LIB_NAME}_library}) | ||
endforeach() | ||
|
||
list(INSERT ONEMKL_XPU_LIBRARIES 0 "-Wl,--no-as-needed") | ||
list(APPEND ONEMKL_XPU_LIBRARIES "-Wl,--as-needed") | ||
|
||
set(ONEMKL_XPU_FOUND TRUE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,13 @@ | ||
# ATen XPU sources | ||
|
||
file(GLOB xpu_cpp "xpu/*.cpp" "native/xpu/*.cpp" "native/sparse/*.cpp") | ||
file(GLOB xpu_mkl "native/xpu/mkl/*.cpp") | ||
file(GLOB xpu_sycl "native/xpu/sycl/*.cpp") | ||
|
||
list(APPEND ATen_XPU_CPP_SRCS ${xpu_cpp}) | ||
list(APPEND ATen_XPU_MKL_SRCS ${xpu_mkl}) | ||
list(APPEND ATen_XPU_SYCL_SRCS ${xpu_sycl}) | ||
|
||
set(ATen_XPU_CPP_SRCS ${ATen_XPU_CPP_SRCS} PARENT_SCOPE) | ||
set(ATen_XPU_MKL_SRCS ${ATen_XPU_MKL_SRCS} PARENT_SCOPE) | ||
set(ATen_XPU_SYCL_SRCS ${ATen_XPU_SYCL_SRCS} PARENT_SCOPE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#include <ATen/ATen.h> | ||
#include <ATen/native/xpu/mkl/SpectralOps.h> | ||
#include <ATen/xpu/XPUNativeFunctions.h> | ||
|
||
namespace at { | ||
|
||
Tensor XPUNativeFunctions::_fft_c2c( | ||
const Tensor& self, | ||
IntArrayRef dim, | ||
int64_t normalization, | ||
bool forward) { | ||
TORCH_CHECK(self.is_complex()); | ||
|
||
if (dim.empty()) { | ||
return self.clone(); | ||
} | ||
|
||
return native::xpu::_fft_c2c_mkl(self, dim, normalization, forward); | ||
} | ||
|
||
Tensor& XPUNativeFunctions::_fft_c2c_out( | ||
const Tensor& self, | ||
IntArrayRef dim, | ||
int64_t normalization, | ||
bool forward, | ||
Tensor& out) { | ||
TORCH_CHECK(self.is_complex()); | ||
|
||
if (dim.empty()) { | ||
out.copy_(self); | ||
return out; | ||
} | ||
|
||
Tensor result = native::xpu::_fft_c2c_mkl(self, dim, normalization, forward); | ||
out.copy_(result); | ||
return out; | ||
} | ||
|
||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.