Skip to content

Commit

Permalink
Init MKL for Pytorch XPU and enable fft_c2c
Browse files Browse the repository at this point in the history
  • Loading branch information
CuiYifeng committed Aug 6, 2024
1 parent 718bc42 commit c94442e
Show file tree
Hide file tree
Showing 11 changed files with 699 additions and 2 deletions.
54 changes: 54 additions & 0 deletions cmake/Modules/FindONEMKL.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
set(ONEMKL_XPU_FOUND FALSE)

set(ONEMKL_XPU_LIBRARIES)

if(DEFINED ENV{MKLROOT})
set(ONEMKL_ROOT $ENV{MKLROOT})
endif()

if(NOT ONEMKL_ROOT)
message(WARNING "Cannot find oneMKL in ENV{MKLROOT}, please setup oneMKL before building!!")
return()
endif()

find_file(
ONEMKL_INCLUDE_DIR
NAMES include
HINTS ${ONEMKL_ROOT}
NO_DEFAULT_PATH
)

find_file(
ONEMKL_LIB_DIR
NAMES lib
HINTS ${ONEMKL_ROOT}
NO_DEFAULT_PATH
)

if((NOT ONEMKL_INCLUDE_DIR) OR (NOT ONEMKL_LIB_DIR))
message(WARNING "oneMKL sdk is incomplete!!")
return()
endif()

set(CMAKE_INCLUDE_PATH ${CMAKE_INCLUDE_PATH}
"${ONEMKL_INCLUDE_DIR}")
set(CMAKE_LIBRARY_PATH ${CMAKE_LIBRARY_PATH}
"${ONEMKL_LIB_DIR}")

set(MKL_LIB_NAMES "mkl_intel_lp64" "mkl_gnu_thread" "mkl_core" "mkl_sycl_dft")

foreach(LIB_NAME IN LISTS MKL_LIB_NAMES)
find_library(
${LIB_NAME}_library
NAMES ${LIB_NAME}
HINTS ${ONEMKL_LIB_DIR}
NO_CMAKE_PATH
NO_CMAKE_ENVIRONMENT_PATH
)
list(APPEND ONEMKL_XPU_LIBRARIES ${${LIB_NAME}_library})
endforeach()

list(INSERT ONEMKL_XPU_LIBRARIES 0 "-Wl,--no-as-needed")
list(APPEND ONEMKL_XPU_LIBRARIES "-Wl,--as-needed")

set(ONEMKL_XPU_FOUND TRUE)
3 changes: 3 additions & 0 deletions src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# ATen XPU sources

file(GLOB xpu_cpp "xpu/*.cpp" "native/xpu/*.cpp" "native/sparse/*.cpp")
file(GLOB xpu_mkl "native/xpu/mkl/*.cpp")
file(GLOB xpu_sycl "native/xpu/sycl/*.cpp")

list(APPEND ATen_XPU_CPP_SRCS ${xpu_cpp})
list(APPEND ATen_XPU_MKL_SRCS ${xpu_mkl})
list(APPEND ATen_XPU_SYCL_SRCS ${xpu_sycl})

set(ATen_XPU_CPP_SRCS ${ATen_XPU_CPP_SRCS} PARENT_SCOPE)
set(ATen_XPU_MKL_SRCS ${ATen_XPU_MKL_SRCS} PARENT_SCOPE)
set(ATen_XPU_SYCL_SRCS ${ATen_XPU_SYCL_SRCS} PARENT_SCOPE)
32 changes: 32 additions & 0 deletions src/ATen/native/xpu/SpectralOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include <ATen/ATen.h>
#include <ATen/native/Resize.h>
#include <ATen/native/xpu/mkl/SpectralOps.h>
#include <ATen/xpu/XPUNativeFunctions.h>

namespace at {

Tensor XPUNativeFunctions::_fft_c2c(
const Tensor& self,
IntArrayRef dim,
int64_t normalization,
bool forward) {
TORCH_CHECK(self.is_complex());

return native::xpu::_fft_c2c_mkl(self, dim, normalization, forward);
}

Tensor& XPUNativeFunctions::_fft_c2c_out(
const Tensor& self,
IntArrayRef dim,
int64_t normalization,
bool forward,
Tensor& out) {
TORCH_CHECK(self.is_complex());

Tensor result = native::xpu::_fft_c2c_mkl(self, dim, normalization, forward);
at::native::resize_output(out, result.sizes());
out.copy_(result);
return out;
}

} // namespace at
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"_efficient_attention_forward",
"_embedding_bag_dense_backward",
"_embedding_bag_per_sample_weights_backward",
"_fft_c2c",
"_fft_c2r",
"_fft_r2c",
"_flash_attention_forward",
Expand Down
Loading

0 comments on commit c94442e

Please sign in to comment.