Skip to content

Commit

Permalink
fix clang avx512bf16 build (#5842)
Browse files Browse the repository at this point in the history
* check compiler supports isa with optimization enabled
  • Loading branch information
nihui authored Dec 23, 2024
1 parent 6c438c4 commit 66cd40e
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
7 changes: 4 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ endif()
##############################################

include(CheckCXXCompilerFlag)
set(CMAKE_TRY_COMPILE_CONFIGURATION release)

# gnu inline assembly in clang msvc does not work actually
if(NOT (CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")))
Expand Down Expand Up @@ -523,7 +524,7 @@ else()
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX512")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256bh _s; __m512bh _a, _b; _s = _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(_s), _a, _b)); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16)
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256bh _s; __m512bh _a, _b; _s = _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(_s), _a, _b)); return 0; }\n__m512i t(__m512 a) { __m256i _a = (__m256i)_mm512_cvtneps_pbh(a); return _mm512_inserti32x8(_mm512_castsi256_si512(_a), _a, 1); }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX512")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512h _s, _a, _b; _s = _mm512_fmadd_ph(_s, _a, _b); __m512 _s2; _s2 = _mm512_cvtxph_ps(_mm512_cvtxps_ph(_s2)); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_FP16)
Expand Down Expand Up @@ -560,7 +561,7 @@ else()
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX512 -mfma -mf16c -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512bf16")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256bh _s; __m512bh _a, _b; _s = _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(_s), _a, _b)); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16)
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256bh _s; __m512bh _a, _b; _s = _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(_s), _a, _b)); return 0; }\n__m512i t(__m512 a) { __m256i _a = (__m256i)_mm512_cvtneps_pbh(a); return _mm512_inserti32x8(_mm512_castsi256_si512(_a), _a, 1); }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX512 -mfma -mf16c -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512fp16")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512h _s, _a, _b; _s = _mm512_fmadd_ph(_s, _a, _b); __m512 _s2; _s2 = _mm512_cvtxph_ps(_mm512_cvtxps_ph(_s2)); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_FP16)
Expand Down Expand Up @@ -595,7 +596,7 @@ else()
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI)

set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512bf16")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256bh _s; __m512bh _a, _b; _s = _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(_s), _a, _b)); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16)
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256bh _s; __m512bh _a, _b; _s = _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(_s), _a, _b)); return 0; }\n__m512i t(__m512 a) { __m256i _a = (__m256i)_mm512_cvtneps_pbh(a); return _mm512_inserti32x8(_mm512_castsi256_si512(_a), _a, 1); }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16)

set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512fp16")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512h _s, _a, _b; _s = _mm512_fmadd_ph(_s, _a, _b); __m512 _s2; _s2 = _mm512_cvtxph_ps(_mm512_cvtxps_ph(_s2)); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_FP16)
Expand Down
4 changes: 2 additions & 2 deletions src/layer/x86/gemm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -2014,7 +2014,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i

__m256i _pp = combine4x2_epi32(_pp0, _pp1);
#if !__AVXVNNIINT8__
_w_shift = _mm256_dpbusd_epi32(_w_shift, _v127, _pp);
_w_shift = _mm256_comp_dpbusd_epi32(_w_shift, _v127, _pp);
#endif // !__AVXVNNIINT8__
_mm256_storeu_si256((__m256i*)pp, _pp);

Expand Down Expand Up @@ -2108,7 +2108,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i

__m256i _pp = combine4x2_epi32(_pp0, _pp1);
#if !__AVXVNNIINT8__
_w_shift = _mm256_dpbusd_epi32(_w_shift, _v127, _pp);
_w_shift = _mm256_comp_dpbusd_epi32(_w_shift, _v127, _pp);
#endif // !__AVXVNNIINT8__
_mm256_storeu_si256((__m256i*)pp, _pp);

Expand Down
6 changes: 3 additions & 3 deletions src/layer/x86/x86_usability.h
Original file line number Diff line number Diff line change
Expand Up @@ -1490,9 +1490,9 @@ static NCNN_FORCEINLINE __m256i float2bfloat_avx512(const __m512& v0)
static NCNN_FORCEINLINE __m512i float2bfloat_avx512(const __m512& v0, const __m512& v1)
{
#if __AVX512BF16__
__m256bh _v0 = _mm512_cvtneps_pbh(v0);
__m256bh _v1 = _mm512_cvtneps_pbh(v1);
__m512i _v = _mm512_inserti32x8(_mm512_castsi256_si512((__m256i)_v0), (__m256i)_v1, 1);
__m256i _v0 = (__m256i)_mm512_cvtneps_pbh(v0);
__m256i _v1 = (__m256i)_mm512_cvtneps_pbh(v1);
__m512i _v = _mm512_inserti32x8(_mm512_castsi256_si512(_v0), _v1, 1);
#else
__m512i _a = _mm512_castps_si512(v0);
__m512i _b = _mm512_castps_si512(v1);
Expand Down

0 comments on commit 66cd40e

Please sign in to comment.