Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix clang avx512bf16 build #5842

Merged
merged 6 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ endif()
##############################################

include(CheckCXXCompilerFlag)
set(CMAKE_TRY_COMPILE_CONFIGURATION release)

# gnu inline assembly in clang msvc does not work actually
if(NOT (CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")))
Expand Down Expand Up @@ -523,7 +524,7 @@ else()
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX512")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256bh _s; __m512bh _a, _b; _s = _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(_s), _a, _b)); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16)
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256bh _s; __m512bh _a, _b; _s = _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(_s), _a, _b)); return 0; }\n__m512i t(__m512 a) { __m256i _a = (__m256i)_mm512_cvtneps_pbh(a); return _mm512_inserti32x8(_mm512_castsi256_si512(_a), _a, 1); }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX512")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512h _s, _a, _b; _s = _mm512_fmadd_ph(_s, _a, _b); __m512 _s2; _s2 = _mm512_cvtxph_ps(_mm512_cvtxps_ph(_s2)); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_FP16)
Expand Down Expand Up @@ -560,7 +561,7 @@ else()
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX512 -mfma -mf16c -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512bf16")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256bh _s; __m512bh _a, _b; _s = _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(_s), _a, _b)); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16)
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256bh _s; __m512bh _a, _b; _s = _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(_s), _a, _b)); return 0; }\n__m512i t(__m512 a) { __m256i _a = (__m256i)_mm512_cvtneps_pbh(a); return _mm512_inserti32x8(_mm512_castsi256_si512(_a), _a, 1); }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX512 -mfma -mf16c -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512fp16")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512h _s, _a, _b; _s = _mm512_fmadd_ph(_s, _a, _b); __m512 _s2; _s2 = _mm512_cvtxph_ps(_mm512_cvtxps_ph(_s2)); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_FP16)
Expand Down Expand Up @@ -595,7 +596,7 @@ else()
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI)

set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512bf16")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256bh _s; __m512bh _a, _b; _s = _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(_s), _a, _b)); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16)
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256bh _s; __m512bh _a, _b; _s = _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(_s), _a, _b)); return 0; }\n__m512i t(__m512 a) { __m256i _a = (__m256i)_mm512_cvtneps_pbh(a); return _mm512_inserti32x8(_mm512_castsi256_si512(_a), _a, 1); }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16)

set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512fp16")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512h _s, _a, _b; _s = _mm512_fmadd_ph(_s, _a, _b); __m512 _s2; _s2 = _mm512_cvtxph_ps(_mm512_cvtxps_ph(_s2)); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_FP16)
Expand Down
4 changes: 2 additions & 2 deletions src/layer/x86/gemm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -2014,7 +2014,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i

__m256i _pp = combine4x2_epi32(_pp0, _pp1);
#if !__AVXVNNIINT8__
_w_shift = _mm256_dpbusd_epi32(_w_shift, _v127, _pp);
_w_shift = _mm256_comp_dpbusd_epi32(_w_shift, _v127, _pp);
#endif // !__AVXVNNIINT8__
_mm256_storeu_si256((__m256i*)pp, _pp);

Expand Down Expand Up @@ -2108,7 +2108,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i

__m256i _pp = combine4x2_epi32(_pp0, _pp1);
#if !__AVXVNNIINT8__
_w_shift = _mm256_dpbusd_epi32(_w_shift, _v127, _pp);
_w_shift = _mm256_comp_dpbusd_epi32(_w_shift, _v127, _pp);
#endif // !__AVXVNNIINT8__
_mm256_storeu_si256((__m256i*)pp, _pp);

Expand Down
6 changes: 3 additions & 3 deletions src/layer/x86/x86_usability.h
Original file line number Diff line number Diff line change
Expand Up @@ -1490,9 +1490,9 @@ static NCNN_FORCEINLINE __m256i float2bfloat_avx512(const __m512& v0)
static NCNN_FORCEINLINE __m512i float2bfloat_avx512(const __m512& v0, const __m512& v1)
{
#if __AVX512BF16__
__m256bh _v0 = _mm512_cvtneps_pbh(v0);
__m256bh _v1 = _mm512_cvtneps_pbh(v1);
__m512i _v = _mm512_inserti32x8(_mm512_castsi256_si512((__m256i)_v0), (__m256i)_v1, 1);
__m256i _v0 = (__m256i)_mm512_cvtneps_pbh(v0);
__m256i _v1 = (__m256i)_mm512_cvtneps_pbh(v1);
__m512i _v = _mm512_inserti32x8(_mm512_castsi256_si512(_v0), _v1, 1);
#else
__m512i _a = _mm512_castps_si512(v0);
__m512i _b = _mm512_castps_si512(v1);
Expand Down
Loading