Skip to content

Commit

Permalink
Init MKL for Pytorch XPU and enable fft_c2c
Browse files Browse the repository at this point in the history
  • Loading branch information
CuiYifeng committed Dec 24, 2024
1 parent bc99386 commit f1204c4
Show file tree
Hide file tree
Showing 15 changed files with 694 additions and 3 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ set(TORCH_XPU_OPS_ROOT ${PROJECT_SOURCE_DIR})
list(APPEND CMAKE_MODULE_PATH ${TORCH_XPU_OPS_ROOT}/cmake/Modules)

include(${TORCH_XPU_OPS_ROOT}/cmake/SYCL.cmake)
include(${TORCH_XPU_OPS_ROOT}/cmake/ONEMKL.cmake)
include(${TORCH_XPU_OPS_ROOT}/cmake/BuildFlags.cmake)

if(BUILD_TEST)
Expand Down
66 changes: 66 additions & 0 deletions cmake/Modules/FindONEMKL.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
set(ONEMKL_FOUND FALSE)

set(ONEMKL_LIBRARIES)

# In order to be compatible with various situations of Pytorch development
# bundle setup, ENV{MKLROOT} and SYCL_ROOT will be checked sequentially to get
# the root directory of oneMKL.
if(DEFINED ENV{MKLROOT})
# Directly get the root directory of oneMKL if ENV{MKLROOT} exists.
set(ONEMKL_ROOT $ENV{MKLROOT})
elseif(SYCL_FOUND)
# oneMKL configuration may not be imported into the build system. Get the root
# directory of oneMKL based on the root directory of compiler relatively.
get_filename_component(ONEMKL_ROOT "${SYCL_ROOT}/../../mkl/latest" REALPATH)
endif()

if(NOT DEFINED ONEMKL_ROOT)
message(
WARNING
"Cannot find either ENV{MKLROOT} or SYCL_ROOT, please setup oneAPI environment before building!!"
)
return()
endif()

if(NOT EXISTS ${ONEMKL_ROOT})
message(
WARNING
"${ONEMKL_ROOT} not found, please setup oneAPI environment before building!!"
)
return()
endif()

find_file(
ONEMKL_INCLUDE_DIR
NAMES include
HINTS ${ONEMKL_ROOT}
NO_DEFAULT_PATH)

find_file(
ONEMKL_LIB_DIR
NAMES lib
HINTS ${ONEMKL_ROOT}
NO_DEFAULT_PATH)

if((ONEMKL_INCLUDE_DIR STREQUAL "ONEMKL_INCLUDE_DIR-NOTFOUND")
OR (ONEMKL_LIB_DIR STREQUAL "ONEMKL_LIB_DIR-NOTFOUND"))
message(WARNING "oneMKL sdk is incomplete!!")
return()
endif()

if(WIN32)
set(MKL_LIB_NAMES "mkl_sycl" "mkl_intel_lp64" "mkl_intel_thread" "mkl_core")
else()
set(MKL_LIB_NAMES "mkl_sycl_dft" "mkl_intel_lp64" "mkl_gnu_thread" "mkl_core")
endif()

foreach(LIB_NAME IN LISTS MKL_LIB_NAMES)
find_library(
${LIB_NAME}_library
NAMES ${LIB_NAME}
HINTS ${ONEMKL_LIB_DIR}
NO_CMAKE_PATH NO_CMAKE_ENVIRONMENT_PATH)
list(APPEND ONEMKL_LIBRARIES ${${LIB_NAME}_library})
endforeach()

set(ONEMKL_FOUND TRUE)
11 changes: 11 additions & 0 deletions cmake/ONEMKL.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
find_package(ONEMKL)
if(NOT ONEMKL_FOUND)
message(FATAL_ERROR "Can NOT find ONEMKL cmake helpers module!")
endif()

set(TORCH_XPU_OPS_ONEMKL_INCLUDE_DIR ${ONEMKL_INCLUDE_DIR})

set(TORCH_XPU_OPS_ONEMKL_LIBRARIES ${ONEMKL_LIBRARIES})

list(INSERT TORCH_XPU_OPS_ONEMKL_LIBRARIES 1 "-Wl,--start-group")
list(APPEND TORCH_XPU_OPS_ONEMKL_LIBRARIES "-Wl,--end-group")
3 changes: 3 additions & 0 deletions src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@

file(GLOB xpu_h "xpu/*.h")
file(GLOB xpu_cpp "xpu/*.cpp")
file(GLOB xpu_mkl "native/xpu/mkl/*.cpp")
file(GLOB xpu_native_cpp "native/xpu/*.cpp" "native/sparse/*.cpp" "native/sparse/xpu/*.cpp" "native/transformers/*.cpp" "native/quantized/*.cpp")
file(GLOB xpu_sycl "native/xpu/sycl/*.cpp" "native/sparse/xpu/sycl/*.cpp" "native/transformers/sycl/*.cpp" "native/quantized/sycl/*.cpp")

list(APPEND ATen_XPU_CPP_SRCS ${xpu_cpp})
list(APPEND ATen_XPU_MKL_SRCS ${xpu_mkl})
list(APPEND ATen_XPU_NATIVE_CPP_SRCS ${xpu_native_cpp})
list(APPEND ATen_XPU_SYCL_SRCS ${xpu_sycl})

set(ATen_XPU_CPP_SRCS ${ATen_XPU_CPP_SRCS} PARENT_SCOPE)
set(ATen_XPU_MKL_SRCS ${ATen_XPU_MKL_SRCS} PARENT_SCOPE)
set(ATen_XPU_NATIVE_CPP_SRCS ${ATen_XPU_NATIVE_CPP_SRCS} PARENT_SCOPE)
set(ATen_XPU_SYCL_SRCS ${ATen_XPU_SYCL_SRCS} PARENT_SCOPE)

Expand Down
28 changes: 28 additions & 0 deletions src/ATen/native/xpu/SpectralOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include <ATen/native/Resize.h>
#include <ATen/native/xpu/mkl/SpectralOps.h>
#include <comm/xpu_aten.h>

namespace at::native {

Tensor _fft_c2c_xpu(
const Tensor& self,
IntArrayRef dim,
int64_t normalization,
bool forward) {
TORCH_CHECK(self.is_complex());

return native::xpu::_fft_c2c_mkl(self, dim, normalization, forward);
}

Tensor& _fft_c2c_xpu_out(
const Tensor& self,
IntArrayRef dim,
int64_t normalization,
bool forward,
Tensor& out) {
TORCH_CHECK(self.is_complex());

return native::xpu::_fft_c2c_mkl_out(self, dim, normalization, forward, out);
}

} // namespace at::native
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"_cholesky_solve_helper",
"dot",
"_efficient_attention_forward",
"_fft_c2c",
"_fft_c2r",
"_fft_r2c",
"_flash_attention_forward",
Expand Down
Loading

0 comments on commit f1204c4

Please sign in to comment.